Array API Support#
ezmsg-sigproc provides support for the Python Array API standard, enabling many transformers to work with arrays from different backends such as NumPy, CuPy, PyTorch, and JAX.
What is the Array API?#
The Array API is a standardized interface for array operations across different Python array libraries. By coding to this standard, ezmsg-sigproc transformers can process data regardless of which array library created it, enabling:
GPU acceleration via CuPy or PyTorch tensors
Framework interoperability for integration with ML pipelines
Hardware flexibility without code changes
How It Works#
Compatible transformers use array-api-compat to detect the input array’s namespace and use the appropriate operations:
from array_api_compat import get_namespace
def _process(self, message: AxisArray) -> AxisArray:
xp = get_namespace(message.data) # numpy, cupy, torch, etc.
result = xp.abs(message.data) # Uses the correct backend
return replace(message, data=result)
Usage Example#
Using Array API compatible transformers with CuPy for GPU acceleration:
import cupy as cp
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.sigproc.math.abs import AbsTransformer
from ezmsg.sigproc.math.clip import ClipTransformer, ClipSettings
# Create data on GPU
gpu_data = cp.random.randn(1000, 64).astype(cp.float32)
message = AxisArray(gpu_data, dims=["time", "ch"])
# Process entirely on GPU - no data transfer!
abs_transformer = AbsTransformer()
clip_transformer = ClipTransformer(ClipSettings(min=0.0, max=1.0))
result = clip_transformer(abs_transformer(message))
# result.data is still a CuPy array on GPU
Compatible Modules#
The following transformers fully support the Array API standard:
Math Operations#
Module |
Description |
|---|---|
Absolute value |
|
Clip values to a range |
|
Logarithm with configurable base |
|
Multiply by a constant |
|
Compute 1/x |
|
Subtract a constant (ConstDifferenceTransformer) |
Signal Processing#
Module |
Description |
|---|---|
Compute differences along an axis |
|
Transpose/permute array dimensions |
|
Per-channel linear transform (scale + offset) |
|
Aggregate operations (AggregateTransformer only) |
Coordinate Transforms#
Module |
Description |
|---|---|
Cartesian/polar coordinate conversions |
Limitations#
Some operations remain NumPy-only due to lack of Array API equivalents:
Random number generation: Modules using
np.random(e.g.,denormalize)SciPy operations: Filtering (
scipy.signal.lfilter), FFT, waveletsAdvanced indexing: Some slicing operations for metadata handling
Memory layout:
np.requirefor contiguous array optimization (NumPy only)
Metadata arrays (axis labels, coordinates) typically remain as NumPy arrays since they are not performance-critical.
Adding Array API Support#
When contributing new transformers, follow this pattern:
from array_api_compat import get_namespace
from ezmsg.baseproc import BaseTransformer
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
class MyTransformer(BaseTransformer[MySettings, AxisArray, AxisArray]):
def _process(self, message: AxisArray) -> AxisArray:
xp = get_namespace(message.data)
# Use xp instead of np for array operations
result = xp.sqrt(xp.abs(message.data))
return replace(message, data=result)
Key guidelines:
Call
get_namespace(message.data)at the start of_processUse
xp.function_nameinstead ofnp.function_nameNote that some functions have different names: -
np.concatenate→xp.concat-np.transpose→xp.permute_dimsKeep metadata operations (axis labels, etc.) as NumPy
Use in-place operations (
/=,*=) where possible for efficiency