import ezmsg.core as ez
import numba
import numpy as np
import numpy.typing as npt
import sparse
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
@numba.jit(nopython=True, cache=True)
def _inhomogeneous_poisson_generator(
rates: np.ndarray, # (n_bins,) rates for this channel
accumulated: float, # initial accumulated value
threshold: float, # initial threshold
bin_duration: float,
output_fs: float,
max_events: int,
) -> tuple[np.ndarray, int, float, float]:
"""
Inhomogeneous Poisson process event generator using the integration method.
Returns:
event_samples: pre-allocated array of event sample indices
n_events: actual number of events generated
accumulated: updated accumulated value for next chunk
threshold: updated threshold for next chunk
"""
event_samples = np.empty(max_events, dtype=np.int64)
n_events = 0
n_bins = len(rates)
for t in range(n_bins):
bin_start = t * bin_duration
rate = rates[t]
time_in_bin = 0.0
while True:
time_to_event = (threshold - accumulated) / rate
event_time = time_in_bin + time_to_event
if event_time >= bin_duration:
# No more events in this bin
accumulated += rate * (bin_duration - time_in_bin)
break
# Record event
if n_events < max_events:
event_sample = int((event_time + bin_start) * output_fs)
event_samples[n_events] = event_sample
n_events += 1
# Update state for next event
time_in_bin = event_time
accumulated = 0.0
threshold = np.random.exponential(1.0)
return event_samples, n_events, accumulated, threshold
@numba.jit(nopython=True, parallel=True, cache=True)
def _generate_events_all_channels(
rates_array: np.ndarray, # (n_bins, n_channels)
accumulated: np.ndarray, # (n_channels,)
threshold: np.ndarray, # (n_channels,)
bin_duration: float,
output_fs: float,
max_events_per_channel: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Generate events for all channels in parallel.
Returns:
all_event_samples: (n_channels, max_events_per_channel) event sample indices
event_counts: (n_channels,) number of events per channel
accumulated_out: (n_channels,) updated accumulated values
threshold_out: (n_channels,) updated thresholds
"""
n_bins, n_channels = rates_array.shape
# Pre-allocate output arrays
all_event_samples = np.empty((n_channels, max_events_per_channel), dtype=np.int64)
event_counts = np.empty(n_channels, dtype=np.int64)
accumulated_out = np.empty(n_channels, dtype=np.float64)
threshold_out = np.empty(n_channels, dtype=np.float64)
# Process each channel in parallel
for ch in numba.prange(n_channels):
rates = rates_array[:, ch]
samples, count, acc, thresh = _inhomogeneous_poisson_generator(
rates,
accumulated[ch],
threshold[ch],
bin_duration,
output_fs,
max_events_per_channel,
)
all_event_samples[ch, :] = samples
event_counts[ch] = count
accumulated_out[ch] = acc
threshold_out[ch] = thresh
return all_event_samples, event_counts, accumulated_out, threshold_out
@numba.jit(nopython=True, cache=True)
def _flatten_events_unsorted(
all_event_samples: np.ndarray, # (n_channels, max_events)
event_counts: np.ndarray, # (n_channels,)
) -> tuple[np.ndarray, np.ndarray]:
"""Flatten per-channel event arrays into coordinate arrays (unsorted)."""
total_events = np.sum(event_counts)
if total_events == 0:
return np.zeros(0, dtype=np.int64), np.zeros(0, dtype=np.int64)
event_samples = np.empty(total_events, dtype=np.int64)
event_channels = np.empty(total_events, dtype=np.int64)
idx = 0
for ch in range(len(event_counts)):
count = event_counts[ch]
if count > 0:
for i in range(count):
event_samples[idx + i] = all_event_samples[ch, i]
event_channels[idx + i] = ch
idx += count
return event_samples, event_channels
[docs]
class PoissonEventSettings(ez.Settings):
output_fs: float = 30_000
"""Output sampling rate."""
layout: str = "coo"
"""Layout of the output event train sparse array. Options are 'coo' or 'gcxs'"""
compress_dims: list[int] | None = None
"""Dimensions to compress. Ignored if layout is 'coo'."""
assume_counts: bool = False
"""If True, input is event counts per bin. If False, input is firing rate in Hz."""
min_rate: float = 1e-6
"""Minimum rate to avoid division by zero."""
max_rate: float = 500.0
"""Maximum expected firing rate (Hz). Used to pre-allocate event arrays."""
[docs]
@processor_state
class PoissonEventState:
accumulated: npt.NDArray | None = None
"""Integrated rate since last event for each channel."""
threshold: npt.NDArray | None = None
"""Exp(1) threshold for next event for each channel."""
[docs]
class PoissonEventUnit(BaseTransformerUnit[PoissonEventSettings, AxisArray, AxisArray, PoissonEventTransformer]):
SETTINGS = PoissonEventSettings