Source code for ezmsg.sigproc.downsample
import numpy as np
from ezmsg.util.messages.axisarray import (
AxisArray,
slice_along_axis,
replace,
)
import ezmsg.core as ez
from .base import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
)
[docs]
class DownsampleSettings(ez.Settings):
"""
Settings for :obj:`Downsample` node.
"""
axis: str = "time"
"""The name of the axis along which to downsample."""
target_rate: float | None = None
"""Desired rate after downsampling. The actual rate will be the nearest integer factor of the
input rate that is the same or higher than the target rate."""
factor: int | None = None
"""Explicitly specify downsample factor. If specified, target_rate is ignored."""
[docs]
@processor_state
class DownsampleState:
q: int = 0
"""The integer downsampling factor. It will be determined based on the target rate."""
s_idx: int = 0
"""Index of the next msg's first sample into the virtual rotating ds_factor counter."""
[docs]
class DownsampleTransformer(
BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]
):
"""
Downsampled data simply comprise every `factor`th sample.
This should only be used following appropriate lowpass filtering.
If your pipeline does not already have lowpass filtering then consider
using the :obj:`Decimate` collection instead.
"""
def _hash_message(self, message: AxisArray) -> int:
return hash((message.axes[self.settings.axis].gain, message.key))
def _reset_state(self, message: AxisArray) -> None:
axis_info = message.get_axis(self.settings.axis)
if self.settings.factor is not None:
q = self.settings.factor
elif self.settings.target_rate is None:
q = 1
else:
q = int(1 / (axis_info.gain * self.settings.target_rate))
if q < 1:
ez.logger.warning(
f"Target rate {self.settings.target_rate} cannot be achieved with input rate of {1 / axis_info.gain}."
"Setting factor to 1."
)
q = 1
self._state.q = q
self._state.s_idx = 0
def _process(self, message: AxisArray) -> AxisArray:
axis = self.settings.axis
axis_info = message.get_axis(axis)
axis_idx = message.get_axis_idx(axis)
n_samples = message.data.shape[axis_idx]
samples = (
np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
)
if n_samples > 0:
# Update state for next iteration.
self._state.s_idx = samples[-1] + 1
pub_samples = np.where(samples == 0)[0]
if len(pub_samples) > 0:
n_step = pub_samples[0].item()
data_slice = pub_samples
else:
n_step = 0
data_slice = slice(None, 0, None)
msg_out = replace(
message,
data=slice_along_axis(message.data, data_slice, axis=axis_idx),
axes={
**message.axes,
axis: replace(
axis_info,
gain=axis_info.gain * self._state.q,
offset=axis_info.offset + axis_info.gain * n_step,
),
},
)
return msg_out
[docs]
class Downsample(
BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]
):
SETTINGS = DownsampleSettings
[docs]
def downsample(
axis: str = "time",
target_rate: float | None = None,
factor: int | None = None,
) -> DownsampleTransformer:
return DownsampleTransformer(
DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor)
)