Skip to content

Commit 05a460a

Browse files
committed
Add shift_units_outside_probe.
1 parent c471c21 commit 05a460a

File tree

2 files changed

+159
-7
lines changed

2 files changed

+159
-7
lines changed

src/spikeinterface/generation/session_displacement_generator.py

Lines changed: 117 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def generate_session_displacement_recordings(
2222
recording_shifts=((0, 0), (0, 25), (0, 50)),
2323
non_rigid_gradient=None,
2424
recording_amplitude_scalings=None,
25+
shift_units_outside_probe=False,
2526
sampling_frequency=30000.0,
2627
probe_name="Neuropixel-128",
2728
generate_probe_kwargs=None,
@@ -87,8 +88,16 @@ def generate_session_displacement_recordings(
8788
"scalings" - a list of numpy arrays, one for each recording, with
8889
each entry an array of length num_units holding the unit scalings.
8990
e.g. for 3 recordings, 2 units: ((1, 1), (1, 1), (0.5, 0.5)).
91+
shift_units_outside_probe : bool
92+
By default (`False`), when units are shifted across sessions, new units are
93+
not introduced into the recording (e.g. the region in which units
94+
have been shifted out of is left at baseline level). In reality,
95+
when the probe shifts new units from outside the original recorded
96+
region are shifted into the recording. When `True`, new units
97+
are shifted into the generated recording.
9098
generate_sorting_kwargs : dict
9199
Only `firing_rates` and `refractory_period_ms` are expected if passed.
100+
92101
All other parameters are used as in from `generate_drifting_recording()`.
93102
94103
Returns
@@ -105,7 +114,6 @@ def generate_session_displacement_recordings(
105114
A list (length num records) of (num_units, num_samples, num_channels)
106115
arrays of templates that have been shifted.
107116
108-
109117
Notes
110118
-----
111119
It is important to consider what unit properties are maintained
@@ -141,12 +149,28 @@ def generate_session_displacement_recordings(
141149

142150
# Fix generate template kwargs, so they are the same for every created recording.
143151
# Also fix unit firing rates across recordings.
144-
generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed)
152+
fixed_generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed)
145153

146154
fixed_firing_rates = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed)
147-
generate_sorting_kwargs["firing_rates"] = fixed_firing_rates
155+
fixed_generate_sorting_kwargs = copy.deepcopy(generate_sorting_kwargs)
156+
fixed_generate_sorting_kwargs["firing_rates"] = fixed_firing_rates
157+
158+
if shift_units_outside_probe:
159+
num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs = (
160+
_update_kwargs_for_extended_units(
161+
num_units,
162+
channel_locations,
163+
unit_locations,
164+
generate_unit_locations_kwargs,
165+
generate_templates_kwargs,
166+
generate_sorting_kwargs,
167+
fixed_generate_templates_kwargs,
168+
fixed_generate_sorting_kwargs,
169+
seed,
170+
)
171+
)
148172

149-
# Start looping over parameters, creating recordings shifted
173+
# Start looping over parameters, creating recordings shifted
150174
# and scaled as required
151175
extra_outputs_dict = {
152176
"unit_locations": [],
@@ -174,7 +198,7 @@ def generate_session_displacement_recordings(
174198
num_units=num_units,
175199
sampling_frequency=sampling_frequency,
176200
durations=[duration],
177-
**generate_sorting_kwargs,
201+
**fixed_generate_sorting_kwargs,
178202
extra_outputs=True,
179203
seed=seed,
180204
)
@@ -195,7 +219,7 @@ def generate_session_displacement_recordings(
195219
unit_locations_moved,
196220
sampling_frequency=sampling_frequency,
197221
seed=seed,
198-
**generate_templates_kwargs,
222+
**fixed_generate_templates_kwargs,
199223
)
200224

201225
if recording_amplitude_scalings is not None:
@@ -210,7 +234,7 @@ def generate_session_displacement_recordings(
210234

211235
# Bring it all together in a `InjectTemplatesRecording` and
212236
# propagate all relevant metadata to the recording.
213-
ms_before = generate_templates_kwargs["ms_before"]
237+
ms_before = fixed_generate_templates_kwargs["ms_before"]
214238
nbefore = int(sampling_frequency * ms_before / 1000.0)
215239

216240
recording = InjectTemplatesRecording(
@@ -388,3 +412,89 @@ def _check_generate_session_displacement_arguments(
388412
"The entry for each recording in `recording_amplitude_scalings` "
389413
"must have the same length as the number of units."
390414
)
415+
416+
417+
def _update_kwargs_for_extended_units(
418+
num_units,
419+
channel_locations,
420+
unit_locations,
421+
generate_unit_locations_kwargs,
422+
generate_templates_kwargs,
423+
generate_sorting_kwargs,
424+
fixed_generate_templates_kwargs,
425+
fixed_generate_sorting_kwargs,
426+
seed,
427+
):
428+
"""
429+
In a real world situation, if the probe moves up / down
430+
not only will previously recorded units be shifted, but
431+
new units will be introduced into the recording.
432+
433+
This function extends the default num units, unit locations,
434+
and template / sorting kwargs to extend the unit of units
435+
one probe's height (y dimension) above and below the probe.
436+
Now, when the unit locations are shifted, new units will be
437+
introduced into the recording (from below or above).
438+
439+
It is important that the unit kwargs for the units are kept the
440+
same across runs when seeded (i.e. whether `shift_units_outside_probe`
441+
is `True` or `False`). To acheive this, the fixed unit kwargs
442+
are extended with new units located above and below these fixed
443+
units. The seeds are shifted slightly, so the introduced
444+
units do not duplicate the existing units.
445+
446+
"""
447+
seed_top = seed + 1 if seed is not None else None
448+
seed_bottom = seed - 1 if seed is not None else None
449+
450+
# Set unit locations above and below the probe and extend
451+
# the `unit_locations` array.
452+
channel_locations_extend_top = channel_locations.copy()
453+
channel_locations_extend_top[:, 1] -= np.max(channel_locations[:, 1])
454+
455+
extend_top_locations = generate_unit_locations(
456+
num_units,
457+
channel_locations_extend_top,
458+
seed=seed_top,
459+
**generate_unit_locations_kwargs,
460+
)
461+
462+
channel_locations_extend_bottom = channel_locations.copy()
463+
channel_locations_extend_bottom[:, 1] += np.max(channel_locations[:, 1])
464+
465+
extend_bottom_locations = generate_unit_locations(
466+
num_units,
467+
channel_locations_extend_bottom,
468+
seed=seed_bottom,
469+
**generate_unit_locations_kwargs,
470+
)
471+
472+
unit_locations = np.r_[extend_bottom_locations, unit_locations, extend_top_locations]
473+
474+
# For the new units located above and below the probe, generate a set of
475+
# firing rates and template kwargs.
476+
477+
# Extend the template kwargs
478+
template_kwargs_top = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_top)
479+
template_kwargs_bottom = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_bottom)
480+
481+
for key in fixed_generate_templates_kwargs["unit_params"].keys():
482+
fixed_generate_templates_kwargs["unit_params"][key] = np.r_[
483+
template_kwargs_top["unit_params"][key],
484+
fixed_generate_templates_kwargs["unit_params"][key],
485+
template_kwargs_bottom["unit_params"][key],
486+
]
487+
488+
# Extend the firing rates
489+
firing_rates_top = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_top)
490+
firing_rates_bottom = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_bottom)
491+
492+
fixed_generate_sorting_kwargs["firing_rates"] = np.r_[
493+
firing_rates_top, fixed_generate_sorting_kwargs["firing_rates"], firing_rates_bottom
494+
]
495+
496+
# Update the number of units (3x as a
497+
# new set above and below the existing units)
498+
num_units *= 3
499+
500+
return num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs

src/spikeinterface/generation/tests/test_session_displacement_generator.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,48 @@ def test_metadata(self, options):
370370
)
371371
assert output_sortings[i].name == "InterSessionDisplacementSorting"
372372

373+
def test_shift_units_outside_probe(self, options):
374+
"""
375+
When `shift_units_outside_probe` is `True`, a new set of
376+
units above and below the probe (y dimension) are created,
377+
such that they may be shifted into the recording.
378+
379+
Here, check that these new units are created when `shift_units_outside_probe`
380+
is on and that the kwargs for the central set of units match those
381+
as when `shift_units_outside_probe` is `False`.
382+
"""
383+
num_sessions = len(options["kwargs"]["recording_durations"])
384+
_, _, baseline_outputs = generate_session_displacement_recordings(
385+
**options["kwargs"],
386+
)
387+
388+
_, _, outside_probe_outputs = generate_session_displacement_recordings(
389+
**options["kwargs"], shift_units_outside_probe=True
390+
)
391+
392+
num_units = options["kwargs"]["num_units"]
393+
num_extended_units = num_units * 3
394+
395+
for ses_idx in range(num_sessions):
396+
397+
# There are 3x the number of units when new units are created
398+
# (one new set above, and one new set below the probe).
399+
for key in ["unit_locations", "templates_array_moved", "firing_rates"]:
400+
assert outside_probe_outputs[key][ses_idx].shape[0] == num_extended_units
401+
402+
assert np.array_equal(
403+
baseline_outputs[key][ses_idx], outside_probe_outputs[key][ses_idx][num_units:-num_units]
404+
)
405+
406+
# The kwargs of the units in the central positions should be identical
407+
# to those when `shift_units_outside_probe` is `False`.
408+
lower_unit_pos = outside_probe_outputs["unit_locations"][ses_idx][-num_units:][:, 1]
409+
upper_unit_pos = outside_probe_outputs["unit_locations"][ses_idx][:num_units][:, 1]
410+
middle_unit_pos = baseline_outputs["unit_locations"][ses_idx][:, 1]
411+
412+
assert np.min(upper_unit_pos) > np.max(middle_unit_pos)
413+
assert np.max(lower_unit_pos) < np.min(middle_unit_pos)
414+
373415
def test_same_as_generate_ground_truth_recording(self):
374416
"""
375417
It is expected that inter-session displacement randomly

0 commit comments

Comments
 (0)