ezmsg.learn.model.refit_kalman#

Classes

class RefitKalmanFilter(A_state_transition_matrix=None, W_process_noise_covariance=None, H_observation_matrix=None, Q_measurement_noise_covariance=None, steady_state=False, enforce_state_structure=False, alpha_fading_memory=1.0, process_noise_scale=1, measurement_noise_scale=1.2)[source]#

Bases: object

Refit Kalman filter for adaptive neural decoding.

This class implements a Kalman filter that can be refitted online during operation. Unlike the standard Kalman filter, this version can adapt its observation model (H and Q matrices) based on new data while maintaining the state transition model (A and W matrices). This is particularly useful for brain-computer interfaces where the relationship between neural activity and intended movements may change over time.

The filter operates in two phases: 1. Initial fitting: Learns all system matrices (A, W, H, Q) from training data 2. Refitting: Updates only the observation model (H, Q) based on new data

A_state_transition_matrix#

The state transition matrix A (n_states x n_states).

W_process_noise_covariance#

The process noise covariance matrix W (n_states x n_states).

H_observation_matrix#

The observation matrix H (n_observations x n_states).

Q_measurement_noise_covariance#

The measurement noise covariance matrix Q (n_observations x n_observations).

K_kalman_gain#

The Kalman gain matrix (n_states x n_observations).

P_state_covariance#

The state error covariance matrix (n_states x n_states).

steady_state#

Whether to use steady-state Kalman gain computation.

is_fitted#

Whether the model has been fitted with data.

Example

>>> # Create and fit the filter
>>> rkf = RefitKalmanFilter(steady_state=True)
>>> rkf.fit(X_train, y_train)
>>>
>>> # Refit with new data
>>> rkf.refit(X_new, Y_state, velocity_indices, targets, cursors, holds)
>>>
>>> # Predict with updated model
>>> x_updated = rkf.predict_and_update(measurement, current_state)
__init__(A_state_transition_matrix=None, W_process_noise_covariance=None, H_observation_matrix=None, Q_measurement_noise_covariance=None, steady_state=False, enforce_state_structure=False, alpha_fading_memory=1.0, process_noise_scale=1, measurement_noise_scale=1.2)[source]#
fit(X_train, y_train)[source]#

Fit the Refit Kalman filter to the training data.

This method learns all system matrices (A, W, H, Q) from training data using least-squares estimation, then computes the steady-state solution. This is the initial fitting phase that establishes the baseline model.

Parameters:
  • X_train – Neural activity (n_samples, n_neurons).

  • y_train – Outputs being predicted (n_samples, n_states).

Raises:
  • ValueError – If training data has invalid dimensions.

  • LinAlgError – If matrix operations fail during fitting.

refit(X_neural, Y_state, intention_velocity_indices=None, target_positions=None, cursor_positions=None, hold_indices=None)[source]#

Refit the observation model based on new data.

This method updates only the observation model (H and Q matrices) while keeping the state transition model (A and W matrices) unchanged. The refitting process modifies the intended states based on target positions and hold flags to better align with user intentions.

The refitting process: 1. Modifies intended states based on target positions and hold flags 2. Recalculates the observation matrix H using least-squares 3. Recalculates the measurement noise covariance Q 4. Updates the Kalman gain accordingly

Parameters:
  • X_neural (ndarray) – Neural activity data (n_samples, n_neurons).

  • Y_state (ndarray) – State estimates (n_samples, n_states).

  • intention_velocity_indices (int | None) – Index of velocity components in state vector.

  • target_positions (ndarray | None) – Target positions for each sample (n_samples, 2).

  • cursor_positions (ndarray | None) – Current cursor positions (n_samples, 2).

  • hold_indices (ndarray | None) – Boolean flags indicating hold periods (n_samples,).

Raises:

ValueError – If input data has invalid dimensions or the model is not fitted.

predict(x_current)[source]#

Predict the next state and covariance.

This method predicts the next state and covariance using the current state.

Return type:

tuple[ndarray, ndarray]

Parameters:

x_current (ndarray)

update(z_measurement, x_predicted, P_predicted=None)[source]#

Update state estimate and covariance based on measurement z.

Return type:

ndarray

Parameters:
class RefitKalmanFilter(A_state_transition_matrix=None, W_process_noise_covariance=None, H_observation_matrix=None, Q_measurement_noise_covariance=None, steady_state=False, enforce_state_structure=False, alpha_fading_memory=1.0, process_noise_scale=1, measurement_noise_scale=1.2)[source]#

Bases: object

Refit Kalman filter for adaptive neural decoding.

This class implements a Kalman filter that can be refitted online during operation. Unlike the standard Kalman filter, this version can adapt its observation model (H and Q matrices) based on new data while maintaining the state transition model (A and W matrices). This is particularly useful for brain-computer interfaces where the relationship between neural activity and intended movements may change over time.

The filter operates in two phases: 1. Initial fitting: Learns all system matrices (A, W, H, Q) from training data 2. Refitting: Updates only the observation model (H, Q) based on new data

A_state_transition_matrix#

The state transition matrix A (n_states x n_states).

W_process_noise_covariance#

The process noise covariance matrix W (n_states x n_states).

H_observation_matrix#

The observation matrix H (n_observations x n_states).

Q_measurement_noise_covariance#

The measurement noise covariance matrix Q (n_observations x n_observations).

K_kalman_gain#

The Kalman gain matrix (n_states x n_observations).

P_state_covariance#

The state error covariance matrix (n_states x n_states).

steady_state#

Whether to use steady-state Kalman gain computation.

is_fitted#

Whether the model has been fitted with data.

Example

>>> # Create and fit the filter
>>> rkf = RefitKalmanFilter(steady_state=True)
>>> rkf.fit(X_train, y_train)
>>>
>>> # Refit with new data
>>> rkf.refit(X_new, Y_state, velocity_indices, targets, cursors, holds)
>>>
>>> # Predict with updated model
>>> x_updated = rkf.predict_and_update(measurement, current_state)
__init__(A_state_transition_matrix=None, W_process_noise_covariance=None, H_observation_matrix=None, Q_measurement_noise_covariance=None, steady_state=False, enforce_state_structure=False, alpha_fading_memory=1.0, process_noise_scale=1, measurement_noise_scale=1.2)[source]#
fit(X_train, y_train)[source]#

Fit the Refit Kalman filter to the training data.

This method learns all system matrices (A, W, H, Q) from training data using least-squares estimation, then computes the steady-state solution. This is the initial fitting phase that establishes the baseline model.

Parameters:
  • X_train – Neural activity (n_samples, n_neurons).

  • y_train – Outputs being predicted (n_samples, n_states).

Raises:
  • ValueError – If training data has invalid dimensions.

  • LinAlgError – If matrix operations fail during fitting.

refit(X_neural, Y_state, intention_velocity_indices=None, target_positions=None, cursor_positions=None, hold_indices=None)[source]#

Refit the observation model based on new data.

This method updates only the observation model (H and Q matrices) while keeping the state transition model (A and W matrices) unchanged. The refitting process modifies the intended states based on target positions and hold flags to better align with user intentions.

The refitting process: 1. Modifies intended states based on target positions and hold flags 2. Recalculates the observation matrix H using least-squares 3. Recalculates the measurement noise covariance Q 4. Updates the Kalman gain accordingly

Parameters:
  • X_neural (ndarray) – Neural activity data (n_samples, n_neurons).

  • Y_state (ndarray) – State estimates (n_samples, n_states).

  • intention_velocity_indices (int | None) – Index of velocity components in state vector.

  • target_positions (ndarray | None) – Target positions for each sample (n_samples, 2).

  • cursor_positions (ndarray | None) – Current cursor positions (n_samples, 2).

  • hold_indices (ndarray | None) – Boolean flags indicating hold periods (n_samples,).

Raises:

ValueError – If input data has invalid dimensions or the model is not fitted.

predict(x_current)[source]#

Predict the next state and covariance.

This method predicts the next state and covariance using the current state.

Return type:

tuple[ndarray, ndarray]

Parameters:

x_current (ndarray)

update(z_measurement, x_predicted, P_predicted=None)[source]#

Update state estimate and covariance based on measurement z.

Return type:

ndarray

Parameters: