Source code for ezmsg.neo.source

import asyncio
import os
import time
import typing
from collections import deque
from pathlib import Path

import ezmsg.core as ez
import neo.rawio.baserawio
import numpy as np
import sparse
from ezmsg.util.generator import GenState
from ezmsg.util.messages.axisarray import AxisArray, replace


[docs] class NeoIteratorSettings(ez.Settings): """Settings for :obj:`NeoIterator`.""" filepath: os.PathLike chunk_dur: float = 0.05 self_terminating: bool = True t_offset: typing.Optional[float] = None
[docs] class NeoIterator:
[docs] def __init__(self, settings: NeoIteratorSettings): self._settings = settings self._reader: typing.Optional[neo.rawio.baserawio.BaseRawIO] = None self._playback_state: typing.Optional[dict] = None self._reset()
def _reset(self): self._playback_state = { "t_offset": self._settings.t_offset if self._settings.t_offset is not None else time.time(), "t_start": np.inf, "chunk_ix": 0, "msg_queue": deque(), "streams": {}, } self._preload() def _preload(self): fpath = Path(self._settings.filepath) if not fpath.exists(): raise FileNotFoundError(f"File not found: {fpath}") if fpath.suffix == ".vhdr": from neo.rawio import BrainVisionRawIO as RawIO elif fpath.suffix.startswith(".ns") or fpath.suffix == ".nev": from neo.rawio import BlackrockRawIO as RawIO else: raise ValueError(f"Unsupported file type: {fpath.suffix}") self._reader = RawIO(filename=str(fpath)) self._reader.parse_header() nb_block = self._reader.block_count() if nb_block > 1: raise NotImplementedError("Only single-block files are supported.") nb_seg = [self._reader.segment_count(_) for _ in range(nb_block)][0] if nb_seg > 1: raise NotImplementedError("Only single-segment files are supported.") nb_sig_streams = self._reader.signal_streams_count() t_stop = -np.inf # Fill out metadata for analogsignal streams for strm_ix in range(nb_sig_streams): t_start = self._reader.get_signal_t_start(0, 0, strm_ix) self._playback_state["t_start"] = min(self._playback_state["t_start"], t_start) nb_chans = self._reader.signal_channels_count(strm_ix) fs = self._reader.get_signal_sampling_rate(strm_ix) nb_samps = self._reader.get_signal_size(0, 0, strm_ix) t_stop = max(t_stop, t_start + nb_samps / fs) chan_struct_arr = self._reader.header["signal_channels"] key = self._reader.header["signal_streams"][strm_ix]["name"] template = AxisArray( data=np.zeros((0, nb_chans), dtype=float), dims=["time", "ch"], axes={ "time": AxisArray.TimeAxis(fs=fs, offset=0.0), "ch": AxisArray.CoordinateAxis(data=chan_struct_arr["name"], dims=["ch"], unit="label"), }, key=key, ) self._playback_state["streams"][key] = { "idx": strm_ix, "type": "analogsignal", "t_start": t_start, "template": template, "prev_samp": 0, } # Fill out metadata for event streams nb_event_channel = self._reader.event_channels_count() if nb_event_channel > 0: # TODO: Event should probably use SampleTriggerMessage self._playback_state["streams"]["events"] = { "type": "event", "nchan": nb_event_channel, "template": AxisArray( data=np.array([""]), dims=["time"], axes={"time": AxisArray.CoordinateAxis(data=np.array([0]), dims=["time"], unit="s")}, key="events", ), } # spiketrain streams nb_unit = self._reader.spike_channels_count() if nb_unit > 0: spk_chans = self._reader.header["spike_channels"] if "wf_sampling_rate" in spk_chans.dtype.names: spike_fs = spk_chans["wf_sampling_rate"][0] else: spike_fs = 30_000.0 if "name" in spk_chans.dtype.names: spk_ch_labels = spk_chans["name"] else: spk_ch_labels = np.arange(1, 1 + nb_unit).astype(str) self._playback_state["streams"]["spike"] = { "type": "spiketrain", "nchan": nb_unit, "template": AxisArray( data=sparse.SparseArray((nb_unit, 0)), dims=["unit", "time"], axes={ "unit": AxisArray.CoordinateAxis(data=spk_ch_labels, dims=["unit"], unit="unit"), "time": AxisArray.TimeAxis(fs=spike_fs, offset=0.0), }, key="spike", ), } t_elapsed = t_stop - self._playback_state["t_start"] self._playback_state["n_chunks"] = int(np.ceil(t_elapsed / self._settings.chunk_dur)) def __iter__(self): self._reset() return self def _chunk_step(self): state = self._playback_state t_range = (np.arange(2) + state["chunk_ix"]) * self._settings.chunk_dur if True: # Offset by global t_start t_range += self._playback_state["t_start"] for key, strm in state["streams"].items(): if strm["type"] == "analogsignal": # Fetch data from last_idx to next_idx = next_time * fs fs = 1 / strm["template"].axes["time"].gain prev_samp = strm["prev_samp"] next_samp = max(0, int((t_range[1] - strm["t_start"]) * fs)) dat = self._reader.get_analogsignal_chunk( seg_index=0, stream_index=strm["idx"], i_start=prev_samp, i_stop=next_samp, ) if dat.size: dat = self._reader.rescale_signal_raw_to_float(dat, dtype=float) msg = replace( strm["template"], data=dat, axes={ **strm["template"].axes, "time": replace( strm["template"].axes["time"], offset=state["t_offset"] + prev_samp / fs, ), }, ) state["msg_queue"].append(msg) strm["prev_samp"] = next_samp elif strm["type"] == "event": # TODO: Event should probably use SampleTriggerMessage for ev_ch_ix in range(strm["nchan"]): ev_timestamps, ev_durations, ev_labels = self._reader.get_event_timestamps( block_index=0, seg_index=0, event_channel_index=ev_ch_ix, t_start=t_range[0], t_stop=t_range[1], ) if len(ev_timestamps) == 0: continue ev_times = self._reader.rescale_event_timestamp(ev_timestamps, dtype=float) msg = replace( strm["template"], data=ev_labels, axes={ **strm["template"].axes, "time": replace( strm["template"].axes["time"], data=ev_times + state["t_offset"], ), }, ) state["msg_queue"].append(msg) elif strm["type"] == "spiketrain": samp_step = strm["template"].axes["time"].gain n_times = int((t_range[1] - t_range[0]) / samp_step) tvec = t_range[0] + np.arange(n_times) * samp_step samp_idx = np.array([], dtype=int) chan_idx = np.array([], dtype=int) for spk_ch_ix in range(strm["nchan"]): spike_times = self._reader.get_spike_timestamps( block_index=0, seg_index=0, spike_channel_index=spk_ch_ix, t_start=t_range[0], t_stop=t_range[1], ) spike_times = self._reader.rescale_spike_timestamp(spike_times, dtype="float64") samp_idx = np.hstack((samp_idx, np.searchsorted(tvec, spike_times))) chan_idx = np.hstack((chan_idx, np.full((len(spike_times),), spk_ch_ix, dtype=int))) # raw_waveforms = reader.get_spike_raw_waveforms(block_index=0, seg_index=0, spike_channel_index=0, # t_start=0, t_stop=10) # float_waveforms = reader.rescale_waveforms_to_float( # raw_waveforms, dtype='float32', spike_channel_index=0) # state["msg_queue"].append(msg) result = sparse.COO( np.vstack((chan_idx, samp_idx)), data=1, shape=(strm["nchan"], len(tvec)), ) msg = replace( strm["template"], data=result, axes={ **strm["template"].axes, "time": replace(strm["template"].axes["time"], offset=t_range[0]), }, ) state["msg_queue"].append(msg) state["chunk_ix"] += 1 def __next__(self) -> AxisArray: state = self._playback_state if not state["msg_queue"]: if state["chunk_ix"] >= state["n_chunks"]: # TODO Close file raise StopIteration self._chunk_step() if not state["msg_queue"]: raise StopIteration return state["msg_queue"].popleft()
[docs] class NeoIteratorUnit(ez.Unit): STATE = GenState SETTINGS = NeoIteratorSettings OUTPUT_SIGNAL = ez.OutputStream(AxisArray) OUTPUT_TERM = ez.OutputStream(typing.Any)
[docs] def initialize(self) -> None: self.construct_generator()
[docs] def construct_generator(self): self.STATE.gen = NeoIterator( settings=self.SETTINGS, )
[docs] @ez.publisher(OUTPUT_SIGNAL) async def pub_chunk(self) -> typing.AsyncGenerator: for msg in self.STATE.gen: # TODO: Direct msg to OUTPUT_TRIGGER if type is SampleTriggerMessage yield self.OUTPUT_SIGNAL, msg await asyncio.sleep(0) ez.logger.debug(f"File ({self.SETTINGS.filepath}) exhausted.") if self.SETTINGS.self_terminating: raise ez.NormalTermination yield self.OUTPUT_TERM, ez.Flag