Source code for ezmsg.learn.model.refit_kalman

# refit_kalman.py

import numpy as np
from numpy.linalg import LinAlgError
from scipy.linalg import solve_discrete_are


[docs] class RefitKalmanFilter: """ Refit Kalman filter for adaptive neural decoding. This class implements a Kalman filter that can be refitted online during operation. Unlike the standard Kalman filter, this version can adapt its observation model (H and Q matrices) based on new data while maintaining the state transition model (A and W matrices). This is particularly useful for brain-computer interfaces where the relationship between neural activity and intended movements may change over time. The filter operates in two phases: 1. Initial fitting: Learns all system matrices (A, W, H, Q) from training data 2. Refitting: Updates only the observation model (H, Q) based on new data Attributes: A_state_transition_matrix: The state transition matrix A (n_states x n_states). W_process_noise_covariance: The process noise covariance matrix W (n_states x n_states). H_observation_matrix: The observation matrix H (n_observations x n_states). Q_measurement_noise_covariance: The measurement noise covariance matrix Q (n_observations x n_observations). K_kalman_gain: The Kalman gain matrix (n_states x n_observations). P_state_covariance: The state error covariance matrix (n_states x n_states). steady_state: Whether to use steady-state Kalman gain computation. is_fitted: Whether the model has been fitted with data. Example: >>> # Create and fit the filter >>> rkf = RefitKalmanFilter(steady_state=True) >>> rkf.fit(X_train, y_train) >>> >>> # Refit with new data >>> rkf.refit(X_new, Y_state, velocity_indices, targets, cursors, holds) >>> >>> # Predict with updated model >>> x_updated = rkf.predict_and_update(measurement, current_state) """
[docs] def __init__( self, A_state_transition_matrix=None, W_process_noise_covariance=None, H_observation_matrix=None, Q_measurement_noise_covariance=None, steady_state=False, enforce_state_structure=False, alpha_fading_memory=1.000, process_noise_scale=1, measurement_noise_scale=1.2, ): self.A_state_transition_matrix = A_state_transition_matrix self.W_process_noise_covariance = W_process_noise_covariance self.H_observation_matrix = H_observation_matrix self.Q_measurement_noise_covariance = Q_measurement_noise_covariance self.K_kalman_gain = None self.P_state_covariance = None self.alpha_fading_memory = alpha_fading_memory # Noise scaling factors for smoothing control self.process_noise_scale = process_noise_scale self.measurement_noise_scale = measurement_noise_scale self.steady_state = steady_state self.enforce_state_structure = enforce_state_structure self.is_fitted = False
def _validate_state_vector(self, Y_state): """ Validate that the state vector has proper dimensions. Args: Y_state: State vector to validate Raises: ValueError: If state vector has invalid dimensions """ if Y_state.ndim != 2: raise ValueError(f"State vector must be 2D, got {Y_state.ndim}D") if ( not hasattr(self, "H_observation_matrix") or self.H_observation_matrix is None ): raise ValueError("Model must be fitted before refitting") expected_states = self.H_observation_matrix.shape[1] if Y_state.shape[1] != expected_states: raise ValueError( f"State vector has {Y_state.shape[1]} dimensions, expected {expected_states}" )
[docs] def fit(self, X_train, y_train): """ Fit the Refit Kalman filter to the training data. This method learns all system matrices (A, W, H, Q) from training data using least-squares estimation, then computes the steady-state solution. This is the initial fitting phase that establishes the baseline model. Args: X_train: Neural activity (n_samples, n_neurons). y_train: Outputs being predicted (n_samples, n_states). Raises: ValueError: If training data has invalid dimensions. LinAlgError: If matrix operations fail during fitting. """ # self._validate_state_vector(y_train) X = np.array(y_train) Z = np.array(X_train) n_samples = X.shape[0] # Calculate the transition matrix (from x_t to x_t+1) using least-squares X2 = X[1:, :] # x_{t+1} X1 = X[:-1, :] # x_t A = X2.T @ X1 @ np.linalg.inv(X1.T @ X1) # Transition matrix W = ( (X2 - X1 @ A.T).T @ (X2 - X1 @ A.T) / (n_samples - 1) ) # Covariance of transition matrix # Calculate the measurement matrix (from x_t to z_t) using least-squares H = Z.T @ X @ np.linalg.inv(X.T @ X) # Measurement matrix Q = ( (Z - X @ H.T).T @ (Z - X @ H.T) / Z.shape[0] ) # Covariance of measurement matrix self.A_state_transition_matrix = A self.W_process_noise_covariance = W * self.process_noise_scale self.H_observation_matrix = H self.Q_measurement_noise_covariance = Q * self.measurement_noise_scale self._compute_gain() self.is_fitted = True
[docs] def refit( self, X_neural: np.ndarray, Y_state: np.ndarray, intention_velocity_indices: int | None = None, target_positions: np.ndarray | None = None, cursor_positions: np.ndarray | None = None, hold_indices: np.ndarray | None = None, ): """ Refit the observation model based on new data. This method updates only the observation model (H and Q matrices) while keeping the state transition model (A and W matrices) unchanged. The refitting process modifies the intended states based on target positions and hold flags to better align with user intentions. The refitting process: 1. Modifies intended states based on target positions and hold flags 2. Recalculates the observation matrix H using least-squares 3. Recalculates the measurement noise covariance Q 4. Updates the Kalman gain accordingly Args: X_neural: Neural activity data (n_samples, n_neurons). Y_state: State estimates (n_samples, n_states). intention_velocity_indices: Index of velocity components in state vector. target_positions: Target positions for each sample (n_samples, 2). cursor_positions: Current cursor positions (n_samples, 2). hold_indices: Boolean flags indicating hold periods (n_samples,). Raises: ValueError: If input data has invalid dimensions or the model is not fitted. """ self._validate_state_vector(Y_state) # Check if velocity indices are provided if intention_velocity_indices is None: # Assume (x, y, vx, vy) vel_idx = 2 if Y_state.shape[1] >= 4 else 0 print( f"[RefitKalmanFilter] No velocity index provided — defaulting to {vel_idx}" ) else: if isinstance(intention_velocity_indices, (list, tuple)): if len(intention_velocity_indices) != 1: raise ValueError( "Only one velocity start index should be provided." ) vel_idx = intention_velocity_indices[0] else: vel_idx = intention_velocity_indices # Only remap velocity if target and cursor positions are provided if target_positions is None or cursor_positions is None: intended_states = Y_state.copy() else: intended_states = Y_state.copy() # Calculate intended velocities for each sample for i, (state, pos, target) in enumerate( zip(Y_state, cursor_positions, target_positions) ): is_hold = hold_indices[i] if hold_indices is not None else False if is_hold: # During hold periods, intended velocity is zero intended_states[i, vel_idx : vel_idx + 2] = 0.0 if i > 0: intended_states[i, :2] = intended_states[ i - 1, :2 ] # Same position as previous else: # Calculate direction to target to_target = target - pos target_distance = np.linalg.norm(to_target) if target_distance > 1e-5: # Avoid division by zero # Get current decoded velocity magnitude current_velocity = state[vel_idx : vel_idx + 2] current_speed = np.linalg.norm(current_velocity) # Calculate intended velocity: same speed, but toward target target_direction = to_target / target_distance intended_velocity = target_direction * current_speed # Update intended state with new velocity intended_states[i, vel_idx : vel_idx + 2] = intended_velocity # If target is very close, keep original velocity else: intended_states[i, vel_idx : vel_idx + 2] = state[ vel_idx : vel_idx + 2 ] intended_states = np.array(intended_states) Z = np.array(X_neural) # Recalculate observation matrix and noise covariance H = ( Z.T @ intended_states @ np.linalg.pinv(intended_states.T @ intended_states) ) # Using pinv() instead of inv() to avoid singular matrix errors Q = (Z - intended_states @ H.T).T @ (Z - intended_states @ H.T) / Z.shape[0] self.H_observation_matrix = H self.Q_measurement_noise_covariance = Q self._compute_gain()
def _compute_gain(self): """ Compute the Kalman gain matrix. This method computes the Kalman gain matrix based on the current system parameters. In steady-state mode, it solves the discrete-time algebraic Riccati equation to find the optimal steady-state gain. In non-steady-state mode, it computes the gain using the current covariance matrix. Raises: LinAlgError: If the Riccati equation cannot be solved or matrix operations fail. """ ## TODO: consider removing non-steady-state for compute_gain() - non_steady_state updates will occur during predict() and update() # if self.steady_state: try: # Try with original matrices self.P_state_covariance = solve_discrete_are( self.A_state_transition_matrix.T, self.H_observation_matrix.T, self.W_process_noise_covariance, self.Q_measurement_noise_covariance, ) self.K_kalman_gain = ( self.P_state_covariance
[docs] @ self.H_observation_matrix.T @ np.linalg.inv( self.H_observation_matrix @ self.P_state_covariance @ self.H_observation_matrix.T + self.Q_measurement_noise_covariance ) ) except LinAlgError: # Apply regularization and retry # A_reg = self.A_state_transition_matrix * 0.999 # Slight damping # W_reg = self.W_process_noise_covariance + 1e-7 * np.eye( # self.W_process_noise_covariance.shape[0] # ) Q_reg = self.Q_measurement_noise_covariance + 1e-7 * np.eye( self.Q_measurement_noise_covariance.shape[0] ) try: self.P_state_covariance = solve_discrete_are( self.A_state_transition_matrix.T, self.H_observation_matrix.T, self.W_process_noise_covariance, Q_reg, ) self.K_kalman_gain = ( self.P_state_covariance @ self.H_observation_matrix.T @ np.linalg.inv( self.H_observation_matrix @ self.P_state_covariance @ self.H_observation_matrix.T + Q_reg ) ) print("Warning: Used regularized matrices for DARE solution") except LinAlgError: # Fallback to identity or manual initialization print("Warning: DARE failed, using identity covariance") self.P_state_covariance = np.eye( self.A_state_transition_matrix.shape[0] ) # else: # n_states = self.A_state_transition_matrix.shape[0] # self.P_state_covariance = ( # np.eye(n_states) * 1000 # ) # Large initial uncertainty # P_m = ( # self.A_state_transition_matrix # @ self.P_state_covariance # @ self.A_state_transition_matrix.T # + self.W_process_noise_covariance # ) # S = ( # self.H_observation_matrix @ P_m @ self.H_observation_matrix.T # + self.Q_measurement_noise_covariance # ) # self.K_kalman_gain = P_m @ self.H_observation_matrix.T @ np.linalg.pinv(S) # I_mat = np.eye(self.A_state_transition_matrix.shape[0]) # self.P_state_covariance = ( # I_mat - self.K_kalman_gain @ self.H_observation_matrix # ) @ P_m def predict(self, x_current: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ Predict the next state and covariance. This method predicts the next state and covariance using the current state. """ x_predicted = self.A_state_transition_matrix @ x_current if self.steady_state is True: return x_predicted, None else: P_predicted = self.alpha_fading_memory**2 * ( self.A_state_transition_matrix
[docs] @ self.P_state_covariance @ self.A_state_transition_matrix.T + self.W_process_noise_covariance ) return x_predicted, P_predicted
def update( self, z_measurement: np.ndarray, x_predicted: np.ndarray, P_predicted: np.ndarray | None = None, ) -> np.ndarray: """Update state estimate and covariance based on measurement z.""" # Compute residual innovation = z_measurement - self.H_observation_matrix @ x_predicted if self.steady_state: x_updated = x_predicted + self.K_kalman_gain @ innovation return x_updated if P_predicted is None: raise ValueError("P_predicted must be provided for non-steady-state mode") # Non-steady-state mode # System uncertainty S = ( self.H_observation_matrix @ P_predicted @ self.H_observation_matrix.T + self.Q_measurement_noise_covariance ) # Kalman gain K = P_predicted @ self.H_observation_matrix.T @ np.linalg.pinv(S) # Updated state x_updated = x_predicted + K @ innovation # Covariance update I_mat = np.eye(self.A_state_transition_matrix.shape[0]) P_updated = (I_mat - K @ self.H_observation_matrix) @ P_predicted @ ( I_mat - K @ self.H_observation_matrix ).T + K @ self.Q_measurement_noise_covariance @ K.T # Save updated values self.P_state_covariance = P_updated self.K_kalman_gain = K # self.S = S # Optional: for diagnostics return x_updated