Source code for ezmsg.sigproc.denormalize
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
[docs]
class DenormalizeSettings(ez.Settings):
low_rate: float = 2.0
"""Low end of probable rate after denormalization (Hz)."""
high_rate: float = 40.0
"""High end of probable rate after denormalization (Hz)."""
distribution: str = "uniform"
"""Distribution to sample rates from. Options are 'uniform', 'normal', or 'constant'."""
[docs]
@processor_state
class DenormalizeState:
gains: npt.NDArray | None = None
offsets: npt.NDArray | None = None
[docs]
class DenormalizeTransformer(BaseStatefulTransformer[DenormalizeSettings, AxisArray, AxisArray, DenormalizeState]):
"""
Scales data from a normalized distribution (mean=0, std=1) to a denormalized
distribution using random per-channel offsets and gains designed to keep the
99.9% CIs between 0 and 2x the offset.
This is useful for simulating realistic firing rates from normalized data.
"""
def _reset_state(self, message: AxisArray) -> None:
ax_ix = message.get_axis_idx("ch")
nch = message.data.shape[ax_ix]
arr_size = (nch, 1) if ax_ix == 0 else (1, nch)
if self.settings.distribution == "uniform":
self.state.offsets = np.random.uniform(2.0, 40.0, size=arr_size)
elif self.settings.distribution == "normal":
self.state.offsets = np.random.normal(
loc=(self.settings.low_rate + self.settings.high_rate) / 2.0,
scale=(self.settings.high_rate - self.settings.low_rate) / 6.0,
size=arr_size,
)
self.state.offsets = np.clip(
self.state.offsets,
a_min=self.settings.low_rate,
a_max=self.settings.high_rate,
)
elif self.settings.distribution == "constant":
self.state.offsets = np.full(
shape=arr_size,
fill_value=(self.settings.low_rate + self.settings.high_rate) / 2.0,
)
else:
raise ValueError(f"Invalid distribution: {self.settings.distribution}")
# Input has std == 1
# Desired output has range from 0 to 2*self.state.offsets within 99.9% confidence interval
# For a standard normal distribution, 99.9% of data is within +/- 3.29 std devs.
# So, gain = offset / 3.29 to scale the std dev appropriately.
self.state.gains = self.state.offsets / 3.29
def _process(self, message: AxisArray) -> AxisArray:
denorm = message.data * self.state.gains + self.state.offsets
return replace(
message,
data=np.clip(denorm, a_min=0.0, a_max=None),
)
[docs]
class DenormalizeUnit(BaseTransformerUnit[DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer]):
SETTINGS = DenormalizeSettings