Source code for ezmsg.sigproc.linear
"""
Apply a linear transformation: output = scale * input + offset.
Supports per-element scale and offset along a specified axis.
For full matrix transformations, use :obj:`AffineTransformTransformer` instead.
.. note::
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
"""
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from array_api_compat import get_namespace
from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
[docs]
class LinearTransformSettings(ez.Settings):
scale: float | list[float] | npt.ArrayLike = 1.0
"""Scale factor(s). Can be a scalar (applied to all elements) or an array
matching the size of the specified axis for per-element scaling."""
offset: float | list[float] | npt.ArrayLike = 0.0
"""Offset value(s). Can be a scalar (applied to all elements) or an array
matching the size of the specified axis for per-element offset."""
axis: str | None = None
"""Axis along which to apply per-element scale/offset. If None, scalar
scale/offset are broadcast to all elements."""
[docs]
@processor_state
class LinearTransformState:
scale: npt.NDArray = None
"""Prepared scale array for broadcasting."""
offset: npt.NDArray = None
"""Prepared offset array for broadcasting."""
[docs]
class LinearTransformTransformer(
BaseStatefulTransformer[LinearTransformSettings, AxisArray, AxisArray, LinearTransformState]
):
"""Apply linear transformation: output = scale * input + offset.
This transformer is optimized for element-wise linear operations with
optional per-channel (or per-axis) coefficients. For full matrix
transformations, use :obj:`AffineTransformTransformer` instead.
Examples:
# Uniform scaling and offset
>>> transformer = LinearTransformTransformer(LinearTransformSettings(scale=2.0, offset=1.0))
# Per-channel scaling (e.g., for 3-channel data along "ch" axis)
>>> transformer = LinearTransformTransformer(LinearTransformSettings(
... scale=[0.5, 1.0, 2.0],
... offset=[0.0, 0.1, 0.2],
... axis="ch"
... ))
"""
def _hash_message(self, message: AxisArray) -> int:
"""Hash based on shape and axis to detect when broadcast shapes need recalculation."""
axis = self.settings.axis
if axis is not None:
axis_idx = message.get_axis_idx(axis)
return hash((message.data.ndim, axis_idx, message.data.shape[axis_idx]))
return hash(message.data.ndim)
def _reset_state(self, message: AxisArray) -> None:
"""Prepare scale/offset arrays with proper broadcast shapes."""
xp = get_namespace(message.data)
ndim = message.data.ndim
scale = self.settings.scale
offset = self.settings.offset
# Convert settings to arrays
if isinstance(scale, (list, np.ndarray)):
scale = xp.asarray(scale, dtype=xp.float64)
else:
# Scalar: create a 0-d array
scale = xp.asarray(float(scale), dtype=xp.float64)
if isinstance(offset, (list, np.ndarray)):
offset = xp.asarray(offset, dtype=xp.float64)
else:
# Scalar: create a 0-d array
offset = xp.asarray(float(offset), dtype=xp.float64)
# If axis is specified and we have 1-d arrays, reshape for proper broadcasting
if self.settings.axis is not None and ndim > 0:
axis_idx = message.get_axis_idx(self.settings.axis)
if scale.ndim == 1:
# Create shape for broadcasting: all 1s except at axis_idx
broadcast_shape = [1] * ndim
broadcast_shape[axis_idx] = scale.shape[0]
scale = xp.reshape(scale, broadcast_shape)
if offset.ndim == 1:
broadcast_shape = [1] * ndim
broadcast_shape[axis_idx] = offset.shape[0]
offset = xp.reshape(offset, broadcast_shape)
self._state.scale = scale
self._state.offset = offset
def _process(self, message: AxisArray) -> AxisArray:
result = message.data * self._state.scale + self._state.offset
return replace(message, data=result)
[docs]
class LinearTransform(BaseTransformerUnit[LinearTransformSettings, AxisArray, AxisArray, LinearTransformTransformer]):
"""Unit wrapper for LinearTransformTransformer."""
SETTINGS = LinearTransformSettings