Array API Support#
ezmsg-sigproc provides support for the Python Array API standard, enabling many transformers to work with arrays from different backends such as NumPy, CuPy, PyTorch, JAX, and MLX.
What is the Array API?#
The Array API is a standardized interface for array operations across different Python array libraries. By coding to this standard, ezmsg-sigproc transformers can process data regardless of which array library created it, enabling:
GPU acceleration via CuPy, PyTorch, or JAX tensors
Apple Silicon acceleration via MLX
Framework interoperability for integration with ML pipelines
Hardware flexibility without code changes
How It Works#
Compatible transformers use array-api-compat to detect the input array’s namespace and use the appropriate operations:
from array_api_compat import get_namespace
def _process(self, message: AxisArray) -> AxisArray:
xp = get_namespace(message.data) # numpy, cupy, torch, mlx.core, etc.
result = xp.abs(message.data) # Uses the correct backend
return replace(message, data=result)
Usage Example#
Using Array API compatible transformers with CuPy for GPU acceleration:
import cupy as cp
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.sigproc.math.abs import AbsTransformer
from ezmsg.sigproc.math.clip import ClipTransformer, ClipSettings
# Create data on GPU
gpu_data = cp.random.randn(1000, 64).astype(cp.float32)
message = AxisArray(gpu_data, dims=["time", "ch"])
# Process entirely on GPU - no data transfer!
abs_transformer = AbsTransformer()
clip_transformer = ClipTransformer(ClipSettings(min=0.0, max=1.0))
result = clip_transformer(abs_transformer(message))
# result.data is still a CuPy array on GPU
Compatible Modules#
The following transformers fully support the Array API standard:
Math Operations#
Module |
Description |
|---|---|
Absolute value |
|
Clip values to a range |
|
Logarithm with configurable base |
|
Multiply by a constant |
|
Compute 1/x |
|
Subtract a constant (ConstDifferenceTransformer) |
Signal Processing#
Module |
Description |
|---|---|
FFT-based spectrum (SpectrumTransformer) |
|
Aggregate operations (AggregateTransformer, RangedAggregateTransformer) |
|
Compute differences along an axis |
|
Transpose/permute array dimensions |
|
Per-channel linear transform (scale + offset) |
Coordinate Transforms#
Module |
Description |
|---|---|
Cartesian/polar coordinate conversions |
Composite Pipelines#
These CompositeProcessor pipelines chain Array API-aware steps together.
When fed non-NumPy arrays, each step in the pipeline preserves the backend:
Module |
Description |
|---|---|
BandPowerTransformer (spectrogram + ranged aggregate) |
|
RMSBandPowerTransformer (with explicit |
MLX on Apple Silicon#
MLX is an array library for Apple Silicon that provides GPU-accelerated operations with a NumPy-like API. ezmsg-sigproc’s Array API support enables MLX acceleration for spectral analysis and other pipelines without code changes to the transformers themselves.
Basic usage#
Pass MLX arrays in your AxisArray messages:
import mlx.core as mx
import numpy as np
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.sigproc.spectrum import SpectrumTransformer, SpectrumSettings
# Create data as MLX array
np_data = np.random.randn(1000, 64).astype(np.float32)
message = AxisArray(
data=mx.array(np_data),
dims=["time", "ch"],
axes={"time": AxisArray.TimeAxis(fs=1000.0)},
)
proc = SpectrumTransformer(SpectrumSettings(axis="time"))
result = proc(message)
# result.data is an mlx.core.array
Lazy evaluation and mx.eval#
MLX uses lazy evaluation — computations are not executed until their results are needed. This allows MLX to fuse operations and optimize the computation graph. However, it means that timing code or downstream consumers may see artificially fast “processing” that is actually deferred.
To force evaluation, call mx.eval():
result = proc(message)
mx.eval(result.data) # Forces computation to complete
For CompositeProcessor pipelines (like BandPowerTransformer), you can
override _post_process to call mx.eval() automatically so that every
output is fully materialized:
class BandPowerTransformer(CompositeProcessor[BandPowerSettings, AxisArray, AxisArray]):
@staticmethod
def _initialize_processors(settings):
return {
"spectrogram": SpectrogramTransformer(settings=settings.spectrogram_settings),
"aggregate": RangedAggregateTransformer(...),
}
def _post_process(self, result: AxisArray | None) -> AxisArray | None:
if result is not None:
try:
import mlx.core as mx
if isinstance(result.data, mx.array):
mx.eval(result.data)
except ImportError:
pass
return result
This pattern is used by BandPowerTransformer and RMSBandPowerTransformer.
It ensures downstream consumers (ezmsg Units, visualization, logging) receive
fully evaluated arrays without needing to know about MLX internals.
It also provides a safety valve so the lazy graph does not accumulate if the graph
is not evaluated at the right time downstream.
Note
The _post_process hook is defined on CompositeProcessor in
ezmsg-baseproc. It runs after the entire processor chain completes and
receives the final output. The try/except ImportError pattern
keeps MLX as an optional dependency.
MLX quirks#
MLX’s Array API coverage is nearly complete but has a few gaps that ezmsg-sigproc works around internally:
Feature |
MLX status |
Workaround |
|---|---|---|
|
Not supported |
Manual normalization ( |
|
Needs tuple |
Always pass |
|
Not available |
Computed with NumPy (metadata only) |
|
No |
|
Window functions |
Not available |
Computed with NumPy, converted via |
|
Not available |
Falls back to NumPy automatically |
Boolean indexing |
Not supported |
Avoided in hot paths; used only in NumPy metadata code |
Slice with |
Rejected |
Slice bounds cast to Python |
These workarounds are handled inside the transformers — user code does not need to account for them.
Limitations#
Some operations remain NumPy-only due to lack of Array API equivalents:
SciPy operations: Butterworth filtering (
scipy.signal.sosfilt) and other scipy-dependent steps. UseAsArrayTransformerto convert between backends at pipeline boundaries (seeRMSBandPowerTransformerfor an example).Random number generation: Modules using
np.random(e.g.,denormalize)Trapezoidal integration:
np.trapezoidhas no Array API equivalent.RangedAggregateTransformerfalls back to NumPy transparently.Memory layout:
np.requirefor contiguous array optimization
Metadata arrays (axis labels, coordinates) always remain as NumPy arrays since they are not performance-critical.
Adding Array API Support#
When contributing new transformers, follow this pattern:
from array_api_compat import get_namespace
from ezmsg.baseproc import BaseTransformer
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
class MyTransformer(BaseTransformer[MySettings, AxisArray, AxisArray]):
def _process(self, message: AxisArray) -> AxisArray:
xp = get_namespace(message.data)
# Use xp instead of np for array operations
result = xp.sqrt(xp.abs(message.data))
return replace(message, data=result)
Key guidelines:
Call
get_namespace(message.data)at the start of_process(or_reset_statefor stateful transformers).Use
xp.function_nameinstead ofnp.function_namefor all operations onmessage.data.Note that some functions have different names: -
np.concatenate→xp.concat-np.transpose→xp.permute_dimsKeep metadata operations (axis labels, etc.) as NumPy.
When a backend lacks a function (e.g., MLX has no
nanmean), fall back gracefully:func_name = "mean" if hasattr(xp, func_name): result = getattr(xp, func_name)(data, axis=axis_idx) else: result = np.mean(np.asarray(data), axis=axis_idx)
For
CompositeProcessorsubclasses that may produce MLX output, add a_post_processoverride to callmx.eval()(see the MLX section above).Use portable helpers from
ezmsg.sigproc.util.arraywhen needed:is_complex_dtype,is_float_dtype,xp_asarray.