Warm tip: This article is reproduced from stackoverflow.com, please click
legend matplotlib plot python

Matplotlib legend makes the image too large

发布于 2020-05-03 00:28:41

I'm plotting this figure with matplotlib, the for loop just color the background:

fig, ax = plt.subplots()

ax.set_ylabel('Number of contacts')  
ax.set_xlabel('Time [s]') 

for m in range(len(data[node])):
    if data[node][m] == -1:
        ax.axvline(m,color='r',linewidth=5,alpha=0.2,label="OUT")
    if data[node][m] == 0:
        ax.axvline(m,color='g',linewidth=5,alpha=0.2,label="RZ0")
    if data[node][m] == 1:
        ax.axvline(m,color='y',linewidth=5,alpha=0.2,label="RZ1")

ax.plot(x, y, 'b+')
# ax.legend() # HERE is the problem
plt.show()

Which plots the following:

enter image description here

What I want now is a legend to indicate each color of the background meaning, but when I include ax.legend() I get the following error:

ValueError: Image size of 392x648007 pixels is too large. It must be less than 2^16 in each 
direction.

<Figure size 432x288 with 1 Axes>
<Figure size 432x288 with 0 Axes>

How am I supposed to name each color of the background, there are 43200 vertical lines but only 3 colors, does it have anything to do with the number of lines?

Questioner
BlueMountain
Viewed
12
JohanC 2020-02-14 20:16

The trick is to set the label only once. You can add a variable for each label and replace it with None once it's used. Note that using axvline to draw a background has the problem that the line width is measured in pixel space, so neighboring lines will either overlap or have a small white space inbetween. Better to use axvspan. To avoid the white space at the left and at the right, you can explicitly set the x-limits.

The code can be simplified somewhat using a loop.

Updated code:

  • group consecutive spans together for drawing
  • precalculate the effect of alpha so the background can be drawn without the need for transparency
from matplotlib import pyplot as plt
from matplotlib import colors as mcolors
import numpy as np
import pandas as pd
import itertools

fig, ax = plt.subplots()

# create some random data
x = np.arange(100)
y = np.sinh(x/20)
indicators = [-1, 0, 1]
node = 0
data = [np.random.choice(indicators, len(x), p=[10/16,1/16,5/16])]

labels = ["OUT", "RZ0", "RZ1"]
colors = ['lime', 'purple', 'gold']

alpha = 0.4
# precalculate the effect of alpha so the colors can be applied with alpha=1
colors = [[1 + (x - 1) * alpha for x in mcolors.to_rgb(c)] for c in colors]

m = 0
for val, group in itertools.groupby(data[node]):
    width = len(list(group))
    ind = indicators.index(val)
    ax.axvspan(m, m + width, color=colors[ind], linewidth=0, alpha=1, label=labels[ind])
    labels[ind] = None  # reset the label to make sure it is only used once
    m += width

ax.plot(x, y, 'b+')
ax.set_xlim(0, len(data[node]))
ax.legend(framealpha=1) # to make the legend background opaque
plt.show()

sample plot