"""Portable helpers for Array API interoperability.
These utilities smooth over differences between Array API libraries
(NumPy, PyTorch, MLX, CuPy, etc.) — in particular around ``device``
placement and ``dtype`` introspection, which are not uniformly supported.
Design rule for ``xp_*`` helpers: **prefer the native op only when it
differs semantically from the fallback.** For pure stride/metadata
tricks (reshape, transpose, slicing) every backend's implementation is
equivalent in cost, so the simplest path wins. For ops that do real
work (empty vs. zeros, compiled kernels) we route to the native op when
available. When backends disagree on API — e.g. ``torch.flip(dims=...)``
vs ``numpy.flip(axis=...)``, or torch's refusal of negative-step slicing
— we absorb that here rather than leaking it to callers.
"""
import numpy as np
from array_api_compat import get_namespace
[docs]
def array_device(x):
"""Return the device of an array, or ``None`` for device-less libraries."""
try:
from array_api_compat import device
return device(x)
except (AttributeError, TypeError):
return None
[docs]
def xp_asarray(xp, obj, *, dtype=None, device=None):
"""Portable ``xp.asarray`` that omits unsupported kwargs.
Some Array API libraries (e.g. MLX) don't accept a ``device`` keyword.
This helper builds the kwargs dict dynamically so that only supported
arguments are forwarded.
"""
kwargs = {}
if dtype is not None:
kwargs["dtype"] = dtype
if device is not None:
kwargs["device"] = device
return xp.asarray(obj, **kwargs)
[docs]
def xp_create(fn, *args, dtype=None, device=None, **extra):
"""Call a creation function (``zeros``, ``ones``, ``eye``) portably.
Omits ``device`` if it is ``None`` (for libraries that don't support it).
"""
kwargs = dict(extra)
if dtype is not None:
kwargs["dtype"] = dtype
if device is not None:
kwargs["device"] = device
return fn(*args, **kwargs)
[docs]
def xp_empty(xp, shape, *, dtype=None):
"""Portable ``xp.empty`` with a ``zeros`` fallback for backends (e.g. MLX)
that don't expose ``empty``. MLX is lazy so the extra zero init is near-free;
on eager backends ``empty`` is preferred when available."""
fn = getattr(xp, "empty", None) or xp.zeros
if dtype is not None:
return fn(shape, dtype=dtype)
return fn(shape)
[docs]
def xp_flip(arr, axis):
"""Reverse ``arr`` along ``axis``, portable across backends.
Dispatches: ``numpy.flip(axis=)`` / ``cupy.flip(axis=)`` / ``torch.flip(dims=)``
when the namespace exposes ``flip``, else negative-step slicing (MLX).
Torch is the reason we can't make slicing the universal path — it
rejects negative steps with ``ValueError``.
Note on cost: numpy/cupy return a strided view (O(1)); torch's flip
materializes a copy (no view equivalent exists there); MLX's slicing
returns a view.
"""
xp = get_namespace(arr)
flip = getattr(xp, "flip", None)
if flip is not None:
try:
return flip(arr, axis=axis)
except TypeError:
# torch.flip takes ``dims=[...]``, not ``axis=``.
return flip(arr, dims=[axis])
# MLX: no module-level flip; negative-step slicing works (view).
idx = [slice(None)] * arr.ndim
idx[axis] = slice(None, None, -1)
return arr[tuple(idx)]
[docs]
def xp_itemsize(dtype) -> int:
"""Bytes per element of ``dtype``, portable across backends.
numpy/cupy dtype *instances* expose ``.itemsize`` as an int; torch
dtypes also expose ``.itemsize`` as an int; MLX dtypes expose ``.size``.
NumPy scalar *types* (e.g. ``np.float32`` the class) expose ``.itemsize``
as an attribute descriptor, not a concrete int — we detect that and
round-trip through ``np.dtype(...)`` to get the instance.
"""
size = getattr(dtype, "itemsize", None)
if isinstance(size, int):
return size
size = getattr(dtype, "size", None)
if isinstance(size, int):
return size
try:
return int(np.dtype(dtype).itemsize)
except TypeError:
pass
raise TypeError(f"Cannot determine byte size of dtype {dtype!r}")
[docs]
def is_complex_dtype(dtype) -> bool:
"""Check whether *dtype* is a complex type, portably across backends."""
if hasattr(dtype, "kind"):
return dtype.kind == "c"
return "complex" in str(dtype).lower()
[docs]
def is_float_dtype(xp, dtype) -> bool:
"""Check whether *dtype* is a real floating-point type, portably."""
try:
return xp.isdtype(dtype, "real floating")
except AttributeError:
pass
# torch dtypes advertise ``is_floating_point`` (excludes complex).
is_fp = getattr(dtype, "is_floating_point", None)
if isinstance(is_fp, bool):
return is_fp
# Fallback for libraries without isdtype (e.g. MLX).
try:
return xp.issubdtype(dtype, xp.floating)
except (AttributeError, TypeError):
return np.issubdtype(np.dtype(dtype), np.floating)