Array API Compatibility#
ezmsg-learn uses the Array API standard to allow processors to operate on arrays from different backends — NumPy, CuPy, PyTorch, and others — without code changes.
How It Works#
Modules that support the Array API derive the array namespace from their input
data using array_api_compat.get_namespace():
from array_api_compat import get_namespace
def process(self, data):
xp = get_namespace(data) # numpy, cupy, torch, etc.
result = xp.linalg.inv(data) # dispatches to the right backend
return result
This means that if you pass a CuPy array, all computation stays on the GPU. If you pass a NumPy array, it behaves exactly as before.
Helper utilities from ezmsg.sigproc.util.array handle device placement
and creation functions portably:
array_device(x)— returns the device of an array, orNonexp_create(fn, *args, dtype=None, device=None)— calls creation functions (zeros,eye) with optional devicexp_asarray(xp, obj, dtype=None, device=None)— portableasarray
Module Compatibility#
The table below summarises the Array API status of each module.
Fully compatible#
These modules perform all computation in the source array namespace.
Module |
Notes |
|---|---|
|
LRR / self-supervised regression. Full Array API. |
|
Incremental CCA. Replaced |
|
PyTorch-native; operates on |
Mostly compatible (with NumPy boundaries)#
These modules use the Array API for data manipulation but fall back to NumPy at specific points where a dependency requires it.
Module |
NumPy boundary |
Reason |
|---|---|---|
|
|
|
|
|
Per-sample velocity remapping uses |
|
Inherits boundaries from model |
State init and output arrays use the source namespace. |
|
|
sklearn |
|
|
sklearn and river models require NumPy / pandas input. |
|
|
sklearn |
Not converted#
These modules use NumPy directly. Conversion would provide little benefit because the underlying estimator is the bottleneck.
Module |
Reason |
|---|---|
|
Thin wrapper around sklearn |
|
sklearn |
|
Generic wrapper for arbitrary models; cannot assume Array API support. |
|
Delegates to |
sklearn Array API Dispatch#
scikit-learn 1.8+ has experimental support for Array API dispatch on a subset of estimators. Two estimators used in ezmsg-learn are on the supported list:
Estimator |
Used in |
Constraint |
|---|---|---|
|
|
Requires |
|
|
Requires |
To use dispatch, enable it before creating the estimator:
from sklearn import set_config
set_config(array_api_dispatch=True)
Warning
array_api_dispatchis marked experimental in sklearn.Solver constraints (
solver="svd") may produce slightly different numerical results compared to other solvers.Enabling dispatch globally may affect other sklearn estimators in the same process.
ezmsg-learn does not enable dispatch by default.
Estimators that do not support Array API dispatch:
IncrementalPCA,MiniBatchNMF— only batchPCAis supportedSGDClassifier,SGDRegressor,PassiveAggressiveRegressorAll river models
Writing Array API Compatible Code#
When adding or modifying processors in ezmsg-learn, follow these patterns.
Deriving the namespace#
Always derive xp from the input data, not from a hardcoded numpy:
from array_api_compat import get_namespace
from ezmsg.sigproc.util.array import array_device, xp_create
def _process(self, message):
xp = get_namespace(message.data)
dev = array_device(message.data)
Transposing matrices#
The Array API does not support .T. Use xp.linalg.matrix_transpose():
# Before (numpy-only)
result = A.T @ B
# After (Array API)
_mT = xp.linalg.matrix_transpose
result = _mT(A) @ B
Creating arrays#
Use xp_create to handle device placement portably:
# Before
I = np.eye(n)
z = np.zeros((m, n), dtype=np.float64)
# After
I = xp_create(xp.eye, n, device=dev)
z = xp_create(xp.zeros, (m, n), dtype=xp.float64, device=dev)
Handling sklearn boundaries#
When calling into sklearn (or other NumPy-only libraries), convert at the boundary and convert back:
from array_api_compat import is_numpy_array
# Convert to numpy for sklearn
X_np = np.asarray(X) if not is_numpy_array(X) else X
result_np = estimator.predict(X_np)
# Convert back to source namespace
result = xp.asarray(result_np) if not is_numpy_array(X) else result_np
Checking for NaN#
Use xp.isnan instead of np.isnan:
if xp.any(xp.isnan(message.data)):
return
Norms#
Use xp.linalg.matrix_norm (Frobenius by default) instead of
np.linalg.norm for matrices. For vectors, use xp.linalg.vector_norm.