diff --git a/doc/conf.py b/doc/conf.py index 742f46662e..ea93b4ddfe 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -125,6 +125,7 @@ 'subsection_order': ExplicitOrder([ '../examples/tutorials/core', '../examples/tutorials/extractors', + '../examples/tutorials/preprocessing', '../examples/tutorials/curation', '../examples/tutorials/qualitymetrics', '../examples/tutorials/comparison', diff --git a/doc/tutorials_custom_index.rst b/doc/tutorials_custom_index.rst index 82f2c06eed..fb0e2cafa0 100644 --- a/doc/tutorials_custom_index.rst +++ b/doc/tutorials_custom_index.rst @@ -85,6 +85,20 @@ These tutorials focus on the :py:mod:`spikeinterface.core` module. :class-card: gallery-card :text-align: center +Preprocessing tutorials +----------------------- + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Inter-session Alignment + :link-type: ref + :link: sphx_glr_tutorials_preprocessing_plot_7_inter_session_alignment.py + :img-top: /tutorials/preprocessing/images/thumb/sphx_glr_plot_7_inter_session_alignment_thumb.png + :img-alt: Inter-session Alignment + :class-card: gallery-card + :text-align: center + Extractors tutorials -------------------- diff --git a/examples/tutorials/preprocessing/README.rst b/examples/tutorials/preprocessing/README.rst new file mode 100644 index 0000000000..50ef0bd6f1 --- /dev/null +++ b/examples/tutorials/preprocessing/README.rst @@ -0,0 +1,2 @@ +Preprocessing Tutorials +----------------------- diff --git a/examples/tutorials/preprocessing/plot_7_inter_session_alignment.py b/examples/tutorials/preprocessing/plot_7_inter_session_alignment.py new file mode 100644 index 0000000000..54762da865 --- /dev/null +++ b/examples/tutorials/preprocessing/plot_7_inter_session_alignment.py @@ -0,0 +1,349 @@ +""" +How to perform inter-session alignment +====================================== + +In this tutorial we will assess and correct changes in probe position across +multiple experimental sessions using `inter-session alignment`. + +This is often valuable for chronic-recording experiments, where the goal is to track units across sessions + + +Running inter-session alignment +------------------------------- + +In SpikeInterface, it is recommended to perform inter-session alignment +following within-session motion correction (if used) and before whitening / sorting. +If you are running inter-session alignment after motion correction, see +:ref:`inter-session alignment after motion correction `. + +Preprocessed recordings should first be stored in a list: + +.. code-block:: python + + recordings_list = [prepro_session_1, prepro_session_2, ...] + +Here, we will simulate an experiment with two sessions by generating a pair of sessions in +which the probe is displaced 200 micrometers (μm) along its y-axis (depth). +First, we will import all required packages and functions: +""" + +import spikeinterface.full as si +from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings +from spikeinterface.preprocessing.inter_session_alignment import session_alignment +from spikeinterface.widgets import plot_session_alignment, plot_activity_histogram_2d +import matplotlib.pyplot as plt + + +# %% +# and then generate the test recordings: + +recordings_list, _ = generate_session_displacement_recordings( # TODO: add to spikeinterface.full ? + num_units=8, + recording_durations=[10, 10], + recording_shifts=((0, 0), (0, 200)), # (x offset, y offset) pairs + seed=42 +) + +# %% +# We won't preprocess the simulated recordings in this tutorial, but you can imagine +# preprocessing steps have already been run (e.g. filtering, common reference etc.). +# +# To run inter-session alignment, peaks must be detected and localised +# as the locations of firing neurons are used to anchor the sessions' alignment. +# +# If you are **running inter-session alignment following motion correction**, the peaks will +# already be detected and localised. In this case, please jump to +# :ref:`inter-session alignment after motion correction `. +# +# In this section of the tutorial, we will assume motion correction was not run, so we need to compute the peaks: + +peaks_list, peak_locations_list = session_alignment.compute_peaks_locations_for_session_alignment( + recordings_list, + detect_kwargs={"method": "locally_exclusive"}, + localize_peaks_kwargs={"method": "grid_convolution"}, +) + +# %% +# The peak locations (before correction) can be visualised with the plotting function: + +plot_session_alignment( + recordings_list, + peaks_list, + peak_locations_list, +) +plt.show() + +# %% +# we are now ready to perform inter-session alignment. There are many options associated +# with this method (see sections below). To edit the configurations, fetch the default options +# with the available getters function and make select changes as required: + +estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs() +estimate_histogram_kwargs["histogram_type"] = "2d" + +corrected_recordings_list, extra_info = session_alignment.align_sessions( + recordings_list, + peaks_list, + peak_locations_list, + estimate_histogram_kwargs=estimate_histogram_kwargs +) + +# %% +# To assess the performance of inter-session alignment, ``plot_session_alignment()`` +# will plot both the original and corrected recordings: + +plot_session_alignment( + recordings_list, + peaks_list, + peak_locations_list, + extra_info["session_histogram_list"], + **extra_info["corrected"], + spatial_bin_centers=extra_info["spatial_bin_centers"], + drift_raster_map_kwargs={"clim":(-250, 0)} +) +plt.show() + +# %% +# As we have used 2d histograms for alignment, we can also plot these with ``plot_activity_histogram_2d()``: + +plot_activity_histogram_2d( + extra_info["session_histogram_list"], + extra_info["spatial_bin_centers"], + extra_info["corrected"]["corrected_session_histogram_list"] +) +plt.show() + +# +# .. _with_motion_correction: + +# %% +# Inter-session alignment after motion correction +# ----------------------------------------------- +# +# If motion correction has already been performed, it is possible to reuse the +# previously computed peaks and peak locations, avoiding the need for re-computation. +# We will use the special function` `align_sessions_after_motion_correction()`` for this case. +# +# Critically, the last preprocessing step prior to inter-session alignment should be motion correction. +# This ensures the correction for inter-session alignment will be **added directly to the motion correction**. +# This is beneficial as it avoids interpolating the data (i.e. shifting the traces) more than once. +# +# .. admonition:: Warning +# :class: warning +# +# To ensure that inter-session alignment adds the displacement directly to the motion-corrected recording +# to avoid repeated interpolation, motion correction must be the final operation applied to the recording +# prior to inter-session alignment. +# +# You can verify this by confirming the recording is an ``InterpolateMotionRecording`` with: +# +# .. code-block:: +# type(recording)`` # quick check, should print `InterpolateMotionRecording` +# +# from spikeinterface.sortingcomponents.motion.motion_interpolation import InterpolateMotionRecording +# +# assert isinstance(recording, InterpolateMotionRecording) # error if not true +# +# +# ``align_sessions_after_motion_correction()`` will raise an error if the passed recordings +# are not all ``InterpolateMotionRecordings``. +# +# Let's first create some test data. We can create a recording with motion errors, +# then split it in two to simulate two separate sessions: + +# Generate the recording with motion artefact +motion_recording = si.generate_drifting_recording(duration=100)[0] +total_duration = motion_recording.get_duration() +split_time = total_duration / 2 + +# Split in two to simulate two sessions +recording_part1 = motion_recording.time_slice(start_time=0, end_time=split_time) +recording_part2 = motion_recording.time_slice(start_time=split_time, end_time=total_duration) + +# %% +# Next, motion correction is performed, storing the results in a list: + +# perform motion correction on each session, storing the outputs in lists +recordings_list_motion = [] +motion_info_list = [] +for recording in [recording_part1, recording_part2]: + + rec, motion_info = si.correct_motion(recording, output_motion_info=True, preset="rigid_fast") + + recordings_list_motion.append(rec) + motion_info_list.append(motion_info) + +# %% +# Now, we are ready to use ``align_sessions_after_motion_correction()``. +# We can pass any arguments directly to ``align_sessions`` using the ``align_sessions_kwargs`` argument: + +estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs() +estimate_histogram_kwargs["histogram_type"] = "2d" + +align_sessions_kwargs = {"estimate_histogram_kwargs": estimate_histogram_kwargs} + +corrected_recordings_list_motion, _ = session_alignment.align_sessions_after_motion_correction( + recordings_list_motion, motion_info_list, align_sessions_kwargs +) + +# %% +# As above, the inter-session alignment can be assessed using ``plot_session_alignment()``. + +# %% +# Inter-session alignment settings +# -------------------------------- +# +# Below, the settings that control how inter-session alignment is performed +# are explored. These configs can be accessed by the getter functions +# ``get_estimate_histogram_kwargs``, ``get_compute_alignment_kwargs``, +# ``get_non_rigid_window_kwargs``, ``get_interpolate_motion_kwargs``. +# +# TODO: cannot add inter-session alignment to imports due to circular import error +# +# Estimate Histogram Kwargs +# ~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The settings control how the activity histogram (used for alignment) is estimated +# for each session. They can be obtained with ``get_estimate_histogram_kwargs``. +# +# The ``"bin_um"`` parameter controls the bin-size of the activity histogram. +# Along the probe's y-axis, spatial bins will be generated according to this bin size. +# +# To compute the histogram, the session is split into chunks across time, and either +# the mean or median (bin-wise) taken across chunks. This generates the summary +# histgoram for that session to be used to estimate inter-session displacement. +# +# The ``"method"`` parameter controls whether the mean (``"chunked_mean"``) +# or median (``"chunked_median"``) is used. The idea of using the median is to +# reduce the effect periods of the recording which may be outliers +# due to noise or other signal contamination. +# ``"chunked_bin_size_s"`` sets the size of the temporal chunks. By default is +# ``"estimate"`` which estimates the chunk size based on firing frequency +# (see XXXX). Otherwise, can taFke a float for chunk size in seconds. +# +# The ``histogram_type`` can be ``"1d"` or ``"2d"``, +# if 1D the firing rate x spatial bin histogram is generated. Otherwise +# a firing rate x amplitude x spatial bin histogram is generated. +# +# We can visualise the histograms for each time chunk with: + +estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs() +estimate_histogram_kwargs["histogram_type"] = "1d" +estimate_histogram_kwargs["chunked_bin_size_s"] = 1.0 + +_, extra_info_rigid = session_alignment.align_sessions( + recordings_list, + peaks_list, + peak_locations_list, + estimate_histogram_kwargs=estimate_histogram_kwargs, +) + +plt.plot(extra_info_rigid["histogram_info_list"][0]["chunked_histograms"].T) +plt.xlabel("Spatial bim (um)") +plt.show() + +# %% +# Compute Alignment Kwargs +# ~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Once the histograms have been generated for each session, the displacement +# between sessions is computed. ``get_compute_alignment_kwargs()`` set how this +# displacement is estimated. +# +# The estimation proceeds similar to the `Kilosort motion-correction `_ +# method (see also the "kilosort-like" option in `:func:`correct_motion``.). Briefly, the cross-correlation +# of activity histograms is performed and the peak position used as a linear estimate of the displacement. +# For-non rigid alignment, first linear alignment is performed, then the probe y-axis is split into segments +# and linear estimation performed in each bin. Then, the displacement set at each bin center are interpolated acoss channels. +# +# Most compute-alignment kwargs are similar to those used in motion correction. +# Key arguments and those related to inter-session alignment +# include: +# +# ``"num_shifts_global"``: This is the number of shifts to perform cross-correlation across for linear alignment. +# Put differently, this is the maximum allowed displacement to consider for rigid alignment. +# ``"num_shifts_block"``: The number of shifts to perform cross-correlation across for non-linear alignment (within each spatial bin). +# ``"akima_interp_nonrigid"``: If ``True``, perform akima interpolation across non-rigid spatial bins (rather than linear). +# ``"min_crosscorr_threshold"``: To estimate alignment, normalised cross-correlation is performed. In some cases, particularly +# for non-rigid alignment, there may be little correlation within a bin. To stop aberrant shifts estimated on poor correlations, +# this sets a minimum value for the correlation used to estimate the shift. If less than this value, the shift will be set to zero. +# +# Non-rigid window kwargs +# ~~~~~~~~~~~~~~~~~~~~~~~ +# Non-rigid window kwargs determine how the non-rigid alignment is performed, +# in particular around how the y-axis of the probe is segmented into blocks +# (each which will be aligned using rigid alignment) are found here. +# (and see ``get_non_rigid_window_kwargs()``. +# +# We can see how the ``compute_alignment_kwargs`` control the non-rigid alignment +# by inspecting the output of inter-session alignment. First, we generate a +# pair of recordings with non-rigid displacement and perform rigid alignment: + +recordings_list, _ = generate_session_displacement_recordings( + num_units=8, + recording_durations=[10, 10], + recording_shifts=((0, 0), (0, 200)), # (x offset, y offset) pairs + non_rigid_gradient=0.1, + seed=42 +) + +peaks_list, peak_locations_list = session_alignment.compute_peaks_locations_for_session_alignment( + recordings_list, + detect_kwargs={"method": "locally_exclusive"}, + localize_peaks_kwargs={"method": "grid_convolution"}, +) + + +non_rigid_window_kwargs = session_alignment.get_non_rigid_window_kwargs() +non_rigid_window_kwargs["rigid"] = True + +_, extra_info_rigid = session_alignment.align_sessions( + recordings_list, + peaks_list, + peak_locations_list, + estimate_histogram_kwargs=estimate_histogram_kwargs, + non_rigid_window_kwargs=non_rigid_window_kwargs, +) + +plt.plot(extra_info_rigid["corrected"]["corrected_session_histogram_list"][0]) +plt.plot(extra_info_rigid["corrected"]["corrected_session_histogram_list"][1]) +plt.show() + +# %% +# Above, you can see there rigid alignemnt has well-matched one peak but +# the second peak is offset. Next, we can apply non-rigid alignment, +# and visualise the non-rigid segments that the probe is split into. +# Note that by default, Gaussian windows are used: + +# %% + +non_rigid_window_kwargs = session_alignment.get_non_rigid_window_kwargs() +non_rigid_window_kwargs["rigid"] = False +non_rigid_window_kwargs["win_step_um"] = 200 +non_rigid_window_kwargs["win_scale_um"] = 100 + +compute_alignment_kwargs = session_alignment.get_compute_alignment_kwargs() +compute_alignment_kwargs["akima_interp_nonrigid"] = True + +_, extra_info_nonrigid = session_alignment.align_sessions( + recordings_list, + peaks_list, + peak_locations_list, + estimate_histogram_kwargs=estimate_histogram_kwargs, + non_rigid_window_kwargs=non_rigid_window_kwargs, +) + +plt.plot(extra_info_nonrigid["corrected"]["corrected_session_histogram_list"][0]) +plt.plot(extra_info_nonrigid["corrected"]["corrected_session_histogram_list"][1]) +plt.plot(extra_info_nonrigid["non_rigid_windows"].T) +plt.show() + +# %% +# It is notable that the non-rigid alignment is not perfect in this case. This +# is because for each bin, the displacement is computed and we can imagine +# its value being 'positioned' in the center of the bin. Then, the bin center +# values are interpolates across all channels. This leads to non-perfect alignment. +# +# This is in part because this is a simulated test case with only a few peaks and +# the spatial footprint of the APs is small. Nonetheless, the non-rigid window kwargs +# may be adjusting to maximise the performance of non-rigid alignment in the real world case. diff --git a/pyproject.toml b/pyproject.toml index 15ce837774..b6d0d0b0c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,7 +150,7 @@ test = [ "pytest-dependency", "pytest-cov", "psutil", - + "pytest-mock", "huggingface_hub", # preprocessing diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index ed18b815de..dd8d292281 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2437,12 +2437,19 @@ def generate_ground_truth_recording( parent_recording=noise_rec, upsample_vector=upsample_vector, ) - recording.annotate(is_filtered=True) - recording.set_probe(probe, in_place=True) - recording.set_channel_gains(1.0) - recording.set_channel_offsets(0.0) - + setup_inject_templates_recording(recording, probe) recording.name = "GroundTruthRecording" sorting.name = "GroundTruthSorting" return recording, sorting + + +def setup_inject_templates_recording(recording: BaseRecording, probe: Probe) -> None: + """ + Convenience function to modify a generated + recording in-place with annotation and probe details + """ + recording.annotate(is_filtered=True) + recording.set_probe(probe, in_place=True) + recording.set_channel_gains(1.0) + recording.set_channel_offsets(0.0) diff --git a/src/spikeinterface/core/motion.py b/src/spikeinterface/core/motion.py index 3ebf6ad371..e26b91de19 100644 --- a/src/spikeinterface/core/motion.py +++ b/src/spikeinterface/core/motion.py @@ -67,7 +67,13 @@ def __repr__(self): else: rigid_txt = f"non-rigid - {nbins} spatial bins" - interval_s = self.temporal_bins_s[0][1] - self.temporal_bins_s[0][0] + if self.temporal_bins_s[0].size > 1: + interval_s = self.temporal_bins_s[0][1] - self.temporal_bins_s[0][0] + else: + # If there is only one temporal bin (entire session), assume the bin + # left edge is zero, and take twice it for the bin size. + interval_s = self.temporal_bins_s[0][0] * 2 + txt = f"Motion {rigid_txt} - interval {interval_s}s - {self.num_segments} segments" return txt @@ -149,6 +155,12 @@ def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_inde # reshape to grid domain shape if necessary displacement = displacement.reshape(out_shape) + # For the inter-session alignment case + if self.temporal_bins_s[segment_index].size == 1 and self.spatial_bins_um.size == 1: + assert np.all(np.isnan(displacement)) + assert self.displacement[segment_index].size == 1 + displacement[:] = self.displacement[segment_index] + return displacement def to_dict(self): diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 7bd3bbd860..bd14320de9 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -548,7 +548,7 @@ def run_node_pipeline( Here a "spike" is a spike with any a label so already sorted. The main idea is to have a graph of nodes. - Every node is doing a computaion of some peaks and related traces. + Every node is doing a computation of some peaks and related traces. The first node is PeakSource so either a peak detector PeakDetector or peak/spike replay (PeakRetriever/SpikeRetriever) Every node can have one or several output that can be directed to other nodes (aka nodes have parents). @@ -587,7 +587,7 @@ def run_node_pipeline( Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. recording_slices : None | list[tuple] - Optionaly give a list of slices to run the pipeline only on some chunks of the recording. + Optionally give a list of slices to run the pipeline only on some chunks of the recording. It must be a list of (segment_index, frame_start, frame_stop). If None (default), the function iterates over the entire duration of the recording. diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index 9395823e03..a7e29b9aef 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -8,6 +8,7 @@ """ +from __future__ import annotations import numpy as np from probeinterface import generate_multi_columns_probe @@ -21,6 +22,7 @@ ) from .drift_tools import DriftingTemplates, make_linear_displacement, InjectDriftingTemplatesRecording from .noise_tools import generate_noise +from probeinterface import Probe # this should be moved in probeinterface but later @@ -181,7 +183,7 @@ def generate_displacement_vector( duration : float Duration of the displacement vector in seconds unit_locations : np.array - The unit location with shape (num_units, 3) + The unit location with shape (num_units, 2) displacement_sampling_frequency : float, default: 5. The sampling frequency of the displacement vector drift_start_um : list of float, default: [0, 20.] @@ -240,22 +242,70 @@ def generate_displacement_vector( if non_rigid_gradient is None: displacement_unit_factor[:, m] = 1 else: - gradient_direction = drift_stop_um - drift_start_um - gradient_direction /= np.linalg.norm(gradient_direction) - - proj = np.dot(unit_locations, gradient_direction).squeeze() - factors = (proj - np.min(proj)) / (np.max(proj) - np.min(proj)) - if non_rigid_gradient < 0: - # reverse - factors = 1 - factors - f = np.abs(non_rigid_gradient) - displacement_unit_factor[:, m] = factors * (1 - f) + f + displacement_unit_factor[:, m] = calculate_displacement_unit_factor( + non_rigid_gradient, unit_locations, drift_start_um, drift_stop_um + ) displacement_vectors = np.concatenate(displacement_vectors, axis=2) return displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps +def calculate_displacement_unit_factor( + non_rigid_gradient: float, unit_locations: np.array, drift_start_um: np.array, drift_stop_um: np.array +) -> np.array: + """ + Introduces a non-rigid drift across the probe, this is a linear + scaling of the displacement based on the unit position. + + To introduce non-rigid drift, a set of scaling factors (one per unit) + are generated. These scale the displacement applied to each unit + as a function of unit position. The smaller the `non_rigid_gradient`, + the larger the influence of the unit position is on scaling the + displacement (more non-linearity). + + The projections of the gradient vector (x, y) + and unit locations (x, y) are normalised to range between + 0 and 1 (i.e. based on relative location to the gradient). + + Parameters + ---------- + + non_rigid_gradient : float + A number in the range [0, 1] by which to scale the scaling factors + that are based on unit location. This sets the weighting given to the factors + based on unit locations. When 1, the factors will all equal 1 (no effect), + when 0, the scaling factor based on unit location will be used directly. + Smaller number results in more nonlinearity. + unit_locations : np.array + The unit location with shape (num_units, 2) + drift_start_um : np.array + The start boundary of the motion in the x and y direction. + drift_stop_um : np.array + The stop boundary of the motion in the x and y direction. + + Returns + ------- + displacement_unit_factor : np.array + An array of scaling factors (one per unit) by which + to scale the displacement. + """ + gradient_direction = drift_stop_um - drift_start_um + gradient_direction /= np.linalg.norm(gradient_direction) + + proj = np.dot(unit_locations, gradient_direction).squeeze() + factors = (proj - np.min(proj)) / (np.max(proj) - np.min(proj)) + + if non_rigid_gradient < 0: # reverse + factors = 1 - factors + + f = np.abs(non_rigid_gradient) + + displacement_unit_factor = factors * (1 - f) + f + + return displacement_unit_factor + + def generate_drifting_recording( num_units=250, duration=600.0, @@ -351,12 +401,9 @@ def generate_drifting_recording( This can be helpfull for motion benchmark. """ # probe - if generate_probe_kwargs is None: - generate_probe_kwargs = _toy_probes[probe_name] - probe = generate_multi_columns_probe(**generate_probe_kwargs) - num_channels = probe.get_contact_count() - probe.set_device_channel_indices(np.arange(num_channels)) + probe = generate_probe(generate_probe_kwargs, probe_name) channel_locations = probe.contact_positions + # import matplotlib.pyplot as plt # import probeinterface.plotting # fig, ax = plt.subplots() @@ -384,9 +431,7 @@ def generate_drifting_recording( unit_displacements[:, :, direction] += m # unit_params need to be fixed before the displacement steps - generate_templates_kwargs = generate_templates_kwargs.copy() - unit_params = _ensure_unit_params(generate_templates_kwargs.get("unit_params", {}), num_units, seed) - generate_templates_kwargs["unit_params"] = unit_params + generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed) # generate templates templates_array = generate_templates( @@ -479,3 +524,50 @@ def generate_drifting_recording( return static_recording, drifting_recording, sorting, extra_infos else: return static_recording, drifting_recording, sorting + + +def generate_probe(generate_probe_kwargs: dict, probe_name: str | None = None) -> Probe: + """ + Generate a probe for use in certain ground-truth recordings. + + Parameters + ---------- + + generate_probe_kwargs : dict + The kwargs to pass to `generate_multi_columns_probe()` + probe_name : str | None + The probe type if generate_probe_kwargs is None. + """ + if generate_probe_kwargs is None: + assert probe_name is not None, "`probe_name` must be set if `generate_probe_kwargs` is `None`." + generate_probe_kwargs = _toy_probes[probe_name] + probe = generate_multi_columns_probe(**generate_probe_kwargs) + num_channels = probe.get_contact_count() + probe.set_device_channel_indices(np.arange(num_channels)) + + return probe + + +def fix_generate_templates_kwargs(generate_templates_kwargs: dict, num_units: int, seed: int) -> dict: + """ + Fix the generate_template_kwargs such that the same units are created + across calls to `generate_template`. We must explicitly pre-set + the parameters for each unit, done in `_ensure_unit_params()`. + + Parameters + ---------- + + generate_templates_kwargs : dict + These kwargs will have the "unit_params" entry edited such that the + parameters are explicitly set for each unit to create (rather than + generated randomly on the fly). + num_units : int + Number of units to fix the kwargs for + seed : int + Random seed. + """ + generate_templates_kwargs = generate_templates_kwargs.copy() + unit_params = _ensure_unit_params(generate_templates_kwargs.get("unit_params", {}), num_units, seed) + generate_templates_kwargs["unit_params"] = unit_params + + return generate_templates_kwargs diff --git a/src/spikeinterface/generation/session_displacement_generator.py b/src/spikeinterface/generation/session_displacement_generator.py new file mode 100644 index 0000000000..241d6ce7a7 --- /dev/null +++ b/src/spikeinterface/generation/session_displacement_generator.py @@ -0,0 +1,529 @@ +import copy + +from spikeinterface.generation.drifting_generator import ( + generate_probe, + fix_generate_templates_kwargs, + calculate_displacement_unit_factor, +) +from spikeinterface.core.generate import ( + generate_unit_locations, + generate_sorting, + generate_templates, +) +import numpy as np +from spikeinterface.generation.noise_tools import generate_noise +from spikeinterface.core.generate import setup_inject_templates_recording, _ensure_firing_rates +from spikeinterface.core import InjectTemplatesRecording + + +def generate_session_displacement_recordings( + num_units=250, + recording_durations=(10, 10, 10), + recording_shifts=((0, 0), (0, 25), (0, 50)), + non_rigid_gradient=None, + recording_amplitude_scalings=None, + shift_units_outside_probe=False, + sampling_frequency=30000.0, + probe_name="Neuropixel-128", + generate_probe_kwargs=None, + generate_unit_locations_kwargs=dict( + margin_um=20.0, + minimum_z=5.0, + maximum_z=45.0, + minimum_distance=18.0, + max_iteration=100, + distance_strict=False, + ), + generate_templates_kwargs=dict( + ms_before=1.5, + ms_after=3.0, + mode="ellipsoid", + unit_params=dict( + alpha=(150.0, 500.0), + spatial_decay=(10, 45), + ), + ), + generate_sorting_kwargs=dict(firing_rates=(2.0, 8.0), refractory_period_ms=4.0), + generate_noise_kwargs=dict(noise_levels=(12.0, 15.0), spatial_decay=25.0), + extra_outputs=False, + seed=None, +): + """ + Generate a set of recordings simulating probe drift across recording + sessions. + + Rigid drift can be added in the (x, y) direction in `recording_shifts`. + These drifts can be made non-rigid (scaled dependent on the unit location) + with the `non_rigid_gradient` parameter. Amplitude of units can be scaled + (e.g. template signal removed by scaling with zero) by specifying scaling + factors in `recording_amplitude_scalings`. + + Parameters + ---------- + + num_units : int + The number of units in the generated recordings. + recording_durations : list + An array of length (num_recordings,) specifying the + duration that each created recording should be, in seconds. + recording_shifts : list + An array of length (num_recordings,) in which each element + is a 2-element array specifying the (x, y) shift for the recording. + Typically, the first recording will have shift (0, 0) so all further + recordings are shifted relative to it. e.g. to create two recordings, + the second shifted by 50 um in the x-direction and 250 um in the y + direction : ((0, 0), (50, 250)). + non_rigid_gradient : None | float | list[float] + Factor which sets the level of non-rigidty in the displacement. + See `calculate_displacement_unit_factor` for details. + `None` is linear gradient. + recording_amplitude_scalings : dict + A dict with keys: + "method" - order by which to apply the scalings. + "by_passed_order" - scalings are applied to the unit templates + in order passed + "by_firing_rate" - scalings are applied to the units in order of + maximum to minimum firing rate + "by_amplitude_and_firing_rate" - scalings are applied to the units + in order of amplitude * firing_rate (maximum to minimum) + "scalings" - a list of numpy arrays, one for each recording, with + each entry an array of length num_units holding the unit scalings. + e.g. for 3 recordings, 2 units: ((1, 1), (1, 1), (0.5, 0.5)). + shift_units_outside_probe : bool + By default (`False`), when units are shifted across sessions, new units are + not introduced into the recording (e.g. the region in which units + have been shifted out of is left at baseline level). In reality, + when the probe shifts new units from outside the original recorded + region are shifted into the recording. When `True`, new units + are shifted into the generated recording. + generate_sorting_kwargs : dict + Only `firing_rates` and `refractory_period_ms` are expected if passed. + + All other parameters are used as in from `generate_drifting_recording()`. + + Returns + ------- + output_recordings : list + A list of recordings with units shifted (i.e. replicated probe shift). + output_sortings : list + A list of corresponding sorting objects. + extra_outputs_dict (options) : dict + When `extra_outputs` is `True`, a dict containing variables used + in the generation process. + "unit_locations" : A list (length num records) of shifted unit locations + "templates_array_moved" : list[np.array] + A list (length num records) of (num_units, num_samples, num_channels) + arrays of templates that have been shifted. + + Notes + ----- + It is important to consider what unit properties are maintained + across the session. Here, all `generate_template_kwargs` are fixed + across sessions, to be sure the unit properties do not change. + The firing rates passed to `generate_sorting` for each unit are + also fixed across sessions. When a seed is set, the exact spike times + will also be fixed across recordings. otherwise, when seed is `None` + the actual spike times will be different across recordings, although + all other unit properties will be maintained (except any location + shifting and template scaling applied). + """ + # temporary fix + generate_unit_locations_kwargs = copy.deepcopy(generate_unit_locations_kwargs) + generate_templates_kwargs = copy.deepcopy(generate_templates_kwargs) + generate_sorting_kwargs = copy.deepcopy(generate_sorting_kwargs) + generate_noise_kwargs = copy.deepcopy(generate_noise_kwargs) + + _check_generate_session_displacement_arguments( + num_units, + recording_durations, + recording_shifts, + recording_amplitude_scalings, + shift_units_outside_probe, + non_rigid_gradient, + ) + + probe = generate_probe(generate_probe_kwargs, probe_name) + channel_locations = probe.contact_positions + + # Create the starting unit locations (which will be shifted). + unit_locations = generate_unit_locations( + num_units, + channel_locations, + seed=seed, + **generate_unit_locations_kwargs, + ) + + # Fix generate template kwargs, so they are the same for every created recording. + # Also fix unit firing rates across recordings. + fixed_generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed) + + fixed_firing_rates = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed) + fixed_generate_sorting_kwargs = copy.deepcopy(generate_sorting_kwargs) + fixed_generate_sorting_kwargs["firing_rates"] = fixed_firing_rates + + if shift_units_outside_probe: + # Create a new set of templates one probe-width above and + # one probe-width below the original templates. The number of + # units is duplicated for each section, so the new num units + # is 3x the old num units. + num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs = ( + _update_kwargs_for_extended_units( + num_units, + channel_locations, + unit_locations, + generate_unit_locations_kwargs, + generate_templates_kwargs, + generate_sorting_kwargs, + fixed_generate_templates_kwargs, + fixed_generate_sorting_kwargs, + seed, + ) + ) + + # Start looping over parameters, creating recordings shifted + # and scaled as required + extra_outputs_dict = { + "unit_locations": [], + "templates_array_moved": [], + "firing_rates": [], + } + output_recordings = [] + output_sortings = [] + + for rec_idx, (shift, duration) in enumerate(zip(recording_shifts, recording_durations)): + + displacement_vector, displacement_unit_factor = _get_inter_session_displacements( + shift, + non_rigid_gradient[rec_idx] if isinstance(non_rigid_gradient, list) else non_rigid_gradient, + num_units, + unit_locations, + ) + + # Move the canonical `unit_locations` according to the set (x, y) shifts + unit_locations_moved = unit_locations.copy() + unit_locations_moved[:, :2] += displacement_vector[0, :][np.newaxis, :] * displacement_unit_factor + + # Generate the sorting (e.g. spike times) for the recording + + sorting = generate_sorting( + num_units=num_units, + sampling_frequency=sampling_frequency, + durations=[duration], + **fixed_generate_sorting_kwargs, + seed=seed, + ) + sorting.set_property("gt_unit_locations", unit_locations_moved) + + # Generate the noise in the recording + noise = generate_noise( + probe=probe, + sampling_frequency=sampling_frequency, + durations=[duration], + seed=seed, + **generate_noise_kwargs, + ) + + # Generate the (possibly shifted, scaled) unit templates + templates_array_moved = generate_templates( + channel_locations, + unit_locations_moved, + sampling_frequency=sampling_frequency, + seed=seed, + **fixed_generate_templates_kwargs, + ) + + if recording_amplitude_scalings is not None: + + first_rec_templates = ( + templates_array_moved if rec_idx == 0 else extra_outputs_dict["templates_array_moved"][0] + ) + _amplitude_scale_templates_in_place( + first_rec_templates, + templates_array_moved, + recording_amplitude_scalings, + fixed_generate_sorting_kwargs, + rec_idx, + ) + + # Bring it all together in a `InjectTemplatesRecording` and + # propagate all relevant metadata to the recording. + ms_before = fixed_generate_templates_kwargs["ms_before"] + nbefore = int(sampling_frequency * ms_before / 1000.0) + + recording = InjectTemplatesRecording( + sorting=sorting, + templates=templates_array_moved, + nbefore=nbefore, + amplitude_factor=None, + parent_recording=noise, + num_samples=noise.get_num_samples(0), + upsample_vector=None, + check_borders=False, + ) + + setup_inject_templates_recording(recording, probe) + + recording.name = "InterSessionDisplacementRecording" + sorting.name = "InterSessionDisplacementSorting" + + output_recordings.append(recording) + output_sortings.append(sorting) + extra_outputs_dict["unit_locations"].append(unit_locations_moved) + extra_outputs_dict["templates_array_moved"].append(templates_array_moved) + extra_outputs_dict["firing_rates"].append(fixed_generate_sorting_kwargs["firing_rates"]) + + if extra_outputs: + return output_recordings, output_sortings, extra_outputs_dict + else: + return output_recordings, output_sortings + + +def _get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_locations): + """ + Get the formatted `displacement_vector` and `displacement_unit_factor` + used to shift the `unit_locations`. + + Parameters + --------- + shift : np.array | list | tuple + A 2-element array with the shift in the (x, y) direction. + non_rigid_gradient : float + Factor which sets the level of non-rigidty in the displacement. + See `calculate_displacement_unit_factor` for details. + num_units : int + Number of units + unit_locations : np.array + (num_units, 3) array of unit locations (x, y, z). + + Returns + ------- + displacement_vector : np.array + A (:, 2) array of (x, y) of displacements + to add to (i.e. move) unit_locations. + e.g. array([[1, 2]]) + displacement_unit_factor : np.array + A (num_units, :) array of scaling values to apply to the + displacement vector in order to add nonrigid shift to + the displacement. Note the same scaling is applied to the + x and y dimension. + """ + displacement_vector = np.atleast_2d(shift) + + if non_rigid_gradient is None or (shift[0] == 0 and shift[1] == 0): + displacement_unit_factor = np.ones((num_units, 1)) + else: + displacement_unit_factor = calculate_displacement_unit_factor( + non_rigid_gradient, + unit_locations[:, :2], + drift_start_um=np.array([0, 0], dtype=float), + drift_stop_um=np.array(shift, dtype=float), + ) + displacement_unit_factor = displacement_unit_factor[:, np.newaxis] + + return displacement_vector, displacement_unit_factor + + +def _amplitude_scale_templates_in_place( + first_rec_templates, moved_templates, recording_amplitude_scalings, fixed_generate_sorting_kwargs, rec_idx +): + """ + Scale a set of templates given a set of scaling values. The scaling + values can be applied in the order passed, or instead in order of + the unit firing range (max to min) or unit amplitude * firing rate (max to min). + This will chang the `templates_array` in place. This must be done after + the templates are moved. + + Parameters + ---------- + first_rec_templates : np.array + The (num_units, num_samples, num_channels) templates array from the + first recording. Scaling by amplitude scales based on the amplitudes in + the first session. + moved_templates : np.array + A (num_units, num_samples, num_channels) array moved templates to the + current recording, that will be scaled. + recording_amplitude_scalings : dict + see `generate_session_displacement_recordings()`. + fixed_generate_sorting_kwargs : dict + Dict holding the firing frequency of all units. + The unit order is assumed to match the templates. + rec_idx : int + The index of the recording for which the templates are being scaled. + + Notes + ----- + This method is used in the context of inter-session displacement. Often, + units may drop out of the recording across sessions. This simulates this by + directly scaling the template (e.g. if scaling by 0, the template is completely + dropped out). The provided scalings can be applied in the order passed, or + in the order of unit firing rate or firing rate * amplitude. The idea is + that it may be desired to remove to downscale the most activate neurons + that contribute most significantly to activity histograms. Similarly, + if amplitude is included in activity histograms the amplitude may + also want to be considered when ordering the units to downscale. + """ + method = recording_amplitude_scalings["method"] + + if method in ["by_amplitude_and_firing_rate", "by_firing_rate"]: + + firing_rates_hz = fixed_generate_sorting_kwargs["firing_rates"] + + if method == "by_amplitude_and_firing_rate": + neg_ampl = np.min(np.min(first_rec_templates, axis=2), axis=1) + assert np.all(neg_ampl < 0), "assumes all amplitudes are negative here." + score = firing_rates_hz * neg_ampl + else: + score = firing_rates_hz + + order_idx = np.argsort(score) + ordered_rec_scalings = recording_amplitude_scalings["scalings"][rec_idx][order_idx, np.newaxis, np.newaxis] + + elif method == "by_passed_order": + + ordered_rec_scalings = recording_amplitude_scalings["scalings"][rec_idx][:, np.newaxis, np.newaxis] + + else: + raise ValueError("`recording_amplitude_scalings['method']` not recognised.") + + moved_templates *= ordered_rec_scalings + + +def _check_generate_session_displacement_arguments( + num_units, + recording_durations, + recording_shifts, + recording_amplitude_scalings, + shift_units_outside_probe, + non_rigid_gradient, +): + """ + Function to check the input arguments related to recording + shift and scale parameters are the correct size. + """ + expected_num_recs = len(recording_durations) + + if len(recording_shifts) != expected_num_recs: + raise ValueError( + "`recording_shifts` and `recording_durations` must be " + "the same length, the number of recordings to generate." + ) + + if recording_amplitude_scalings and shift_units_outside_probe: + raise ValueError( + "At present, using `recording_amplitude_scalings` and " + "`shift_units_outside_probe` together is not supported." + ) + + shifts_are_2d = [len(shift) == 2 for shift in recording_shifts] + if not all(shifts_are_2d): + raise ValueError("Each record entry for `recording_shifts` must have two elements, the x and y shift.") + + if recording_amplitude_scalings is not None: + + keys = recording_amplitude_scalings.keys() + if not "method" in keys or not "scalings" in keys: + raise ValueError("`recording_amplitude_scalings` must be a dict with keys `method` and `scalings`.") + + allowed_methods = ["by_passed_order", "by_amplitude_and_firing_rate", "by_firing_rate"] + if not recording_amplitude_scalings["method"] in allowed_methods: + raise ValueError(f"`recording_amplitude_scalings` must be one of {allowed_methods}") + + rec_scalings = recording_amplitude_scalings["scalings"] + if not len(rec_scalings) == expected_num_recs: + raise ValueError("`recording_amplitude_scalings` 'scalings' must have one array per recording.") + + if not all([len(scale) == num_units for scale in rec_scalings]): + raise ValueError( + "The entry for each recording in `recording_amplitude_scalings` " + "must have the same length as the number of units." + ) + + if isinstance(non_rigid_gradient, list): + if not len(non_rigid_gradient) == expected_num_recs: + raise ValueError("If `non_rigid_gradient` is a list, it must " "contain one value for each recording.") + + +def _update_kwargs_for_extended_units( + num_units, + channel_locations, + unit_locations, + generate_unit_locations_kwargs, + generate_templates_kwargs, + generate_sorting_kwargs, + fixed_generate_templates_kwargs, + fixed_generate_sorting_kwargs, + seed, +): + """ + In a real world situation, if the probe moves up / down + not only will previously recorded units be shifted, but + new units will be introduced into the recording. + + This function extends the default num units, unit locations, + and template / sorting kwargs to extend the unit of units + one probe's height (y dimension) above and below the probe. + Now, when the unit locations are shifted, new units will be + introduced into the recording (from below or above). + + It is important that the unit kwargs for the units are kept the + same across runs when seeded (i.e. whether `shift_units_outside_probe` + is `True` or `False`). To achieve this, the fixed unit kwargs + are extended with new units located above and below these fixed + units. The seeds are shifted slightly, so the introduced + units do not duplicate the existing units. Note that this maintains + the density of neurons above / below the probe (it is not random). + """ + seed_top = seed + 1 if seed is not None else None + seed_bottom = seed - 1 if seed is not None else None + + # Set unit locations above and below the probe and extend + # the `unit_locations` array. + channel_locations_extend_top = channel_locations.copy() + channel_locations_extend_top[:, 1] -= np.max(channel_locations[:, 1]) + + extend_top_locations = generate_unit_locations( + num_units, + channel_locations_extend_top, + seed=seed_top, + **generate_unit_locations_kwargs, + ) + + channel_locations_extend_bottom = channel_locations.copy() + channel_locations_extend_bottom[:, 1] += np.max(channel_locations[:, 1]) + + extend_bottom_locations = generate_unit_locations( + num_units, + channel_locations_extend_bottom, + seed=seed_bottom, + **generate_unit_locations_kwargs, + ) + + unit_locations = np.r_[extend_bottom_locations, unit_locations, extend_top_locations] + + # For the new units located above and below the probe, generate a set of + # firing rates and template kwargs. + + # Extend the template kwargs + template_kwargs_top = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_top) + template_kwargs_bottom = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_bottom) + + for key in fixed_generate_templates_kwargs["unit_params"].keys(): + fixed_generate_templates_kwargs["unit_params"][key] = np.r_[ + template_kwargs_top["unit_params"][key], + fixed_generate_templates_kwargs["unit_params"][key], + template_kwargs_bottom["unit_params"][key], + ] + + # Extend the firing rates + firing_rates_top = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_top) + firing_rates_bottom = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_bottom) + + fixed_generate_sorting_kwargs["firing_rates"] = np.r_[ + firing_rates_top, fixed_generate_sorting_kwargs["firing_rates"], firing_rates_bottom + ] + + # Update the number of units (3x as a + # new set above and below the existing units) + num_units *= 3 + + return num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs diff --git a/src/spikeinterface/generation/tests/test_session_displacement_generator.py b/src/spikeinterface/generation/tests/test_session_displacement_generator.py new file mode 100644 index 0000000000..9736a0d11e --- /dev/null +++ b/src/spikeinterface/generation/tests/test_session_displacement_generator.py @@ -0,0 +1,518 @@ +import pytest + +from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings +from spikeinterface.generation.drifting_generator import generate_drifting_recording +from spikeinterface.core.generate import _ensure_firing_rates +from spikeinterface.core import order_channels_by_depth +import numpy as np +from spikeinterface.sortingcomponents.peak_detection import detect_peaks +from spikeinterface.sortingcomponents.peak_localization import localize_peaks + + +class TestSessionDisplacementGenerator: + """ + This class tests the `generate_session_displacement_recordings` that + returns a recordings / sorting in which the units are shifted + across sessions. This is achieved by shifting the unit locations + in both (x, y) on the generated templates that are used in + `InjectTemplatesRecording()`. + """ + + @pytest.fixture(scope="function") + def options(self): + """ + Set a set of base options that can be used in + `generate_session_displacement_recordings() ("kwargs") + and provide general information on the generated recordings. + These can be edited in the tests as required. + """ + options = { + "kwargs": { + "recording_durations": (10, 10, 25, 33), + "recording_shifts": ((0, 0), (2, -100), (-3, 275), (4, 1e6)), + "num_units": 5, + "extra_outputs": True, + "seed": 42, + }, + "num_recs": 4, + "y_bin_um": 10, + } + options["kwargs"]["generate_probe_kwargs"] = dict( + num_columns=1, + num_contact_per_column=128, + xpitch=16, + ypitch=options["y_bin_um"], + contact_shapes="square", + contact_shape_params={"width": 12}, + ) + + return options + + ### Tests + def test_x_y_rigid_shifts_are_properly_set(self, options): + """ + The session displacement works by generating a set of + templates shared across all recordings, but set with + different `unit_locations()`. Check here that the + (x, y) displacements passed in `recording_shifts` are properly + propagated. + + First, check the set `unit_locations` are moved as expected according + to the (x, y) shifts). Next, check the templates themselves are + moved as expected. The x-axis shift has the effect of changing + the template amplitude, and is not possible to test. However, + the y-axis shift shifts the maximum signal channel, so we check + the maximum signal channel o fthe templates is shifted as expected. + This implicitly tests the x-axis case as if the x-axis `unit_locations` + are shifted as expected, and the unit-locations are propagated + to the template, then (x, y) will both be working. + """ + output_recordings, _, extra_outputs = generate_session_displacement_recordings(**options["kwargs"]) + num_units = options["kwargs"]["num_units"] + recording_shifts = options["kwargs"]["recording_shifts"] + + # test unit locations are shifted as expected according + # to the record shifts + locations_1 = extra_outputs["unit_locations"][0] + + for rec_idx in range(1, 4): + + shifts = recording_shifts[rec_idx] + + assert np.array_equal( + locations_1 + np.r_[shifts, 0].astype(np.float32), extra_outputs["unit_locations"][rec_idx] + ) + + # Check that the generated templates are correctly shifted + # For each generated unit, check that the max loading channel is + # shifted as expected. In the case that the unit location is off the + # probe, check the maximum signal channel is the min / max channel on + # the probe, or zero (the unit is too far to reach the probe). + min_channel_loc = output_recordings[0].get_channel_locations()[0, 1] + max_channel_loc = output_recordings[0].get_channel_locations()[-1, 1] + for unit_idx in range(num_units): + + start_pos = self._get_peak_chan_loc_in_um( + extra_outputs["templates_array_moved"][0][unit_idx], + options["y_bin_um"], + ) + + for rec_idx in range(1, options["num_recs"]): + + new_pos = self._get_peak_chan_loc_in_um( + extra_outputs["templates_array_moved"][rec_idx][unit_idx], options["y_bin_um"] + ) + + y_shift = recording_shifts[rec_idx][1] + if start_pos + y_shift > max_channel_loc: + assert new_pos == max_channel_loc or new_pos == 0 + elif start_pos + y_shift < min_channel_loc: + assert new_pos == min_channel_loc or new_pos == 0 + else: + assert np.isclose(new_pos, start_pos + y_shift, options["y_bin_um"]) + + # Confidence check the correct templates are + # loaded to the recording object. + for rec_idx in range(options["num_recs"]): + assert np.array_equal( + output_recordings[rec_idx].templates, + extra_outputs["templates_array_moved"][rec_idx], + ) + + def _get_peak_chan_loc_in_um(self, template_array, y_bin_um): + """ + Convenience function to get the maximally loading + channel y-position in um for the template. + """ + return np.argmax(np.max(template_array, axis=0)) * y_bin_um + + def test_recordings_length(self, options): + """ + Test that the `recording_durations` that sets the + length of each recording changes the recording + length as expected. + """ + output_recordings = generate_session_displacement_recordings(**options["kwargs"])[0] + + for rec, expected_rec_length in zip(output_recordings, options["kwargs"]["recording_durations"]): + assert rec.get_total_duration() == expected_rec_length + + def test_spike_times_and_firing_rates_across_recordings(self, options): + """ + Check the randomisation of spike times across recordings. + When a seed is set, this is passed to `generate_sorting` + and so the spike times across all records are expected + to be identical. However, if no seed is set, then the spike + times will be different across recordings. + """ + options["kwargs"]["recording_durations"] = (10,) * options["num_recs"] + + output_sortings_same, extra_outputs_same = generate_session_displacement_recordings(**options["kwargs"])[1:3] + + options["kwargs"]["seed"] = None + output_sortings_different, extra_outputs_different = generate_session_displacement_recordings( + **options["kwargs"] + )[1:3] + + for unit_idx in range(options["kwargs"]["num_units"]): + for rec_idx in range(1, options["num_recs"]): + + # Exact spike times are not preserved when seed is None + assert np.array_equal( + output_sortings_same[0].get_unit_spike_train(str(unit_idx)), + output_sortings_same[rec_idx].get_unit_spike_train(str(unit_idx)), + ) + assert not np.array_equal( + output_sortings_different[0].get_unit_spike_train(str(unit_idx)), + output_sortings_different[rec_idx].get_unit_spike_train(str(unit_idx)), + ) + # Firing rates should always be preserved. + assert np.array_equal( + extra_outputs_same["firing_rates"][0][unit_idx], + extra_outputs_same["firing_rates"][rec_idx][unit_idx], + ) + assert np.array_equal( + extra_outputs_different["firing_rates"][0][unit_idx], + extra_outputs_different["firing_rates"][rec_idx][unit_idx], + ) + + def test_ensure_unit_params_assumption(self): + """ + Test the assumption that `_ensure_unit_params` does not + change an array of firing rates, otherwise `generate_sorting` + will internally change our firing rates. + """ + array = np.random.randn(5) + assert np.array_equal(_ensure_firing_rates(array, 5, None), array) + + @pytest.mark.parametrize("dim_idx", [0, 1]) + def test_x_y_shift_non_rigid(self, options, dim_idx): + """ + Check that the non-rigid shift changes the channel location + as expected. Non-rigid shifts are calculated depending on the + position of the channel. The `non_rigid_gradient` parameter + determines how much the position or 'distance' of the channel + (w.r.t the gradient of movement) affects the scaling. When + 0, the displacement is scaled by the distance. When 0, the + distance is ignored and all scalings are 1. + + This test checks the generated `unit_locations` under extreme + cases, when `non_rigid_gradient` is `None` or 0, which are equivalent, + and when it is `1`, and the displacement is directly propotional to + the unit position. + """ + options["kwargs"]["recording_shifts"] = ((0, 0), (10, 15), (15, 20), (20, 25)) + + _, _, rigid_info = generate_session_displacement_recordings( + **options["kwargs"], + non_rigid_gradient=None, + ) + _, _, nonrigid_max_info = generate_session_displacement_recordings( + **options["kwargs"], + non_rigid_gradient=0, + ) + _, _, nonrigid_none_info = generate_session_displacement_recordings( + **options["kwargs"], + non_rigid_gradient=1, + ) + + initial_locations = rigid_info["unit_locations"][0] + + # For each recording (i.e. each recording as different displacement + # w.r.t the first recording), check the rigid and nonrigid shifts + # are as expected. + for rec_idx in range(1, options["num_recs"]): + + shift = options["kwargs"]["recording_shifts"][rec_idx][dim_idx] + + # Get the rigid shift between the first recording and this shifted recording + # Check shifts for all unit locations are all the same. + shifts_rigid = self._get_shifts(rigid_info, rec_idx, dim_idx, initial_locations) + shifts_rigid = np.round(shifts_rigid, 5) + + assert np.unique(shifts_rigid).size == 1 + + # Get the nonrigid shift between the first recording and this recording. + # The shift for each unit should be directly proportional to its position. + y_shifts_nonrigid = self._get_shifts(nonrigid_max_info, rec_idx, dim_idx, initial_locations) + + distance = np.linalg.norm(initial_locations, axis=1) + norm_distance = (distance - np.min(distance)) / (np.max(distance) - np.min(distance)) + + assert np.unique(y_shifts_nonrigid).size == options["kwargs"]["num_units"] + + # There is some small rounding error due to difference in distance computation, + # the main thing is the relative order not the absolute value. + assert np.allclose(y_shifts_nonrigid, shift * norm_distance, rtol=0, atol=0.5) + + # then do again with non-ridig-gradient 1 and check it matches rigid case + shifts_rigid_2 = self._get_shifts(nonrigid_none_info, rec_idx, dim_idx, initial_locations) + assert np.array_equal(shifts_rigid, np.round(shifts_rigid_2, 5)) + + def test_non_rigid_shifts_list(self, options): + """ + Quick check that non-rigid gradients are indeed different across + recordings when a list of different gradients is passed. + """ + options["kwargs"]["recording_shifts"] = ((0, 0), (0, 10), (0, 10), (0, 10)) + options["kwargs"]["seed"] = 42 + + _, _, same_info = generate_session_displacement_recordings( + **options["kwargs"], + non_rigid_gradient=0.50, + ) + _, _, different_info = generate_session_displacement_recordings( + **options["kwargs"], + non_rigid_gradient=[0.25, 0.50, 0.75, 1.0], + ) + + # Just check the first two recordings + assert np.array_equal(same_info["unit_locations"][1], same_info["unit_locations"][2]) + assert not np.array_equal(different_info["unit_locations"][1], different_info["unit_locations"][2]) + + def _get_shifts(self, extras_dict, rec_idx, dim_idx, initial_locations): + return extras_dict["unit_locations"][rec_idx][:, dim_idx] - initial_locations[:, dim_idx] + + def test_displacement_with_peak_detection(self, options): + """ + This test checks that the session displacement occurs + as expected under normal usage. Create a recording with a + single unit and a y-axis displacement. Find the peak + locations and check the shifted peak location is as expected, + within the tolerate of the y-axis pitch + some small error. + """ + # The seed is important here, otherwise the unit positions + # might go off the end of the probe. These kwargs are + # chosen to make the recording as small as possible as this + # test is slow for larger recordings. + y_shift = 50 + options["kwargs"]["recording_shifts"] = ((0, 0), (0, y_shift)) + options["kwargs"]["recording_durations"] = (0.5, 0.5) + options["num_recs"] = 2 + options["kwargs"]["num_units"] = 1 + options["kwargs"]["generate_probe_kwargs"]["num_contact_per_column"] = 18 + + output_recordings, _, _ = generate_session_displacement_recordings( + **options["kwargs"], generate_noise_kwargs=dict(noise_levels=(1.0, 2.0), spatial_decay=1.0) + ) + + first_recording = output_recordings[0] + + # Peak location of unshifted recording + peaks = detect_peaks(first_recording, method="by_channel") + peak_locs = localize_peaks(first_recording, peaks, method="center_of_mass") + first_pos = np.mean(peak_locs["y"]) + + # Find peak location on shifted recording and check it is + # the original location + shift. + shifted_recording = output_recordings[1] + peaks = detect_peaks(shifted_recording, method="by_channel") + peak_locs = localize_peaks(shifted_recording, peaks, method="center_of_mass") + + new_pos = np.mean(peak_locs["y"]) + + # Completely arbitrary 0.5 offset to pass tests on macOS which fail around ~0.2 + # over the bin, probably due to small amount of noise. + assert np.isclose(new_pos, first_pos + y_shift, rtol=0, atol=options["y_bin_um"] + 0.5) + + def test_amplitude_scalings(self, options): + """ + Test that the templates are scaled by the passed scaling factors + in the specified order. The order can be in the passed order, + in the order of highest-to-lowest firing unit, or in the order + of (amplitude * firing_rate) (highest to lowest unit). + """ + # Setup arguments to create an unshifted set of recordings + # where the templates are to be scaled with `true_scalings` + options["kwargs"]["recording_durations"] = (10, 10) + options["kwargs"]["recording_shifts"] = ((0, 0), (0, 0)) + options["kwargs"]["num_units"] == 5, + + true_scalings = np.array([0.1, 0.2, 0.3, 0.4, 0.5]) + + recording_amplitude_scalings = { + "method": "by_passed_order", + "scalings": (np.ones(5), true_scalings), + } + + _, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], + recording_amplitude_scalings=recording_amplitude_scalings, + ) + + # Check that the unit templates are scaled in the order + # the scalings were passed. + test_scalings = self._calculate_scalings_from_output(extra_outputs) + assert np.allclose(test_scalings, true_scalings) + + # Now run, again applying the scalings in the order of + # unit firing rates (highest to lowest). + firing_rates = np.array([5, 4, 3, 2, 1]) + generate_sorting_kwargs = dict(firing_rates=firing_rates, refractory_period_ms=4.0) + recording_amplitude_scalings["method"] = "by_firing_rate" + _, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], + recording_amplitude_scalings=recording_amplitude_scalings, + generate_sorting_kwargs=generate_sorting_kwargs, + ) + + test_scalings = self._calculate_scalings_from_output(extra_outputs) + assert np.allclose(test_scalings, true_scalings[np.argsort(firing_rates)]) + + # Finally, run again applying the scalings in the order of + # unit amplitude * firing_rate + recording_amplitude_scalings["method"] = "by_amplitude_and_firing_rate" # TODO: method -> order + amplitudes = np.min(np.min(extra_outputs["templates_array_moved"][0], axis=2), axis=1) + firing_rate_by_amplitude = np.argsort(amplitudes * firing_rates) + + _, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], + recording_amplitude_scalings=recording_amplitude_scalings, + generate_sorting_kwargs=generate_sorting_kwargs, + ) + + test_scalings = self._calculate_scalings_from_output(extra_outputs) + assert np.allclose(test_scalings, true_scalings[firing_rate_by_amplitude]) + + def _calculate_scalings_from_output(self, extra_outputs): + first, second = extra_outputs["templates_array_moved"] + first_min = np.min(np.min(first, axis=2), axis=1) + second_min = np.min(np.min(second, axis=2), axis=1) + test_scalings = second_min / first_min + return test_scalings + + def test_metadata(self, options): + """ + Check that metadata required to be set of generated recordings is present + on all output recordings. + """ + output_recordings, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], generate_noise_kwargs=dict(noise_levels=(1.0, 2.0), spatial_decay=1.0) + ) + num_chans = output_recordings[0].get_num_channels() + + for i in range(len(output_recordings)): + assert output_recordings[i].name == "InterSessionDisplacementRecording" + assert output_recordings[i]._annotations["is_filtered"] is True + assert output_recordings[i].has_probe() + assert np.array_equal(output_recordings[i].get_channel_gains(), np.ones(num_chans)) + assert np.array_equal(output_recordings[i].get_channel_offsets(), np.zeros(num_chans)) + + assert np.array_equal( + output_sortings[i].get_property("gt_unit_locations"), extra_outputs["unit_locations"][i] + ) + assert output_sortings[i].name == "InterSessionDisplacementSorting" + + def test_shift_units_outside_probe(self, options): + """ + When `shift_units_outside_probe` is `True`, a new set of + units above and below the probe (y dimension) are created, + such that they may be shifted into the recording. + + Here, check that these new units are created when `shift_units_outside_probe` + is on and that the kwargs for the central set of units match those + as when `shift_units_outside_probe` is `False`. + """ + num_sessions = len(options["kwargs"]["recording_durations"]) + _, _, baseline_outputs = generate_session_displacement_recordings( + **options["kwargs"], + ) + + _, _, outside_probe_outputs = generate_session_displacement_recordings( + **options["kwargs"], shift_units_outside_probe=True + ) + + num_units = options["kwargs"]["num_units"] + num_extended_units = num_units * 3 + + for ses_idx in range(num_sessions): + + # There are 3x the number of units when new units are created + # (one new set above, and one new set below the probe). + for key in ["unit_locations", "templates_array_moved", "firing_rates"]: + assert outside_probe_outputs[key][ses_idx].shape[0] == num_extended_units + + assert np.array_equal( + baseline_outputs[key][ses_idx], outside_probe_outputs[key][ses_idx][num_units:-num_units] + ) + + # The kwargs of the units in the central positions should be identical + # to those when `shift_units_outside_probe` is `False`. + lower_unit_pos = outside_probe_outputs["unit_locations"][ses_idx][-num_units:][:, 1] + upper_unit_pos = outside_probe_outputs["unit_locations"][ses_idx][:num_units][:, 1] + middle_unit_pos = baseline_outputs["unit_locations"][ses_idx][:, 1] + + assert np.min(upper_unit_pos) > np.max(middle_unit_pos) + assert np.max(lower_unit_pos) < np.min(middle_unit_pos) + + def test_same_as_generate_ground_truth_recording(self): + """ + It is expected that inter-session displacement randomly + generated recording and injected motion recording will + use exactly the same method to generate the ground-truth + recording (without displacement or motion). To check this, + set their kwargs equal and seed, then generate a non-displaced + recording. It should be identical to the static recroding + generated by `generate_drifting_recording()`. + """ + + # Set some shared kwargs + num_units = 5 + duration = 10 + sampling_frequency = 30000.0 + probe_name = "Neuropixel-128" + generate_probe_kwargs = None + generate_unit_locations_kwargs = dict() + generate_templates_kwargs = dict(ms_before=1.5, ms_after=3) + generate_sorting_kwargs = dict(firing_rates=1) + generate_noise_kwargs = dict() + seed = 42 + + # Generate a inter-session displacement recording with no displacement + no_shift_recording, _ = generate_session_displacement_recordings( + num_units=num_units, + recording_durations=[duration], + recording_shifts=((0, 0),), + sampling_frequency=sampling_frequency, + probe_name=probe_name, + generate_probe_kwargs=generate_probe_kwargs, + generate_unit_locations_kwargs=generate_unit_locations_kwargs, + generate_templates_kwargs=generate_templates_kwargs, + generate_sorting_kwargs=generate_sorting_kwargs, + generate_noise_kwargs=generate_noise_kwargs, + seed=seed, + ) + no_shift_recording = no_shift_recording[0] + + # Generate a drifting recording with no drift + static_recording, _, _ = generate_drifting_recording( + num_units=num_units, + duration=duration, + sampling_frequency=sampling_frequency, + probe_name=probe_name, + generate_probe_kwargs=generate_probe_kwargs, + generate_unit_locations_kwargs=generate_unit_locations_kwargs, + generate_templates_kwargs=generate_templates_kwargs, + generate_sorting_kwargs=generate_sorting_kwargs, + generate_noise_kwargs=generate_noise_kwargs, + generate_displacement_vector_kwargs=dict( + motion_list=[ + dict( + drift_mode="zigzag", + non_rigid_gradient=None, + t_start_drift=1.0, + t_end_drift=None, + period_s=200, + ), + ] + ), + seed=seed, + ) + + # Check the templates and raw data match exactly. + assert np.array_equal( + no_shift_recording.get_traces(start_frame=0, end_frame=10), + static_recording.get_traces(start_frame=0, end_frame=10), + ) + + assert np.array_equal(no_shift_recording.templates, static_recording.drifting_templates.templates_array) diff --git a/src/spikeinterface/preprocessing/__init__.py b/src/spikeinterface/preprocessing/__init__.py index 3343217090..9e66044bb9 100644 --- a/src/spikeinterface/preprocessing/__init__.py +++ b/src/spikeinterface/preprocessing/__init__.py @@ -2,6 +2,20 @@ from .motion import correct_motion, load_motion_info, save_motion_info, get_motion_parameters_preset, get_motion_presets +""" For this to work, I think `get_spatial_interpolation_kernel` could go to core + (or elsewhere) to avoid circular imports. Currently sortingcomponents.motion + requires it but inter-session-alignment requires sortingcomponents.motion. + +from .inter_session_alignment.session_alignment import ( + get_estimate_histogram_kwargs, + get_compute_alignment_kwargs, + get_non_rigid_window_kwargs, + get_interpolate_motion_kwargs, + align_sessions, + align_sessions_after_motion_correction, + compute_peaks_locations_for_session_alignment, +) +""" from .preprocessing_tools import get_spatial_interpolation_kernel from .detect_bad_channels import detect_bad_channels from .correct_lsb import correct_lsb diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/__init__.py b/src/spikeinterface/preprocessing/inter_session_alignment/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py b/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py new file mode 100644 index 0000000000..05e6700708 --- /dev/null +++ b/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py @@ -0,0 +1,484 @@ +import time + +from spikeinterface import BaseRecording +import numpy as np + +from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram +from spikeinterface.sortingcomponents.motion.iterative_template import kriging_kernel +from packaging.version import Version + + +# ############################################################################# +# Get Histograms +# ############################################################################# + + +def get_2d_activity_histogram( + recording: BaseRecording, + peaks: np.ndarray, + peak_locations: np.ndarray, + spatial_bin_edges: np.ndarray, + bin_s: float | None, + depth_smooth_um: float | None, + scale_to_hz: bool = False, + weight_with_amplitude: bool = False, + avg_in_bin: bool = True, +): + """ + Generate a 2D activity histogram for the session. Wraps the underlying + spikeinterface function with some adjustments for scaling to time and + log transform. + + Parameters + ---------- + + recording: BaseRecording, + A SpikeInterface recording object. + peaks: np.ndarray, + A SpikeInterface `peaks` array. + peak_locations: np.ndarray, + A SpikeInterface `peak_locations` array. + spatial_bin_edges: np.ndarray, + A (1 x n_bins + 1) array of spatial (probe y dimension) bin edges. + bin_s | None: float, + If `None`, a single histogram will be generated from all session + peaks. Otherwise, multiple histograms will be generated, one for + each time bin. + depth_smooth_um: float | None + If not `None`, smooth the histogram across the spatial + axis. see `make_2d_motion_histogram()` for details. + + TODO + ---- + - ask Sam whether it makes sense to integrate this function with `make_2d_motion_histogram`. + """ + activity_histogram, temporal_bin_edges, generated_spatial_bin_edges = make_2d_motion_histogram( + recording, + peaks, + peak_locations, + weight_with_amplitude=weight_with_amplitude, + direction="y", + bin_s=(bin_s if bin_s is not None else recording.get_duration(segment_index=0)), + bin_um=None, + hist_margin_um=None, + spatial_bin_edges=spatial_bin_edges, + depth_smooth_um=depth_smooth_um, + avg_in_bin=avg_in_bin, + ) + + if scale_to_hz: + if bin_s is None: + scaler = 1 / recording.get_duration() + else: + scaler = 1 / np.diff(temporal_bin_edges)[:, np.newaxis] + + activity_histogram *= scaler + + temporal_bin_centers = get_bin_centers(temporal_bin_edges) + spatial_bin_centers = get_bin_centers(spatial_bin_edges) + + return activity_histogram, temporal_bin_centers, spatial_bin_centers + + +def get_bin_centers(bin_edges): + return (bin_edges[1:] + bin_edges[:-1]) / 2 + + +def estimate_chunk_size(scaled_activity_histogram): + """ + Estimate a chunk size based on the firing rate. Intuitively, we + want longer chunk size to better estimate low firing rates. The + estimation computes a summary of the firing rates for the session + by taking the value 25% of the max of the activity histogram. + + Then, the chunk size that will accurately estimate this firing rate + within 90% accuracy, 90% of the time based on assumption of Poisson + firing (based on CLT) is computed. + + Parameters + ---------- + + scaled_activity_histogram: np.ndarray + The activity histogram scaled to firing rate in Hz. + """ + print("scaled max", np.max(scaled_activity_histogram)) + + firing_rate = np.max(scaled_activity_histogram) * 0.25 + + lambda_hat_s = firing_rate + range_percent = 0.1 + confidence_z = 1.645 # 90% of samples in the normal distribution + e = lambda_hat_s * range_percent + + t = lambda_hat_s / (e / confidence_z) ** 2 + + print( + f"Chunked histogram window size of: {t}s estimated " + f"for firing rate (25% of histogram peak) of {lambda_hat_s}" + ) + + return 10 + + +# ############################################################################# +# Chunked Histogram estimation methods +# ############################################################################# +# Given a set off chunked_session_histograms (num time chunks x num spatial bins) +# take the summary statistic over the time axis. + + +def get_chunked_hist_mean(chunked_session_histograms): + """ """ + mean_hist = np.mean(chunked_session_histograms, axis=0) + + return mean_hist + + +def get_chunked_hist_median(chunked_session_histograms): + """ """ + median_hist = np.median(chunked_session_histograms, axis=0) + + return median_hist + + +# ############################################################################# +# TODO: MOVE creating recordings +# ############################################################################# + + +# TODO: a good test here is to give zero shift for even and off numbered hist and check the output is zero! +def compute_histogram_crosscorrelation( + session_histogram_list: np.ndarray, + non_rigid_windows: np.ndarray, + num_shifts: int, + interpolate: bool, + interp_factor: int, + kriging_sigma: float, + kriging_p: float, + kriging_d: float, + smoothing_sigma_bin: None | float, + smoothing_sigma_window: None | float, + min_crosscorr_threshold: float, +) -> tuple[np.ndarray, np.ndarray]: + """ + Given a list of session activity histograms, cross-correlate + all histograms returning the peak correlation shift (in indices) + in a symmetric (num_session x num_session) matrix. + + Supports non-rigid estimation by windowing the activity histogram + and performing separate cross-correlations on each window separately. + + Parameters + ---------- + + session_histogram_list : np.ndarray + (num_sessions, num_bins) (1d histogram) or (num_sessions, num_bins, 2) (2d histogram) + array of session activity histograms. + non_rigid_windows : np.ndarray + A (num windows x num_bins) binary of weights by which to window + the activity histogram for non-rigid-registration. For example, if + 2 rectangular masks were used, there would be a two row binary mask + the first row with mask of the first half of the probe and the second + row a mask for the second half of the probe. + num_shifts : int + Number of indices by which to shift the histogram to find the maximum + of the cross correlation. If `None`, the entire activity histograms + are cross-correlated. + interpolate : bool + If `True`, the cross-correlation is interpolated before maximum is taken. + interp_factor: + Factor by which to interpolate the cross-correlation. + kriging_sigma : float + sigma parameter for kriging_kernel function. See `kriging_kernel`. + kriging_p : float + p parameter for kriging_kernel function. See `kriging_kernel`. + kriging_d : float + d parameter for kriging_kernel function. See `kriging_kernel`. + smoothing_sigma_bin : float + sigma parameter for the gaussian smoothing kernel over the + spatial bins. + smoothing_sigma_window : float + sigma parameter for the gaussian smoothing kernel over the + non-rigid windows. + + Returns + ------- + + shift_matrix : ndarray + A (num_session x num_session) symmetric matrix of shifts + (indices) between pairs of session activity histograms. + + Notes + ----- + + - This function is very similar to the IterativeTemplateRegistration + function used in motion correct, though slightly difference in scope. + It was not convenient to merge them at this time, but worth looking + into in future. + + - Some obvious performances boosts, not done so because already fast + 1) the cross correlations for each session comparison are performed + twice. They are slightly different due to interpolation, but + still probably better to calculate once and flip. + 2) `num_shifts` is implemented by simply making the full + cross correlation. Would probably be nicer to explicitly calculate + only where needed. However, in general these cross correlations are + only a few thousand datapoints and so are already extremely + fast to cross correlate. + + Notes + ----- + + - The original kilosort method does not work in the inter-session + context because it averages over time bins to form a template to + align too. In this case, averaging over a small number of possibly + quite different session histograms does not work well. + + - In the nonrigid case, this strategy can completely fail when the xcorr + is very bad for a certain window. The smoothing and interpolation + make it much worse, because bad xcorr are merged together. The x-corr + can be bad when the recording is shifted a lot and so there are empty + regions that are correlated with non-empty regions in the nonrigid + approach. A different approach will need to be taken in this case. + + Note that kilosort method does not work because creating a + mean does not make sense over sessions. + """ + # scipy is not a core dependency + from scipy.ndimage import gaussian_filter + + num_sessions = session_histogram_list.shape[0] + num_bins = session_histogram_list.shape[1] # all hists are same length + num_windows = non_rigid_windows.shape[0] + + shift_matrix = np.zeros((num_sessions, num_sessions, num_windows)) + + center_bin = np.floor((num_bins * 2 - 1) / 2).astype(int) + + # Create the (num windows, num_bins) matrix for this pair of sessions + if num_shifts is None: + num_shifts = num_bins - 1 + shifts_array = np.arange(-(num_shifts), num_shifts + 1) + num_iter = shifts_array.size + + for i in range(num_sessions): + for j in range(i, num_sessions): + + xcorr_matrix = np.zeros((non_rigid_windows.shape[0], num_iter)) + + # For each window, window the session histograms (`window` is binary) + # and perform the cross correlations + for win_idx, window in enumerate(non_rigid_windows): + + if session_histogram_list.ndim == 3: + # For 2D histogram (spatial, amplitude), manually loop through shifts along + # the spatial axis of the histogram. This is faster than using correlate2d + # because we are not shifting along the amplitude axis. + + windowed_histogram_i = session_histogram_list[i, :] * window[:, np.newaxis] + windowed_histogram_j = session_histogram_list[j, :] * window[:, np.newaxis] + + windowed_histogram_i = (windowed_histogram_i - np.mean(windowed_histogram_i)) / ( + np.std(windowed_histogram_i) + 1e-8 + ) + windowed_histogram_j = (windowed_histogram_j - np.mean(windowed_histogram_j)) / ( + np.std(windowed_histogram_j) + 1e-8 + ) + + xcorr = np.zeros(num_iter) + + for idx, shift in enumerate(shifts_array): + + shifted_i = shift_array_fill_zeros(windowed_histogram_i, shift) + flatten_i = shifted_i.flatten() + + xcorr[idx] = np.correlate(flatten_i, windowed_histogram_j.flatten()) / flatten_i.size + + else: + # For a 1D histogram, compute the full cross-correlation and + # window the desired shifts ( this is faster than manual looping). + windowed_histogram_i = session_histogram_list[i, :] * window + windowed_histogram_j = session_histogram_list[j, :] * window + + windowed_histogram_i = (windowed_histogram_i - np.mean(windowed_histogram_i)) / ( + np.std(windowed_histogram_i) + 1e-8 + ) + windowed_histogram_j = (windowed_histogram_j - np.mean(windowed_histogram_j)) / ( + np.std(windowed_histogram_j) + 1e-8 + ) + + xcorr = np.correlate( + windowed_histogram_i, + windowed_histogram_j, + mode="full", + ) / (windowed_histogram_i.size) + + if num_shifts: + window_indices = np.arange(center_bin - num_shifts, center_bin + num_shifts + 1) + xcorr = xcorr[window_indices] + + xcorr_matrix[win_idx, :] = xcorr + + # Smooth the cross-correlations across the bins + if smoothing_sigma_bin: + xcorr_matrix = gaussian_filter(xcorr_matrix, sigma=smoothing_sigma_bin, axes=1) + + # Smooth the cross-correlations across the windows + if num_windows > 1 and smoothing_sigma_window: + xcorr_matrix = gaussian_filter(xcorr_matrix, sigma=smoothing_sigma_window, axes=0) + + # Upsample the cross-correlation + if interpolate: + + shifts_upsampled = np.linspace(shifts_array[0], shifts_array[-1], shifts_array.size * interp_factor) + + K = kriging_kernel( + np.c_[np.ones_like(shifts_array), shifts_array], + np.c_[np.ones_like(shifts_upsampled), shifts_upsampled], + sigma=kriging_sigma, + p=kriging_p, + d=kriging_d, + ) + + xcorr_matrix = np.matmul(xcorr_matrix, K, axes=[(-2, -1), (-2, -1), (-2, -1)]) + xcorr_peak = np.argmax(xcorr_matrix, axis=1) + xcorr_value = np.max(xcorr_matrix, axis=1) + shifts_to_idx = shifts_upsampled + else: + xcorr_peak = np.argmax(xcorr_matrix, axis=1) + xcorr_value = np.max(xcorr_matrix, axis=1) + shifts_to_idx = shifts_array + + shift = shifts_to_idx[xcorr_peak] + + shift[np.where(xcorr_value < min_crosscorr_threshold)] = 0 + + shift_matrix[i, j, :] = shift + + # As xcorr shifts are symmetric, the shift matrix is skew symmetric, so fill + # the (empty) lower triangular with the negative (already computed) upper triangular to save computation + for k in range(shift_matrix.shape[2]): + lower_i, lower_j = np.tril_indices_from(shift_matrix[:, :, k], k=-1) + upper_i, upper_j = np.triu_indices_from(shift_matrix[:, :, k], k=1) + shift_matrix[lower_i, lower_j, k] = shift_matrix[upper_i, upper_j, k] * -1 + + return shift_matrix, xcorr_matrix + + +def shift_array_fill_zeros(array: np.ndarray, shift: int) -> np.ndarray: + """ + Shift an array by `shift` indices, padding with zero. + Samples going out of bounds are dropped i,e, the array is not + extended and samples are not wrapped around to the start of the array. + + Parameters + ---------- + + array : np.ndarray + The array to pad. + shift : int + Number of indices why which to shift the array. If positive, the + zeros are added from the end of the array. If negative, the zeros + are added from the start of the array. + + Returns + ------- + + cut_padded_array : np.ndarray + The `array` padded with zeros and cut down (i.e. out of bounds + samples dropped). + + """ + abs_shift = np.abs(shift) + pad_tuple = (0, abs_shift) if shift > 0 else (abs_shift, 0) + + if array.ndim == 2: + pad_tuple = (pad_tuple, (0, 0)) + + padded_hist = np.pad(array, pad_tuple, mode="constant") + + if padded_hist.ndim == 2: + cut_padded_array = padded_hist[abs_shift:, :] if shift >= 0 else padded_hist[:-abs_shift, :] + else: + cut_padded_array = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift] + + return cut_padded_array + + +def akima_interpolate_nonrigid_shifts( + non_rigid_shifts: np.ndarray, + non_rigid_window_centers: np.ndarray, + spatial_bin_centers: np.ndarray, +): + """ + Perform Akima spline interpolation on a set of non-rigid shifts. + The non-rigid shifts are per segment of the probe, each segment + containing a number of channels. Interpolating these non-rigid + shifts to the spatial bin centers gives a more accurate shift + per channel. + + Parameters + ---------- + non_rigid_shifts : np.ndarray + non_rigid_window_centers : np.ndarray + spatial_bin_centers : np.ndarray + + Returns + ------- + interp_nonrigid_shifts : np.ndarray + An array (length num_spatial_bins) of shifts + interpolated from the non-rigid shifts. + + """ + import scipy + + if Version(scipy.__version__) < Version("1.14.0"): + raise ImportError("Scipy version 14 or higher is required fro Akima interpolation.") + + from scipy.interpolate import Akima1DInterpolator + + x = non_rigid_window_centers + xs = spatial_bin_centers + + num_sessions = non_rigid_shifts.shape[0] + num_bins = spatial_bin_centers.shape[0] + + interp_nonrigid_shifts = np.zeros((num_sessions, num_bins)) + for ses_idx in range(num_sessions): + + y = non_rigid_shifts[ses_idx] + y_new = Akima1DInterpolator(x, y, method="akima", extrapolate=True)(xs) + interp_nonrigid_shifts[ses_idx, :] = y_new + + return interp_nonrigid_shifts + + +def get_shifts_from_session_matrix(alignment_order: str, session_offsets_matrix: np.ndarray): + """ + Given a matrix of displacements between all sessions, find the + shifts (one per session) to bring the sessions into alignment. + Assumes `session_offsets_matrix` is skew symmetric. + + Parameters + ---------- + alignment_order : "to_middle" or "to_session_X" where + "N" is the number of the session to align to. + session_offsets_matrix : np.ndarray + The num_sessions x num_sessions symmetric matrix + of displacements between all sessions, generated by + `_compute_session_alignment()`. + + Returns + ------- + optimal_shift_indices : np.ndarray + A 1 x num_sessions array of shifts to apply to + each session in order to bring all sessions into + alignment. + """ + if alignment_order == "to_middle": + optimal_shift_indices = -np.mean(session_offsets_matrix, axis=0) + else: + ses_idx = int(alignment_order.split("_")[-1]) - 1 + optimal_shift_indices = -session_offsets_matrix[ses_idx, :, :] + + return optimal_shift_indices diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py new file mode 100644 index 0000000000..a1be2eaa78 --- /dev/null +++ b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py @@ -0,0 +1,1026 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from spikeinterface.core.baserecording import BaseRecording + +import warnings +import numpy as np +from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording +from spikeinterface.sortingcomponents.motion.motion_utils import get_spatial_windows, get_spatial_bins +from spikeinterface.sortingcomponents.motion.motion_interpolation import correct_motion_on_peaks +from spikeinterface.sortingcomponents.motion.motion_utils import make_3d_motion_histograms +from spikeinterface.core.motion import Motion +from spikeinterface.preprocessing.motion import run_peak_detection_pipeline_node +from spikeinterface.preprocessing.inter_session_alignment import alignment_utils + +import copy + + +def get_estimate_histogram_kwargs() -> dict: + """ + A dictionary controlling how the histogram for each session is + computed. The session histograms are estimated by chunking + the recording into time segments and computing histograms + for each chunk, then performing some summary statistic over + the chunked histograms. + + Returns + ------- + A dictionary with entries: + + "bin_um" : number of spatial histogram bins. As the estimated peak + locations are continuous (i.e. real numbers) this is not constrained + by the number of channels. + "method" : may be "chunked_mean", "chunked_median" + "chunked_poisson". Determines the summary statistic used over + the histograms computed across a session. See `alignment_utils.py + for details on each method. + "chunked_bin_size_s" : The length in seconds (float) to chunk the recording + for estimating the chunked histograms. Can be set to "estimate" (str), + and the size is estimated from firing frequencies. + "log_transform" : if `True`, histograms are log transformed. + "depth_smooth_um" : if `None`, no smoothing is applied. See + `make_2d_motion_histogram`. + """ + return { + "bin_um": 2, + "method": "chunked_mean", + "chunked_bin_size_s": "estimate", + "log_transform": True, + "depth_smooth_um": None, + "histogram_type": "1d", + "weight_with_amplitude": False, + "avg_in_bin": False, + } + + +def get_compute_alignment_kwargs() -> dict: + """ + A dictionary with settings controlling how inter-session + alignment is estimated and computed given a set of + session activity histograms. + + All keys except for "non_rigid_window_kwargs" determine + how alignment is estimated, based on the kilosort ("kilosort_like" + in spikeinterface) motion correction method. See + `iterative_template_registration` for details. + + "non_rigid_window_kwargs" : if nonrigid alignment + is performed, this determines the nature of the + windows along the probe depth. See `get_spatial_windows`. + """ + return { + "num_shifts_global": None, + "num_shifts_block": 20, + "interpolate": False, + "interp_factor": 10, + "kriging_sigma": 1, + "kriging_p": 2, + "kriging_d": 2, + "smoothing_sigma_bin": 0.5, + "smoothing_sigma_window": 0.5, + "akima_interp_nonrigid": False, + "min_crosscorr_threshold": 0.001, + } + + +def get_non_rigid_window_kwargs(): + """ + see get_spatial_windows() for parameters. + + TODO + ---- + merge with motion correction kwargs which are + defined in the function signature. + """ + return { + "rigid": True, + "win_shape": "gaussian", + "win_step_um": 50, + "win_scale_um": 50, + "win_margin_um": None, + "zero_threshold": None, + } + + +def get_interpolate_motion_kwargs(): + """ + Settings to pass to `InterpolateMotionRecording`, + see that class for parameter descriptions. + """ + return { + "border_mode": "force_zeros", # fixed as this until can figure out probe + "spatial_interpolation_method": "kriging", + "sigma_um": 20.0, + "p": 2, + } + + +############################################################################### +# Public Entry Level Functions +############################################################################### + + +def align_sessions( + recordings_list: list[BaseRecording], + peaks_list: list[np.ndarray], + peak_locations_list: list[np.ndarray], + alignment_order: str = "to_middle", + non_rigid_window_kwargs: dict = get_non_rigid_window_kwargs(), + estimate_histogram_kwargs: dict = get_estimate_histogram_kwargs(), + compute_alignment_kwargs: dict = get_compute_alignment_kwargs(), + interpolate_motion_kwargs: None | dict = get_interpolate_motion_kwargs(), +) -> tuple[list[BaseRecording], dict]: + """ + Estimate probe displacement across recording sessions and + return interpolated, displacement-corrected recording. Displacement + is only estimated along the "y" dimension. + + This assumes peaks and peak locations have already been computed. + See `compute_peaks_locations_for_session_alignment` for generating + `peaks_list` and `peak_locations_list` from a `recordings_list`. + + If a recording in `recordings_list` is already an `InterpolateMotionRecording`, + the displacement will be added to the existing shifts to avoid duplicate + interpolations. Note the returned, corrected recording is a copy + (recordings in `recording_list` are not edited in-place). + + Parameters + ---------- + recordings_list : list[BaseRecording] + A list of recordings to be aligned. + peaks_list : list[np.ndarray] + A list of peaks detected from the recordings in `recordings_list`, + as returned from the `detect_peaks` function. Each entry in + `peaks_list` should be from the corresponding entry in `recordings_list`. + peak_locations_list : list[np.ndarray] + A list of peak locations, as computed by `localize_peaks`. Each entry + in `peak_locations_list` should be matched to the corresponding entry + in `peaks_list` and `recordings_list`. + alignment_order : str + "to_middle" will align all sessions to the mean position. + Alternatively, "to_session_N" where "N" is a session number + will align to the Nth session. + non_rigid_window_kwargs : dict + see `get_non_rigid_window_kwargs` + estimate_histogram_kwargs : dict + see `get_estimate_histogram_kwargs()` + compute_alignment_kwargs : dict + see `get_compute_alignment_kwargs()` + interpolate_motion_kwargs : dict + see `get_interpolate_motion_kwargs()` Will not be used if passed + recording is InterpolateMotionRecording (in which case, do not + use this function but use `compute_peaks_locations_for_session_alignment`. + + Returns + ------- + `corrected_recordings_list : list[BaseRecording] + List of displacement-corrected recordings (corresponding + in order to `recordings_list`). If an input recordings is + an InterpolateMotionRecording` recording, the corrected + output recording will be a copy of the input recording with + the additional displacement correction added. + + extra_outputs_dict : dict + Dictionary of features used in the alignment estimation and correction. + + shifts_array : np.ndarray + A (num_sessions x num_rigid_windows) array of shifts. + session_histogram_list : list[np.ndarray] + A list of histograms (one per session) used for the alignment. + spatial_bin_centers : np.ndarray + The spatial bin centers, shared between all recordings. + temporal_bin_centers_list : list[np.ndarray] + List of temporal bin centers. As alignment is based on a single + histogram per session, this contains only 1 value per recording, + which is the mid-timepoint of the recording. + non_rigid_window_centers : np.ndarray + Window centers of the probe segments used for non-rigid alignment. + If rigid alignment is performed, this is a single value (mid-probe). + non_rigid_windows : np.ndarray + A (num nonrigid windows, num spatial_bin_centers) binary array used to mask + the probe segments for non-rigid alignment. If rigid alignment is performed, + this a vector of ones with length (spatial_bin_centers,) + histogram_info_list :list[dict] + see `_get_single_session_activity_histogram()` for details. + motion_objects_list : + List of motion objects containing the shifts and spatial and temporal + bins for each recording. Note this contains only displacement + associated with the inter-session alignment, and so will differ from + the motion on corrected recording objects if the recording is + already an `InterpolateMotionRecording` object containing + within-session motion correction. + corrected : dict + Dictionary containing corrected-histogram + information. + corrected_peak_locations_list : + Displacement-corrected `peak_locations`. + corrected_session_histogram_list : + Corrected activity histogram (computed from the corrected peak locations). + """ + non_rigid_window_kwargs = copy.deepcopy(non_rigid_window_kwargs) + estimate_histogram_kwargs = copy.deepcopy(estimate_histogram_kwargs) + compute_alignment_kwargs = copy.deepcopy(compute_alignment_kwargs) + interpolate_motion_kwargs = copy.deepcopy(interpolate_motion_kwargs) + + # Ensure list lengths match and all channel locations are the same across recordings. + _check_align_sessions_inputs( + recordings_list, + peaks_list, + peak_locations_list, + alignment_order, + estimate_histogram_kwargs, + interpolate_motion_kwargs, + ) + + print("Computing a single activity histogram from each session...") + + (session_histogram_list, temporal_bin_centers_list, spatial_bin_centers, spatial_bin_edges, histogram_info_list) = ( + _compute_session_histograms(recordings_list, peaks_list, peak_locations_list, **estimate_histogram_kwargs) + ) + + print("Aligning the activity histograms across sessions...") + + contact_depths = recordings_list[0].get_channel_locations()[:, 1] + + shifts_array, non_rigid_windows, non_rigid_window_centers = _compute_session_alignment( + session_histogram_list, + contact_depths, + spatial_bin_centers, + alignment_order, + non_rigid_window_kwargs, + compute_alignment_kwargs, + ) + shifts_array *= estimate_histogram_kwargs["bin_um"] + + print("Creating corrected recordings...") + + corrected_recordings_list, motion_objects_list = _create_motion_recordings( + recordings_list, shifts_array, temporal_bin_centers_list, non_rigid_window_centers, interpolate_motion_kwargs + ) + + print("Creating corrected peak locations and histograms...") + + corrected_peak_locations_list, corrected_session_histogram_list = _correct_session_displacement( + corrected_recordings_list, + peaks_list, + peak_locations_list, + motion_objects_list, + spatial_bin_edges, + estimate_histogram_kwargs, + ) + + extra_outputs_dict = { + "shifts_array": shifts_array, + "session_histogram_list": session_histogram_list, + "spatial_bin_centers": spatial_bin_centers, + "temporal_bin_centers_list": temporal_bin_centers_list, + "non_rigid_window_centers": non_rigid_window_centers, + "non_rigid_windows": non_rigid_windows, + "histogram_info_list": histogram_info_list, + "motion_objects_list": motion_objects_list, + "corrected": { + "corrected_peak_locations_list": corrected_peak_locations_list, + "corrected_session_histogram_list": corrected_session_histogram_list, + }, + } + return corrected_recordings_list, extra_outputs_dict + + +def align_sessions_after_motion_correction( + recordings_list: list[BaseRecording], motion_info_list: list[dict], align_sessions_kwargs: dict | None +) -> tuple[list[BaseRecording], dict]: + """ + Convenience function to run `align_sessions` to correct for + inter-session displacement from the outputs of motion correction. + + The estimated displacement will be added directly to the recording. + + Parameters + ---------- + recordings_list : list[BaseRecording] + A list of motion-corrected (`InterpolateMotionRecording`) recordings. + motion_info_list : list[dict] + A list of `motion_info` objects, as output from `correct_motion`. + Each entry should correspond to a recording in `recording_list`. + align_sessions_kwargs : dict + A dictionary of keyword arguments passed to `align_sessions`. + + """ + # Check motion kwargs are the same across all recordings + if not all(isinstance(rec, InterpolateMotionRecording) for rec in recordings_list): + raise ValueError( + "All passed recordings have been run with motion correction as the last step. " + "They must be InterpolateMotionRecording." + ) + + motion_kwargs_list = [info["parameters"]["estimate_motion_kwargs"] for info in motion_info_list] + if not all(kwargs == motion_kwargs_list[0] for kwargs in motion_kwargs_list): + raise ValueError( + "The motion correct settings used on the `recordings_list` must be identical for all recordings" + ) + + motion_window_kwargs = copy.deepcopy(motion_kwargs_list[0]) + + if "direction" in motion_window_kwargs and motion_window_kwargs["direction"] != "y": + raise ValueError("motion correct must have been performed along the 'y' dimension.") + + if align_sessions_kwargs is None: + align_sessions_kwargs = {} + else: + # If motion correction was nonrigid, we must use the same settings for + # inter-session alignment, or we will not be able to add the nonrigid + # shifts together. + if ( + "non_rigid_window_kwargs" in align_sessions_kwargs + and not align_sessions_kwargs["non_rigid_window_kwargs"]["rigid"] + ): + if not motion_window_kwargs["rigid"]: + warnings.warn( + "Nonrigid inter-session alignment must use the motion correct " + "nonrigid settings.\n!Now overwriting any passed `non_rigid_window_kwargs` " + "with the motion object's non_rigid_window_kwargs !" + ) + non_rigid_window_kwargs = get_non_rigid_window_kwargs() + + for ( + key, + value, + ) in motion_window_kwargs.items(): + if key in non_rigid_window_kwargs: + non_rigid_window_kwargs[key] = value + + align_sessions_kwargs = copy.deepcopy(align_sessions_kwargs) + align_sessions_kwargs["non_rigid_window_kwargs"] = non_rigid_window_kwargs + + if "interpolate_motion_kwargs" in align_sessions_kwargs: + raise ValueError( + "Cannot set `interpolate_motion_kwargs` when using this function. " + "The interpolate kwargs from the original motion recording will be used." + ) + # This does not do anything, just makes it explicit these are not used. + align_sessions_kwargs["interpolate_motion_kwargs"] = None + + corrected_peak_locations = [ + correct_motion_on_peaks(info["peaks"], info["peak_locations"], info["motion"], recording) + for info, recording in zip(motion_info_list, recordings_list) + ] + + return align_sessions( + recordings_list, + [info["peaks"] for info in motion_info_list], + corrected_peak_locations, + **align_sessions_kwargs, + ) + + +def compute_peaks_locations_for_session_alignment( + recording_list: list[BaseRecording], + detect_kwargs: dict, + localize_peaks_kwargs: dict, + job_kwargs: dict | None = None, + gather_mode: str = "memory", +): + """ + A convenience function to compute `peaks_list` and `peak_locations_list` + from a list of recordings, for `align_sessions`. + + Parameters + ---------- + recording_list : list[BaseRecording] + A list of recordings to compute `peaks` and + `peak_locations` for. + detect_kwargs : dict + Arguments to be passed to `detect_peaks`. + localize_peaks_kwargs : dict + Arguments to be passed to `localise_peaks`. + job_kwargs : dict | None + `job_kwargs` for `run_node_pipeline()`. + gather_mode : str + The mode for `run_node_pipeline()`. + """ + if job_kwargs is None: + job_kwargs = {} + + peaks_list = [] + peak_locations_list = [] + + for recording in recording_list: + peaks, peak_locations, _ = run_peak_detection_pipeline_node( + recording, gather_mode, detect_kwargs, localize_peaks_kwargs, job_kwargs + ) + peaks_list.append(peaks) + peak_locations_list.append(peak_locations) + + return peaks_list, peak_locations_list + + +############################################################################### +# Private Functions +############################################################################### + + +def _compute_session_histograms( + recordings_list: list[BaseRecording], + peaks_list: list[np.ndarray], + peak_locations_list: list[np.ndarray], + histogram_type, # TODO think up better names + bin_um: float, + method: str, + chunked_bin_size_s: float | "estimate", + depth_smooth_um: float, + log_transform: bool, + weight_with_amplitude: bool, + avg_in_bin: bool, +) -> tuple[list[np.ndarray], list[np.ndarray], np.ndarray, np.ndarray, list[dict]]: + """ + Compute a 1d activity histogram for the session. As + sessions may be long, the approach taken is to chunk + the recording into time segments and compute separate + histograms for each. Then, a summary statistic is computed + over the histograms. This accounts for periods of noise + in the recording or segments of irregular spiking. + + Parameters + ---------- + see `align_sessions` for `recording_list`, `peaks_list`, + `peak_locations_list`. + + see `get_estimate_histogram_kwargs()` for all other kwargs. + + Returns + ------- + + session_histogram_list : list[np.ndarray] + A list of activity histograms (1 x n_bins), one per session. + This is the histogram which summarises all chunked histograms. + + temporal_bin_centers_list : list[np.ndarray] + A list of temporal bin centers, one per session. We have one + histogram per session, the temporal bin has 1 entry, the + mid-time point of the session. + + spatial_bin_centers : np.ndarray + A list of spatial bin centers corresponding to the session + activity histograms. + + spatial_bin_edges : np.ndarray + The corresponding spatial bin edges + + histogram_info_list : list[dict] + A list of extra information on the histograms generation + (e.g. chunked histograms). One per session. See + `_get_single_session_activity_histogram()` for details. + """ + # Get spatial windows (shared across all histograms) + # and estimate the session histograms + temporal_bin_centers_list = [] + + spatial_bin_centers, spatial_bin_edges, _ = get_spatial_bins( + recordings_list[0], direction="y", hist_margin_um=0, bin_um=bin_um + ) + + session_histogram_list = [] + histogram_info_list = [] + + for recording, peaks, peak_locations in zip(recordings_list, peaks_list, peak_locations_list): + + session_hist, temporal_bin_centers, histogram_info = _get_single_session_activity_histogram( + recording, + peaks, + peak_locations, + histogram_type=histogram_type, + spatial_bin_edges=spatial_bin_edges, + method=method, + log_transform=log_transform, + chunked_bin_size_s=chunked_bin_size_s, + depth_smooth_um=depth_smooth_um, + weight_with_amplitude=weight_with_amplitude, + avg_in_bin=avg_in_bin, + ) + temporal_bin_centers_list.append(temporal_bin_centers) + session_histogram_list.append(session_hist) + histogram_info_list.append(histogram_info) + + return ( + session_histogram_list, + temporal_bin_centers_list, + spatial_bin_centers, + spatial_bin_edges, + histogram_info_list, + ) + + +def _get_single_session_activity_histogram( + recording: BaseRecording, + peaks: np.ndarray, + peak_locations: np.ndarray, + histogram_type, + spatial_bin_edges: np.ndarray, + method: str, + log_transform: bool, + chunked_bin_size_s: float | "estimate", + depth_smooth_um: float, + weight_with_amplitude: bool, + avg_in_bin: bool, +) -> tuple[np.ndarray, np.ndarray, dict]: + """ + Compute an activity histogram for a single session. + The recording is chunked into time segments, histograms + estimated and a summary statistic calculated across the histograms + + Note if `chunked_bin_size_is` is set to `"estimate"` the + histogram for the entire session is first created to get a good + estimate of the firing rates. + The firing rates are used to use a time segment size that will + allow a good estimation of the firing rate. + + Parameters + ---------- + `spatial_bin_edges : np.ndarray + The spatial bin edges for the created histogram. This is + explicitly required as for inter-session alignment, the + session histograms must share bin edges. + + see `_compute_session_histograms()` for all other keyword arguments. + + Returns + ------- + session_histogram : np.ndarray + Summary activity histogram for the session. + temporal_bin_centers : np.ndarray + Temporal bin center (session mid-point as we only have + one time point) for the session. + histogram_info : dict + A dict of additional info including: + "chunked_histograms" : The chunked histograms over which + the summary histogram was calculated. + "chunked_temporal_bin_centers" : The temporal vin centers + for the chunked histograms, with length num_chunks. + "session_std" : The mean across bin-wise standard deviation + of the chunked histograms. + "chunked_bin_size_s" : time of each chunk used to + calculate the chunked histogram. + """ + times = recording.get_times() + temporal_bin_centers = np.atleast_1d((times[-1] + times[0]) / 2) + + # Estimate an entire session histogram if requested or doing + # full estimation for chunked bin size + if chunked_bin_size_s == "estimate": + + scaled_hist, _, _ = alignment_utils.get_2d_activity_histogram( + recording, + peaks, + peak_locations, + spatial_bin_edges, + bin_s=None, + depth_smooth_um=None, + scale_to_hz=True, + weight_with_amplitude=False, + avg_in_bin=False, + ) + + # It is important that the passed histogram is scaled to firing rate in Hz + chunked_bin_size_s = alignment_utils.estimate_chunk_size(scaled_hist) + chunked_bin_size_s = np.min([chunked_bin_size_s, recording.get_duration()]) + + if histogram_type == "1d": + + chunked_histograms, chunked_temporal_bin_centers, _ = alignment_utils.get_2d_activity_histogram( + recording, + peaks, + peak_locations, + spatial_bin_edges, + bin_s=chunked_bin_size_s, + depth_smooth_um=depth_smooth_um, + weight_with_amplitude=weight_with_amplitude, + avg_in_bin=avg_in_bin, + scale_to_hz=True, + ) + + elif histogram_type in ["2d"]: + + if histogram_type == "2d": + + chunked_histograms, chunked_temporal_bin_edges, _ = make_3d_motion_histograms( + recording, + peaks, + peak_locations, + direction="y", + bin_s=chunked_bin_size_s, + bin_um=None, + hist_margin_um=50, + num_amp_bins=20, + log_transform=False, + spatial_bin_edges=spatial_bin_edges, + ) + + chunked_temporal_bin_centers = alignment_utils.get_bin_centers(chunked_temporal_bin_edges) + + if method == "chunked_mean": + session_histogram = alignment_utils.get_chunked_hist_mean(chunked_histograms) + + elif method == "chunked_median": + session_histogram = alignment_utils.get_chunked_hist_median(chunked_histograms) + + if log_transform: + session_histogram = np.log2(1 + session_histogram) + + histogram_info = { + "chunked_histograms": chunked_histograms, + "chunked_temporal_bin_centers": chunked_temporal_bin_centers, + "chunked_bin_size_s": chunked_bin_size_s, + } + + return session_histogram, temporal_bin_centers, histogram_info + + +def _create_motion_recordings( + recordings_list: list[BaseRecording], + shifts_array: np.ndarray, + temporal_bin_centers_list: list[np.ndarray], + non_rigid_window_centers: np.ndarray, + interpolate_motion_kwargs: dict, +) -> tuple[list[BaseRecording], list[Motion]]: + """ + Given a set of recordings, motion shifts and bin information per-recording, + generate an InterpolateMotionRecording. If the recording is already an + InterpolateMotionRecording, then the shifts will be added to a copy + of it. Copies of the Recordings are made, nothing is changed in-place. + + Parameters + ---------- + shifts_array : num_sessions x num_nonrigid bins + + Returns + ------- + corrected_recordings_list : list[BaseRecording] + A list of InterpolateMotionRecording recordings of shift-corrected + recordings corresponding to `recordings_list`. + + motion_objects_list : list[Motion] + A list of Motion objects. If the recording in `recordings_list` + is already an InterpolateMotionRecording, this will be `None`, as + no motion object is created (the existing motion object is added to) + """ + assert all(array.ndim == 1 for array in shifts_array), "time dimension should be 1 for session displacement" + + corrected_recordings_list = [] + motion_objects_list = [] + for ses_idx, recording in enumerate(recordings_list): + + session_shift = shifts_array[ses_idx][np.newaxis, :] + + motion = Motion([session_shift], [temporal_bin_centers_list[ses_idx]], non_rigid_window_centers, direction="y") + motion_objects_list.append(motion) + + if isinstance(recording, InterpolateMotionRecording): + + print("Recording is already an `InterpolateMotionRecording. Adding shifts directly the recording object.") + + corrected_recording = _add_displacement_to_interpolate_recording(recording, motion) + else: + corrected_recording = InterpolateMotionRecording( + recording, + motion, + interpolation_time_bin_centers_s=motion.temporal_bins_s, + interpolation_time_bin_edges_s=[np.array(recording.get_times()[0], recording.get_times()[-1])], + **interpolate_motion_kwargs, + ) + corrected_recording = corrected_recording.set_probe( + recording.get_probe() + ) # TODO: if this works, might need to do above + + corrected_recordings_list.append(corrected_recording) + + return corrected_recordings_list, motion_objects_list + + +def _add_displacement_to_interpolate_recording( + original_recording: BaseRecording, + session_displacement_motion: Motion, +): + """ + This function adds a shift to an InterpolateMotionRecording. + + There are four cases: + - The original recording is rigid and new shift is rigid (shifts are added). + - The original recording is rigid and new shifts are non-rigid (sets the + non-rigid shifts onto the recording, then adds back the original shifts). + - The original recording is nonrigid and the new shifts are rigid (rigid + shift added to all nonlinear shifts) + - The original recording is nonrigid and the new shifts are nonrigid + (respective non-rigid shifts are added, must have same number of + non-rigid windows). + + Parameters + ---------- + see `_create_motion_recordings()` + + Returns + ------- + corrected_recording : InterpolateMotionRecording + A copy of the `recording` with new shifts added. + + TODO + ---- + Check + ask Sam if any other fields need to be changed. This is a little + hairy (4 possible combinations of new and old displacement shapes, + rigid or nonrigid, so test thoroughly. + """ + # Everything is done in place, so keep a short variable + # name reference to the new recordings `motion` object + # and update it.okay + corrected_recording = copy.deepcopy(original_recording) + + shifts_to_add = session_displacement_motion.displacement[0] + new_non_rigid_window_centers = session_displacement_motion.spatial_bins_um + + motion_ref = corrected_recording._recording_segments[0].motion + recording_bins = motion_ref.displacement[0].shape[1] + + # If the new displacement is a scalar (i.e. rigid), + # just add it to the existing displacements + if shifts_to_add.shape[1] == 1: + motion_ref.displacement[0] += shifts_to_add[0, 0] + + else: + if recording_bins == 1: + # If the new displacement is nonrigid (multiple windows) but the motion + # recording is rigid, we update the displacement at all time bins + # with the new, nonrigid displacement added to the old, rigid displacement. + num_time_bins = motion_ref.displacement[0].shape[0] + tiled_nonrigid_displacement = np.repeat(shifts_to_add, num_time_bins, axis=0) + shifts_to_add = tiled_nonrigid_displacement + motion_ref.displacement + + motion_ref.displacement = shifts_to_add + motion_ref.spatial_bins_um = new_non_rigid_window_centers + else: + # Otherwise, if both the motion and new displacement are + # nonrigid, we need to make sure the nonrigid windows + # match exactly. + assert np.array_equal(motion_ref.spatial_bins_um, new_non_rigid_window_centers) + assert motion_ref.displacement[0].shape[1] == shifts_to_add.shape[1] + + motion_ref.displacement[0] += shifts_to_add + + return corrected_recording + + +def _correct_session_displacement( + recordings_list: list[BaseRecording], + peaks_list: list[np.ndarray], + peak_locations_list: list[np.ndarray], + motion_objects_list: list[Motion], + spatial_bin_edges: np.ndarray, + estimate_histogram_kwargs: dict, +): + """ + Internal function to apply the correction from `align_sessions` + to build a corrected histogram for comparison. First, create + new shifted peak locations. Then, create a new 'corrected' + activity histogram from the new peak locations. + + Parameters + ---------- + see `align_sessions()` for parameters. + + Returns + ------- + corrected_peak_locations_list : list[np.ndarray] + A list of peak locations corrected by the inter-session + shifts (one entry per session). + corrected_session_histogram_list : list[np.ndarray] + A list of histograms calculated from the corrected peaks (one per session). + """ + corrected_peak_locations_list = [] + + for recording, peaks, peak_locations, motion in zip( + recordings_list, peaks_list, peak_locations_list, motion_objects_list + ): + + # Note this `motion` is not necessarily the same as the motion on the recording. If the recording + # is an `InterpolateMotionRecording`, it will contain correction for both motion and inter-session displacement. + # Here we want to correct only the motion associated with inter-session displacement. + corrected_peak_locs = correct_motion_on_peaks( + peaks, + peak_locations, + motion, + recording, + ) + corrected_peak_locations_list.append(corrected_peak_locs) + + corrected_session_histogram_list = [] + + for recording, peaks, corrected_locations in zip(recordings_list, peaks_list, corrected_peak_locations_list): + session_hist, _, _ = _get_single_session_activity_histogram( + recording, + peaks, + corrected_locations, + estimate_histogram_kwargs["histogram_type"], + spatial_bin_edges, + estimate_histogram_kwargs["method"], + estimate_histogram_kwargs["log_transform"], + estimate_histogram_kwargs["chunked_bin_size_s"], + estimate_histogram_kwargs["depth_smooth_um"], + estimate_histogram_kwargs["weight_with_amplitude"], + estimate_histogram_kwargs["avg_in_bin"], + ) + corrected_session_histogram_list.append(session_hist) + + return corrected_peak_locations_list, corrected_session_histogram_list + + +######################################################################################################################## + + +def _compute_session_alignment( + session_histogram_list: list[np.ndarray], + contact_depths: np.ndarray, + spatial_bin_centers: np.ndarray, + alignment_order: str, + non_rigid_window_kwargs: dict, + compute_alignment_kwargs: dict, +) -> tuple[np.ndarray, ...]: + """ + Given a list of activity histograms (one per session) compute + rigid or non-rigid set of shifts (one per session) that will bring + all sessions into alignment. + + For rigid shifts, a cross-correlation between activity + histograms is performed. For non-rigid shifts, the probe + is split into segments, and linear estimation of shift + performed for each segment. + + Parameters + ---------- + See `align_sessions()` for parameters + + Returns + ------- + shifts : np.ndarray + A (num_sessions x num_rigid_windows) array of shifts to bring + the histograms in `session_histogram_list` into alignment. + non_rigid_windows : np.ndarray + An array (num_non_rigid_windows x num_spatial_bins) of weightings + for each bin in each window. For rect, these are in the range [0, 1], + for Gaussian these are gaussian etc. + non_rigid_window_centers : np.ndarray + The centers (spatial, in um) of each non-rigid window. + """ + session_histogram_array = np.array(session_histogram_list) + + akima_interp_nonrigid = compute_alignment_kwargs.pop("akima_interp_nonrigid") + num_shifts_global = compute_alignment_kwargs.pop("num_shifts_global") + num_shifts_block = compute_alignment_kwargs.pop("num_shifts_block") + + non_rigid_windows, non_rigid_window_centers = get_spatial_windows( + contact_depths, spatial_bin_centers, **non_rigid_window_kwargs + ) + + rigid_shifts = _estimate_rigid_alignment( + session_histogram_array, + alignment_order, + num_shifts_global, + compute_alignment_kwargs, + ) + + if non_rigid_window_kwargs["rigid"]: + return rigid_shifts, non_rigid_windows, non_rigid_window_centers + + # For non-rigid, first shift the histograms according to the rigid shift + shifted_histograms = np.zeros_like(session_histogram_array) + for ses_idx, orig_histogram in enumerate(session_histogram_array): + + shifted_histogram = alignment_utils.shift_array_fill_zeros( + array=orig_histogram, shift=int(rigid_shifts[ses_idx, 0]) + ) + shifted_histograms[ses_idx, :] = shifted_histogram + + # Then compute the nonrigid shifts + nonrigid_session_offsets_matrix, _ = alignment_utils.compute_histogram_crosscorrelation( + shifted_histograms, non_rigid_windows, num_shifts=num_shifts_block, **compute_alignment_kwargs + ) + non_rigid_shifts = alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix) + + # Akima interpolate the nonrigid bins if required. + if akima_interp_nonrigid: + interp_nonrigid_shifts = alignment_utils.akima_interpolate_nonrigid_shifts( + non_rigid_shifts, non_rigid_window_centers, spatial_bin_centers + ) + shifts = rigid_shifts + interp_nonrigid_shifts + non_rigid_window_centers = spatial_bin_centers + else: + shifts = rigid_shifts + non_rigid_shifts + + return shifts, non_rigid_windows, non_rigid_window_centers + + +def _estimate_rigid_alignment( + session_histogram_array: np.ndarray, + alignment_order: str, + num_shifts: None | int, + compute_alignment_kwargs: dict, +): + """ + Estimate the rigid alignment from a set of activity + histograms, using simple cross-correlation. + + Parameters + ---------- + session_histogram_array : np.ndarray + A (num_sessions x num_spatial_bins) array of activity + histograms to align + alignment_order : str + Align "to_middle" or "to_session_N" (where "N" is the session number) + compute_alignment_kwargs : dict + See `get_compute_alignment_kwargs()`. + + Returns + ------- + optimal_shift_indices : np.ndarray + An array (num_sessions x 1) of shifts to bring all + session histograms into alignment. + """ + compute_alignment_kwargs = copy.deepcopy(compute_alignment_kwargs) + + rigid_window = np.ones(session_histogram_array.shape[1])[np.newaxis, :] + + rigid_session_offsets_matrix, _ = alignment_utils.compute_histogram_crosscorrelation( + session_histogram_array, + rigid_window, + num_shifts=num_shifts, + **compute_alignment_kwargs, # TODO: remove the copy above and pass directly. Consider removing this function... + ) + optimal_shift_indices = alignment_utils.get_shifts_from_session_matrix( + alignment_order, rigid_session_offsets_matrix + ) + + return optimal_shift_indices + + +# ----------------------------------------------------------------------------- +# Checkers +# ----------------------------------------------------------------------------- + + +def _check_align_sessions_inputs( + recordings_list: list[BaseRecording], + peaks_list: list[np.ndarray], + peak_locations_list: list[np.ndarray], + alignment_order: str, + estimate_histogram_kwargs: dict, + interpolate_motion_kwargs: dict, +): + """ + Perform checks on the input of `align_sessions()` + """ + num_sessions = len(recordings_list) + + if len(peaks_list) != num_sessions or len(peak_locations_list) != num_sessions: + raise ValueError( + "`recordings_list`, `peaks_list` and `peak_locations_list` " + "must be the same length. They must contains list of corresponding " + "recordings, peak and peak location objects." + ) + + if not all(rec.get_num_segments() == 1 for rec in recordings_list): + raise ValueError( + "Multi-segment recordings not supported. All recordings in `recordings_list` but have only 1 segment." + ) + + channel_locs = [rec.get_channel_locations() for rec in recordings_list] + if not all([np.array_equal(locs, channel_locs[0]) for locs in channel_locs]): + raise ValueError( + "The recordings in `recordings_list` do not all have " + "the same channel locations. All recordings must be " + "performed using the same probe." + ) + + accepted_hist_methods = [ + "entire_session", + "chunked_mean", + "chunked_median", + ] + method = estimate_histogram_kwargs["method"] + if method not in accepted_hist_methods: + raise ValueError(f"`method` option must be one of: {accepted_hist_methods}") + + if alignment_order != "to_middle": + + split_name = alignment_order.split("_") + if not "_".join(split_name[:2]) == "to_session": + raise ValueError( + "`alignment_order` must take be 'to_middle' or take the form 'to_session_X' where X is the session number to align to." + ) + + ses_num = int(split_name[-1]) + if ses_num > num_sessions: + raise ValueError( + f"`alignment_order` session {ses_num} is larger than the number of sessions in `recordings_list`." + ) + + if ses_num == 0: + raise ValueError("`alignment_order` required the session number, not session index.") diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 96751604fe..82f6abc6eb 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -337,11 +337,10 @@ def correct_motion( for plotting. See `plot_motion_info()` """ # local import are important because "sortingcomponents" is not important by default - from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods + from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks - from spikeinterface.sortingcomponents.peak_localization import localize_peaks, localize_peak_methods + from spikeinterface.sortingcomponents.peak_localization import localize_peaks from spikeinterface.sortingcomponents.motion import estimate_motion, InterpolateMotionRecording - from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline # get preset params and update if necessary params = motion_options_preset[preset] @@ -385,34 +384,11 @@ def correct_motion( if not do_selection: # maybe do this directly in the folder when not None, but might be slow on external storage gather_mode = "memory" - # node detect - method = detect_kwargs.pop("method", "locally_exclusive") - method_class = detect_peak_methods[method] - node0 = method_class(recording, **detect_kwargs) - - node1 = ExtractDenseWaveforms(recording, parents=[node0], ms_before=0.1, ms_after=0.3) - - # node detect + localize - method = localize_peaks_kwargs.pop("method", "center_of_mass") - method_class = localize_peak_methods[method] - node2 = method_class(recording, parents=[node0, node1], return_output=True, **localize_peaks_kwargs) - pipeline_nodes = [node0, node1, node2] - t0 = time.perf_counter() - peaks, peak_locations = run_node_pipeline( - recording, - pipeline_nodes, - job_kwargs, - job_name="detect and localize", - gather_mode=gather_mode, - gather_kwargs=None, - squeeze_output=False, - folder=None, - names=None, - ) - t1 = time.perf_counter() - run_times = dict( - detect_and_localize=t1 - t0, + + peaks, peak_locations, peaks_run_time = run_peak_detection_pipeline_node( + recording, gather_mode, detect_kwargs, localize_peaks_kwargs, job_kwargs ) + run_times = dict(detect_and_localize=peaks_run_time) else: # localization is done after select_peaks() pipeline_nodes = None @@ -462,6 +438,47 @@ def correct_motion( return out +def run_peak_detection_pipeline_node(recording, gather_mode, detect_kwargs, localize_peaks_kwargs, job_kwargs): + """ + TODO: add docstring + """ + from spikeinterface.sortingcomponents.peak_detection import detect_peak_methods + from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline + from spikeinterface.sortingcomponents.peak_localization import localize_peak_methods + + # Don't modify the kwargs in place in case the caller requires them + detect_kwargs = copy.deepcopy(detect_kwargs) + localize_peaks_kwargs = copy.deepcopy(localize_peaks_kwargs) + + # node detect + method = detect_kwargs.pop("method", "locally_exclusive") + method_class = detect_peak_methods[method] + node0 = method_class(recording, **detect_kwargs) + + node1 = ExtractDenseWaveforms(recording, parents=[node0], ms_before=0.1, ms_after=0.3) + + # node detect + localize + method = localize_peaks_kwargs.pop("method", "center_of_mass") + method_class = localize_peak_methods[method] + node2 = method_class(recording, parents=[node0, node1], return_output=True, **localize_peaks_kwargs) + pipeline_nodes = [node0, node1, node2] + t0 = time.perf_counter() + peaks, peak_locations = run_node_pipeline( + recording, + pipeline_nodes, + job_kwargs, + job_name="detect and localize", + gather_mode=gather_mode, + gather_kwargs=None, + squeeze_output=False, + folder=None, + names=None, + ) + run_time = time.perf_counter() - t0 + + return peaks, peak_locations, run_time + + _doc_presets = "\n" for k, v in motion_options_preset.items(): if k == "": diff --git a/src/spikeinterface/preprocessing/tests/test_inter_session_alignment.py b/src/spikeinterface/preprocessing/tests/test_inter_session_alignment.py new file mode 100644 index 0000000000..ba5c7e665d --- /dev/null +++ b/src/spikeinterface/preprocessing/tests/test_inter_session_alignment.py @@ -0,0 +1,964 @@ +import numpy as np +import pytest + +from spikeinterface.preprocessing.inter_session_alignment import session_alignment, alignment_utils +from spikeinterface.generation.session_displacement_generator import * +import spikeinterface # required for monkeypatching +import spikeinterface.full as si +from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording + +DEBUG = False + + +class TestInterSessionAlignment: + + @pytest.fixture(scope="session") + def recording_1(self): + """ + Generate a set of session recordings with displacement. + These parameters are chosen such that simulated AP signal is strong + on the probe to avoid noise in the AP positions. This is important + for checking that the estimated shift matches the known shift. + """ + shifts = ((0, 0), (0, -200), (0, 150)) + + recordings_list, _ = generate_session_displacement_recordings( + num_units=15, + recording_durations=[0.1, 0.2, 0.3], + recording_shifts=shifts, + non_rigid_gradient=None, + seed=55, + generate_sorting_kwargs=dict(firing_rates=(100, 250), refractory_period_ms=4.0), + generate_unit_locations_kwargs=dict( + margin_um=0.0, + minimum_z=0.0, + maximum_z=2.0, + minimum_distance=18.0, + max_iteration=100, + distance_strict=False, + ), + generate_noise_kwargs=dict(noise_levels=(0.0, 0.0), spatial_decay=1.0), + ) + + peaks_list, peak_locations_list = session_alignment.compute_peaks_locations_for_session_alignment( + recordings_list, + detect_kwargs={"method": "locally_exclusive"}, + localize_peaks_kwargs={"method": "grid_convolution"}, + ) + return (recordings_list, shifts, peaks_list, peak_locations_list) + + # TODO: need to make this a fixtures somehow, I guess rigid or nonrigid case. + @pytest.fixture(scope="session") + def recording_2(self): + """ + Get a shifted inter-session alignment recording. First, interpolate-motion + within session. The purpose of these tests is then to run inter-session alignment + and check the displacements are properly added. + """ + shifts = ((0, 0), (0, 250)) + + recordings_list, _ = generate_session_displacement_recordings( + num_units=5, + recording_durations=[0.3, 0.3], + recording_shifts=shifts, + non_rigid_gradient=0.2, + seed=55, # 52 + generate_sorting_kwargs=dict(firing_rates=(100, 250), refractory_period_ms=4.0), + generate_unit_locations_kwargs=dict( + margin_um=0.0, + minimum_z=0.0, + maximum_z=2.0, + minimum_distance=18.0, + max_iteration=100, + distance_strict=False, + ), + generate_noise_kwargs=dict(noise_levels=(0.0, 0.5), spatial_decay=1.0), + # must have some noise, or peak detection becomes completely stoachastic + # because it relies on the std to set the threshold. + ) + + peaks_list, peak_locations_list = session_alignment.compute_peaks_locations_for_session_alignment( + recordings_list, + detect_kwargs={"method": "locally_exclusive"}, + localize_peaks_kwargs={"method": "grid_convolution"}, + ) + + return (recordings_list, shifts, peaks_list, peak_locations_list) + + def motion_correct_recordings_list(self, recordings_list, rigid_motion): + # Unfortunately this is necessary only in the test environemnt + # because adding offsets means copying the motion object does not work + interpolate_motion_kwargs = {"border_mode": "force_zeros"} + localize_peaks_kwargs = {"method": "grid_convolution"} + + preset = "rigid_fast" if rigid_motion else "kilosort_like" + + # Perform a motion correction, note this is just to make the + # motion correction object with the correct displacment, but + # the displacements should be zero here. These are manulally + # added in the tetsts. + mc_recording_list = [] + mc_motion_info_list = [] + for rec in recordings_list: + corrected_rec, motion_info = si.correct_motion( + rec, + preset=preset, + interpolate_motion_kwargs=interpolate_motion_kwargs, + output_motion_info=True, + localize_peaks_kwargs=localize_peaks_kwargs, + ) + mc_recording_list.append(corrected_rec) + mc_motion_info_list.append(motion_info) + + return mc_recording_list, mc_motion_info_list # , shifts + + ########################################################################### + # Functional Tests + ############################################################################ + + @pytest.mark.parametrize("histogram_type", ["1d", "2d"]) + @pytest.mark.parametrize("num_shifts_global", [None, 200]) + def test_align_sessions_finds_correct_shifts(self, num_shifts_global, recording_1, histogram_type): + """ + Test that `align_sessions` recovers the correct (linear) shifts. + """ + recordings_list, shifts, peaks_list, peak_locations_list = recording_1 + + assert shifts == ( + (0, 0), + (0, -200), + (0, 150), + ), "expected shifts are hard-coded into this test ahould should be set in the fixture.." + + compute_alignment_kwargs = session_alignment.get_compute_alignment_kwargs() + compute_alignment_kwargs["smoothing_sigma_bin"] = None + compute_alignment_kwargs["smoothing_sigma_window"] = None + compute_alignment_kwargs["num_shifts_global"] = num_shifts_global + + estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs() + estimate_histogram_kwargs["bin_um"] = 2 + estimate_histogram_kwargs["histogram_type"] = histogram_type + estimate_histogram_kwargs["log_transform"] = True + + for mode, expected in zip( + ["to_session_1", "to_session_2", "to_session_3", "to_middle"], + [ + (0, -200, 150), + (200, 0, 350), + (-150, -350, 0), + (16.66, -183.33, 166.66), + ], + ): + corrected_recordings_list, extra_info = session_alignment.align_sessions( + recordings_list, + peaks_list, + peak_locations_list, + alignment_order=mode, + compute_alignment_kwargs=compute_alignment_kwargs, + estimate_histogram_kwargs=estimate_histogram_kwargs, + ) + + if DEBUG: + from spikeinterface.widgets import plot_session_alignment, plot_activity_histogram_2d + import matplotlib.pyplot as plt + + plot = plot_session_alignment( + recordings_list, + peaks_list, + peak_locations_list, + extra_info["session_histogram_list"], + **extra_info["corrected"], + spatial_bin_centers=extra_info["spatial_bin_centers"], + drift_raster_map_kwargs={"clim": (-250, 0), "scatter_decimate": 10}, + ) + plt.show() + + assert np.allclose(expected, extra_info["shifts_array"].squeeze(), rtol=0, atol=0.02) + + corr_peaks_list, corr_peak_loc_list = session_alignment.compute_peaks_locations_for_session_alignment( + corrected_recordings_list, + detect_kwargs={"method": "locally_exclusive"}, + localize_peaks_kwargs={"method": "grid_convolution"}, + ) + + new_histograms = session_alignment._compute_session_histograms( + corrected_recordings_list, corr_peaks_list, corr_peak_loc_list, **estimate_histogram_kwargs + )[0] + + rows, cols = np.triu_indices(len(new_histograms), k=1) + assert np.all( + np.abs(np.corrcoef([hist.flatten() for hist in new_histograms])[rows, cols]) + - np.abs(np.corrcoef([hist.flatten() for hist in extra_info["session_histogram_list"]])[rows, cols]) + >= 0 + ) + + def test_histogram_generation(self, recording_1): + """ """ + recordings_list, _, peaks_list, peak_locations_list = recording_1 + + recording = recordings_list[0] + + channel_locations = recording.get_channel_locations() + loc_start = np.min(channel_locations[:, 1]) + loc_end = np.max(channel_locations[:, 1]) + + # Test some floats as slightly more complex than integer case + bin_s = 1.5 + bin_um = 1.5 + + estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs() + estimate_histogram_kwargs["bin_um"] = bin_um + estimate_histogram_kwargs["chunked_bin_size_s"] = bin_s + estimate_histogram_kwargs["log_transform"] = False + + ( + session_histogram_list, + temporal_bin_centers_list, + spatial_bin_centers, + spatial_bin_edges, + histogram_info_list, + ) = session_alignment._compute_session_histograms( + recordings_list, peaks_list, peak_locations_list, **estimate_histogram_kwargs + ) + + num_bins = (loc_end - loc_start) / bin_um + + bin_edges = np.linspace(loc_start, loc_end, int(num_bins) + 1) + bin_centers = bin_edges[:-1] + bin_um / 2 + + assert np.array_equal(bin_edges, spatial_bin_edges) + assert np.array_equal(bin_centers, spatial_bin_centers) + + for recording, temporal_bin_center in zip(recordings_list, temporal_bin_centers_list): + times = recording.get_times() + centers = (np.max(times) - np.min(times)) / 2 + assert temporal_bin_center == centers + + for ses_idx, (recording, chunked_histogram_info) in enumerate(zip(recordings_list, histogram_info_list)): + + # TODO: this is direct copy from above, can merge + times = recording.get_times() + + chunk_time_window = chunked_histogram_info["chunked_bin_size_s"] + + num_windows = (np.ceil(np.max(times)) - np.min(times)) / chunk_time_window + temp_bin_edges = np.arange(np.ceil(num_windows) + 1) * chunk_time_window + centers = temp_bin_edges[:-1] + chunk_time_window / 2 + + assert chunked_histogram_info["chunked_bin_size_s"] == chunk_time_window + assert np.array_equal(chunked_histogram_info["chunked_temporal_bin_centers"], centers) + + for edge_idx in range(len(temp_bin_edges) - 1): + + lower = temp_bin_edges[edge_idx] + upper = temp_bin_edges[edge_idx + 1] + + lower_idx = recording.time_to_sample_index(lower) + upper_idx = recording.time_to_sample_index(upper) + + new_peak_locs = peak_locations_list[ses_idx][ + np.where( + np.logical_and( + peaks_list[ses_idx]["sample_index"] >= lower_idx, + peaks_list[ses_idx]["sample_index"] < upper_idx, + ) + ) + ] + assert np.allclose( + np.histogram(new_peak_locs["y"], bins=bin_edges)[0] / (upper - lower), + chunked_histogram_info["chunked_histograms"][edge_idx, :], + rtol=0, + atol=1e-6, + ) + + @pytest.mark.parametrize("histogram_type", ["1d", "2d"]) + @pytest.mark.parametrize("operator", ["mean", "median"]) + def test_histogram_log_tranform(self, recording_1, histogram_type, operator): + """ """ + # Run histogram compute with a set of kwargs + recordings_list, _, peaks_list, peak_locations_list = recording_1 + + estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs() + estimate_histogram_kwargs["log_transform"] = False + estimate_histogram_kwargs["method"] = f"chunked_{operator}" + estimate_histogram_kwargs["histogram_type"] = histogram_type + + _, extra_info_log_false = session_alignment.align_sessions( + recordings_list, peaks_list, peak_locations_list, estimate_histogram_kwargs=estimate_histogram_kwargs + ) + + # Now, run it again with log transform and check the + # summary histogram is indeed the log-transformed mean / median + # of the chunked histograms. + estimate_histogram_kwargs["log_transform"] = True + + _, extra_info_log_true = session_alignment.align_sessions( + recordings_list, peaks_list, peak_locations_list, estimate_histogram_kwargs=estimate_histogram_kwargs + ) + for ses_idx in range(len(recordings_list)): + + summary_hist_log_false = extra_info_log_false["session_histogram_list"][ses_idx] + summary_hist_log_true = extra_info_log_true["session_histogram_list"][ses_idx] + chunked_histograms_log_true = extra_info_log_true["histogram_info_list"][ses_idx]["chunked_histograms"] + + assert np.array_equal(np.log2(summary_hist_log_false + 1), summary_hist_log_true) + + summary_func = np.median if operator == "median" else np.mean + assert np.array_equal(summary_hist_log_true, np.log2(1 + summary_func(chunked_histograms_log_true, axis=0))) + + ########################################################################### + # Following Motion Correction + ########################################################################### + # These tests check that the displacement found by the inter-session alignment + # is correctly added to any existing motion-correction results. + + def test_rigid_motion_rigid_intersession(self, recording_1): + """ + Create an inter-session alignment recording and motion correct it so that + it is an InterpolateMotion recording. Add some shifts to the existing displacement + on the InterpolateMotion recordings and check the inter-session alignment shifts + are properly added to this. + """ + recordings_list, shifts, _, _ = recording_1 + + mc_recording_list, mc_motion_info_list = self.motion_correct_recordings_list( + recordings_list, + rigid_motion=True, + ) + first_ses_mc_displacement = mc_recording_list[0]._recording_segments[0].motion.displacement + second_ses_mc_displacement = mc_recording_list[1]._recording_segments[0].motion.displacement + + # Ensure the motion was generated rigid by the test suite + assert first_ses_mc_displacement[0].size == 1 + assert second_ses_mc_displacement[0].size == 1 + + # Add some shifts to represent an existing motion correction + first_ses_mc_displacement[0] += 0.01 + second_ses_mc_displacement[0] += 0.02 + + non_rigid_window_kwargs = session_alignment.get_non_rigid_window_kwargs() + non_rigid_window_kwargs["rigid"] = True + + corrected_recordings, extra_info = session_alignment.align_sessions_after_motion_correction( + mc_recording_list, + mc_motion_info_list, + align_sessions_kwargs={ + "alignment_order": "to_session_1", + "non_rigid_window_kwargs": non_rigid_window_kwargs, + }, + ) + first_ses_total_displacement = corrected_recordings[0]._recording_segments[0].motion.displacement + second_ses_total_displacement = corrected_recordings[1]._recording_segments[0].motion.displacement + + # Check that the shift is the existing motion correction + the inter-session shift + assert first_ses_total_displacement == [np.array([[shifts[0][1] + 0.01]])] + assert second_ses_total_displacement == [np.array([[shifts[1][1] + 0.02]])] + + self.assert_interpolate_recording_not_duplicate(corrected_recordings[0]) + + def test_rigid_motion_nonrigid_intersession(self, recording_2): + """ + Test that non-rigid shifts estimated in inter-session alignment are properly + added to rigid shifts estimated in motion correction. + """ + recordings_list, _, peaks_list, _ = recording_2 + + mc_recording_list, mc_motion_info_list = self.motion_correct_recordings_list( + recordings_list, + rigid_motion=True, + ) + + first_ses_mc_displacement = mc_recording_list[0]._recording_segments[0].motion.displacement + second_ses_mc_displacement = mc_recording_list[1]._recording_segments[0].motion.displacement + + # Ensure the motion was generated rigid by the test suite + assert first_ses_mc_displacement[0].size == 1 + assert second_ses_mc_displacement[0].size == 1 + + # Add some shifts to represent an existing motion correction + first_ses_mc_displacement[0] += 0.01 + second_ses_mc_displacement[0] += 0.02 + + # All of this is direct copy between these tests... + non_rigid_window_kwargs = session_alignment.get_non_rigid_window_kwargs() + non_rigid_window_kwargs["rigid"] = False + + corrected_recordings, extra_info = session_alignment.align_sessions_after_motion_correction( + mc_recording_list, + mc_motion_info_list, + align_sessions_kwargs={ + "alignment_order": "to_session_1", + "non_rigid_window_kwargs": non_rigid_window_kwargs, + }, + ) + + first_ses_total_displacement = corrected_recordings[0]._recording_segments[0].motion.displacement + second_ses_total_displacement = corrected_recordings[1]._recording_segments[0].motion.displacement + + # The shift themselves are not expected to be correct to align this tricky test case + # (see test_interesting_debug_case) but the shift on the motion objects should + # match the estimateed shifts from inter-session alignment + the motion shifts set above + assert np.all(extra_info["shifts_array"][0] + 0.01 == first_ses_total_displacement) + assert np.all(extra_info["shifts_array"][1] + 0.02 == second_ses_total_displacement) + + self.assert_interpolate_recording_not_duplicate(corrected_recordings[0]) + + @pytest.mark.parametrize("rigid_intersession", [True, False]) + def test_nonrigid_motion(self, rigid_intersession, recording_1, recording_2): + """ + Now test that non-rigid motion estimates are properly combined with the + rigid or non-rigid inter-session alignment estimates. + """ + if rigid_intersession: + recordings_list, _, peaks_list, _ = recording_1 + else: + recordings_list, _, peaks_list, _ = recording_2 + + mc_recording_list, mc_motion_info_list = self.motion_correct_recordings_list( + recordings_list, + rigid_motion=False, + ) + + # Now the motion data has multiple displcements (per probe segment window). + # Add offsets to these different windows. + first_ses_mc_displacement = mc_recording_list[0]._recording_segments[0].motion.displacement + second_ses_mc_displacement = mc_recording_list[1]._recording_segments[0].motion.displacement + + offsets1 = np.linspace(0, 0.1, first_ses_mc_displacement[0].size) + offsets2 = np.linspace(0, 0.1, first_ses_mc_displacement[0].size) + + first_ses_mc_displacement[0] += offsets1 + second_ses_mc_displacement[0] += offsets2 + + # Now run inter-session alignment (either rigid or non-rigid). Check that + # the final displacement on the motion objects is a combination of the + # nonrigid motion estimate and rigid or nonrigid inter-session alignment estimate. + non_rigid_window_kwargs = session_alignment.get_non_rigid_window_kwargs() + non_rigid_window_kwargs["rigid"] = rigid_intersession + + corrected_recordings, extra_info = session_alignment.align_sessions_after_motion_correction( + mc_recording_list, + mc_motion_info_list, + align_sessions_kwargs={ + "alignment_order": "to_session_1", + "non_rigid_window_kwargs": non_rigid_window_kwargs, + }, + ) + + first_ses_total_displacement = corrected_recordings[0]._recording_segments[0].motion.displacement + second_ses_total_displacement = corrected_recordings[1]._recording_segments[0].motion.displacement + + assert np.all(extra_info["shifts_array"][0] + offsets1 == first_ses_total_displacement) + assert np.all(extra_info["shifts_array"][1] + offsets2 == second_ses_total_displacement) + + self.assert_interpolate_recording_not_duplicate(corrected_recordings[0]) + + def assert_interpolate_recording_not_duplicate(self, recording): + """ + Do a quick check that indeed the interpolate recording is not duplicate + (i.e. only one interpolate recording, and the previous is the generated simulation + recording. + """ + assert ( + isinstance(recording, InterpolateMotionRecording) + and recording._parent.name == "InterSessionDisplacementRecording" + ) + + def test_motion_correction_peaks_are_converted(self, mocker, recording_1): + """ + When `align_sessions_after_motion_correction` is run, the peaks locations + used should be those that are already motion corrected, which requires + correcting the peak locations in the function. + + Therefore, check that the final peak locations passed to `align_sessions` + are motion-corrected. + """ + recordings_list, _, peaks_list, peak_locations_list = recording_1 + + # Motion correct recordings, and add a known motion-displacement + mc_recording_list, mc_motion_info_list = self.motion_correct_recordings_list( + recordings_list, + rigid_motion=True, + ) + + first_ses_mc_displacement = mc_recording_list[0]._recording_segments[0].motion.displacement + second_ses_mc_displacement = mc_recording_list[1]._recording_segments[0].motion.displacement + + first_ses_mc_displacement[0] += 0.1 + second_ses_mc_displacement[0] += 0.2 + + # mock the `align_sessions` function to check what was passed + spy_align_sessions = mocker.spy( + spikeinterface.preprocessing.inter_session_alignment.session_alignment, "align_sessions" + ) + + # Call the function, and check that the passed peak-locations are corrected + corrected_recordings, _ = session_alignment.align_sessions_after_motion_correction( + mc_recording_list, mc_motion_info_list, None + ) + + passed_peak_locations_1 = spy_align_sessions.call_args_list[0][0][2][0] + passed_peak_locations_2 = spy_align_sessions.call_args_list[0][0][2][1] + + assert np.allclose( + passed_peak_locations_1["y"], mc_motion_info_list[0]["peak_locations"]["y"] - 0.1, rtol=0, atol=1e-4 + ) + assert np.allclose( + passed_peak_locations_2["y"], mc_motion_info_list[1]["peak_locations"]["y"] - 0.2, rtol=0, atol=1e-4 + ) + + def test_motion_correction_kwargs(self, mocker, recording_1): + """ + For `align_sessions_after_motion_correction`, if the motion-correct is non-rigid + then the non-rigid window kwargs must match for inter-session alignment, + otherwise it will not be possible to add the displacement. + """ + recordings_list, _, _, _ = recording_1 + + mc_recording_list, mc_motion_info_list = self.motion_correct_recordings_list( + recordings_list, + rigid_motion=False, + ) + + spy_align_sessions = mocker.spy( + spikeinterface.preprocessing.inter_session_alignment.session_alignment, "align_sessions" + ) + + # Run `align_sessions_after_motion_correction` with non-rigid window kwargs + # that do not mach those used for motion correction + changed = session_alignment.get_non_rigid_window_kwargs() + changed["rigid"] = False + changed["win_step_um"] = 51 + + session_alignment.align_sessions_after_motion_correction( + mc_recording_list, mc_motion_info_list, {"non_rigid_window_kwargs": changed} + ) + + # Now remove kwargs from the motion-correct and inter-session alignment (passed) non-rigid + # window kwargs that don't match (Some from motion are not relevant for inter-session alignment, + # some for inter-session and not set on motion (they may? predate their introduction). + # Check that these core kwargs match (i.e. align_sessions is using the non-rigid-window + # settings that motion uses. + non_rigid_windows = mc_motion_info_list[0]["parameters"]["estimate_motion_kwargs"] + non_rigid_windows.pop("method") + non_rigid_windows.pop("bin_s") + non_rigid_windows.pop( + "hist_margin_um" + ) # TODO: I think some kwargs are not exposed, this is probably okay, could mention to Sam + + passed_non_rigid_windows = spy_align_sessions.call_args_list[0][1]["non_rigid_window_kwargs"] + passed_non_rigid_windows.pop("zero_threshold") + passed_non_rigid_windows.pop("win_margin_um") + + assert sorted(passed_non_rigid_windows) == sorted(non_rigid_windows) + + ########################################################################### + # Unit Tests + ########################################################################### + + def test_shift_array_fill_zeros(self): + """ + The tested function shifts a 1d array or 2d array (along a certain axis) + and fills space with zero. Check that arrays are shifted as expected. + """ + # Test 1d array + test_1d = np.random.random((10)) + + # shift leftwards + res = alignment_utils.shift_array_fill_zeros(test_1d, 2) + assert np.all(res[8:] == 0) + assert np.array_equal(res[:8], test_1d[2:]) + + # shift rightwards + res = alignment_utils.shift_array_fill_zeros(test_1d, -2) + assert np.all(res[:2] == 0) + assert np.array_equal(res[2:], test_1d[:8]) + + # Test 2d array + test_2d = np.random.random((10, 10)) + + # shift upwards + res = alignment_utils.shift_array_fill_zeros(test_2d, 2) + assert np.all(res[8:, :] == 0) + assert np.array_equal(res[:8, :], test_2d[2:, :]) + + # shift downwards. + res = alignment_utils.shift_array_fill_zeros(test_2d, -2) + assert np.all(res[:2, :] == 0) + assert np.array_equal(res[2:, :], test_2d[:8, :]) + + def test_get_shifts_from_session_matrix(self): + """ + Given a 'session matrix' of shifts (a matrix Mij where each element + is the shift to get from session i to session j). It is skew-symmetric. + """ + matrix = np.random.random((10, 10, 2)) + + res = alignment_utils.get_shifts_from_session_matrix("to_middle", matrix) + assert np.array_equal(res, -np.mean(matrix, axis=0)) + + res = alignment_utils.get_shifts_from_session_matrix("to_session_1", matrix) + assert np.array_equal(res, -matrix[0, :, :]) + + res = alignment_utils.get_shifts_from_session_matrix("to_session_5", matrix) + assert np.array_equal(res, -matrix[4, :, :]) + + res = alignment_utils.get_shifts_from_session_matrix("to_session_10", matrix) + assert np.array_equal(res, -matrix[9, :, :]) + + @pytest.mark.parametrize("interpolate", [True, False]) + @pytest.mark.parametrize("odd_hist_size", [True, False]) + @pytest.mark.parametrize("shifts", [3, -2]) + def test_compute_histogram_crosscorrelation_rigid(self, interpolate, odd_hist_size, shifts): + """ + Create some toy array and shift it, then check that the cross-correlattion + correctly finds the shifts, under a number of conditions. + """ + if odd_hist_size: + hist = np.array([1, 0, 1, 1, 1, 0]) + else: + hist = np.array([0, 0, 1, 1, 0, 1, 0, 1]) + + hist_shift = alignment_utils.shift_array_fill_zeros(hist, shifts) + + session_histogram_list = np.vstack([hist, hist_shift]) + + interp_factor = 50 # not used when interpolate = False + shifts_matrix, xcorr_matrix_unsmoothed = alignment_utils.compute_histogram_crosscorrelation( + session_histogram_list, + non_rigid_windows=np.ones((1, hist.size)), + num_shifts=None, + interpolate=interpolate, + interp_factor=interp_factor, + kriging_sigma=0.2, + kriging_p=2, + kriging_d=2, + smoothing_sigma_bin=None, + smoothing_sigma_window=None, + min_crosscorr_threshold=0.001, + ) + assert np.isclose( + alignment_utils.get_shifts_from_session_matrix("to_session_1", shifts_matrix)[-1], + -shifts, + rtol=0, + atol=0.01, + ) + + num_shifts = hist.size * 2 - 1 + if interpolate: + assert xcorr_matrix_unsmoothed.shape[1] == num_shifts * interp_factor + else: + assert xcorr_matrix_unsmoothed.shape[1] == num_shifts + + @pytest.mark.parametrize("histogram_mode", ["1d", "2d"]) + def test_compute_histogram_crosscorrelation_nonrigid(self, histogram_mode): + """ """ + # fmt: off + # Window 1 | Window 2 | Window 3 + hist_1 = np.array([0.5, 1, 0, 0, 0, 1e-3, 0, 0, 0, 0, 0, 1]) + hist_2 = np.array([0, 0, 0.5, 1, 1e-12, 0, 0, 0, 1, 0, 0, 0]) + + if histogram_mode == "2d": + hist_1 = np.vstack([hist_1, hist_1 * 2]).T + hist_2 = np.vstack([hist_2, hist_2 * 2]).T + + winds = np.array([[ 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]]) + + shifts_matrix, xcorr_matrix_unsmoothed = alignment_utils.compute_histogram_crosscorrelation( + np.stack([hist_1, hist_2], axis=0), + non_rigid_windows=winds, + num_shifts=None, + interpolate=None, + interp_factor=1.0, + kriging_sigma=0.2, + kriging_p=2, + kriging_d=2, + smoothing_sigma_bin=None, + smoothing_sigma_window=None, + min_crosscorr_threshold=0.001, + ) + # fmt: on + + wind_1_shift = shifts_matrix[1, 0][0] + wind_2_shift = shifts_matrix[1, 0][1] + wind_3_shift = shifts_matrix[1, 0][2] + + assert wind_1_shift == 2 # the first window is shifted +2 + assert wind_2_shift == 0 # the second window has poor correlation under 0.001, so set to 0 + assert wind_3_shift == -3 # the third window is shifted -3 + + ########################################################################### + # Kwargs Tests + ########################################################################### + + def test_get_estimate_histogram_kwargs(self, mocker, recording_1): + + recordings_list, _, peaks_list, peak_locations_list = recording_1 + + default_kwargs = { + "bin_um": 2, + "method": "chunked_mean", + "chunked_bin_size_s": "estimate", + "log_transform": True, + "depth_smooth_um": None, + "histogram_type": "1d", + "weight_with_amplitude": False, + "avg_in_bin": False, + } + + assert ( + default_kwargs == session_alignment.get_estimate_histogram_kwargs() + ), "Default `get_estimate_histogram_kwargs` were changed." + + different_kwargs = session_alignment.get_estimate_histogram_kwargs() + different_kwargs.update( + { + "chunked_bin_size_s": 6, + "depth_smooth_um": 5, + "weight_with_amplitude": True, + "avg_in_bin": True, + } + ) + + spy_2d_histogram = mocker.spy( + spikeinterface.preprocessing.inter_session_alignment.alignment_utils, "make_2d_motion_histogram" + ) + session_alignment.align_sessions( + [recordings_list[0]], [peaks_list[0]], [peak_locations_list[0]], estimate_histogram_kwargs=different_kwargs + ) + first_call = spy_2d_histogram.call_args_list[0] + args, kwargs = first_call + + assert kwargs["bin_s"] == different_kwargs["chunked_bin_size_s"] + assert kwargs["bin_um"] is None + assert np.unique(np.diff(kwargs["spatial_bin_edges"])) == different_kwargs["bin_um"] + assert kwargs["depth_smooth_um"] == different_kwargs["depth_smooth_um"] + assert kwargs["weight_with_amplitude"] == different_kwargs["weight_with_amplitude"] + assert kwargs["avg_in_bin"] == different_kwargs["avg_in_bin"] + + def test_compute_alignment_kwargs(self, mocker, recording_1): + + recordings_list, _, peaks_list, peak_locations_list = recording_1 + + default_kwargs = { + "num_shifts_global": None, + "num_shifts_block": 20, + "interpolate": False, + "interp_factor": 10, + "kriging_sigma": 1, + "kriging_p": 2, + "kriging_d": 2, + "smoothing_sigma_bin": 0.5, + "smoothing_sigma_window": 0.5, + "akima_interp_nonrigid": False, + "min_crosscorr_threshold": 0.001, + } + assert ( + session_alignment.get_compute_alignment_kwargs() == default_kwargs + ), "Default `get_compute_alignment_kwargs` were changed." + + different_kwargs = session_alignment.get_compute_alignment_kwargs() + different_kwargs.update( + { + "interpolate": True, + "kriging_sigma": 5, + "kriging_p": 10, + "kriging_d": 20, + "smoothing_sigma_bin": 1.2, + "smoothing_sigma_window": 1.3, + } + ) + import scipy + + spy_kriging = mocker.spy(spikeinterface.preprocessing.inter_session_alignment.alignment_utils, "kriging_kernel") + spy_gaussian_filter = mocker.spy(scipy.ndimage, "gaussian_filter") + non_rigid_window_kwargs = session_alignment.get_non_rigid_window_kwargs() + non_rigid_window_kwargs["rigid"] = False + + session_alignment.align_sessions( + [recordings_list[0]], + [peaks_list[0]], + [peak_locations_list[0]], + compute_alignment_kwargs=different_kwargs, + non_rigid_window_kwargs=non_rigid_window_kwargs, + ) + kwargs = spy_kriging.call_args_list[0][1] + assert kwargs["sigma"] == different_kwargs["kriging_sigma"] + assert kwargs["p"] == different_kwargs["kriging_p"] + assert kwargs["d"] == different_kwargs["kriging_d"] + + # First call is overall rigid, then nonrigid smooth bin + kwargs = spy_gaussian_filter.call_args_list[0][1] + assert kwargs["sigma"] == different_kwargs["smoothing_sigma_bin"] + + # then nonrigid smooth window + kwargs = spy_gaussian_filter.call_args_list[2][1] + assert kwargs["sigma"] == different_kwargs["smoothing_sigma_window"] + + def test_non_rigid_window_kwargs(self, mocker, recording_1): + + import spikeinterface # TODO: place up top with a note + + default_kwargs = { + "rigid": True, + "win_shape": "gaussian", + "win_step_um": 50, + "win_scale_um": 50, + "win_margin_um": None, + "zero_threshold": None, + } + assert ( + session_alignment.get_non_rigid_window_kwargs() == default_kwargs + ), "Default `get_non_rigid_window_kwargs` were changed." + + different_kwargs = { + "rigid": False, + "win_shape": "rect", + "win_step_um": 55, + "win_scale_um": 65, + "win_margin_um": 10, + "zero_threshold": 4, + } + + recordings_list, _, peaks_list, peak_locations_list = recording_1 + + spy_get_spatial_windows = mocker.spy( + spikeinterface.preprocessing.inter_session_alignment.session_alignment, "get_spatial_windows" + ) + session_alignment.align_sessions( + [recordings_list[0]], [peaks_list[0]], [peak_locations_list[0]], non_rigid_window_kwargs=different_kwargs + ) + kwargs = spy_get_spatial_windows.call_args_list[0][1] + assert kwargs["rigid"] == different_kwargs["rigid"] + assert kwargs["win_shape"] == different_kwargs["win_shape"] + assert kwargs["win_step_um"] == different_kwargs["win_step_um"] + assert kwargs["win_scale_um"] == different_kwargs["win_scale_um"] + assert kwargs["win_margin_um"] == different_kwargs["win_margin_um"] + assert kwargs["zero_threshold"] == different_kwargs["zero_threshold"] + + def test_interpolate_motion_kwargs(self, mocker, recording_1): + """ """ + + default_kwargs = { + "border_mode": "force_zeros", + "spatial_interpolation_method": "kriging", + "sigma_um": 20.0, + "p": 2, + } + assert ( + session_alignment.get_interpolate_motion_kwargs() == default_kwargs + ), "Default `get_non_rigid_window_kwargs` were changed." + + different_kwargs = { + "border_mode": "force_zeros", + "spatial_interpolation_method": "nearest", + "sigma_um": 25.0, + "p": 3, + } + + recordings_list, _, peaks_list, peak_locations_list = recording_1 + + spy_get_2d_activity_histogram = mocker.spy( + spikeinterface.preprocessing.inter_session_alignment.session_alignment.InterpolateMotionRecording, + "__init__", + ) + session_alignment.align_sessions( + [recordings_list[0]], [peaks_list[0]], [peak_locations_list[0]], interpolate_motion_kwargs=different_kwargs + ) + first_call = spy_get_2d_activity_histogram.call_args_list[0] + args, kwargs = first_call + + assert kwargs["border_mode"] == different_kwargs["border_mode"] + assert kwargs["spatial_interpolation_method"] == different_kwargs["spatial_interpolation_method"] + assert kwargs["sigma_um"] == different_kwargs["sigma_um"] + assert kwargs["p"] == different_kwargs["p"] + + @pytest.mark.parametrize("histogram_type", ["1d", "2d"]) + def test_interesting_debug_case(self, histogram_type, recording_2): + """ + This is an interseting debug case that is included in the tests to act as + both a regression test and highlight how the alignment works and can lead + to imperfect results in the test setting. + + In this case we take a non-rigid alignment, and we see that the right-edge + of the histogram is aligned well, but the middle area (which lies within + the same nonrigid window) is not well aligned. This is in spite of the + shifts being estimated correctly for that segment. + + The problem is that the nonrigid bins are interpolated to get the shifts + for each channel. In this case, the histogram peak being aligned + lies in between two nonrigid window middle points, and so is interpolated. + It is in the window with shift ~144 but next to a window with ~190 + and so ends up around ~170, resulting in the incorrect shift for this segment. + But, in the real world it is necessary to interpolate channels like this + or shifting whole windows would end up with very strange results! + """ + recording_list, _, _, _ = recording_2 + + mc_recording_list, mc_motion_info_list = self.motion_correct_recordings_list(recording_list, rigid_motion=True) + + # Run alignment + non_rigid_window_kwargs = session_alignment.get_non_rigid_window_kwargs() + non_rigid_window_kwargs["rigid"] = False + non_rigid_window_kwargs["win_shape"] = "rect" + non_rigid_window_kwargs["win_step_um"] = 250 + non_rigid_window_kwargs["win_scale_um"] = 250 + + compute_alignment_kwargs = session_alignment.get_compute_alignment_kwargs() + compute_alignment_kwargs["num_shifts_block"] = 75 + compute_alignment_kwargs["smoothing_sigma_bin"] = None + compute_alignment_kwargs["smoothing_sigma_window"] = None + + estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs() + estimate_histogram_kwargs["histogram_type"] = histogram_type + + corrected_recordings, extra_info = session_alignment.align_sessions_after_motion_correction( + mc_recording_list, + mc_motion_info_list, + align_sessions_kwargs={ + "alignment_order": "to_session_1", # to_center + "non_rigid_window_kwargs": non_rigid_window_kwargs, + "compute_alignment_kwargs": compute_alignment_kwargs, + "estimate_histogram_kwargs": estimate_histogram_kwargs, + }, + ) + + if DEBUG: + from spikeinterface.widgets import plot_session_alignment, plot_activity_histogram_2d + import matplotlib.pyplot as plt + + # Plot the results, as well as the shift non-rigid window centers and + # the shifts amount. You can see where the shift is reduced to align + # the peaks in the middle of the histogram, the orange peak is + # not sufficiently moved (because of the interpolation). + peaks_list = [info["peaks"] for info in mc_motion_info_list] + peak_locations_list = [info["peak_locations"] for info in mc_motion_info_list] + plot = plot_session_alignment( + mc_recording_list, + peaks_list, + peak_locations_list, + extra_info["session_histogram_list"], + **extra_info["corrected"], + spatial_bin_centers=extra_info["spatial_bin_centers"], + drift_raster_map_kwargs={"clim": (-250, 0), "scatter_decimate": 10}, + ) + + window_edges = np.r_[0, np.cumsum(np.diff(extra_info["non_rigid_window_centers"]))] + window_edges[:-1] += np.diff(window_edges) / 2 + + y_window = extra_info["shifts_array"][1] + x_bin = extra_info["non_rigid_window_centers"] + + ax4_twin = plot.figure.axes[4].twinx() + ax5_twin = plot.figure.axes[5].twinx() + ax4_twin.scatter(x_bin, y_window, color="red", s=100, edgecolor="black", zorder=3, label="Points") + ax5_twin.scatter(x_bin, y_window, color="red", s=100, edgecolor="black", zorder=3, label="Points") + + for x_n, y_n in zip(x_bin, y_window): + ax4_twin.plot([x_n, x_n], [0, y_n], color="black", linestyle="--", linewidth=1.5, zorder=2) + ax5_twin.plot([x_n, x_n], [0, y_n], color="black", linestyle="--", linewidth=1.5, zorder=2) + + plt.suptitle("test_interesting_debug_case") + plt.show() diff --git a/src/spikeinterface/sortingcomponents/motion/decentralized.py b/src/spikeinterface/sortingcomponents/motion/decentralized.py index 956f23efba..5d0b5a0ad0 100644 --- a/src/spikeinterface/sortingcomponents/motion/decentralized.py +++ b/src/spikeinterface/sortingcomponents/motion/decentralized.py @@ -3,7 +3,15 @@ from tqdm.auto import tqdm, trange from spikeinterface.core.motion import Motion -from .motion_utils import get_spatial_windows, get_spatial_bin_edges, make_2d_motion_histogram, scipy_conv1d + +from .motion_utils import ( + get_spatial_windows, + get_spatial_bin_edges, + make_2d_motion_histogram, + scipy_conv1d, + get_spatial_bins, +) + from .dredge import normxcorr1d @@ -135,13 +143,9 @@ def run( lsqr_robust_n_iter=20, weight_with_amplitude=False, ): - - dim = ["x", "y", "z"].index(direction) - contact_depths = recording.get_channel_locations()[:, dim] - - # spatial histogram bins - spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) - spatial_bin_centers = 0.5 * (spatial_bin_edges[1:] + spatial_bin_edges[:-1]) + spatial_bin_centers, spatial_bin_edges, contact_depths = get_spatial_bins( + recording, direction, hist_margin_um, bin_um + ) # get spatial windows non_rigid_windows, non_rigid_window_centers = get_spatial_windows( diff --git a/src/spikeinterface/sortingcomponents/motion/iterative_template.py b/src/spikeinterface/sortingcomponents/motion/iterative_template.py index 7bb067b5bd..905ba7cde1 100644 --- a/src/spikeinterface/sortingcomponents/motion/iterative_template.py +++ b/src/spikeinterface/sortingcomponents/motion/iterative_template.py @@ -289,6 +289,8 @@ def iterative_template_registration( return optimal_shift_indices, target_spikecount_hist, shift_covs_block +# TODO: this is duplicate of get_kriging_kernel_distance() but that +# doesnt expose d parameter, could combine? def kriging_kernel(source_location, target_location, sigma=1, p=2, d=2): from scipy.spatial.distance import cdist diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 2a64f9f7ea..5663c73ff9 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -354,7 +354,6 @@ def __init__( **spatial_interpolation_kwargs, ): # assert recording.get_num_segments() == 1, "correct_motion() is only available for single-segment recordings" - channel_locations = recording.get_channel_locations() assert channel_locations.ndim >= motion.dim, ( f"'direction' {motion.direction} not available. " diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index c72ff3de58..965e411d63 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -168,6 +168,18 @@ def get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um): return spatial_bins +def get_spatial_bins(recording, direction, hist_margin_um, bin_um): + # TODO: could this be merged with the above function? + dim = ["x", "y", "z"].index(direction) + contact_depths = recording.get_channel_locations()[:, dim] + + # spatial histogram bins + spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) + spatial_bin_centers = 0.5 * (spatial_bin_edges[1:] + spatial_bin_edges[:-1]) + + return spatial_bin_centers, spatial_bin_edges, contact_depths + + def make_2d_motion_histogram( recording, peaks, diff --git a/src/spikeinterface/widgets/inter_session_alignment.py b/src/spikeinterface/widgets/inter_session_alignment.py new file mode 100644 index 0000000000..89dbd16d01 --- /dev/null +++ b/src/spikeinterface/widgets/inter_session_alignment.py @@ -0,0 +1,431 @@ +import itertools + +from spikeinterface.core import BaseRecording +import numpy as np +from spikeinterface.widgets.base import BaseWidget +from spikeinterface.widgets.base import to_attr +from spikeinterface.widgets.motion import DriftRasterMapWidget + + +class SessionAlignmentWidget(BaseWidget): + """ + Widget to display the output of inter-session alignment. + In the top section, `DriftRasterMapWidget`s are used to display + the raster maps for each session, before and after alignment. + The order of all lists should correspond to the same recording. + + If histograms are provided, `ActivityHistogram1DWidget` + are used to show the activity histograms, before and after alignment. + See `align_sessions` for context. + + Corrected and uncorrected activity histograms are generated + as part of the `align_sessions` step. + + Parameters + ---------- + + recordings_list : list[BaseRecording] + List of recordings to plot. + peaks_list : list[np.ndarray] + List of detected peaks for each session. + peak_locations_list : list[np.ndarray] + List of detected peak locations for each session. + session_histogram_list : np.ndarray | None + A list of activity histograms as output from `align_sessions`. + If `None`, no histograms will be displayed. + spatial_bin_centers=None : np.ndarray | None + Spatial bin centers for the histogram (each session activity + histogram will have the same spatial bin centers). + corrected_peak_locations_list : list[np.ndarray] | None + A list of corrected peak locations. If provided, the corrected + raster plots will be displayed. + corrected_session_histogram_list : list[np.ndarray] + A list of corrected session activity histograms, as + output from `align_sessions`. + drift_raster_map_kwargs : dict | None + Kwargs to be passed to `DriftRasterMapWidget`. + session_alignment_histogram_kwargs : dict | None + Kwargs to be passed to `ActivityHistogram1DWidget`. + **backend_kwargs + """ + + def __init__( + self, + recordings_list: list[BaseRecording], + peaks_list: list[np.ndarray], + peak_locations_list: list[np.ndarray], + session_histogram_list: list[np.ndarray] | None = None, + spatial_bin_centers: np.ndarray | None = None, + corrected_peak_locations_list: list[np.ndarray] | None = None, + corrected_session_histogram_list: list[np.ndarray] = None, + drift_raster_map_kwargs: dict | None = None, + session_alignment_histogram_kwargs: dict | None = None, + **backend_kwargs, + ): + assert len(recordings_list) <= 8, ( + "At present, this widget supports plotting up to 8 sessions. " + "Please contact SpikeInterface to discuss increasing." + ) + if corrected_session_histogram_list is not None: + if not len(corrected_session_histogram_list) == len(session_histogram_list): + raise ValueError( + "`corrected_session_histogram_list` must be the same length as `session_histogram_list`. " + "Entries should correspond exactly, with the histogram in each position being the corrected" + "version of `session_histogram_list`." + ) + if corrected_peak_locations_list is not None: + if not len(corrected_peak_locations_list) == len(peak_locations_list): + raise ValueError( + "`corrected_peak_locations_list` must be the same length as `peak_locations_list`. " + "Entries should correspond exactly, with the histogram in each position being the corrected" + "version of `peak_locations_list`." + ) + if (corrected_peak_locations_list is None) != (corrected_session_histogram_list is None): + raise ValueError( + "If either `corrected_peak_locations_list` or `corrected_session_histogram_list` " + "is passed, they must both be passed." + ) + + if drift_raster_map_kwargs is None: + drift_raster_map_kwargs = {} + + if session_alignment_histogram_kwargs is None: + session_alignment_histogram_kwargs = {} + + plot_data = dict( + recordings_list=recordings_list, + peaks_list=peaks_list, + peak_locations_list=peak_locations_list, + session_histogram_list=session_histogram_list, + spatial_bin_centers=spatial_bin_centers, + corrected_peak_locations_list=corrected_peak_locations_list, + corrected_session_histogram_list=corrected_session_histogram_list, + drift_raster_map_kwargs=drift_raster_map_kwargs, + session_alignment_histogram_kwargs=session_alignment_histogram_kwargs, + ) + + BaseWidget.__init__(self, plot_data, backend="matplotlib", **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + """ + Create the `SessionAlignmentWidget` for matplotlib. + """ + from spikeinterface.widgets.utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + assert backend_kwargs["axes"] is None, "axes argument is not allowed in SessionAlignmentWidget" + assert backend_kwargs["ax"] is None, "ax argument is not allowed in SessionAlignmentWidget" + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + fig = self.figure + fig.clear() + + # Find the min and max y peak position across all sessions so the + # axis can be set the same across all sessions + min_y = np.min(np.hstack([locs["y"] for locs in dp.peak_locations_list])) + max_y = np.max(np.hstack([locs["y"] for locs in dp.peak_locations_list])) + + # First, plot the peak location raster plots + + if dp.corrected_peak_locations_list is None: + # In this case, we only have uncorrected peak locations. We plot only the + # uncorrected raster maps. If there are more than 4 sessions, move + # onto the second row (usually reserved for the corrected peak raster). + + num_cols = np.min([4, len(dp.peak_locations_list)]) + num_rows = 1 if num_cols <= 4 else 2 + + ordered_row_col = list(itertools.product(range(num_rows), range(num_cols))) + + gs = fig.add_gridspec(num_rows + 1, num_cols, wspace=0.3, hspace=0.5) + + for i, row_col in enumerate(ordered_row_col): + + ax = fig.add_subplot(gs[row_col]) + + DriftRasterMapWidget( + dp.peaks_list[i], + dp.peak_locations_list[i], + recording=dp.recordings_list[i], + ax=ax, + **dp.drift_raster_map_kwargs, + ) + ax.set_ylim((min_y, max_y)) + else: + # In this case, we have corrected and unncorrected peak locations to + # plot in the raster. Uncorrected are on the first row and corrected are + # on the second. Each session is a new column. + + # Own function, then see if can compare + num_cols = len(dp.peak_locations_list) + num_rows = 2 + + gs = fig.add_gridspec(num_rows + 1, num_cols, wspace=0.3, hspace=0.5) + + for i in range(num_cols): + + ax_top = fig.add_subplot(gs[0, i]) + ax_bottom = fig.add_subplot(gs[1, i]) + + # Uncorrected session (row 1) + DriftRasterMapWidget( + dp.peaks_list[i], + dp.peak_locations_list[i], + recording=dp.recordings_list[i], + ax=ax_top, + **dp.drift_raster_map_kwargs, + ) + ax_top.set_title(f"Session {i + 1}") + ax_top.set_xlabel(None) + ax_top.set_ylim((min_y, max_y)) + + # Corrected session (row 2) + DriftRasterMapWidget( + dp.peaks_list[i], + dp.corrected_peak_locations_list[i], + recording=dp.recordings_list[i], + ax=ax_bottom, + **dp.drift_raster_map_kwargs, + ) + ax_bottom.set_title(f"Corrected Session {i + 1}") + ax_bottom.set_ylim((min_y, max_y)) + + # Next, plot the activity histograms under the raster plots + # If we only have uncorrected, plot taking up two columns. + # Otherwise, uncorrected histogram on the left column and + # corrected histgoram on the right column + if dp.session_histogram_list: + num_sessions = len(dp.session_histogram_list) + + if "legend" not in dp.session_alignment_histogram_kwargs: + sessions = [f"session {i + 1}" for i in range(num_sessions)] + dp.session_alignment_histogram_kwargs["legend"] = sessions + + if not dp.corrected_session_histogram_list: + + ax = fig.add_subplot(gs[num_rows, :]) + + ActivityHistogram1DWidget( + dp.session_histogram_list, + dp.spatial_bin_centers, + ax=ax, + **dp.session_alignment_histogram_kwargs, + ) + ax.legend(loc="upper left") + else: + + gs_sub = gs[num_rows, :].subgridspec(1, 2) + + ax_left = fig.add_subplot(gs_sub[0]) + ax_right = fig.add_subplot(gs_sub[1]) + + ActivityHistogram1DWidget( + dp.session_histogram_list, + dp.spatial_bin_centers, + ax=ax_left, + **dp.session_alignment_histogram_kwargs, + ) + ActivityHistogram1DWidget( + dp.corrected_session_histogram_list, + dp.spatial_bin_centers, + ax=ax_right, + **dp.session_alignment_histogram_kwargs, + ) + ax_left.get_legend().set_loc("upper right") + ax_left.set_title("Original Histogram") + ax_right.get_legend().set_loc("upper right") + ax_right.set_title("Corrected Histogram") + + +class ActivityHistogram1DWidget(BaseWidget): + """ + Plot 1D session activity histograms, overlaid on the same plot. + See SessionAlignmentWidget for detail. + + Parameters + ---------- + + session_histogram_list: list[np.ndarray] + List of 1D activity histograms to plot + spatial_bin_centers: list[np.ndarray] | np.ndarray | None + x-axis tick labels (bin centers of the histogram) + legend: None | list[str] = None + List of str to set as plot legend + linewidths: None | float | list[float] = 2, + Linewidths (list of linewidth for different across histograms, + otherwise `None` or specify shared linewidth with `float`. + colors: None | str | list[str] = None, + Colors to set the activity histograms. `None` uses matplotlib defautl colors. + """ + + def __init__( + self, + session_histogram_list: list[np.ndarray], + spatial_bin_centers: list[np.ndarray] | np.ndarray | None, + legend: None | list[str] = None, + linewidths: None | float | list[float] = 2.0, + colors: None | list = None, + **backend_kwargs, + ): + + plot_data = dict( + session_histogram_list=session_histogram_list, + spatial_bin_centers=spatial_bin_centers, + legend=legend, + linewidths=linewidths, + colors=colors, + ) + + BaseWidget.__init__(self, plot_data, backend="matplotlib", **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + from spikeinterface.widgets.utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + legend = dp.legend + colors = dp.colors + linewidths = dp.linewidths + spatial_bin_centers = dp.spatial_bin_centers + + assert backend_kwargs["axes"] is None, "`axes` argument not supported. Use `ax` to pass an axis to set." + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + num_histograms = len(dp.session_histogram_list) + + # If passed parameters are not unique across plots, then + # create as lists to set them for all plots + if isinstance(colors, int) or colors is None: + colors = [colors] * num_histograms + + if not isinstance(linewidths, (list, tuple)): + linewidths = [linewidths] * num_histograms + + spatial_bin_centers = [spatial_bin_centers] * num_histograms + + # If 2D, average across amplitude axis + if dp.session_histogram_list[0].ndim == 2: + histogram_list = [np.sum(hist_, axis=1) for hist_ in dp.session_histogram_list] + print( + "2D histogram passed, will be summed across first (i.e. amplitude) axis.\n" + "Use ActivityHistogram1DWidget to plot the 2D histograms directly." + ) + else: + histogram_list = dp.session_histogram_list + + # Plot the activity histograms + for i in range(num_histograms): + self.ax.plot(spatial_bin_centers[i], histogram_list[i], color=colors[i], linewidth=linewidths[i]) + + if legend is not None: + self.ax.legend(legend) + + self.ax.set_xlabel("Spatial bins (um)") + self.ax.set_ylabel("Activity (p.d.u)") + + +class ActivityHistogram2DWidget(BaseWidget): + """ + Plot 2D (spatial bin, amplitude bin) histograms following inter-session alignment. + The first column is uncorrected histograms, the second (if passed) is the corrected histogram. + + Parameters + ---------- + session_histogram_list : list[np.ndarray] + List of 2D activity histograms (one per sesson) + spatial_bin_centers : np.ndarray + Array of spatial bin centers (shared between all histograms) + corrected_session_histogram_list : None | list[np.ndarray] + A list of 2D corrected activity histograms (one per session, order + corresponding to `session_histogram_list`. + """ + + def __init__( + self, + session_histogram_list: list[np.ndarray], + spatial_bin_centers: np.ndarray, + corrected_session_histogram_list: None | list[np.ndarray] = None, + **backend_kwargs, + ): + if corrected_session_histogram_list: + if not (len(corrected_session_histogram_list) == len(session_histogram_list)): + raise ValueError( + "`corrected_session_histogram_list` must be the same" + "length as `session_histogram_list`, containing a " + "corrected histogram corresponding to each entry." + ) + + plot_data = dict( + session_histogram_list=session_histogram_list, + spatial_bin_centers=spatial_bin_centers, + corrected_session_histogram_list=corrected_session_histogram_list, + ) + + BaseWidget.__init__(self, plot_data, backend="matplotlib", **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + """ + Create the `SessionAlignmentWidget` for matplotlib. + """ + from spikeinterface.widgets.utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + assert backend_kwargs["axes"] is None, "axes argument is not allowed in ActivityHistogram1DWidget" + assert backend_kwargs["ax"] is None, "ax argument is not allowed in ActivityHistogram1DWidget" + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + fig = self.figure + fig.clear() + + num_sessions = len(dp.session_histogram_list) + has_corrected = dp.corrected_session_histogram_list is not None + + num_cols = 2 if has_corrected else 1 + gs = fig.add_gridspec(num_sessions, num_cols, wspace=0.3, hspace=0.5) + + # Show 8 (arbitrary numbers) ticks on the spatial bin axis + bin_centers = dp.spatial_bin_centers + divisor = int(bin_centers.size // 8) + xlabels = bin_centers[::divisor] + + for idx in range(num_sessions): + + # Plot uncorrected 2d histograms in the first column + ax = fig.add_subplot(gs[idx, 0]) + + num_bins = dp.session_histogram_list[idx].shape[0] + ax.imshow(dp.session_histogram_list[idx].T, aspect="auto") + + ax.set_title(f"Session {idx + 1}") + + self._set_plot_tick_labels(idx, num_sessions, ax, num_bins, xlabels, col=0) + + # If passed, plot corrected 2d histograms in the second column + if has_corrected: + ax = fig.add_subplot(gs[idx, 1]) + + ax.imshow(dp.corrected_session_histogram_list[idx].T, aspect="auto") + ax.set_title(f"Corrected Session {idx + 1}") + + self._set_plot_tick_labels(idx, num_sessions, ax, num_bins, xlabels, col=1) + + def _set_plot_tick_labels(self, idx, num_sessions, ax, num_bins, xlabels, col): + """ + Setup the plot labels. Only the bottom plots should show the x-axis (spatial) + bin ticks. On the left plots should show the y-axis (amplitude) bin label. + The amplitude bins are not specified in units (just bin number: TODO: Q: is this confusing?) + """ + if col == 0: + ax.set_ylabel("Amplitude Bin") + + if idx == num_sessions - 1: + # Set the x-ticks on the bottom plot only + ax.set_xticks(np.linspace(0, num_bins - 1, xlabels.size)) + ax.set_xticklabels([f"{i}" for i in xlabels], rotation=45) + ax.set_xlabel("Spatial Bins (µm)") + else: + ax.set_xticks([]) + ax.set_xticklabels([]) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index a0c7e1e28c..9979a35696 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -203,6 +203,7 @@ def __init__( peak_amplitudes = peak_amplitudes[peak_mask] from matplotlib.pyplot import colormaps + from matplotlib.colors import Normalize if color_amplitude: amps = peak_amplitudes @@ -214,7 +215,7 @@ def __init__( amps /= q_95 c = cmap(amps) else: - norm_function = Normalize(vmin=dp.clim[0], vmax=dp.clim[1], clip=True) + norm_function = Normalize(vmin=clim[0], vmax=clim[1], clip=True) c = cmap(norm_function(amps)) color_kwargs = dict( color=None, @@ -325,7 +326,6 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure - from spikeinterface.sortingcomponents.motion import correct_motion_on_peaks dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 8590aab948..6166f0036a 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -9,6 +9,7 @@ from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget from .crosscorrelograms import CrossCorrelogramsWidget +from .inter_session_alignment import SessionAlignmentWidget, ActivityHistogram1DWidget, ActivityHistogram2DWidget from .isi_distribution import ISIDistributionWidget from .motion import DriftRasterMapWidget, MotionWidget, MotionInfoWidget from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget @@ -38,6 +39,8 @@ from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget widget_list = [ + ActivityHistogram1DWidget, + ActivityHistogram2DWidget, AgreementMatrixWidget, AllAmplitudesDistributionsWidget, AmplitudesWidget, @@ -59,6 +62,7 @@ ProbeMapWidget, QualityMetricsWidget, RasterWidget, + SessionAlignmentWidget, SortingSummaryWidget, SpikeLocationsWidget, SpikesOnTracesWidget, @@ -112,6 +116,8 @@ # make function for all widgets +plot_activity_histogram_1d = ActivityHistogram1DWidget +plot_activity_histogram_2d = ActivityHistogram2DWidget plot_agreement_matrix = AgreementMatrixWidget plot_all_amplitudes_distributions = AllAmplitudesDistributionsWidget plot_amplitudes = AmplitudesWidget @@ -133,6 +139,7 @@ plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget plot_rasters = RasterWidget +plot_session_alignment = SessionAlignmentWidget plot_sorting_summary = SortingSummaryWidget plot_spike_locations = SpikeLocationsWidget plot_spikes_on_traces = SpikesOnTracesWidget