Source code for ezmsg.learn.model.refit_kalman

# refit_kalman.py
"""Refit Kalman filter for adaptive neural decoding.

.. note::
    This module supports the Array API standard via
    ``array_api_compat.get_namespace()``.  All linear algebra in :meth:`fit`,
    :meth:`predict`, and :meth:`update` stays in the source array namespace.
    The DARE solver in :meth:`_compute_gain` and the per-sample mutation loop
    in :meth:`refit` use NumPy regardless of input backend.
"""

import numpy as np
from array_api_compat import get_namespace
from ezmsg.sigproc.util.array import array_device, xp_asarray, xp_create
from numpy.linalg import LinAlgError
from scipy.linalg import solve_discrete_are


[docs] class RefitKalmanFilter: """ Refit Kalman filter for adaptive neural decoding. This class implements a Kalman filter that can be refitted online during operation. Unlike the standard Kalman filter, this version can adapt its observation model (H and Q matrices) based on new data while maintaining the state transition model (A and W matrices). This is particularly useful for brain-computer interfaces where the relationship between neural activity and intended movements may change over time. The filter operates in two phases: 1. Initial fitting: Learns all system matrices (A, W, H, Q) from training data 2. Refitting: Updates only the observation model (H, Q) based on new data Attributes: A_state_transition_matrix: The state transition matrix A (n_states x n_states). W_process_noise_covariance: The process noise covariance matrix W (n_states x n_states). H_observation_matrix: The observation matrix H (n_observations x n_states). Q_measurement_noise_covariance: The measurement noise covariance matrix Q (n_observations x n_observations). K_kalman_gain: The Kalman gain matrix (n_states x n_observations). P_state_covariance: The state error covariance matrix (n_states x n_states). steady_state: Whether to use steady-state Kalman gain computation. is_fitted: Whether the model has been fitted with data. Example: >>> # Create and fit the filter >>> rkf = RefitKalmanFilter(steady_state=True) >>> rkf.fit(X_train, y_train) >>> >>> # Refit with new data >>> rkf.refit(X_new, Y_state, velocity_indices, targets, cursors, holds) >>> >>> # Predict with updated model >>> x_updated = rkf.predict_and_update(measurement, current_state) """
[docs] def __init__( self, A_state_transition_matrix=None, W_process_noise_covariance=None, H_observation_matrix=None, Q_measurement_noise_covariance=None, steady_state=False, enforce_state_structure=False, alpha_fading_memory=1.000, process_noise_scale=1, measurement_noise_scale=1.2, ): self.A_state_transition_matrix = A_state_transition_matrix self.W_process_noise_covariance = W_process_noise_covariance self.H_observation_matrix = H_observation_matrix self.Q_measurement_noise_covariance = Q_measurement_noise_covariance self.K_kalman_gain = None self.P_state_covariance = None self.alpha_fading_memory = alpha_fading_memory # Noise scaling factors for smoothing control self.process_noise_scale = process_noise_scale self.measurement_noise_scale = measurement_noise_scale self.steady_state = steady_state self.enforce_state_structure = enforce_state_structure self.is_fitted = False
def _validate_state_vector(self, Y_state): """ Validate that the state vector has proper dimensions. Args: Y_state: State vector to validate Raises: ValueError: If state vector has invalid dimensions """ if Y_state.ndim != 2: raise ValueError(f"State vector must be 2D, got {Y_state.ndim}D") if not hasattr(self, "H_observation_matrix") or self.H_observation_matrix is None: raise ValueError("Model must be fitted before refitting") expected_states = self.H_observation_matrix.shape[1] if Y_state.shape[1] != expected_states: raise ValueError(f"State vector has {Y_state.shape[1]} dimensions, expected {expected_states}")
[docs] def fit(self, X_train, y_train): """ Fit the Refit Kalman filter to the training data. This method learns all system matrices (A, W, H, Q) from training data using least-squares estimation, then computes the steady-state solution. This is the initial fitting phase that establishes the baseline model. Args: X_train: Neural activity (n_samples, n_neurons). y_train: Outputs being predicted (n_samples, n_states). Raises: ValueError: If training data has invalid dimensions. LinAlgError: If matrix operations fail during fitting. """ # self._validate_state_vector(y_train) xp = get_namespace(X_train, y_train) _mT = xp.linalg.matrix_transpose X = xp.asarray(y_train) Z = xp.asarray(X_train) n_samples = X.shape[0] # Calculate the transition matrix (from x_t to x_t+1) using least-squares X2 = X[1:, :] # x_{t+1} X1 = X[:-1, :] # x_t A = _mT(X2) @ X1 @ xp.linalg.inv(_mT(X1) @ X1) # Transition matrix W = _mT(X2 - X1 @ _mT(A)) @ (X2 - X1 @ _mT(A)) / (n_samples - 1) # Covariance of transition matrix # Calculate the measurement matrix (from x_t to z_t) using least-squares H = _mT(Z) @ X @ xp.linalg.inv(_mT(X) @ X) # Measurement matrix Q = _mT(Z - X @ _mT(H)) @ (Z - X @ _mT(H)) / Z.shape[0] # Covariance of measurement matrix self.A_state_transition_matrix = A self.W_process_noise_covariance = W * self.process_noise_scale self.H_observation_matrix = H self.Q_measurement_noise_covariance = Q * self.measurement_noise_scale self._compute_gain() self.is_fitted = True
[docs] def refit( self, X_neural, Y_state, intention_velocity_indices: int | None = None, target_positions=None, cursor_positions=None, hold_indices=None, ): """ Refit the observation model based on new data. This method updates only the observation model (H and Q matrices) while keeping the state transition model (A and W matrices) unchanged. The refitting process modifies the intended states based on target positions and hold flags to better align with user intentions. The refitting process: 1. Modifies intended states based on target positions and hold flags 2. Recalculates the observation matrix H using least-squares 3. Recalculates the measurement noise covariance Q 4. Updates the Kalman gain accordingly Args: X_neural: Neural activity data (n_samples, n_neurons). Y_state: State estimates (n_samples, n_states). intention_velocity_indices: Index of velocity components in state vector. target_positions: Target positions for each sample (n_samples, 2). cursor_positions: Current cursor positions (n_samples, 2). hold_indices: Boolean flags indicating hold periods (n_samples,). Raises: ValueError: If input data has invalid dimensions or the model is not fitted. """ self._validate_state_vector(Y_state) # Check if velocity indices are provided if intention_velocity_indices is None: # Assume (x, y, vx, vy) vel_idx = 2 if Y_state.shape[1] >= 4 else 0 print(f"[RefitKalmanFilter] No velocity index provided — defaulting to {vel_idx}") else: if isinstance(intention_velocity_indices, (list, tuple)): if len(intention_velocity_indices) != 1: raise ValueError("Only one velocity start index should be provided.") vel_idx = intention_velocity_indices[0] else: vel_idx = intention_velocity_indices # The per-sample mutation loop uses numpy for element-wise operations # on small vectors (np.linalg.norm on 2-element vectors, scalar indexing). Y_state_np = np.asarray(Y_state) target_positions_np = np.asarray(target_positions) if target_positions is not None else None cursor_positions_np = np.asarray(cursor_positions) if cursor_positions is not None else None # Only remap velocity if target and cursor positions are provided intended_states = Y_state_np.copy() if target_positions_np is not None and cursor_positions_np is not None: # Calculate intended velocities for each sample for i, (state, pos, target) in enumerate(zip(Y_state_np, cursor_positions_np, target_positions_np)): is_hold = hold_indices[i] if hold_indices is not None else False if is_hold: # During hold periods, intended velocity is zero intended_states[i, vel_idx : vel_idx + 2] = 0.0 if i > 0: intended_states[i, :2] = intended_states[i - 1, :2] # Same position as previous else: # Calculate direction to target to_target = target - pos target_distance = np.linalg.norm(to_target) if target_distance > 1e-5: # Avoid division by zero # Get current decoded velocity magnitude current_velocity = state[vel_idx : vel_idx + 2] current_speed = np.linalg.norm(current_velocity) # Calculate intended velocity: same speed, but toward target target_direction = to_target / target_distance intended_velocity = target_direction * current_speed # Update intended state with new velocity intended_states[i, vel_idx : vel_idx + 2] = intended_velocity # If target is very close, keep original velocity else: intended_states[i, vel_idx : vel_idx + 2] = state[vel_idx : vel_idx + 2] # Convert back to source namespace for final linalg xp = get_namespace(X_neural) dev = array_device(X_neural) _mT = xp.linalg.matrix_transpose intended_states = xp_asarray(xp, intended_states, device=dev) Z = xp.asarray(X_neural) # Recalculate observation matrix and noise covariance H = ( _mT(Z) @ intended_states @ xp.linalg.pinv(_mT(intended_states) @ intended_states) ) # Using pinv() instead of inv() to avoid singular matrix errors Q = _mT(Z - intended_states @ _mT(H)) @ (Z - intended_states @ _mT(H)) / Z.shape[0] self.H_observation_matrix = H self.Q_measurement_noise_covariance = Q self._compute_gain()
def _compute_gain(self): """ Compute the Kalman gain matrix. This method computes the Kalman gain matrix based on the current system parameters. In steady-state mode, it solves the discrete-time algebraic Riccati equation to find the optimal steady-state gain. In non-steady-state mode, it computes the gain using the current covariance matrix. The DARE solver requires NumPy arrays; results are converted back to the source array namespace. Raises: LinAlgError: If the Riccati equation cannot be solved or matrix operations fail. """ xp = get_namespace(self.A_state_transition_matrix) dev = array_device(self.A_state_transition_matrix) _mT = xp.linalg.matrix_transpose # Convert to numpy for DARE (no Array API equivalent) A_np = np.asarray(self.A_state_transition_matrix) H_np = np.asarray(self.H_observation_matrix) W_np = np.asarray(self.W_process_noise_covariance) Q_np = np.asarray(self.Q_measurement_noise_covariance) try: P_np = solve_discrete_are(A_np.T, H_np.T, W_np, Q_np) self.P_state_covariance = xp_asarray(xp, P_np, device=dev) S = ( self.H_observation_matrix @ self.P_state_covariance @ _mT(self.H_observation_matrix) + self.Q_measurement_noise_covariance ) self.K_kalman_gain = self.P_state_covariance @ _mT(self.H_observation_matrix) @ xp.linalg.inv(S) except LinAlgError: Q_reg_np = Q_np + 1e-7 * np.eye(Q_np.shape[0]) try: P_np = solve_discrete_are(A_np.T, H_np.T, W_np, Q_reg_np) self.P_state_covariance = xp_asarray(xp, P_np, device=dev) Q_reg = xp_asarray(xp, Q_reg_np, device=dev) S = self.H_observation_matrix @ self.P_state_covariance @ _mT(self.H_observation_matrix) + Q_reg self.K_kalman_gain = self.P_state_covariance @ _mT(self.H_observation_matrix) @ xp.linalg.inv(S) print("Warning: Used regularized matrices for DARE solution") except LinAlgError: # Fallback to identity or manual initialization print("Warning: DARE failed, using identity covariance") self.P_state_covariance = xp_create(xp.eye, self.A_state_transition_matrix.shape[0], device=dev) # else: # n_states = self.A_state_transition_matrix.shape[0] # self.P_state_covariance = ( # np.eye(n_states) * 1000 # ) # Large initial uncertainty # P_m = ( # self.A_state_transition_matrix # @ self.P_state_covariance # @ self.A_state_transition_matrix.T # + self.W_process_noise_covariance # ) # S = ( # self.H_observation_matrix @ P_m @ self.H_observation_matrix.T # + self.Q_measurement_noise_covariance # ) # self.K_kalman_gain = P_m @ self.H_observation_matrix.T @ np.linalg.pinv(S) # I_mat = np.eye(self.A_state_transition_matrix.shape[0]) # self.P_state_covariance = ( # I_mat - self.K_kalman_gain @ self.H_observation_matrix # ) @ P_m
[docs] def predict(self, x_current): """ Predict the next state and covariance. This method predicts the next state and covariance using the current state. """ xp = get_namespace(x_current) _mT = xp.linalg.matrix_transpose x_predicted = self.A_state_transition_matrix @ x_current if self.steady_state is True: return x_predicted, None else: P_predicted = self.alpha_fading_memory**2 * ( self.A_state_transition_matrix @ self.P_state_covariance @ _mT(self.A_state_transition_matrix) + self.W_process_noise_covariance ) return x_predicted, P_predicted
[docs] def update( self, z_measurement, x_predicted, P_predicted=None, ): """Update state estimate and covariance based on measurement z.""" xp = get_namespace(z_measurement, x_predicted) dev = array_device(x_predicted) _mT = xp.linalg.matrix_transpose # Compute residual innovation = z_measurement - self.H_observation_matrix @ x_predicted if self.steady_state: x_updated = x_predicted + self.K_kalman_gain @ innovation return x_updated if P_predicted is None: raise ValueError("P_predicted must be provided for non-steady-state mode") # Non-steady-state mode # System uncertainty S = ( self.H_observation_matrix @ P_predicted @ _mT(self.H_observation_matrix) + self.Q_measurement_noise_covariance ) # Kalman gain K = P_predicted @ _mT(self.H_observation_matrix) @ xp.linalg.pinv(S) # Updated state x_updated = x_predicted + K @ innovation # Covariance update n = self.A_state_transition_matrix.shape[0] I_mat = xp_create(xp.eye, n, device=dev) P_updated = (I_mat - K @ self.H_observation_matrix) @ P_predicted @ _mT( I_mat - K @ self.H_observation_matrix ) + K @ self.Q_measurement_noise_covariance @ _mT(K) # Save updated values self.P_state_covariance = P_updated self.K_kalman_gain = K # self.S = S # Optional: for diagnostics return x_updated