Array API Support#

ezmsg-sigproc provides support for the Python Array API standard, enabling many transformers to work with arrays from different backends such as NumPy, CuPy, PyTorch, and JAX.

What is the Array API?#

The Array API is a standardized interface for array operations across different Python array libraries. By coding to this standard, ezmsg-sigproc transformers can process data regardless of which array library created it, enabling:

  • GPU acceleration via CuPy or PyTorch tensors

  • Framework interoperability for integration with ML pipelines

  • Hardware flexibility without code changes

How It Works#

Compatible transformers use array-api-compat to detect the input array’s namespace and use the appropriate operations:

from array_api_compat import get_namespace

def _process(self, message: AxisArray) -> AxisArray:
    xp = get_namespace(message.data)  # numpy, cupy, torch, etc.
    result = xp.abs(message.data)     # Uses the correct backend
    return replace(message, data=result)

Usage Example#

Using Array API compatible transformers with CuPy for GPU acceleration:

import cupy as cp
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.sigproc.math.abs import AbsTransformer
from ezmsg.sigproc.math.clip import ClipTransformer, ClipSettings

# Create data on GPU
gpu_data = cp.random.randn(1000, 64).astype(cp.float32)
message = AxisArray(gpu_data, dims=["time", "ch"])

# Process entirely on GPU - no data transfer!
abs_transformer = AbsTransformer()
clip_transformer = ClipTransformer(ClipSettings(min=0.0, max=1.0))

result = clip_transformer(abs_transformer(message))
# result.data is still a CuPy array on GPU

Compatible Modules#

The following transformers fully support the Array API standard:

Math Operations#

Module

Description

ezmsg.sigproc.math.abs

Absolute value

ezmsg.sigproc.math.clip

Clip values to a range

ezmsg.sigproc.math.log

Logarithm with configurable base

ezmsg.sigproc.math.scale

Multiply by a constant

ezmsg.sigproc.math.invert

Compute 1/x

ezmsg.sigproc.math.difference

Subtract a constant (ConstDifferenceTransformer)

Signal Processing#

Module

Description

ezmsg.sigproc.diff

Compute differences along an axis

ezmsg.sigproc.transpose

Transpose/permute array dimensions

ezmsg.sigproc.linear

Per-channel linear transform (scale + offset)

ezmsg.sigproc.aggregate

Aggregate operations (AggregateTransformer only)

Coordinate Transforms#

Module

Description

ezmsg.sigproc.coordinatespaces

Cartesian/polar coordinate conversions

Limitations#

Some operations remain NumPy-only due to lack of Array API equivalents:

  • Random number generation: Modules using np.random (e.g., denormalize)

  • SciPy operations: Filtering (scipy.signal.lfilter), FFT, wavelets

  • Advanced indexing: Some slicing operations for metadata handling

  • Memory layout: np.require for contiguous array optimization (NumPy only)

Metadata arrays (axis labels, coordinates) typically remain as NumPy arrays since they are not performance-critical.

Adding Array API Support#

When contributing new transformers, follow this pattern:

from array_api_compat import get_namespace
from ezmsg.baseproc import BaseTransformer
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

class MyTransformer(BaseTransformer[MySettings, AxisArray, AxisArray]):
    def _process(self, message: AxisArray) -> AxisArray:
        xp = get_namespace(message.data)

        # Use xp instead of np for array operations
        result = xp.sqrt(xp.abs(message.data))

        return replace(message, data=result)

Key guidelines:

  1. Call get_namespace(message.data) at the start of _process

  2. Use xp.function_name instead of np.function_name

  3. Note that some functions have different names: - np.concatenatexp.concat - np.transposexp.permute_dims

  4. Keep metadata operations (axis labels, etc.) as NumPy

  5. Use in-place operations (/=, *=) where possible for efficiency