Source code for ezmsg.sigproc.singlebandpow
"""
Time-domain single-band power estimation.
Two methods are provided:
1. **RMS Band Power** — Bandpass filter, square, window into bins, take the mean, optionally take the square root.
2. **Square-Law + LPF Band Power** — Bandpass filter, square, lowpass filter (smoothing), downsample.
"""
from dataclasses import field
import ezmsg.core as ez
from ezmsg.baseproc import (
BaseProcessor,
BaseStatefulProcessor,
BaseTransformerUnit,
CompositeProcessor,
)
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.modify import ModifyAxisSettings, ModifyAxisTransformer
from .aggregate import AggregateSettings, AggregateTransformer, AggregationFunction
from .butterworthfilter import ButterworthFilterSettings, ButterworthFilterTransformer
from .downsample import DownsampleSettings, DownsampleTransformer
from .math.pow import PowSettings, PowTransformer
from .window import WindowTransformer
[docs]
class RMSBandPowerSettings(ez.Settings):
"""Settings for :obj:`RMSBandPowerTransformer`."""
bandpass: ButterworthFilterSettings = field(
default_factory=lambda: ButterworthFilterSettings(order=4, coef_type="sos")
)
"""Butterworth bandpass filter settings. Set ``cuton`` and ``cutoff`` to define the band."""
bin_duration: float = 0.05
"""Duration of each non-overlapping bin in seconds."""
apply_sqrt: bool = True
"""If True, output is RMS (root-mean-square). If False, output is mean-square power."""
[docs]
class RMSBandPowerTransformer(CompositeProcessor[RMSBandPowerSettings, AxisArray, AxisArray]):
"""
RMS band power estimation.
Pipeline: bandpass -> square -> window(bins) -> mean(time) -> rename bin->time -> [sqrt]
"""
@staticmethod
def _initialize_processors(
settings: RMSBandPowerSettings,
) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
procs: dict[str, BaseProcessor | BaseStatefulProcessor] = {
"bandpass": ButterworthFilterTransformer(settings.bandpass),
"square": PowTransformer(PowSettings(exponent=2.0)),
"window": WindowTransformer(
axis="time",
newaxis="bin",
window_dur=settings.bin_duration,
window_shift=settings.bin_duration,
zero_pad_until="none",
),
"aggregate": AggregateTransformer(AggregateSettings(axis="time", operation=AggregationFunction.MEAN)),
"rename": ModifyAxisTransformer(settings=ModifyAxisSettings(name_map={"bin": "time"})),
}
if settings.apply_sqrt:
procs["sqrt"] = PowTransformer(PowSettings(exponent=0.5))
return procs
def _post_process(self, result: AxisArray | None) -> AxisArray | None:
if result is not None:
try:
import mlx.core as mx
if isinstance(result.data, mx.array):
mx.eval(result.data)
except ImportError:
pass
return result
[docs]
class RMSBandPower(BaseTransformerUnit[RMSBandPowerSettings, AxisArray, AxisArray, RMSBandPowerTransformer]):
SETTINGS = RMSBandPowerSettings
[docs]
class SquareLawBandPowerSettings(ez.Settings):
"""Settings for :obj:`SquareLawBandPowerTransformer`."""
bandpass: ButterworthFilterSettings = field(
default_factory=lambda: ButterworthFilterSettings(order=4, coef_type="sos")
)
"""Butterworth bandpass filter settings. Set ``cuton`` and ``cutoff`` to define the band."""
lowpass: ButterworthFilterSettings = field(
default_factory=lambda: ButterworthFilterSettings(order=4, coef_type="sos")
)
"""Butterworth lowpass filter settings for smoothing the squared signal."""
downsample: DownsampleSettings = field(default_factory=DownsampleSettings)
"""Downsample settings for rate reduction after lowpass smoothing."""
[docs]
class SquareLawBandPowerTransformer(CompositeProcessor[SquareLawBandPowerSettings, AxisArray, AxisArray]):
"""
Square-law + LPF band power estimation.
Pipeline: bandpass -> square -> lowpass -> downsample
"""
@staticmethod
def _initialize_processors(
settings: SquareLawBandPowerSettings,
) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
return {
"bandpass": ButterworthFilterTransformer(settings.bandpass),
"square": PowTransformer(PowSettings(exponent=2.0)),
"lowpass": ButterworthFilterTransformer(settings.lowpass),
"downsample": DownsampleTransformer(settings.downsample),
}
[docs]
class SquareLawBandPower(
BaseTransformerUnit[SquareLawBandPowerSettings, AxisArray, AxisArray, SquareLawBandPowerTransformer]
):
SETTINGS = SquareLawBandPowerSettings