Source code for ezmsg.sigproc.signalinjector
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
import numpy as np
import numpy.typing as npt
from .base import (
BaseAsyncTransformer,
BaseTransformerUnit,
processor_state,
)
[docs]
class SignalInjectorSettings(ez.Settings):
time_dim: str = "time" # Input signal needs a time dimension with units in sec.
frequency: float | None = None # Hz
amplitude: float = 1.0
mixing_seed: int | None = None
[docs]
@processor_state
class SignalInjectorState:
cur_shape: tuple[int, ...] | None = None
cur_frequency: float | None = None
cur_amplitude: float | None = None
mixing: npt.NDArray | None = None
[docs]
class SignalInjectorTransformer(
BaseAsyncTransformer[
SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState
]
):
def _hash_message(self, message: AxisArray) -> int:
time_ax_idx = message.get_axis_idx(self.settings.time_dim)
sample_shape = (
message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
)
return hash((message.key,) + sample_shape)
def _reset_state(self, message: AxisArray) -> None:
if self._state.cur_frequency is None:
self._state.cur_frequency = self.settings.frequency
if self._state.cur_amplitude is None:
self._state.cur_amplitude = self.settings.amplitude
time_ax_idx = message.get_axis_idx(self.settings.time_dim)
self._state.cur_shape = (
message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
)
rng = np.random.default_rng(self.settings.mixing_seed)
self._state.mixing = rng.random((1, message.shape2d(self.settings.time_dim)[1]))
self._state.mixing = (self._state.mixing * 2.0) - 1.0
async def _aprocess(self, message: AxisArray) -> AxisArray:
if self._state.cur_frequency is None:
return message
out_msg = replace(message, data=message.data.copy())
t = out_msg.ax(self.settings.time_dim).values[..., np.newaxis]
signal = np.sin(2 * np.pi * self._state.cur_frequency * t)
mixed_signal = signal * self._state.mixing * self._state.cur_amplitude
with out_msg.view2d(self.settings.time_dim) as view:
view[...] = view + mixed_signal.astype(view.dtype)
return out_msg
[docs]
class SignalInjector(
BaseTransformerUnit[
SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer
]
):
SETTINGS = SignalInjectorSettings
INPUT_FREQUENCY = ez.InputStream(float | None)
INPUT_AMPLITUDE = ez.InputStream(float)
[docs]
@ez.subscriber(INPUT_FREQUENCY)
async def on_frequency(self, msg: float | None) -> None:
self.processor.state.cur_frequency = msg
[docs]
@ez.subscriber(INPUT_AMPLITUDE)
async def on_amplitude(self, msg: float) -> None:
self.processor.state.cur_amplitude = msg