diff --git a/neo/rawio/spikeglxrawio.py b/neo/rawio/spikeglxrawio.py index 8225a22c6..8583bcdcd 100644 --- a/neo/rawio/spikeglxrawio.py +++ b/neo/rawio/spikeglxrawio.py @@ -53,6 +53,7 @@ from pathlib import Path import os import re +from warnings import warn import numpy as np @@ -76,7 +77,9 @@ class SpikeGLXRawIO(BaseRawWithBufferApiIO): dirname: str, default: '' The spikeglx folder containing meta/bin files load_sync_channel: bool, default: False - The last channel (SY0) of each stream is a fake channel used for synchronisation + Can be used to load the synch stream as the last channel of the neural data. + This option is deprecated and will be removed in version 0.15. + From versions higher than 0.14.1 the sync channel is always loaded as a separate stream. load_channel_location: bool, default: False If True probeinterface is used to load the channel locations from the directory @@ -109,6 +112,12 @@ def __init__(self, dirname="", load_sync_channel=False, load_channel_location=Fa BaseRawWithBufferApiIO.__init__(self) self.dirname = dirname self.load_sync_channel = load_sync_channel + if load_sync_channel: + warn( + "The load_sync_channel=True option is deprecated and will be removed in version 0.15 \n" + "The sync channel is now loaded as a separate stream by default and should be accessed as such. ", + DeprecationWarning, stacklevel=2 + ) self.load_channel_location = load_channel_location def _source_name(self): @@ -152,6 +161,8 @@ def _parse_header(self): signal_buffers = [] signal_streams = [] signal_channels = [] + sync_stream_id_to_buffer_id = {} + for stream_name in stream_names: # take first segment info = self.signals_info_dict[0, stream_name] @@ -168,6 +179,16 @@ def _parse_header(self): for local_chan in range(info["num_chan"]): chan_name = info["channel_names"][local_chan] chan_id = f"{stream_name}#{chan_name}" + + # Sync channel + if "nidq" not in stream_name and "SY0" in chan_name and not self.load_sync_channel and local_chan == info["num_chan"] - 1: + # This is a sync channel and should be added as its own stream + sync_stream_id = f"{stream_name}-SYNC" + sync_stream_id_to_buffer_id[sync_stream_id] = buffer_id + stream_id_for_chan = sync_stream_id + else: + stream_id_for_chan = stream_id + signal_channels.append( ( chan_name, @@ -177,25 +198,33 @@ def _parse_header(self): info["units"], info["channel_gains"][local_chan], info["channel_offsets"][local_chan], - stream_id, + stream_id_for_chan, buffer_id, ) ) - # all channel by dafult unless load_sync_channel=False + # all channel by default unless load_sync_channel=False self._stream_buffer_slice[stream_id] = None + # check sync channel validity if "nidq" not in stream_name: if not self.load_sync_channel and info["has_sync_trace"]: - # the last channel is remove from the stream but not from the buffer - last_chan = signal_channels[-1] - last_chan = last_chan[:-2] + ("", buffer_id) - signal_channels = signal_channels[:-1] + [last_chan] + # the last channel is removed from the stream but not from the buffer self._stream_buffer_slice[stream_id] = slice(0, -1) + + # Add a buffer slice for the sync channel + sync_stream_id = f"{stream_name}-SYNC" + self._stream_buffer_slice[sync_stream_id] = slice(-1, None) + if self.load_sync_channel and not info["has_sync_trace"]: raise ValueError("SYNC channel is not present in the recording. " "Set load_sync_channel to False") signal_buffers = np.array(signal_buffers, dtype=_signal_buffer_dtype) + + # Add sync channels as their own streams + for sync_stream_id, buffer_id in sync_stream_id_to_buffer_id.items(): + signal_streams.append((sync_stream_id, sync_stream_id, buffer_id)) + signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype) signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype) @@ -237,6 +266,14 @@ def _parse_header(self): t_start = frame_start / sampling_frequency self._t_starts[stream_name][seg_index] = t_start + + # This need special logic because sync not present in stream_names + if f"{stream_name}-SYNC" in signal_streams["name"]: + sync_stream_name = f"{stream_name}-SYNC" + if sync_stream_name not in self._t_starts: + self._t_starts[sync_stream_name] = {} + self._t_starts[sync_stream_name][seg_index] = t_start + t_stop = info["sample_length"] / info["sampling_rate"] self._t_stops[seg_index] = max(self._t_stops[seg_index], t_stop) @@ -265,7 +302,11 @@ def _parse_header(self): if self.load_channel_location: # need probeinterface to be installed import probeinterface - + + # Skip for sync streams + if "SYNC" in stream_name: + continue + info = self.signals_info_dict[seg_index, stream_name] if "imroTbl" in info["meta"] and info["stream_kind"] == "ap": # only for ap channel diff --git a/neo/test/rawiotest/test_spikeglxrawio.py b/neo/test/rawiotest/test_spikeglxrawio.py index dd00b1e9f..4111f403b 100644 --- a/neo/test/rawiotest/test_spikeglxrawio.py +++ b/neo/test/rawiotest/test_spikeglxrawio.py @@ -55,7 +55,7 @@ def test_loading_only_one_probe_in_multi_probe_scenario(self): rawio = SpikeGLXRawIO(probe_folder_path) rawio.parse_header() - expected_stream_names = ["imec1.ap", "imec1.lf"] + expected_stream_names = ["imec1.ap", "imec1.lf", "imec1.ap-SYNC", "imec1.lf-SYNC"] actual_stream_names = rawio.header["signal_streams"]["name"].tolist() assert ( actual_stream_names == expected_stream_names @@ -130,6 +130,30 @@ def test_nidq_digital_channel(self): atol = 0.001 assert np.allclose(on_diff, 1, atol=atol) + def test_sync_channel_as_separate_stream(self): + """Test that sync channel is added as its own stream when load_sync_channel=False.""" + import warnings + + # Test with load_sync_channel=False (default) + rawio_no_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_with_sync"), load_sync_channel=False) + rawio_no_sync.parse_header() + + # Get stream names + stream_names = rawio_no_sync.header["signal_streams"]["name"].tolist() + + # Check if there's a sync channel stream (should contain "SY0" or "SYNC" in the name) + sync_streams = [name for name in stream_names if "SY0" in name or "SYNC" in name] + assert len(sync_streams) > 0, "No sync channel stream found when load_sync_channel=False" + + # Test deprecation warning when load_sync_channel=True + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + rawio_with_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_with_sync"), load_sync_channel=True) + + # Check if deprecation warning was raised + assert any(issubclass(warning.category, DeprecationWarning) for warning in w), "No deprecation warning raised" + assert any("will be removed in version 0.15" in str(warning.message) for warning in w), "Deprecation warning message is incorrect" + def test_t_start_reading(self): """Test that t_start values are correctly read for all streams and segments."""