diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index 59a6b99158..ba65e7dc61 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -180,6 +180,8 @@ and doesn't save anything. If we wanted to save the extension we should have sta Available postprocessing extensions ----------------------------------- +.. _postprocessing_noise_levels: + noise_levels ^^^^^^^^^^^^ @@ -192,7 +194,7 @@ As an extension, this expects the :code:`Recording` as input and the computed va - +.. _postprocessing_principal_components: principal_components ^^^^^^^^^^^^^^^^^^^^ @@ -232,7 +234,7 @@ and is not well suited for high-density probes. For more information, see :py:func:`~spikeinterface.postprocessing.compute_template_similarity` - +.. _postprocessing_spike_amplitudes: spike_amplitudes ^^^^^^^^^^^^^^^^ @@ -249,6 +251,7 @@ each spike. For more information, see :py:func:`~spikeinterface.postprocessing.compute_spike_amplitudes` +.. _postprocessing_spike_locations: spike_locations ^^^^^^^^^^^^^^^ diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index f119693203..55244b0967 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -12,17 +12,99 @@ Completeness metrics (or 'false negative'/'type II' metrics) aim to identify whe Examples include: presence ratio, amplitude cutoff, NN-miss rate. Drift metrics aim to identify changes in waveforms which occur when spike sorters fail to successfully track neurons in the case of electrode drift. -Some metrics make use of principal component analysis (PCA) to reduce the dimensionality of computations. +The quality metrics are saved as an extension of a :doc:`SortingAnalyzer `. Some metrics can only be computed if certain extensions have been computed first. For example the drift metrics can only be computed the spike locations extension has been computed. By default, as many metrics as possible are computed. Which ones are computed depends on which other extensions have +been computed. + +In detail, the default metrics are (click on each metric to find out more about them!): + +- :doc:`qualitymetrics/firing_rate` +- :doc:`qualitymetrics/presence_ratio` +- :doc:`qualitymetrics/isi_violations` +- :doc:`qualitymetrics/sliding_rp_violations` +- :doc:`qualitymetrics/synchrony` +- :doc:`qualitymetrics/firing_range` + +If :ref:`postprocessing_spike_locations` are computed, add: + +- :doc:`qualitymetrics/drift` + +If :ref:`postprocessing_spike_amplitudes` and ``templates`` are computed, add: + +- :doc:`qualitymetrics/amplitude_cutoff` +- :doc:`qualitymetrics/amplitude_median` +- :doc:`qualitymetrics/amplitude_cv` +- :doc:`qualitymetrics/noise_cutoff` + +If :ref:`postprocessing_noise_levels` and ``templates`` are computed, add: + +- :doc:`qualitymetrics/snr` + +If the recording, :ref:`postprocessing_spike_amplitudes`and ``templates`` are available, add: + +- :doc:`qualitymetrics/sd_ratio` + +If :ref:`postprocessing_principal_components` are computed, add: + +- :doc:`qualitymetrics/isolation_distance` +- :doc:`qualitymetrics/l_ratio` +- :doc:`qualitymetrics/d_prime` +- :doc:`qualitymetrics/silhouette_score` +- :doc:`qualitymetrics/nearest_neighbor` (note: excluding the ``nn_noise_overlap`` metric) + +You can compute the default metrics using the following code snippet: + +.. code-block:: python + + # load or create a sorting analyzer + sorting_analyzer = si.load_sorting_analyzer(folder='my_sorting_analyzer') + + # compute the metrics + sorting_analyzer.compute("quality_metrics") + + # get the metrics in the form as a pandas DataFrame + quality_metrics = sorting_analyzer.get_extension("quality_metrics").get_data() + + # print the metrics that have been computed + print(quality_metrics.columns) + +Some metrics are very slow to compute when the number of units it large. So by default, the following metrics are not computed: + +- :doc:`qualitymetrics/isolation_distance` +- The ``nn_noise_overlap`` from :doc:`qualitymetrics/nearest_neighbor` + +Some metrics make use of :ref:`principal component analysis ` (PCA) to reduce the dimensionality of computations. Various approaches to computing the principal components are possible, and choice should be carefully considered in relation to the recording equipment used. -The following metrics make use of PCA: isolation distance, L-ratio, D-prime, Silhouette score and NN-metrics. -By contrast, the following metrics are based on spike times only: firing rate, ISI violations, presence ratio. -And amplitude cutoff and SNR are based on spike times as well as waveforms. -For more details about each metric and it's availability and use within SpikeInterface, see the individual pages for each metrics. +If you only want to compute a subset of metrics, you can use convenience functions to compute each one, + +.. code-block:: python + + from spikeinterface.quality_metrics import compute_isi_violations + compute_isi_violations(sorting_analyzer, isi_threshold_ms=3.0) + +or use the ``compute`` method + +.. code-block:: python + + sorting_analyzer.compute( + "quality_metrics", + metric_names = ["isi_violation", "snr"], + extension_params = { + "isi_violation": {"isi_threshold_ms": 3.0}, + } + ) + +Note that if you request a specific metric using ``metric_names`` and you do not have the required extension computed, this will error. + +For more information about quality metrics, check out this excellent +`documentation `_ +from the Allen Institute. + .. toctree:: :maxdepth: 1 :glob: + :hidden: qualitymetrics/amplitude_cutoff qualitymetrics/amplitude_cv @@ -42,28 +124,3 @@ For more details about each metric and it's availability and use within SpikeInt qualitymetrics/sliding_rp_violations qualitymetrics/snr qualitymetrics/synchrony - - -This code snippet shows how to compute quality metrics (with or without principal components) in SpikeInterface: - -.. code-block:: python - - sorting_analyzer = si.load_sorting_analyzer(folder='waveforms') # start from a sorting_analyzer - - # without PC (depends on "waveforms", "templates", and "noise_levels") - qm_ext = sorting_analyzer.compute(input="quality_metrics", metric_names=['snr'], skip_pc_metrics=True) - metrics = qm_ext.get_data() - assert 'snr' in metrics.columns - - # with PCs (depends on "pca" in addition to the above metrics) - - qm_ext = sorting_analyzer.compute(input={"principal_components": dict(n_components=5, mode="by_channel_local"), - "quality_metrics": dict(skip_pc_metrics=False)}) - metrics = qm_ext.get_data() - assert 'isolation_distance' in metrics.columns - - - -For more information about quality metrics, check out this excellent -`documentation `_ -from the Allen Institute. diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index f645f3416f..6837257f70 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -618,14 +618,14 @@ def test_extensions_sorting(): assert list(sorted_extensions_2.keys()) == list(extensions_in_order.keys()) # doing two movements - extensions_qm_left = {"quality_metrics": {}, "waveforms": {}, "templates": {}} - extensions_qm_correct = {"waveforms": {}, "templates": {}, "quality_metrics": {}} + extensions_qm_left = {"template_metrics": {}, "waveforms": {}, "templates": {}} + extensions_qm_correct = {"waveforms": {}, "templates": {}, "template_metrics": {}} sorted_extensions_3 = _sort_extensions_by_dependency(extensions_qm_left) assert list(sorted_extensions_3.keys()) == list(extensions_qm_correct.keys()) - # should move parent (waveforms) left of child (quality_metrics), and move grandparent (random_spikes) left of parent - extensions_qm_left = {"quality_metrics": {}, "waveforms": {}, "templates": {}, "random_spikes": {}} - extensions_qm_correct = {"random_spikes": {}, "waveforms": {}, "templates": {}, "quality_metrics": {}} + # should move parent (waveforms) left of child (template_metrics), and move grandparent (random_spikes) left of parent + extensions_qm_left = {"template_metrics": {}, "waveforms": {}, "templates": {}, "random_spikes": {}} + extensions_qm_correct = {"random_spikes": {}, "waveforms": {}, "templates": {}, "template_metrics": {}} sorted_extensions_4 = _sort_extensions_by_dependency(extensions_qm_left) assert list(sorted_extensions_4.keys()) == list(extensions_qm_correct.keys()) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 2f3393ccd7..e7b9dee2c7 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -9,7 +9,7 @@ from __future__ import annotations - +from .utils import _has_required_extensions from collections import namedtuple import math import warnings @@ -26,7 +26,6 @@ get_dense_templates_array, ) - numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: HAVE_NUMBA = True @@ -78,14 +77,8 @@ def compute_noise_cutoffs(sorting_analyzer, high_quantile=0.25, low_quantile=0.1 noise_cutoff_dict = {} noise_ratio_dict = {} - if not sorting_analyzer.has_extension("spike_amplitudes"): - warnings.warn( - "`compute_noise_cutoffs` requires the 'spike_amplitudes` extension. Please run sorting_analyzer.compute('spike_amplitudes') to be able to compute `noise_cutoff`" - ) - for unit_id in unit_ids: - noise_cutoff_dict[unit_id] = np.nan - noise_ratio_dict[unit_id] = np.nan - return res(noise_cutoff_dict, noise_ratio_dict) + + _has_required_extensions(sorting_analyzer, metric_name="noise_cutoff") amplitude_extension = sorting_analyzer.get_extension("spike_amplitudes") peak_sign = amplitude_extension.params["peak_sign"] @@ -374,14 +367,17 @@ def compute_snrs( snrs : dict Computed signal to noise ratio for each unit. """ - assert sorting_analyzer.has_extension("noise_levels") + + if unit_ids is None: + unit_ids = sorting_analyzer.unit_ids + + _has_required_extensions(sorting_analyzer, metric_name="snr") + noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() assert peak_sign in ("neg", "pos", "both") assert peak_mode in ("extremum", "at_index", "peak_to_peak") - if unit_ids is None: - unit_ids = sorting_analyzer.unit_ids channel_ids = sorting_analyzer.channel_ids extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, mode=peak_mode) @@ -908,12 +904,9 @@ def compute_amplitude_cv_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - if sorting_analyzer.has_extension(amplitude_extension): - amps = sorting_analyzer.get_extension(amplitude_extension).get_data() - else: - warnings.warn("compute_amplitude_cv_metrics() need 'spike_amplitudes' or 'amplitude_scalings'") - empty_dict = {unit_id: np.nan for unit_id in unit_ids} - return empty_dict + _has_required_extensions(sorting_analyzer, metric_name="amplitude_cv") + + amps = sorting_analyzer.get_extension(amplitude_extension).get_data() # precompute segment slice segment_slices = [] @@ -1040,35 +1033,30 @@ def compute_amplitude_cutoffs( unit_ids = sorting_analyzer.unit_ids all_fraction_missing = {} - if sorting_analyzer.has_extension("spike_amplitudes") or sorting_analyzer.has_extension("waveforms"): + _has_required_extensions(sorting_analyzer, metric_name="amplitude_cutoff") - invert_amplitudes = False - if ( - sorting_analyzer.has_extension("spike_amplitudes") - and sorting_analyzer.get_extension("spike_amplitudes").params["peak_sign"] == "pos" - ): - invert_amplitudes = True - elif sorting_analyzer.has_extension("waveforms") and peak_sign == "pos": - invert_amplitudes = True - - amplitudes_by_units = _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign) + invert_amplitudes = False + if ( + sorting_analyzer.has_extension("spike_amplitudes") + and sorting_analyzer.get_extension("spike_amplitudes").params["peak_sign"] == "pos" + ): + invert_amplitudes = True + elif sorting_analyzer.has_extension("waveforms") and peak_sign == "pos": + invert_amplitudes = True - for unit_id in unit_ids: - amplitudes = amplitudes_by_units[unit_id] - if invert_amplitudes: - amplitudes = -amplitudes + amplitudes_by_units = _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign) - all_fraction_missing[unit_id] = amplitude_cutoff( - amplitudes, num_histogram_bins, histogram_smoothing_value, amplitudes_bins_min_ratio - ) + for unit_id in unit_ids: + amplitudes = amplitudes_by_units[unit_id] + if invert_amplitudes: + amplitudes = -amplitudes - if np.any(np.isnan(list(all_fraction_missing.values()))): - warnings.warn(f"Some units have too few spikes : amplitude_cutoff is set to NaN") + all_fraction_missing[unit_id] = amplitude_cutoff( + amplitudes, num_histogram_bins, histogram_smoothing_value, amplitudes_bins_min_ratio + ) - else: - warnings.warn("compute_amplitude_cutoffs need 'spike_amplitudes' or 'waveforms' extension") - for unit_id in unit_ids: - all_fraction_missing[unit_id] = np.nan + if np.any(np.isnan(list(all_fraction_missing.values()))): + warnings.warn(f"Some units have too few spikes : amplitude_cutoff is set to NaN") return all_fraction_missing @@ -1106,15 +1094,12 @@ def compute_amplitude_medians(sorting_analyzer, peak_sign="neg", unit_ids=None): if unit_ids is None: unit_ids = sorting_analyzer.unit_ids + _has_required_extensions(sorting_analyzer, metric_name="amplitude_median") + all_amplitude_medians = {} - if sorting_analyzer.has_extension("spike_amplitudes") or sorting_analyzer.has_extension("waveforms"): - amplitudes_by_units = _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign) - for unit_id in unit_ids: - all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) - else: - warnings.warn("compute_amplitude_medians need 'spike_amplitudes' or 'waveforms' extension") - for unit_id in unit_ids: - all_amplitude_medians[unit_id] = np.nan + amplitudes_by_units = _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign) + for unit_id in unit_ids: + all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) return all_amplitude_medians @@ -1189,29 +1174,18 @@ def compute_drift_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - if sorting_analyzer.has_extension("spike_locations"): - spike_locations_ext = sorting_analyzer.get_extension("spike_locations") - spike_locations = spike_locations_ext.get_data() - # spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit") - spikes = sorting.to_spike_vector() - spike_locations_by_unit = {} - for unit_id in unit_ids: - unit_index = sorting.id_to_index(unit_id) - # TODO @alessio this is very slow this sjould be done with spike_vector_to_indices() in code - spike_mask = spikes["unit_index"] == unit_index - spike_locations_by_unit[unit_id] = spike_locations[spike_mask] + _has_required_extensions(sorting_analyzer, metric_name="drift") - else: - warnings.warn( - "The drift metrics require the `spike_locations` waveform extension. " - "Use the `postprocessing.compute_spike_locations()` function. " - "Drift metrics will be set to NaN" - ) - empty_dict = {unit_id: np.nan for unit_id in unit_ids} - if return_positions: - return res(empty_dict, empty_dict, empty_dict), np.nan - else: - return res(empty_dict, empty_dict, empty_dict) + spike_locations_ext = sorting_analyzer.get_extension("spike_locations") + spike_locations = spike_locations_ext.get_data() + # spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit") + spikes = sorting.to_spike_vector() + spike_locations_by_unit = {} + for unit_id in unit_ids: + unit_index = sorting.id_to_index(unit_id) + # TODO @alessio this is very slow this sjould be done with spike_vector_to_indices() in code + spike_mask = spikes["unit_index"] == unit_index + spike_locations_by_unit[unit_id] = spike_locations[spike_mask] interval_samples = int(interval_s * sorting_analyzer.sampling_frequency) assert direction in spike_locations.dtype.names, ( @@ -1644,24 +1618,16 @@ def compute_sd_ratio( ) return {unit_id: np.nan for unit_id in unit_ids} + _has_required_extensions(sorting_analyzer, metric_name="sd_ratio") + + spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() + if not HAVE_NUMBA: warnings.warn( "'sd_ratio' metric computation requires numba. Install it with >>> pip install numba. " "SD ratio metric will be set to NaN" ) return {unit_id: np.nan for unit_id in unit_ids} - - if sorting_analyzer.has_extension("spike_amplitudes"): - amplitudes_ext = sorting_analyzer.get_extension("spike_amplitudes") - spike_amplitudes = amplitudes_ext.get_data() - else: - warnings.warn( - "The `sd_ratio` metric require the `spike_amplitudes` waveform extension. " - "Use the `postprocessing.compute_spike_amplitudes()` functions. " - "SD ratio metric will be set to NaN" - ) - return {unit_id: np.nan for unit_id in unit_ids} - noise_levels = get_noise_levels( sorting_analyzer.recording, return_in_uV=sorting_analyzer.return_in_uV, method="std", **job_kwargs ) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 055fefc78c..5d338a990b 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -4,7 +4,7 @@ import warnings from itertools import chain -from copy import deepcopy +from copy import deepcopy, copy import numpy as np from warnings import warn @@ -19,6 +19,7 @@ _possible_pc_metric_names, qm_compute_name_to_column_names, column_name_to_column_dtype, + metric_extension_dependencies, ) from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params @@ -53,7 +54,7 @@ class ComputeQualityMetrics(AnalyzerExtension): """ extension_name = "quality_metrics" - depend_on = ["templates", "noise_levels"] + depend_on = [] need_recording = False use_nodepipeline = False need_job_kwargs = True @@ -83,7 +84,9 @@ def _set_params( metric_params = qm_params warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) + metric_names_is_none = False if metric_names is None: + metric_names_is_none = True metric_names = list(_misc_metric_name_to_func.keys()) # if PC is available, PC metrics are automatically added to the list if self.sorting_analyzer.has_extension("principal_components") and not skip_pc_metrics: @@ -92,10 +95,6 @@ def _set_params( pc_metrics.remove("nn_isolation") pc_metrics.remove("nn_noise_overlap") metric_names += pc_metrics - # if spike_locations are not available, drift is removed from the list - if not self.sorting_analyzer.has_extension("spike_locations"): - if "drift" in metric_names: - metric_names.remove("drift") metric_params_ = get_default_qm_params() for k in metric_params_: @@ -114,6 +113,26 @@ def _set_params( ] metric_names = metrics_to_compute + existing_metric_names_propagated + ## Deal with dependencies + computable_metrics_to_compute = copy(metrics_to_compute) + if metric_names_is_none: + need_more_extensions = False + warning_text = "Some metrics you are trying to compute depend on other extensions:\n" + for metric in metrics_to_compute: + metric_dependencies = metric_extension_dependencies.get(metric) + if metric_dependencies is not None: + for extension_name in metric_dependencies: + if all( + self.sorting_analyzer.has_extension(name) is False for name in extension_name.split("|") + ): + need_more_extensions = True + if metric in computable_metrics_to_compute: + computable_metrics_to_compute.remove(metric) + warning_text += f" {metric} requires {metric_dependencies}\n" + warning_text += "To include these metrics, compute the required extensions using `sorting_analyzer.compute('extension_name')" + if need_more_extensions: + warnings.warn(warning_text) + params = dict( metric_names=metric_names, peak_sign=peak_sign, @@ -121,7 +140,7 @@ def _set_params( metric_params=metric_params_, skip_pc_metrics=skip_pc_metrics, delete_existing_metrics=delete_existing_metrics, - metrics_to_compute=metrics_to_compute, + metrics_to_compute=computable_metrics_to_compute, ) return params diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 89fea6a25e..5e769ab8eb 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -2,6 +2,17 @@ from __future__ import annotations +# a dict containing the extension dependencies for each metric +metric_extension_dependencies = { + "snr": ["noise_levels", "templates"], + "amplitude_cutoff": ["spike_amplitudes|waveforms", "templates"], + "amplitude_median": ["spike_amplitudes|waveforms", "templates"], + "amplitude_cv": ["spike_amplitudes|amplitude_scalings", "templates"], + "drift": ["spike_locations"], + "sd_ratio": ["templates", "spike_amplitudes"], + "noise_cutoff": ["spike_amplitudes"], +} + from .misc_metrics import ( compute_num_spikes, @@ -55,6 +66,7 @@ "noise_cutoff": compute_noise_cutoffs, } + # a dict converting the name of the metric for computation to the output of that computation qm_compute_name_to_column_names = { "num_spikes": ["num_spikes"], diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 39bc62ae12..c2a6c6fe82 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -80,6 +80,6 @@ def sorting_analyzer_simple(): sorting_analyzer.compute("noise_levels") sorting_analyzer.compute("waveforms", **job_kwargs) sorting_analyzer.compute("templates") - sorting_analyzer.compute("spike_amplitudes", **job_kwargs) + sorting_analyzer.compute(["spike_amplitudes", "spike_locations"], **job_kwargs) return sorting_analyzer diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index db3e55b629..79f25ac772 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -120,19 +120,6 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): assert np.all(old_snr_data != new_snr_data) assert new_quality_metric_extension.params["metric_params"]["snr"]["peak_mode"] == "peak_to_peak" - # check that all quality metrics are deleted when parents are recomputed, even after - # recomputation - extensions_to_compute = { - "templates": {"operators": ["average", "median"]}, - "spike_amplitudes": {}, - "spike_locations": {}, - "principal_components": {}, - } - - small_sorting_analyzer.compute(extensions_to_compute) - - assert small_sorting_analyzer.get_extension("quality_metrics") is None - def test_metric_names_in_same_order(small_sorting_analyzer): """ diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index ea8939ebb4..36f2e0785a 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -9,6 +9,8 @@ aggregate_units, ) +from spikeinterface.qualitymetrics import compute_snrs + from spikeinterface.qualitymetrics import ( compute_quality_metrics, @@ -17,6 +19,32 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +def test_warnings_errors_when_missing_deps(): + """ + If the user requests to compute a quality metric which depends on an extension + that has not been computed, this should error. If the user uses the default + quality metrics (i.e. they do not explicitly request the specific metrics), + this should report a warning about which metrics could not be computed. + We check this behavior in this test. + """ + + recording, sorting = generate_ground_truth_recording() + analyzer = create_sorting_analyzer(sorting=sorting, recording=recording) + + # user tries to use `compute_snrs` without templates. Should error + with pytest.raises(ValueError): + compute_snrs(analyzer) + + # user asks for drift metrics without spike_locations. Should error + with pytest.raises(ValueError): + analyzer.compute("quality_metrics", metric_names=["drift"]) + + # user doesn't specify which metrics to compute. Should return a warning + # about which metrics have not been computed. + with pytest.warns(Warning): + analyzer.compute("quality_metrics") + + def test_compute_quality_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple diff --git a/src/spikeinterface/qualitymetrics/utils.py b/src/spikeinterface/qualitymetrics/utils.py index 844a7da7f5..90faf1a602 100644 --- a/src/spikeinterface/qualitymetrics/utils.py +++ b/src/spikeinterface/qualitymetrics/utils.py @@ -2,6 +2,29 @@ import numpy as np +from spikeinterface.qualitymetrics.quality_metric_list import metric_extension_dependencies + + +def _has_required_extensions(sorting_analyzer, metric_name): + + required_extensions = metric_extension_dependencies[metric_name] + + not_computed_required_extensions = [] + for ext in required_extensions: + if all(sorting_analyzer.has_extension(name) is False for name in ext.split("|")): + not_computed_required_extensions.append(ext) + + if len(not_computed_required_extensions) > 0: + warnings_string = f"The `{metric_name}` metric requires the {not_computed_required_extensions} extensions.\n" + warnings_string += "Use the sorting_analyzer.compute([" + for count, ext in enumerate(not_computed_required_extensions): + if count == len(not_computed_required_extensions) - 1: + warnings_string += f"'{ext}'" + else: + warnings_string += f"'{ext}', " + warnings_string += f"]) method to compute." + raise ValueError(warnings_string) + def create_ground_truth_pc_distributions(center_locations, total_points): """