Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add context manager to allow main process first. #98

Merged
merged 3 commits into from
Jun 3, 2021
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import gc
import os
from contextlib import contextmanager
from typing import List, Optional, Union

import torch
Expand Down Expand Up @@ -183,6 +184,33 @@ def use_fp16(self):
use_fp16 = self.state.use_fp16
return use_fp16

@contextmanager
def local_main_process_first(self):
"""
Lets the local main process go inside a with block.

The other processes will enter the with block after the main process exits.
"""
yield from self._goes_first(self.is_local_main_process)

@contextmanager
def main_process_first(self):
"""
Lets the main process go first inside a with block.

The other processes will enter the with block after the main process exits.
"""
yield from self._goes_first(self.is_main_process)

def _goes_first(self, is_master):
if not is_master:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit: can we use is_main for consistency here (as well as at line 211 below)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sure!

self.wait_for_everyone()

yield

if is_master:
self.wait_for_everyone()

def print(self, *args, **kwargs):
"""
Use in replacement of :obj:`print()` to only print once per server.
Expand Down