Skip to content



Sometimes it is useful for problem solvers to include a couple plots in the same figure window. This can be done in Matplotlib using subplots. Matplotlib's plt.subplot() function can include two positional arguments for the number of rows of subplots in the figure and the number of columns of subplots in the figure. The general format is:

fig, <ax objects> = plt.subplots(rows, cols)

Where rows and cols denote how the subplots are laid out. The <ax objects> needs to have dimensions that correspond to rows and cols.

If a 2 row by 2 column array of plots is created, the <ax object> must to be arrayed as shown below:

fig, ( (ax1,ax2), (ax3,ax4) ) = plt.subplots(2,2)   

If a 2 row by 3 column array of plots is created, the <ax objects> must be arrayed to correspond to these dimensions:

fig, ( (ax1,ax2,a3), (ax4,ax5,ax6) ) = plt.subplots(2, 3)   

Subplots are useful if you want to show the same data on different scales. The plot of an exponential function looks different on a linear scale compare to a logarithmic scale. Matplotlib contains three plotting methods which scale the x and y-axis linearly or logarithmically. A table below summarizes these methods.

Matplotlib method axis scaling
ax.plot() linear x, linear y
ax.semilogy() linear x, logarithmic y
ax.semilogx() logarithmic x, linear y
ax.loglog() logarithmic x, logarithmic y

The code section below builds a 2 row by 2 column array of subplots in one figure. The axes of each subplot is scaled in a different way.

In [1]:
import matplotlib.pyplot as plt
import numpy as np
# if using a Jupyter notebook, include:
%matplotlib inline

# Data for plotting t = np.arange(0.01, 20.0, 0.01)

# Create a figure with 2 rows and 2 cols of subplots fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)

# linear x and y axis ax1.plot(t, np.exp(-t / 5.0)) ax1.set(title='linear x and y') ax1.grid()

# log y axis ax2.semilogy(t, np.exp(-t / 5.0)) ax2.set(title='semilogy') ax2.grid()

# log x axis ax3.semilogx(t, np.exp(-t / 5.0)) ax3.set(title='semilogx') ax3.grid()

# log x and y axis ax4.loglog(t, 20 * np.exp(-t / 5.0), basex=2) ax4.set(title='loglog base 2 on x') ax4.grid()