diff --git a/ml4ht/data/data_loader.py b/ml4ht/data/data_loader.py index d5f7bd1..c3bc339 100644 --- a/ml4ht/data/data_loader.py +++ b/ml4ht/data/data_loader.py @@ -282,7 +282,10 @@ def __getitem__(self, item: int) -> Batch: ) -def numpy_collate_fn(samples: List[Batch]) -> Batch: +def numpy_collate_fn( + samples: List[Batch], + auto_float: bool = True, +) -> Batch: """ Merges a list of ml4ht batch formatted data. Can be used as 'collate_fn` in torch.utils.data.DataLoader @@ -290,13 +293,21 @@ def numpy_collate_fn(samples: List[Batch]) -> Batch: """ # construct correctly-shaped empty arrays for input and output of model in_batch_keys = list(samples[0][0]) + if auto_float: + in_dtypes = {k: np.float32 for k in in_batch_keys} + else: + in_dtypes = {k: samples[0][0][k].dtype for k in in_batch_keys} in_batch = { - k: np.empty((len(samples),) + samples[0][0][k].shape, dtype=np.float32) + k: np.empty((len(samples),) + samples[0][0][k].shape, dtype=in_dtypes[k]) for k in in_batch_keys } out_batch_keys = list(samples[0][1]) + if auto_float: + out_dtypes = {k: np.float32 for k in out_batch_keys} + else: + out_dtypes = {k: samples[0][1][k].dtype for k in out_batch_keys} out_batch = { - k: np.empty((len(samples),) + samples[0][1][k].shape, dtype=np.float32) + k: np.empty((len(samples),) + samples[0][1][k].shape, dtype=out_dtypes[k]) for k in out_batch_keys } # fill in the values of the input and output arrays diff --git a/ml4ht/data/sample_getter.py b/ml4ht/data/sample_getter.py index fb89aff..4c324b0 100644 --- a/ml4ht/data/sample_getter.py +++ b/ml4ht/data/sample_getter.py @@ -29,10 +29,12 @@ def __init__( input_data_descriptions: List[DataDescription], output_data_descriptions: List[DataDescription], option_picker: OptionPicker = None, + restricted_sample_id_idx=None, ): self.input_data_descriptions = input_data_descriptions self.output_data_descriptions = output_data_descriptions self.option_picker = option_picker or self._default_option_picker + self.restricted_sample_id_idx = restricted_sample_id_idx @staticmethod def _default_option_picker( @@ -67,6 +69,8 @@ def __call__(self, sample_id: SampleID) -> Batch: sample_id, self.input_data_descriptions + self.output_data_descriptions, ) + if self.restricted_sample_id_idx is not None: + sample_id = sample_id[self.restricted_sample_id_idx] tensors_in = self._half_batch(sample_id, loading_options, True) tensors_out = self._half_batch(sample_id, loading_options, False) return tensors_in, tensors_out