Source code for ezmsg.learn.process.slda

"""Shrinkage LDA classifier processor.

.. note::
    This module supports the Array API standard via
    ``array_api_compat.get_namespace()``.  Input data is manipulated using
    Array API operations (``permute_dims``, ``reshape``); a NumPy boundary
    is applied before ``sklearn.predict_proba``.
"""

import typing

import ezmsg.core as ez
import numpy as np
from array_api_compat import get_namespace, is_numpy_array
from ezmsg.baseproc import (
    BaseStatefulTransformer,
    BaseTransformerUnit,
    processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA

from ..util import ClassifierMessage


[docs] class SLDASettings(ez.Settings): settings_path: str axis: str = "time"
[docs] @processor_state class SLDAState: lda: LDA out_template: typing.Optional[ClassifierMessage] = None
[docs] class SLDATransformer(BaseStatefulTransformer[SLDASettings, AxisArray, ClassifierMessage, SLDAState]): def _reset_state(self, message: AxisArray) -> None: if self.settings.settings_path[-4:] == ".mat": # Expects a very specific format from a specific project. Not for general use. import scipy.io as sio matlab_sLDA = sio.loadmat(self.settings.settings_path, squeeze_me=True) temp_weights = matlab_sLDA["weights"][1, 1:] temp_intercept = matlab_sLDA["weights"][1, 0] # Create weights and use zeros for channels we do not keep. channels = matlab_sLDA["channels"] - 4 channels -= channels[0] # Offsets are wrong somehow. n_channels = message.data.shape[message.dims.index("ch")] valid_indices = [ch for ch in channels if ch < n_channels] full_weights = np.zeros(n_channels) full_weights[valid_indices] = temp_weights[: len(valid_indices)] lda = LDA(solver="lsqr", shrinkage="auto") lda.classes_ = np.asarray([0, 1]) lda.coef_ = np.expand_dims(full_weights, axis=0) lda.intercept_ = temp_intercept # TODO: Is this supposed to be per-channel? Why the [1, 0]? self.state.lda = lda # mean = matlab_sLDA['mXtrain'] # std = matlab_sLDA['sXtrain'] # lags = matlab_sLDA['lags'] + 1 else: import pickle with open(self.settings.settings_path, "rb") as f: self.state.lda = pickle.load(f) # Create template ClassifierMessage using lda.classes_ out_labels = self.state.lda.classes_.tolist() zero_shape = (0, len(out_labels)) self.state.out_template = ClassifierMessage( data=np.zeros(zero_shape, dtype=message.data.dtype), dims=[self.settings.axis, "classes"], axes={ self.settings.axis: message.axes[self.settings.axis], "classes": AxisArray.CoordinateAxis(data=np.array(out_labels), dims=["classes"]), }, labels=out_labels, key=message.key, ) def _process(self, message: AxisArray) -> ClassifierMessage: xp = get_namespace(message.data) samp_ax_idx = message.dims.index(self.settings.axis) # Move sample axis to front perm = (samp_ax_idx,) + tuple(i for i in range(message.data.ndim) if i != samp_ax_idx) X = xp.permute_dims(message.data, perm) if X.shape[0]: if isinstance(self.settings.settings_path, str) and self.settings.settings_path[-4:] == ".mat": # Assumes F-contiguous weights — need numpy for predict_proba X_np = np.asarray(X) if not is_numpy_array(X) else X pred_probas = [] for samp in X_np: tmp = samp.flatten(order="F") * 1e-6 tmp = np.expand_dims(tmp, axis=0) probas = self.state.lda.predict_proba(tmp) pred_probas.append(probas) pred_probas = np.concatenate(pred_probas, axis=0) else: # Numpy boundary before sklearn predict_proba X_np = np.asarray(X) if not is_numpy_array(X) else X X_np = X_np.reshape(X_np.shape[0], -1) pred_probas = self.state.lda.predict_proba(X_np) update_ax = self.state.out_template.axes[self.settings.axis] update_ax.offset = message.axes[self.settings.axis].offset return replace( self.state.out_template, data=pred_probas, axes={ **self.state.out_template.axes, # `replace` will copy the minimal set of fields self.settings.axis: replace(update_ax, offset=update_ax.offset), }, ) else: return self.state.out_template
[docs] class SLDA(BaseTransformerUnit[SLDASettings, AxisArray, ClassifierMessage, SLDATransformer]): SETTINGS = SLDASettings