Source code for ezmsg.learn.process.seqseqsampler

"""
TODO: This needs a lot of work and some testig.
This node will assume that all incoming data will be aligned on the same time basis.
INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage)
The .value is a mostly useless string. Reserved for special commands like "stop"
We may need task-specific TriggerParser or similar node to interpret task events
and convert them to SampleTriggerMessage.
INPUT_VALUE = ez.InputStream(AxisArray) takes in the continuous signal to be used as the labels and buffers them.
INPUT_SIGNAL = ez.InputStream(AxisArray) takes in the continuous data and buffers them.
OUTPUT_SAMPLE = ez.OutputStream(AxisArray)

max_buffer_size: int = 512*1024*1024 to put a cap on memory usage for each of _value_buffer and _signal_buffer

"""

import asyncio
import copy

import ezmsg.core as ez
import numpy as np
from ezmsg.sigproc.sampler import SampleTriggerMessage
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
from ezmsg.util.messages.util import replace

MAX_ONE_TO_ONE_SAMPLE_MISMATCH = 1


def _time_vector(message: AxisArray) -> np.ndarray:
    time_axis = message.axes["time"]
    time_idx = message.get_axis_idx("time")
    n_times = message.data.shape[time_idx]
    if hasattr(time_axis, "data"):
        return np.asarray(time_axis.data)
    return np.asarray(time_axis.value(np.arange(n_times)))


def _sample_spacing(tvec: np.ndarray) -> float | None:
    if len(tvec) < 2:
        return None

    dt = np.abs(np.diff(tvec))
    dt = dt[dt > 0]
    if dt.size == 0:
        return None
    return float(np.median(dt))


def _boundary_slack(*tvecs: np.ndarray) -> float:
    spacings = [_sample_spacing(tvec) for tvec in tvecs]
    spacings = [spacing for spacing in spacings if spacing is not None]
    if not spacings:
        return 0.0
    return max(spacings)


def _select_best_aligned_window(
    longer_inds: np.ndarray,
    longer_tvec: np.ndarray,
    shorter_tvec: np.ndarray,
) -> tuple[np.ndarray, int, float]:
    """Crop a longer contiguous slice to the subwindow best aligned to shorter_tvec."""
    target_len = len(shorter_tvec)
    extra = len(longer_inds) - target_len

    if target_len <= 0:
        return longer_inds[:0], 0, 0.0

    best_offset = 0
    best_cost = float("inf")
    for offset in range(extra + 1):
        candidate_tvec = longer_tvec[offset : offset + target_len]
        cost = float(np.mean(np.abs(candidate_tvec - shorter_tvec)))
        if cost < best_cost:
            best_offset = offset
            best_cost = cost

    aligned_inds = longer_inds[best_offset : best_offset + target_len]
    return aligned_inds, best_offset, best_cost


def _align_keep_indices(
    signal_inds: np.ndarray,
    signal_tvec: np.ndarray,
    value_inds: np.ndarray,
    value_tvec: np.ndarray,
    trig: SampleTriggerMessage,
) -> tuple[np.ndarray, np.ndarray] | None:
    length_delta = abs(len(signal_inds) - len(value_inds))

    if length_delta == 0:
        return signal_inds, value_inds

    if len(signal_inds) == 0 or len(value_inds) == 0:
        return None

    if length_delta > MAX_ONE_TO_ONE_SAMPLE_MISMATCH:
        ez.logger.warning(
            "SeqSeqSampler could not align mismatched trigger window: "
            f"signal_len={len(signal_inds)} value_len={len(value_inds)} "
            f"trigger_timestamp={trig.timestamp} trigger_period={trig.period}"
        )
        return None

    if len(signal_inds) > len(value_inds):
        aligned_signal_inds, offset, cost = _select_best_aligned_window(signal_inds, signal_tvec, value_tvec)
        ez.logger.warning(
            "SeqSeqSampler aligned mismatched trigger window by trimming signal: "
            f"signal_len={len(signal_inds)} value_len={len(value_inds)} "
            f"offset={offset} mean_abs_dt={cost:.6f}s "
            f"trigger_timestamp={trig.timestamp} trigger_period={trig.period}"
        )
        return aligned_signal_inds, value_inds

    aligned_value_inds, offset, cost = _select_best_aligned_window(value_inds, value_tvec, signal_tvec)
    ez.logger.warning(
        "SeqSeqSampler aligned mismatched trigger window by trimming value: "
        f"signal_len={len(signal_inds)} value_len={len(value_inds)} "
        f"offset={offset} mean_abs_dt={cost:.6f}s "
        f"trigger_timestamp={trig.timestamp} trigger_period={trig.period}"
    )
    return signal_inds, aligned_value_inds


[docs] class SeqSeqSampler:
[docs] def __init__( self, max_buffer_dur: float = 5.0, ): self._max_buffer_dur = max_buffer_dur self._trig_queue: asyncio.Queue[SampleTriggerMessage] = asyncio.Queue() self._value_buffer: AxisArray | None = None self._signal_buffer: AxisArray | None = None
def __aiter__(self): self._trig_queue: asyncio.Queue[SampleTriggerMessage] = asyncio.Queue() self._value_buffer: AxisArray | None = None self._signal_buffer: AxisArray | None = None return self
[docs] async def asend(self, message: AxisArray): await self.enqueue_signal(message) return await self.__anext__()
[docs] async def enqueue_signal(self, message: AxisArray): self._update_buffer(message, "signal")
[docs] async def enqueue_value(self, message: AxisArray): self._update_buffer(message, "value")
[docs] async def enqueue_trigger(self, message: SampleTriggerMessage): if isinstance(message.value, str) and message.value == "end": # TODO: For each trigger currently in self._trig_queue, overwrite its # `.period[1]` with the incoming trigger's .timestamp print("TODO") else: await self._trig_queue.put(message)
async def __anext__(self): try: trig = self._trig_queue.get_nowait() except asyncio.QueueEmpty: return None samp_msg, keep_waiting = self._process_trigger(trig) if keep_waiting: # Trigger could not be processed fully because buffers did not satisfy the period. await self._trig_queue.put(trig) self._trig_queue.task_done() return samp_msg def _process_trigger(self, trig: SampleTriggerMessage) -> tuple[AxisArray | None, bool]: if trig.period is None: ez.logger.warning("SeqSeqSampler dropped trigger without a period.") return None, False trig_range = trig.timestamp + np.array(trig.period) if self._value_buffer is None or self._signal_buffer is None: return None, True val_tvec = _time_vector(self._value_buffer) sig_tvec = _time_vector(self._signal_buffer) if val_tvec.size == 0 or sig_tvec.size == 0: return None, True boundary_slack = _boundary_slack(val_tvec, sig_tvec) if trig_range[0] < (val_tvec[0] - boundary_slack) or trig_range[0] < (sig_tvec[0] - boundary_slack): ez.logger.warning( "SeqSeqSampler dropped trigger before buffers could satisfy it: " f"signal_span=({sig_tvec[0]}, {sig_tvec[-1]}) " f"value_span=({val_tvec[0]}, {val_tvec[-1]}) " f"trigger_timestamp={trig.timestamp} trigger_period={trig.period}" ) return None, False if trig_range[1] > val_tvec[-1] or trig_range[1] > sig_tvec[-1]: return None, True value_keep_inds = np.where(np.logical_and(val_tvec >= trig_range[0], val_tvec < trig_range[1]))[0] signal_keep_inds = np.where(np.logical_and(sig_tvec >= trig_range[0], sig_tvec < trig_range[1]))[0] if len(value_keep_inds) == 0 or len(signal_keep_inds) == 0: ez.logger.warning( "SeqSeqSampler could not slice trigger window: " f"signal_len={len(signal_keep_inds)} value_len={len(value_keep_inds)} " f"trigger_timestamp={trig.timestamp} trigger_period={trig.period}" ) return None, False aligned_keep_inds = _align_keep_indices( signal_keep_inds, sig_tvec[signal_keep_inds], value_keep_inds, val_tvec[value_keep_inds], trig, ) if aligned_keep_inds is None: return None, False signal_keep_inds, value_keep_inds = aligned_keep_inds messages: dict[str, AxisArray] = {} for buf_name, buffer, tvec, keep_inds in [ ("value", self._value_buffer, val_tvec, value_keep_inds), ("signal", self._signal_buffer, sig_tvec, signal_keep_inds), ]: if hasattr(buffer.axes["time"], "data"): new_time_ax = replace(buffer.axes["time"], data=tvec[keep_inds]) else: new_time_ax = replace(buffer.axes["time"], offset=tvec[keep_inds[0]]) new_dat = slice_along_axis( buffer.data, slice(keep_inds[0], keep_inds[-1] + 1), axis=buffer.get_axis_idx("time"), ) new_msg = replace( buffer, data=new_dat, axes={ **buffer.axes, "time": new_time_ax, }, ) messages[buf_name] = new_msg sample_trigger = copy.copy(trig) sample_trigger.value = messages["value"] samp_msg = replace( messages["signal"], attrs={**messages["signal"].attrs, "trigger": sample_trigger}, ) return samp_msg, False def _update_buffer(self, message: AxisArray, target: str): if target == "value": buffer = self._value_buffer elif target == "signal": buffer = self._signal_buffer else: raise ValueError(f"Invalid target: {target}") ax_ix = message.get_axis_idx("time") # TODO: Check if we need to reset the buffer because the input changed. if buffer is None: buffer = copy.deepcopy(message) if target == "value": self._value_buffer = buffer elif target == "signal": self._signal_buffer = buffer else: buffer.data = np.concatenate([buffer.data, message.data], axis=ax_ix) if hasattr(buffer.axes["time"], "data"): buffer.axes["time"].data = np.concatenate([buffer.axes["time"].data, message.axes["time"].data], axis=0) # No need for `else:` condition because offset does not change. # Trim down to self._max_buffer_dur if hasattr(buffer.axes["time"], "data"): tvec = buffer.axes["time"].data else: n_times = buffer.data.shape[buffer.get_axis_idx("time")] tvec = buffer.axes["time"].value(np.arange(n_times)) t_min = tvec[-1] - self._max_buffer_dur b_keep = tvec >= t_min if not np.all(b_keep): keep_inds = np.where(b_keep)[0] buffer.data = slice_along_axis(buffer.data, slice(keep_inds[0], keep_inds[-1] + 1), ax_ix) tvec = tvec[keep_inds] if hasattr(buffer.axes["time"], "data"): buffer.axes["time"].data = tvec else: buffer.axes["time"].offset = tvec[0]
[docs] class SeqSeqSamplerSettings(ez.Settings): max_buffer_dur: float = 5.0
[docs] class SeqSeqSamplerState(ez.State): core: SeqSeqSampler
[docs] class SeqSeqSamplerUnit(ez.Unit): SETTINGS = SeqSeqSamplerSettings STATE = SeqSeqSamplerState INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage) INPUT_VALUE = ez.InputStream(AxisArray) INPUT_SIGNAL = ez.InputStream(AxisArray) OUTPUT_SAMPLE = ez.OutputStream(AxisArray)
[docs] async def initialize(self): self.STATE.core = SeqSeqSampler(max_buffer_dur=self.SETTINGS.max_buffer_dur)
[docs] @ez.subscriber(INPUT_TRIGGER) async def on_trigger(self, message: SampleTriggerMessage): await self.STATE.core.enqueue_trigger(message)
[docs] @ez.subscriber(INPUT_VALUE) async def on_value(self, message: AxisArray): await self.STATE.core.enqueue_value(message)
[docs] @ez.subscriber(INPUT_SIGNAL) async def on_signal(self, message: AxisArray): await self.STATE.core.enqueue_signal(message)
[docs] @ez.publisher(OUTPUT_SAMPLE) async def send_sample(self): while True: result: AxisArray = await anext(self.STATE.core) if result is not None: yield self.OUTPUT_SAMPLE, result else: # No sample could be produced. Try again later. await asyncio.sleep(0.005)