Source code for ezmsg.simbiophys.dnss.spike

"""Produce data mimicking Blackrock Neurotech's Digital Neural Signal Simulator"""

from typing import Generator

import numpy as np
import numpy.typing as npt
import sparse
from ezmsg.baseproc import (
    BaseClockDrivenProducer,
    BaseClockDrivenUnit,
    ClockDrivenSettings,
    ClockDrivenState,
    processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, replace

"""
## Spike Pattern

We know the overall spiking pattern and the individual waveforms from the DNSS source code.

The pattern during single spikes:

```
 ** Ch 1 1-----------2-----------3-----------1-----------2- ... and
 ** Ch 2 ---2-----------3-----------1-----------2---------- ... on
 ** Ch 3 ------3-----------1-----------2-----------3------- ... and
 ** Ch 4 ---------1-----------2-----------3-----------1---- ... on
 ** Ch 5 1-----------2-----------3-----------1-----------2- ... and
```

The first waveform is inserted immediately into channels 1::4 as soon as the pattern starts.
Waveforms are inserted 7500 samples (0.25 seconds). At each iteration, the next waveform (% 3 waveforms)
is inserted into the next channel (% 4 channels).
This repeats until there are 36 spike + gap periods (36 * 7500), for 9 seconds total.


During bursting:

```
 ** Ch 1 1---2---3---1---2---3---1---2---3---1---2---3---1- ... and
 ** Ch 2 1---2---3---1---2---3---1---2---3---1---2---3---1- ... on
 ** Ch 3 1---2---3---1---2---3---1---2---3---1---2---3---1- ... and
 ** Ch 4 1---2---3---1---2---3---1---2---3---1---2---3---1- ... on
 ** Ch 5 1---2---3---1---2---3---1---2---3---1---2---3---1- ... and
```

Bursts begin immediately after the last gap following the 36th spike during the slow period.
During bursting, a spike occurs in all channels simultaneously with the same waveform.
Each spike in a burst comprises a spike waveform + gap for 300 total samples (0.01 s).
Waveforms cycle in order, 1, 2, 3, 1, 2, 3, and so on.
There are 99 total spikes in a burst for 0.99 second total.

A burst ends with a 300 sample gap (so 600 from the onset of the last spike in the burst)
before the slow period begins again.

** WARNING **

The above pattern describes what's supposed to happen, but there are 2 bugs when using the DNSS' HDMI output.

* First bug: The pattern actually starts on channel 2, not channel 1.
* Second bug: The first channel in a bank has its spikes (but not LFPs) delayed by 1 sample.
"""


# Pattern constants
INT_SLOW = 7500  # Samples between slow spikes
N_SLOW_SPIKES = 36
SAMPS_SLOW = INT_SLOW * N_SLOW_SPIKES  # 270,000 samples (9 seconds)

INT_BURST = 300  # Samples between burst spikes
N_BURST_SPIKES = 99
GAP_BURST = 300  # Gap after last burst spike
SAMPS_BURST = INT_BURST * N_BURST_SPIKES + GAP_BURST  # 30,000 samples (1 second)

FULL_PERIOD = SAMPS_SLOW + SAMPS_BURST  # 300,000 samples (10 seconds)

# Sample rate (default for DNSS)
FS = 30_000


def _spikes_in_range(start: int, end: int, burst: bool = False) -> npt.NDArray[np.int_]:
    """
    Return indices of spikes occurring in [start, end) for slow or burst phase.

    Spike i occurs at position i * interval. Returns array of i values
    where start <= i * interval < end.

    Args:
        start: Start of range (inclusive).
        end: End of range (exclusive).
        burst: If True, use burst phase parameters; otherwise use slow phase.
    """
    interval = INT_BURST if burst else INT_SLOW
    n_spikes = N_BURST_SPIKES if burst else N_SLOW_SPIKES

    if end <= start:
        return np.array([], dtype=np.int_)

    # i * interval >= start => i >= ceil(start / interval)
    i_min = int(np.ceil(start / interval))
    # i * interval < end => i < end / interval => i <= floor((end - 1) / interval)
    i_max = (end - 1) // interval + 1  # exclusive upper bound

    i_min = max(0, i_min)
    i_max = min(n_spikes, i_max)

    if i_min >= i_max:
        return np.array([], dtype=np.int_)

    return np.arange(i_min, i_max, dtype=np.int_)


[docs] def spike_event_generator( mode: str = "hdmi", n_chans: int = 4, ) -> Generator[ tuple[npt.NDArray[np.int_], npt.NDArray[np.int_]], int, None, ]: """ Generator yielding spike event indices for the DNSS pattern. This is a send-able generator. After priming with next(), use send(n_samples) to get spikes for the next n_samples window. The generator maintains internal state tracking the current sample position. Args: mode: "hdmi" to reproduce HDMI bugs, "ideal" for ideal pattern. n_chans: Number of channels in the slow-phase rotation (default 4). Yields: Tuple of (coords, waveform_ids): - coords: Shape (2, n_spikes) array of [sample_indices, channel_indices] - waveform_ids: Waveform shape identifiers (1, 2, or 3) Example: gen = spike_event_generator() next(gen) # Prime the generator coords, waveforms = gen.send(30000) # Get spikes in first 30000 samples coords, waveforms = gen.send(15000) # Get spikes in next 15000 samples """ hdmi_mode = mode.lower() == "hdmi" ch_offset = 1 if hdmi_mode else 0 # HDMI bug #1: pattern starts on channel 1 current_sample = 0 empty_coords = np.array([[], []], dtype=np.int_) empty_waveforms = np.array([], dtype=np.int_) n_samples = yield None # Prime - caller does next(gen) while True: if n_samples is None or n_samples <= 0: n_samples = yield (empty_coords, empty_waveforms) continue window_start = current_sample window_end = current_sample + n_samples result_arrays: list[tuple[npt.NDArray[np.int_], npt.NDArray[np.int_], npt.NDArray[np.int_]]] = [] # Process the window in chunks that stay within a single period pos = window_start while pos < window_end: pos_in_period = pos % FULL_PERIOD remaining = window_end - pos chunk_end_in_period = min(pos_in_period + remaining, FULL_PERIOD) window_offset = pos - window_start # Offset from window_start for this chunk # === Slow phase: [0, SAMPS_SLOW) === if pos_in_period < SAMPS_SLOW: slow_start = pos_in_period slow_end = min(chunk_end_in_period, SAMPS_SLOW) spike_indices = _spikes_in_range(slow_start, slow_end, burst=False) if len(spike_indices) > 0: n_spikes = len(spike_indices) # Compute base values for each spike spike_pos_in_period = spike_indices * INT_SLOW base_sample_idxs = spike_pos_in_period - pos_in_period + window_offset base_channels = (spike_indices + ch_offset) % 4 # Always 4-channel base pattern base_waveforms = (spike_indices % 3) + 1 # The 4-channel pattern tiles across all channels # e.g., with n_chans=5, base ch 0 fires on ch 0 and 4 # For n_chans < 4, only spikes on channels 0..n_chans-1 are present max_tiles = (n_chans + 3) // 4 # Ceiling division sample_idxs = np.repeat(base_sample_idxs, max_tiles) waveforms = np.repeat(base_waveforms, max_tiles) base_channels_repeated = np.repeat(base_channels, max_tiles) tile_indices = np.tile(np.arange(max_tiles, dtype=np.int_), n_spikes) channels = base_channels_repeated + tile_indices * 4 # Filter to valid channels (handles both partial tiles and n_chans < 4) valid_mask = channels < n_chans sample_idxs = sample_idxs[valid_mask] waveforms = waveforms[valid_mask] channels = channels[valid_mask] # HDMI bug #2: channel 0 spikes delayed by 1 sample if hdmi_mode: sample_idxs = sample_idxs + (channels == 0).astype(np.int_) if len(channels) > 0: result_arrays.append((sample_idxs, channels, waveforms)) # === Burst phase: [SAMPS_SLOW, SAMPS_SLOW + N_BURST_SPIKES * INT_BURST) === # Spikes occur at SAMPS_SLOW + i * INT_BURST for i in 0..98 # The final GAP_BURST samples have no spikes burst_spike_end = SAMPS_SLOW + N_BURST_SPIKES * INT_BURST if chunk_end_in_period > SAMPS_SLOW and pos_in_period < burst_spike_end: burst_start = max(pos_in_period, SAMPS_SLOW) burst_end = min(chunk_end_in_period, burst_spike_end) # Convert to relative positions within burst phase rel_start = burst_start - SAMPS_SLOW rel_end = burst_end - SAMPS_SLOW spike_indices = _spikes_in_range(rel_start, rel_end, burst=True) if len(spike_indices) > 0: n_spikes = len(spike_indices) # Compute base positions and waveforms for each spike spike_pos_in_period = SAMPS_SLOW + spike_indices * INT_BURST base_sample_idxs = spike_pos_in_period - pos_in_period + window_offset base_waveforms = (spike_indices % 3) + 1 # Expand: each spike fires on all n_chans channels sample_idxs = np.repeat(base_sample_idxs, n_chans) waveforms = np.repeat(base_waveforms, n_chans) channels = np.tile(np.arange(n_chans, dtype=np.int_), n_spikes) # HDMI bug #2: channel 0 spikes delayed by 1 sample if hdmi_mode: sample_idxs = sample_idxs + (channels == 0).astype(np.int_) result_arrays.append((sample_idxs, channels, waveforms)) pos += chunk_end_in_period - pos_in_period current_sample = window_end # Concatenate all result arrays and build coords if result_arrays: sample_idxs = np.concatenate([r[0] for r in result_arrays]) chan_idxs = np.concatenate([r[1] for r in result_arrays]) waveform_ids = np.concatenate([r[2] for r in result_arrays]) coords = np.array([sample_idxs, chan_idxs], dtype=np.int_) n_samples = yield (coords, waveform_ids) else: n_samples = yield (empty_coords, empty_waveforms)
# ============================================================================= # Transformer-based implementation # =============================================================================
[docs] class DNSSSpikeSettings(ClockDrivenSettings): """Settings for DNSS spike producer.""" fs: float = FS """Sample rate in Hz. DNSS is fixed at 30kHz.""" n_ch: int = 256 """Number of channels.""" mode: str = "hdmi" """Mode: "hdmi" reproduces HDMI bugs, "ideal" for ideal pattern."""
[docs] @processor_state class DNSSSpikeState(ClockDrivenState): """State for DNSS spike producer.""" spike_gen: Generator | None = None template: AxisArray | None = None
[docs] class DNSSSpikeProducer(BaseClockDrivenProducer[DNSSSpikeSettings, DNSSSpikeState]): """ Produces DNSS spike signal synchronized to clock ticks. Each clock tick produces a block of spike data as sparse COO arrays based on the sample rate (fs) and chunk size (n_time) settings. """ def _reset_state(self, time_axis: LinearAxis) -> None: """Initialize the spike generator.""" # Verify sample rate is 30kHz - spike patterns are tied to this rate if not np.isclose(self.settings.fs, FS, rtol=1e-6): raise ValueError( f"DNSSSpikeProducer requires fs={FS} Hz, " f"but settings.fs={self.settings.fs:.1f} Hz. " f"Spike patterns cannot be resampled to other rates." ) self._state.spike_gen = spike_event_generator( mode=self.settings.mode, n_chans=self.settings.n_ch, ) next(self._state.spike_gen) # Pre-construct template AxisArray with channel axis self._state.template = AxisArray( data=sparse.COO( coords=np.array([[], []], dtype=np.int_), data=np.array([], dtype=np.int_), shape=(0, self.settings.n_ch), ), dims=["time", "ch"], axes={ "time": time_axis, "ch": AxisArray.CoordinateAxis( data=np.arange(self.settings.n_ch), dims=["ch"], ), }, ) def _produce(self, n_samples: int, time_axis: LinearAxis) -> AxisArray: """Generate spike signal for this chunk.""" # Generate spike events coords, waveform_ids = self._state.spike_gen.send(n_samples) # Create sparse COO array spike_data = sparse.COO( coords=coords, data=waveform_ids, shape=(n_samples, self.settings.n_ch), ) return replace( self._state.template, data=spike_data, axes={ **self._state.template.axes, "time": time_axis, }, )
[docs] class DNSSSpikeUnit(BaseClockDrivenUnit[DNSSSpikeSettings, DNSSSpikeProducer]): """Unit for generating DNSS spikes from clock input.""" SETTINGS = DNSSSpikeSettings