Source code for ezmsg.blackrock.nsp

import asyncio
from ctypes import Structure
import functools
import typing

import numpy as np
from pycbsdk import cbsdk, cbhw
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray, replace
from ezmsg.event.message import EventMessage

from .util import ClockSync


grp_fs = {1: 500, 2: 1_000, 3: 2_000, 4: 10_000, 5: 30_000, 6: 30_000}


[docs] class NSPSourceSettings(ez.Settings): inst_addr: str = "" inst_port: int = 51001 client_addr: str = "" client_port: int = 51002 recv_bufsize: typing.Optional[int] = None protocol: str = "3.11" cont_buffer_dur: float = 0.5 """Duration of continuous buffer to hold recv packets. Up to ~15 MB / second.""" microvolts: bool = True """Convert continuous data to uV (True) or keep raw integers (False).""" cbtime: bool = True """ Use Cerebus time for continuous data (True) or local time.time (False). Note that time.time is delayed by the network transmission latency relative to Cerebus time. """
[docs] class NSPSourceState(ez.State): device: cbsdk.NSPDevice spike_queue: asyncio.Queue[EventMessage] cont_buffer = { _: ( np.array([], dtype=int), np.array([[]], dtype=np.int16), ) for _ in range(1, 7) } cont_read_idx = {_: 0 for _ in range(1, 7)} cont_write_idx = {_: 0 for _ in range(1, 7)} template_cont = { _: AxisArray(data=np.array([[]]), dims=["time", "ch"]) for _ in range(1, 7) } scale_cont = {_: np.array([]) for _ in range(1, 7)} sysfreq: int = 30_000 # Default for pre-Gemini system n_channels: int = 0
[docs] class NSPSource(ez.Unit): SETTINGS = NSPSourceSettings STATE = NSPSourceState OUTPUT_SPIKE = ez.OutputStream(EventMessage) OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
[docs] async def initialize(self) -> None: self.STATE.spike_queue = asyncio.Queue() params = cbsdk.create_params( inst_addr=self.SETTINGS.inst_addr, inst_port=self.SETTINGS.inst_port, client_addr=self.SETTINGS.client_addr, client_port=self.SETTINGS.client_port, recv_bufsize=self.SETTINGS.recv_bufsize, protocol=self.SETTINGS.protocol, ) self.STATE.device = cbsdk.NSPDevice(params) run_level = self.STATE.device.connect(startup_sequence=False) if not run_level: raise ConnectionError(f"Failed to connect to NSP; {params=}") config = cbsdk.get_config(self.STATE.device, force_refresh=True) self.STATE.sysfreq = 1e9 if config["b_gemini"] else config["sysfreq"] self._clock_sync = ClockSync(alpha=0.1, sysfreq=self.STATE.sysfreq) monitor_state = self.STATE.device.get_monitor_state() while monitor_state["pkts_received"] < 1: await asyncio.sleep(0.1) monitor_state = self.STATE.device.get_monitor_state() self._clock_sync.add_pair(monitor_state["time"], monitor_state["sys_time"]) _ = cbsdk.register_spk_callback(self.STATE.device, self.on_spike) for grp_idx in range(1, 7): self._reset_buffer(grp_idx) _ = cbsdk.register_group_callback( self.STATE.device, grp_idx, functools.partial(self.on_smp_group, grp_idx=grp_idx), )
def _reset_buffer(self, grp_idx: int) -> None: config: dict = self.STATE.device.config chanset = config["group_infos"][grp_idx] buff_samples = int(self.SETTINGS.cont_buffer_dur * grp_fs[grp_idx]) self.STATE.n_channels = len(chanset) self.STATE.cont_buffer[grp_idx] = ( np.zeros((buff_samples,), dtype=int), np.zeros((buff_samples, self.STATE.n_channels), dtype=np.int16), ) self.STATE.cont_read_idx[grp_idx] = 0 self.STATE.cont_write_idx[grp_idx] = 0 time_ax = AxisArray.TimeAxis(grp_fs[grp_idx], offset=0.0) chan_labels = [] scale_factors = [] for ch_idx in chanset: pkt: cbhw.packet.packets.CBPacketChanInfo = config["channel_infos"][ch_idx] chan_labels.append(pkt.label.decode("utf-8")) scale_fac = (pkt.scalin.anamax - pkt.scalin.anamin) / ( pkt.scalin.digmax - pkt.scalin.digmin ) if pkt.scalin.anaunit.decode("utf-8") == "mV": scale_fac /= 1000 scale_factors.append(scale_fac) ch_ax = AxisArray.CoordinateAxis( data=np.array(chan_labels), dims=["ch"], unit="label" ) self.STATE.template_cont[grp_idx] = AxisArray( np.zeros((0, 0)), dims=["time", "ch"], axes={"time": time_ax, "ch": ch_ax}, key=f"ns{grp_idx}", attrs={"unit": "uV" if self.SETTINGS.microvolts else "raw"}, ) self.STATE.scale_cont[grp_idx] = np.array(scale_factors)
[docs] def shutdown(self) -> None: if hasattr(self.STATE, "device") and self.STATE.device is not None: self.STATE.device.disconnect()
[docs] def on_smp_group(self, pkt: Structure, grp_idx: int = 5): _buffer = self.STATE.cont_buffer[grp_idx] _write_idx = self.STATE.cont_write_idx[grp_idx] if self.STATE.n_channels != len(pkt.data): self._reset_buffer(grp_idx) _buffer[1][_write_idx, :] = memoryview(pkt.data[: self.STATE.n_channels]) _buffer[0][_write_idx] = pkt.header.time self.STATE.cont_write_idx[grp_idx] = (_write_idx + 1) % len(_buffer[0])
[docs] @ez.task async def update_clock(self) -> None: while True: await asyncio.sleep(1.0) if self.STATE.device is not None: monitor_state = self.STATE.device.get_monitor_state() self._clock_sync.add_pair( monitor_state["time"], monitor_state["sys_time"] )
[docs] @ez.publisher(OUTPUT_SIGNAL) async def pub_cont(self) -> typing.AsyncGenerator: while True: b_any = False for grp_idx in range(1, 7): _buff = self.STATE.cont_buffer[grp_idx] _read_idx = self.STATE.cont_read_idx[grp_idx] _write_idx = self.STATE.cont_write_idx[grp_idx] buff_len = len(_buff[0]) read_term = _write_idx if _write_idx >= _read_idx else buff_len if _read_idx == read_term: continue else: b_any = True read_slice = slice(_read_idx, min(buff_len, read_term)) out_dat = _buff[1][read_slice].copy() if self.SETTINGS.microvolts: out_dat = out_dat * self.STATE.scale_cont[grp_idx][None, :] if self.SETTINGS.cbtime: new_offset: float = _buff[0][_read_idx] / self.STATE.sysfreq else: new_offset: float = self._clock_sync.nsp2system(_buff[0][_read_idx]) _templ = self.STATE.template_cont[grp_idx] new_time_ax = replace(_templ.axes["time"], offset=new_offset) out_msg = replace( _templ, data=out_dat, axes={**_templ.axes, **{"time": new_time_ax}}, ) self.STATE.cont_read_idx[grp_idx] = read_term % buff_len yield self.OUTPUT_SIGNAL, out_msg if not b_any: await asyncio.sleep(0.001)
[docs] def on_spike(self, spk_pkt: Structure): self.STATE.spike_queue.put_nowait( EventMessage( offset=spk_pkt.header.time / self.STATE.sysfreq if self.SETTINGS.cbtime else self._clock_sync.nsp2system(spk_pkt.header.time), ch_idx=spk_pkt.header.chid - 1, sub_idx=min(spk_pkt.unit, 6), # 0=unsorted, 1-5 sorted unit, >5=noise value=1, ) )
[docs] @ez.publisher(OUTPUT_SPIKE) async def spikes(self) -> typing.AsyncGenerator: while True: spike_event = await self.STATE.spike_queue.get() yield self.OUTPUT_SPIKE, spike_event