This post records my collected notes on various tricks in using Matplotlib.

Twin axes in the same plot with one legend

Often times, we need to make two axes which share the same x axis, but with different scale of y axis. Fortunately, Matplotlib’s Axes class has a twinx() which fulfills our need. But another question arises: how to make the two axes share one legend?

If we have two axes and use their legend() method seperately. The figure will have two separate legends, which is a bit redundant. In order to use one legend to show the labels of lines in both axes, we need to collect all the lines as well as their corresponding labels.

Two different ways to achieve what we want

We can use two slightly different ways to achieve this. First, you can try the following:

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y1 = 2*x
y2 = np.sqrt(x)
y3 = np.exp(0.5*x)

fig = plt.figure()
ax1 = fig.add_subplot(111)
line1 = ax1.plot(x, y1, 'r', label="2x")
line2 = ax1.plot(x, y2, 'g', label="sqrt(x)")

ax2 = ax1.twinx()
line3 = ax2.plot(x, y3, 'b', label="exp(0.5x)")
lines = line1+line2+line3
labels = [l.get_label() for l in lines]
ax2.legend(lines, labels)
plt.show()

Or you can try the second method:

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y1 = 2*x
y2 = np.sqrt(x)
y3 = np.exp(0.5*x)

fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.plot(x, y1, 'r', label="2x")
ax1.plot(x, y2, 'g', label="sqrt(x)")
ax2 = ax1.twinx()
ax2.plot(x, y3, 'b', label="exp(0.5x)")

handles, labels = [], []
for ax in [ax1, ax2]:
    h, l = ax.get_legend_handles_labels()
    handles.extend(h)
    labels.extend(l)
ax2.legend(handles, labels)

plt.show()

Both will produce the same output plot as shown below,

A side note

Also, what is worth notice is that in first snippet, we use line1+line2+line3. That is possible because the three lines are actually Python list, which is in turn because Matplotlib’s plot() method return a list of lines even if you plot just one line.

Log-scale plot with correct ticks and tick labels

Sometimes, we want to plot in log scale. This is easy to achieve in Matplotlib. We can use the normal plotting command and then set the x axis to log scale. Or we can directly plot in log scale using semi-log method. A simple snippet the shown below,

import matplotlib.pyplot as plt
import numpy as np
import matplotlib

colors = ["#e6194b",
          "#3cb44b",
          "#ffe119"]

m1 = np.array([ 0.6125, 0.775, 0.8375, 0.875, 0.9125, 0.9875])
m2 = np.array([0.6750, 0.8125, 0.8625, 0.9000, 0.9375, 0.9750])
m3 = np.array([0.8000, 0.8625, 0.9250, 0.9625, 0.9625, 0.9750])

position = [1, 2, 4, 8, 16, 32]
fig, ax = plt.subplots()

ax.semilogx(position, m1*100, linestyle='-', color=colors[0])
ax.semilogx(position, m2*100, linestyle='-', color=colors[1])
ax.semilogx(position, m3*100, linestyle='-', color=colors[2])

ax.set_xticks(position)
ax.grid(linestyle='--')
ax.set_xlabel("K")
ax.set_ylabel("Recall@k score (%)")

plt.show()

The above snippet will produce the following figure,

Notice that in the above figure, some minor ticks appear in the plot and also the tick labels are not properly shown.

We can use tick_params() method of axes and minorticks_off() method to turn off the axis minor ticks. As for the im-properly displayed ticks in x axis, we can use set_major_formatter to correct it. A work example is shown below,

import matplotlib.pyplot as plt
import numpy as np
import matplotlib

colors = ["#e6194b",
          "#3cb44b",
          "#ffe119"]

m1 = np.array([ 0.6125, 0.775, 0.8375, 0.875, 0.9125, 0.9875])
m2 = np.array([0.6750, 0.8125, 0.8625, 0.9000, 0.9375, 0.9750])
m3 = np.array([0.8000, 0.8625, 0.9250, 0.9625, 0.9625, 0.9750])

position = [1, 2, 4, 8, 16, 32]
fig, ax = plt.subplots()

ax.semilogx(position, m1*100, linestyle='-', color=colors[0])
ax.semilogx(position, m2*100, linestyle='-', color=colors[1])
ax.semilogx(position, m3*100, linestyle='-', color=colors[2])

ax.set_xticks(position)
ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())

# use the following method to turn off minor ticks

ax.tick_params(axis='x', which='minor', bottom='off')
# ax.minorticks_off()  # or use minorticks_off method

ax.grid(linestyle='--')
ax.set_xlabel("K")
ax.set_ylabel("Recall@k score (%)")

plt.show()

The new plot generate is shown below,

We can see that all the issues of the first-version plot have been corrected.

Change line width in the legend

Sometimes, we want to increase the line width in the legend for better readability. The following snippet shows how to do this in Matplotlib using the object oriented API:

import numpy as np
import matplotlib.pyplot as plt

# make some data

x = np.linspace(0, 2*np.pi)
y1 = np.sin(x)
y2 = np.cos(x)

fig, ax = plt.subplots()
ax.plot(x, y1, linewidth=1.0, label='sin(x)')
ax.plot(x, y2, linewidth=1.0, label='cos(x)')

leg = ax.legend()

for line in leg.get_lines():
    line.set_linewidth(4.0)

plt.show()

The generated image is shown below,

References