ezmsg.learn.process.refit_kalman#

Classes

class RefitKalmanFilterProcessor(*args, **kwargs)[source]#

Bases: BaseAdaptiveTransformer[RefitKalmanFilterSettings, AxisArray, AxisArray, RefitKalmanFilterState]

Processor for implementing a Refit Kalman filter in the ezmsg framework.

This processor integrates the RefitKalmanFilter model into the ezmsg message passing system. It handles the conversion between AxisArray messages and the internal Refit Kalman filter operations.

The processor performs the following operations: 1. Configures the Refit Kalman filter model with provided settings 2. Processes incoming measurement messages 3. Performs prediction and update steps 4. Logs data for potential refitting 5. Supports online refitting of the observation model 6. Returns filtered state estimates as AxisArray messages 7. Maintains state between message processing calls

The processor can operate in two modes: 1. Pre-trained mode: Loads parameters from checkpoint_path 2. Learning mode: Collects data and fits the model when buffer is full

Key features: - Online refitting capability for adaptive neural decoding - Data logging for retrospective analysis - Position tracking for cursor control applications - Hold period detection and handling

settings#

Configuration settings for the Refit Kalman filter.

_state#

Internal state management object.

Example

>>> # Create settings with checkpoint path
>>> settings = RefitKalmanFilterSettings(
...     checkpoint_path="path/to/checkpoint.pkl",
...     steady_state=True
... )
>>>
>>> # Create processor
>>> processor = RefitKalmanFilterProcessor(settings)
>>>
>>> # Process measurement message
>>> result = processor(measurement_message)
>>>
>>> # Log data for refitting
>>> processor.log_for_refit(message, target_pos, hold_flag)
>>>
>>> # Refit the model
>>> processor.refit_model()
fit(X, y)[source]#
Return type:

None

Parameters:
load_from_checkpoint(checkpoint_path)[source]#

Load model parameters from a serialized checkpoint file.

Parameters:

checkpoint_path (str) – Path to the saved checkpoint file.

Return type:

None

Side Effects:
  • Initializes a new model if not already set.

  • Sets model matrices A, W, H, Q from the checkpoint.

  • Computes Kalman gain based on restored parameters.

save_checkpoint(checkpoint_path)[source]#

Save current model parameters to a checkpoint file.

Parameters:

checkpoint_path (str) – Destination file path for saving model parameters.

Raises:

ValueError – If the model is not initialized or has not been fitted.

Return type:

None

partial_fit(message)[source]#

Perform refitting using externally provided data.

Return type:

None

Parameters:

message (SampleMessage)

Expects message.sample.data (neural input) and message.trigger.value as a dict with:
  • Y_state: (n_samples, n_states) array

  • intention_velocity_indices: Optional[int]

  • target_positions: Optional[np.ndarray]

  • cursor_positions: Optional[np.ndarray]

  • hold_flags: Optional[list[bool]]

log_for_refit(message, target_position=None, hold_flag=None)[source]#

Log data for potential refitting of the model.

This method stores measurement data, state estimates, and contextual information (target positions, cursor positions, hold flags) in buffers for later use in refitting the observation model. This data is used to adapt the model to changing neural-to-behavioral relationships.

Parameters:
  • message (AxisArray) – AxisArray message containing measurement data.

  • target_position (ndarray | None) – Target position for the current time point (2,).

  • hold_flag (bool | None) – Boolean flag indicating if this is a hold period.

refit_model()[source]#

Refit the observation model (H, Q) using buffered measurements and contextual data.

This method updates the model’s understanding of the neural-to-state mapping by calculating a new observation matrix and noise covariance, based on:

  • Logged neural data

  • Cursor state estimates

  • Hold flags and target positions

Parameters:

velocity_indices (tuple) – Indices in the state vector corresponding to velocity components. Default assumes 2D velocity at indices (0, 1).

Raises:

ValueError – If no buffered data exists.

class RefitKalmanFilterSettings(checkpoint_path=None, steady_state=False, velocity_indices=(2, 3))[source]#

Bases: Settings

Settings for the Refit Kalman filter processor.

This class defines the configuration parameters for the Refit Kalman filter processor. The RefitKalmanFilter is designed for online processing and playback.

Parameters:
checkpoint_path#

Path to saved model parameters (optional). If provided, loads pre-trained parameters instead of learning from data.

steady_state#

Whether to use steady-state Kalman filter. If True, uses pre-computed Kalman gain; if False, updates dynamically.

checkpoint_path: str | None = None#
steady_state: bool = False#
velocity_indices: tuple[int, int] = (2, 3)#
__init__(checkpoint_path=None, steady_state=False, velocity_indices=(2, 3))#
Parameters:
Return type:

None

class RefitKalmanFilterState[source]#

Bases: object

State management for the Refit Kalman filter processor.

This class manages the persistent state of the Refit Kalman filter processor, including the model instance, current state estimates, and data buffers for refitting.

model#

The RefitKalmanFilter model instance.

x#

Current state estimate (n_states,).

P#

Current state covariance matrix (n_states x n_states).

buffer_neural#

Buffer for storing neural activity data for refitting.

buffer_state#

Buffer for storing state estimates for refitting.

buffer_cursor_positions#

Buffer for storing cursor positions for refitting.

buffer_target_positions#

Buffer for storing target positions for refitting.

buffer_hold_flags#

Buffer for storing hold flags for refitting.

current_position#

Current cursor position estimate (2,).

model: RefitKalmanFilter | None = None#
x: ndarray | None = None#
P: ndarray | None = None#
buffer_neural: list | None = None#
buffer_state: list | None = None#
buffer_cursor_positions: list | None = None#
buffer_target_positions: list | None = None#
buffer_hold_flags: list | None = None#
class RefitKalmanFilterUnit(*args, settings=None, **kwargs)[source]#

Bases: BaseAdaptiveTransformerUnit[RefitKalmanFilterSettings, AxisArray, AxisArray, RefitKalmanFilterProcessor]

Parameters:

settings (Settings | None)

SETTINGS#

alias of RefitKalmanFilterSettings

class RefitKalmanFilterSettings(checkpoint_path=None, steady_state=False, velocity_indices=(2, 3))[source]#

Bases: Settings

Settings for the Refit Kalman filter processor.

This class defines the configuration parameters for the Refit Kalman filter processor. The RefitKalmanFilter is designed for online processing and playback.

Parameters:
checkpoint_path#

Path to saved model parameters (optional). If provided, loads pre-trained parameters instead of learning from data.

steady_state#

Whether to use steady-state Kalman filter. If True, uses pre-computed Kalman gain; if False, updates dynamically.

checkpoint_path: str | None = None#
steady_state: bool = False#
velocity_indices: tuple[int, int] = (2, 3)#
__init__(checkpoint_path=None, steady_state=False, velocity_indices=(2, 3))#
Parameters:
Return type:

None

class RefitKalmanFilterState[source]#

Bases: object

State management for the Refit Kalman filter processor.

This class manages the persistent state of the Refit Kalman filter processor, including the model instance, current state estimates, and data buffers for refitting.

model#

The RefitKalmanFilter model instance.

x#

Current state estimate (n_states,).

P#

Current state covariance matrix (n_states x n_states).

buffer_neural#

Buffer for storing neural activity data for refitting.

buffer_state#

Buffer for storing state estimates for refitting.

buffer_cursor_positions#

Buffer for storing cursor positions for refitting.

buffer_target_positions#

Buffer for storing target positions for refitting.

buffer_hold_flags#

Buffer for storing hold flags for refitting.

current_position#

Current cursor position estimate (2,).

model: RefitKalmanFilter | None = None#
x: ndarray | None = None#
P: ndarray | None = None#
buffer_neural: list | None = None#
buffer_state: list | None = None#
buffer_cursor_positions: list | None = None#
buffer_target_positions: list | None = None#
buffer_hold_flags: list | None = None#
class RefitKalmanFilterProcessor(*args, **kwargs)[source]#

Bases: BaseAdaptiveTransformer[RefitKalmanFilterSettings, AxisArray, AxisArray, RefitKalmanFilterState]

Processor for implementing a Refit Kalman filter in the ezmsg framework.

This processor integrates the RefitKalmanFilter model into the ezmsg message passing system. It handles the conversion between AxisArray messages and the internal Refit Kalman filter operations.

The processor performs the following operations: 1. Configures the Refit Kalman filter model with provided settings 2. Processes incoming measurement messages 3. Performs prediction and update steps 4. Logs data for potential refitting 5. Supports online refitting of the observation model 6. Returns filtered state estimates as AxisArray messages 7. Maintains state between message processing calls

The processor can operate in two modes: 1. Pre-trained mode: Loads parameters from checkpoint_path 2. Learning mode: Collects data and fits the model when buffer is full

Key features: - Online refitting capability for adaptive neural decoding - Data logging for retrospective analysis - Position tracking for cursor control applications - Hold period detection and handling

settings#

Configuration settings for the Refit Kalman filter.

_state#

Internal state management object.

Example

>>> # Create settings with checkpoint path
>>> settings = RefitKalmanFilterSettings(
...     checkpoint_path="path/to/checkpoint.pkl",
...     steady_state=True
... )
>>>
>>> # Create processor
>>> processor = RefitKalmanFilterProcessor(settings)
>>>
>>> # Process measurement message
>>> result = processor(measurement_message)
>>>
>>> # Log data for refitting
>>> processor.log_for_refit(message, target_pos, hold_flag)
>>>
>>> # Refit the model
>>> processor.refit_model()
fit(X, y)[source]#
Return type:

None

Parameters:
load_from_checkpoint(checkpoint_path)[source]#

Load model parameters from a serialized checkpoint file.

Parameters:

checkpoint_path (str) – Path to the saved checkpoint file.

Return type:

None

Side Effects:
  • Initializes a new model if not already set.

  • Sets model matrices A, W, H, Q from the checkpoint.

  • Computes Kalman gain based on restored parameters.

save_checkpoint(checkpoint_path)[source]#

Save current model parameters to a checkpoint file.

Parameters:

checkpoint_path (str) – Destination file path for saving model parameters.

Raises:

ValueError – If the model is not initialized or has not been fitted.

Return type:

None

partial_fit(message)[source]#

Perform refitting using externally provided data.

Return type:

None

Parameters:

message (SampleMessage)

Expects message.sample.data (neural input) and message.trigger.value as a dict with:
  • Y_state: (n_samples, n_states) array

  • intention_velocity_indices: Optional[int]

  • target_positions: Optional[np.ndarray]

  • cursor_positions: Optional[np.ndarray]

  • hold_flags: Optional[list[bool]]

log_for_refit(message, target_position=None, hold_flag=None)[source]#

Log data for potential refitting of the model.

This method stores measurement data, state estimates, and contextual information (target positions, cursor positions, hold flags) in buffers for later use in refitting the observation model. This data is used to adapt the model to changing neural-to-behavioral relationships.

Parameters:
  • message (AxisArray) – AxisArray message containing measurement data.

  • target_position (ndarray | None) – Target position for the current time point (2,).

  • hold_flag (bool | None) – Boolean flag indicating if this is a hold period.

refit_model()[source]#

Refit the observation model (H, Q) using buffered measurements and contextual data.

This method updates the model’s understanding of the neural-to-state mapping by calculating a new observation matrix and noise covariance, based on:

  • Logged neural data

  • Cursor state estimates

  • Hold flags and target positions

Parameters:

velocity_indices (tuple) – Indices in the state vector corresponding to velocity components. Default assumes 2D velocity at indices (0, 1).

Raises:

ValueError – If no buffered data exists.

class RefitKalmanFilterUnit(*args, settings=None, **kwargs)[source]#

Bases: BaseAdaptiveTransformerUnit[RefitKalmanFilterSettings, AxisArray, AxisArray, RefitKalmanFilterProcessor]

Parameters:

settings (Settings | None)

SETTINGS#

alias of RefitKalmanFilterSettings