#TODO: remove the 'concat' class method - put the code directly in the static 'concat' method. There are no longer any cases where Axis subclasses need to overload the merging logic.
#TODO: change the arguments for 'concat' from axes to *axes
#TODO: remove NamedAxis class - mostly redundant
#TODO: make map_to a light wrapper for common_map, since the latter is a more powerful version of the method
from pygeode.var import Var
try:
import pylab as pyl
class AxisFormatter(pyl.Formatter):
# {{{
def __init__(self, axis, fmt=None, unitstr=None, units=False):
if fmt is None:
fmt = axis.plotatts.get('plotfmt', None)
if fmt is None: fmt = axis.formatstr
if unitstr is None:
unitstr = axis.plotatts.get('plotunits', None)
if unitstr is None: unitstr = axis.units
self.fmt = fmt
self.unitstr = unitstr
self.showunits = units
self.pygaxis = axis
def __call__(self, x, pos=None):
return self.pygaxis.formatvalue(x, fmt=self.fmt, unitstr=self.unitstr, units=self.showunits)
# }}}
except ImportError:
from warnings import warn
warn ("Matplotlib not available; plotting functionality will be absent.")
# Axis parent class
[docs]class Axis(Var):
# {{{
"""
An object that describes a single dimension of a :class:`Var` object.
It is a subclass of :class:`Var`, so it can be used anywhere a Var would be
used.
Parameters
----------
values : numpy.ndarray
The coordinate values for each point along the axis. Should be monotonic.
name : string
A name used to reference this axis. Should be unique among the other
axes of a variable.
atts : dict
Any additional metadata to associate with the axis. The dictionary keys
should be strings.
See Also
--------
:doc:`var`
"""
# Default dictionaries: these are class defaults and are overwritten by child class defaults
#: Auxiliary arrays. These contain additionnal fields beyond the regular value array.
auxarrays = {}
#: Auxiliary attributes. These are preserved during merge/slice/etc operations.
auxatts = {}
#: Format specification for plotting values.
formatstr = '%g'
#: Relative tolerance for identifying two values of this axis as equal
rtol = 1e-5
#: Dictionary of attributes for plotting; see plotting documentation.
plotatts = Var.plotatts.copy()
[docs] def __init__(self, values, name=None, atts=None, plotatts=None, rtol=None, **kwargs):
# {{{
"""
Create a new Axis object with the given values.
Parameters
----------
values : numpy.ndarray
A one-dimensional coordinate defining the axis grid.
name : string (optional)
What to call the axis (i.e. for plot titles & when saving to file)
atts : dict (optional)
Any additional metadata to associate with the axis. The dictionary
keys should be strings.
plotatts : dict (optional)
Parameters that control plotting behaviour; default values are available.
The dictionary keys should be strings.
rtol : float
A relative tolerance used for identifying an element of this axis.
Notes
-----
All subclasses of :class:`Axis` need to call this __init__ method within
their own __init__, to properly initialize all attributes.
"""
import numpy as np
from pygeode.var import Var
# If a single integer given, expand to an integer range
#TODO: get rid of this? if you want to use integer indices, then make an appropriate 'Index' axis subclass?
if isinstance(values,int):
values = list(range(values))
values = np.asarray(values)
# Read configuration details
self.__class__._readaxisconfig(self)
# Note: Call init before hasattr (or don't use hasattr at all in here)
# (__getattr__ is overridden to call getaxis, which assumes axes are defined, otherwise __getattr__ is called to find an 'axes' property, ....)
Var.__init__(self, [self], values=values, name=name, atts=atts, plotatts=plotatts)
# Compute size of spacing relative to magnitude for relative tolerances when mapping
if rtol is None:
rtol = 1e-5
rtol = self.rtol
inz = np.where(values != 0.)[0]
if len(inz) > 1:
vnz = np.sort(values[inz]).astype('d')
logr = np.floor(np.min( np.log10(np.abs(np.diff(vnz) / vnz[:-1])) ))
if not np.isinf(logr) and 10**logr < rtol: rtol = 10**logr
#: The relative tolerance for identifying an element of this axis.
self.rtol = rtol
# Add auxilliary arrays after calling Var.__init__ - the weights
# array, if present, will be added here, not by the logic in Var.__init___
auxarrays = {}; auxatts = {}
for key, val in kwargs.items():
if isinstance(val,Var): val = val.get()
if isinstance(val,(list,tuple,np.ndarray)):
val = np.asarray(val)
if val.shape != self.values.shape:
raise ValueError('Auxilliary array %s has the wrong shape. Expected %s, got %s' % (key,self.values.shape, val.shape))
auxarrays[key] = val
else:
auxatts[key] = val
# update auxiliary attribute (make copy to not change class defaults)
self.auxarrays = self.__class__.auxarrays.copy()
self.auxarrays.update(auxarrays.copy())
self.auxatts = self.__class__.auxatts.copy()
self.auxatts.update(auxatts.copy())
# }}}
#
[docs] @classmethod
def isparentof(cls,other):
# {{{
"""
Determines if an axis object is an instance of a base class (or the same
class) of another axis.
Parameters
==========
other : :class:`Axis` object to compare against this one.
Returns
=======
bool : boolean
True if ``other`` is an instance of this object's class
"""
return isinstance(other,cls)
# }}}
@classmethod
def _readaxisconfig(cls, ax):
# {{{
from pygeode import _config
c = cls
nm = c.__name__
while c is not Axis:
if _config.has_option('Axes', nm + '.name'):
ax.name = str(_config.get('Axes', nm + '.name'))
break
else:
c = c.__bases__[0]
nm = c.__name__
if c is Axis: ax.name = cls.__name__.lower()
# Set basic attributes
for p in ['formatstr', 'units']:
if _config.has_option('Axes', nm + '.' + p):
setattr(ax, p, _config.get('Axes', nm + '.' + p))
for p in ['rtol']:
if _config.has_option('Axes', nm + '.' + p):
setattr(ax, p, _config.getfloat('Axes', nm + '.' + p))
# Set plot attributes
for p in ['plottitle', 'plotfmt', 'plotscale']:
if _config.has_option('Axes', nm + '.' + p):
ax.plotatts[p] = _config.get('Axes', nm + '.' + p)
for p in ['plotorder']:
if _config.has_option('Axes', nm + '.' + p):
ax.plotatts[p] = int(_config.getfloat('Axes', nm + '.' + p))
# }}}
#TODO: fix inconsistency between Axis and Var, for == and !=
# Vars produce a boolean mask under those operations, Axes return scalar True/False
# I.e., "lat == 30" and "(lat*1) == 30" give very different results!
def __ne__ (self, other): return not self.__eq__(other)
def __eq__ (self, other):
# {{{
'''override Var's ufunc stuff here
this is a weak comparison, in that we only require the other
axis to be a *subclass* of this one.
this allows things like "lat in [time,gausslat,lev]" to evaluate to True
If you want a more strict comparison, then check the classes explicitly yourself.'''
#TODO: do some testing to see just how many times this is called, if it will be a bottleneck for large axes
# print '<< Axis.__eq__ on', repr(self), 'and', repr(other), '>>'
# exact same object?
if self is other: return True
# incomparable?
if not isinstance(other,Axis):
# print 'not an axis?'
return False
if not self.isparentof(other) and not other.isparentof(self):
# print 'parent issues'
return False
# If they are generic Axis objects, an additional requirement is that they have the same name
if self.__class__ is Axis and other.__class__ is Axis:
if self.name != other.name: return False
# Check if they have the same lengths
if len(self.values) != len(other.values):
# print 'false by length'
return False
# Check the values
from numpy import allclose
if not allclose(self.values, other.values):
# print 'values mismatch'
return False
# Check auxiliary attributes
if set(self.auxatts.keys()) != set(other.auxatts.keys()): return False
for fname in self.auxatts.keys():
if self.auxatts[fname] != other.auxatts[fname]: return False
# Check any associated arrays
if set(self.auxarrays.keys()) != set(other.auxarrays.keys()):
# print 'false by mismatched set of auxarrays'
return False
# Check values of associated arrays
for fname in self.auxarrays.keys():
if not allclose(self.auxarrays[fname], other.auxarrays[fname]):
# print 'false by mismatched auxarray "%s":'%fname
return False
return True
# }}}
def alleq (self, *others):
# {{{
''' alleq(self, *others) - returns True if self matches with all axes in others.'''
for other in others:
if not self.__eq__(other): return False
return True
# }}}
#TODO: include associated arrays when doing the mapping?
[docs] def map_to (self, other):
# {{{
'''Returns indices of this axis which correspond to the axis ``other``.
Parameters
----------
other : :class:`Axis`
Axis to find mapping to
Returns
-------
mapping : integer array or None
Notes
-----
Returns an ordered indices of the elements of this axis that correspond to those of
the axis ``other``, if one exists, otherwise None. This axis must be a
parent class of ``other`` or vice versa in order for the mapping to
exist. The mapping may include only a subset of this axis object, but
must be as long as the other axis, if it is not None. The mapping
identifies equivalent elements based on equality up to a tolerance
specified by self.rtol.
'''
from pygeode.tools import map_to
import numpy as np
if not self.isparentof(other) and not other.isparentof(self): return None
# special case: both axes are identical
if self == other: return np.arange(len(self))
# Use less conservative tolerance?
#rtol = max(self.auxatts.get('rtol', 1e-5), other.auxatts.get('rtol', 1e-5))
return map_to(self.values, other.values, self.rtol)
# }}}
[docs] def sorted (self, reverse=None):
# {{{
"""
Sorts the points of the Axis.
Parameters
----------
reverse : boolean (optional)
If ``True``, sorts in descending order. If ``False``, sorts in ascending order.
By default the sign of self.plotorder is used.
Returns
-------
sorted_axis : Axis
A sorted version of the input axis.
Examples
--------
>>> from pygeode import Lat
>>> x = Lat([30,20,10])
>>> print(x)
lat <Lat> : 30 N to 10 N (3 values)
>>> y = x.sorted()
>>> print(y)
lat <Lat> : 10 N to 30 N (3 values)
See Also
--------
argsort
"""
S = self.argsort(reverse=reverse)
return self.slice[S]
# }}}
[docs] def argsort (self, reverse=None):
# {{{
"""
Generates a list of indices that would sort the Axis.
Parameters
----------
reverse : boolean (optional)
If ``False``, indices are in ascending order. If ``True``, will produce
indices for a *reverse* sort instead. By default, sign of self.plotorder is used.
Returns
-------
indices : list
The indices which will produces a sorted version of the Axis.
Examples
--------
>>> from pygeode import Lat
>>> x = Lat([20,30,10])
>>> print(x)
lat <Lat> : 20 N to 10 N (3 values)
>>> indices = x.argsort()
>>> print(indices)
[2 0 1]
>>> print(x.slice[indices])
lat <Lat> : 10 N to 30 N (3 values)
See Also
--------
sorted
"""
import numpy as np
S = np.argsort(self.values)
step = 1
if reverse is None: step = self.plotatts.get('plotorder', 1)
if reverse is True: step = -1
return S[::step]
# }}}
#TODO: implement and test this (if it's ever needed?)
# (make sure to check auxiliary arrays)
"""
def common_map (self, other):
'''return the indices that map common elements from one axis to another'''
from pygeode.tools import common_map
assert self.isparentof(other) or other.isparentof(self)
return common_map(self.values, other.values)
"""
# The length of an axis (equal to the length of the array of values)
def __len__ (self): return len(self.values)
"""
# Avoid accidentally iterating over axes
# (Would cause a new, 1-element axis to be created for each iteration)
def __iter__ (self):
raise Exception ("Axes cannot be iterated over")
"""
# Iterating over an axis iterates over the values
def __iter__ (self): return iter(self.values)
# Slice an axis -> construct a new one with the sliced values
# (Overridden from Var to preserve our Axis status)
[docs] def _getitem_asvar (self, slices):
# {{{
import numpy as np
values = np.array(self.values[slices], ndmin=1, copy=False)
# Check if we even need to do any slicing
if len(values) == len(self.values) and np.all(values==self.values): return self
aux = {}
# Keep auxiliary attributes
for key,val in self.auxatts.items():
aux[key] = val
# Slice auxiliary arrays
for key,val in self.auxarrays.items():
aux[key] = np.array(val[slices], ndmin=1)
axis = type(self)(values, name=self.name, atts=self.atts, **aux)
return axis
# }}}
[docs] def str_as_val(self, key, s):
# {{{
'''str_as_val(self, key, s) - converts string s to a value corresponding to this axis. Default
implementation returns float(s); derived classes can return different conversions depending
on whether key corresponds to the axis itself or an auxiliary array.'''
return float(s)
# }}}
def get_slice(self, kwargs, ignore_mismatch=False):
# {{{
import numpy as np
from pygeode.view import simplify, expand
# boolean flags indicating which axis indices will be used
n = len(self)
keep = np.ones(n,bool)
matched = []
for k, v in kwargs.items():
# Split off prefix if present
if '_' in k and not self.has_alias(k):
prefix, ax = k.split('_', 1)
else:
prefix, ax = '', k
if 'i' in prefix:
################### Select by index; key must correspond to this axis
if not self.has_alias(ax):
if ignore_mismatch: continue
raise Exception("'%s' is not associated with this %s axis" % (ax, self.name))
# Build mask
kp = np.zeros(n, bool)
if not hasattr(v, '__len__'): # Treat as an index
kp[v] = True
elif len(v) > 3 or len(v) < 2 or 'l' in prefix: # Treat as integer array
kp[np.array(v, 'int')] = True
elif len(v) == 2: # Treat as slice
kp[v[0]:v[1]] = True
elif len(v) == 3: # Treat as slice with stride
kp[v[0]:v[1]:v[2]] = True
else:
raise ValueException("'%s' is not associated with this %s axis" % (ax, self.name))
else:
################### Select by value
if self.has_alias(ax): # Does key match this axis?
vals = self.values
elif ax in self.auxarrays: # What about an aux. array?
vals = self.auxarrays[ax]
else:
if ignore_mismatch: continue
raise Exception("'%s' is not associated with this %s axis" % (ax, self.name))
# Build mask
kp = np.zeros(n, bool)
# Convert string representation if necessary
if isinstance(v, str): v = self.str_as_val(ax, v)
if isinstance(v,str) or not hasattr(v,'__len__'): # Single value given
if vals.dtype.name.startswith('float'): # closest match
kp[np.argmin( np.abs(v-vals) )] = True
else: # otherwise require an exact match
kp[vals == v] = True
elif 'l' in prefix:
for V in v:
# Convert string representation if necessary
if isinstance(V, str): V = self.str_as_val(ax, V)
if vals.dtype.name.startswith('float'): # closest match
kp[np.argmin( np.abs(V-vals) )] = True
else: # otherwise require an exact match
kp[vals == V] = True
elif len(v) == 2: # Select within range
# Convert string representations if necessary
v = [self.str_as_val(ax, V) if isinstance(V, str) else V for V in v]
lower, upper = min(v), max(v)
kp[(lower <= vals) & (vals <= upper)] = True
else: # Don't know what to do with more than 2 values
raise Exception('A range must be specified')
# Use complement of requested set
if 'n' in prefix: kp = ~kp
# Compute intersection of index sets
keep &= kp
matched.append(k) # Mark for removal from slice list
# Pop kw arguments that have been handled
for m in matched:
kwargs.pop(m)
# Convert boolean mask to integer indices
sl = np.flatnonzero(keep)
# Filter through view.simplify() to construct slice objects wherever possible
# (otherwise, it's a generic integer array)
return simplify(sl)
# }}}
# Keyword/value based slicing of an axis
# (Overridden from Var to preserve our Axis status)
[docs] def __call__ (self, **kwargs):
# {{{
sl = self.get_slice(kwargs)
return self._getitem_asvar(sl)
# }}}
# Get an axis attribute
# Overloaded from pygeode.Var to allow shortcuts to auxiliary arrays
def __getattr__ (self, name):
# {{{
# Disregard metaclass stuff
if name.startswith('__'): raise AttributeError
# print 'axis getattr ??', name
from pygeode.var import Var
if name in self.auxarrays: return self.auxarrays[name]
if name in self.auxatts: return self.auxatts[name]
return Var.__getattr__(self, name)
# }}}
[docs] def auxasvar (self, name):
# {{{
''' Returns auxiliary array as a new :class:`Var` object.
Parameters
==========
name : string
Name of auxiliary array to return
Returns
=======
var : :class:`Var`
Variable with values of requested auxilliary array
See Also
========
auxarrays
'''
from pygeode.var import Var
return Var([self], values=self.auxarrays[name], name=name)
# }}}
# Pretty printing
def __repr__ (self): return '<' + self.__class__.__name__ + '>'
def __str__ (self):
# {{{
if len(self) > 0:
first = self.formatvalue(self.values[0])
last = self.formatvalue(self.values[-1])
else: first = last = "<empty>"
num = str(len(self.values))
if self.name != '': head = self.name + ' ' + repr(self)
else: head = repr(self)
if len(self) > 1:
out = head.ljust(15) + ': '+first+' to '+last+' ('+num+' values)'
else:
out = head.ljust(15) + ': '+first
# return out+"\n"
return out
# }}}
# Rename an axis
[docs] def rename (self, name):
# {{{
"""
Assigns a new name to this axis.
Parameters
----------
name : string
The new name of this axis.
Returns
-------
renamed_axis : Axis
An instance of the same axis class with the new name.
"""
aux = {}
for k,v in self.auxatts.items(): aux[k] = v
for k,v in self.auxarrays.items(): aux[k] = v
return type(self)(values=self.values, name=name, atts=self.atts, **aux)
# }}}
# Check if a given string is a meaningful alias to the axis
# (i.e., if the string matches the name, or the class name, or one of the base class names)
[docs] @classmethod
def class_has_alias (cls, name):
# {{{
# if cls is Axis: return False # need a uniquely identifiable subclass of Axis
# Default string name for the class
if cls.name.lower() == name.lower(): return True
# A stringified version of the class name (i.e. Time => 'time')
if cls.__name__.lower() == name.lower(): return True
for subclass in cls.__bases__:
if not issubclass(subclass,Axis): continue # only iterate over Axis subclasses
if subclass.class_has_alias(name): return True
return False
# }}}
# Now, a function which can work on instances of a class
# Extends the above to also check the name given to the particular instance
# (which can depend on the source of the data loaded at runtime)
def has_alias (self, name):
# {{{
assert isinstance(name, str)
if self.name.lower() == name.lower(): return True
return self.class_has_alias(name)
# }}}
# Concatenate multiple axes together
# Use numpy arrays
# Assume the segments are pre-sorted
@classmethod
def concat (cls, axes):
# {{{
from numpy import concatenate
from pygeode.tools import common_dict
# Must all be same type of axis
for a in axes: assert isinstance(a,cls), 'axes must be the same type'
values = concatenate([a.values for a in axes])
# Get common attributes
atts = common_dict([a.atts for a in axes])
aux = {}
# Check that all pieces have the same auxiliary attributes, and propogate them to the output.
auxkeys = set(axes[0].auxatts.keys())
for a in axes[1:]:
auxkeys = auxkeys.intersection(list(a.auxatts.keys()))
for k in auxkeys:
vals = [a.auxatts[k] for a in axes]
v1 = vals[0]
# assert all(v == v1 for v in vals), "inconsistent '%s' attribute"%k
# Only use consistent aux atts
if all(v == v1 for v in vals):
aux[k] = axes[0].auxatts[k]
# Find and concatenate auxilliary arrays common to all axes being concatenated
auxkeys = set(axes[0].auxarrays.keys())
for a in axes[1:]: # set.intersection takes multiple arguments only in python 2.6 and later..
auxkeys = auxkeys.intersection(list(a.auxarrays.keys()))
for k in auxkeys:
aux[k] = concatenate([a.auxarrays[k] for a in axes])
name = axes[0].name #TODO: check all names?
return cls(values, name=name, atts=atts, **aux)
# }}}
# Replace the values of an axis
# (any auxiliary arrays from the old axis are ignored)
def withnewvalues (self, values):
# {{{
# Assume any auxiliary scalars are the same for the new axis
return type(self)(values, name=self.name, atts=self.atts, **self.auxatts)
# }}}
# }}}
# Useful axis subclasses
# Named axis
[docs]class NamedAxis (Axis):
# {{{
'''Generic axis object identified by its name.'''
[docs] def __init__ (self, values, name, **kwargs):
# {{{
Axis.__init__(self, values, **kwargs)
self.name = name
# }}}
def __eq__ (self, other):
# {{{
if type(other) is not NamedAxis: return False
# Check the names
if self.name != other.name: return False
# If the names match, check the values
return Axis.__eq__(self,other)
# }}}
def __repr__ (self): return "<%s '%s'>"%(self.__class__.__name__, self.name)
# Need more restrictions on mapping for named axes
# (not only do both axes need to be a NamedAxis, but they need to have the same name)
# (The name is the only way to uniquely identify them)
[docs] def map_to (self, other):
if not isinstance(other, NamedAxis): return None
if other.name != self.name: return None
return Axis.map_to(self, other)
# }}}
# Dummy axis (values are just placeholders).
# Useful when there is no intrinsic coordinate system for a dimension.
# Example could be a dimension in a netCDF file that has no variable
# associated with it.
class DummyAxis (NamedAxis): pass
[docs]class XAxis (Axis):
name = 'xaxis'
[docs]class YAxis (Axis):
name = 'yaxis'
[docs]class Lon (XAxis):
# {{{
''' Longitude axis. '''
name = 'lon'
#name = _config.get('Axes', 'Lon.name')
formatstr = '%.3g E<360'
plotatts = XAxis.plotatts.copy()
plotatts['plottitle'] = ''
plotatts['plotfmt'] = '%.3g E'
# }}}
[docs] def locator(self):
# {{{
import pylab as pyl
return pyl.MaxNLocator(nbins=9, steps=[1, 3, 6, 10])
# }}}
# }}}
[docs]def regularlon(n, origin=0., order=1, repeat_origin=False):
# {{{
'''Constructs a regularly spaced :class:`Lon` axis with n longitudes. The
values range from origin to origin + 360. If repeat_origin is set to True,
the final point is equal to origin + 360. '''
import numpy as np
vals = np.linspace(0., 360, n, endpoint=repeat_origin)[::order] + origin
return Lon(vals)
# }}}
[docs]def rotatelon(v, origin, duplicate = False):
# {{{
''' Rotates longitude axis to start at a new origin.
Parameters
----------
v : :class:`Var`
Variable with :class:`Lon` axis to modify
origin: float
New origin for longitude axis
duplicate: boolean; optional
If true, duplicates the origin value at the end of the axis (can be useful
for plotting). Default is ``False``.
Returns
-------
vn : :class:`Var`
Variable with modified longitude axis
Examples
========
>>> import pygeode as pyg; from pygeode.tutorial import t1
>>> print(t1.Temp.lon)
lon <Lon> : 0 E to 354 E (60 values)
>>> print(pyg.rotatelon(t1.Temp, -180))
<Var 'Temp'>:
Units: K Shape: (lat,lon) (31,60)
Axes:
lat <Lat> : 90 S to 90 N (31 values)
lon <Lon> : 180 E to 174 E (60 values)
Attributes:
{}
Type: SortedVar (dtype="float64")
>>> print(pyg.rotatelon(t1.Temp, 0, duplicate=True).lon[:])
[ 0. 6. 12. 18. 24. 30. 36. 42. 48. 54. 60. 66. 72. 78.
84. 90. 96. 102. 108. 114. 120. 126. 132. 138. 144. 150. 156. 162.
168. 174. 180. 186. 192. 198. 204. 210. 216. 222. 228. 234. 240. 246.
252. 258. 264. 270. 276. 282. 288. 294. 300. 306. 312. 318. 324. 330.
336. 342. 348. 354. 360.]
'''
from . import concatenate
if not v.hasaxis('lon'): return v
lons = v.lon[:]
o = origin % 360
off = o - origin
lp = Lon(values=(lons - o) % 360 + o - off)
v0 = v.replace_axes(lon=lp).sorted('lon')
if duplicate:
l360 = Lon(values = [origin + 360])
v0 = concatenate([v0, v0(lon=0).replace_axes(lon = l360)])
return v0
# }}}
[docs]class Lat (YAxis):
# {{{
''' Latitude axis. '''
name = 'lat'
formatstr = '%.2g N'
plotatts = YAxis.plotatts.copy()
plotatts['plottitle'] = ''
# Make sure we get some weights
[docs] def __init__(self, values, weights=None, **kwargs):
# {{{
from numpy import cos, asarray, pi
# Input weights are only along latitude, not area weight
# Output weights are area weights
# If no input weights given, assume uniform
#TODO: handle non-uniform latitudes?
if weights is None:
weights = cos(asarray(values) * pi / 180.)
Axis.__init__(self, values, weights=weights, **kwargs)
# }}}
# }}}
[docs] def locator(self):
# {{{
import pylab as pyl
return pyl.MaxNLocator(nbins=9, steps=[1, 1.5, 3, 5, 10])
# }}}
# }}}
[docs]def gausslat (n, order=1, axis_dict={}):
# {{{
'''Constructs a Gaussian :class:`Lat` axis with n latitudes.'''
from pygeode.quadrulepy import legendre_compute
import numpy as np
from math import pi
if (n,order) in axis_dict: return axis_dict[(n,order)]
x, w = legendre_compute(n)
x = np.arcsin(x) / pi * 180
x = x[::order]
w = w[::order]
axis = Lat (x, weights=w)
axis_dict[(n,order)] = axis
return axis
# }}}
[docs]def regularlat(n, order=1, inc_poles=True):
# {{{
'''Constructs a regularly spaced :class:`Lat` axis with n latitudes.
If inc_poles is set to True, the grid includes the poles. '''
import numpy as np
if inc_poles: vals = np.linspace(-90, 90, n)[::order]
else: vals = np.linspace(-90, 90, n+2)[1:-1][::order]
return Lat(vals)
# }}}
# Spectral axes
# Note: XAxis/YAxis is used by cccma code to put these in the proper order (XAxis is fastest-increasing)
class SpectralM(YAxis): name = 'm'
class SpectralN(XAxis): name = 'n'
# Vertical axes
[docs]class ZAxis (Axis):
# {{{
name = 'lev'
formatstr = '%3g'
# }}}
# Geometric height
#TODO: weights
#TODO: attributes
[docs]class Height(ZAxis):
# {{{
''' Geometric height axis. '''
name = 'z' # default name
formatstr = '%d'
units = 'm'
plotatts = ZAxis.plotatts.copy()
plotatts['plotname'] = 'Height' # name displayed in plots (axis label)
# }}}
# Model hybrid levels
#TODO: weights!
[docs]class Hybrid (ZAxis):
# {{{
''' Hybridized vertical coordinate axis. '''
name = 'eta' #TODO: rename this to 'hybrid'? (keep 'eta' for now, for compatibility with existing code)
formatstr = '%g'
plotatts = ZAxis.plotatts.copy()
plotatts['plotorder'] = -1
plotatts['plotscale'] = 'log'
[docs] def __init__ (self, values, A, B, **kwargs):
# {{{
# Just pass all the stuff to the superclass
# (All we do here is enforce the existence of 'A' and 'B' associated arrays
ZAxis.__init__ (self, values, A=A, B=B, **kwargs)
# }}}
def __eq__ (self, other):
# {{{
if not ZAxis.__eq__(self, other): return False
from numpy import allclose
if not allclose(self.A, other.A): return False
if not allclose(self.B, other.B): return False
return True
# }}}
[docs] def locator(self):
import pylab as pyl, numpy as np
ndecs = np.log10(np.max(self.values) / np.min(self.values))
if ndecs < 1.2: return pyl.LogLocator(subs=[1., 2., 4., 7.])
elif ndecs < 3.: return pyl.LogLocator(subs=[1., 3.])
else: return pyl.LogLocator()
# }}}
[docs]class Pres (ZAxis):
# {{{
''' Pressure height axis. '''
name = 'pres'
units = 'hPa'
formatstr = '%.2g<100'
plotatts = ZAxis.plotatts.copy()
plotatts['plotname'] = 'Pressure'
plotatts['plotscale'] = 'log'
plotatts['plotorder'] = -1
[docs] def logPAxis(self, p0=1000., H=7.1):
# {{{
'''logPAxis(p0, H) - returns a pygeode axis with log pressure heights
corresponding to this axis. By default p0 = 1000 hPa and H = 7.1 (km).'''
import numpy as np
z = ZAxis(H * np.log(p0 / self.values))
z.plotatts['plotname'] = 'Log-p Height'
return z
# }}}
[docs] def locator(self):
# {{{
import pylab as pyl, numpy as np
numdecs = np.log10(np.max(self.values) / np.min(self.values))
if numdecs < 1.2: subs = [1., 2., 4., 7.]
elif numdecs < 3.: subs = [1., 3.]
else: subs = [1.]
return pyl.LogLocator(subs = subs, numdecs = numdecs)
# }}}
# }}}
# }}}
# NOTE: Time axis is in pygeode.timeaxis
# It's a fairly heavyweight class, worthy of its own module.
# Importing it here would create a circular reference that would bugger things up
class Freq(Axis):
# {{{
name = 'freq'
# plotscale = 'log'
def __init__ (self, values, inv_units=None, *args, **kwargs):
if inv_units is not None:
self.units = '/'+inv_units
self.plotatts['plottitle'] = 'Frequency (per %s)'%inv_units
kwargs = kwargs.copy()
kwargs['inv_units'] = inv_units
Axis.__init__ (self, values, *args, **kwargs)
# }}}
# Indexing arrays (represent discrete items, such as EOF number, ensemble number, etc.)
class Index(Axis):
# {{{
def __init__ (self, n, *args, **kwargs):
values = n if hasattr(n,'__len__') else list(range(n))
Axis.__init__(self, values, *args, **kwargs)
# }}}
# Coefficient number (when returning a set of coefficients, such as for a polynomial fit)
class Coef(Index): pass
[docs]class NonCoordinateAxis(Axis):
# {{{
'''Non-coordinate axis (disables nearest-neighbour value matching, etc.)'''
# Refresh the coordinate values (should always be monotonically increasing integers).
[docs] def __init__ (self, *args, **kwargs):
# {{{
import numpy as np
lengths = [len(kw) for kw in list(kwargs.values()) if isinstance(kw,(list,tuple,np.ndarray))]
if len(lengths) == 0:
raise ValueError("Unable to determine a length for the non-coordinate axis.")
N = lengths[0]
kwargs['values'] = np.arange(N)
Axis.__init__(self, **kwargs)
# Remember original name
self._name = self.name
# }}}
# Modify test for equality to look for an exact match
#TODO: Make the default Axis.__eq__ logic do this, and move the "close"
# matching to a subclass of Axis.
[docs] def __eq__ (self, other):
# {{{
# For simplicity, expect them to be the same type of axis.
if type(self) != type(other): return False
if set(self.auxarrays.keys()) != set(other.auxarrays.keys()): return False
for k in list(self.auxarrays.keys()):
if list(self.auxarrays[k]) != list(other.auxarrays[k]): return False
return True
# }}}
# How to map string values to dummy indices
[docs] def str_as_val(self, key, s):
# {{{
# Special case: referencing an aux array with the same name as the axis.
if self._name in self.auxarrays:
values = list(self.auxarrays[self._name])
if s in values:
return values.index(s)
# Otherwise, return an invalid index (no match)
return -1
# }}}
# Modify formatvalue to convert dummy indices to the appropriate values
# }}}
# Modify map_to do use exact matching.
# (Avoids use of tools.map_to, which assumes the values are numerical)
#TODO: Make this the default for Axis (don't assume we have numerical values?)
[docs] def map_to (self, other):
# {{{
import numpy as np
# Only allow mapping non-coordinate axes if they're the exact same type.
if not isinstance(other,type(self)): return None
# Get keys to use for comparing aux arrays
keys = list(set(self.auxarrays.keys()) & set(other.auxarrays.keys()))
if len(keys) == 0: return None
values = list(zip(*[self.auxarrays[k] for k in keys]))
other_values = list(zip(*[other.auxarrays[k] for k in keys]))
#TODO: Speed this up? (move this to tools.c?)
values_set = set(values)
indices = []
for v in other_values:
if v not in values_set: continue
indices.append(values.index(v))
return indices
# }}}
# }}}
class Station(NonCoordinateAxis):
# {{{
'''Station axis (for timeseries data at fixed station locations)'''
name = "station"
# }}}
# Concatenate a bunch of axes together.
# Find a common parent class for all of them, and call that class's concat function.
def concat (axes):
# {{{
axes = list(axes) # in case we're given a set, generator, etc.
assert len(axes) > 0, 'nothing to concatenate!'
# Degenerate case: only 1 axis provided
if len(axes) == 1: return axes[0]
cls = type(axes[0])
for a in axes:
cls2 = type(a)
if issubclass(cls, cls2): cls = cls2
assert issubclass(cls2, cls), "can't concatenate incompatible axes"
return cls.concat(axes)
# }}}
# List of axes provided in this module (for easy importing)
standard_axes = [Axis, NamedAxis, XAxis, YAxis, ZAxis, TAxis, Lon, NonCoordinateAxis, regularlon, rotatelon, Lat, gausslat, regularlat, Pres, Hybrid, Height, SpectralM, SpectralN, Freq, Index]