Source code for ezmsg.sigproc.slicer
import numpy as np
import numpy.typing as npt
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import (
AxisArray,
slice_along_axis,
AxisBase,
replace,
)
from .base import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
)
"""
Slicer:Select a subset of data along a particular axis.
"""
[docs]
def parse_slice(
s: str,
axinfo: AxisArray.CoordinateAxis | None = None,
) -> tuple[slice | int, ...]:
"""
Parses a string representation of a slice and returns a tuple of slice objects.
- "" -> slice(None, None, None) (take all)
- ":" -> slice(None, None, None)
- '"none"` (case-insensitive) -> slice(None, None, None)
- "{start}:{stop}" or {start}:{stop}:{step} -> slice(start, stop, step)
- "5" (or any integer) -> (5,). Take only that item.
applying this to a ndarray or AxisArray will drop the dimension.
- A comma-separated list of the above -> a tuple of slices | ints
- A comma-separated list of values and axinfo is provided and is a CoordinateAxis -> a tuple of ints
Args:
s: The string representation of the slice.
axinfo: (Optional) If provided, and of type CoordinateAxis,
and `s` is a comma-separated list of values, then the values
in s will be checked against the values in axinfo.data.
Returns:
A tuple of slice objects and/or ints.
"""
if s.lower() in ["", ":", "none"]:
return (slice(None),)
if "," not in s:
parts = [part.strip() for part in s.split(":")]
if len(parts) == 1:
if (
axinfo is not None
and hasattr(axinfo, "data")
and parts[0] in axinfo.data
):
return tuple(np.where(axinfo.data == parts[0])[0])
return (int(parts[0]),)
return (slice(*(int(part.strip()) if part else None for part in parts)),)
suplist = [parse_slice(_, axinfo=axinfo) for _ in s.split(",")]
return tuple([item for sublist in suplist for item in sublist])
[docs]
class SlicerSettings(ez.Settings):
selection: str = ""
"""selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details."""
axis: str | None = None
"""The name of the axis to slice along. If None, the last axis is used."""
[docs]
@processor_state
class SlicerState:
slice_: slice | int | npt.NDArray | None = None
new_axis: AxisBase | None = None
b_change_dims: bool = False
[docs]
class SlicerTransformer(
BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]
):
def _hash_message(self, message: AxisArray) -> int:
axis = self.settings.axis or message.dims[-1]
axis_idx = message.get_axis_idx(axis)
return hash((message.key, message.data.shape[axis_idx]))
def _reset_state(self, message: AxisArray) -> None:
axis = self.settings.axis or message.dims[-1]
axis_idx = message.get_axis_idx(axis)
self._state.new_axis = None
self._state.b_change_dims = False
# Calculate the slice
_slices = parse_slice(self.settings.selection, message.axes.get(axis, None))
if len(_slices) == 1:
self._state.slice_ = _slices[0]
self._state.b_change_dims = isinstance(self._state.slice_, int)
else:
indices = np.arange(message.data.shape[axis_idx])
indices = np.hstack([indices[_] for _ in _slices])
self._state.slice_ = np.s_[indices]
# Create the output axis
if (
axis in message.axes
and hasattr(message.axes[axis], "data")
and len(message.axes[axis].data) > 0
):
in_data = np.array(message.axes[axis].data)
if self._state.b_change_dims:
out_data = in_data[self._state.slice_ : self._state.slice_ + 1]
else:
out_data = in_data[self._state.slice_]
self._state.new_axis = replace(message.axes[axis], data=out_data)
def _process(self, message: AxisArray) -> AxisArray:
axis = self.settings.axis or message.dims[-1]
axis_idx = message.get_axis_idx(axis)
replace_kwargs = {}
if self._state.b_change_dims:
replace_kwargs["dims"] = [
_ for dim_ix, _ in enumerate(message.dims) if dim_ix != axis_idx
]
replace_kwargs["axes"] = {
k: v for k, v in message.axes.items() if k != axis
}
elif self._state.new_axis is not None:
replace_kwargs["axes"] = {
k: (v if k != axis else self._state.new_axis)
for k, v in message.axes.items()
}
return replace(
message,
data=slice_along_axis(message.data, self._state.slice_, axis_idx),
**replace_kwargs,
)
[docs]
class Slicer(
BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]
):
SETTINGS = SlicerSettings
[docs]
def slicer(selection: str = "", axis: str | None = None) -> SlicerTransformer:
"""
Slice along a particular axis.
Args:
selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details.
axis: The name of the axis to slice along. If None, the last axis is used.
Returns:
:obj:`SlicerTransformer`
"""
return SlicerTransformer(SlicerSettings(selection=selection, axis=axis))