From aec9e9339f14dc6dd3ba01cdcfc24628b5ebd2d2 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Wed, 1 May 2024 10:45:38 -0400 Subject: [PATCH 1/2] ENH: Add options to consolidate with ecg2x notebooks --- ml4ht/data/data_loader.py | 17 ++++++++++++++--- ml4ht/data/sample_getter.py | 6 +++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/ml4ht/data/data_loader.py b/ml4ht/data/data_loader.py index d5f7bd1..877a662 100644 --- a/ml4ht/data/data_loader.py +++ b/ml4ht/data/data_loader.py @@ -282,7 +282,10 @@ def __getitem__(self, item: int) -> Batch: ) -def numpy_collate_fn(samples: List[Batch]) -> Batch: +def numpy_collate_fn( + samples: List[Batch], + auto_float: bool = True +) -> Batch: """ Merges a list of ml4ht batch formatted data. Can be used as 'collate_fn` in torch.utils.data.DataLoader @@ -290,13 +293,21 @@ def numpy_collate_fn(samples: List[Batch]) -> Batch: """ # construct correctly-shaped empty arrays for input and output of model in_batch_keys = list(samples[0][0]) + if auto_float: + in_dtypes = {k: np.float32 for k in in_batch_keys} + else: + in_dtypes = {k: samples[0][0][k].dtype for k in in_batch_keys} in_batch = { - k: np.empty((len(samples),) + samples[0][0][k].shape, dtype=np.float32) + k: np.empty((len(samples),) + samples[0][0][k].shape, dtype=in_dtypes[k]) for k in in_batch_keys } out_batch_keys = list(samples[0][1]) + if auto_float: + out_dtypes = {k: np.float32 for k in out_batch_keys} + else: + out_dtypes = {k: samples[0][1][k].dtype for k in out_batch_keys} out_batch = { - k: np.empty((len(samples),) + samples[0][1][k].shape, dtype=np.float32) + k: np.empty((len(samples),) + samples[0][1][k].shape, dtype=out_dtypes[k]) for k in out_batch_keys } # fill in the values of the input and output arrays diff --git a/ml4ht/data/sample_getter.py b/ml4ht/data/sample_getter.py index fb89aff..4a20e98 100644 --- a/ml4ht/data/sample_getter.py +++ b/ml4ht/data/sample_getter.py @@ -28,11 +28,13 @@ def __init__( self, input_data_descriptions: List[DataDescription], output_data_descriptions: List[DataDescription], - option_picker: OptionPicker = None, + option_picker: OptionPicker=None, + restricted_sample_id_idx=None, ): self.input_data_descriptions = input_data_descriptions self.output_data_descriptions = output_data_descriptions self.option_picker = option_picker or self._default_option_picker + self.restricted_sample_id_idx = restricted_sample_id_idx @staticmethod def _default_option_picker( @@ -67,6 +69,8 @@ def __call__(self, sample_id: SampleID) -> Batch: sample_id, self.input_data_descriptions + self.output_data_descriptions, ) + if self.restricted_sample_id_idx is not None: + sample_id = sample_id[self.restricted_sample_id_idx] tensors_in = self._half_batch(sample_id, loading_options, True) tensors_out = self._half_batch(sample_id, loading_options, False) return tensors_in, tensors_out From 90d369e1a3491e1ba4fdd8735c05d0f8c75138c1 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Wed, 14 Aug 2024 09:12:34 -0400 Subject: [PATCH 2/2] STYLE: Pass style checks --- ml4ht/data/data_loader.py | 2 +- ml4ht/data/sample_getter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4ht/data/data_loader.py b/ml4ht/data/data_loader.py index 877a662..c3bc339 100644 --- a/ml4ht/data/data_loader.py +++ b/ml4ht/data/data_loader.py @@ -284,7 +284,7 @@ def __getitem__(self, item: int) -> Batch: def numpy_collate_fn( samples: List[Batch], - auto_float: bool = True + auto_float: bool = True, ) -> Batch: """ Merges a list of ml4ht batch formatted data. diff --git a/ml4ht/data/sample_getter.py b/ml4ht/data/sample_getter.py index 4a20e98..4c324b0 100644 --- a/ml4ht/data/sample_getter.py +++ b/ml4ht/data/sample_getter.py @@ -28,7 +28,7 @@ def __init__( self, input_data_descriptions: List[DataDescription], output_data_descriptions: List[DataDescription], - option_picker: OptionPicker=None, + option_picker: OptionPicker = None, restricted_sample_id_idx=None, ): self.input_data_descriptions = input_data_descriptions