ezmsg.learn.process.refit_kalman#
Classes
- class RefitKalmanFilterProcessor(*args, **kwargs)[source]#
Bases:
BaseAdaptiveTransformer[RefitKalmanFilterSettings,AxisArray,AxisArray,RefitKalmanFilterState]Processor for implementing a Refit Kalman filter in the ezmsg framework.
This processor integrates the RefitKalmanFilter model into the ezmsg message passing system. It handles the conversion between AxisArray messages and the internal Refit Kalman filter operations.
The processor performs the following operations: 1. Configures the Refit Kalman filter model with provided settings 2. Processes incoming measurement messages 3. Performs prediction and update steps 4. Logs data for potential refitting 5. Supports online refitting of the observation model 6. Returns filtered state estimates as AxisArray messages 7. Maintains state between message processing calls
The processor can operate in two modes: 1. Pre-trained mode: Loads parameters from checkpoint_path 2. Learning mode: Collects data and fits the model when buffer is full
Key features: - Online refitting capability for adaptive neural decoding - Data logging for retrospective analysis - Position tracking for cursor control applications - Hold period detection and handling
- settings#
Configuration settings for the Refit Kalman filter.
- _state#
Internal state management object.
Example
>>> # Create settings with checkpoint path >>> settings = RefitKalmanFilterSettings( ... checkpoint_path="path/to/checkpoint.pkl", ... steady_state=True ... ) >>> >>> # Create processor >>> processor = RefitKalmanFilterProcessor(settings) >>> >>> # Process measurement message >>> result = processor(measurement_message) >>> >>> # Log data for refitting >>> processor.log_for_refit(message, target_pos, hold_flag) >>> >>> # Refit the model >>> processor.refit_model()
- load_from_checkpoint(checkpoint_path)[source]#
Load model parameters from a serialized checkpoint file.
- Side Effects:
Initializes a new model if not already set.
Sets model matrices A, W, H, Q from the checkpoint.
Computes Kalman gain based on restored parameters.
- save_checkpoint(checkpoint_path)[source]#
Save current model parameters to a checkpoint file.
- Parameters:
checkpoint_path (str) – Destination file path for saving model parameters.
- Raises:
ValueError – If the model is not initialized or has not been fitted.
- Return type:
- partial_fit(message)[source]#
Perform refitting using externally provided data.
- Return type:
- Parameters:
message (SampleMessage)
- Expects message.sample.data (neural input) and message.trigger.value as a dict with:
Y_state: (n_samples, n_states) array
intention_velocity_indices: Optional[int]
target_positions: Optional[np.ndarray]
cursor_positions: Optional[np.ndarray]
hold_flags: Optional[list[bool]]
- log_for_refit(message, target_position=None, hold_flag=None)[source]#
Log data for potential refitting of the model.
This method stores measurement data, state estimates, and contextual information (target positions, cursor positions, hold flags) in buffers for later use in refitting the observation model. This data is used to adapt the model to changing neural-to-behavioral relationships.
- refit_model()[source]#
Refit the observation model (H, Q) using buffered measurements and contextual data.
This method updates the model’s understanding of the neural-to-state mapping by calculating a new observation matrix and noise covariance, based on:
Logged neural data
Cursor state estimates
Hold flags and target positions
- Parameters:
velocity_indices (tuple) – Indices in the state vector corresponding to velocity components. Default assumes 2D velocity at indices (0, 1).
- Raises:
ValueError – If no buffered data exists.
- class RefitKalmanFilterSettings(checkpoint_path=None, steady_state=False, velocity_indices=(2, 3))[source]#
Bases:
SettingsSettings for the Refit Kalman filter processor.
This class defines the configuration parameters for the Refit Kalman filter processor. The RefitKalmanFilter is designed for online processing and playback.
- checkpoint_path#
Path to saved model parameters (optional). If provided, loads pre-trained parameters instead of learning from data.
- steady_state#
Whether to use steady-state Kalman filter. If True, uses pre-computed Kalman gain; if False, updates dynamically.
- class RefitKalmanFilterState[source]#
Bases:
objectState management for the Refit Kalman filter processor.
This class manages the persistent state of the Refit Kalman filter processor, including the model instance, current state estimates, and data buffers for refitting.
- model#
The RefitKalmanFilter model instance.
- x#
Current state estimate (n_states,).
- P#
Current state covariance matrix (n_states x n_states).
- buffer_neural#
Buffer for storing neural activity data for refitting.
- buffer_state#
Buffer for storing state estimates for refitting.
- buffer_cursor_positions#
Buffer for storing cursor positions for refitting.
- buffer_target_positions#
Buffer for storing target positions for refitting.
- buffer_hold_flags#
Buffer for storing hold flags for refitting.
- current_position#
Current cursor position estimate (2,).
- model: RefitKalmanFilter | None = None#
- class RefitKalmanFilterUnit(*args, settings=None, **kwargs)[source]#
Bases:
BaseAdaptiveTransformerUnit[RefitKalmanFilterSettings,AxisArray,AxisArray,RefitKalmanFilterProcessor]- Parameters:
settings (Settings | None)
- SETTINGS#
alias of
RefitKalmanFilterSettings
- class RefitKalmanFilterSettings(checkpoint_path=None, steady_state=False, velocity_indices=(2, 3))[source]#
Bases:
SettingsSettings for the Refit Kalman filter processor.
This class defines the configuration parameters for the Refit Kalman filter processor. The RefitKalmanFilter is designed for online processing and playback.
- checkpoint_path#
Path to saved model parameters (optional). If provided, loads pre-trained parameters instead of learning from data.
- steady_state#
Whether to use steady-state Kalman filter. If True, uses pre-computed Kalman gain; if False, updates dynamically.
- class RefitKalmanFilterState[source]#
Bases:
objectState management for the Refit Kalman filter processor.
This class manages the persistent state of the Refit Kalman filter processor, including the model instance, current state estimates, and data buffers for refitting.
- model#
The RefitKalmanFilter model instance.
- x#
Current state estimate (n_states,).
- P#
Current state covariance matrix (n_states x n_states).
- buffer_neural#
Buffer for storing neural activity data for refitting.
- buffer_state#
Buffer for storing state estimates for refitting.
- buffer_cursor_positions#
Buffer for storing cursor positions for refitting.
- buffer_target_positions#
Buffer for storing target positions for refitting.
- buffer_hold_flags#
Buffer for storing hold flags for refitting.
- current_position#
Current cursor position estimate (2,).
- model: RefitKalmanFilter | None = None#
- class RefitKalmanFilterProcessor(*args, **kwargs)[source]#
Bases:
BaseAdaptiveTransformer[RefitKalmanFilterSettings,AxisArray,AxisArray,RefitKalmanFilterState]Processor for implementing a Refit Kalman filter in the ezmsg framework.
This processor integrates the RefitKalmanFilter model into the ezmsg message passing system. It handles the conversion between AxisArray messages and the internal Refit Kalman filter operations.
The processor performs the following operations: 1. Configures the Refit Kalman filter model with provided settings 2. Processes incoming measurement messages 3. Performs prediction and update steps 4. Logs data for potential refitting 5. Supports online refitting of the observation model 6. Returns filtered state estimates as AxisArray messages 7. Maintains state between message processing calls
The processor can operate in two modes: 1. Pre-trained mode: Loads parameters from checkpoint_path 2. Learning mode: Collects data and fits the model when buffer is full
Key features: - Online refitting capability for adaptive neural decoding - Data logging for retrospective analysis - Position tracking for cursor control applications - Hold period detection and handling
- settings#
Configuration settings for the Refit Kalman filter.
- _state#
Internal state management object.
Example
>>> # Create settings with checkpoint path >>> settings = RefitKalmanFilterSettings( ... checkpoint_path="path/to/checkpoint.pkl", ... steady_state=True ... ) >>> >>> # Create processor >>> processor = RefitKalmanFilterProcessor(settings) >>> >>> # Process measurement message >>> result = processor(measurement_message) >>> >>> # Log data for refitting >>> processor.log_for_refit(message, target_pos, hold_flag) >>> >>> # Refit the model >>> processor.refit_model()
- load_from_checkpoint(checkpoint_path)[source]#
Load model parameters from a serialized checkpoint file.
- Side Effects:
Initializes a new model if not already set.
Sets model matrices A, W, H, Q from the checkpoint.
Computes Kalman gain based on restored parameters.
- save_checkpoint(checkpoint_path)[source]#
Save current model parameters to a checkpoint file.
- Parameters:
checkpoint_path (str) – Destination file path for saving model parameters.
- Raises:
ValueError – If the model is not initialized or has not been fitted.
- Return type:
- partial_fit(message)[source]#
Perform refitting using externally provided data.
- Return type:
- Parameters:
message (SampleMessage)
- Expects message.sample.data (neural input) and message.trigger.value as a dict with:
Y_state: (n_samples, n_states) array
intention_velocity_indices: Optional[int]
target_positions: Optional[np.ndarray]
cursor_positions: Optional[np.ndarray]
hold_flags: Optional[list[bool]]
- log_for_refit(message, target_position=None, hold_flag=None)[source]#
Log data for potential refitting of the model.
This method stores measurement data, state estimates, and contextual information (target positions, cursor positions, hold flags) in buffers for later use in refitting the observation model. This data is used to adapt the model to changing neural-to-behavioral relationships.
- refit_model()[source]#
Refit the observation model (H, Q) using buffered measurements and contextual data.
This method updates the model’s understanding of the neural-to-state mapping by calculating a new observation matrix and noise covariance, based on:
Logged neural data
Cursor state estimates
Hold flags and target positions
- Parameters:
velocity_indices (tuple) – Indices in the state vector corresponding to velocity components. Default assumes 2D velocity at indices (0, 1).
- Raises:
ValueError – If no buffered data exists.
- class RefitKalmanFilterUnit(*args, settings=None, **kwargs)[source]#
Bases:
BaseAdaptiveTransformerUnit[RefitKalmanFilterSettings,AxisArray,AxisArray,RefitKalmanFilterProcessor]- Parameters:
settings (Settings | None)
- SETTINGS#
alias of
RefitKalmanFilterSettings