Source code for ezmsg.panel.lineplot

import asyncio
from functools import partial
from typing import Dict, List, Optional

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
import panel
from bokeh.models import ColumnDataSource
from bokeh.models.renderers import GlyphRenderer
from bokeh.plotting import figure
from ezmsg.util.messages.axisarray import AxisArray
from param.parameterized import Event

from .util import AxisScale

CDS_X_DIM = "__x__"


[docs] class LinePlotSettings(ez.Settings): name: str = "LinePlot" x_axis: Optional[str] = None # If not specified, dim 0 is used. x_axis_scale: AxisScale = AxisScale.LINEAR y_axis_scale: AxisScale = AxisScale.LINEAR y_axis_label: Optional[str] = None x_axis_label: Optional[str] = None
[docs] class LinePlotState(ez.State): x_data: npt.NDArray cds_data: Dict[str, npt.NDArray] # Visualization controls channelize: panel.widgets.Checkbox gain: panel.widgets.FloatInput update_ev: asyncio.Event cur_signal: Optional[AxisArray]
[docs] class LinePlot(ez.Unit): SETTINGS = LinePlotSettings STATE = LinePlotState INPUT_SIGNAL = ez.InputStream(Optional[AxisArray])
[docs] def initialize(self) -> None: self.STATE.x_data = np.arange(0) self.STATE.cds_data = dict() self.STATE.update_ev = asyncio.Event() self.STATE.update_ev.clear() self.STATE.cur_signal = None self.STATE.channelize = panel.widgets.Checkbox(name="Channelize", value=True) self.STATE.gain = panel.widgets.FloatInput(name="Gain", value=1.0) def on_vis_control(*events: Event) -> None: self.STATE.update_ev.set() self.STATE.channelize.param.watch(on_vis_control, "value") self.STATE.gain.param.watch(on_vis_control, "value")
[docs] def plot(self) -> panel.viewable.Viewable: cds = ColumnDataSource() x_axis_type, y_axis_type = "linear", "linear" if self.SETTINGS.x_axis_scale == AxisScale.LOG: x_axis_type = "log" if self.SETTINGS.y_axis_scale == AxisScale.LOG: y_axis_type = "log" axis_labels = dict() if self.SETTINGS.x_axis_label is not None: axis_labels["x_axis_label"] = self.SETTINGS.x_axis_label if self.SETTINGS.y_axis_label is not None: axis_labels["y_axis_label"] = self.SETTINGS.y_axis_label fig = figure( sizing_mode="stretch_width", title=self.SETTINGS.name, output_backend="webgl", x_axis_type=x_axis_type, y_axis_type=y_axis_type, tooltips=[("x", "$x"), ("y", "$y")], **axis_labels, ) lines = dict() @panel.io.with_lock def _update( fig: figure, cds: ColumnDataSource, lines: Dict[str, GlyphRenderer] ) -> None: cds_data = {**self.STATE.cds_data, **{CDS_X_DIM: self.STATE.x_data}} for key in list(lines.keys() - self.STATE.cds_data.keys()): cds.remove(key) fig.renderers.remove(lines[key]) del lines[key] for key in list(self.STATE.cds_data.keys() - lines.keys()): cds.add([], key) lines[key] = fig.line(x=CDS_X_DIM, y=key, source=cds) cds.data = cds_data _ = panel.state.add_periodic_callback( partial(_update, fig, cds, lines), period=50 ) return panel.pane.Bokeh(fig)
@property def controls(self) -> List[panel.viewable.Viewable]: return [ self.STATE.channelize, self.STATE.gain, ]
[docs] def panel(self) -> panel.viewable.Viewable: return panel.Row( self.plot(), panel.Column("__Line Plot Controls__", *self.controls) )
[docs] @ez.subscriber(INPUT_SIGNAL) async def on_signal(self, msg: Optional[AxisArray]) -> None: self.STATE.cur_signal = msg self.STATE.update_ev.set()
[docs] @ez.task async def update_data(self) -> None: while True: await self.STATE.update_ev.wait() self.STATE.update_ev.clear() msg = self.STATE.cur_signal if msg is None: # clear the plot self.STATE.x_data = np.arange(0) self.STATE.cds_data = dict() continue axis_name = self.SETTINGS.x_axis if axis_name is None: axis_name = msg.dims[0] axis = msg.get_axis(axis_name) with msg.view2d(axis_name) as view: ch_names = getattr(msg, "ch_names", None) if ch_names is None: ch_names = [f"ch_{i}" for i in range(view.shape[1])] self.STATE.x_data = (np.arange(view.shape[0]) * axis.gain) + axis.offset vis_view = view * self.STATE.gain.value if self.STATE.channelize.value: vis_view += np.arange(len(ch_names)) self.STATE.cds_data = { ch_name: vis_view[:, ch_idx] for ch_idx, ch_name in enumerate(ch_names) }