Subplots

3.5. Subplots#

In Python, subplots refer to the division of a single figure into multiple smaller plots or subplots. Each subplot is an independent plot area within the larger figure. Subplots are useful when you want to display multiple plots or visualizations side by side or in a grid-like arrangement.

The subplots() function in the matplotlib library is used to create subplots. It allows you to specify the number of rows and columns in the subplot grid, which determines the overall layout of the subplots.

Here’s an example to help you understand subplots. We will use the dataset of a sample of 100 vehicles corresponding to the 3-axle vehicle type 3C (remember the Maximum bending moment on a simply supported bridge example on numpy section?)

First we will read the data set using pandas

dataset = pd.read_csv("https://raw.githubusercontent.com/"
                "mike-mendoza/Bivariate_NPBN_workshop_files/"
                "a991bc3d9391a92437af1c3d69ae9fdfe6baf6da/"
                "files_pyhton_book_test/V3AX_WIM_BR.csv")
dataset.head()
A1_kN A2_kN A3_kN D1_m D2_m
0 42.1 77.5 65.3 5.1 1.4
1 48.7 80.1 50.2 5.4 1.2
2 51.7 90.2 61.6 5.2 1.2
3 41.2 75.7 58.6 5.4 1.2
4 25.0 48.4 33.5 5.6 1.2

Let’s create one figure with one histogram per colum in the dataset using for loop.

variable_names = ['Axle 1', 'Axle 2', 'Axle 3',
                   'Inter-axle distance 1', 'Inter-axle distance 2']

xlabels =['Weight [kN]', 'Weight [kN]', 'Weight [kN]',
           'Distance [m]', 'Distance [m]']

fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(12,3))
for i,column in enumerate(dataset.columns):
    axes[i].hist(dataset[column])
    axes[i].set_xlabel(xlabels[i])
    axes[i].set_ylabel('Frequency')
    axes[i].set_title(variable_names[i])
plt.suptitle('Histograms of variables in the dataset')    
plt.tight_layout()
plt.show()
../_images/31656bc0c4efee7eb8d2159a3c2529eae095bac5bf3212865f2f89f13c7f1a5d.png

Let’s break it down

  1. variable_names is a list containing the names of different variables in a dataset.

  2. xlabels is a list containing the x-axis labels for each histogram.

  3. The code creates a figure with 1 row and 5 columns of subplots using plt.subplots(nrows=1, ncols=5, figsize=(12, 3)).

  4. for i, column in enumerate(dataset.columns) initiates a loop that iterates over the columns of the dataset. It uses the enumerate() function to retrieve both the index (i) and the column name (column) at each iteration.

  5. It then loops over the columns of the dataset and creates a histogram for each column using axes[i].hist(dataset[column]).

  6. The x-axis label, y-axis label, and title of each subplot are set using axes[i].set_xlabel(xlabels[i]), axes[i].set_ylabel('Frequency'), and axes[i].set_title(variable_names[i]), respectively.

  7. The super title for the entire figure is set using plt.suptitle('Histograms of variables in the dataset').

  8. plt.tight_layout() adjusts the spacing between subplots.

  9. Finally, plt.show() displays the figure with all the subplots.