diff --git a/CHANGELOG.md b/CHANGELOG.md index ace08f4ee68..7fd0b2e4c58 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed a mispelling: the parameter `contructor_extras` in `Lazy()` is now correctly called `constructor_extras`. - Fixed broken links in `allennlp.nn.initializers` docs. - `TransformerTextField` can now take tensors of shape `(1, n)` like the tensors produced from a HuggingFace tokenizer. +- `tqdm` lock is now set inside `MultiProcessDataLoading` when new workers are spawned to avoid contention when writing output. - `ConfigurationError` is now pickleable. ### Changed diff --git a/allennlp/common/tqdm.py b/allennlp/common/tqdm.py index bad9f35b574..dc9692a96ee 100644 --- a/allennlp/common/tqdm.py +++ b/allennlp/common/tqdm.py @@ -88,3 +88,11 @@ def tqdm(*args, **kwargs): } return _tqdm(*args, **new_kwargs) + + @staticmethod + def set_lock(lock): + _tqdm.set_lock(lock) + + @staticmethod + def get_lock(): + return _tqdm.get_lock() diff --git a/allennlp/data/data_loaders/multiprocess_data_loader.py b/allennlp/data/data_loaders/multiprocess_data_loader.py index a46ea4f4f53..6243144f5aa 100644 --- a/allennlp/data/data_loaders/multiprocess_data_loader.py +++ b/allennlp/data/data_loaders/multiprocess_data_loader.py @@ -428,20 +428,22 @@ def _iter_batches(self) -> Iterator[TensorDict]: self._join_workers(workers, queue) def _start_instance_workers(self, queue: mp.JoinableQueue, ctx) -> List[BaseProcess]: + Tqdm.set_lock(mp.RLock()) workers: List[BaseProcess] = [] for worker_id in range(self.num_workers): worker: BaseProcess = ctx.Process( - target=self._instance_worker, args=(worker_id, queue), daemon=True + target=self._instance_worker, args=(worker_id, queue, Tqdm.get_lock()), daemon=True ) worker.start() workers.append(worker) return workers def _start_batch_workers(self, queue: mp.JoinableQueue, ctx) -> List[BaseProcess]: + Tqdm.set_lock(mp.RLock()) workers: List[BaseProcess] = [] for worker_id in range(self.num_workers): worker: BaseProcess = ctx.Process( - target=self._batch_worker, args=(worker_id, queue), daemon=True + target=self._batch_worker, args=(worker_id, queue, Tqdm.get_lock()), daemon=True ) worker.start() workers.append(worker) @@ -463,7 +465,8 @@ def _join_workers(self, workers: List[BaseProcess], queue) -> None: if worker.is_alive(): worker.terminate() - def _instance_worker(self, worker_id: int, queue: mp.JoinableQueue) -> None: + def _instance_worker(self, worker_id: int, queue: mp.JoinableQueue, lock) -> None: + Tqdm.set_lock(lock) try: self.reader._set_worker_info(WorkerInfo(self.num_workers, worker_id)) instances = self.reader.read(self.data_path) @@ -495,7 +498,8 @@ def _instance_worker(self, worker_id: int, queue: mp.JoinableQueue) -> None: # Wait until this process can safely exit. queue.join() - def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue) -> None: + def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue, lock) -> None: + Tqdm.set_lock(lock) try: self.reader._set_worker_info(WorkerInfo(self.num_workers, worker_id)) instances = self.reader.read(self.data_path)