"""CerePlex impedance measurement pipeline.
The CerePlex headstage injects a 1 kHz, 1 nA sine wave for 100 ms per channel,
cycling sequentially through all channels. Channels not under test read exactly
zero (filters must be disabled). Impedance is extracted via single-bin DFT at
1 kHz: Z(kOhm) = V_peak(uV) / I_peak(nA).
Multiple headstages are tracked independently — each may be at a different point
in its impedance sweep.
"""
import dataclasses
import logging
import typing
import ezmsg.core as ez
import numpy as np
import scipy.signal as ss
from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
from ezmsg.util.messages.axisarray import AxisArray, replace
logger = logging.getLogger(__name__)
[docs]
class CerePlexImpedanceSettings(ez.Settings):
headstage_channel_offsets: tuple[int, ...] = (0,)
"""Starting channel index of each CerePlex headstage. Each headstage's range
extends from its offset to the next offset (or n_ch for the last).
Example: two 128-ch headstages → (0, 128)."""
collect_duration_s: float = 0.1
"""Maximum burst duration to buffer per channel (100 ms for CerePlex)."""
fft_duration_s: float = 0.09227
"""Duration of data used for FFT, taken from the end of the burst.
The preceding samples serve as settle time."""
freq_lo: float = 960.0
"""Lower bound of frequency range for peak extraction (Hz)."""
freq_hi: float = 1050.0
"""Upper bound of frequency range for peak extraction (Hz)."""
test_current_nA: float = 1.0
"""Injected test-current peak-to-peak amplitude (nA)."""
class _HeadstageTracker:
"""Per-headstage sequential channel tracker."""
__slots__ = ("ch_start", "ch_end", "tracking_ch", "buffer", "buf_len")
def __init__(self, ch_start: int, ch_end: int, buffer: np.ndarray):
self.ch_start = ch_start
self.ch_end = ch_end # exclusive
self.tracking_ch = -1 # absolute index; -1 = scanning
self.buffer = buffer
self.buf_len = 0
[docs]
@processor_state
class CerePlexImpedanceState:
trackers: list | None = None # list[_HeadstageTracker]
max_buffer_samples: int = 0
fft_samples: int = 0
fs: float = 0.0
impedance: np.ndarray | None = None # (n_ch,), NaN = unmeasured
ch_axis: typing.Any = None
def _scan_for_active(data: np.ndarray, pos: int, hs: _HeadstageTracker) -> None:
"""Find the currently-active channel and tag the next one for measurement."""
remaining = data[pos:, hs.ch_start : hs.ch_end]
has_data = np.any(remaining != 0, axis=0)
candidates = np.flatnonzero(has_data)
if len(candidates) == 0:
return
# Pick the channel whose first non-zero sample is earliest
first_nz = np.argmax(remaining[:, candidates] != 0, axis=0)
active_local = int(candidates[first_nz.argmin()])
do_next = remaining[0, active_local] != 0 # True if we don't know when active_local started.
n_hs = hs.ch_end - hs.ch_start
hs.tracking_ch = hs.ch_start + (active_local + int(do_next)) % n_hs
hs.buf_len = 0
[docs]
class CerePlexImpedanceProcessor(
BaseStatefulTransformer[
CerePlexImpedanceSettings,
AxisArray,
AxisArray | None,
CerePlexImpedanceState,
]
):
"""Stateful transformer that extracts per-channel impedance from a CerePlex sweep.
Expects a stream of ``AxisArray`` messages with dims ``["time", "ch"]``
where the data is in **microvolts**. When using :class:`CereLinkSource`,
set ``microvolts=True``; raw ADC counts will produce incorrect results.
The processor tracks one or more headstages independently (configured via
:attr:`CerePlexImpedanceSettings.headstage_channel_offsets`). Each
headstage's impedance sweep cycles sequentially through its channels:
exactly one channel is non-zero at a time while the others read zero.
On each impedance update the processor emits an ``AxisArray`` whose data
is a ``(1, n_ch)`` array of impedance values in kOhm (``NaN`` for channels
not yet measured).
"""
def _hash_message(self, message: AxisArray) -> int:
ch_idx = message.dims.index("ch")
n_ch = message.data.shape[ch_idx]
time_axis = message.axes.get("time")
fs = time_axis.gain if hasattr(time_axis, "gain") else 0
return hash((n_ch, fs))
def _reset_state(self, message: AxisArray) -> None:
s = self.state
s.fs = 1.0 / message.axes["time"].gain
ch_idx = message.dims.index("ch")
n_ch = message.data.shape[ch_idx]
settings = self.settings
s.max_buffer_samples = int(settings.collect_duration_s * s.fs)
s.fft_samples = int(settings.fft_duration_s * s.fs)
self._build_trackers(n_ch)
s.impedance = np.full(n_ch, np.nan, dtype=np.float64)
s.ch_axis = message.axes.get("ch")
def _build_trackers(self, n_ch: int) -> None:
"""Build per-headstage trackers from the current settings.
Split out from ``_reset_state`` so a settings update that only changes
``headstage_channel_offsets`` can rebuild the tracker layout without
clearing the accumulated ``state.impedance`` array. Any in-flight
per-tracker burst buffer is discarded; new bursts buffer fresh.
"""
s = self.state
offsets = sorted(self.settings.headstage_channel_offsets)
s.trackers = []
for i, start in enumerate(offsets):
end = offsets[i + 1] if i + 1 < len(offsets) else n_ch
buf = np.zeros(s.max_buffer_samples, dtype=np.float64)
s.trackers.append(_HeadstageTracker(start, end, buf))
# --- Per-headstage helpers ---
def _complete_channel(self, hs: _HeadstageTracker) -> bool:
"""FFT the buffered burst, store impedance, advance to next channel.
Only updates the stored impedance if the burst contained enough
samples for a reliable FFT. Truncated bursts (e.g. from a file-loop
boundary or impedance mode being disabled mid-sweep) are discarded
so they don't overwrite a previous good measurement.
"""
s = self.state
settings = self.settings
updated = False
if hs.buf_len >= s.fft_samples:
imp = extract_impedance(
hs.buffer[: hs.buf_len],
s.fft_samples,
s.fs,
settings.freq_lo,
settings.freq_hi,
settings.test_current_nA,
)
if imp is not None:
s.impedance[hs.tracking_ch] = imp
updated = True
n_hs = hs.ch_end - hs.ch_start
local = hs.tracking_ch - hs.ch_start
hs.tracking_ch = hs.ch_start + (local + 1) % n_hs
hs.buf_len = 0
return updated
def _buffer_channel(
self,
data: np.ndarray,
pos: int,
hs: _HeadstageTracker,
) -> tuple[int, bool, bool]:
"""Buffer samples from the tracked channel's column slice.
Termination: tracked channel is zero AND next channel is non-zero
(the headstage has handed off to the next channel).
Returns (samples_consumed, channel_done, impedance_updated).
"""
s = self.state
col = data[pos:, hs.tracking_ch]
n = len(col)
if n == 0:
return 0, False, False
# Next channel in the headstage sequence
n_hs = hs.ch_end - hs.ch_start
local = hs.tracking_ch - hs.ch_start
next_ch = hs.ch_start + (local + 1) % n_hs
next_col = data[pos:, next_ch]
# Skip leading zeros if not yet buffering
start = 0
if hs.buf_len == 0:
nz = col != 0
first_nz = int(np.argmax(nz))
if not nz[first_nz]:
return n, False, False # all zero — channel not active yet
start = first_nz
if n_hs > 1 and next_col[start] != 0:
hs.tracking_ch = -1
hs.buf_len = 0
return n, True, False
tail = col[start:]
next_tail = next_col[start:]
# Find end of non-zero run in tracked channel
zeros = tail == 0.0
first_zero = len(tail)
if np.any(zeros):
first_zero = int(np.argmax(zeros))
# Exclusivity check: if the next channel in sequence is non-zero
# while we're buffering, impedance mode was toggled off (or tracking
# is out of sync). Only meaningful with >1 channel per headstage.
if n_hs > 1 and first_zero > 0 and np.any(next_tail[:first_zero] != 0):
hs.tracking_ch = -1
hs.buf_len = 0
return start + first_zero, True, False
# Buffer non-zero portion only
space = s.max_buffer_samples - hs.buf_len
n_copy = min(first_zero, space)
if n_copy > 0:
hs.buffer[hs.buf_len : hs.buf_len + n_copy] = tail[:n_copy]
hs.buf_len += n_copy
# Buffer full → complete regardless
if hs.buf_len >= s.max_buffer_samples:
return start + first_zero, True, self._complete_channel(hs)
# Tracked channel went to zero — determine what happened
if first_zero < len(tail):
# 1. Check next channel (expected sequential handoff)
remainder_next = next_tail[first_zero:]
if np.any(remainder_next != 0):
term_pos = first_zero + int(np.argmax(remainder_next != 0))
consumed = start + term_pos
if hs.buf_len >= s.fft_samples:
return consumed, True, self._complete_channel(hs)
hs.tracking_ch = -1
hs.buf_len = 0
return consumed, True, False
# 2. Next channel not active — check if ANY headstage channel is
# (sequence break: file wrap, channel skip, etc.)
hs_remainder = data[pos + start + first_zero :, hs.ch_start : hs.ch_end]
if np.any(hs_remainder != 0):
consumed = start + first_zero
updated = False
if hs.buf_len >= s.fft_samples:
updated = self._complete_channel(hs)
hs.tracking_ch = -1 # force re-scan to re-lock sequence
hs.buf_len = 0
return consumed, True, updated
# 3. No channels active — gap, consume rest and wait
return start + len(tail), False, False
# Burst continues to end of chunk
return start + len(tail), False, False
# --- Per-headstage processing ---
def _process_headstage(
self,
data: np.ndarray,
n_time: int,
hs: _HeadstageTracker,
) -> bool:
any_updated = False
pos = 0
while pos < n_time:
if hs.tracking_ch == -1:
_scan_for_active(data, pos, hs)
if hs.tracking_ch == -1:
break
consumed, done, updated = self._buffer_channel(data, pos, hs)
any_updated |= updated
pos += consumed
if not done:
break # chunk ended mid-burst, continue next call
return any_updated
# --- Main entry point ---
def _process(self, message: AxisArray) -> AxisArray | None:
s = self.state
data = message.data
n_time = data.shape[0]
if n_time == 0:
return None
# Capture if ch_axis changed for later emission even if no update
incoming_ch = message.axes.get("ch")
ch_axis_changed = incoming_ch is not None and incoming_ch is not s.ch_axis
if ch_axis_changed:
s.ch_axis = incoming_ch
any_updated = False
for hs in s.trackers:
any_updated |= self._process_headstage(data, n_time, hs)
if any_updated or ch_axis_changed:
time_ix = message.get_axis_idx("time")
new_time_ax = replace(
message.axes["time"], offset=message.axes["time"].value(message.data.shape[time_ix] - 1)
)
return replace(
message,
data=s.impedance.copy()[None, :],
axes={**message.axes, "time": new_time_ax},
)
return None
[docs]
class CerePlexImpedance(
BaseTransformerUnit[
CerePlexImpedanceSettings,
AxisArray,
AxisArray,
CerePlexImpedanceProcessor,
]
):
SETTINGS = CerePlexImpedanceSettings
[docs]
@ez.subscriber(BaseTransformerUnit.INPUT_SETTINGS)
async def on_settings(self, msg: CerePlexImpedanceSettings) -> None:
"""Apply new settings.
If only ``headstage_channel_offsets`` changed, rebuild the trackers in
place — the accumulated impedance values for previously-measured
channels remain valid. Any other change recreates the processor as
usual (state is reset on the next message).
"""
old = self.SETTINGS
if old is not None and isinstance(old, CerePlexImpedanceSettings):
old_offsets = tuple(old.headstage_channel_offsets)
new_offsets = tuple(msg.headstage_channel_offsets)
offsets_changed = old_offsets != new_offsets
# For the "everything else matches" check, normalise the offsets
# field on both sides before comparing the full dataclasses.
old_norm = dataclasses.replace(old, headstage_channel_offsets=new_offsets)
msg_norm = dataclasses.replace(msg, headstage_channel_offsets=new_offsets)
only_offsets = offsets_changed and old_norm == msg_norm
else:
only_offsets = False
self.apply_settings(msg)
proc = getattr(self, "processor", None)
state = getattr(proc, "state", None) if proc is not None else None
impedance = getattr(state, "impedance", None) if state is not None else None
if only_offsets and impedance is not None:
proc.settings = msg
proc._build_trackers(impedance.shape[0])
logger.info(
"CerePlexImpedance.on_settings: rebuilt trackers (offsets=%s); preserved %d impedance values.",
msg.headstage_channel_offsets,
impedance.shape[0],
)
else:
self.create_processor()