"""
Coordinate space transformations for streaming data.
This module provides utilities and ezmsg nodes for transforming between
Cartesian (x, y) and polar (r, theta) coordinate systems.
.. note::
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
"""
from enum import Enum
from typing import Tuple
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from array_api_compat import get_namespace, is_array_api_obj
from ezmsg.baseproc import (
BaseTransformer,
BaseTransformerUnit,
)
from ezmsg.util.messages.axisarray import AxisArray, replace
# -- Utility functions for coordinate transformations --
def _get_namespace_or_numpy(*args: npt.ArrayLike):
"""Get array namespace if any arg is an array, otherwise return numpy."""
for arg in args:
if is_array_api_obj(arg):
return get_namespace(arg)
return np
[docs]
def polar2z(r: npt.ArrayLike, theta: npt.ArrayLike) -> npt.ArrayLike:
"""Convert polar coordinates to complex number representation."""
xp = _get_namespace_or_numpy(r, theta)
return r * xp.exp(1j * theta)
[docs]
def z2polar(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
"""Convert complex number to polar coordinates (r, theta)."""
xp = _get_namespace_or_numpy(z)
return xp.abs(z), xp.atan2(xp.imag(z), xp.real(z))
[docs]
def cart2z(x: npt.ArrayLike, y: npt.ArrayLike) -> npt.ArrayLike:
"""Convert Cartesian coordinates to complex number representation."""
return x + 1j * y
[docs]
def z2cart(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
"""Convert complex number to Cartesian coordinates (x, y)."""
xp = _get_namespace_or_numpy(z)
return xp.real(z), xp.imag(z)
[docs]
def cart2pol(x: npt.ArrayLike, y: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
"""Convert Cartesian coordinates (x, y) to polar coordinates (r, theta)."""
return z2polar(cart2z(x, y))
[docs]
def pol2cart(r: npt.ArrayLike, theta: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
"""Convert polar coordinates (r, theta) to Cartesian coordinates (x, y)."""
return z2cart(polar2z(r, theta))
# -- ezmsg transformer classes --
[docs]
class CoordinateMode(str, Enum):
"""Transformation mode for coordinate conversion."""
CART2POL = "cart2pol"
"""Convert Cartesian (x, y) to polar (r, theta)."""
POL2CART = "pol2cart"
"""Convert polar (r, theta) to Cartesian (x, y)."""
[docs]
class CoordinateSpacesSettings(ez.Settings):
"""
Settings for :obj:`CoordinateSpaces`.
See :obj:`coordinate_spaces` for argument details.
"""
mode: CoordinateMode = CoordinateMode.CART2POL
"""The transformation mode: 'cart2pol' or 'pol2cart'."""
axis: str | None = None
"""
The name of the axis containing the coordinate components.
Defaults to the last axis. Must have exactly 2 elements (x,y or r,theta).
"""
[docs]
class CoordinateSpaces(
BaseTransformerUnit[CoordinateSpacesSettings, AxisArray, AxisArray, CoordinateSpacesTransformer]
):
"""
Unit for transforming between Cartesian and polar coordinate systems.
See :obj:`CoordinateSpacesSettings` for configuration options.
"""
SETTINGS = CoordinateSpacesSettings