Source code for pygeode.ensemble

#TODO: allow variables to be omitted in some ensembles

from pygeode.axis import Index
class Ensemble(Index): pass
del Index

def make_ensemble (n, ensdict={}):
  if n not in ensdict: ensdict[n] = Ensemble(n)
  return ensdict[n]

from pygeode.var import Var
class EnsembleVar(Var):
  def __init__(self, varlist):
    from pygeode.var import Var, combine_meta
    self.varlist = varlist
    # assume the vars have already been checked for consistency
    var0 = varlist[0]
    axes = list(var0.axes)
#    axes = [Ensemble(len(varlist))] + axes
    axes = [make_ensemble(len(varlist))] + axes
    Var.__init__(self, axes, dtype=var0.dtype)
#    copy_meta (var0, self)
#    self.atts = common_dict(var.atts for var in varlist)
#    self.plotatts = common_dict(var.plotatts for var in varlist)
    combine_meta (varlist, self)
    self.name = varlist[0].name
  def getview (self, view, pbar):
    import numpy as np
    subview = view.remove(0)
    N = len(view.integer_indices[0])
    chunks = [[subview.get(self.varlist[i], pbar=pbar.part(n,N))]
              for n,i in enumerate(view.integer_indices[0])]
    return np.concatenate(chunks, axis=0)

del Var

# Collect vars into an ensemble

[docs]def ensemble (*varlists): """ Creates an ensemble out of a set of similar variables. The corresponding variable must have the same axes and the same name. If a bunch of vars are passed as inputs, then a single ensemble var is returned. If a bunch of datasets are passed as inputs, then a single dataset is returned, consisting of an ensemble of the internal vars. Each input dataset must have matching vars. """ from pygeode.var import Var from pygeode.dataset import Dataset, asdataset from pygeode.tools import common_dict datasets = [asdataset(v) for v in varlists] varnames = [v.name for v in datasets[0].vars] # Make sure we have the same varnames in each dataset for dataset in datasets: assert set(dataset.vardict.keys()) == set(varnames), "inconsistent variable names between datasets" # Make sure the varlists are all in the same order for i, dataset in enumerate(datasets): varlist = [dataset[varname] for varname in varnames] datasets[i] = Dataset(varlist, atts=dataset.atts) for varname in varnames: var0 = datasets[0][varname] for dataset in datasets: var = dataset[varname] # Make sure the axes are the same between ensemble vars assert var.axes == var0.axes, "inconsistent axes for %s"%varname # Collect the ensembles together ensembles = [] for varname in varnames: ensemble = EnsembleVar([dataset[varname] for dataset in datasets]) ensembles.append(ensemble) # Global attributes atts = common_dict(dataset.atts for dataset in datasets) if isinstance(varlists[0], Dataset): return Dataset(ensembles, atts=atts) if isinstance(varlists[0], Var): assert len(ensembles) == 1 return ensembles[0] return ensembles