Array API Support ================= ezmsg-sigproc provides support for the `Python Array API standard `_, enabling many transformers to work with arrays from different backends such as NumPy, CuPy, PyTorch, and JAX. What is the Array API? ---------------------- The Array API is a standardized interface for array operations across different Python array libraries. By coding to this standard, ezmsg-sigproc transformers can process data regardless of which array library created it, enabling: - **GPU acceleration** via CuPy or PyTorch tensors - **Framework interoperability** for integration with ML pipelines - **Hardware flexibility** without code changes How It Works ------------ Compatible transformers use `array-api-compat `_ to detect the input array's namespace and use the appropriate operations: .. code-block:: python from array_api_compat import get_namespace def _process(self, message: AxisArray) -> AxisArray: xp = get_namespace(message.data) # numpy, cupy, torch, etc. result = xp.abs(message.data) # Uses the correct backend return replace(message, data=result) Usage Example ------------- Using Array API compatible transformers with CuPy for GPU acceleration: .. code-block:: python import cupy as cp from ezmsg.util.messages.axisarray import AxisArray from ezmsg.sigproc.math.abs import AbsTransformer from ezmsg.sigproc.math.clip import ClipTransformer, ClipSettings # Create data on GPU gpu_data = cp.random.randn(1000, 64).astype(cp.float32) message = AxisArray(gpu_data, dims=["time", "ch"]) # Process entirely on GPU - no data transfer! abs_transformer = AbsTransformer() clip_transformer = ClipTransformer(ClipSettings(min=0.0, max=1.0)) result = clip_transformer(abs_transformer(message)) # result.data is still a CuPy array on GPU Compatible Modules ------------------ The following transformers fully support the Array API standard: Math Operations ^^^^^^^^^^^^^^^ .. list-table:: :header-rows: 1 :widths: 30 70 * - Module - Description * - :mod:`ezmsg.sigproc.math.abs` - Absolute value * - :mod:`ezmsg.sigproc.math.clip` - Clip values to a range * - :mod:`ezmsg.sigproc.math.log` - Logarithm with configurable base * - :mod:`ezmsg.sigproc.math.scale` - Multiply by a constant * - :mod:`ezmsg.sigproc.math.invert` - Compute 1/x * - :mod:`ezmsg.sigproc.math.difference` - Subtract a constant (ConstDifferenceTransformer) Signal Processing ^^^^^^^^^^^^^^^^^ .. list-table:: :header-rows: 1 :widths: 30 70 * - Module - Description * - :mod:`ezmsg.sigproc.diff` - Compute differences along an axis * - :mod:`ezmsg.sigproc.transpose` - Transpose/permute array dimensions * - :mod:`ezmsg.sigproc.linear` - Per-channel linear transform (scale + offset) * - :mod:`ezmsg.sigproc.aggregate` - Aggregate operations (AggregateTransformer only) Coordinate Transforms ^^^^^^^^^^^^^^^^^^^^^ .. list-table:: :header-rows: 1 :widths: 30 70 * - Module - Description * - :mod:`ezmsg.sigproc.coordinatespaces` - Cartesian/polar coordinate conversions Limitations ----------- Some operations remain NumPy-only due to lack of Array API equivalents: - **Random number generation**: Modules using ``np.random`` (e.g., ``denormalize``) - **SciPy operations**: Filtering (``scipy.signal.lfilter``), FFT, wavelets - **Advanced indexing**: Some slicing operations for metadata handling - **Memory layout**: ``np.require`` for contiguous array optimization (NumPy only) Metadata arrays (axis labels, coordinates) typically remain as NumPy arrays since they are not performance-critical. Adding Array API Support ------------------------ When contributing new transformers, follow this pattern: .. code-block:: python from array_api_compat import get_namespace from ezmsg.baseproc import BaseTransformer from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.messages.util import replace class MyTransformer(BaseTransformer[MySettings, AxisArray, AxisArray]): def _process(self, message: AxisArray) -> AxisArray: xp = get_namespace(message.data) # Use xp instead of np for array operations result = xp.sqrt(xp.abs(message.data)) return replace(message, data=result) Key guidelines: 1. Call ``get_namespace(message.data)`` at the start of ``_process`` 2. Use ``xp.function_name`` instead of ``np.function_name`` 3. Note that some functions have different names: - ``np.concatenate`` → ``xp.concat`` - ``np.transpose`` → ``xp.permute_dims`` 4. Keep metadata operations (axis labels, etc.) as NumPy 5. Use in-place operations (``/=``, ``*=``) where possible for efficiency