@@ -22,6 +22,7 @@ def __init__(
22
22
automatic_batching = None ,
23
23
num_workers = None ,
24
24
pin_memory = None ,
25
+ shuffle = None ,
25
26
** kwargs ,
26
27
):
27
28
"""
@@ -34,13 +35,13 @@ def __init__(
34
35
If ``batch_size=None`` all
35
36
samples are loaded and data are not batched, defaults to None.
36
37
:type batch_size: int | None
37
- :param train_size: percentage of elements in the train dataset
38
+ :param train_size: Percentage of elements in the train dataset.
38
39
:type train_size: float
39
- :param test_size: percentage of elements in the test dataset
40
+ :param test_size: Percentage of elements in the test dataset.
40
41
:type test_size: float
41
- :param val_size: percentage of elements in the val dataset
42
+ :param val_size: Percentage of elements in the val dataset.
42
43
:type val_size: float
43
- :param predict_size: percentage of elements in the predict dataset
44
+ :param predict_size: Percentage of elements in the predict dataset.
44
45
:type predict_size: float
45
46
:param compile: if True model is compiled before training,
46
47
default False. For Windows users compilation is always disabled.
@@ -49,9 +50,13 @@ def __init__(
49
50
performed. Please avoid using automatic batching when batch_size is
50
51
large, default False.
51
52
:type automatic_batching: bool
52
- :param num_workers: Number of worker threads for data loading. Default 0 (serial loading)
53
+ :param num_workers: Number of worker threads for data loading.
54
+ Default 0 (serial loading).
53
55
:type num_workers: int
54
- :param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False)
56
+ :param pin_memory: Whether to use pinned memory for faster data
57
+ transfer to GPU. Default False.
58
+ :type pin_memory: bool
59
+ :param shuffle: Whether to shuffle the data for training. Default False.
55
60
:type pin_memory: bool
56
61
57
62
:Keyword Arguments:
@@ -77,6 +82,10 @@ def __init__(
77
82
check_consistency (pin_memory , int )
78
83
else :
79
84
num_workers = 0
85
+ if shuffle is not None :
86
+ check_consistency (shuffle , bool )
87
+ else :
88
+ shuffle = False
80
89
if train_size + test_size + val_size + predict_size > 1 :
81
90
raise ValueError (
82
91
"train_size, test_size, val_size and predict_size "
@@ -131,6 +140,7 @@ def __init__(
131
140
automatic_batching ,
132
141
pin_memory ,
133
142
num_workers ,
143
+ shuffle ,
134
144
)
135
145
136
146
# logging
@@ -166,6 +176,7 @@ def _create_datamodule(
166
176
automatic_batching ,
167
177
pin_memory ,
168
178
num_workers ,
179
+ shuffle ,
169
180
):
170
181
"""
171
182
This method is used here because is resampling is needed
@@ -196,6 +207,7 @@ def _create_datamodule(
196
207
automatic_batching = automatic_batching ,
197
208
num_workers = num_workers ,
198
209
pin_memory = pin_memory ,
210
+ shuffle = shuffle ,
199
211
)
200
212
201
213
def train (self , ** kwargs ):
0 commit comments