Source code for ezmsg.sigproc.aggregate

"""
Aggregation operations over arrays.

.. note::
    :obj:`AggregateTransformer` supports the :doc:`Array API standard </guides/explanations/array_api>`,
    enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
    :obj:`RangedAggregateTransformer` currently requires NumPy arrays.
"""

import typing

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from array_api_compat import get_namespace
from ezmsg.baseproc import (
    BaseStatefulTransformer,
    BaseTransformer,
    BaseTransformerUnit,
    processor_state,
)
from ezmsg.util.messages.axisarray import (
    AxisArray,
    AxisBase,
    replace,
    slice_along_axis,
)

from .spectral import OptionsEnum


[docs] class AggregationFunction(OptionsEnum): """Enum for aggregation functions available to be used in :obj:`ranged_aggregate` operation.""" NONE = "None (all)" MAX = "max" MIN = "min" MEAN = "mean" MEDIAN = "median" STD = "std" SUM = "sum" NANMAX = "nanmax" NANMIN = "nanmin" NANMEAN = "nanmean" NANMEDIAN = "nanmedian" NANSTD = "nanstd" NANSUM = "nansum" ARGMIN = "argmin" ARGMAX = "argmax" TRAPEZOID = "trapezoid"
AGGREGATORS = { AggregationFunction.NONE: np.all, AggregationFunction.MAX: np.max, AggregationFunction.MIN: np.min, AggregationFunction.MEAN: np.mean, AggregationFunction.MEDIAN: np.median, AggregationFunction.STD: np.std, AggregationFunction.SUM: np.sum, AggregationFunction.NANMAX: np.nanmax, AggregationFunction.NANMIN: np.nanmin, AggregationFunction.NANMEAN: np.nanmean, AggregationFunction.NANMEDIAN: np.nanmedian, AggregationFunction.NANSTD: np.nanstd, AggregationFunction.NANSUM: np.nansum, AggregationFunction.ARGMIN: np.argmin, AggregationFunction.ARGMAX: np.argmax, # Note: Some methods require x-coordinates and # are handled specially in `_process`. AggregationFunction.TRAPEZOID: np.trapezoid, }
[docs] class RangedAggregateSettings(ez.Settings): """ Settings for ``RangedAggregate``. """ axis: str | None = None """The name of the axis along which to apply the bands.""" bands: list[tuple[float, float]] | None = None """ [(band1_min, band1_max), (band2_min, band2_max), ...] If not set then this acts as a passthrough node. """ operation: AggregationFunction = AggregationFunction.MEAN """:obj:`AggregationFunction` to apply to each band."""
[docs] @processor_state class RangedAggregateState: slices: list[tuple[typing.Any, ...]] | None = None out_axis: AxisBase | None = None ax_vec: npt.NDArray | None = None
[docs] class RangedAggregateTransformer( BaseStatefulTransformer[RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateState] ): def __call__(self, message: AxisArray) -> AxisArray: # Override for shortcut passthrough mode. if self.settings.bands is None: return message return super().__call__(message) def _hash_message(self, message: AxisArray) -> int: axis = self.settings.axis or message.dims[0] target_axis = message.get_axis(axis) hash_components = (message.key,) if hasattr(target_axis, "data"): hash_components += (len(target_axis.data),) elif isinstance(target_axis, AxisArray.LinearAxis): hash_components += (target_axis.gain, target_axis.offset) return hash(hash_components) def _reset_state(self, message: AxisArray) -> None: axis = self.settings.axis or message.dims[0] target_axis = message.get_axis(axis) ax_idx = message.get_axis_idx(axis) if hasattr(target_axis, "data"): self._state.ax_vec = target_axis.data else: self._state.ax_vec = target_axis.value(np.arange(message.data.shape[ax_idx])) ax_dat = [] slices = [] for start, stop in self.settings.bands: inds = np.where(np.logical_and(self._state.ax_vec >= start, self._state.ax_vec <= stop))[0] slices.append(np.s_[inds[0] : inds[-1] + 1]) if hasattr(target_axis, "data"): if self._state.ax_vec.dtype.type is np.str_: sl_dat = f"{self._state.ax_vec[start]} - {self._state.ax_vec[stop]}" else: ax_dat.append(np.mean(self._state.ax_vec[inds])) else: sl_dat = target_axis.value(np.mean(inds)) ax_dat.append(sl_dat) self._state.slices = slices self._state.out_axis = AxisArray.CoordinateAxis( data=np.array(ax_dat), dims=[axis], unit=target_axis.unit, ) def _process(self, message: AxisArray) -> AxisArray: axis = self.settings.axis or message.dims[0] ax_idx = message.get_axis_idx(axis) agg_func = AGGREGATORS[self.settings.operation] if self.settings.operation in [ AggregationFunction.TRAPEZOID, ]: # Special handling for methods that require x-coordinates. out_data = [ agg_func( slice_along_axis(message.data, sl, axis=ax_idx), x=self._state.ax_vec[sl], axis=ax_idx, ) for sl in self._state.slices ] else: out_data = [ agg_func(slice_along_axis(message.data, sl, axis=ax_idx), axis=ax_idx) for sl in self._state.slices ] msg_out = replace( message, data=np.stack(out_data, axis=ax_idx), axes={**message.axes, axis: self._state.out_axis}, ) if self.settings.operation in [ AggregationFunction.ARGMIN, AggregationFunction.ARGMAX, ]: out_data = [] for sl_ix, sl in enumerate(self._state.slices): offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx) out_data.append(self._state.ax_vec[sl][offsets]) msg_out.data = np.concatenate(out_data, axis=ax_idx) return msg_out
[docs] class RangedAggregate(BaseTransformerUnit[RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateTransformer]): SETTINGS = RangedAggregateSettings
[docs] def ranged_aggregate( axis: str | None = None, bands: list[tuple[float, float]] | None = None, operation: AggregationFunction = AggregationFunction.MEAN, ) -> RangedAggregateTransformer: """ Apply an aggregation operation over one or more bands. Args: axis: The name of the axis along which to apply the bands. bands: [(band1_min, band1_max), (band2_min, band2_max), ...] If not set then this acts as a passthrough node. operation: :obj:`AggregationFunction` to apply to each band. Returns: :obj:`RangedAggregateTransformer` """ return RangedAggregateTransformer(RangedAggregateSettings(axis=axis, bands=bands, operation=operation))
[docs] class AggregateSettings(ez.Settings): """Settings for :obj:`Aggregate`.""" axis: str """The name of the axis to aggregate over. This axis will be removed from the output.""" operation: AggregationFunction = AggregationFunction.MEAN """:obj:`AggregationFunction` to apply."""
[docs] class AggregateTransformer(BaseTransformer[AggregateSettings, AxisArray, AxisArray]): """ Transformer that aggregates an entire axis using a specified operation. Unlike :obj:`RangedAggregateTransformer` which aggregates over specific ranges/bands and preserves the axis (with one value per band), this transformer aggregates the entire axis and removes it from the output, reducing dimensionality by one. """ def _process(self, message: AxisArray) -> AxisArray: xp = get_namespace(message.data) axis_idx = message.get_axis_idx(self.settings.axis) op = self.settings.operation if op == AggregationFunction.NONE: raise ValueError("AggregationFunction.NONE is not supported for full-axis aggregation") if op == AggregationFunction.TRAPEZOID: # Trapezoid integration requires x-coordinates target_axis = message.get_axis(self.settings.axis) if hasattr(target_axis, "data"): x = target_axis.data else: x = target_axis.value(np.arange(message.data.shape[axis_idx])) agg_data = np.trapezoid(np.asarray(message.data), x=x, axis=axis_idx) else: # Try array-API compatible function first, fall back to numpy func_name = op.value if hasattr(xp, func_name): agg_data = getattr(xp, func_name)(message.data, axis=axis_idx) else: agg_data = AGGREGATORS[op](message.data, axis=axis_idx) new_dims = list(message.dims) new_dims.pop(axis_idx) new_axes = dict(message.axes) new_axes.pop(self.settings.axis, None) return replace( message, data=agg_data, dims=new_dims, axes=new_axes, )
[docs] class AggregateUnit(BaseTransformerUnit[AggregateSettings, AxisArray, AxisArray, AggregateTransformer]): """Unit that aggregates an entire axis using a specified operation.""" SETTINGS = AggregateSettings