diff --git a/src/evaluate/module.py b/src/evaluate/module.py index b4d73e5e..af1c5d37 100644 --- a/src/evaluate/module.py +++ b/src/evaluate/module.py @@ -874,11 +874,14 @@ def __init__(self, evaluation_modules, force_prefix=False): from .loading import load # avoid circular imports self.evaluation_module_names = None - if isinstance(evaluation_modules, list): + if isinstance(evaluation_modules, (list, tuple)): self.evaluation_modules = evaluation_modules elif isinstance(evaluation_modules, dict): self.evaluation_modules = list(evaluation_modules.values()) self.evaluation_module_names = list(evaluation_modules.keys()) + else: + raise ValueError("`evaluation_modules` should be a list, tuple or dict") + loaded_modules = [] for module in self.evaluation_modules: diff --git a/tests/test_metric.py b/tests/test_metric.py index 598b0f92..dadf5595 100644 --- a/tests/test_metric.py +++ b/tests/test_metric.py @@ -8,7 +8,7 @@ import pytest from datasets.features import Features, Sequence, Value -from evaluate.module import EvaluationModule, EvaluationModuleInfo, combine +from evaluate.module import CombinedEvaluations, EvaluationModule, EvaluationModuleInfo, combine from .utils import require_tf, require_torch @@ -757,3 +757,16 @@ def test_modules_from_string_poslabel(self): self.assertDictEqual( expected_result, combined_evaluation.compute(predictions=predictions, references=references, pos_label=0) ) + + +@pytest.mark.parametrize( + "evaluations,", + ( + [DummyMetric(), DummyMetric()], + (DummyMetric(), DummyMetric()), + {"metric1": DummyMetric(), "metric2": DummyMetric()}, + ), +) +def test_combine_evaluations_in_different_forms(evaluations): + combined_evaluation = combine(evaluations) + assert isinstance(combined_evaluation, CombinedEvaluations)