from pathlib import Path
import queue
import numpy as np
import numpy.typing as npt
import pyxdf
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
[docs]
class XDFIterator:
[docs]
def __init__(
self,
filepath: Path | str,
select: set[str]
| None = None, # If set, then the iterator yields only AxisArray of selected stream(s).
# If None (default), then the iterator yields dicts with keys for each stream
chunk_dur: float = 1.0, # Attempt to chunk data into chunks of this duration.
start_time: float | None = None,
stop_time: float | None = None,
rezero: bool = True,
):
"""
An Iterator that yields chunks from an XDF.
A typical offline analysis might load the entire file into memory, then perform a processing step on the entire
recording duration, and the next step on the entire result of the first step, and so on. This might require a
tremendous amount of memory and, if one is not careful about memory layout, can be incredibly slow. An
alternative procedure is to load the file into memory a chunk at a time (see Note1), then pass that chunk
through the entire processing pipeline, then proceed onto the next chunk (See Note2). We create an Iterator to
provide our chunks.
> Note1: I have not written a true lazy-loader for XDF because it has not yet been necessary as the files are
all small. Thus, I use pyxdf.load_xdf which loads the entire raw data into memory. The processing is still
done chunk-by-chunk.
> Note2: It should be possible to start on chunk[ix+1] while chunk[ix] is still going through the pipeline.
Indeed, this is (optionally) how it works online. However, the overhead of setting this up for offline
analysis is not worth the gain, at least not at this stage.
Args:
filepath: The path to the file to load and iterate over.
select: (Optional) A set of stream names to select. If None, then all streams are selected.
chunk_dur: The duration of each chunk in seconds.
start_time: Start playback at this time. If rezero is True then this is relative to the file start time.
If rezero is False then this is relative to the original timestamps.
stop_time: Truncate the playback to stop at this time. If rezero is True then this is relative to the file
start time. If rezero is False then this is relative to the original timestamps.
rezero: The absolute value of timestamps in an XDF file are useful for synchronization WITHIN file, but they
are absolutely meaningless outside the exact XDF file like in an ezmsg application. Thus, by default we
rezero the timestamps to start at t=0.0 for simplicity. However, there may be rare circumstances where
one wants to compare the timestamps produced by ezmsg to timestamps produced by another XDF analysis
tool that does not rezero. In that case, set rezero=False.
"""
if isinstance(filepath, str):
filepath = Path(filepath).expanduser()
self._filepath = filepath
self._select = select
self._chunk_dur = chunk_dur
self._rezero = rezero
self._n_chunks = 0
self._t0 = 0.0
self._chunk_ix = 0
self._last_time = 0.0
self._metadata = {}
self._prev_file_read_s: float = (
0 # File read header in seconds for previous iteration
)
self._time_range: tuple[float | None, float | None] = (start_time, stop_time)
self._scan_file()
def _scan_file(self):
# Note: For larger datafiles we wouldn't want to load the entire thing into memory with load_xdf.
# Instead, get a file handle, then
# - Scan the file for chunk boundaries and timestamps
# - Maintain a list of chunk boundaries
# - Perform timestamp corrections (maintain corrected ts in memory or use func to correct during next pass?)
# - Iterator operates on original chunk-boundaries, but using corrected timestamps.
# However, we would need a custom file parser for that. For now, we load the relatively small
# file into memory simply with pyxdf.load_xdf then iterate over the items in memory
# at a user-defined chunk boundary (`chunk_dur`).
# Load xdf
self._streams, fileheader = pyxdf.load_xdf(
self._filepath,
select_streams=None
if (self._select is None or self._rezero)
else [{"name": _} for _ in self._select],
)
self._metadata = {}
self._file_read_s = 0
self._prev_file_read_s = 0
xdf_t0 = np.inf
xdf_tmax = 0
for strm in self._streams:
# Convert empty data to an array for easier slicing
if type(strm["time_series"]) is list:
strm["time_series"] = np.array(strm["time_series"])
# Get more digestable metadata
info = strm["info"]
new_meta = {
"name": info["name"][0],
"type": info["type"][0],
"channel_count": int(info["channel_count"][0]),
"nominal_srate": float(info["nominal_srate"][0]),
}
self._metadata[new_meta["name"]] = new_meta
# Update time range limits
tvec = strm["time_stamps"]
if len(tvec) > 0:
xdf_t0 = min(xdf_t0, tvec[0])
xdf_tmax = max(xdf_tmax, tvec[-1])
# Permanently modify streams' time stamps
if self._rezero:
for strm in self._streams:
strm["time_stamps"] = strm["time_stamps"] - xdf_t0
xdf_tmax -= xdf_t0
xdf_t0 = 0
# Adjust for provided time bounds
for strm in self._streams:
tvec = strm["time_stamps"]
if len(tvec) > 0:
b_keep = np.ones(len(tvec), dtype=bool)
if self._time_range[0] is not None:
b_keep = np.logical_and(b_keep, tvec >= self._time_range[0])
if self._time_range[1] is not None:
b_keep = np.logical_and(b_keep, tvec <= self._time_range[1])
if np.any(~b_keep):
strm["time_stamps"] = tvec[b_keep]
strm["timeseries"] = strm["timeseries"][b_keep]
# Recalculate tmax
xdf_dur = 0
for strm in self._streams:
tvec = strm["time_stamps"]
srate = float(strm["info"]["nominal_srate"][0])
adj = (1 / srate if srate > 0 else 0) - xdf_t0
if len(tvec) > 0:
xdf_dur = max(xdf_dur, tvec[-1] + adj)
# Chunking
self._n_chunks = int(np.ceil(xdf_dur / self._chunk_dur))
self._t0 = xdf_t0
# Drop streams that were not selected. (Could not drop earlier due to timestamp rezero)
if self._rezero and self._select is not None:
stream_names = [_["info"]["name"][0] for _ in self._streams]
self._streams = [self._streams[stream_names.index(_)] for _ in self._select]
self._metadata = {k: self._metadata[k] for k in self._select}
print(
f"Imported {len(self._streams)} streams from {self._filepath} "
f"spanning {xdf_dur:.2f} s beginning at t={xdf_t0:.2f}."
)
@property
def stream_meta(self) -> list[dict] | dict:
return self._metadata
@property
def n_chunks(self) -> int:
return self._n_chunks
def __iter__(self):
self._chunk_ix = 0
return self
def __next__(self) -> dict[str, tuple[npt.NDArray, npt.NDArray]]:
if self._chunk_ix >= self.n_chunks:
raise StopIteration
else:
out_dict = {}
t_start, t_stop = (
self._chunk_ix * self._chunk_dur + self._t0,
(self._chunk_ix + 1) * self._chunk_dur + self._t0,
)
for strm in self._streams:
b_chunk = np.logical_and(
strm["time_stamps"] >= t_start, strm["time_stamps"] < t_stop
)
out_tvec = strm["time_stamps"][b_chunk]
out_data = strm["time_series"][b_chunk]
out_dict[strm["info"]["name"][0]] = (out_data, out_tvec)
if len(out_tvec) > 0:
self._last_time = max(self._last_time, out_tvec[-1])
self._chunk_ix += 1
return out_dict
[docs]
def labels_from_strm(strm: dict) -> list[str]:
desc = strm["info"]["desc"][0]
if desc is not None and "channels" in desc:
labels = [_["label"][0] for _ in desc["channels"][0]["channel"]]
else:
n_ch = int(strm["info"]["channel_count"][0])
labels = [str(_ + 1) for _ in range(n_ch)]
return labels
[docs]
class XDFAxisArrayIterator(XDFIterator):
[docs]
def __init__(self, *args, select: str, **kwargs):
"""
This Iterator loads only a single stream and yields a single :obj:`AxisArray` object per chunk.
Args:
*args:
select: Unlike :obj:`XDFIterator`, this must be a single string, the name of the stream to select.
**kwargs:
"""
kwargs["select"] = set((select,))
super().__init__(*args, **kwargs)
_sel = [_ for _ in self._select][0]
labels = labels_from_strm(self._streams[0])
if self._metadata[_sel].get("nominal_srate", None):
time_ax = AxisArray.TimeAxis(
fs=self._metadata[_sel]["nominal_srate"], offset=0
)
else:
time_ax = AxisArray.CoordinateAxis(
data=np.array([]),
dims=["time"],
unit="s"
)
self._template = AxisArray(
data=np.zeros(
(0, len(labels)), dtype=self._streams[0]["time_series"].dtype
),
dims=["time", "ch"],
axes={
"time": time_ax,
"ch": AxisArray.CoordinateAxis(data=np.array(labels), dims=["ch"]),
},
key=self._streams[0]["info"]["name"][0],
)
def __next__(self) -> AxisArray:
result: AxisArray | None = None
chunk_dict = super().__next__()
# Should only be 1 in self._select. If there are more then we overwrite with the last.
for strm_name in self._select:
if strm_name in chunk_dict:
data, tvec = chunk_dict[strm_name]
if isinstance(self._template.axes["time"], AxisArray.CoordinateAxis):
t_kwargs = {"data": tvec}
else:
t_kwargs = {"offset": tvec[0] if len(tvec) else self._last_time}
result = replace(
self._template,
data=data,
axes={
**self._template.axes,
"time": replace(
self._template.axes["time"],
**t_kwargs,
),
},
)
return result
[docs]
class XDFMultiAxArrIterator(XDFIterator):
[docs]
def __init__(self, *args, force_single_sample: set = set(), **kwargs):
"""
This Iterator loads multiple streams and yields a :obj:`AxisArray` object per iteration,
but the stream source might different between chunks.
Args:
*args:
force_single_sample: Use this to identify irregular-rate streams that might conceivably have more than one
event within the defined chunk_dur, for which :obj:`AxisArray` cannot represent timestamps properly.
**kwargs:
"""
super().__init__(*args, **kwargs)
self._force_single_sample = force_single_sample
stream_names = [_["info"]["name"][0] for _ in self._streams]
# Create template messages for each stream
self._templates = {}
for stream_name, stream_meta in self._metadata.items():
stream = self._streams[stream_names.index(stream_name)]
labels = labels_from_strm(stream)
fs = stream_meta["nominal_srate"]
time_ax = (
AxisArray.TimeAxis(fs=fs, offset=0.0)
if fs
else AxisArray.CoordinateAxis(data=np.array([]), dims=["time"], unit="s")
)
self._templates[stream_name] = AxisArray(
data=np.zeros(
(0, stream_meta["channel_count"]), dtype=stream["time_series"].dtype
),
dims=["time", "ch"],
axes={
"time": time_ax,
"ch": AxisArray.CoordinateAxis(data=np.array(labels), dims=["ch"]),
},
key=stream_name,
)
self._pubqueue: queue.SimpleQueue[AxisArray] = queue.SimpleQueue()
def __next__(self) -> AxisArray | None:
if self._pubqueue.empty():
chunk_dict = super().__next__()
for k, template in self._templates.items():
if k in chunk_dict and len(chunk_dict[k][1]) > 0:
data, tvec = chunk_dict[k]
if k in self._force_single_sample:
if isinstance(template.axes["time"], AxisArray.CoordinateAxis):
t_kwargs = {"data": np.array([])}
else:
t_kwargs = {"offset": 0.0}
for ix, _t in enumerate(tvec):
if "data" in t_kwargs:
t_kwargs["data"] = np.array([_t])
else:
t_kwargs["offset"] = _t
self._pubqueue.put_nowait(
replace(
template,
data=data[ix : ix + 1],
axes={
**template.axes,
"time": replace(
template.axes["time"], **t_kwargs
),
},
)
)
else:
if isinstance(template.axes["time"], AxisArray.CoordinateAxis):
t_kwargs = {"data": tvec if len(tvec) else np.array([])}
else:
t_kwargs = {
"offset": tvec[0] if len(tvec) else self._last_time
}
self._pubqueue.put_nowait(
replace(
template,
data=data,
axes={
**template.axes,
"time": replace(
template.axes["time"],
**t_kwargs,
),
},
)
)
try:
return self._pubqueue.get_nowait()
except queue.Empty:
return None