Source code for ezmsg.sigproc.util.axisarray_buffer

import math
import typing

from array_api_compat import get_namespace
import numpy as np
from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, CoordinateAxis
from ezmsg.util.messages.util import replace

from .buffer import HybridBuffer


Array = typing.TypeVar("Array")


[docs] class HybridAxisBuffer: """ A buffer that intelligently handles ezmsg.util.messages.AxisArray _axes_ objects. LinearAxis is maintained internally by tracking its offset, gain, and the number of samples that have passed through. CoordinateAxis has its data values maintained in a `HybridBuffer`. Args: duration: The desired duration of the buffer in seconds. This is non-limiting when managing a LinearAxis. **kwargs: Additional keyword arguments to pass to the underlying HybridBuffer (e.g., `update_strategy`, `threshold`, `overflow_strategy`, `max_size`). """ _coords_buffer: HybridBuffer | None _coords_template: CoordinateAxis | None _coords_gain_estimate: float | None = None _linear_axis: LinearAxis | None _linear_n_available: int
[docs] def __init__(self, duration: float, **kwargs): self.duration = duration self.buffer_kwargs = kwargs # Delay initialization until the first message arrives self._coords_buffer = None self._coords_template = None self._linear_axis = None self._linear_n_available = 0
@property def capacity(self) -> int: """The maximum number of samples that can be stored in the buffer.""" if self._coords_buffer is not None: return self._coords_buffer.capacity elif self._linear_axis is not None: return int(math.ceil(self.duration / self._linear_axis.gain)) else: return 0
[docs] def available(self) -> int: if self._coords_buffer is None: return self._linear_n_available return self._coords_buffer.available()
[docs] def is_empty(self) -> bool: return self.available() == 0
[docs] def is_full(self) -> bool: if self._coords_buffer is not None: return self._coords_buffer.is_full() return 0 < self.capacity == self.available()
def _initialize(self, first_axis: LinearAxis | CoordinateAxis) -> None: if hasattr(first_axis, "data"): # Initialize a CoordinateAxis buffer if len(first_axis.data) > 1: _axis_gain = (first_axis.data[-1] - first_axis.data[0]) / ( len(first_axis.data) - 1 ) else: _axis_gain = 1.0 self._coords_gain_estimate = _axis_gain capacity = int(self.duration / _axis_gain) self._coords_buffer = HybridBuffer( get_namespace(first_axis.data), capacity, other_shape=(), dtype=first_axis.data.dtype, **self.buffer_kwargs, ) self._coords_template = replace(first_axis, data=first_axis.data[:0].copy()) else: # Initialize a LinearAxis buffer self._linear_axis = replace(first_axis, offset=first_axis.offset) self._linear_n_available = 0
[docs] def write(self, axis: LinearAxis | CoordinateAxis, n_samples: int) -> None: if self._linear_axis is None and self._coords_buffer is None: self._initialize(axis) if self._coords_buffer is not None: if axis.__class__ is not self._coords_template.__class__: raise TypeError( f"Buffer initialized with {self._coords_template.__class__.__name__}, " f"but received {axis.__class__.__name__}." ) self._coords_buffer.write(axis.data) else: if axis.__class__ is not self._linear_axis.__class__: raise TypeError( f"Buffer initialized with {self._linear_axis.__class__.__name__}, " f"but received {axis.__class__.__name__}." ) if axis.gain != self._linear_axis.gain: raise ValueError( f"Buffer initialized with gain={self._linear_axis.gain}, " f"but received gain={axis.gain}." ) if self._linear_n_available + n_samples > self.capacity: # Simulate overflow by advancing the offset and decreasing # the number of available samples. n_to_discard = self._linear_n_available + n_samples - self.capacity self.seek(n_to_discard) # Update the offset corresponding to the oldest sample in the buffer # by anchoring on the new offset and accounting for the samples already available. self._linear_axis.offset = ( axis.offset - self._linear_n_available * axis.gain ) self._linear_n_available += n_samples
[docs] def peek(self, n_samples: int | None = None) -> LinearAxis | CoordinateAxis: if self._coords_buffer is not None: return replace( self._coords_template, data=self._coords_buffer.peek(n_samples) ) else: # Return a shallow copy. return replace(self._linear_axis, offset=self._linear_axis.offset)
[docs] def seek(self, n_samples: int) -> int: if self._coords_buffer is not None: return self._coords_buffer.seek(n_samples) else: n_to_seek = min(n_samples, self._linear_n_available) self._linear_n_available -= n_to_seek self._linear_axis.offset += n_to_seek * self._linear_axis.gain return n_to_seek
[docs] def prune(self, n_samples: int) -> int: """Discards all but the last n_samples from the buffer.""" n_to_discard = self.available() - n_samples if n_to_discard <= 0: return 0 return self.seek(n_to_discard)
@property def final_value(self) -> float | None: """ The axis-value (timestamp, typically) of the last sample in the buffer. This does not advance the read head. """ if self._coords_buffer is not None: return self._coords_buffer.peek_last()[0] elif self._linear_axis is not None: return self._linear_axis.value(self._linear_n_available - 1) else: return None @property def first_value(self) -> float | None: """ The axis-value (timestamp, typically) of the first sample in the buffer. This does not advance the read head. """ if self.available() == 0: return None if self._coords_buffer is not None: return self._coords_buffer.peek_at(0)[0] elif self._linear_axis is not None: return self._linear_axis.value(0) else: return None @property def gain(self) -> float | None: if self._coords_buffer is not None: return self._coords_gain_estimate elif self._linear_axis is not None: return self._linear_axis.gain else: return None
[docs] def searchsorted( self, values: typing.Union[float, Array], side: str = "left" ) -> typing.Union[int, Array]: if self._coords_buffer is not None: return self._coords_buffer.xp.searchsorted( self._coords_buffer.peek(self.available()), values, side=side ) else: if self.available() == 0: if isinstance(values, float): return 0 else: _xp = get_namespace(values) return _xp.zeros_like(values, dtype=int) f_inds = (values - self._linear_axis.offset) / self._linear_axis.gain res = np.ceil(f_inds) if side == "right": res[np.isclose(f_inds, res)] += 1 return res.astype(int)
[docs] class HybridAxisArrayBuffer: """A buffer that intelligently handles ezmsg.util.messages.AxisArray objects. This buffer defers its own initialization until the first message arrives, allowing it to automatically configure its size, shape, dtype, and array backend (e.g., NumPy, CuPy) based on the message content and a desired buffer duration. Args: duration: The desired duration of the buffer in seconds. axis: The name of the axis to buffer along. **kwargs: Additional keyword arguments to pass to the underlying HybridBuffer (e.g., `update_strategy`, `threshold`, `overflow_strategy`, `max_size`). """ _data_buffer: HybridBuffer | None _axis_buffer: HybridAxisBuffer _template_msg: AxisArray | None
[docs] def __init__(self, duration: float, axis: str = "time", **kwargs): self.duration = duration self._axis = axis self.buffer_kwargs = kwargs self._axis_buffer = HybridAxisBuffer(duration=duration, **kwargs) # Delay initialization until the first message arrives self._data_buffer = None self._template_msg = None
[docs] def available(self) -> int: """The total number of unread samples currently available in the buffer.""" if self._data_buffer is None: return 0 return self._data_buffer.available()
[docs] def is_empty(self) -> bool: return self.available() == 0
[docs] def is_full(self) -> bool: return 0 < self._data_buffer.capacity == self.available()
@property def axis_first_value(self) -> float | None: """The axis-value (timestamp, typically) of the first sample in the buffer.""" return self._axis_buffer.first_value @property def axis_final_value(self) -> float | None: """The axis-value (timestamp, typically) of the last sample in the buffer.""" return self._axis_buffer.final_value def _initialize(self, first_msg: AxisArray) -> None: # Create a template message that has everything except the data are length 0 # and the target axis is missing. self._template_msg = replace( first_msg, data=first_msg.data[:0], axes={k: v for k, v in first_msg.axes.items() if k != self._axis}, ) in_axis = first_msg.axes[self._axis] self._axis_buffer._initialize(in_axis) capacity = int(self.duration / self._axis_buffer.gain) self._data_buffer = HybridBuffer( get_namespace(first_msg.data), capacity, other_shape=first_msg.data.shape[1:], dtype=first_msg.data.dtype, **self.buffer_kwargs, )
[docs] def write(self, msg: AxisArray) -> None: """Adds an AxisArray message to the buffer, initializing on the first call.""" in_axis_idx = msg.get_axis_idx(self._axis) if in_axis_idx > 0: # This class assumes that the target axis is the first axis. # If it is not, we move it to the front. dims = list(msg.dims) dims.insert(0, dims.pop(in_axis_idx)) _xp = get_namespace(msg.data) msg = replace(msg, data=_xp.moveaxis(msg.data, in_axis_idx, 0), dims=dims) if self._data_buffer is None: self._initialize(msg) self._data_buffer.write(msg.data) self._axis_buffer.write(msg.axes[self._axis], msg.shape[0])
[docs] def peek(self, n_samples: int | None = None) -> AxisArray | None: """Retrieves the oldest unread data as a new AxisArray without advancing the read head.""" if self._data_buffer is None: return None data_array = self._data_buffer.peek(n_samples) if data_array is None: return None out_axis = self._axis_buffer.peek(n_samples) return replace( self._template_msg, data=data_array, axes={**self._template_msg.axes, self._axis: out_axis}, )
[docs] def peek_axis( self, n_samples: int | None = None ) -> LinearAxis | CoordinateAxis | None: """Retrieves the axis data without advancing the read head.""" if self._data_buffer is None: return None out_axis = self._axis_buffer.peek(n_samples) if out_axis is None: return None return out_axis
[docs] def seek(self, n_samples: int) -> int: """Advances the read pointer by n_samples.""" if self._data_buffer is None: return 0 skipped_data_count = self._data_buffer.seek(n_samples) axis_skipped = self._axis_buffer.seek(skipped_data_count) assert ( axis_skipped == skipped_data_count ), f"Axis buffer skipped {axis_skipped} samples, but data buffer skipped {skipped_data_count}." return skipped_data_count
[docs] def read(self, n_samples: int | None = None) -> AxisArray | None: """Retrieves the oldest unread data as a new AxisArray and advances the read head.""" retrieved_axis_array = self.peek(n_samples) if retrieved_axis_array is None or retrieved_axis_array.shape[0] == 0: return None self.seek(retrieved_axis_array.shape[0]) return retrieved_axis_array
[docs] def prune(self, n_samples: int) -> int: """Discards all but the last n_samples from the buffer.""" if self._data_buffer is None: return 0 n_to_discard = self.available() - n_samples if n_to_discard <= 0: return 0 return self.seek(n_to_discard)
@property def axis_gain(self) -> float | None: """ The gain of the target axis, which is the time step between samples. This is typically the sampling rate (e.g., 1 / fs). """ return self._axis_buffer.gain
[docs] def axis_searchsorted( self, values: typing.Union[float, Array], side: str = "left" ) -> typing.Union[int, Array]: """ Find the indices into which the given values would be inserted into the target axis data to maintain order. """ return self._axis_buffer.searchsorted(values, side=side)