Source code for ezmsg.learn.collection.sample_adapt_regressor
from dataclasses import field
import ezmsg.core as ez
from ezmsg.baseproc import SampleTriggerMessage
from ezmsg.sigproc.resample import ResampleSettings, ResampleUnit
from ezmsg.sigproc.window import Window, WindowSettings
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.learn.process.adaptive_linear_regressor import (
AdaptiveLinearRegressorSettings,
AdaptiveLinearRegressorUnit,
)
from ezmsg.learn.process.flatten import Flatten, FlattenSettings
from ezmsg.learn.process.seqseqsampler import SeqSeqSamplerSettings, SeqSeqSamplerUnit
from ezmsg.learn.util import AdaptiveLinearRegressor
[docs]
class SampleAdaptRegressorSettings(ez.Settings):
# AdaptiveLinearRegressor settings
model_type: AdaptiveLinearRegressor = AdaptiveLinearRegressor.LINEAR
"""Adaptive regressor backend/model."""
model_path: str | None = None
"""Optional path to a pickled preconfigured model."""
model_kwargs: dict = field(default_factory=dict)
"""Extra kwargs passed to the underlying regressor."""
# Resampling settings
resample_axis: str = "time"
"""Axis to resample along."""
resample_buffer_duration: float = 2.0
"""Duration of the buffer for resampling in seconds."""
# SeqSeqSampler settings
sampler_max_buffer_dur: float = 5.0
"""Maximum buffer duration for the SeqSeqSampler in seconds."""
decode_window_dur: float | None = None
"""Optional inference-side feature window duration in seconds."""
decode_window_shift: float | None = None
"""Optional inference-side feature window shift in seconds."""
[docs]
class SampleAdaptRegressor(ez.Collection):
SETTINGS = SampleAdaptRegressorSettings
INPUT_LABELS = ez.InputTopic(AxisArray)
INPUT_SIGNAL = ez.InputTopic(AxisArray)
INPUT_TRIGGER = ez.InputTopic(SampleTriggerMessage)
OUTPUT_SIGNAL = ez.OutputTopic(AxisArray)
RESAMPLE = ResampleUnit()
SEQSEQSAMPLER = SeqSeqSamplerUnit()
WINDOW = Window()
FLATTEN = Flatten()
REGRESSOR = AdaptiveLinearRegressorUnit()
[docs]
def network(self) -> ez.NetworkDefinition:
network = [
(self.INPUT_LABELS, self.RESAMPLE.INPUT_SIGNAL),
(self.INPUT_SIGNAL, self.RESAMPLE.INPUT_REFERENCE),
(self.RESAMPLE.OUTPUT_SIGNAL, self.SEQSEQSAMPLER.INPUT_VALUE),
(self.INPUT_SIGNAL, self.SEQSEQSAMPLER.INPUT_SIGNAL),
(self.INPUT_TRIGGER, self.SEQSEQSAMPLER.INPUT_TRIGGER),
(self.SEQSEQSAMPLER.OUTPUT_SAMPLE, self.REGRESSOR.INPUT_SAMPLE),
]
if self.SETTINGS.decode_window_dur is None:
network.append((self.INPUT_SIGNAL, self.REGRESSOR.INPUT_SIGNAL))
else:
network.extend(
[
(self.INPUT_SIGNAL, self.WINDOW.INPUT_SIGNAL),
(self.WINDOW.OUTPUT_SIGNAL, self.FLATTEN.INPUT_SIGNAL),
(self.FLATTEN.OUTPUT_SIGNAL, self.REGRESSOR.INPUT_SIGNAL),
]
)
network.append((self.REGRESSOR.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL))
return tuple(network)