import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
Showing static visualizations
This page is generated from a Jupyter notebook and demonstrates how to generate static visualizations with matplotlib
, pandas
, and seaborn
.
Start by importing the packages we need:
Load the “Palmer penguins” dataset from week 2:
# Load data on Palmer penguins
= pd.read_csv("https://raw.githubusercontent.com/MUSA-550-Fall-2023/week-2/main/data/penguins.csv") penguins
# Show the first ten rows
=10) penguins.head(n
species | island | bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | sex | year | |
---|---|---|---|---|---|---|---|---|
0 | Adelie | Torgersen | 39.1 | 18.7 | 181.0 | 3750.0 | male | 2007 |
1 | Adelie | Torgersen | 39.5 | 17.4 | 186.0 | 3800.0 | female | 2007 |
2 | Adelie | Torgersen | 40.3 | 18.0 | 195.0 | 3250.0 | female | 2007 |
3 | Adelie | Torgersen | NaN | NaN | NaN | NaN | NaN | 2007 |
4 | Adelie | Torgersen | 36.7 | 19.3 | 193.0 | 3450.0 | female | 2007 |
5 | Adelie | Torgersen | 39.3 | 20.6 | 190.0 | 3650.0 | male | 2007 |
6 | Adelie | Torgersen | 38.9 | 17.8 | 181.0 | 3625.0 | female | 2007 |
7 | Adelie | Torgersen | 39.2 | 19.6 | 195.0 | 4675.0 | male | 2007 |
8 | Adelie | Torgersen | 34.1 | 18.1 | 193.0 | 3475.0 | NaN | 2007 |
9 | Adelie | Torgersen | 42.0 | 20.2 | 190.0 | 4250.0 | NaN | 2007 |
A simple visualization, 3 different ways
I want to scatter flipper length vs. bill length, colored by the penguin species
Using matplotlib
# Setup a dict to hold colors for each species
= {"Adelie": "#1f77b4", "Gentoo": "#ff7f0e", "Chinstrap": "#D62728"}
color_map
# Initialize the figure "fig" and axes "ax"
= plt.subplots(figsize=(10, 6))
fig, ax
# Group the data frame by species and loop over each group
# NOTE: "group" will be the dataframe holding the data for "species"
for species, group_df in penguins.groupby("species"):
# Plot flipper length vs bill length for this group
# Note: we are adding this plot to the existing "ax" object
ax.scatter("flipper_length_mm"],
group_df["bill_length_mm"],
group_df[="o",
marker=species,
label=color_map[species],
color=0.75,
alpha=10
zorder
)
# Plotting is done...format the axes!
## Add a legend to the axes
="best")
ax.legend(loc
## Add x-axis and y-axis labels
"Flipper Length (mm)")
ax.set_xlabel("Bill Length (mm)")
ax.set_ylabel(
## Add the grid of lines
True); ax.grid(
How about in pandas
?
DataFrames have a built-in “plot” function that can make all of the basic type of matplotlib plots!
First, we need to add a new “color” column specifying the color to use for each species type.
Use the pd.replace()
function: it use a dict to replace values in a DataFrame column.
# Calculate a list of colors
= {"Adelie": "#1f77b4", "Gentoo": "#ff7f0e", "Chinstrap": "#D62728"}
color_map
# Map species name to color
"color"] = penguins["species"].replace(color_map)
penguins[
penguins.head()
species | island | bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | sex | year | color | |
---|---|---|---|---|---|---|---|---|---|
0 | Adelie | Torgersen | 39.1 | 18.7 | 181.0 | 3750.0 | male | 2007 | #1f77b4 |
1 | Adelie | Torgersen | 39.5 | 17.4 | 186.0 | 3800.0 | female | 2007 | #1f77b4 |
2 | Adelie | Torgersen | 40.3 | 18.0 | 195.0 | 3250.0 | female | 2007 | #1f77b4 |
3 | Adelie | Torgersen | NaN | NaN | NaN | NaN | NaN | 2007 | #1f77b4 |
4 | Adelie | Torgersen | 36.7 | 19.3 | 193.0 | 3450.0 | female | 2007 | #1f77b4 |
Now plot!
# Same as before: Start by initializing the figure and axes
= plt.subplots(figsize=(10, 6))
fig, myAxes
# Scatter plot two columns, colored by third
# Use the built-in pandas plot.scatter function
penguins.plot.scatter(="flipper_length_mm",
x="bill_length_mm",
y="color",
c=0.75,
alpha=myAxes, # IMPORTANT: Make sure to plot on the axes object we created already!
ax=10
zorder
)
# Format the axes finally
"Flipper Length (mm)")
myAxes.set_xlabel("Bill Length (mm)")
myAxes.set_ylabel(True); myAxes.grid(
Note: no easy way to get legend added to the plot in this case…
Seaborn: statistical data visualization
Seaborn is designed to plot two columns colored by a third column…
# Initialize the figure and axes
= plt.subplots(figsize=(10, 6))
fig, ax
# style keywords as dict
= {"Adelie": "#1f77b4", "Gentoo": "#ff7f0e", "Chinstrap": "#D62728"}
color_map = dict(palette=color_map, s=60, edgecolor="none", alpha=0.75, zorder=10)
style
# use the scatterplot() function
sns.scatterplot(="flipper_length_mm", # the x column
x="bill_length_mm", # the y column
y="species", # the third dimension (color)
hue=penguins, # pass in the data
data=ax, # plot on the axes object we made
ax**style # add our style keywords
)
# Format with matplotlib commands
"Flipper Length (mm)")
ax.set_xlabel("Bill Length (mm)")
ax.set_ylabel(True)
ax.grid(="best"); ax.legend(loc