Source code for pygeode.plot.wrappers

# Wrapper for matplotlib.pyplot

from matplotlib import pyplot as pyl
import matplotlib as mpl

# Allows plots to be constructed in a slightly more object-oriented way.
# Also allows plot objects to be saved to / loaded from files, something which
# normal matplotlib plots can't do.

# Interface for wrapping a matplotlib axes object
[docs]class AxesWrapper: # {{{ def __init__(self, parent=None, rect=None, size=None, pad=None, make_axis=False, name='', **kwargs): # {{{ self.parent = parent self.nplots = 0 self.plots = [] self.make_axis = make_axis self.naxes = 0 self.axes = [] self.ax_boxes = [] if rect is None: rect = [pyl.rcParams['figure.subplot.' + k] for k in ['left', 'bottom', 'right', 'top']] rect[2] -= rect[0] rect[3] -= rect[1] self.rect = rect self.pad = pad if size is None: size = pyl.rcParams['figure.figsize'] self.size = (float(size[0]), float(size[1])) self.args = {} self.axes_args = kwargs self.projection = None self.xaxis_args = {} self.yaxis_args = {} self.name = name # }}} def add_axis(self, axis, rect): # {{{ bb = mpl.transforms.Bbox.from_extents(rect) assert all([s > 0. for s in bb.size]), 'Bounding box must not vanish' axis.parent = self self.axes.append(axis) self.ax_boxes.append(rect) self.naxes = len(self.axes) # }}} def add_plot(self, plot, order=None, make_axes=True): # {{{ plot.axes = self if order is None: self.plots.append(plot) else: self.plots.insert(plot, order) if make_axes: self.make_axis = True self.nplots = len(self.plots) # }}} def pop_plot(self, order=-1): # {{{ self.plots.pop(order) self.nplots = len(self.plots) # }}} def render(self, fig = None, show = True, **kwargs): # {{{ wason = pyl.isinteractive() if wason: pyl.ioff() if not isinstance(fig, mpl.figure.Figure): figparm = dict(figsize = self.size) figparm.update(kwargs) if not fig is None: figparm['num'] = fig fig = pyl.figure(**figparm) fig.clf() self._build_axes(fig, self) self._do_plots(fig) if wason: pyl.ion() if show: # Check if the current backend has a GUI if mpl.get_backend() in mpl.rcsetup.interactive_bk + ['module://ipympl.backend_nbagg', 'nbAgg']: pyl.show() pyl.draw() return fig # }}} def get_transform(self, root = None): # {{{ if self is root or self.parent is None: return mpl.transforms.IdentityTransform() ia = self.parent.axes.index(self) rect = self.parent.ax_boxes[ia] box = mpl.transforms.Bbox.from_extents(rect) t_self = mpl.transforms.BboxTransformTo(box) t_parent = self.parent.get_transform() return mpl.transforms.CompositeAffine2D(t_self, t_parent) # }}} def _build_axes(self, fig, root): # {{{ if self.make_axis: tfm = self.get_transform(root) if self.pad is not None: l, b = tfm.transform_point((0., 0.)) r, t = tfm.transform_point((1., 1.)) #print l, b, r, t fsize = fig.get_size_inches() l += self.pad[0] / fsize[0] b += self.pad[1] / fsize[1] w = r - l - self.pad[2] / fsize[0] h = t - b - self.pad[3] / fsize[1] else: l, b = self.rect[0], self.rect[1] r, t = self.rect[0] + self.rect[2], self.rect[1] + self.rect[3] l, b = tfm.transform_point((l, b)) r, t = tfm.transform_point((r, t)) w = r - l h = t - b self.ax = fig.add_axes([l, b, w, h], projection = self.projection) else: self.ax = None # Build children for a in self.axes: a._build_axes(fig, root) #Draw bounding boxes of children for debugging purposes #if root is self: #ax = pyl.gca() #for b in self.ax_boxes: #xy = b[0], b[1] #w = b[2] - b[0] #h = b[3] - b[1] #ax.add_patch(pyl.Rectangle(xy, w, h, lw=2., fill=False, transform=fig.transFigure, clip_on=False)) # }}} def _do_plots(self, fig): # {{{ # Plot children for a in self.axes: a._do_plots(fig) if self.ax is None: return preops = [p for p in self.plots if p.pre] postops = [p for p in self.plots if not p.pre] # Perform plotting operations for p in preops: p.render(self.ax) # Handle scaling first, because setting this screws up other custom attributes like ticks args = self.args.copy() if 'xscale' in args: self.ax.set_xscale(args.pop('xscale')) if 'yscale' in args: self.ax.set_yscale(args.pop('yscale')) if len(args) > 0: pyl.setp(self.ax, **args) if len(self.xaxis_args) > 0: pyl.setp(self.ax.xaxis, **self.xaxis_args) if len(self.yaxis_args) > 0: pyl.setp(self.ax.yaxis, **self.yaxis_args) # Perform plotting operations for p in postops: p.render(self.ax) # }}} def setp(self, children=True, **kwargs): # {{{ self.args.update(kwargs) if children: for a in self.axes: a.setp(children, **kwargs) # }}} def setp_xaxis(self, children=True, **kwargs): # {{{ self.xaxis_args.update(kwargs) if children: for a in self.axes: a.setp_xaxis(children, **kwargs) # }}} def setp_yaxis(self, children=True, **kwargs): # {{{ self.yaxis_args.update(kwargs) if children: for a in self.axes: a.setp_yaxis(children, **kwargs) # }}} def find_plot(self, cl): # {{{ ''' Returns last instance of plot class cl in this axes plots. ''' for p in reversed(self.plots): if isinstance(p, cl): return p return None
# }}} # }}} # Generic object for holding plot information class PlotOp: # {{{ def __init__(self, *plot_args, **kwargs): self.plot_args = plot_args self.plot_kwargs = kwargs self.axes = None self.pre = True # Draw the thing # This is pretty much the only public-facing method. def render (self, axes=None): pass # }}} # 1D plots class Plot(PlotOp): # {{{ def render (self, axes): axes.plot (*self.plot_args, **self.plot_kwargs) #print 'Autoscaling.' #axes.autoscale_view() # }}} class FillBetween(PlotOp): # {{{ def render (self, axes): axes.fill_between(*self.plot_args, **self.plot_kwargs) axes.autoscale() # }}} class Scatter(PlotOp): # {{{ def render (self, axes): axes.scatter (*self.plot_args, **self.plot_kwargs) # }}} class Errorbar(PlotOp): # {{{ def render (self, axes): axes.errorbar (*self.plot_args, **self.plot_kwargs) # }}} class Bar(PlotOp): # {{{ def render (self, axes): axes.bar (*self.plot_args, **self.plot_kwargs) # }}} class Histogram(PlotOp): # {{{ def render (self, axes): axes.hist (*self.plot_args, **self.plot_kwargs) # }}} class AxHLine(PlotOp): # {{{ def render (self, axes): axes.axhline (*self.plot_args, **self.plot_kwargs) # }}} class AxVLine(PlotOp): # {{{ def render (self, axes): axes.axvline (*self.plot_args, **self.plot_kwargs) # }}} class Legend(PlotOp): # {{{ def render (self, axes): axes.legend (*self.plot_args, **self.plot_kwargs) # }}} class Text(PlotOp): # {{{ def render (self, axes): kwargs = self.plot_kwargs.copy() tr = kwargs.pop('transform', 'Data') if tr == 'Axes': kwargs['transform'] = axes.transAxes if tr == 'Data': kwargs['transform'] = axes.transData axes.text (*self.plot_args, **kwargs) # }}} # Contour class Contour(PlotOp): # {{{ def render (self, axes): self._cnt = axes.contour (*self.plot_args, **self.plot_kwargs) # }}} # Filled Contour class Contourf(PlotOp): # {{{ def render (self, axes): self._cnt = axes.contourf (*self.plot_args, **self.plot_kwargs) # }}} # Op to modify contours class ModifyContours(PlotOp): # {{{ def __init__(self, cnt, ind=None, **kwargs): self.cnt = cnt self.ind = ind PlotOp.__init__(self, **kwargs) def render (self, axes): coll = self.cnt._cnt.collections if self.ind is None: pyl.setp(coll, **self.plot_kwargs) else: pyl.setp([coll[i] for i in self.ind], **self.plot_kwargs) # }}} # Op to add contour labels class CLabel(PlotOp): # {{{ def __init__(self, cnt, **kwargs): self.cnt = cnt PlotOp.__init__(self, **kwargs) self.pre = False def render (self, axes): pyl.clabel(self.cnt._cnt, **self.plot_kwargs) # }}} # PColor class PColor(PlotOp): # {{{ def render (self, axes): self._cnt = axes.pcolor (*self.plot_args, **self.plot_kwargs) # }}} # Streamplot class Streamplot(PlotOp): # {{{ def render (self, axes): self._sp = axes.streamplot (*self.plot_args, **self.plot_kwargs) # }}} # Quiver class Quiver(PlotOp): # {{{ def render (self, axes): self._cnt = axes.quiver (*self.plot_args, **self.plot_kwargs) # }}} # Op to add a quiver key class QuiverKey(PlotOp): # {{{ def __init__(self, cnt, *args, **kwargs): self.cnt = cnt PlotOp.__init__(self, *args, **kwargs) def render (self, axes): pyl.quiverkey(self.cnt._cnt, *self.plot_args, **self.plot_kwargs) # }}} # imshow class ImShow(PlotOp): # {{{ def render (self, axes): self._cnt = axes.imshow (*self.plot_args, **self.plot_kwargs) # }}} # Colorbar class Colorbar(PlotOp): # {{{ def __init__(self, cnt, cax, *plot_args, **kwargs): self.cnt = cnt self.cax = cax self.lcnt = kwargs.pop('lcnt', None) PlotOp.__init__(self, *plot_args, **kwargs) def render (self, axes): self._cbar = pyl.colorbar(self.cnt._cnt, cax=self.cax.ax, *self.plot_args, **self.plot_kwargs) if self.lcnt is not None: self._cbar.add_lines(self.lcnt._cnt) pyl.sca(axes) # }}} def colorbar(axes, cnt, cax=None, rect=None, *args, **kwargs): # {{{ if cax is None: pos = kwargs.pop('pos', 'r') if pos in ['r', 'l']: orient = kwargs.get('orientation', 'vertical') if pos in ['b', 't']: orient = kwargs.get('orientation', 'horizontal') kwargs['orientation'] = orient if orient == 'horizontal': height = kwargs.pop('height', 0.4) size = axes.size[0], height if rect is None: l = kwargs.pop('rl', 0.15) b = kwargs.pop('rb', 0.5) r = kwargs.pop('rr', 0.75) t = kwargs.pop('rt', 0.4) rect = [l, b, r, t] else: width = kwargs.pop('width', 0.8) size = width, axes.size[1] if rect is None: l = kwargs.pop('rl', 0.1) b = kwargs.pop('rb', 0.15) r = kwargs.pop('rr', 0.2) t = kwargs.pop('rt', 0.75) rect = [l, b, r, t] cax = AxesWrapper(size=size, rect=rect, make_axis=True) if pos == 'b': ret = grid([[axes], [cax]]) elif pos == 'l': ret = grid([[cax, axes]]) elif pos == 't': ret = grid([[cax], [axes]]) else: ret = grid([[axes, cax]]) else: ret = None ticklabels = kwargs.pop('ticklabels', None) kwargs['spacing'] = kwargs.pop('spacing', 'proportional') cnt.axes.add_plot(Colorbar(cnt, cax, *args, **kwargs)) if ticklabels is not None: if orient == 'horizontal': cax.setp(xticklabels = ticklabels) else: cax.setp(yticklabels = ticklabels) return ret # }}} def make_plot_func(fclass, make_axes=True): def f(*args, **kwargs): axes = kwargs.pop('axes', None) if axes is None: axes = AxesWrapper() plotop = fclass(*args, **kwargs) axes.add_plot(plotop, make_axes=make_axes) return plotop return f def make_plot_member(f): def g(self, *args, **kwargs): return f(*args, axes = self, **kwargs) return g plot = make_plot_func(Plot) fill_between = make_plot_func(FillBetween) scatter = make_plot_func(Scatter) errorbar = make_plot_func(Errorbar) bar = make_plot_func(Bar) hist = make_plot_func(Histogram) axhline = make_plot_func(AxHLine) axvline = make_plot_func(AxVLine) legend = make_plot_func(Legend) text = make_plot_func(Text, make_axes=False) contour = make_plot_func(Contour) contourf = make_plot_func(Contourf) modifycontours = make_plot_func(ModifyContours) clabel = make_plot_func(CLabel) pcolor = make_plot_func(PColor) streamplot = make_plot_func(Streamplot) quiver = make_plot_func(Quiver) quiverkey = make_plot_func(QuiverKey) imshow = make_plot_func(ImShow) __all__ = ['AxesWrapper', 'plot', 'fill_between', 'scatter', 'errorbar', 'bar', 'hist', 'axhline', 'axvline', 'legend', 'text', 'contour', 'contourf', 'pcolor', 'quiver', 'quiverkey', 'imshow', 'colorbar'] AxesWrapper.plot = make_plot_member(plot) AxesWrapper.fill_between = make_plot_member(fill_between) AxesWrapper.scatter = make_plot_member(scatter) AxesWrapper.errorbar = make_plot_member(errorbar) AxesWrapper.bar = make_plot_member(bar) AxesWrapper.hist = make_plot_member(hist) AxesWrapper.axhline = make_plot_member(axhline) AxesWrapper.axvline = make_plot_member(axvline) AxesWrapper.legend = make_plot_member(legend) AxesWrapper.text = make_plot_member(text) AxesWrapper.contour = make_plot_member(contour) AxesWrapper.contourf = make_plot_member(contourf) AxesWrapper.modifycontours = make_plot_member(modifycontours) AxesWrapper.clabel = make_plot_member(clabel) AxesWrapper.pcolor = make_plot_member(pcolor) AxesWrapper.streamplot = make_plot_member(streamplot) AxesWrapper.quiver = make_plot_member(quiver) AxesWrapper.quiverkey = make_plot_member(quiverkey) AxesWrapper.imshow = make_plot_member(imshow) # Routine for saving this plot to a file def save (fig, filename): # {{{ import pickle outfile = open(filename,'w') pickle.dump(fig, outfile) outfile.close() # }}} # Module-level routine for loading a plot from file def load (filename): # {{{ import pickle infile = open(filename,'ro') theplot = pickle.load(infile) infile.close() return theplot # }}}
[docs]def grid(axes, size = None): # {{{ # Expect a 2d grid; first index is the row, second the column ny = len(axes) nx = len(axes[0]) assert all([len(x) == nx for x in axes[1:]]), 'Each row must have the same number of axes' rowh = [max([a.size[1] for a in row if a is not None]) for row in axes] colw = [max([axes[i][j].size[0] for i in range(ny) if axes[i][j] is not None]) for j in range(nx)] tsize = [float(sum(colw)), float(sum(rowh))] if size is None: size = tsize Ax = AxesWrapper(size = size) x, y = 0., 1. for i in range(ny): for j in range(nx): ax = axes[i][j] if ax is not None: w = ax.size[0] / tsize[0] px = (colw[j] - ax.size[0]) / tsize[0] / 2. h = ax.size[1] / tsize[1] py = (rowh[i] - ax.size[1]) / tsize[1] / 2. r = [x + px, y - h - py, x + w + px, y - py] Ax.add_axis(ax, r) x += colw[j] / tsize[0] x = 0. y -= rowh[i] / tsize[1] return Ax
# }}} def annotate(axes, text, pos='b'): # {{{ size = axes.size[0], axes.size[1] + 0.5 Ax = AxesWrapper(size=size) r = (0, 0.5 / float(size[1]), 1., 1.) Ax.add_axis(axes, r) Ax.text(0.5, 0., text, ha='center', va='bottom')#, transform='Ax') return Ax # }}} __all__.extend(['save', 'load', 'grid']) basemap_avail = True cartopy_avail = True try: from .basemap import * from .basemap import __all__ as bm_all __all__.extend(bm_all) def isbasemapaxis(axes): # {{{ return isinstance(axes, BasemapAxes) # }}} except ImportError: def isbasemapaxis(axes): # {{{ return False # }}} basemap_avail = False try: from .cartopy import * from .cartopy import __all__ as crt_all __all__.extend(crt_all) def iscartopyaxis(axes): # {{{ return isinstance(axes, CartopyAxes) # }}} except ImportError: def iscartopyaxis(axes): # {{{ return False # }}} cartopy_avail = False def ismapaxis(axis): return isbasemapaxis(axis) or iscartopyaxis(axis) if not (basemap_avail or cartopy_avail): import warnings warnings.warn('Neither Cartopy nor Basemap functionality is available.')