Source code for ezmsg.sigproc.transpose

"""
Transpose or permute array dimensions.

.. note::
    This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
    enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
    Memory layout optimization (C/F order) only applies to NumPy arrays.
"""

from types import EllipsisType

import ezmsg.core as ez
import numpy as np
from array_api_compat import get_namespace, is_numpy_array
from ezmsg.baseproc import (
    BaseStatefulTransformer,
    BaseTransformerUnit,
    processor_state,
)
from ezmsg.util.messages.axisarray import (
    AxisArray,
    replace,
)


[docs] class TransposeSettings(ez.Settings): """ Settings for :obj:`Transpose` node. Fields: axes: """ axes: tuple[int | str | EllipsisType, ...] | None = None order: str | None = None
[docs] @processor_state class TransposeState: axes_ints: tuple[int, ...] | None = None
[docs] class TransposeTransformer(BaseStatefulTransformer[TransposeSettings, AxisArray, AxisArray, TransposeState]): """ Downsampled data simply comprise every `factor`th sample. This should only be used following appropriate lowpass filtering. If your pipeline does not already have lowpass filtering then consider using the :obj:`Decimate` collection instead. """ def _hash_message(self, message: AxisArray) -> int: return hash(tuple(message.dims)) def _reset_state(self, message: AxisArray) -> None: if self.settings.axes is None: self._state.axes_ints = None else: ell_ix = [ix for ix, ax in enumerate(self.settings.axes) if ax is Ellipsis] if len(ell_ix) > 1: raise ValueError("Only one Ellipsis is allowed in axes.") ell_ix = ell_ix[0] if len(ell_ix) == 1 else len(message.dims) prefix = [] for ax in self.settings.axes[:ell_ix]: if isinstance(ax, int): prefix.append(ax) else: if ax not in message.dims: raise ValueError(f"Axis {ax} not found in message dims.") prefix.append(message.dims.index(ax)) suffix = [] for ax in self.settings.axes[ell_ix + 1 :]: if isinstance(ax, int): suffix.append(ax) else: if ax not in message.dims: raise ValueError(f"Axis {ax} not found in message dims.") suffix.append(message.dims.index(ax)) ells = [_ for _ in range(message.data.ndim) if _ not in prefix and _ not in suffix] re_ix = tuple(prefix + ells + suffix) if re_ix == tuple(range(message.data.ndim)): self._state.axes_ints = None else: self._state.axes_ints = re_ix if self.settings.order is not None and self.settings.order.upper()[0] not in [ "C", "F", ]: raise ValueError("order must be 'C' or 'F'.") def __call__(self, message: AxisArray) -> AxisArray: if self.settings.axes is None and self.settings.order is None: # Passthrough return message return super().__call__(message) def _process(self, message: AxisArray) -> AxisArray: xp = get_namespace(message.data) if self.state.axes_ints is None: # No transpose required if self.settings.order is None: # No memory relayout required # Note: We should not be able to reach here because it should be shortcutted at passthrough. msg_out = message else: # Memory layout optimization only applies to numpy arrays if is_numpy_array(message.data): msg_out = replace( message, data=np.require(message.data, requirements=self.settings.order.upper()[0]), ) else: msg_out = message else: dims_out = [message.dims[ix] for ix in self.state.axes_ints] data_out = xp.permute_dims(message.data, axes=self.state.axes_ints) if self.settings.order is not None and is_numpy_array(data_out): # Memory layout optimization only applies to numpy arrays data_out = np.require(data_out, requirements=self.settings.order.upper()[0]) msg_out = replace( message, data=data_out, dims=dims_out, ) return msg_out
[docs] class Transpose(BaseTransformerUnit[TransposeSettings, AxisArray, AxisArray, TransposeTransformer]): SETTINGS = TransposeSettings
[docs] def transpose( axes: tuple[int | str | EllipsisType, ...] | None = None, order: str | None = None ) -> TransposeTransformer: return TransposeTransformer(TransposeSettings(axes=axes, order=order))