Source code for ezmsg.event.kernel_insert

"""
Insert kernels at sparse event locations to produce dense signals.

This module provides efficient sparse-to-dense conversion by inserting
kernel waveforms at event locations. Overlapping kernels are summed.
"""

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
from ezmsg.util.messages.axisarray import AxisArray, replace

from .kernel import Kernel, MultiKernel


[docs] class SparseKernelInserterSettings(ez.Settings): """Settings for SparseKernelInserter.""" kernel: Kernel | MultiKernel | None = None """ Kernel to insert at event locations. - Kernel: Same kernel for all events. - MultiKernel: Different kernels based on event value. - None: Events are treated as unit impulses (delta functions). """ scale_by_value: bool = False """ If True, scale kernel amplitude by event value. If False, event value is used only for MultiKernel selection. """ output_dtype: npt.DTypeLike = np.float64 """Data type for output array."""
[docs] @processor_state class SparseKernelInserterState: """State for SparseKernelInserter.""" # Pending contributions that overlap into next chunk # Shape: (pending_samples, n_channels) pending: npt.NDArray[np.floating] | None = None # Number of pending samples (may be less than pending.shape[0] if reused) pending_length: int = 0
[docs] class SparseKernelInserter( BaseStatefulTransformer[ SparseKernelInserterSettings, AxisArray, AxisArray, SparseKernelInserterState, ] ): """ Insert kernels at sparse event locations, producing dense output. Input: AxisArray with sparse.COO data where: - coords[0]: sample indices (time) - coords[1]: channel indices - data: event values (used for MultiKernel selection or scaling) Output: AxisArray with dense data containing inserted kernels. Features: - Handles chunk boundaries seamlessly (kernel tails carry over) - Overlapping kernels are summed additively - Supports acausal kernels (pre_samples > 0) - Efficient O(n_events * kernel_length) instead of dense convolution """ def _get_max_kernel_length(self) -> int: """Get maximum kernel length for buffer allocation.""" kernel = self.settings.kernel if kernel is None: return 1 elif isinstance(kernel, MultiKernel): return kernel.max_length else: return kernel.length def _get_max_pre_samples(self) -> int: """Get maximum pre_samples for acausal kernel handling.""" kernel = self.settings.kernel if kernel is None: return 0 elif isinstance(kernel, MultiKernel): return kernel.max_pre_samples else: return kernel.pre_samples def _reset_state(self, message: AxisArray) -> None: """Initialize state for new input stream.""" n_channels = message.data.shape[1] if message.data.ndim > 1 else 1 max_overlap = self._get_max_kernel_length() - 1 if max_overlap > 0: self._state.pending = np.zeros( (max_overlap, n_channels), dtype=self.settings.output_dtype, ) else: self._state.pending = None self._state.pending_length = 0 def _get_kernel_for_value(self, value: int | float) -> tuple[Kernel | None, float]: """ Get kernel and scale factor for an event value. Returns: (kernel, scale) tuple. """ kernel = self.settings.kernel scale = float(value) if self.settings.scale_by_value else 1.0 if kernel is None: return None, scale elif isinstance(kernel, MultiKernel): return kernel.get(int(value)), scale else: return kernel, scale def _process(self, message: AxisArray) -> AxisArray: """Insert kernels at sparse event locations.""" sparse_data = message.data n_samples = sparse_data.shape[0] n_channels = sparse_data.shape[1] if sparse_data.ndim > 1 else 1 # Initialize output array output = np.zeros((n_samples, n_channels), dtype=self.settings.output_dtype) # Add pending contributions from previous chunk if self._state.pending is not None and self._state.pending_length > 0: overlap = min(self._state.pending_length, n_samples) output[:overlap] += self._state.pending[:overlap] # Reset pending for this chunk max_overlap = self._get_max_kernel_length() - 1 if max_overlap > 0: if self._state.pending is None or self._state.pending.shape[0] < max_overlap: self._state.pending = np.zeros( (max_overlap, n_channels), dtype=self.settings.output_dtype, ) else: self._state.pending[:] = 0 self._state.pending_length = 0 # Process each event if hasattr(sparse_data, "coords") and hasattr(sparse_data, "data"): # sparse.COO format coords = sparse_data.coords values = sparse_data.data for event_idx in range(len(values)): sample_idx = int(coords[0, event_idx]) channel_idx = int(coords[1, event_idx]) if coords.shape[0] > 1 else 0 value = values[event_idx] kernel, scale = self._get_kernel_for_value(value) if kernel is None: # Unit impulse if 0 <= sample_idx < n_samples: output[sample_idx, channel_idx] += scale else: # Insert kernel pre = kernel.pre_samples length = kernel.length # Calculate kernel placement range kernel_start = sample_idx - pre kernel_end = kernel_start + length # Portion within current chunk chunk_start = max(0, kernel_start) chunk_end = min(n_samples, kernel_end) if chunk_start < chunk_end: # Kernel indices for this portion k_start = chunk_start - kernel_start k_end = k_start + (chunk_end - chunk_start) # Get kernel values t = np.arange(k_start, k_end, dtype=np.float64) - pre kernel_values = kernel.evaluate(t) * scale output[chunk_start:chunk_end, channel_idx] += kernel_values # Portion that extends into next chunk (pending) if kernel_end > n_samples and self._state.pending is not None: pending_start = max(0, n_samples - kernel_start) pending_end = length pending_samples = pending_end - pending_start if pending_samples > 0: # Kernel indices for pending portion t = np.arange(pending_start, pending_end, dtype=np.float64) - pre kernel_values = kernel.evaluate(t) * scale # Add to pending buffer self._state.pending[:pending_samples, channel_idx] += kernel_values self._state.pending_length = max( self._state.pending_length, pending_samples, ) # Create output message return replace(message, data=output)
[docs] class SparseKernelInserterUnit( BaseTransformerUnit[ SparseKernelInserterSettings, AxisArray, AxisArray, SparseKernelInserter, ] ): """Unit for SparseKernelInserter.""" SETTINGS = SparseKernelInserterSettings