-
Notifications
You must be signed in to change notification settings - Fork 4.2k
/
Copy pathdataloader.py
162 lines (139 loc) · 6.81 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.data_pipeline.data_sampling.data_sampler import DeepSpeedDataSampler
from deepspeed.runtime.data_pipeline.constants import CURRICULUM_LEARNING, \
DATA_EFFICIENCY, DATA_SAMPLING_NUM_WORKERS
from deepspeed.runtime.constants import GRADIENT_ACCUMULATION_STEPS, \
DATA_PARALLEL_GROUP, GLOBAL_RANK
class RepeatingLoader:
def __init__(self, loader):
"""Wraps an iterator to allow for infinite iteration. This is especially useful
for DataLoader types that we wish to automatically restart upon completion.
Args:
loader (iterator): The data loader to repeat.
"""
self.loader = loader
self.data_iter = iter(self.loader)
def __iter__(self):
return self
def __next__(self):
try:
batch = next(self.data_iter)
except StopIteration:
self.data_iter = iter(self.loader)
batch = next(self.data_iter)
return batch
class DeepSpeedDataLoader(object):
def __init__(self,
dataset,
batch_size,
pin_memory,
local_rank,
tput_timer,
collate_fn=None,
num_local_io_workers=None,
data_sampler=None,
data_parallel_world_size=None,
data_parallel_rank=None,
dataloader_drop_last=False,
deepspeed_dataloader_config={}):
self.deepspeed_dataloader_config = deepspeed_dataloader_config
self.tput_timer = tput_timer
self.batch_size = batch_size
self.curriculum_learning_enabled = False
if CURRICULUM_LEARNING in deepspeed_dataloader_config:
self.curriculum_learning_enabled = deepspeed_dataloader_config[CURRICULUM_LEARNING]
if self.curriculum_learning_enabled:
data_sampler = DeepSpeedDataSampler(self.deepspeed_dataloader_config[DATA_EFFICIENCY],
len(dataset),
self.batch_size,
data_parallel_rank,
data_parallel_world_size,
self.deepspeed_dataloader_config[DATA_PARALLEL_GROUP],
self.deepspeed_dataloader_config[GRADIENT_ACCUMULATION_STEPS],
self.deepspeed_dataloader_config[GLOBAL_RANK],
drop_last=dataloader_drop_last)
device_count = get_accelerator().device_count()
num_local_io_workers = self.deepspeed_dataloader_config[DATA_SAMPLING_NUM_WORKERS]
else:
if local_rank >= 0:
if data_sampler is None:
data_sampler = DistributedSampler(dataset=dataset,
num_replicas=data_parallel_world_size,
rank=data_parallel_rank)
device_count = 1
else:
if data_sampler is None:
data_sampler = RandomSampler(dataset)
device_count = get_accelerator().device_count()
batch_size *= device_count
if num_local_io_workers is None:
num_local_io_workers = 2 * device_count
self.num_local_io_workers = num_local_io_workers
self.data_sampler = data_sampler
self.dataset = dataset
self.collate_fn = collate_fn
self.device_count = device_count
self.batch_size = batch_size
self.pin_memory = pin_memory
self.data = None
self.dataloader_drop_last = dataloader_drop_last
self.post_process_func = None
if self.dataloader_drop_last:
self.len = len(self.data_sampler) // self.batch_size
else:
from math import ceil
self.len = ceil(len(self.data_sampler) / self.batch_size)
def __iter__(self):
self._create_dataloader()
return self
def __len__(self):
return self.len
def __next__(self):
if self.tput_timer:
self.tput_timer.start()
if self.curriculum_learning_enabled:
data = next(self.data_iterator)
if self.post_process_func is not None:
data = self.post_process_func(data, self.data_sampler.state_dict())
return data
else:
return next(self.data)
def _create_dataloader(self):
if self.curriculum_learning_enabled:
if self.collate_fn is None:
self.dataloader = DataLoader(self.dataset,
pin_memory=self.pin_memory,
batch_sampler=self.data_sampler,
num_workers=self.num_local_io_workers)
else:
self.dataloader = DataLoader(self.dataset,
pin_memory=self.pin_memory,
batch_sampler=self.data_sampler,
collate_fn=self.collate_fn,
num_workers=self.num_local_io_workers)
self.data_iterator = iter(self.dataloader)
return self.dataloader
else:
if self.collate_fn is None:
self.dataloader = DataLoader(self.dataset,
batch_size=self.batch_size,
pin_memory=self.pin_memory,
sampler=self.data_sampler,
num_workers=self.num_local_io_workers,
drop_last=self.dataloader_drop_last)
else:
self.dataloader = DataLoader(self.dataset,
batch_size=self.batch_size,
pin_memory=self.pin_memory,
sampler=self.data_sampler,
collate_fn=self.collate_fn,
num_workers=self.num_local_io_workers,
drop_last=self.dataloader_drop_last)
self.data = (x for x in self.dataloader)
return self.dataloader
# DataLoader([(torch.randn(3, 3), torch.tensor(i % 2)) for i in range(10)], batch_size=2))