Source code for ezmsg.sigproc.util.array
"""Portable helpers for Array API interoperability.
These utilities smooth over differences between Array API libraries
(NumPy, PyTorch, MLX, CuPy, etc.) — in particular around ``device``
placement and ``dtype`` introspection, which are not uniformly supported.
"""
import numpy as np
[docs]
def array_device(x):
"""Return the device of an array, or ``None`` for device-less libraries."""
try:
from array_api_compat import device
return device(x)
except (AttributeError, TypeError):
return None
[docs]
def xp_asarray(xp, obj, *, dtype=None, device=None):
"""Portable ``xp.asarray`` that omits unsupported kwargs.
Some Array API libraries (e.g. MLX) don't accept a ``device`` keyword.
This helper builds the kwargs dict dynamically so that only supported
arguments are forwarded.
"""
kwargs = {}
if dtype is not None:
kwargs["dtype"] = dtype
if device is not None:
kwargs["device"] = device
return xp.asarray(obj, **kwargs)
[docs]
def xp_create(fn, *args, dtype=None, device=None, **extra):
"""Call a creation function (``zeros``, ``ones``, ``eye``) portably.
Omits ``device`` if it is ``None`` (for libraries that don't support it).
"""
kwargs = dict(extra)
if dtype is not None:
kwargs["dtype"] = dtype
if device is not None:
kwargs["device"] = device
return fn(*args, **kwargs)
[docs]
def is_complex_dtype(dtype) -> bool:
"""Check whether *dtype* is a complex type, portably across backends."""
if hasattr(dtype, "kind"):
return dtype.kind == "c"
return "complex" in str(dtype).lower()
[docs]
def is_float_dtype(xp, dtype) -> bool:
"""Check whether *dtype* is a real floating-point type, portably."""
try:
return xp.isdtype(dtype, "real floating")
except AttributeError:
pass
# Fallback for libraries without isdtype (e.g. MLX).
try:
return xp.issubdtype(dtype, xp.floating)
except (AttributeError, TypeError):
return np.issubdtype(np.dtype(dtype), np.floating)