import asyncio
import traceback
from dataclasses import dataclass, field
import time
import typing
import numpy as np
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
from .butterworthfilter import ButterworthFilterSettings, ButterworthFilterTransformer
from .base import (
BaseStatefulProducer,
BaseProducerUnit,
BaseTransformer,
BaseTransformerUnit,
CompositeProducer,
ProducerType,
SettingsType,
MessageInType,
MessageOutType,
processor_state,
)
from .util.asio import run_coroutine_sync
from .util.profile import profile_subpub
[docs]
@dataclass
class AddState:
queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
[docs]
class AddProcessor:
[docs]
def __init__(self):
self._state = AddState()
@property
def state(self) -> AddState:
return self._state
@state.setter
def state(self, state: AddState | bytes | None) -> None:
if state is not None:
# TODO: Support hydrating state from bytes
# if isinstance(state, bytes):
# self._state = pickle.loads(state)
# else:
self._state = state
[docs]
def push_a(self, msg: AxisArray) -> None:
self._state.queue_a.put_nowait(msg)
[docs]
def push_b(self, msg: AxisArray) -> None:
self._state.queue_b.put_nowait(msg)
async def __acall__(self) -> AxisArray:
a = await self._state.queue_a.get()
b = await self._state.queue_b.get()
return replace(a, data=a.data + b.data)
def __call__(self) -> AxisArray:
return run_coroutine_sync(self.__acall__())
# Aliases for legacy interface
async def __anext__(self) -> AxisArray:
return await self.__acall__()
def __next__(self) -> AxisArray:
return self.__call__()
[docs]
class Add(ez.Unit):
"""Add two signals together. Assumes compatible/similar axes/dimensions."""
INPUT_SIGNAL_A = ez.InputStream(AxisArray)
INPUT_SIGNAL_B = ez.InputStream(AxisArray)
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
[docs]
async def initialize(self) -> None:
self.processor = AddProcessor()
[docs]
@ez.subscriber(INPUT_SIGNAL_A)
async def on_a(self, msg: AxisArray) -> None:
self.processor.push_a(msg)
[docs]
@ez.subscriber(INPUT_SIGNAL_B)
async def on_b(self, msg: AxisArray) -> None:
self.processor.push_b(msg)
[docs]
@ez.publisher(OUTPUT_SIGNAL)
async def output(self) -> typing.AsyncGenerator:
while True:
yield self.OUTPUT_SIGNAL, await self.processor.__acall__()
[docs]
class ClockSettings(ez.Settings):
"""Settings for clock generator."""
dispatch_rate: float | str | None = None
"""Dispatch rate in Hz, 'realtime', or None for external clock"""
[docs]
@processor_state
class ClockState:
"""State for clock generator."""
t_0: float = field(default_factory=time.time) # Start time
n_dispatch: int = 0 # Number of dispatches
[docs]
class ClockProducer(BaseStatefulProducer[ClockSettings, ez.Flag, ClockState]):
"""
Produces clock ticks at specified rate.
Can be used to drive periodic operations.
"""
def _reset_state(self) -> None:
"""Reset internal state."""
self._state.t_0 = time.time()
self._state.n_dispatch = 0
def __call__(self) -> ez.Flag:
"""Synchronous clock production. We override __call__ (which uses run_coroutine_sync) to avoid async overhead."""
if self._hash == -1:
self._reset_state()
self._hash = 0
if isinstance(self.settings.dispatch_rate, (int, float)):
# Manual dispatch_rate. (else it is 'as fast as possible')
target_time = (
self.state.t_0
+ (self.state.n_dispatch + 1) / self.settings.dispatch_rate
)
now = time.time()
if target_time > now:
time.sleep(target_time - now)
self.state.n_dispatch += 1
return ez.Flag()
async def _produce(self) -> ez.Flag:
"""Generate next clock tick."""
if isinstance(self.settings.dispatch_rate, (int, float)):
# Manual dispatch_rate. (else it is 'as fast as possible')
target_time = (
self.state.t_0
+ (self.state.n_dispatch + 1) / self.settings.dispatch_rate
)
now = time.time()
if target_time > now:
await asyncio.sleep(target_time - now)
self.state.n_dispatch += 1
return ez.Flag()
[docs]
def aclock(dispatch_rate: float | None) -> ClockProducer:
"""
Construct an async generator that yields events at a specified rate.
Returns:
A :obj:`ClockProducer` object.
"""
return ClockProducer(ClockSettings(dispatch_rate=dispatch_rate))
clock = aclock
"""
Alias for :obj:`aclock` expected by synchronous methods. `ClockProducer` can be used in sync or async.
"""
[docs]
class Clock(
BaseProducerUnit[
ClockSettings, # SettingsType
ez.Flag, # MessageType
ClockProducer, # ProducerType
]
):
SETTINGS = ClockSettings
[docs]
@ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
async def produce(self) -> typing.AsyncGenerator:
# Override so we can not to yield if out is False-like
while True:
out = await self.producer.__acall__()
if out:
yield self.OUTPUT_SIGNAL, out
# COUNTER - Generate incrementing integer. fs and dispatch_rate parameters combine to give many options. #
[docs]
class CounterSettings(ez.Settings):
# TODO: Adapt this to use ezmsg.util.rate?
"""
Settings for :obj:`Counter`.
See :obj:`acounter` for a description of the parameters.
"""
n_time: int
"""Number of samples to output per block."""
fs: float
"""Sampling rate of signal output in Hz"""
n_ch: int = 1
"""Number of channels to synthesize"""
dispatch_rate: float | str | None = None
"""
Message dispatch rate (Hz), 'realtime', 'ext_clock', or None (fast as possible)
Note: if dispatch_rate is a float then time offsets will be synthetic and the
system will run faster or slower than wall clock time.
"""
mod: int | None = None
"""If set to an integer, counter will rollover"""
[docs]
@processor_state
class CounterState:
"""
State for counter generator.
"""
counter_start: int = 0
"""next sample's first value"""
n_sent: int = 0
"""number of samples sent"""
clock_zero: float | None = None
"""time of first sample"""
timer_type: str = "unspecified"
"""
"realtime" | "ext_clock" | "manual" | "unspecified"
"""
new_generator: asyncio.Event | None = None
"""
Event to signal the counter has been reset.
"""
[docs]
class CounterProducer(BaseStatefulProducer[CounterSettings, AxisArray, CounterState]):
"""Produces incrementing integer blocks as AxisArray."""
# TODO: Adapt this to use ezmsg.util.rate?
[docs]
@classmethod
def get_message_type(cls, dir: str) -> typing.Optional[type[AxisArray]]:
if dir == "in":
return None
elif dir == "out":
return AxisArray
else:
raise ValueError(f"Invalid direction: {dir}. Use 'in' or 'out'.")
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if isinstance(
self.settings.dispatch_rate, str
) and self.settings.dispatch_rate not in ["realtime", "ext_clock"]:
raise ValueError(f"Unknown dispatch_rate: {self.settings.dispatch_rate}")
self._reset_state()
self._hash = 0
def _reset_state(self) -> None:
"""Reset internal state."""
self._state.counter_start = 0
self._state.n_sent = 0
self._state.clock_zero = time.time()
if self.settings.dispatch_rate is not None:
if isinstance(self.settings.dispatch_rate, str):
self._state.timer_type = self.settings.dispatch_rate.lower()
else:
self._state.timer_type = "manual"
if self._state.new_generator is None:
self._state.new_generator = asyncio.Event()
# Set the event to indicate that the state has been reset.
self._state.new_generator.set()
async def _produce(self) -> AxisArray:
"""Generate next counter block."""
# 1. Prepare counter data
block_samp = np.arange(
self.state.counter_start, self.state.counter_start + self.settings.n_time
)[:, np.newaxis]
if self.settings.mod is not None:
block_samp %= self.settings.mod
block_samp = np.tile(block_samp, (1, self.settings.n_ch))
# 2. Sleep if necessary. 3. Calculate time offset.
if self._state.timer_type == "realtime":
n_next = self.state.n_sent + self.settings.n_time
t_next = self.state.clock_zero + n_next / self.settings.fs
await asyncio.sleep(t_next - time.time())
offset = t_next - self.settings.n_time / self.settings.fs
elif self._state.timer_type == "manual":
# manual dispatch rate
n_disp_next = 1 + self.state.n_sent / self.settings.n_time
t_disp_next = (
self.state.clock_zero + n_disp_next / self.settings.dispatch_rate
)
await asyncio.sleep(t_disp_next - time.time())
offset = self.state.n_sent / self.settings.fs
elif self._state.timer_type == "ext_clock":
# ext_clock -- no sleep. Assume this is called at appropriate intervals.
offset = time.time()
else:
# Was "unspecified"
offset = self.state.n_sent / self.settings.fs
# 4. Create output AxisArray
# Note: We can make this a bit faster by preparing a template for self._state
result = AxisArray(
data=block_samp,
dims=["time", "ch"],
axes={
"time": AxisArray.TimeAxis(fs=self.settings.fs, offset=offset),
"ch": AxisArray.CoordinateAxis(
data=np.array([f"Ch{_}" for _ in range(self.settings.n_ch)]),
dims=["ch"],
),
},
key="acounter",
)
# 5. Update state
self.state.counter_start = block_samp[-1, 0] + 1
self.state.n_sent += self.settings.n_time
return result
[docs]
def acounter(
n_time: int,
fs: float | None,
n_ch: int = 1,
dispatch_rate: float | str | None = None,
mod: int | None = None,
) -> CounterProducer:
"""
Construct an asynchronous generator to generate AxisArray objects at a specified rate
and with the specified sampling rate.
NOTE: This module uses asyncio.sleep to delay appropriately in realtime mode.
This method of sleeping/yielding execution priority has quirky behavior with
sub-millisecond sleep periods which may result in unexpected behavior (e.g.
fs = 2000, n_time = 1, realtime = True -- may result in ~1400 msgs/sec)
Returns:
An asynchronous generator.
"""
return CounterProducer(
CounterSettings(
n_time=n_time, fs=fs, n_ch=n_ch, dispatch_rate=dispatch_rate, mod=mod
)
)
[docs]
class Counter(
BaseProducerUnit[
CounterSettings, # SettingsType
AxisArray, # MessageOutType
CounterProducer, # ProducerType
]
):
"""Generates monotonically increasing counter. Unit for :obj:`CounterProducer`."""
SETTINGS = CounterSettings
INPUT_CLOCK = ez.InputStream(ez.Flag)
[docs]
@ez.subscriber(INPUT_CLOCK)
@ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
async def on_clock(self, _: ez.Flag):
if self.producer.settings.dispatch_rate == "ext_clock":
out = await self.producer.__acall__()
yield self.OUTPUT_SIGNAL, out
[docs]
@ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
async def produce(self) -> typing.AsyncGenerator:
"""
Generate counter output.
This is an infinite loop, but we will likely only enter the loop once if we are self-timed,
and twice if we are using an external clock.
When using an internal clock, we enter the loop, and wait for the event which should have
been reset upon initialization then we immediately clear, then go to the internal loop
that will async call __acall__ to let the internal timer determine when to produce an output.
When using an external clock, we enter the loop, and wait for the event which should have been
reset upon initialization then we immediately clear, then we hit `continue` to loop back around
and wait for the event to be set again -- potentially forever. In this case, it is expected that
`on_clock` will be called to produce the output.
"""
try:
while True:
# Once-only, enter the generator loop
await self.producer.state.new_generator.wait()
self.producer.state.new_generator.clear()
if self.producer.settings.dispatch_rate == "ext_clock":
# We shouldn't even be here. Cycle around and wait on the event again.
continue
# We are not using an external clock. Run the generator.
while not self.producer.state.new_generator.is_set():
out = await self.producer.__acall__()
yield self.OUTPUT_SIGNAL, out
except Exception:
ez.logger.info(traceback.format_exc())
[docs]
class SinGeneratorSettings(ez.Settings):
"""
Settings for :obj:`SinGenerator`.
See :obj:`sin` for parameter descriptions.
"""
axis: str | None = "time"
"""
The name of the axis over which the sinusoid passes.
Note: The axis must exist in the msg.axes and be of type AxisArray.LinearAxis.
"""
freq: float = 1.0
"""The frequency of the sinusoid, in Hz."""
amp: float = 1.0 # Amplitude
"""The amplitude of the sinusoid."""
phase: float = 0.0 # Phase offset (in radians)
"""The initial phase of the sinusoid, in radians."""
[docs]
class SinGenerator(
BaseTransformerUnit[SinGeneratorSettings, AxisArray, AxisArray, SinTransformer]
):
"""Unit for generating sinusoidal waveforms."""
SETTINGS = SinGeneratorSettings
[docs]
def sin(
axis: str | None = "time",
freq: float = 1.0,
amp: float = 1.0,
phase: float = 0.0,
) -> SinTransformer:
"""
Construct a generator of sinusoidal waveforms in AxisArray objects.
Returns:
A primed generator that expects .send(axis_array) of sample counts
and yields an AxisArray of sinusoids.
"""
return SinTransformer(
SinGeneratorSettings(axis=axis, freq=freq, amp=amp, phase=phase)
)
[docs]
class RandomGeneratorSettings(ez.Settings):
loc: float = 0.0
"""loc argument for :obj:`numpy.random.normal`"""
scale: float = 1.0
"""scale argument for :obj:`numpy.random.normal`"""
[docs]
class RandomGenerator(
BaseTransformerUnit[
RandomGeneratorSettings,
AxisArray,
AxisArray,
RandomTransformer,
]
):
SETTINGS = RandomGeneratorSettings
[docs]
class OscillatorSettings(ez.Settings):
"""Settings for :obj:`Oscillator`"""
n_time: int
"""Number of samples to output per block."""
fs: float
"""Sampling rate of signal output in Hz"""
n_ch: int = 1
"""Number of channels to output per block"""
dispatch_rate: float | str | None = None
"""(Hz) | 'realtime' | 'ext_clock'"""
freq: float = 1.0
"""Oscillation frequency in Hz"""
amp: float = 1.0
"""Amplitude"""
phase: float = 0.0
"""Phase offset (in radians)"""
sync: bool = False
"""Adjust `freq` to sync with sampling rate"""
[docs]
class OscillatorProducer(CompositeProducer[OscillatorSettings, AxisArray]):
@staticmethod
def _initialize_processors(
settings: OscillatorSettings,
) -> dict[str, CounterProducer | SinTransformer]:
# Calculate synchronous settings if necessary
freq = settings.freq
mod = None
if settings.sync:
period = 1.0 / settings.freq
mod = round(period * settings.fs)
freq = 1.0 / (mod / settings.fs)
return {
"counter": CounterProducer(
CounterSettings(
n_time=settings.n_time,
fs=settings.fs,
n_ch=settings.n_ch,
dispatch_rate=settings.dispatch_rate,
mod=mod,
)
),
"sin": SinTransformer(
SinGeneratorSettings(freq=freq, amp=settings.amp, phase=settings.phase)
),
}
[docs]
class BaseCounterFirstProducerUnit(
BaseProducerUnit[SettingsType, MessageOutType, ProducerType],
typing.Generic[SettingsType, MessageInType, MessageOutType, ProducerType],
):
"""
Base class for units whose primary processor is a composite producer with a CounterProducer as the first
processor (producer) in the chain.
"""
INPUT_SIGNAL = ez.InputStream(MessageInType)
[docs]
def create_producer(self):
super().create_producer()
def recurse_get_counter(proc) -> CounterProducer:
if hasattr(proc, "_procs"):
return recurse_get_counter(list(proc._procs.values())[0])
return proc
self._counter = recurse_get_counter(self.producer)
[docs]
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
@ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
@profile_subpub(trace_oldest=False)
async def on_signal(self, _: ez.Flag):
if self.producer.settings.dispatch_rate == "ext_clock":
out = await self.producer.__acall__()
yield self.OUTPUT_SIGNAL, out
[docs]
@ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
async def produce(self) -> typing.AsyncGenerator:
try:
counter_state = self._counter.state
while True:
# Once-only, enter the generator loop
await counter_state.new_generator.wait()
counter_state.new_generator.clear()
if self.producer.settings.dispatch_rate == "ext_clock":
# We shouldn't even be here. Cycle around and wait on the event again.
continue
# We are not using an external clock. Run the generator.
while not counter_state.new_generator.is_set():
out = await self.producer.__acall__()
yield self.OUTPUT_SIGNAL, out
except Exception:
ez.logger.info(traceback.format_exc())
[docs]
class Oscillator(
BaseCounterFirstProducerUnit[
OscillatorSettings, AxisArray, AxisArray, OscillatorProducer
]
):
"""Generates sinusoidal waveforms using a counter and sine transformer."""
SETTINGS = OscillatorSettings
[docs]
class NoiseSettings(ez.Settings):
"""
See :obj:`CounterSettings` and :obj:`RandomGeneratorSettings`.
"""
n_time: int # Number of samples to output per block
fs: float # Sampling rate of signal output in Hz
n_ch: int = 1 # Number of channels to output
dispatch_rate: float | str | None = None
"""(Hz), 'realtime', or 'ext_clock'"""
loc: float = 0.0 # DC offset
scale: float = 1.0 # Scale (in standard deviations)
WhiteNoiseSettings = NoiseSettings
[docs]
class WhiteNoiseProducer(CompositeProducer[NoiseSettings, AxisArray]):
@staticmethod
def _initialize_processors(
settings: NoiseSettings,
) -> dict[str, CounterProducer | RandomTransformer]:
return {
"counter": CounterProducer(
CounterSettings(
n_time=settings.n_time,
fs=settings.fs,
n_ch=settings.n_ch,
dispatch_rate=settings.dispatch_rate,
mod=None,
)
),
"random": RandomTransformer(
RandomGeneratorSettings(
loc=settings.loc,
scale=settings.scale,
)
),
}
[docs]
class WhiteNoise(
BaseCounterFirstProducerUnit[
NoiseSettings, AxisArray, AxisArray, WhiteNoiseProducer
]
):
"""chains a :obj:`Counter` and :obj:`RandomGenerator`."""
SETTINGS = NoiseSettings
PinkNoiseSettings = NoiseSettings
[docs]
class PinkNoiseProducer(CompositeProducer[PinkNoiseSettings, AxisArray]):
@staticmethod
def _initialize_processors(
settings: PinkNoiseSettings,
) -> dict[str, WhiteNoiseProducer | ButterworthFilterTransformer]:
return {
"white_noise": WhiteNoiseProducer(settings=settings),
"filter": ButterworthFilterTransformer(
settings=ButterworthFilterSettings(
axis="time",
order=1,
cutoff=settings.fs * 0.01, # Hz
)
),
}
[docs]
class PinkNoise(
BaseCounterFirstProducerUnit[NoiseSettings, AxisArray, AxisArray, PinkNoiseProducer]
):
"""chains :obj:`WhiteNoise` and :obj:`ButterworthFilter`."""
SETTINGS = NoiseSettings
[docs]
class EEGSynthSettings(ez.Settings):
"""See :obj:`OscillatorSettings`."""
fs: float = 500.0 # Hz
n_time: int = 100
alpha_freq: float = 10.5 # Hz
n_ch: int = 8
[docs]
class EEGSynth(ez.Collection):
"""
A :obj:`Collection` that chains a :obj:`Clock` to both :obj:`PinkNoise`
and :obj:`Oscillator`, then :obj:`Add` s the result.
Unlike the Oscillator, WhiteNoise, and PinkNoise composite processors which have linear
flows, this class has a diamond flow, with clock branching to both PinkNoise and Oscillator,
which then are combined in Add.
Optional: Refactor as a ProducerUnit, similar to Clock, but we manually add all the other
transformers.
"""
SETTINGS = EEGSynthSettings
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
CLOCK = Clock()
NOISE = PinkNoise()
OSC = Oscillator()
ADD = Add()
[docs]
def network(self) -> ez.NetworkDefinition:
return (
(self.CLOCK.OUTPUT_SIGNAL, self.OSC.INPUT_SIGNAL),
(self.CLOCK.OUTPUT_SIGNAL, self.NOISE.INPUT_SIGNAL),
(self.OSC.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_A),
(self.NOISE.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_B),
(self.ADD.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
)