import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
)
from ezmsg.util.messages.axisarray import (
AxisArray,
AxisBase,
replace,
slice_along_axis,
)
"""
Slicer:Select a subset of data along a particular axis.
"""
[docs]
def parse_slice(
s: str,
axinfo: AxisArray.CoordinateAxis | None = None,
) -> tuple[slice | int, ...]:
"""
Parses a string representation of a slice and returns a tuple of slice objects.
- "" -> slice(None, None, None) (take all)
- ":" -> slice(None, None, None)
- '"none"` (case-insensitive) -> slice(None, None, None)
- "{start}:{stop}" or {start}:{stop}:{step} -> slice(start, stop, step)
- "5" (or any integer) -> (5,). Take only that item.
applying this to a ndarray or AxisArray will drop the dimension.
- A comma-separated list of the above -> a tuple of slices | ints
- A comma-separated list of values and axinfo is provided and is a CoordinateAxis -> a tuple of ints
Args:
s: The string representation of the slice.
axinfo: (Optional) If provided, and of type CoordinateAxis,
and `s` is a comma-separated list of values, then the values
in s will be checked against the values in axinfo.data.
Returns:
A tuple of slice objects and/or ints.
"""
if s.lower() in ["", ":", "none"]:
return (slice(None),)
if "," not in s:
parts = [part.strip() for part in s.split(":")]
if len(parts) == 1:
if axinfo is not None and hasattr(axinfo, "data") and parts[0] in axinfo.data:
return tuple(np.where(axinfo.data == parts[0])[0])
return (int(parts[0]),)
return (slice(*(int(part.strip()) if part else None for part in parts)),)
suplist = [parse_slice(_, axinfo=axinfo) for _ in s.split(",")]
return tuple([item for sublist in suplist for item in sublist])
[docs]
class SlicerSettings(ez.Settings):
selection: str = ""
"""selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details."""
axis: str | None = None
"""The name of the axis to slice along. If None, the last axis is used."""
[docs]
@processor_state
class SlicerState:
slice_: slice | int | npt.NDArray | None = None
new_axis: AxisBase | None = None
b_change_dims: bool = False
[docs]
class Slicer(BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]):
SETTINGS = SlicerSettings
[docs]
def slicer(selection: str = "", axis: str | None = None) -> SlicerTransformer:
"""
Slice along a particular axis.
Args:
selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details.
axis: The name of the axis to slice along. If None, the last axis is used.
Returns:
:obj:`SlicerTransformer`
"""
return SlicerTransformer(SlicerSettings(selection=selection, axis=axis))