Source code for ezmsg.sigproc.wavelets

import typing

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
import pywt
from ezmsg.baseproc import (
    BaseStatefulTransformer,
    BaseTransformerUnit,
    processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from .filterbank import FilterbankMode, MinPhaseMode, filterbank


[docs] class CWTSettings(ez.Settings): """ Settings for :obj:`CWT` See :obj:`cwt` for argument details. """ frequencies: list | tuple | npt.NDArray | None wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet min_phase: MinPhaseMode = MinPhaseMode.NONE axis: str = "time" scales: list | tuple | npt.NDArray | None = None
[docs] @processor_state class CWTState: neg_rt_scales: npt.NDArray | None = None int_psi_scales: list[npt.NDArray] | None = None template: AxisArray | None = None fbgen: typing.Generator[AxisArray, AxisArray, None] | None = None last_conv_samp: npt.NDArray | None = None
[docs] class CWTTransformer(BaseStatefulTransformer[CWTSettings, AxisArray, AxisArray, CWTState]): def _hash_message(self, message: AxisArray) -> int: ax_idx = message.get_axis_idx(self.settings.axis) in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :] return hash( ( message.data.dtype.kind, message.axes[self.settings.axis].gain, in_shape, message.key, ) ) def _reset_state(self, message: AxisArray) -> None: precision = 10 # Process wavelet wavelet = ( self.settings.wavelet if isinstance(self.settings.wavelet, (pywt.ContinuousWavelet, pywt.Wavelet)) else pywt.DiscreteContinuousWavelet(self.settings.wavelet) ) # Process wavelet integration int_psi, wave_xvec = pywt.integrate_wavelet(wavelet, precision=precision) int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi # Calculate scales and frequencies if self.settings.frequencies is not None: frequencies = np.sort(np.array(self.settings.frequencies)) scales = pywt.frequency2scale( wavelet, frequencies * message.axes[self.settings.axis].gain, precision=precision, ) else: scales = np.sort(self.settings.scales)[::-1] self._state.neg_rt_scales = -np.sqrt(scales)[:, None] # Convert to appropriate dtype dt_data = message.data.dtype dt_cplx = np.result_type(dt_data, np.complex64) dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data int_psi = np.asarray(int_psi, dtype=dt_psi) # Note: Currently int_psi cannot be made non-complex once it is complex. # Calculate waves for each scale wave_xvec = np.asarray(wave_xvec, dtype=message.data.real.dtype) wave_range = wave_xvec[-1] - wave_xvec[0] step = wave_xvec[1] - wave_xvec[0] self._state.int_psi_scales = [] for scale in scales: reix = (np.arange(scale * wave_range + 1) / (scale * step)).astype(int) if reix[-1] >= int_psi.size: reix = np.extract(reix < int_psi.size, reix) self._state.int_psi_scales.append(int_psi[reix][::-1]) # Setup filterbank generator self._state.fbgen = filterbank( self._state.int_psi_scales, mode=FilterbankMode.CONV, min_phase=self.settings.min_phase, axis=self.settings.axis, ) # Create output template ax_idx = message.get_axis_idx(self.settings.axis) in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :] freqs = pywt.scale2frequency(wavelet, scales, precision) / message.axes[self.settings.axis].gain dummy_shape = in_shape + (len(scales), 0) self._state.template = AxisArray( np.zeros(dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data), dims=message.dims[:ax_idx] + message.dims[ax_idx + 1 :] + ["freq", self.settings.axis], axes={ **message.axes, "freq": AxisArray.CoordinateAxis(unit="Hz", data=freqs, dims=["freq"]), }, key=message.key, ) self._state.last_conv_samp = np.zeros(dummy_shape[:-1] + (1,), dtype=self._state.template.data.dtype) def _process(self, message: AxisArray) -> AxisArray: conv_msg = self._state.fbgen.send(message) # Prepend with last_conv_samp before doing diff dat = np.concatenate((self._state.last_conv_samp, conv_msg.data), axis=-1) coef = self._state.neg_rt_scales * np.diff(dat, axis=-1) # Store last_conv_samp for next iteration self._state.last_conv_samp = conv_msg.data[..., -1:] if self._state.template.data.dtype.kind != "c": coef = coef.real # pywt.cwt slices off the beginning and end of the result where the convolution overran. We don't have # that luxury when streaming. # d = (coef.shape[-1] - msg_in.data.shape[ax_idx]) / 2. # coef = coef[..., math.floor(d):-math.ceil(d)] return replace( self._state.template, data=coef, axes={ **self._state.template.axes, self.settings.axis: message.axes[self.settings.axis], }, )
[docs] class CWT(BaseTransformerUnit[CWTSettings, AxisArray, AxisArray, CWTTransformer]): SETTINGS = CWTSettings
[docs] def cwt( frequencies: list | tuple | npt.NDArray | None, wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet, min_phase: MinPhaseMode = MinPhaseMode.NONE, axis: str = "time", scales: list | tuple | npt.NDArray | None = None, ) -> CWTTransformer: """ Perform a continuous wavelet transform. The function is equivalent to the :obj:`pywt.cwt` function, but is designed to work with streaming data. Args: frequencies: The wavelet frequencies to use in Hz. If `None` provided then the scales will be used. Note: frequencies will be sorted from smallest to largest. wavelet: Wavelet object or name of wavelet to use. min_phase: See filterbank MinPhaseMode for details. axis: The target axis for operation. Note that this will be moved to the -1th dimension because fft and matrix multiplication is much faster on the last axis. This axis must be in the msg.axes and it must be of type AxisArray.LinearAxis. scales: The scales to use. If None, the scales will be calculated from the frequencies. Note: Scales will be sorted from largest to smallest. Note: Use of scales is deprecated in favor of frequencies. Convert scales to frequencies using `pywt.scale2frequency(wavelet, scales, precision=10) * fs` where fs is the sampling frequency. Returns: A primed Generator object that expects an :obj:`AxisArray` via `.send(axis_array)` of continuous data and yields an :obj:`AxisArray` with a continuous wavelet transform in its data. """ return CWTTransformer( CWTSettings( frequencies=frequencies, wavelet=wavelet, min_phase=min_phase, axis=axis, scales=scales, ) )