Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
set tqdm lock when new workers are spawned (#5330)
Browse files Browse the repository at this point in the history
Co-authored-by: Dirk Groeneveld <[email protected]>
  • Loading branch information
epwalsh and dirkgr authored Jul 23, 2021
1 parent 67add9d commit 76f2487
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed a mispelling: the parameter `contructor_extras` in `Lazy()` is now correctly called `constructor_extras`.
- Fixed broken links in `allennlp.nn.initializers` docs.
- `TransformerTextField` can now take tensors of shape `(1, n)` like the tensors produced from a HuggingFace tokenizer.
- `tqdm` lock is now set inside `MultiProcessDataLoading` when new workers are spawned to avoid contention when writing output.
- `ConfigurationError` is now pickleable.

### Changed
Expand Down
8 changes: 8 additions & 0 deletions allennlp/common/tqdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,11 @@ def tqdm(*args, **kwargs):
}

return _tqdm(*args, **new_kwargs)

@staticmethod
def set_lock(lock):
_tqdm.set_lock(lock)

@staticmethod
def get_lock():
return _tqdm.get_lock()
12 changes: 8 additions & 4 deletions allennlp/data/data_loaders/multiprocess_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,20 +428,22 @@ def _iter_batches(self) -> Iterator[TensorDict]:
self._join_workers(workers, queue)

def _start_instance_workers(self, queue: mp.JoinableQueue, ctx) -> List[BaseProcess]:
Tqdm.set_lock(mp.RLock())
workers: List[BaseProcess] = []
for worker_id in range(self.num_workers):
worker: BaseProcess = ctx.Process(
target=self._instance_worker, args=(worker_id, queue), daemon=True
target=self._instance_worker, args=(worker_id, queue, Tqdm.get_lock()), daemon=True
)
worker.start()
workers.append(worker)
return workers

def _start_batch_workers(self, queue: mp.JoinableQueue, ctx) -> List[BaseProcess]:
Tqdm.set_lock(mp.RLock())
workers: List[BaseProcess] = []
for worker_id in range(self.num_workers):
worker: BaseProcess = ctx.Process(
target=self._batch_worker, args=(worker_id, queue), daemon=True
target=self._batch_worker, args=(worker_id, queue, Tqdm.get_lock()), daemon=True
)
worker.start()
workers.append(worker)
Expand All @@ -463,7 +465,8 @@ def _join_workers(self, workers: List[BaseProcess], queue) -> None:
if worker.is_alive():
worker.terminate()

def _instance_worker(self, worker_id: int, queue: mp.JoinableQueue) -> None:
def _instance_worker(self, worker_id: int, queue: mp.JoinableQueue, lock) -> None:
Tqdm.set_lock(lock)
try:
self.reader._set_worker_info(WorkerInfo(self.num_workers, worker_id))
instances = self.reader.read(self.data_path)
Expand Down Expand Up @@ -495,7 +498,8 @@ def _instance_worker(self, worker_id: int, queue: mp.JoinableQueue) -> None:
# Wait until this process can safely exit.
queue.join()

def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue) -> None:
def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue, lock) -> None:
Tqdm.set_lock(lock)
try:
self.reader._set_worker_info(WorkerInfo(self.num_workers, worker_id))
instances = self.reader.read(self.data_path)
Expand Down

0 comments on commit 76f2487

Please sign in to comment.