4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- import pickle
8
- import threading
9
7
import time
10
8
11
9
from itertools import zip_longest
10
+ from typing import Dict , List
12
11
13
12
import torch
14
13
32
31
"DataPipeToQueuesLoop" ,
33
32
"CreateProcessForDataPipeline" ,
34
33
"CreateProcessForMultipleDataPipelines" ,
35
- "CreateThreadForDataPipeline" ,
36
34
]
37
35
38
36
39
- class _ResetCounter :
37
+ class _RequestCounter :
40
38
exp_cnt : int
41
- cnt : int
42
- _reached : bool
39
+ _keys : List [str ] = ["limit" , "pause" , "reset_epoch" , "resume" ]
40
+ _cnt : Dict [str , int ]
41
+ _reached : Dict [str , bool ]
43
42
44
43
def __init__ (self , exp_cnt : int ):
45
44
self .exp_cnt = exp_cnt
46
- self .cnt = 0
47
- self ._reached = False
48
-
49
- def increment (self ) -> None :
50
- self .cnt += 1
51
- assert self .cnt <= self .exp_cnt
52
-
53
- def is_reached (self ) -> bool :
54
- if self .cnt == self .exp_cnt :
55
- self ._reached = True
56
- return self ._reached
57
-
58
- def reset (self ) -> None :
59
- if self ._reached :
60
- self ._reached = False
61
- self .cnt = 0
62
-
63
-
64
- def MultipleDataPipesToQueuesLoop (source_datapipes , req_queues , res_queues , process_name , call_on_process_init = None ):
45
+ self ._cnt = {k : 0 for k in self ._keys }
46
+ self ._reached = {k : False for k in self ._keys }
47
+
48
+ def increment (self , key : str ) -> None :
49
+ assert key in self ._reached
50
+ self ._cnt [key ] += 1
51
+ assert self ._cnt [key ] <= self .exp_cnt
52
+ if self ._cnt [key ] == self .exp_cnt :
53
+ self ._reached [key ] = True
54
+
55
+ def is_reached (self , key : str ) -> bool :
56
+ assert key in self ._reached
57
+ return self ._reached [key ]
58
+
59
+ def reset (self , key : str ) -> None :
60
+ assert key in self ._reached and self ._reached [key ]
61
+ assert self ._cnt [key ] >= 1
62
+ self ._cnt [key ] -= 1
63
+ if self ._cnt [key ] == 0 :
64
+ self ._reached [key ] = False
65
+
66
+
67
+ def MultipleDataPipesToQueuesLoop (
68
+ source_datapipes , req_queues , res_queues , process_name , worker_info , call_on_process_init = None , custom_reset_fn = None
69
+ ):
65
70
r"""
66
71
Set the appropriate pipes and protocol server type, and create a loop over multiple datapipes
67
72
with the protocol server in a non-blocking manner.
@@ -71,7 +76,9 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, proc
71
76
req_queue: Multiprocessing queue providing requests from the worker process
72
77
res_queue: Multiprocessing queue sending results to the worker process
73
78
process_name: The name of process (used for logging and exception handling)
79
+ worker_info: Worker information (worker id and number of workers)
74
80
call_on_process_init: Not allowed by dispatching process for now.
81
+ custom_reset_fn: Optional callable function to reset the DataPipe.
75
82
"""
76
83
assert call_on_process_init is None , "``MultipleDataPipesToQueuesLoop`` does not support call_on_process_init"
77
84
num_loops = len (source_datapipes )
@@ -82,21 +89,24 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, proc
82
89
torch .set_num_threads (1 )
83
90
84
91
loops = []
85
- reset_iterator_counter = _ResetCounter (num_loops )
92
+ request_counter = _RequestCounter (num_loops )
86
93
94
+ loop_id = 0
87
95
for source_datapipe , req_queue , res_queue in zip (source_datapipes , req_queues , res_queues ):
88
- # Extract Serialization Wrapper
89
- source_datapipe = extract_wrapper (source_datapipe )
90
96
loops .append (
91
97
_create_datapipe_queue_loop (
92
98
source_datapipe ,
93
99
req_queue ,
94
100
res_queue ,
95
101
process_name ,
102
+ loop_id ,
103
+ worker_info ,
104
+ custom_reset_fn ,
96
105
blocking_request_get = False ,
97
- reset_iterator_counter = reset_iterator_counter ,
106
+ request_counter = request_counter ,
98
107
)
99
108
) # Non-blocking request with reset counters
109
+ loop_id += 1
100
110
101
111
# Using `zip_longest` to guarantee the process is terminated only when
102
112
# all loops have received `TerminateRequest`
@@ -107,7 +117,9 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, proc
107
117
time .sleep (0 )
108
118
109
119
110
- def DataPipeToQueuesLoop (source_datapipe , req_queue , res_queue , process_name , call_on_process_init = None ):
120
+ def DataPipeToQueuesLoop (
121
+ source_datapipe , req_queue , res_queue , process_name , worker_info , call_on_process_init = None , custom_reset_fn = None
122
+ ):
111
123
r"""
112
124
Initialize with the given init function, set the appropriate pipe and protocol server type, and
113
125
create a loop with the protocol server.
@@ -117,8 +129,10 @@ def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, process_name, ca
117
129
req_queue: Multiprocessing queue providing requests from the main process
118
130
res_queue: Multiprocessing queue sending results to the main process
119
131
process_name: The name of process (used for logging and exception handling)
132
+ worker_info: Worker information (worker id and number of workers)
120
133
call_on_process_init: Callable function will be called at the time of worker process initialization.
121
134
Users can provide it to modify the DataPipe grpah in the worker process.
135
+ custom_reset_fn: Optional callable function to reset the DataPipe.
122
136
"""
123
137
# Extract Serialization Wrapper
124
138
source_datapipe = extract_wrapper (source_datapipe )
@@ -128,7 +142,16 @@ def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, process_name, ca
128
142
129
143
torch .set_num_threads (1 )
130
144
131
- loop = _create_datapipe_queue_loop (source_datapipe , req_queue , res_queue , process_name , blocking_request_get = True )
145
+ loop = _create_datapipe_queue_loop (
146
+ source_datapipe ,
147
+ req_queue ,
148
+ res_queue ,
149
+ process_name ,
150
+ worker_info .worker_id ,
151
+ worker_info ,
152
+ custom_reset_fn ,
153
+ blocking_request_get = True ,
154
+ )
132
155
133
156
for _ in loop :
134
157
pass
@@ -139,8 +162,11 @@ def _create_datapipe_queue_loop(
139
162
req_queue ,
140
163
res_queue ,
141
164
process_name ,
165
+ loop_id ,
166
+ worker_info ,
167
+ custom_reset_fn = None ,
142
168
blocking_request_get = True ,
143
- reset_iterator_counter = None ,
169
+ request_counter = None ,
144
170
):
145
171
if isinstance (source_datapipe , IterDataPipe ):
146
172
pipe_type = communication .iter
@@ -155,51 +181,33 @@ def _create_datapipe_queue_loop(
155
181
source_datapipe ,
156
182
protocol_type (req_queue , res_queue ),
157
183
process_name = process_name ,
184
+ loop_id = loop_id ,
185
+ worker_info = worker_info ,
186
+ custom_reset_fn = custom_reset_fn ,
158
187
blocking_request_get = blocking_request_get ,
159
- reset_iterator_counter = reset_iterator_counter ,
188
+ request_counter = request_counter ,
160
189
)
161
190
162
191
163
- def CreateProcessForDataPipeline (multiprocessing_ctx , datapipe , process_name , call_on_process_init = None ):
192
+ def CreateProcessForDataPipeline (
193
+ multiprocessing_ctx , datapipe , process_name , worker_info , call_on_process_init = None , custom_reset_fn = None
194
+ ):
164
195
r"""
165
196
Given a DataPipe, creates a new process with ``DataPipeToQueuesLoop`` as target,
166
197
and returns ``(process, req_queue, res_queue)``.
167
198
"""
168
199
req_queue = multiprocessing_ctx .Queue ()
169
200
res_queue = multiprocessing_ctx .Queue ()
170
201
process = multiprocessing_ctx .Process (
171
- target = DataPipeToQueuesLoop , args = (datapipe , req_queue , res_queue , process_name , call_on_process_init )
202
+ target = DataPipeToQueuesLoop ,
203
+ args = (datapipe , req_queue , res_queue , process_name , worker_info , call_on_process_init , custom_reset_fn ),
172
204
)
173
205
return process , req_queue , res_queue
174
206
175
207
176
- def CreateThreadForDataPipeline (datapipe , thread_name ):
177
- r"""
178
- Given a DataPipe, creates a copy of the DataPipe, starts a new Thread with ``DataPipeToQueuesLoop`` as target,
179
- and returns ``(process, req_queue, res_queue, new_copied_datapipe)``.
180
- """
181
- req_queue = communication .queue .ThreadingQueue ()
182
- res_queue = communication .queue .ThreadingQueue ()
183
-
184
- try :
185
- new_datapipe = pickle .loads (pickle .dumps (datapipe ))
186
- except Exception as pe :
187
- if HAS_DILL :
188
- try :
189
- new_datapipe = dill .loads (dill .dumps (datapipe ))
190
- except Exception as de :
191
- raise Exception ("Unable to dill DataPipe to make thread local copy" , de )
192
-
193
- else :
194
- raise Exception ("Unable to pickle DataPipe to make thread local copy (consider installing `dill`)" , pe )
195
-
196
- process = threading .Thread (
197
- target = DataPipeToQueuesLoop , args = (new_datapipe , req_queue , res_queue , thread_name ), daemon = True
198
- )
199
- return process , req_queue , res_queue , new_datapipe
200
-
201
-
202
- def CreateProcessForMultipleDataPipelines (multiprocessing_ctx , datapipes , process_name ):
208
+ def CreateProcessForMultipleDataPipelines (
209
+ multiprocessing_ctx , datapipes , process_name , worker_info , custom_reset_fn = None
210
+ ):
203
211
r"""
204
212
Given a DataPipe, creates a new process with ``MultipleDataPipesToQueuesLoop`` as target,
205
213
and returns ``(process, [req_queue_0, ...], [res_queue_0, ...])``.
@@ -211,6 +219,7 @@ def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes, proces
211
219
res_queues .append (multiprocessing_ctx .Queue ())
212
220
213
221
process = multiprocessing_ctx .Process (
214
- target = MultipleDataPipesToQueuesLoop , args = (datapipes , req_queues , res_queues , process_name )
222
+ target = MultipleDataPipesToQueuesLoop ,
223
+ args = (datapipes , req_queues , res_queues , process_name , worker_info , custom_reset_fn ),
215
224
)
216
225
return process , req_queues , res_queues
0 commit comments