Source code for pygeode.ext_xarray

# Functions for converting PyGeode objects to other (external) projects, and
# vice-versa.

[docs]def to_xarray(dataset): """ Converts a PyGeode Dataset into an xarray Dataset. Parameters ---------- dataset : pygeode.Dataset The dataset to be converted. Returns ------- out : xarray.Dataset An object which can be used with the xarray package. """ from pygeode.dataset import asdataset from pygeode.formats.cfmeta import encode_cf from pygeode.view import View from dask.base import tokenize import dask.array as da import xarray as xr dataset = asdataset(dataset) # Encode the axes/variables with CF metadata. dataset = encode_cf(dataset) out = dict() # Loop over each axis and variable. for var in list(dataset.axes) + list(dataset.vars): # Generate a unique name to identify it with dask. name = var.name + "-" + tokenize(var) dsk = dict() dims = [a.name for a in var.axes] # Special case: already have the values in memory. if hasattr(var,'values'): out[var.name] = xr.DataArray(var.values, dims=dims, attrs=var.atts, name=var.name) continue # Keep track of all the slices that were made over each dimension. # This information will be used to determine the "chunking" that was done # on the variable from inview.loop_mem(). slice_order = [[] for a in var.axes] chunks = [] # Break up the variable into into portions that are small enough to fit # in memory. These will become the "chunks" for dask. inview = View(var.axes) for outview in inview.loop_mem(): integer_indices = list(map(tuple,outview.integer_indices)) # Determine *how* loop_mem is splitting the axes, and define the chunk # sizes accordingly. # A little indirect, but loop_mem doesn't make its chunking choices # available to the caller. for o, sl in zip(slice_order, integer_indices): if sl not in o: o.append(sl) ind = [o.index(sl) for o, sl in zip(slice_order, integer_indices)] # Add this chunk to the dask array. key = tuple([name] + ind) dsk[key] = (var.getview, outview, False) # Construct the dask array. chunks = [list(map(len,sl)) for sl in slice_order] arr = da.Array(dsk, name, chunks, dtype=var.dtype) # Wrap this into an xarray.DataArray (with metadata and named axes). out[var.name] = xr.DataArray(arr, dims = dims, attrs = var.atts, name=var.name) # Build the final xarray.Dataset. out = xr.Dataset(out, attrs=dataset.atts) # Re-decode the CF metadata on the xarray side. out = xr.conventions.decode_cf(out) return out
# Helper method - convert unicode attributes to str. def _fix_atts (atts): import sys if sys.version_info[0] >= 3: unicode = str atts = dict((str(k),v) for k,v in atts.items()) for k,v in list(atts.items()): if isinstance(v,unicode): atts[k] = str(v) return atts from pygeode.var import Var class XArray_DataArray(Var): """ A wrapper for accessing xarray.DataArray objects as pygeode.Var objects. """ def __init__ (self, name, arr): from pygeode.var import Var from pygeode.axis import NamedAxis self._arr = arr # Extract axes and metadata. # Convert unicode strings to str for compatibility with PyGeode. axes = [NamedAxis(n,str(d)) for n,d in zip(arr.shape,arr.dims)] atts = _fix_atts(arr.attrs) Var.__init__(self, axes, name=str(name), dtype=arr.dtype, atts=atts) def getview (self, view, pbar): import numpy as np out = np.asarray(self._arr[view.slices]) pbar.update(100) return out del Var
[docs]def from_xarray(dataset): """ Converts an xarray Dataset into a PyGeode Dataset. Parameters ---------- dataset : xarray.Dataset The dataset to be converted. Returns ------- out : pygeode.Dataset An object which can be used with the pygeode package. """ import xarray as xr from pygeode.dataset import Dataset from pygeode.formats.netcdf import dims2axes from pygeode.formats.cfmeta import decode_cf # Encode the axes/variables with CF metadata. out = [] # Loop over each axis and variable, and wrap as a pygeode.Var object. for varname, var in dataset.variables.items(): # Apply a subset of conventions that are relevant to PyGeode. try: var = xr.conventions.maybe_encode_datetime(var) var = xr.conventions.maybe_encode_timedelta(var) except AttributeError: var = xr.coding.times.CFDatetimeCoder().encode(var) var = xr.coding.times.CFTimedeltaCoder().encode(var) try: var = xr.conventions.maybe_encode_string_dtype(var) except AttributeError: pass # Using an older version of xarray (<0.10.0)? out.append(XArray_DataArray(varname, var)) # Wrap all the Var objects into a pygeode.Dataset object. out = Dataset(out, atts=_fix_atts(dataset.attrs)) # Re-construct the axes as pygeode.axis.NamedAxis objects. out = dims2axes(out) # Re-decode the CF metadata on the PyGeode end. # This will get the approperiate axis types for lat, lon, time, etc. out = decode_cf(out) return out