Source code for sympl._core.tracers

from .exceptions import InvalidPropertyDictError
import numpy as np
from .units import units_are_same
from .get_np_arrays import get_numpy_arrays_with_properties
from .restore_dataarray import restore_data_arrays_with_properties

_tracer_unit_dict = {}
_tracer_names = []
_packers = set()

def reset_tracers():
    global _tracer_names
    while len(_tracer_unit_dict) > 0:
    while len(_tracer_names) > 0:

def reset_packers():
    while len(_packers) > 0:

[docs]def register_tracer(name, units): """ Parameters ---------- name : str Quantity name to register as a tracer. units : str Unit string of that quantity. """ if name in _tracer_unit_dict and not units_are_same(_tracer_unit_dict[name], units): raise ValueError( 'Tracer {} is already registered with units {} which are different ' 'from input units {}'.format( name, _tracer_unit_dict[name], units ) ) _tracer_unit_dict[name] = units _tracer_names.append(name) for packer in _packers: packer.ensure_tracer_not_in_outputs(name, units)
[docs]def get_tracer_unit_dict(): """ Returns ------- unit_dict : dict A dictionary whose keys are tracer quantity names as str and values are units of those quantities as str. """ return_dict = {} return_dict.update(_tracer_unit_dict) return return_dict
[docs]def get_tracer_names(): """ Returns ------- tracer_names : tuple of str Tracer names in the order that they will appear in tracer arrays.\ """ return tuple(_tracer_names)
def get_tracer_input_properties(prepend_tracers, tracer_dims): """ Args: prepend_tracers (list of tuple): Pairs of (name, units) describing tracers that are to be included in addition to any registered tracers. tracer_dims (list): Dimensions to use for each tracer (e.g. ['dim1', 'dim2']). Returns: input_properties (dict): A properties dictionary for registered and additional tracers. """ tracer_dims = list(tracer_dims) if 'tracer' in tracer_dims: tracer_dims.remove('tracer') tracer_units = {} tracer_units.update(get_tracer_unit_dict()) tracer_units.update(dict(prepend_tracers)) tracer_properties = {} for name, units in tracer_units.items(): tracer_properties[name] = {'units': units, 'dims': tracer_dims} return tracer_properties def get_quantity_dims(tracer_dims): if 'tracer' not in tracer_dims: raise ValueError("Tracer dims must include a dimension named 'tracer'") quantity_dims = list(tracer_dims) quantity_dims.remove('tracer') return tuple(quantity_dims) class TracerPacker(object): def __init__(self, component, tracer_dims, prepend_tracers=None): self._prepend_tracers = prepend_tracers or () self._tracer_dims = tuple(tracer_dims) self._tracer_quantity_dims = get_quantity_dims(tracer_dims) if hasattr(component, 'tendency_properties') or hasattr(component, 'output_properties'): self.component = component else: raise TypeError( 'Expected a component object subclassing type Stepper, ' 'ImplicitTendencyComponent, or TendencyComponent but received component of ' 'type {}'.format(component.__class__.__name__)) for name, units in self._prepend_tracers: if name not in _tracer_unit_dict.keys(): self.ensure_tracer_not_in_outputs(name, units) for name, units in _tracer_unit_dict.items(): self.ensure_tracer_not_in_outputs(name, units) _packers.add(self) def ensure_tracer_not_in_outputs(self, name, units): if hasattr(self.component, 'tendency_properties'): self._ensure_tracer_not_in_tendency_properties(name, units) elif hasattr(self.component, 'output_properties'): self.ensure_tracer_not_in_output_properties(name, units) def ensure_tracer_not_in_output_properties(self, name, units): if name in self.component.output_properties.keys(): raise InvalidPropertyDictError( 'Attempted to insert {} as tracer to component of type {} but ' 'it already has that quantity defined as an output.'.format( name, self.component.__class__.__name__ ) ) def _ensure_tracer_not_in_tendency_properties(self, name, units): if name in self.component.tendency_properties.keys(): raise InvalidPropertyDictError( 'Attempted to insert {} as tracer to component of type {} but ' 'it already has that quantity defined as a tendency ' 'output.'.format( name, self.component.__class__.__name__ ) ) def _ensure_tracer_quantity_dims(self, dims): if tuple(self._tracer_quantity_dims) != tuple(dims): raise InvalidPropertyDictError( 'Tracers have conflicting dims {} and {}'.format( self._tracer_quantity_dims, dims) ) def register_tracers(self, unit_dict, input_properties): for name, units in unit_dict.items(): if name not in input_properties.keys(): input_properties[name] = { 'dims': self._tracer_quantity_dims, 'units': units, 'tracer': True, } @property def tracer_names(self): return_list = [] for name, units in self._prepend_tracers: return_list.append(name) for name in _tracer_names: if name not in return_list: return_list.append(name) return tuple(return_list) @property def _tracer_index(self): return self._tracer_dims.index('tracer') def pack(self, state): """ Args: state (dict): A state dictionary. Returns: tracer_array (ndarray): An array containing the tracer data, with dimensions as specified by tracer_dims on initializing this object. """ tracer_properties = get_tracer_input_properties( self._prepend_tracers, self._tracer_quantity_dims) raw_state = get_numpy_arrays_with_properties(state, tracer_properties) if len(self.tracer_names) == 0: shape = [0 for dim in self._tracer_dims] else: shape = list(raw_state[self.tracer_names[0]].shape) shape.insert(self._tracer_index, len(self.tracer_names)) array = np.empty(shape, dtype=np.float64) for i, name in enumerate(self.tracer_names): tracer_slice = [slice(0, d) for d in shape] tracer_slice[self._tracer_index] = i array[tracer_slice] = raw_state[name] return array def unpack(self, tracer_array, input_state, multiply_unit=''): """ Args: tracer_array (ndarray): An array containing tracer values, with dimensions as specified by tracer_dims on initializing this object. input_state (dict): A state dictionary from which the tracer array was originally packed. multiply_unit (str, optional): A unit string which should be multiplied to the units of each output in the returned DataArrays, for example to represent the units over which a time difference was taken. Returns: tracer_dict (dict): A dictionary whose keys are tracer names and values are DataArrays containing the values of each tracer. """ tracer_properties = get_tracer_input_properties( self._prepend_tracers, self._tracer_quantity_dims) raw_state = {} for i, name in enumerate(self.tracer_names): tracer_slice = [slice(0, d) for d in tracer_array.shape] tracer_slice[self._tracer_index] = i raw_state[name] = tracer_array[tracer_slice] out_properties = {} for name, properties in tracer_properties.items(): out_properties[name] = properties.copy() if multiply_unit is not '': out_properties[name]['units'] = '{} {}'.format( out_properties[name]['units'], multiply_unit) return_state = restore_data_arrays_with_properties( raw_state, out_properties, input_state, tracer_properties) return return_state