· 8 min read
Fix your matplotlib colorbars (animations)
Some time ago I saw a post on fixing matplotlib colorbars by Joseph Lang. One of the major problems with matplotlib is that it works quite well out of the box, until it doesn’t. Making figures for publication just right can become tedious as shown by Joseph Lang as it requires one deep into the docs and figuring something simple like aligning colorbars to axes.
After browsing stackoverflow I was triggered by a post that seemed to achieve something somewhat trivial but the proposed solution seemed too difficult. The problem posted by original poster described how the goal was to update the colorbar in animations of a heatmap. Animating in matplotlib can be super slow, if some tricks are not known to the end-user. Examples include blitting, preventing large object reconstruction, setting data structures manually, drawing canvas etc.
After digging through the docs I figured out that in order to prevent manually updating the colorbar, one needs to force update of the map afterwhich the colorbar will be updated:
from matplotlib.pyplot import subplots
from matplotlib.animation import FuncAnimation
from matplotlib import rc
from mpl_toolkits.axes_grid1 import make_axes_locatable
from IPython.display import HTML
rc('animation', html = 'html5')
import numpy as np, threading
# define data
d = np.random.randn(10, 10)
# setup figure
fig, ax = subplots(figsize = \
(4,4))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", \
size = "5%",\
pad = .1) # J. Lang
# create mappable
h = ax.imshow(d)
# create colorbar
cb = fig.colorbar(h, cax = cax)
def update(i):
# generate new data
h.set_data(np.random.randn(*d.shape) + np.random.randn())
# rescale data for cb trigger
h.norm.autoscale(h._A)
# update mappable
h.colorbar.update_normal(h.colorbar.mappable)
# flush events update time
ax.set_title(f't = {i}')
fig.canvas.draw(); fig.canvas.flush_events();
return (h,)
# fig.subplots_adjust(wspace = 1)
fig.tight_layout(h_pad = 1)
anim = FuncAnimation(fig, update, blit = 1,\
repeat = 1,\
frames = 20,\
interval = 100)
html_vid = anim.to_html5_video(embed_limit = 5)
fig.set_visible(0) # prevent double
HTML(html_vid)