Source code for ezmsg.sigproc.math.difference
"""
Take the difference between 2 signals or between a signal and a constant value.
.. note::
:obj:`ConstDifferenceTransformer` supports the :doc:`Array API standard </guides/explanations/array_api>`,
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
:obj:`DifferenceProcessor` (two-input difference) currently requires NumPy arrays.
"""
import asyncio
import typing
from dataclasses import dataclass, field
import ezmsg.core as ez
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
from ezmsg.baseproc.util.asio import run_coroutine_sync
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
[docs]
class ConstDifferenceSettings(ez.Settings):
value: float = 0.0
"""number to subtract or be subtracted from the input data"""
subtrahend: bool = True
"""If True (default) then value is subtracted from the input data. If False, the input data
is subtracted from value."""
[docs]
class ConstDifference(BaseTransformerUnit[ConstDifferenceSettings, AxisArray, AxisArray, ConstDifferenceTransformer]):
SETTINGS = ConstDifferenceSettings
[docs]
def const_difference(value: float = 0.0, subtrahend: bool = True) -> ConstDifferenceTransformer:
"""
result = (in_data - value) if subtrahend else (value - in_data)
https://en.wikipedia.org/wiki/Template:Arithmetic_operations
Args:
value: number to subtract or be subtracted from the input data
subtrahend: If True (default) then value is subtracted from the input data.
If False, the input data is subtracted from value.
Returns: :obj:`ConstDifferenceTransformer`.
"""
return ConstDifferenceTransformer(ConstDifferenceSettings(value=value, subtrahend=subtrahend))
# --- Two-input Difference ---
[docs]
@dataclass
class DifferenceState:
"""State for Difference processor with two input queues."""
queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
[docs]
class DifferenceProcessor:
"""Processor that subtracts two AxisArray signals (A - B).
This processor maintains separate queues for two input streams and
subtracts corresponding messages element-wise. It assumes both inputs
have compatible shapes and aligned time spans.
"""
[docs]
def __init__(self):
self._state = DifferenceState()
@property
def state(self) -> DifferenceState:
return self._state
@state.setter
def state(self, state: DifferenceState | bytes | None) -> None:
if state is not None:
self._state = state
[docs]
def push_a(self, msg: AxisArray) -> None:
"""Push a message to queue A (minuend)."""
self._state.queue_a.put_nowait(msg)
[docs]
def push_b(self, msg: AxisArray) -> None:
"""Push a message to queue B (subtrahend)."""
self._state.queue_b.put_nowait(msg)
async def __acall__(self) -> AxisArray:
"""Await and subtract the next messages (A - B)."""
a = await self._state.queue_a.get()
b = await self._state.queue_b.get()
return replace(a, data=a.data - b.data)
def __call__(self) -> AxisArray:
"""Synchronously get and subtract the next messages."""
return run_coroutine_sync(self.__acall__())
# Aliases for legacy interface
async def __anext__(self) -> AxisArray:
return await self.__acall__()
def __next__(self) -> AxisArray:
return self.__call__()
[docs]
class Difference(ez.Unit):
"""Subtract two signals (A - B).
Assumes compatible/similar axes/dimensions and aligned time spans.
Messages are paired by arrival order (oldest from each queue).
OUTPUT = INPUT_SIGNAL_A - INPUT_SIGNAL_B
"""
INPUT_SIGNAL_A = ez.InputStream(AxisArray)
INPUT_SIGNAL_B = ez.InputStream(AxisArray)
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
[docs]
async def initialize(self) -> None:
self.processor = DifferenceProcessor()
[docs]
@ez.subscriber(INPUT_SIGNAL_A)
async def on_a(self, msg: AxisArray) -> None:
self.processor.push_a(msg)
[docs]
@ez.subscriber(INPUT_SIGNAL_B)
async def on_b(self, msg: AxisArray) -> None:
self.processor.push_b(msg)
[docs]
@ez.publisher(OUTPUT_SIGNAL)
async def output(self) -> typing.AsyncGenerator:
while True:
yield self.OUTPUT_SIGNAL, await self.processor.__acall__()