ezmsg.learn.model.refit_kalman#
Classes
- class RefitKalmanFilter(A_state_transition_matrix=None, W_process_noise_covariance=None, H_observation_matrix=None, Q_measurement_noise_covariance=None, steady_state=False, enforce_state_structure=False, alpha_fading_memory=1.0, process_noise_scale=1, measurement_noise_scale=1.2)[source]#
Bases:
objectRefit Kalman filter for adaptive neural decoding.
This class implements a Kalman filter that can be refitted online during operation. Unlike the standard Kalman filter, this version can adapt its observation model (H and Q matrices) based on new data while maintaining the state transition model (A and W matrices). This is particularly useful for brain-computer interfaces where the relationship between neural activity and intended movements may change over time.
The filter operates in two phases: 1. Initial fitting: Learns all system matrices (A, W, H, Q) from training data 2. Refitting: Updates only the observation model (H, Q) based on new data
- A_state_transition_matrix#
The state transition matrix A (n_states x n_states).
- W_process_noise_covariance#
The process noise covariance matrix W (n_states x n_states).
- H_observation_matrix#
The observation matrix H (n_observations x n_states).
- Q_measurement_noise_covariance#
The measurement noise covariance matrix Q (n_observations x n_observations).
- K_kalman_gain#
The Kalman gain matrix (n_states x n_observations).
- P_state_covariance#
The state error covariance matrix (n_states x n_states).
- steady_state#
Whether to use steady-state Kalman gain computation.
- is_fitted#
Whether the model has been fitted with data.
Example
>>> # Create and fit the filter >>> rkf = RefitKalmanFilter(steady_state=True) >>> rkf.fit(X_train, y_train) >>> >>> # Refit with new data >>> rkf.refit(X_new, Y_state, velocity_indices, targets, cursors, holds) >>> >>> # Predict with updated model >>> x_updated = rkf.predict_and_update(measurement, current_state)
- __init__(A_state_transition_matrix=None, W_process_noise_covariance=None, H_observation_matrix=None, Q_measurement_noise_covariance=None, steady_state=False, enforce_state_structure=False, alpha_fading_memory=1.0, process_noise_scale=1, measurement_noise_scale=1.2)[source]#
- fit(X_train, y_train)[source]#
Fit the Refit Kalman filter to the training data.
This method learns all system matrices (A, W, H, Q) from training data using least-squares estimation, then computes the steady-state solution. This is the initial fitting phase that establishes the baseline model.
- Parameters:
X_train – Neural activity (n_samples, n_neurons).
y_train – Outputs being predicted (n_samples, n_states).
- Raises:
ValueError – If training data has invalid dimensions.
LinAlgError – If matrix operations fail during fitting.
- refit(X_neural, Y_state, intention_velocity_indices=None, target_positions=None, cursor_positions=None, hold_indices=None)[source]#
Refit the observation model based on new data.
This method updates only the observation model (H and Q matrices) while keeping the state transition model (A and W matrices) unchanged. The refitting process modifies the intended states based on target positions and hold flags to better align with user intentions.
The refitting process: 1. Modifies intended states based on target positions and hold flags 2. Recalculates the observation matrix H using least-squares 3. Recalculates the measurement noise covariance Q 4. Updates the Kalman gain accordingly
- Parameters:
X_neural (
ndarray) – Neural activity data (n_samples, n_neurons).Y_state (
ndarray) – State estimates (n_samples, n_states).intention_velocity_indices (
int|None) – Index of velocity components in state vector.target_positions (
ndarray|None) – Target positions for each sample (n_samples, 2).cursor_positions (
ndarray|None) – Current cursor positions (n_samples, 2).hold_indices (
ndarray|None) – Boolean flags indicating hold periods (n_samples,).
- Raises:
ValueError – If input data has invalid dimensions or the model is not fitted.
- class RefitKalmanFilter(A_state_transition_matrix=None, W_process_noise_covariance=None, H_observation_matrix=None, Q_measurement_noise_covariance=None, steady_state=False, enforce_state_structure=False, alpha_fading_memory=1.0, process_noise_scale=1, measurement_noise_scale=1.2)[source]#
Bases:
objectRefit Kalman filter for adaptive neural decoding.
This class implements a Kalman filter that can be refitted online during operation. Unlike the standard Kalman filter, this version can adapt its observation model (H and Q matrices) based on new data while maintaining the state transition model (A and W matrices). This is particularly useful for brain-computer interfaces where the relationship between neural activity and intended movements may change over time.
The filter operates in two phases: 1. Initial fitting: Learns all system matrices (A, W, H, Q) from training data 2. Refitting: Updates only the observation model (H, Q) based on new data
- A_state_transition_matrix#
The state transition matrix A (n_states x n_states).
- W_process_noise_covariance#
The process noise covariance matrix W (n_states x n_states).
- H_observation_matrix#
The observation matrix H (n_observations x n_states).
- Q_measurement_noise_covariance#
The measurement noise covariance matrix Q (n_observations x n_observations).
- K_kalman_gain#
The Kalman gain matrix (n_states x n_observations).
- P_state_covariance#
The state error covariance matrix (n_states x n_states).
- steady_state#
Whether to use steady-state Kalman gain computation.
- is_fitted#
Whether the model has been fitted with data.
Example
>>> # Create and fit the filter >>> rkf = RefitKalmanFilter(steady_state=True) >>> rkf.fit(X_train, y_train) >>> >>> # Refit with new data >>> rkf.refit(X_new, Y_state, velocity_indices, targets, cursors, holds) >>> >>> # Predict with updated model >>> x_updated = rkf.predict_and_update(measurement, current_state)
- __init__(A_state_transition_matrix=None, W_process_noise_covariance=None, H_observation_matrix=None, Q_measurement_noise_covariance=None, steady_state=False, enforce_state_structure=False, alpha_fading_memory=1.0, process_noise_scale=1, measurement_noise_scale=1.2)[source]#
- fit(X_train, y_train)[source]#
Fit the Refit Kalman filter to the training data.
This method learns all system matrices (A, W, H, Q) from training data using least-squares estimation, then computes the steady-state solution. This is the initial fitting phase that establishes the baseline model.
- Parameters:
X_train – Neural activity (n_samples, n_neurons).
y_train – Outputs being predicted (n_samples, n_states).
- Raises:
ValueError – If training data has invalid dimensions.
LinAlgError – If matrix operations fail during fitting.
- refit(X_neural, Y_state, intention_velocity_indices=None, target_positions=None, cursor_positions=None, hold_indices=None)[source]#
Refit the observation model based on new data.
This method updates only the observation model (H and Q matrices) while keeping the state transition model (A and W matrices) unchanged. The refitting process modifies the intended states based on target positions and hold flags to better align with user intentions.
The refitting process: 1. Modifies intended states based on target positions and hold flags 2. Recalculates the observation matrix H using least-squares 3. Recalculates the measurement noise covariance Q 4. Updates the Kalman gain accordingly
- Parameters:
X_neural (
ndarray) – Neural activity data (n_samples, n_neurons).Y_state (
ndarray) – State estimates (n_samples, n_states).intention_velocity_indices (
int|None) – Index of velocity components in state vector.target_positions (
ndarray|None) – Target positions for each sample (n_samples, 2).cursor_positions (
ndarray|None) – Current cursor positions (n_samples, 2).hold_indices (
ndarray|None) – Boolean flags indicating hold periods (n_samples,).
- Raises:
ValueError – If input data has invalid dimensions or the model is not fitted.