import asyncio
import time
import typing
from dataclasses import dataclass, field, fields
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
import pylsl
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
from .util import ClockSync
fmt2npdtype = {
pylsl.cf_double64: float, # Prefer native type for float64
pylsl.cf_int64: int, # Prefer native type for int64
pylsl.cf_float32: np.float32,
pylsl.cf_int32: np.int32,
pylsl.cf_int16: np.int16,
pylsl.cf_int8: np.int8,
# pylsl.cf_string: # For now we don't provide a pre-allocated buffer for string data type.
}
[docs]
@dataclass
class LSLInfo:
name: str = ""
type: str = ""
host: str = "" # Use socket.gethostname() for local host.
channel_count: typing.Optional[int] = None
nominal_srate: float = 0.0
channel_format: typing.Optional[str] = None
def _sanitize_kwargs(kwargs: dict) -> dict:
if "info" not in kwargs:
replace_keys = set()
for k, v in kwargs.items():
if k.startswith("stream_"):
replace_keys.add(k)
if len(replace_keys) > 0:
ez.logger.warning(
f"LSLInlet kwargs beginning with 'stream_' deprecated. Found {replace_keys}. See LSLInfo dataclass."
)
for k in replace_keys:
kwargs[k[7:]] = kwargs.pop(k)
known_fields = [_.name for _ in fields(LSLInfo)]
info_kwargs = {k: v for k, v in kwargs.items() if k in known_fields}
for k in info_kwargs.keys():
kwargs.pop(k)
kwargs["info"] = LSLInfo(**info_kwargs)
return kwargs
[docs]
class LSLInletSettings(ez.Settings):
info: LSLInfo = field(default_factory=LSLInfo)
local_buffer_dur: float = 1.0
use_arrival_time: bool = False
"""
Whether to ignore the LSL timestamps and use the time.time of the pull (True).
If False (default), the LSL (send) timestamps are used.
Send times may be converted from LSL clock to time.time clock. See `use_lsl_clock`.
"""
use_lsl_clock: bool = False
"""
Whether the AxisArray.Axis.offset should use LSL's clock (True) or time.time's clock (False -- default).
"""
processing_flags: int = pylsl.proc_ALL
"""
The processing flags option passed to pylsl.StreamInlet. Default is proc_ALL which includes all flags.
Many users will want to set this to pylsl.proc_clocksync to disable dejittering.
"""
[docs]
class LSLInletState(ez.State):
resolver: typing.Optional[pylsl.ContinuousResolver] = None
inlet: typing.Optional[pylsl.StreamInlet] = None
clock_sync: ClockSync = ClockSync(run_thread=False)
msg_template: typing.Optional[AxisArray] = None
fetch_buffer: typing.Optional[npt.NDArray] = None
[docs]
class LSLInletGenerator:
[docs]
def __init__(self, *args, settings: typing.Optional[LSLInletSettings] = None, **kwargs):
kwargs = _sanitize_kwargs(kwargs)
if settings is None:
if len(args) > 0 and isinstance(args[0], LSLInletSettings):
settings = args[0]
elif len(args) > 0 or len(kwargs) > 0:
settings = LSLInletSettings(*args, **kwargs)
else:
settings = LSLInletSettings()
self._state: LSLInletState = LSLInletState()
self.settings = settings
self.shutdown()
self._reset_resolver()
def __iter__(self):
# self.shutdown() to reset?
return self
@property
def state(self) -> LSLInletState:
return self._state
def _reset_resolver(self) -> None:
self._state.resolver = pylsl.ContinuousResolver(pred=None, forget_after=30.0)
[docs]
def shutdown(self, shutdown_resolver: bool = True):
self._state.msg_template = None
self._state.fetch_buffer = None
if self._state.inlet is not None:
self._state.inlet.close_stream()
del self._state.inlet
self._state.inlet = None
if shutdown_resolver:
self._state.resolver = None
def _reset_inlet(self):
self.shutdown(shutdown_resolver=False)
# If name, type, and host are all provided, then create the StreamInfo directly and
# create the inlet directly from that info.
if all(
[
_ is not None
for _ in [
self.settings.info.name,
self.settings.info.type,
self.settings.info.channel_count,
self.settings.info.channel_format,
]
]
):
info = pylsl.StreamInfo(
name=self.settings.info.name,
type=self.settings.info.type,
channel_count=self.settings.info.channel_count,
channel_format=self.settings.info.channel_format,
)
self._state.inlet = pylsl.StreamInlet(info, max_chunklen=1, processing_flags=self.settings.processing_flags)
elif self._state.resolver is not None:
results: list[pylsl.StreamInfo] = self._state.resolver.results()
for strm_info in results:
b_match = True
b_match = b_match and ((not self.settings.info.name) or strm_info.name() == self.settings.info.name)
b_match = b_match and ((not self.settings.info.type) or strm_info.type() == self.settings.info.type)
b_match = b_match and ((not self.settings.info.host) or strm_info.hostname() == self.settings.info.host)
if b_match:
self._state.inlet = pylsl.StreamInlet(
strm_info,
max_chunklen=1,
processing_flags=self.settings.processing_flags,
)
break
if self._state.inlet is not None:
self._state.inlet.open_stream()
inlet_info = self._state.inlet.info()
# It's bad practice to write directly to settings but here we
# are filling in a value that was optional.
self.settings.info.nominal_srate = inlet_info.nominal_srate()
# If possible, create a destination buffer for faster pulls
fmt = inlet_info.channel_format()
n_ch = inlet_info.channel_count()
if fmt in fmt2npdtype:
dtype = fmt2npdtype[fmt]
n_buff = int(self.settings.local_buffer_dur * inlet_info.nominal_srate()) or 1000
self._state.fetch_buffer = np.zeros((n_buff, n_ch), dtype=dtype)
ch_labels = []
chans = inlet_info.desc().child("channels")
if not chans.empty():
ch = chans.first_child()
while not ch.empty():
ch_labels.append(ch.child_value("label"))
ch = ch.next_sibling()
while len(ch_labels) < n_ch:
ch_labels.append(str(len(ch_labels) + 1))
# Pre-allocate a message template.
fs = inlet_info.nominal_srate()
time_ax = (
AxisArray.TimeAxis(fs=fs)
if fs
else AxisArray.CoordinateAxis(data=np.array([]), dims=["time"], unit="s")
)
self._state.msg_template = AxisArray(
data=np.empty((0, n_ch)),
dims=["time", "ch"],
axes={
"time": time_ax,
"ch": AxisArray.CoordinateAxis(data=np.array(ch_labels), dims=["ch"]),
},
key=inlet_info.name(),
)
[docs]
def update_settings(self, new_settings: LSLInletSettings) -> None:
# The message may be full LSLInletSettings, a dict of settings, just the info, or dict of just info.
if isinstance(new_settings, dict):
# First make sure the info is in the right place.
msg = _sanitize_kwargs(new_settings)
# Next, convert to LSLInletSettings object.
msg = LSLInletSettings(**msg)
if msg != self.settings:
self._reset_resolver()
self._reset_inlet()
def __next__(self) -> typing.Optional[AxisArray]:
if self._state.inlet is None:
# Inlet not yet created, or recently destroyed because settings changed.
self._reset_inlet()
return None
if self._state.fetch_buffer is not None:
samples, timestamps = self._state.inlet.pull_chunk(
max_samples=self._state.fetch_buffer.shape[0],
dest_obj=self._state.fetch_buffer,
)
else:
samples, timestamps = self._state.inlet.pull_chunk()
samples = np.array(samples)
out_msg = self._state.msg_template
if len(timestamps):
data = self._state.fetch_buffer[: len(timestamps)].copy() if samples is None else samples
# `timestamps` is currently in the LSL clock stamped by the sender.
if self.settings.use_arrival_time:
# Drop the sender stamps; use "now"
timestamps = time.time() - (timestamps - timestamps[0])
if self.settings.use_lsl_clock:
timestamps = self._state.clock_sync.system2lsl(timestamps)
elif not self.settings.use_lsl_clock:
# Keep the sender clock but convert to system time.
timestamps = self._state.clock_sync.lsl2system(timestamps)
if self.settings.info.nominal_srate <= 0.0:
# Irregular rate stream uses CoordinateAxis for time so each sample has a timestamp.
out_time_ax = replace(
self._state.msg_template.axes["time"],
data=np.array(timestamps),
)
else:
# Regular rate uses a LinearAxis for time so we only need the time of the first sample.
out_time_ax = replace(self._state.msg_template.axes["time"], offset=timestamps[0])
out_msg = replace(
self._state.msg_template,
data=data,
axes={
**self._state.msg_template.axes,
"time": out_time_ax,
},
)
return out_msg
[docs]
class LSLInletUnitState(ez.State):
generator: typing.Optional[LSLInletGenerator] = None
[docs]
class LSLInletUnit(ez.Unit):
"""
Represents a node in a graph that creates an LSL inlet and
forwards the pulled data to the unit's output.
Args:
stream_name: The `name` of the created LSL outlet.
stream_type: The `type` of the created LSL outlet.
"""
SETTINGS = LSLInletSettings
STATE = LSLInletUnitState
INPUT_SETTINGS = ez.InputStream(LSLInletSettings)
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
[docs]
async def initialize(self) -> None:
self._create_generator()
def _create_generator(self):
self.STATE.generator = LSLInletGenerator(settings=self.SETTINGS)
[docs]
def shutdown(self) -> None:
self.STATE.generator.shutdown()
[docs]
@ez.task
async def update_clock(self) -> None:
gen = self.STATE.generator
while True:
if gen.state.inlet is not None:
gen.state.clock_sync.run_once()
await asyncio.sleep(0.1)
[docs]
@ez.subscriber(INPUT_SETTINGS)
async def on_settings(self, msg: LSLInletSettings) -> None:
self.apply_settings(msg)
self.STATE.generator.update_settings(msg)
[docs]
@ez.publisher(OUTPUT_SIGNAL)
async def lsl_pull(self) -> typing.AsyncGenerator:
while True:
out_msg = next(self.STATE.generator)
if out_msg is not None and np.prod(out_msg.data.shape) > 0:
yield self.OUTPUT_SIGNAL, out_msg
else:
await asyncio.sleep(0.001)