-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
Copy pathdata_loading.py
543 lines (465 loc) · 26.3 KB
/
data_loading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import multiprocessing
import os
from abc import ABC
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler
from torch.utils.data.dataset import IterableDataset
from torch.utils.data.distributed import DistributedSampler
import pytorch_lightning as pl
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import (
_capture_metadata_collate,
CaptureIterableDataset,
CaptureMapDataset,
FastForwardSampler,
)
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import pl_worker_init_function
class TrainerDataLoadingMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
val_check_interval: float
tpu_local_core_rank: int
train_dataloader: DataLoader
num_training_batches: Union[int, float]
val_check_batch: float
val_dataloaders: Optional[List[DataLoader]]
num_val_batches: List[Union[int, float]]
test_dataloaders: Optional[List[DataLoader]]
num_test_batches: List[Union[int, float]]
limit_train_batches: Union[int, float]
log_every_n_steps: int
overfit_batches: Union[int, float]
distributed_sampler_kwargs: dict
accelerator: Accelerator
accelerator_connector: AcceleratorConnector
dev_debugger: InternalDebugger
call_hook: Callable
def _worker_check(self, dataloader: DataLoader, name: str) -> None:
if not isinstance(dataloader, DataLoader):
return
using_spawn = self.accelerator_connector.distributed_backend == "ddp_spawn"
num_cpus = multiprocessing.cpu_count()
# ddp_spawn + num_workers > 0 don't mix! tell the user
if dataloader.num_workers > 0 and using_spawn:
# checks for the attr persistent_workers available in pytorch >= 1.7
if hasattr(dataloader, "persistent_workers"):
if not dataloader.persistent_workers:
rank_zero_warn(
"num_workers>0, persistent_workers=False, and accelerator=ddp_spawn"
" may result in data loading bottlenecks."
" Consider setting persistent_workers=True"
" (this is a limitation of Python .spawn() and PyTorch)"
)
else:
rank_zero_warn(
"num_workers>0 and accelerator=ddp_spawn do not mix well"
" and may result in data loading bottlenecks."
" Consider setting accelerator=ddp to use num_workers>0"
" (this is a limitation of Python .spawn() and PyTorch)"
)
elif dataloader.num_workers == 0 and using_spawn:
# checks for the attr persistent_workers available in pytorch >= 1.7
if hasattr(dataloader, "persistent_workers"):
if not dataloader.persistent_workers:
rank_zero_warn(
"accelerator=ddp_spawn and num_workers=0 may result in data loading bottlenecks."
" Consider setting num_workers>0 and persistent_workers=True"
)
else:
rank_zero_warn(
"accelerator=ddp_spawn and num_workers=0 may result in data loading bottlenecks."
" Consider setting accelerator=ddp and set num_workers>0"
)
elif dataloader.num_workers <= 2 < num_cpus and not using_spawn:
rank_zero_warn(
f"The dataloader, {name}, does not have many workers which may be a bottleneck."
" Consider increasing the value of the `num_workers` argument`"
f" (try {num_cpus} which is the number of cpus on this machine)"
" in the `DataLoader` init to improve performance."
)
def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None:
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank)
def auto_add_sampler(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any:
if isinstance(dataloader, CombinedLoader):
# apply `auto_add_sampler` on all the collection of loaders
dataloader.loaders = apply_to_collection(
dataloader.loaders, DataLoader, self.auto_add_sampler, shuffle, mode=mode
)
return dataloader
# don't do anything if it's not a dataloader
if not isinstance(dataloader, DataLoader):
return dataloader
if (
self.accelerator_connector.replace_sampler_ddp
and self.accelerator_connector.is_distributed
and not isinstance(dataloader.sampler, DistributedSampler)
and not has_iterable_dataset(dataloader)
):
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
raise MisconfigurationException(
"You seem to have configured a sampler in your DataLoader. This will be replaced "
" by `DistributedSampler` since `replace_sampler_ddp` is True and you are using"
" distributed training. Either remove the sampler from your DataLoader or set"
" `replace_sampler_ddp`=False if you want to use your custom sampler."
)
sampler = self._get_distributed_sampler(dataloader, shuffle, mode=mode)
else:
# use current sampler
sampler = dataloader.sampler
dataloader = self.replace_sampler(dataloader, sampler, mode=mode)
return dataloader
@staticmethod
def _resolve_batch_sampler(
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
) -> Dict[str, Any]:
batch_sampler = getattr(dataloader, "batch_sampler")
is_predicting = mode == RunningStage.PREDICTING
# checking the batch sampler type is different than PyTorch default.
if (batch_sampler is not None and type(batch_sampler) is not BatchSampler) or is_predicting:
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=(False if is_predicting else batch_sampler.drop_last),
)
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
if _fault_tolerant_training():
fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
fast_forward_sampler.setup(dataloader_batch_size=1)
return {
"sampler": None,
"shuffle": False,
"batch_sampler": batch_sampler,
"batch_size": 1,
"drop_last": False,
}
if _fault_tolerant_training():
fast_forward_sampler = sampler = FastForwardSampler(sampler)
fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size)
return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
@staticmethod
def _get_dataloader_init_kwargs(
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
) -> Dict[str, Any]:
if not isinstance(dataloader, DataLoader):
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
# get the dataloader instance attributes
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
# not part of `vars`
attrs["multiprocessing_context"] = dataloader.multiprocessing_context
# get the dataloader instance `__init__` parameters
params = dict(inspect.signature(dataloader.__init__).parameters)
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if has_variadic_kwargs:
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
params.update(inspect.signature(DataLoader.__init__).parameters)
del params["self"]
# keep only the params whose default is different to the current attr value
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
# add `dataset` as it might have been replaced with `*args`
non_defaults.add("dataset")
# kwargs to re-construct the dataloader
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
dl_kwargs.update(TrainerDataLoadingMixin._resolve_batch_sampler(dataloader, sampler, mode=mode))
required_args = {
p.name
for p in params.values()
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
and p.default is p.empty
and p.name not in dl_kwargs
}
# the dataloader has required args which we could not extract from the existing attributes
if required_args:
required_args = sorted(required_args)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)
if not has_variadic_kwargs:
# the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = dl_kwargs.keys() - params.keys()
if missing_kwargs:
missing_kwargs = sorted(missing_kwargs)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)
if isinstance(dl_kwargs["dataset"], IterableDataset):
dl_kwargs["batch_sampler"] = None
dl_kwargs["sampler"] = None
if _fault_tolerant_training():
if isinstance(dl_kwargs["dataset"], IterableDataset):
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
elif len(dl_kwargs["dataset"]):
dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"])
else:
raise MisconfigurationException(
"This shouldn't happen, please open an issue on Lightning Github repository."
)
return dl_kwargs
@staticmethod
def replace_sampler(dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None) -> DataLoader:
dl_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler, mode=mode)
dl_cls = type(dataloader)
dataloader = dl_cls(**dl_kwargs)
return dataloader
def _get_distributed_sampler(
self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None
) -> DistributedSampler:
kwargs = self.distributed_sampler_kwargs
kwargs["shuffle"] = shuffle and not self.overfit_batches
kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0)))
cls = UnrepeatedDistributedSampler if mode == RunningStage.PREDICTING else DistributedSampler
sampler = cls(dataloader.dataset, **kwargs)
return sampler
def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the train dataloader and initialises required variables (number of batches, when to validate,
etc.).
Args:
model: The `LightningModule` if calling this outside of the trainer scope.
"""
self.train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model)
if self.overfit_batches > 0:
if hasattr(self.train_dataloader, "sampler") and isinstance(self.train_dataloader.sampler, RandomSampler):
rank_zero_warn(
"You requested to overfit but enabled training dataloader shuffling."
" We are turning off the training dataloader shuffling for you."
)
self.train_dataloader = self.replace_sampler(
self.train_dataloader, SequentialSampler(self.train_dataloader.dataset), mode=RunningStage.TRAINING
)
# debugging
self.dev_debugger.track_load_dataloader_call("train_dataloader", dataloaders=[self.train_dataloader])
# automatically add samplers
self.train_dataloader = apply_to_collection(
self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True, mode=RunningStage.TRAINING
)
# check the workers recursively
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train_dataloader")
# add worker_init_fn for correct seeding in worker processes
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)
# add collate_fn to collect metadata for fault tolerant training
if _fault_tolerant_training():
apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate)
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode)
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float("inf")
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
elif self.num_training_batches != float("inf"):
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
elif self.limit_train_batches != 1.0:
raise MisconfigurationException(
"When using an IterableDataset for `limit_train_batches`,"
" `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
" `num_training_batches` to use."
)
# determine when to check validation
# if int passed in, val checks that often
# otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
raise ValueError(
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
f"to the number of the training batches ({self.num_training_batches}). "
"If you want to disable validation set `limit_val_batches` to 0.0 instead."
)
else:
if not has_len(self.train_dataloader):
if self.val_check_interval == 1.0:
self.val_check_batch = float("inf")
else:
raise MisconfigurationException(
"When using an IterableDataset for `train_dataloader`,"
" `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
" checking validation every k training batches."
)
else:
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
if self.logger and self.num_training_batches < self.log_every_n_steps:
rank_zero_warn(
f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
" you want to see logs for the training epoch."
)
def _reset_eval_dataloader(
self, mode: RunningStage, model: Optional["pl.LightningModule"] = None
) -> Tuple[List[Union[int, float]], List[DataLoader]]:
"""Generic method to reset a dataloader for evaluation.
Args:
mode: The running stage of the ``Trainer``
model: The ``LightningModule`` if calling this outside of the trainer scope.
Returns:
Tuple (num_batches, dataloaders)
"""
assert mode.evaluating or mode == RunningStage.PREDICTING
# always get the loaders first so we can count how many there are
loader_name = f"{mode.dataloader_prefix}_dataloader"
dataloaders = self.request_dataloader(mode, model=model)
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
# when overfitting, use the training loader as val and test
# duplicate it the numb of times needed to match the train loaders
if self.overfit_batches > 0:
train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model)
dataloaders = [deepcopy(train_dataloader) for _ in range(len(dataloaders))]
self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders)
for loader_i in range(len(dataloaders)):
loader = dataloaders[loader_i]
if hasattr(loader, "sampler") and isinstance(loader.sampler, RandomSampler):
# when overfitting, the dataloader should not have sampler
if self.overfit_batches > 0 and mode.evaluating:
rank_zero_warn(
"You requested to overfit but enabled val/test dataloader shuffling."
" We are turning it off for you."
)
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset), mode=mode)
else:
rank_zero_warn(
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
"it is strongly recommended that you turn this off for val/test/predict dataloaders."
)
if any(dl is None for dl in dataloaders):
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
# add samplers
dataloaders = [self.auto_add_sampler(dl, False, mode=mode) for dl in dataloaders if dl is not None]
# add worker_init_fn for correct seeding in worker processes
apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn)
loader_num_batches = []
# determine number of batches
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
num_batches = len(dataloader) if has_len(dataloader) else float("inf")
self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}")
# percent or num_steps
limit_eval_batches = getattr(self, f"limit_{mode.dataloader_prefix}_batches")
# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
num_batches = min(num_batches, int(limit_eval_batches))
elif num_batches != float("inf"):
num_batches = int(num_batches * limit_eval_batches)
elif limit_eval_batches != 1.0:
raise MisconfigurationException(
f"When using an IterableDataset for `limit_{mode}_batches`,"
f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `0.0`, `1.0` or an int. An int k"
f" specifies `num_{mode.dataloader_prefix}_batches` to use."
)
if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
min_pct = 1.0 / len(dataloader)
raise MisconfigurationException(
f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but"
f" {limit_eval_batches}*{num_batches} < 1. Please increase the"
f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least"
f" `limit_{mode.dataloader_prefix}_batches={min_pct}`"
)
loader_num_batches.append(num_batches)
return loader_num_batches, dataloaders
def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the validation dataloader and determines the number of batches.
Args:
model: The `LightningModule` if called outside of the trainer scope.
"""
pl_module = self.lightning_module or model
has_loader = is_overridden("val_dataloader", pl_module)
has_step = is_overridden("validation_step", pl_module)
if has_loader and has_step:
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(
RunningStage.VALIDATING, model=pl_module
)
def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the test dataloader and determines the number of batches.
Args:
model: The `LightningModule` if called outside of the trainer scope.
"""
pl_module = self.lightning_module or model
has_loader = is_overridden("test_dataloader", pl_module)
has_step = is_overridden("test_step", pl_module)
if has_loader and has_step:
self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader(
RunningStage.TESTING, model=pl_module
)
def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the predict dataloader and determines the number of batches.
Args:
model: The `LightningModule` if called outside of the trainer scope.
"""
pl_module = self.lightning_module or model
has_loader = is_overridden("predict_dataloader", pl_module)
if has_loader:
self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(
RunningStage.PREDICTING, model=pl_module
)
def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets train and val dataloaders if none are attached to the trainer.
The val dataloader must be initialized before training loop starts, as the training loop
inspects the val dataloader to determine whether to run the evaluation loop.
Args:
model: The `LightningModule` if called outside of the trainer scope.
"""
if self.train_dataloader is None:
self.reset_train_dataloader(model=model)
if self.val_dataloaders is None:
self.reset_val_dataloader(model=model)
def request_dataloader(
self, stage: RunningStage, model: Optional["pl.LightningModule"] = None
) -> Union[DataLoader, List[DataLoader]]:
"""Handles downloading data in the GPU or TPU case.
Returns:
The dataloader
"""
hook = f"{stage.dataloader_prefix}_dataloader"
self.call_hook("on_" + hook, pl_module=model)
dataloader = self.call_hook(hook, pl_module=model)
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.accelerator.barrier("get_dataloaders")
return dataloader
@staticmethod
def _add_sampler_metadata_collate(dataloader: DataLoader) -> None:
"""Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is
enabled."""
dataloader.collate_fn = partial(
_capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn
)