@@ -22,6 +22,7 @@ def generate_session_displacement_recordings(
22
22
recording_shifts = ((0 , 0 ), (0 , 25 ), (0 , 50 )),
23
23
non_rigid_gradient = None ,
24
24
recording_amplitude_scalings = None ,
25
+ shift_units_outside_probe = False ,
25
26
sampling_frequency = 30000.0 ,
26
27
probe_name = "Neuropixel-128" ,
27
28
generate_probe_kwargs = None ,
@@ -87,8 +88,16 @@ def generate_session_displacement_recordings(
87
88
"scalings" - a list of numpy arrays, one for each recording, with
88
89
each entry an array of length num_units holding the unit scalings.
89
90
e.g. for 3 recordings, 2 units: ((1, 1), (1, 1), (0.5, 0.5)).
91
+ shift_units_outside_probe : bool
92
+ By default (`False`), when units are shifted across sessions, new units are
93
+ not introduced into the recording (e.g. the region in which units
94
+ have been shifted out of is left at baseline level). In reality,
95
+ when the probe shifts new units from outside the original recorded
96
+ region are shifted into the recording. When `True`, new units
97
+ are shifted into the generated recording.
90
98
generate_sorting_kwargs : dict
91
99
Only `firing_rates` and `refractory_period_ms` are expected if passed.
100
+
92
101
All other parameters are used as in from `generate_drifting_recording()`.
93
102
94
103
Returns
@@ -105,7 +114,6 @@ def generate_session_displacement_recordings(
105
114
A list (length num records) of (num_units, num_samples, num_channels)
106
115
arrays of templates that have been shifted.
107
116
108
-
109
117
Notes
110
118
-----
111
119
It is important to consider what unit properties are maintained
@@ -141,12 +149,28 @@ def generate_session_displacement_recordings(
141
149
142
150
# Fix generate template kwargs, so they are the same for every created recording.
143
151
# Also fix unit firing rates across recordings.
144
- generate_templates_kwargs = fix_generate_templates_kwargs (generate_templates_kwargs , num_units , seed )
152
+ fixed_generate_templates_kwargs = fix_generate_templates_kwargs (generate_templates_kwargs , num_units , seed )
145
153
146
154
fixed_firing_rates = _ensure_firing_rates (generate_sorting_kwargs ["firing_rates" ], num_units , seed )
147
- generate_sorting_kwargs ["firing_rates" ] = fixed_firing_rates
155
+ fixed_generate_sorting_kwargs = copy .deepcopy (generate_sorting_kwargs )
156
+ fixed_generate_sorting_kwargs ["firing_rates" ] = fixed_firing_rates
157
+
158
+ if shift_units_outside_probe :
159
+ num_units , unit_locations , fixed_generate_templates_kwargs , fixed_generate_sorting_kwargs = (
160
+ _update_kwargs_for_extended_units (
161
+ num_units ,
162
+ channel_locations ,
163
+ unit_locations ,
164
+ generate_unit_locations_kwargs ,
165
+ generate_templates_kwargs ,
166
+ generate_sorting_kwargs ,
167
+ fixed_generate_templates_kwargs ,
168
+ fixed_generate_sorting_kwargs ,
169
+ seed ,
170
+ )
171
+ )
148
172
149
- # Start looping over parameters, creating recordings shifted
173
+ # Start looping over parameters, creating recordings shifted
150
174
# and scaled as required
151
175
extra_outputs_dict = {
152
176
"unit_locations" : [],
@@ -174,7 +198,7 @@ def generate_session_displacement_recordings(
174
198
num_units = num_units ,
175
199
sampling_frequency = sampling_frequency ,
176
200
durations = [duration ],
177
- ** generate_sorting_kwargs ,
201
+ ** fixed_generate_sorting_kwargs ,
178
202
extra_outputs = True ,
179
203
seed = seed ,
180
204
)
@@ -195,7 +219,7 @@ def generate_session_displacement_recordings(
195
219
unit_locations_moved ,
196
220
sampling_frequency = sampling_frequency ,
197
221
seed = seed ,
198
- ** generate_templates_kwargs ,
222
+ ** fixed_generate_templates_kwargs ,
199
223
)
200
224
201
225
if recording_amplitude_scalings is not None :
@@ -210,7 +234,7 @@ def generate_session_displacement_recordings(
210
234
211
235
# Bring it all together in a `InjectTemplatesRecording` and
212
236
# propagate all relevant metadata to the recording.
213
- ms_before = generate_templates_kwargs ["ms_before" ]
237
+ ms_before = fixed_generate_templates_kwargs ["ms_before" ]
214
238
nbefore = int (sampling_frequency * ms_before / 1000.0 )
215
239
216
240
recording = InjectTemplatesRecording (
@@ -388,3 +412,89 @@ def _check_generate_session_displacement_arguments(
388
412
"The entry for each recording in `recording_amplitude_scalings` "
389
413
"must have the same length as the number of units."
390
414
)
415
+
416
+
417
+ def _update_kwargs_for_extended_units (
418
+ num_units ,
419
+ channel_locations ,
420
+ unit_locations ,
421
+ generate_unit_locations_kwargs ,
422
+ generate_templates_kwargs ,
423
+ generate_sorting_kwargs ,
424
+ fixed_generate_templates_kwargs ,
425
+ fixed_generate_sorting_kwargs ,
426
+ seed ,
427
+ ):
428
+ """
429
+ In a real world situation, if the probe moves up / down
430
+ not only will previously recorded units be shifted, but
431
+ new units will be introduced into the recording.
432
+
433
+ This function extends the default num units, unit locations,
434
+ and template / sorting kwargs to extend the unit of units
435
+ one probe's height (y dimension) above and below the probe.
436
+ Now, when the unit locations are shifted, new units will be
437
+ introduced into the recording (from below or above).
438
+
439
+ It is important that the unit kwargs for the units are kept the
440
+ same across runs when seeded (i.e. whether `shift_units_outside_probe`
441
+ is `True` or `False`). To acheive this, the fixed unit kwargs
442
+ are extended with new units located above and below these fixed
443
+ units. The seeds are shifted slightly, so the introduced
444
+ units do not duplicate the existing units.
445
+
446
+ """
447
+ seed_top = seed + 1 if seed is not None else None
448
+ seed_bottom = seed - 1 if seed is not None else None
449
+
450
+ # Set unit locations above and below the probe and extend
451
+ # the `unit_locations` array.
452
+ channel_locations_extend_top = channel_locations .copy ()
453
+ channel_locations_extend_top [:, 1 ] -= np .max (channel_locations [:, 1 ])
454
+
455
+ extend_top_locations = generate_unit_locations (
456
+ num_units ,
457
+ channel_locations_extend_top ,
458
+ seed = seed_top ,
459
+ ** generate_unit_locations_kwargs ,
460
+ )
461
+
462
+ channel_locations_extend_bottom = channel_locations .copy ()
463
+ channel_locations_extend_bottom [:, 1 ] += np .max (channel_locations [:, 1 ])
464
+
465
+ extend_bottom_locations = generate_unit_locations (
466
+ num_units ,
467
+ channel_locations_extend_bottom ,
468
+ seed = seed_bottom ,
469
+ ** generate_unit_locations_kwargs ,
470
+ )
471
+
472
+ unit_locations = np .r_ [extend_bottom_locations , unit_locations , extend_top_locations ]
473
+
474
+ # For the new units located above and below the probe, generate a set of
475
+ # firing rates and template kwargs.
476
+
477
+ # Extend the template kwargs
478
+ template_kwargs_top = fix_generate_templates_kwargs (generate_templates_kwargs , num_units , seed_top )
479
+ template_kwargs_bottom = fix_generate_templates_kwargs (generate_templates_kwargs , num_units , seed_bottom )
480
+
481
+ for key in fixed_generate_templates_kwargs ["unit_params" ].keys ():
482
+ fixed_generate_templates_kwargs ["unit_params" ][key ] = np .r_ [
483
+ template_kwargs_top ["unit_params" ][key ],
484
+ fixed_generate_templates_kwargs ["unit_params" ][key ],
485
+ template_kwargs_bottom ["unit_params" ][key ],
486
+ ]
487
+
488
+ # Extend the firing rates
489
+ firing_rates_top = _ensure_firing_rates (generate_sorting_kwargs ["firing_rates" ], num_units , seed_top )
490
+ firing_rates_bottom = _ensure_firing_rates (generate_sorting_kwargs ["firing_rates" ], num_units , seed_bottom )
491
+
492
+ fixed_generate_sorting_kwargs ["firing_rates" ] = np .r_ [
493
+ firing_rates_top , fixed_generate_sorting_kwargs ["firing_rates" ], firing_rates_bottom
494
+ ]
495
+
496
+ # Update the number of units (3x as a
497
+ # new set above and below the existing units)
498
+ num_units *= 3
499
+
500
+ return num_units , unit_locations , fixed_generate_templates_kwargs , fixed_generate_sorting_kwargs
0 commit comments