Skip to content

Commit c8a52da

Browse files
ejguanfacebook-github-bot
authored andcommitted
Enable miniepoch for MultiProcessingReadingService (pytorch#1170)
Summary: A retry of pytorch#1147 - Add pause, resume, limit API to MultiProcessingReadingService to support miniepoch API. - When limit is requested, `IterQueueWrapper` will make sure only `limit` number of requests are sent to worker process and dispatching proccess. - Consolidate a few API/messages - reset and reset_epoch is merged as a single message - Process reset function is shared for both worker process and dispatching process - For pause/resume/limit/reset_epoch, add a counter to sync between loops to make sure the DataPipe graph is in-place modified once. - Remove unused thread eventloop Pull Request resolved: pytorch#1170 Differential Revision: D45666912 Pulled By: ejguan fbshipit-source-id: 1e88dde3af0f406c062ca523e7b96a3c79a74b6d
1 parent ba31745 commit c8a52da

File tree

10 files changed

+456
-345
lines changed

10 files changed

+456
-345
lines changed

test/dataloader2/test_dataloader2.py

-69
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
2424

2525
from torchdata.dataloader2 import (
26-
communication,
2726
DataLoader2,
2827
DistributedReadingService,
2928
MultiProcessingReadingService,
@@ -35,7 +34,6 @@
3534
from torchdata.dataloader2.graph import DataPipe, list_dps, replace_dp, set_datapipes_seed, traverse_dps
3635
from torchdata.dataloader2.random import SeedGenerator
3736
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe, ShardingRoundRobinDispatcher
38-
from torchdata.datapipes.map import SequenceWrapper
3937

4038
try:
4139
import dill
@@ -259,73 +257,6 @@ def test_dataloader2_shuffle(self) -> None:
259257
pass
260258

261259

262-
@unittest.skipIf(
263-
TEST_WITH_TSAN,
264-
"Fails with TSAN with the following error: starting new threads after multi-threaded "
265-
"fork is not supported. Dying (set die_after_fork=0 to override)",
266-
)
267-
class TestDataLoader2EventLoop(TestCase):
268-
# TODO: This needs fixing, see issue 624
269-
# @skipIfNoDill
270-
# def test_basic_threading(self):
271-
# def clean_me(process, req_queue, res_queue):
272-
# req_queue.put(communication.messages.TerminateRequest())
273-
# _ = res_queue.get()
274-
# process.join()
275-
#
276-
# it = list(range(100))
277-
# numbers_dp = IterableWrapper(it)
278-
# (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.CreateThreadForDataPipeline(numbers_dp)
279-
#
280-
# process.start()
281-
# local_datapipe = communication.iter.QueueWrapper(
282-
# communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue))
283-
#
284-
# actual = list(local_datapipe)
285-
# clean_me(process, req_queue, res_queue)
286-
#
287-
# self.assertEqual(list(range(100)), actual)
288-
289-
@skipIfNoDill
290-
def test_basic_mapdatapipe_threading(self):
291-
def clean_me(process, req_queue, res_queue):
292-
req_queue.put(communication.messages.TerminateRequest())
293-
_ = res_queue.get()
294-
process.join()
295-
296-
input_len = 100
297-
it = list(range(input_len))
298-
numbers_dp = SequenceWrapper(it)
299-
(process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.CreateThreadForDataPipeline(
300-
numbers_dp,
301-
thread_name="worker thread",
302-
)
303-
304-
process.start()
305-
306-
# Functional Test: Ensure that you can retrieve every element from the Queue and DataPipe
307-
local_datapipe = communication.map.QueueWrapperForMap(
308-
communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)
309-
)
310-
actual = list(local_datapipe)
311-
self.assertEqual([(x, x) for x in range(100)], actual)
312-
313-
# Functional Test: raise Error when input
314-
local_datapipe = communication.map.QueueWrapperForMap(
315-
communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)
316-
)
317-
with self.assertRaisesRegex(IndexError, "out of bound"):
318-
local_datapipe[1000]
319-
320-
# __len__ Test: Ensure that the correct length is returned
321-
local_datapipe = communication.map.QueueWrapperForMap(
322-
communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)
323-
)
324-
self.assertEqual(input_len, len(local_datapipe))
325-
326-
clean_me(process, req_queue, res_queue)
327-
328-
329260
def _x_mult_2(d):
330261
return d * 2
331262

test/dataloader2/test_mprs.py

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _non_dispatching_dp(n_elements=1000):
4040

4141
def _dispatching_dp(n_elements=1000):
4242
dp = IterableWrapper(list(range(n_elements))).shuffle()
43+
dp = dp.prefetch(20)
4344
dp = dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
4445
dp = dp.map(_add_one).batch(16)
4546
return dp

torchdata/dataloader2/communication/eventloop.py

+72-63
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import pickle
8-
import threading
97
import time
108

119
from itertools import zip_longest
10+
from typing import Dict, List
1211

1312
import torch
1413

@@ -32,36 +31,42 @@
3231
"DataPipeToQueuesLoop",
3332
"CreateProcessForDataPipeline",
3433
"CreateProcessForMultipleDataPipelines",
35-
"CreateThreadForDataPipeline",
3634
]
3735

3836

39-
class _ResetCounter:
37+
class _RequestCounter:
4038
exp_cnt: int
41-
cnt: int
42-
_reached: bool
39+
_keys: List[str] = ["limit", "pause", "reset_epoch", "resume"]
40+
_cnt: Dict[str, int]
41+
_reached: Dict[str, bool]
4342

4443
def __init__(self, exp_cnt: int):
4544
self.exp_cnt = exp_cnt
46-
self.cnt = 0
47-
self._reached = False
48-
49-
def increment(self) -> None:
50-
self.cnt += 1
51-
assert self.cnt <= self.exp_cnt
52-
53-
def is_reached(self) -> bool:
54-
if self.cnt == self.exp_cnt:
55-
self._reached = True
56-
return self._reached
57-
58-
def reset(self) -> None:
59-
if self._reached:
60-
self._reached = False
61-
self.cnt = 0
62-
63-
64-
def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, process_name, call_on_process_init=None):
45+
self._cnt = {k: 0 for k in self._keys}
46+
self._reached = {k: False for k in self._keys}
47+
48+
def increment(self, key: str) -> None:
49+
assert key in self._reached
50+
self._cnt[key] += 1
51+
assert self._cnt[key] <= self.exp_cnt
52+
if self._cnt[key] == self.exp_cnt:
53+
self._reached[key] = True
54+
55+
def is_reached(self, key: str) -> bool:
56+
assert key in self._reached
57+
return self._reached[key]
58+
59+
def reset(self, key: str) -> None:
60+
assert key in self._reached and self._reached[key]
61+
assert self._cnt[key] >= 1
62+
self._cnt[key] -= 1
63+
if self._cnt[key] == 0:
64+
self._reached[key] = False
65+
66+
67+
def MultipleDataPipesToQueuesLoop(
68+
source_datapipes, req_queues, res_queues, process_name, worker_info, call_on_process_init=None, custom_reset_fn=None
69+
):
6570
r"""
6671
Set the appropriate pipes and protocol server type, and create a loop over multiple datapipes
6772
with the protocol server in a non-blocking manner.
@@ -71,7 +76,9 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, proc
7176
req_queue: Multiprocessing queue providing requests from the worker process
7277
res_queue: Multiprocessing queue sending results to the worker process
7378
process_name: The name of process (used for logging and exception handling)
79+
worker_info: Worker information (worker id and number of workers)
7480
call_on_process_init: Not allowed by dispatching process for now.
81+
custom_reset_fn: Optional callable function to reset the DataPipe.
7582
"""
7683
assert call_on_process_init is None, "``MultipleDataPipesToQueuesLoop`` does not support call_on_process_init"
7784
num_loops = len(source_datapipes)
@@ -82,21 +89,24 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, proc
8289
torch.set_num_threads(1)
8390

8491
loops = []
85-
reset_iterator_counter = _ResetCounter(num_loops)
92+
request_counter = _RequestCounter(num_loops)
8693

94+
loop_id = 0
8795
for source_datapipe, req_queue, res_queue in zip(source_datapipes, req_queues, res_queues):
88-
# Extract Serialization Wrapper
89-
source_datapipe = extract_wrapper(source_datapipe)
9096
loops.append(
9197
_create_datapipe_queue_loop(
9298
source_datapipe,
9399
req_queue,
94100
res_queue,
95101
process_name,
102+
loop_id,
103+
worker_info,
104+
custom_reset_fn,
96105
blocking_request_get=False,
97-
reset_iterator_counter=reset_iterator_counter,
106+
request_counter=request_counter,
98107
)
99108
) # Non-blocking request with reset counters
109+
loop_id += 1
100110

101111
# Using `zip_longest` to guarantee the process is terminated only when
102112
# all loops have received `TerminateRequest`
@@ -107,7 +117,9 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, proc
107117
time.sleep(0)
108118

109119

110-
def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, process_name, call_on_process_init=None):
120+
def DataPipeToQueuesLoop(
121+
source_datapipe, req_queue, res_queue, process_name, worker_info, call_on_process_init=None, custom_reset_fn=None
122+
):
111123
r"""
112124
Initialize with the given init function, set the appropriate pipe and protocol server type, and
113125
create a loop with the protocol server.
@@ -117,8 +129,10 @@ def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, process_name, ca
117129
req_queue: Multiprocessing queue providing requests from the main process
118130
res_queue: Multiprocessing queue sending results to the main process
119131
process_name: The name of process (used for logging and exception handling)
132+
worker_info: Worker information (worker id and number of workers)
120133
call_on_process_init: Callable function will be called at the time of worker process initialization.
121134
Users can provide it to modify the DataPipe grpah in the worker process.
135+
custom_reset_fn: Optional callable function to reset the DataPipe.
122136
"""
123137
# Extract Serialization Wrapper
124138
source_datapipe = extract_wrapper(source_datapipe)
@@ -128,7 +142,16 @@ def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, process_name, ca
128142

129143
torch.set_num_threads(1)
130144

131-
loop = _create_datapipe_queue_loop(source_datapipe, req_queue, res_queue, process_name, blocking_request_get=True)
145+
loop = _create_datapipe_queue_loop(
146+
source_datapipe,
147+
req_queue,
148+
res_queue,
149+
process_name,
150+
worker_info.worker_id,
151+
worker_info,
152+
custom_reset_fn,
153+
blocking_request_get=True,
154+
)
132155

133156
for _ in loop:
134157
pass
@@ -139,8 +162,11 @@ def _create_datapipe_queue_loop(
139162
req_queue,
140163
res_queue,
141164
process_name,
165+
loop_id,
166+
worker_info,
167+
custom_reset_fn=None,
142168
blocking_request_get=True,
143-
reset_iterator_counter=None,
169+
request_counter=None,
144170
):
145171
if isinstance(source_datapipe, IterDataPipe):
146172
pipe_type = communication.iter
@@ -155,51 +181,33 @@ def _create_datapipe_queue_loop(
155181
source_datapipe,
156182
protocol_type(req_queue, res_queue),
157183
process_name=process_name,
184+
loop_id=loop_id,
185+
worker_info=worker_info,
186+
custom_reset_fn=custom_reset_fn,
158187
blocking_request_get=blocking_request_get,
159-
reset_iterator_counter=reset_iterator_counter,
188+
request_counter=request_counter,
160189
)
161190

162191

163-
def CreateProcessForDataPipeline(multiprocessing_ctx, datapipe, process_name, call_on_process_init=None):
192+
def CreateProcessForDataPipeline(
193+
multiprocessing_ctx, datapipe, process_name, worker_info, call_on_process_init=None, custom_reset_fn=None
194+
):
164195
r"""
165196
Given a DataPipe, creates a new process with ``DataPipeToQueuesLoop`` as target,
166197
and returns ``(process, req_queue, res_queue)``.
167198
"""
168199
req_queue = multiprocessing_ctx.Queue()
169200
res_queue = multiprocessing_ctx.Queue()
170201
process = multiprocessing_ctx.Process(
171-
target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue, process_name, call_on_process_init)
202+
target=DataPipeToQueuesLoop,
203+
args=(datapipe, req_queue, res_queue, process_name, worker_info, call_on_process_init, custom_reset_fn),
172204
)
173205
return process, req_queue, res_queue
174206

175207

176-
def CreateThreadForDataPipeline(datapipe, thread_name):
177-
r"""
178-
Given a DataPipe, creates a copy of the DataPipe, starts a new Thread with ``DataPipeToQueuesLoop`` as target,
179-
and returns ``(process, req_queue, res_queue, new_copied_datapipe)``.
180-
"""
181-
req_queue = communication.queue.ThreadingQueue()
182-
res_queue = communication.queue.ThreadingQueue()
183-
184-
try:
185-
new_datapipe = pickle.loads(pickle.dumps(datapipe))
186-
except Exception as pe:
187-
if HAS_DILL:
188-
try:
189-
new_datapipe = dill.loads(dill.dumps(datapipe))
190-
except Exception as de:
191-
raise Exception("Unable to dill DataPipe to make thread local copy", de)
192-
193-
else:
194-
raise Exception("Unable to pickle DataPipe to make thread local copy (consider installing `dill`)", pe)
195-
196-
process = threading.Thread(
197-
target=DataPipeToQueuesLoop, args=(new_datapipe, req_queue, res_queue, thread_name), daemon=True
198-
)
199-
return process, req_queue, res_queue, new_datapipe
200-
201-
202-
def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes, process_name):
208+
def CreateProcessForMultipleDataPipelines(
209+
multiprocessing_ctx, datapipes, process_name, worker_info, custom_reset_fn=None
210+
):
203211
r"""
204212
Given a DataPipe, creates a new process with ``MultipleDataPipesToQueuesLoop`` as target,
205213
and returns ``(process, [req_queue_0, ...], [res_queue_0, ...])``.
@@ -211,6 +219,7 @@ def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes, proces
211219
res_queues.append(multiprocessing_ctx.Queue())
212220

213221
process = multiprocessing_ctx.Process(
214-
target=MultipleDataPipesToQueuesLoop, args=(datapipes, req_queues, res_queues, process_name)
222+
target=MultipleDataPipesToQueuesLoop,
223+
args=(datapipes, req_queues, res_queues, process_name, worker_info, custom_reset_fn),
215224
)
216225
return process, req_queues, res_queues

0 commit comments

Comments
 (0)