"""
Kernel abstractions for sparse event processing.
Kernels can be applied to sparse events to produce either:
1. Dense signals (via SparseKernelInserter)
2. Binned activation features (via BinnedKernelActivation)
"""
from abc import ABC, abstractmethod
from typing import Callable
import numpy as np
import numpy.typing as npt
[docs]
class Kernel(ABC):
"""
Base class for kernels applied to sparse events.
A kernel defines a shape that gets inserted/convolved at event locations.
Supports both causal (forward-looking) and acausal (symmetric) kernels.
"""
@property
@abstractmethod
def length(self) -> int:
"""Total kernel length in samples."""
pass
@property
@abstractmethod
def pre_samples(self) -> int:
"""Number of samples before t=0 (for acausal kernels)."""
pass
[docs]
@abstractmethod
def evaluate(self, t: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
"""
Evaluate kernel at time offsets t (in samples relative to event).
Args:
t: Array of time offsets in samples. t=0 is the event time.
Returns:
Kernel values at the given offsets.
"""
pass
@property
def is_causal(self) -> bool:
"""True if kernel is zero for t < 0."""
return self.pre_samples == 0
@property
def post_samples(self) -> int:
"""Number of samples at and after t=0."""
return self.length - self.pre_samples
[docs]
class ArrayKernel(Kernel):
"""
Kernel from explicit array (e.g., spike waveforms).
Args:
data: 1D array of kernel values.
pre_samples: Number of samples before t=0. Default 0 (causal kernel).
For a waveform centered at t=0, use pre_samples = len(data) // 2.
"""
[docs]
def __init__(self, data: npt.NDArray, pre_samples: int = 0):
self._data = np.asarray(data, dtype=np.float64)
if self._data.ndim != 1:
raise ValueError("Kernel data must be 1-dimensional")
self._pre_samples = pre_samples
if pre_samples < 0 or pre_samples > len(self._data):
raise ValueError(f"pre_samples must be in [0, {len(self._data)}]")
@property
def length(self) -> int:
return len(self._data)
@property
def pre_samples(self) -> int:
return self._pre_samples
@property
def data(self) -> npt.NDArray[np.float64]:
"""Raw kernel data array."""
return self._data
[docs]
def evaluate(self, t: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
"""Evaluate kernel at time offsets."""
t = np.asarray(t)
indices = (t + self._pre_samples).astype(int)
valid = (indices >= 0) & (indices < len(self._data))
result = np.zeros(t.shape, dtype=self._data.dtype)
result[valid] = self._data[indices[valid]]
return result
[docs]
class FunctionalKernel(Kernel):
"""
Kernel from a function (e.g., exponential decay, Gaussian).
The function should accept (t, sigma) where t is time in samples
and sigma is the time constant in samples.
Args:
func: Kernel function f(t, sigma) -> values.
sigma: Time constant in seconds.
fs: Sample rate in Hz (for converting sigma to samples).
truncate_at: Truncate kernel at this many time constants. Default 5.0.
causal: If True, kernel is zero for t < 0. Default True.
Example:
>>> kernel = FunctionalKernel(
... func=lambda t, s: (t >= 0) * np.exp(-t / s) / s,
... sigma=0.010, # 10ms
... fs=30000,
... )
"""
[docs]
def __init__(
self,
func: Callable[[npt.NDArray, float], npt.NDArray],
sigma: float,
fs: float,
truncate_at: float = 5.0,
causal: bool = True,
):
self._func = func
self._sigma = sigma
self._fs = fs
self._sigma_samples = sigma * fs
self._causal = causal
if causal:
self._pre = 0
self._length = max(1, int(truncate_at * self._sigma_samples))
else:
# Symmetric kernel
half_length = max(1, int(truncate_at * self._sigma_samples))
self._pre = half_length
self._length = 2 * half_length + 1
@property
def length(self) -> int:
return self._length
@property
def pre_samples(self) -> int:
return self._pre
@property
def sigma(self) -> float:
"""Time constant in seconds."""
return self._sigma
@property
def sigma_samples(self) -> float:
"""Time constant in samples."""
return self._sigma_samples
[docs]
def evaluate(self, t: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
"""Evaluate kernel function at time offsets."""
return self._func(np.asarray(t, dtype=np.float64), self._sigma_samples)
[docs]
class MultiKernel:
"""
Dictionary of kernels indexed by event value.
Useful when different event types (e.g., waveform IDs 1, 2, 3)
should produce different kernel shapes.
Args:
kernels: Dictionary mapping event values to Kernel objects.
default_key: Key to use for unknown event values. Default is first key.
"""
[docs]
def __init__(self, kernels: dict[int, Kernel], default_key: int | None = None):
if not kernels:
raise ValueError("kernels dict cannot be empty")
self._kernels = kernels
self._default_key = default_key if default_key is not None else next(iter(kernels))
[docs]
def get(self, value: int) -> Kernel:
"""Get kernel for event value, falling back to default."""
return self._kernels.get(value, self._kernels[self._default_key])
def __getitem__(self, value: int) -> Kernel:
"""Get kernel for event value (raises KeyError if not found)."""
return self._kernels[value]
def __contains__(self, value: int) -> bool:
return value in self._kernels
@property
def max_length(self) -> int:
"""Maximum kernel length across all kernels."""
return max(k.length for k in self._kernels.values())
@property
def max_pre_samples(self) -> int:
"""Maximum pre_samples across all kernels."""
return max(k.pre_samples for k in self._kernels.values())
@property
def max_post_samples(self) -> int:
"""Maximum post_samples across all kernels."""
return max(k.post_samples for k in self._kernels.values())
@property
def keys(self) -> list[int]:
"""Available kernel keys."""
return list(self._kernels.keys())
# =============================================================================
# Common kernel functions
# =============================================================================
[docs]
def exponential_kernel(t: npt.NDArray, sigma: float) -> npt.NDArray:
"""
Causal exponential decay kernel: k(t) = exp(-t/sigma) / sigma for t >= 0.
Normalized so that integral from 0 to inf equals 1.
"""
t = np.asarray(t)
result = np.zeros_like(t, dtype=np.float64)
valid = t >= 0
result[valid] = np.exp(-t[valid] / sigma) / sigma
return result
[docs]
def alpha_kernel(t: npt.NDArray, sigma: float) -> npt.NDArray:
"""
Alpha function kernel: k(t) = (t/sigma^2) * exp(-t/sigma) for t >= 0.
Peaks at t = sigma. Normalized so that integral from 0 to inf equals 1.
"""
t = np.asarray(t)
result = np.zeros_like(t, dtype=np.float64)
valid = t >= 0
result[valid] = (t[valid] / sigma**2) * np.exp(-t[valid] / sigma)
return result
[docs]
def gaussian_kernel(t: npt.NDArray, sigma: float) -> npt.NDArray:
"""
Gaussian kernel: k(t) = exp(-t^2 / (2*sigma^2)) / (sigma * sqrt(2*pi)).
Symmetric (acausal). Normalized so that integral equals 1.
"""
t = np.asarray(t)
return np.exp(-(t**2) / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi))
[docs]
def boxcar_kernel(t: npt.NDArray, sigma: float) -> npt.NDArray:
"""
Boxcar (rectangular) kernel: k(t) = 1/(2*sigma) for abs(t) < sigma.
Symmetric (acausal). Width is 2*sigma. Normalized so that integral equals 1.
"""
t = np.asarray(t)
half_width = sigma
result = np.zeros_like(t, dtype=np.float64)
valid = np.abs(t) < half_width
result[valid] = 0.5 / half_width
return result
[docs]
def causal_boxcar_kernel(t: npt.NDArray, sigma: float) -> npt.NDArray:
"""
Causal boxcar kernel: k(t) = 1/sigma for 0 <= t < sigma.
Normalized so that integral equals 1.
"""
t = np.asarray(t)
result = np.zeros_like(t, dtype=np.float64)
valid = (t >= 0) & (t < sigma)
result[valid] = 1.0 / sigma
return result