Skip to content

Commit 0c8fbdd

Browse files
authored
update trainer (#466)
1 parent fc88b44 commit 0c8fbdd

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

pina/trainer.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
automatic_batching=None,
2323
num_workers=None,
2424
pin_memory=None,
25+
shuffle=None,
2526
**kwargs,
2627
):
2728
"""
@@ -34,13 +35,13 @@ def __init__(
3435
If ``batch_size=None`` all
3536
samples are loaded and data are not batched, defaults to None.
3637
:type batch_size: int | None
37-
:param train_size: percentage of elements in the train dataset
38+
:param train_size: Percentage of elements in the train dataset.
3839
:type train_size: float
39-
:param test_size: percentage of elements in the test dataset
40+
:param test_size: Percentage of elements in the test dataset.
4041
:type test_size: float
41-
:param val_size: percentage of elements in the val dataset
42+
:param val_size: Percentage of elements in the val dataset.
4243
:type val_size: float
43-
:param predict_size: percentage of elements in the predict dataset
44+
:param predict_size: Percentage of elements in the predict dataset.
4445
:type predict_size: float
4546
:param compile: if True model is compiled before training,
4647
default False. For Windows users compilation is always disabled.
@@ -49,9 +50,13 @@ def __init__(
4950
performed. Please avoid using automatic batching when batch_size is
5051
large, default False.
5152
:type automatic_batching: bool
52-
:param num_workers: Number of worker threads for data loading. Default 0 (serial loading)
53+
:param num_workers: Number of worker threads for data loading.
54+
Default 0 (serial loading).
5355
:type num_workers: int
54-
:param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False)
56+
:param pin_memory: Whether to use pinned memory for faster data
57+
transfer to GPU. Default False.
58+
:type pin_memory: bool
59+
:param shuffle: Whether to shuffle the data for training. Default False.
5560
:type pin_memory: bool
5661
5762
:Keyword Arguments:
@@ -77,6 +82,10 @@ def __init__(
7782
check_consistency(pin_memory, int)
7883
else:
7984
num_workers = 0
85+
if shuffle is not None:
86+
check_consistency(shuffle, bool)
87+
else:
88+
shuffle = False
8089
if train_size + test_size + val_size + predict_size > 1:
8190
raise ValueError(
8291
"train_size, test_size, val_size and predict_size "
@@ -131,6 +140,7 @@ def __init__(
131140
automatic_batching,
132141
pin_memory,
133142
num_workers,
143+
shuffle,
134144
)
135145

136146
# logging
@@ -166,6 +176,7 @@ def _create_datamodule(
166176
automatic_batching,
167177
pin_memory,
168178
num_workers,
179+
shuffle,
169180
):
170181
"""
171182
This method is used here because is resampling is needed
@@ -196,6 +207,7 @@ def _create_datamodule(
196207
automatic_batching=automatic_batching,
197208
num_workers=num_workers,
198209
pin_memory=pin_memory,
210+
shuffle=shuffle,
199211
)
200212

201213
def train(self, **kwargs):

0 commit comments

Comments
 (0)