Source code for ezmsg.sigproc.aggregate

import typing

import numpy as np
import numpy.typing as npt
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import (
    AxisArray,
    slice_along_axis,
    AxisBase,
    replace,
)

from .spectral import OptionsEnum
from .base import (
    BaseStatefulTransformer,
    BaseTransformerUnit,
    processor_state,
)


[docs] class AggregationFunction(OptionsEnum): """Enum for aggregation functions available to be used in :obj:`ranged_aggregate` operation.""" NONE = "None (all)" MAX = "max" MIN = "min" MEAN = "mean" MEDIAN = "median" STD = "std" SUM = "sum" NANMAX = "nanmax" NANMIN = "nanmin" NANMEAN = "nanmean" NANMEDIAN = "nanmedian" NANSTD = "nanstd" NANSUM = "nansum" ARGMIN = "argmin" ARGMAX = "argmax" TRAPEZOID = "trapezoid"
AGGREGATORS = { AggregationFunction.NONE: np.all, AggregationFunction.MAX: np.max, AggregationFunction.MIN: np.min, AggregationFunction.MEAN: np.mean, AggregationFunction.MEDIAN: np.median, AggregationFunction.STD: np.std, AggregationFunction.SUM: np.sum, AggregationFunction.NANMAX: np.nanmax, AggregationFunction.NANMIN: np.nanmin, AggregationFunction.NANMEAN: np.nanmean, AggregationFunction.NANMEDIAN: np.nanmedian, AggregationFunction.NANSTD: np.nanstd, AggregationFunction.NANSUM: np.nansum, AggregationFunction.ARGMIN: np.argmin, AggregationFunction.ARGMAX: np.argmax, # Note: Some methods require x-coordinates and # are handled specially in `_process`. AggregationFunction.TRAPEZOID: np.trapezoid, }
[docs] class RangedAggregateSettings(ez.Settings): """ Settings for ``RangedAggregate``. """ axis: str | None = None """The name of the axis along which to apply the bands.""" bands: list[tuple[float, float]] | None = None """ [(band1_min, band1_max), (band2_min, band2_max), ...] If not set then this acts as a passthrough node. """ operation: AggregationFunction = AggregationFunction.MEAN """:obj:`AggregationFunction` to apply to each band."""
[docs] @processor_state class RangedAggregateState: slices: list[tuple[typing.Any, ...]] | None = None out_axis: AxisBase | None = None ax_vec: npt.NDArray | None = None
[docs] class RangedAggregateTransformer( BaseStatefulTransformer[ RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateState ] ): def __call__(self, message: AxisArray) -> AxisArray: # Override for shortcut passthrough mode. if self.settings.bands is None: return message return super().__call__(message) def _hash_message(self, message: AxisArray) -> int: axis = self.settings.axis or message.dims[0] target_axis = message.get_axis(axis) hash_components = (message.key,) if hasattr(target_axis, "data"): hash_components += (len(target_axis.data),) elif isinstance(target_axis, AxisArray.LinearAxis): hash_components += (target_axis.gain, target_axis.offset) return hash(hash_components) def _reset_state(self, message: AxisArray) -> None: axis = self.settings.axis or message.dims[0] target_axis = message.get_axis(axis) ax_idx = message.get_axis_idx(axis) if hasattr(target_axis, "data"): self._state.ax_vec = target_axis.data else: self._state.ax_vec = target_axis.value( np.arange(message.data.shape[ax_idx]) ) ax_dat = [] slices = [] for start, stop in self.settings.bands: inds = np.where( np.logical_and(self._state.ax_vec >= start, self._state.ax_vec <= stop) )[0] slices.append(np.s_[inds[0] : inds[-1] + 1]) if hasattr(target_axis, "data"): if self._state.ax_vec.dtype.type is np.str_: sl_dat = f"{self._state.ax_vec[start]} - {self._state.ax_vec[stop]}" else: ax_dat.append(np.mean(self._state.ax_vec[inds])) else: sl_dat = target_axis.value(np.mean(inds)) ax_dat.append(sl_dat) self._state.slices = slices self._state.out_axis = AxisArray.CoordinateAxis( data=np.array(ax_dat), dims=[axis], unit=target_axis.unit, ) def _process(self, message: AxisArray) -> AxisArray: axis = self.settings.axis or message.dims[0] ax_idx = message.get_axis_idx(axis) agg_func = AGGREGATORS[self.settings.operation] if self.settings.operation in [ AggregationFunction.TRAPEZOID, ]: # Special handling for methods that require x-coordinates. out_data = [ agg_func( slice_along_axis(message.data, sl, axis=ax_idx), x=self._state.ax_vec[sl], axis=ax_idx, ) for sl in self._state.slices ] else: out_data = [ agg_func(slice_along_axis(message.data, sl, axis=ax_idx), axis=ax_idx) for sl in self._state.slices ] msg_out = replace( message, data=np.stack(out_data, axis=ax_idx), axes={**message.axes, axis: self._state.out_axis}, ) if self.settings.operation in [ AggregationFunction.ARGMIN, AggregationFunction.ARGMAX, ]: out_data = [] for sl_ix, sl in enumerate(self._state.slices): offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx) out_data.append(self._state.ax_vec[sl][offsets]) msg_out.data = np.concatenate(out_data, axis=ax_idx) return msg_out
[docs] class RangedAggregate( BaseTransformerUnit[ RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateTransformer ] ): SETTINGS = RangedAggregateSettings
[docs] def ranged_aggregate( axis: str | None = None, bands: list[tuple[float, float]] | None = None, operation: AggregationFunction = AggregationFunction.MEAN, ) -> RangedAggregateTransformer: """ Apply an aggregation operation over one or more bands. Args: axis: The name of the axis along which to apply the bands. bands: [(band1_min, band1_max), (band2_min, band2_max), ...] If not set then this acts as a passthrough node. operation: :obj:`AggregationFunction` to apply to each band. Returns: :obj:`RangedAggregateTransformer` """ return RangedAggregateTransformer( RangedAggregateSettings(axis=axis, bands=bands, operation=operation) )