Source code for ezmsg.simbiophys.dnss.pattern_match
import numpy as np
import scipy.signal
[docs]
def find_pattern_start(
data: np.ndarray,
full_pattern: np.ndarray,
pattern_window: tuple[int, int] | None = None,
subpattern_samples: int | None = None, # e.g., use 30_000 for Neural Signal repeating sine waves
) -> int:
n_pattern = full_pattern.shape[1]
if pattern_window is None:
pattern_window = [0, full_pattern.shape[1]]
n_window = int(np.diff(pattern_window))
# In a worst-case scenario, our `data` might start exactly 1 sample after our pattern_window, in which case
# we need to finish the current pattern and continue until we have reached the next pattern and the
# pattern window.
req_samples = n_window - 1 + n_pattern
assert data.shape[1] >= req_samples, f"At least {req_samples} required to ensure optimal match."
search_data = data[:, :req_samples]
# Search for bursts in each channel independently.
# Note: I tried to do a multi-channel search with correlate2d but it was incredibly slow.
# A multi-channel search might be made faster with pytorch.
match_onsets = np.zeros((search_data.shape[0],), dtype=int)
for chan_ix, chan_data in enumerate(search_data):
xcorr = scipy.signal.correlate(
chan_data,
full_pattern[chan_ix % full_pattern.shape[0], pattern_window[0] : pattern_window[1]],
mode="valid",
)
match_onsets[chan_ix] = np.argmax(xcorr)
match_onset = int(np.median(match_onsets[:96])) # Simple hack for when using HDMI
if subpattern_samples is not None:
# The signal might have a subpattern that can cause mis-matches at subpattern repeat intervals.
# Count how many channels matched best at +/- 1 subpattern repeat relative to the match.
# If -1 is bigger than +1 then we may have detected window offset, not onset.
bin_edges = [
match_onset + shift + win_half
for shift in [-subpattern_samples, 0, subpattern_samples]
for win_half in [-10, 10]
]
bin_counts = np.histogram(match_onsets, bin_edges)[0]
if bin_counts[0] > bin_counts[4]:
match_onset -= subpattern_samples
# Do a multi-channel alignment in a small space around the putative window onset.
test_shifts = np.arange(-5, 6)
tiled_burst = np.tile(
full_pattern[:, pattern_window[0] : pattern_window[1]],
(search_data.shape[0] // full_pattern.shape[0], 1),
)
rms = np.zeros((test_shifts.size,))
for ix, shift in enumerate(test_shifts):
temp = search_data[:, match_onset + shift : match_onset + shift + n_window]
rms[ix] = np.sqrt(np.mean((temp - tiled_burst) ** 2))
match_onset += test_shifts[np.argmin(rms)]
# match_onset tells us when we matched the window. We want to know when the full pattern was matched.
return (match_onset - pattern_window[0]) % n_pattern