Source code for ezmsg.learn.process.slda
"""Shrinkage LDA classifier processor.
.. note::
This module supports the Array API standard via
``array_api_compat.get_namespace()``. Input data is manipulated using
Array API operations (``permute_dims``, ``reshape``); a NumPy boundary
is applied before ``sklearn.predict_proba``.
"""
import typing
import ezmsg.core as ez
import numpy as np
from array_api_compat import get_namespace, is_numpy_array
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from ..util import ClassifierMessage
[docs]
@processor_state
class SLDAState:
lda: LDA
out_template: typing.Optional[ClassifierMessage] = None
[docs]
class SLDATransformer(BaseStatefulTransformer[SLDASettings, AxisArray, ClassifierMessage, SLDAState]):
def _reset_state(self, message: AxisArray) -> None:
if self.settings.settings_path[-4:] == ".mat":
# Expects a very specific format from a specific project. Not for general use.
import scipy.io as sio
matlab_sLDA = sio.loadmat(self.settings.settings_path, squeeze_me=True)
temp_weights = matlab_sLDA["weights"][1, 1:]
temp_intercept = matlab_sLDA["weights"][1, 0]
# Create weights and use zeros for channels we do not keep.
channels = matlab_sLDA["channels"] - 4
channels -= channels[0] # Offsets are wrong somehow.
n_channels = message.data.shape[message.dims.index("ch")]
valid_indices = [ch for ch in channels if ch < n_channels]
full_weights = np.zeros(n_channels)
full_weights[valid_indices] = temp_weights[: len(valid_indices)]
lda = LDA(solver="lsqr", shrinkage="auto")
lda.classes_ = np.asarray([0, 1])
lda.coef_ = np.expand_dims(full_weights, axis=0)
lda.intercept_ = temp_intercept # TODO: Is this supposed to be per-channel? Why the [1, 0]?
self.state.lda = lda
# mean = matlab_sLDA['mXtrain']
# std = matlab_sLDA['sXtrain']
# lags = matlab_sLDA['lags'] + 1
else:
import pickle
with open(self.settings.settings_path, "rb") as f:
self.state.lda = pickle.load(f)
# Create template ClassifierMessage using lda.classes_
out_labels = self.state.lda.classes_.tolist()
zero_shape = (0, len(out_labels))
self.state.out_template = ClassifierMessage(
data=np.zeros(zero_shape, dtype=message.data.dtype),
dims=[self.settings.axis, "classes"],
axes={
self.settings.axis: message.axes[self.settings.axis],
"classes": AxisArray.CoordinateAxis(data=np.array(out_labels), dims=["classes"]),
},
labels=out_labels,
key=message.key,
)
def _process(self, message: AxisArray) -> ClassifierMessage:
xp = get_namespace(message.data)
samp_ax_idx = message.dims.index(self.settings.axis)
# Move sample axis to front
perm = (samp_ax_idx,) + tuple(i for i in range(message.data.ndim) if i != samp_ax_idx)
X = xp.permute_dims(message.data, perm)
if X.shape[0]:
if isinstance(self.settings.settings_path, str) and self.settings.settings_path[-4:] == ".mat":
# Assumes F-contiguous weights — need numpy for predict_proba
X_np = np.asarray(X) if not is_numpy_array(X) else X
pred_probas = []
for samp in X_np:
tmp = samp.flatten(order="F") * 1e-6
tmp = np.expand_dims(tmp, axis=0)
probas = self.state.lda.predict_proba(tmp)
pred_probas.append(probas)
pred_probas = np.concatenate(pred_probas, axis=0)
else:
# Numpy boundary before sklearn predict_proba
X_np = np.asarray(X) if not is_numpy_array(X) else X
X_np = X_np.reshape(X_np.shape[0], -1)
pred_probas = self.state.lda.predict_proba(X_np)
update_ax = self.state.out_template.axes[self.settings.axis]
update_ax.offset = message.axes[self.settings.axis].offset
return replace(
self.state.out_template,
data=pred_probas,
axes={
**self.state.out_template.axes,
# `replace` will copy the minimal set of fields
self.settings.axis: replace(update_ax, offset=update_ax.offset),
},
)
else:
return self.state.out_template
[docs]
class SLDA(BaseTransformerUnit[SLDASettings, AxisArray, ClassifierMessage, SLDATransformer]):
SETTINGS = SLDASettings