Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upload the model with push_to_hub in examples #297

Merged
merged 17 commits into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/albert/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
transformers>=4.5.1
transformers>=4.6.0
datasets>=1.5.0
torch_optimizer>=0.1.0
wandb>=0.10.26
Expand Down
20 changes: 11 additions & 9 deletions examples/albert/run_first_peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class CoordinatorArguments(BaseTrainingArguments):
default=None,
metadata={"help": "Path to HuggingFace repo in which coordinator will upload the model and optimizer states"}
)
repo_url: Optional[str] = field(
default=None,
metadata={"help": "Url to HuggingFace repo in which coordinator will upload the model and optimizer states"}
)
upload_interval: Optional[float] = field(
default=None,
metadata={"help": "Coordinator will upload model once in this many seconds"}
Expand All @@ -63,6 +67,7 @@ def __init__(self, coordinator_args: CoordinatorArguments, collab_optimizer_args
averager_args: AveragerArguments, dht: hivemind.DHT):
self.save_checkpoint_step_interval = coordinator_args.save_checkpoint_step_interval
self.repo_path = coordinator_args.repo_path
self.repo_url = coordinator_args.repo_url
self.upload_interval = coordinator_args.upload_interval
self.previous_step = -1

Expand Down Expand Up @@ -106,6 +111,7 @@ def is_time_to_save_state(self, cur_step):
return False

def save_state(self, cur_step):
logger.info("Saving state from peers")
self.collaborative_optimizer.load_state_from_peers()
self.previous_step = cur_step

Expand All @@ -118,17 +124,13 @@ def is_time_to_upload(self):
return False

def upload_checkpoint(self, current_loss):
self.model.save_pretrained(self.repo_path)
logger.info("Saving optimizer")
torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
self.previous_timestamp = time.time()
try:
subprocess.run("git add --all", shell=True, check=True, cwd=self.repo_path)
current_step = self.collaborative_optimizer.collaboration_state.optimizer_step
subprocess.run(f"git commit -m 'Step {current_step}, loss {current_loss:.3f}'",
shell=True, check=True, cwd=self.repo_path)
subprocess.run("git push", shell=True, check=True, cwd=self.repo_path)
except subprocess.CalledProcessError as e:
logger.warning("Error while uploading model:", e.output)
logger.info('Start uploading model to hub')
self.model.push_to_hub(repo_name=self.repo_path, repo_url=self.repo_url,
commit_message=f'Step {current_step}, loss {current_loss:.3f}')
logger.info('Finish uploading model to hub')


if __name__ == '__main__':
Expand Down