#TODO: remove the 'concat' class method - put the code directly in the static 'concat' method.  There are no longer any cases where Axis subclasses need to overload the merging logic.
#TODO: change the arguments for 'concat' from axes to *axes
#TODO: remove NamedAxis class - mostly redundant
#TODO: make map_to a light wrapper for common_map, since the latter is a more powerful version of the method
from pygeode.var import Var
try:
  import pylab as pyl
  class AxisFormatter(pyl.Formatter):
  # {{{
    def __init__(self, axis, fmt=None, unitstr=None, units=False):
      if fmt is None:
        fmt = axis.plotatts.get('plotfmt', None)
        if fmt is None: fmt = axis.formatstr
      if unitstr is None:
        unitstr = axis.plotatts.get('plotunits', None)
        if unitstr is None: unitstr = axis.units
      self.fmt = fmt
      self.unitstr = unitstr
      self.showunits = units
      self.pygaxis = axis
    def __call__(self, x, pos=None):
      return self.pygaxis.formatvalue(x, fmt=self.fmt, unitstr=self.unitstr, units=self.showunits)
  # }}}
except ImportError:
  from warnings import warn
  warn ("Matplotlib not available; plotting functionality will be absent.")
# Axis parent class
[docs]class Axis(Var):
# {{{
  """
    An object that describes a single dimension of a :class:`Var` object.
    It is a subclass of :class:`Var`, so it can be used anywhere a Var would be
    used.
    Parameters
    ----------
    values : numpy.ndarray
      The coordinate values for each point along the axis.  Should be monotonic.
    name : string
      A name used to reference this axis.  Should be unique among the other
      axes of a variable.
    atts : dict
      Any additional metadata to associate with the axis.  The dictionary keys
      should be strings.
    See Also
    --------
    :doc:`var`
  """
  # Default dictionaries: these are class defaults and are overwritten by child class defaults
  #: Auxiliary arrays. These contain additionnal fields beyond the regular value array.
  auxarrays = {}
  #: Auxiliary attributes. These are preserved during merge/slice/etc operations.
  auxatts = {}
  #: Format specification for plotting values.
  formatstr = '%g'
  #: Relative tolerance for identifying two values of this axis as equal
  rtol = 1e-5
  #: Dictionary of attributes for plotting; see plotting documentation.
  plotatts = Var.plotatts.copy()
[docs]  def __init__(self, values, name=None, atts=None, plotatts=None, rtol=None, **kwargs):
# {{{
    """
    Create a new Axis object with the given values.
    Parameters
    ----------
    values : numpy.ndarray
        A one-dimensional coordinate defining the axis grid.
    name : string (optional)
        What to call the axis (i.e. for plot titles & when saving to file)
    atts : dict (optional)
        Any additional metadata to associate with the axis. The dictionary
        keys should be strings.
    plotatts : dict (optional)
        Parameters that control plotting behaviour; default values are available.
        The dictionary keys should be strings.
    rtol : float
        A relative tolerance used for identifying an element of this axis.
    Notes
    -----
    All subclasses of :class:`Axis` need to call this __init__ method within
    their own __init__, to properly initialize all attributes.
    """
    import numpy as np
    from pygeode.var import Var
    # If a single integer given, expand to an integer range
    #TODO: get rid of this?  if you want to use integer indices, then make an appropriate 'Index' axis subclass?
    if isinstance(values,int):
      values = list(range(values))
    values = np.asarray(values)
    # Read configuration details
    self.__class__._readaxisconfig(self)
    # Note: Call init before hasattr (or don't use hasattr at all in here)
    # (__getattr__ is overridden to call getaxis, which assumes axes are defined, otherwise __getattr__ is called to find an 'axes' property, ....)
    Var.__init__(self, [self], values=values, name=name, atts=atts, plotatts=plotatts)
    # Compute size of spacing relative to magnitude for relative tolerances when mapping
    if rtol is None:
      rtol = 1e-5
      rtol = self.rtol
      inz = np.where(values != 0.)[0]
      if len(inz) > 1:
        vnz = np.sort(values[inz]).astype('d')
        logr = np.floor(np.min( np.log10(np.abs(np.diff(vnz) / vnz[:-1])) ))
        if not np.isinf(logr) and 10**logr < rtol: rtol = 10**logr
    #: The relative tolerance for identifying an element of this axis.
    self.rtol = rtol
    # Add auxilliary arrays after calling Var.__init__ - the weights
    # array, if present, will be added here, not by the logic in Var.__init___
    auxarrays = {}; auxatts = {}
    for key, val in kwargs.items():
      if isinstance(val,Var): val = val.get()
      if isinstance(val,(list,tuple,np.ndarray)):
        val = np.asarray(val)
        if val.shape != self.values.shape:
          raise ValueError('Auxilliary array %s has the wrong shape.  Expected %s, got %s' % (key,self.values.shape, val.shape))
        auxarrays[key] = val
      else:
        auxatts[key] = val
    # update auxiliary attribute (make copy to not change class defaults)
    self.auxarrays = self.__class__.auxarrays.copy()
    self.auxarrays.update(auxarrays.copy())
    self.auxatts = self.__class__.auxatts.copy()
    self.auxatts.update(auxatts.copy()) 
# }}}
  #
[docs]  @classmethod
  def isparentof(cls,other):
  # {{{
    """
    Determines if an axis object is an instance of a base class (or the same
    class) of another axis.
    Parameters
    ==========
    other : :class:`Axis` object to compare against this one.
    Returns
    =======
    bool : boolean
      True if ``other`` is an instance of this object's class
    """
    return isinstance(other,cls) 
  # }}}
  @classmethod
  def _readaxisconfig(cls, ax):
  # {{{
    from pygeode import _config
    c = cls
    nm = c.__name__
    while c is not Axis:
      if _config.has_option('Axes',  nm + '.name'):
        ax.name = str(_config.get('Axes', nm + '.name'))
        break
      else:
        c = c.__bases__[0]
        nm = c.__name__
    if c is Axis: ax.name = cls.__name__.lower()
    # Set basic attributes
    for p in ['formatstr', 'units']:
      if _config.has_option('Axes',  nm + '.' + p):
        setattr(ax, p, _config.get('Axes', nm + '.' + p))
    for p in ['rtol']:
      if _config.has_option('Axes',  nm + '.' + p):
        setattr(ax, p, _config.getfloat('Axes', nm + '.' + p))
    # Set plot attributes
    for p in ['plottitle', 'plotfmt', 'plotscale']:
      if _config.has_option('Axes',  nm + '.' + p):
        ax.plotatts[p] = _config.get('Axes', nm + '.' + p)
    for p in ['plotorder']:
      if _config.has_option('Axes',  nm + '.' + p):
        ax.plotatts[p] = int(_config.getfloat('Axes', nm + '.' + p))
  # }}}
  #TODO: fix inconsistency between Axis and Var, for == and !=
  #      Vars produce a boolean mask under those operations, Axes return scalar True/False
  # I.e., "lat == 30" and "(lat*1) == 30" give very different results!
  def __ne__ (self, other): return not self.__eq__(other)
  def __eq__ (self, other):
  # {{{
    '''override Var's ufunc stuff here
       this is a weak comparison, in that we only require the other
       axis to be a *subclass* of this one.
       this allows things like "lat in [time,gausslat,lev]" to evaluate to True
       If you want a more strict comparison, then check the classes explicitly yourself.'''
    #TODO: do some testing to see just how many times this is called, if it will be a bottleneck for large axes
#    print '<< Axis.__eq__ on', repr(self), 'and', repr(other), '>>'
    # exact same object?
    if self is other: return True
    # incomparable?
    if not isinstance(other,Axis):
#      print 'not an axis?'
      return False
    if not self.isparentof(other) and not other.isparentof(self):
#      print 'parent issues'
      return False
    # If they are generic Axis objects, an additional requirement is that they have the same name
    if self.__class__ is Axis and other.__class__ is Axis:
      if self.name != other.name: return False
    # Check if they have the same lengths
    if len(self.values) != len(other.values):
#      print 'false by length'
      return False
    # Check the values
    from numpy import allclose
    if not allclose(self.values, other.values):
#      print 'values mismatch'
      return False
    # Check auxiliary attributes
    if set(self.auxatts.keys()) != set(other.auxatts.keys()): return False
    for fname in self.auxatts.keys():
      if self.auxatts[fname] != other.auxatts[fname]: return False
    # Check any associated arrays
    if set(self.auxarrays.keys()) != set(other.auxarrays.keys()):
#      print 'false by mismatched set of auxarrays'
      return False
    # Check values of associated arrays
    for fname in self.auxarrays.keys():
      if not allclose(self.auxarrays[fname], other.auxarrays[fname]):
#        print 'false by mismatched auxarray "%s":'%fname
        return False
    return True
  # }}}
  def alleq (self, *others):
  # {{{
    ''' alleq(self, *others) - returns True if self matches with all axes in others.'''
    for other in others:
      if not self.__eq__(other): return False
    return True
  # }}}
  #TODO: include associated arrays when doing the mapping?
[docs]  def map_to (self, other):
  # {{{
    '''Returns indices of this axis which correspond to the axis ``other``.
       Parameters
       ----------
       other : :class:`Axis`
         Axis to find mapping to
       Returns
       -------
       mapping : integer array or None
       Notes
       -----
       Returns an ordered indices of the elements of this axis that correspond to those of
       the axis ``other``, if one exists, otherwise None. This axis must be a
       parent class of ``other`` or vice versa in order for the mapping to
       exist. The mapping may include only a subset of this axis object, but
       must be as long as the other axis, if it is not None. The mapping
       identifies equivalent elements based on equality up to a tolerance
       specified by self.rtol.
       '''
    from pygeode.tools import map_to
    import numpy as np
    if not self.isparentof(other) and not other.isparentof(self): return None
    # special case: both axes are identical
    if self == other: return np.arange(len(self))
    # Use less conservative tolerance?
    #rtol = max(self.auxatts.get('rtol', 1e-5), other.auxatts.get('rtol', 1e-5))
    return map_to(self.values, other.values, self.rtol) 
  # }}}
[docs]  def sorted (self, reverse=None):
# {{{
    """
    Sorts the points of the Axis.
    Parameters
    ----------
    reverse : boolean (optional)
      If ``True``, sorts in descending order. If ``False``, sorts in ascending order.
      By default the sign of self.plotorder is used.
    Returns
    -------
    sorted_axis : Axis
      A sorted version of the input axis.
    Examples
    --------
    >>> from pygeode import Lat
    >>> x = Lat([30,20,10])
    >>> print(x)
    lat <Lat>      :  30 N to 10 N (3 values)
    >>> y = x.sorted()
    >>> print(y)
    lat <Lat>      :  10 N to 30 N (3 values)
    See Also
    --------
    argsort
    """
    S = self.argsort(reverse=reverse)
    return self.slice[S] 
# }}}
[docs]  def argsort (self, reverse=None):
# {{{
    """
      Generates a list of indices that would sort the Axis.
      Parameters
      ----------
      reverse : boolean (optional)
        If ``False``, indices are in ascending order. If ``True``, will produce
        indices for a *reverse* sort instead. By default, sign of self.plotorder is used.
      Returns
      -------
      indices : list
        The indices which will produces a sorted version of the Axis.
      Examples
      --------
      >>> from pygeode import Lat
      >>> x = Lat([20,30,10])
      >>> print(x)
      lat <Lat>      :  20 N to 10 N (3 values)
      >>> indices = x.argsort()
      >>> print(indices)
      [2 0 1]
      >>> print(x.slice[indices])
      lat <Lat>      :  10 N to 30 N (3 values)
      See Also
      --------
      sorted
    """
    import numpy as np
    S = np.argsort(self.values)
    step = 1
    if reverse is None: step = self.plotatts.get('plotorder', 1)
    if reverse is True: step = -1
    return S[::step] 
# }}}
  #TODO: implement and test this (if it's ever needed?)
  # (make sure to check auxiliary arrays)
  """
  def common_map (self, other):
    '''return the indices that map common elements from one axis to another'''
    from pygeode.tools import common_map
    assert self.isparentof(other) or other.isparentof(self)
    return common_map(self.values, other.values)
  """
  # The length of an axis (equal to the length of the array of values)
  def __len__ (self): return len(self.values)
  """
  # Avoid accidentally iterating over axes
  # (Would cause a new, 1-element axis to be created for each iteration)
  def __iter__ (self):
    raise Exception ("Axes cannot be iterated over")
  """
  # Iterating over an axis iterates over the values
  def __iter__ (self):  return iter(self.values)
  # Slice an axis -> construct a new one with the sliced values
  # (Overridden from Var to preserve our Axis status)
[docs]  def _getitem_asvar (self, slices):
  # {{{
    import numpy as np
    values = np.array(self.values[slices], ndmin=1, copy=False)
    # Check if we even need to do any slicing
    if len(values) == len(self.values) and np.all(values==self.values): return self
    aux = {}
    # Keep auxiliary attributes
    for key,val in self.auxatts.items():
      aux[key] = val
    # Slice auxiliary arrays
    for key,val in self.auxarrays.items():
      aux[key] = np.array(val[slices], ndmin=1)
    axis = type(self)(values, name=self.name, atts=self.atts, **aux)
    return axis 
  # }}}
[docs]  def str_as_val(self, key, s):
# {{{
    '''str_as_val(self, key, s) - converts string s to a value corresponding to this axis. Default
        implementation returns float(s); derived classes can return different conversions depending
        on whether key corresponds to the axis itself or an auxiliary array.'''
    return float(s) 
# }}}
  def get_slice(self, kwargs, ignore_mismatch=False):
# {{{
    import numpy as np
    from pygeode.view import simplify, expand
    # boolean flags indicating which axis indices will be used
    n = len(self)
    keep = np.ones(n,bool)
    matched = []
    for k, v in kwargs.items():
       # Split off prefix if present
      if '_' in k and not self.has_alias(k):
        prefix, ax = k.split('_', 1)
      else:
        prefix, ax = '', k
      if 'i' in prefix:
        ################### Select by index; key must correspond to this axis
        if not self.has_alias(ax):
          if ignore_mismatch: continue
          raise Exception("'%s' is not associated with this %s axis" % (ax, self.name))
        # Build mask
        kp = np.zeros(n, bool)
        if not hasattr(v, '__len__'): # Treat as an index
          kp[v] = True
        elif len(v) > 3 or len(v) < 2 or 'l' in prefix: # Treat as integer array
          kp[np.array(v, 'int')] = True
        elif len(v) == 2:         # Treat as slice
          kp[v[0]:v[1]] = True
        elif len(v) == 3:         # Treat as slice with stride
          kp[v[0]:v[1]:v[2]] = True
        else:
          raise ValueException("'%s' is not associated with this %s axis" % (ax, self.name))
      else:
        ################### Select by value
        if self.has_alias(ax):     # Does key match this axis?
          vals = self.values
        elif ax in self.auxarrays: # What about an aux. array?
          vals = self.auxarrays[ax]
        else:
          if ignore_mismatch: continue
          raise Exception("'%s' is not associated with this %s axis" % (ax, self.name))
        # Build mask
        kp = np.zeros(n, bool)
        # Convert string representation if necessary
        if isinstance(v, str): v = self.str_as_val(ax, v)
        if isinstance(v,str) or not hasattr(v,'__len__'): # Single value given
          if vals.dtype.name.startswith('float'): # closest match
            kp[np.argmin( np.abs(v-vals) )] = True
          else:                 # otherwise require an exact match
            kp[vals == v] = True
        elif 'l' in prefix:
          for V in v:
            # Convert string representation if necessary
            if isinstance(V, str): V = self.str_as_val(ax, V)
            if vals.dtype.name.startswith('float'): # closest match
              kp[np.argmin( np.abs(V-vals) )] = True
            else:                 # otherwise require an exact match
              kp[vals == V] = True
        elif len(v) == 2:       # Select within range
          # Convert string representations if necessary
          v = [self.str_as_val(ax, V) if isinstance(V, str) else V for V in v]
          lower, upper = min(v), max(v)
          kp[(lower <= vals) & (vals <= upper)] = True
        else:                   # Don't know what to do with more than 2 values
          raise Exception('A range must be specified')
      # Use complement of requested set
      if 'n' in prefix: kp = ~kp
      # Compute intersection of index sets
      keep &= kp
      matched.append(k)       # Mark for removal from slice list
    # Pop kw arguments that have been handled
    for m in matched:
      kwargs.pop(m)
    # Convert boolean mask to integer indices
    sl = np.flatnonzero(keep)
    # Filter through view.simplify() to construct slice objects wherever possible
    # (otherwise, it's a generic integer array)
    return simplify(sl)
# }}}
  # Keyword/value based slicing of an axis
  # (Overridden from Var to preserve our Axis status)
[docs]  def __call__ (self, **kwargs):
# {{{
    sl = self.get_slice(kwargs)
    return self._getitem_asvar(sl) 
# }}}
  # Get an axis attribute
  # Overloaded from pygeode.Var to allow shortcuts to auxiliary arrays
  def __getattr__ (self, name):
# {{{
    # Disregard metaclass stuff
    if name.startswith('__'): raise AttributeError
#    print 'axis getattr ??', name
    from pygeode.var import Var
    if name in self.auxarrays: return self.auxarrays[name]
    if name in self.auxatts: return self.auxatts[name]
    return Var.__getattr__(self, name)
# }}}
[docs]  def auxasvar (self, name):
# {{{
    ''' Returns auxiliary array as a new :class:`Var` object.
        Parameters
        ==========
        name : string
            Name of auxiliary array to return
        Returns
        =======
        var : :class:`Var`
            Variable with values of requested auxilliary array
        See Also
        ========
        auxarrays
    '''
    from pygeode.var import Var
    return Var([self], values=self.auxarrays[name], name=name) 
# }}}
  # Pretty printing
  def __repr__ (self): return '<' + self.__class__.__name__ + '>'
  def __str__ (self):
  # {{{
    if len(self) > 0:
      first = self.formatvalue(self.values[0])
      last = self.formatvalue(self.values[-1])
    else: first = last = "<empty>"
    num = str(len(self.values))
    if self.name != '': head = self.name + ' ' + repr(self)
    else: head = repr(self)
    if len(self) > 1:
      out = head.ljust(15) + ':  '+first+' to '+last+' ('+num+' values)'
    else:
      out = head.ljust(15) + ':  '+first
#    return out+"\n"
    return out
  # }}}
  # Rename an axis
[docs]  def rename (self, name):
# {{{
    """
    Assigns a new name to this axis.
    Parameters
    ----------
    name : string
      The new name of this axis.
    Returns
    -------
    renamed_axis : Axis
      An instance of the same axis class with the new name.
    """
    aux = {}
    for k,v in self.auxatts.items(): aux[k] = v
    for k,v in self.auxarrays.items(): aux[k] = v
    return type(self)(values=self.values, name=name, atts=self.atts, **aux) 
# }}}
  # Check if a given string is a meaningful alias to the axis
  # (i.e., if the string matches the name, or the class name, or one of the base class names)
[docs]  @classmethod
  def class_has_alias (cls, name):
  # {{{
#    if cls is Axis: return False  # need a uniquely identifiable subclass of Axis
    # Default string name for the class
    if cls.name.lower() == name.lower(): return True
    # A stringified version of the class name (i.e. Time => 'time')
    if cls.__name__.lower() == name.lower(): return True
    for subclass in cls.__bases__:
      if not issubclass(subclass,Axis): continue  # only iterate over Axis subclasses
      if subclass.class_has_alias(name): return True
    return False 
  # }}}
  # Now, a function which can work on instances of a class
  # Extends the above to also check the name given to the particular instance
  # (which can depend on the source of the data loaded at runtime)
  def has_alias (self, name):
  # {{{
    assert isinstance(name, str)
    if self.name.lower() == name.lower(): return True
    return self.class_has_alias(name)
  # }}}
  # Concatenate multiple axes together
  # Use numpy arrays
  # Assume the segments are pre-sorted
  @classmethod
  def concat (cls, axes):
  # {{{
    from numpy import concatenate
    from pygeode.tools import common_dict
    # Must all be same type of axis
    for a in axes: assert isinstance(a,cls), 'axes must be the same type'
    values = concatenate([a.values for a in axes])
    # Get common attributes
    atts = common_dict([a.atts for a in axes])
    aux = {}
    # Check that all pieces have the same auxiliary attributes, and propogate them to the output.
    auxkeys = set(axes[0].auxatts.keys())
    for a in axes[1:]:
      auxkeys = auxkeys.intersection(list(a.auxatts.keys()))
    for k in auxkeys:
      vals = [a.auxatts[k] for a in axes]
      v1 = vals[0]
#      assert all(v == v1 for v in vals), "inconsistent '%s' attribute"%k
      # Only use consistent aux atts
      if all(v == v1 for v in vals):
        aux[k] = axes[0].auxatts[k]
    # Find and concatenate auxilliary arrays common to all axes being concatenated
    auxkeys = set(axes[0].auxarrays.keys())
    for a in axes[1:]:      # set.intersection takes multiple arguments only in python 2.6 and later..
      auxkeys = auxkeys.intersection(list(a.auxarrays.keys()))
    for k in auxkeys:
      aux[k] = concatenate([a.auxarrays[k] for a in axes])
    name = axes[0].name  #TODO: check all names?
    return cls(values, name=name, atts=atts, **aux)
  # }}}
  # Replace the values of an axis
  # (any auxiliary arrays from the old axis are ignored)
  def withnewvalues (self, values):
  # {{{
    # Assume any auxiliary scalars are the same for the new axis
    return type(self)(values, name=self.name, atts=self.atts, **self.auxatts) 
  # }}}
# }}}
# Useful axis subclasses
# Named axis
[docs]class NamedAxis (Axis):
# {{{
  '''Generic axis object identified by its name.'''
[docs]  def __init__ (self, values, name, **kwargs):
  # {{{
    Axis.__init__(self, values, **kwargs)
    self.name = name 
  # }}}
  def __eq__ (self, other):
  # {{{
    if type(other) is not NamedAxis: return False
    # Check the names
    if self.name != other.name: return False
    # If the names match, check the values
    return Axis.__eq__(self,other)
  # }}}
  def __repr__ (self): return "<%s '%s'>"%(self.__class__.__name__, self.name)
  # Need more restrictions on mapping for named axes
  # (not only do both axes need to be a NamedAxis, but they need to have the same name)
  # (The name is the only way to uniquely identify them)
[docs]  def map_to (self, other):
    if not isinstance(other, NamedAxis): return None
    if other.name != self.name: return None
    return Axis.map_to(self, other)  
# }}}
# Dummy axis (values are just placeholders).
# Useful when there is no intrinsic coordinate system for a dimension.
# Example could be a dimension in a netCDF file that has no variable
# associated with it.
class DummyAxis (NamedAxis): pass
[docs]class XAxis (Axis):
  name = 'xaxis' 
[docs]class YAxis (Axis):
  name = 'yaxis' 
[docs]class Lon (XAxis):
# {{{
  ''' Longitude axis. '''
  name = 'lon'
  #name = _config.get('Axes', 'Lon.name')
  formatstr = '%.3g E<360'
  plotatts = XAxis.plotatts.copy()
  plotatts['plottitle'] = ''
  plotatts['plotfmt'] = '%.3g E'
# }}}
[docs]  def locator(self):
  # {{{
    import pylab as pyl
    return pyl.MaxNLocator(nbins=9, steps=[1, 3, 6, 10])  
  # }}}
# }}}
[docs]def regularlon(n, origin=0., order=1, repeat_origin=False):
# {{{
  '''Constructs a regularly spaced :class:`Lon` axis with n longitudes. The
  values range from origin to origin + 360. If repeat_origin is set to True,
  the final point is equal to origin + 360. '''
  import numpy as np
  vals = np.linspace(0., 360, n, endpoint=repeat_origin)[::order] + origin
  return Lon(vals) 
# }}}
[docs]def rotatelon(v, origin, duplicate = False):
# {{{
	''' Rotates longitude axis to start at a new origin.
	Parameters
	----------
  v : :class:`Var`
    Variable with :class:`Lon` axis to modify
  origin: float
    New origin for longitude axis
  duplicate: boolean; optional
    If true, duplicates the origin value at the end of the axis (can be useful
    for plotting).  Default is ``False``.
  Returns
  -------
  vn : :class:`Var`
    Variable with modified longitude axis
  Examples
  ========
  >>> import pygeode as pyg; from pygeode.tutorial import t1
  >>> print(t1.Temp.lon)
  lon <Lon>      :  0 E to 354 E (60 values)
  >>> print(pyg.rotatelon(t1.Temp, -180))
  <Var 'Temp'>:
    Units: K  Shape:  (lat,lon)  (31,60)
    Axes:
      lat <Lat>      :  90 S to 90 N (31 values)
      lon <Lon>      :  180 E to 174 E (60 values)
    Attributes:
      {}
    Type:  SortedVar (dtype="float64")
  >>> print(pyg.rotatelon(t1.Temp, 0, duplicate=True).lon[:])
  [  0.   6.  12.  18.  24.  30.  36.  42.  48.  54.  60.  66.  72.  78.
    84.  90.  96. 102. 108. 114. 120. 126. 132. 138. 144. 150. 156. 162.
   168. 174. 180. 186. 192. 198. 204. 210. 216. 222. 228. 234. 240. 246.
   252. 258. 264. 270. 276. 282. 288. 294. 300. 306. 312. 318. 324. 330.
   336. 342. 348. 354. 360.]
	'''
	from . import concatenate
	if not v.hasaxis('lon'): return v
	lons = v.lon[:]
	o = origin % 360
	off = o - origin
	lp = Lon(values=(lons - o) % 360 + o - off)
	v0 = v.replace_axes(lon=lp).sorted('lon')
	if duplicate:
		l360 = Lon(values = [origin + 360])
		v0 = concatenate([v0, v0(lon=0).replace_axes(lon = l360)])
	return v0 
# }}}
[docs]class Lat (YAxis):
# {{{
  ''' Latitude axis. '''
  name = 'lat'
  formatstr = '%.2g N'
  plotatts = YAxis.plotatts.copy()
  plotatts['plottitle'] = ''
  # Make sure we get some weights
[docs]  def __init__(self, values, weights=None, **kwargs):
  # {{{
    from numpy import cos, asarray, pi
    # Input weights are only along latitude, not area weight
    # Output weights are area weights
    # If no input weights given, assume uniform
    #TODO: handle non-uniform latitudes?
    if weights is None:
      weights = cos(asarray(values) * pi / 180.)
    Axis.__init__(self, values, weights=weights, **kwargs) 
  # }}}
# }}}
[docs]  def locator(self):
  # {{{
    import pylab as pyl
    return pyl.MaxNLocator(nbins=9, steps=[1, 1.5, 3, 5, 10])  
  # }}}
# }}}
[docs]def gausslat (n, order=1, axis_dict={}):
# {{{
  '''Constructs a Gaussian :class:`Lat` axis with n latitudes.'''
  from pygeode.quadrulepy import legendre_compute
  import numpy as np
  from math import pi
  if (n,order) in axis_dict: return axis_dict[(n,order)]
  x, w = legendre_compute(n)
  x = np.arcsin(x) / pi * 180
  x = x[::order]
  w = w[::order]
  axis = Lat (x, weights=w)
  axis_dict[(n,order)] = axis
  return axis 
# }}}
[docs]def regularlat(n, order=1, inc_poles=True):
# {{{
  '''Constructs a regularly spaced :class:`Lat` axis with n latitudes.
  If inc_poles is set to True, the grid includes the poles. '''
  import numpy as np
  if inc_poles: vals = np.linspace(-90, 90, n)[::order]
  else: vals = np.linspace(-90, 90, n+2)[1:-1][::order]
  return Lat(vals) 
# }}}
# Spectral axes
# Note: XAxis/YAxis is used by cccma code to put these in the proper order (XAxis is fastest-increasing)
class SpectralM(YAxis): name = 'm'
class SpectralN(XAxis): name = 'n'
# Vertical axes
[docs]class ZAxis (Axis):
# {{{
  name = 'lev'
  formatstr = '%3g' 
# }}}
# Geometric height
#TODO: weights
#TODO: attributes
[docs]class Height(ZAxis):
# {{{
  ''' Geometric height axis. '''
  name = 'z' # default name
  formatstr = '%d'
  units = 'm'
  plotatts = ZAxis.plotatts.copy()
  plotatts['plotname'] = 'Height' # name displayed in plots (axis label) 
# }}}
# Model hybrid levels
#TODO: weights!
[docs]class Hybrid (ZAxis):
# {{{
  ''' Hybridized vertical coordinate axis. '''
  name = 'eta'  #TODO: rename this to 'hybrid'?  (keep 'eta' for now, for compatibility with existing code)
  formatstr = '%g'
  plotatts = ZAxis.plotatts.copy()
  plotatts['plotorder'] = -1
  plotatts['plotscale'] = 'log'
[docs]  def __init__ (self, values, A, B, **kwargs):
  # {{{
    # Just pass all the stuff to the superclass
    # (All we do here is enforce the existence of 'A' and 'B' associated arrays
    ZAxis.__init__ (self, values, A=A, B=B, **kwargs) 
  # }}}
  def __eq__ (self, other):
  # {{{
    if not ZAxis.__eq__(self, other): return False
    from numpy import allclose
    if not allclose(self.A, other.A): return False
    if not allclose(self.B, other.B): return False
    return True
  # }}}
[docs]  def locator(self):
    import pylab as pyl, numpy as np
    ndecs = np.log10(np.max(self.values) / np.min(self.values))
    if ndecs < 1.2: return pyl.LogLocator(subs=[1., 2., 4., 7.])
    elif ndecs < 3.: return pyl.LogLocator(subs=[1., 3.])
    else: return pyl.LogLocator()  
# }}}
[docs]class Pres (ZAxis):
# {{{
  ''' Pressure height axis. '''
  name = 'pres'
  units = 'hPa'
  formatstr = '%.2g<100'
  plotatts = ZAxis.plotatts.copy()
  plotatts['plotname'] = 'Pressure'
  plotatts['plotscale'] = 'log'
  plotatts['plotorder'] = -1
[docs]  def logPAxis(self, p0=1000., H=7.1):
# {{{
    '''logPAxis(p0, H) - returns a pygeode axis with log pressure heights
          corresponding to this axis. By default p0 = 1000 hPa and H = 7.1 (km).'''
    import numpy as np
    z = ZAxis(H * np.log(p0 / self.values))
    z.plotatts['plotname'] = 'Log-p Height'
    return z 
# }}}
[docs]  def locator(self):
# {{{
    import pylab as pyl, numpy as np
    numdecs = np.log10(np.max(self.values) / np.min(self.values))
    if numdecs < 1.2: subs = [1., 2., 4., 7.]
    elif numdecs < 3.: subs = [1., 3.]
    else: subs = [1.]
    return pyl.LogLocator(subs = subs, numdecs = numdecs) 
# }}}
 
# }}}
# }}}
# NOTE: Time axis is in pygeode.timeaxis
# It's a fairly heavyweight class, worthy of its own module.
# Importing it here would create a circular reference that would bugger things up
class Freq(Axis):
# {{{
  name = 'freq'
#  plotscale = 'log'
  def __init__ (self, values, inv_units=None, *args, **kwargs):
    if inv_units is not None:
      self.units = '/'+inv_units
      self.plotatts['plottitle'] = 'Frequency (per %s)'%inv_units
    kwargs = kwargs.copy()
    kwargs['inv_units'] = inv_units
    Axis.__init__ (self, values, *args, **kwargs)
# }}}
# Indexing arrays (represent discrete items, such as EOF number, ensemble number, etc.)
class Index(Axis):
# {{{
  def __init__ (self, n, *args, **kwargs):
    values = n if hasattr(n,'__len__') else list(range(n))
    Axis.__init__(self, values, *args, **kwargs)
# }}}
# Coefficient number (when returning a set of coefficients, such as for a polynomial fit)
class Coef(Index): pass
[docs]class NonCoordinateAxis(Axis):
# {{{
  '''Non-coordinate axis (disables nearest-neighbour value matching, etc.)'''
  # Refresh the coordinate values (should always be monotonically increasing integers).
[docs]  def __init__ (self, *args, **kwargs):
# {{{
    import numpy as np
    lengths = [len(kw) for kw in list(kwargs.values()) if isinstance(kw,(list,tuple,np.ndarray))]
    if len(lengths) == 0:
      raise ValueError("Unable to determine a length for the non-coordinate axis.")
    N = lengths[0]
    kwargs['values'] = np.arange(N)
    Axis.__init__(self, **kwargs)
    # Remember original name
    self._name = self.name 
# }}}
  # Modify test for equality to look for an exact match
  #TODO: Make the default Axis.__eq__ logic do this, and move the "close"
  # matching to a subclass of Axis.
[docs]  def __eq__ (self, other):
# {{{
    # For simplicity, expect them to be the same type of axis.
    if type(self) != type(other): return False
    if set(self.auxarrays.keys()) != set(other.auxarrays.keys()): return False
    for k in list(self.auxarrays.keys()):
      if list(self.auxarrays[k]) != list(other.auxarrays[k]): return False
    return True 
# }}}
  # How to map string values to dummy indices
[docs]  def str_as_val(self, key, s):
# {{{
    # Special case: referencing an aux array with the same name as the axis.
    if self._name in self.auxarrays:
      values = list(self.auxarrays[self._name])
      if s in values:
        return values.index(s)
    # Otherwise, return an invalid index (no match)
    return -1 
# }}}
  # Modify formatvalue to convert dummy indices to the appropriate values
# }}}
  # Modify map_to do use exact matching.
  # (Avoids use of tools.map_to, which assumes the values are numerical)
  #TODO: Make this the default for Axis (don't assume we have numerical values?)
[docs]  def map_to (self, other):
# {{{
    import numpy as np
    # Only allow mapping non-coordinate axes if they're the exact same type.
    if not isinstance(other,type(self)): return None
    # Get keys to use for comparing aux arrays
    keys = list(set(self.auxarrays.keys()) & set(other.auxarrays.keys()))
    if len(keys) == 0: return None
    values = list(zip(*[self.auxarrays[k] for k in keys]))
    other_values = list(zip(*[other.auxarrays[k] for k in keys]))
    #TODO: Speed this up? (move this to tools.c?)
    values_set = set(values)
    indices = []
    for v in other_values:
      if v not in values_set: continue
      indices.append(values.index(v))
    return indices  
# }}}
# }}}
class Station(NonCoordinateAxis):
# {{{
  '''Station axis (for timeseries data at fixed station locations)'''
  name = "station"
# }}}
# Concatenate a bunch of axes together.
# Find a common parent class for all of them, and call that class's concat function.
def concat (axes):
# {{{
  axes = list(axes) # in case we're given a set, generator, etc.
  assert len(axes) > 0, 'nothing to concatenate!'
  # Degenerate case: only 1 axis provided
  if len(axes) == 1: return axes[0]
  cls = type(axes[0])
  for a in axes:
    cls2 = type(a)
    if issubclass(cls, cls2): cls = cls2
    assert issubclass(cls2, cls), "can't concatenate incompatible axes"
  return cls.concat(axes)
# }}}
# List of axes provided in this module (for easy importing)
standard_axes = [Axis, NamedAxis, XAxis, YAxis, ZAxis, TAxis, Lon, NonCoordinateAxis, regularlon, rotatelon, Lat, gausslat, regularlat, Pres, Hybrid, Height, SpectralM, SpectralN, Freq, Index]