Skip to content

Commit

Permalink
Async save state_dict to file (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
zkh2016 authored Sep 27, 2023
1 parent 25e3671 commit 290386a
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions bmtrain/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import nccl
import io, pickle
from typing import Mapping
import threading

def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix):
if isinstance(model, Block):
Expand Down Expand Up @@ -81,23 +82,40 @@ def _save_to_infer_model(model : torch.nn.Module, infer_model, destination=None,
infer_model.load_layer_state_dict(destination)


def async_save_to_file(state_dict, file_path):
torch.save(state_dict, file_path)
config['finish_save'] = True
print("finish save state_dict to ", file_path)

def save(model : torch.nn.Module, file_name : str):
def save(model : torch.nn.Module, file_name : str, non_blocking : bool=True):
"""Saves the model to the file.
Similar to torch.save, but it used for distributed modules.
Args:
model (torch.nn.Module): The model to be saved.
file_name (str): The file name of the checkpoint.
non_blocking (bool): Whether to asynchronously save state_dict to file
Examples:
>>> bmtrain.save(model, "model.pt")
"""
torch.cuda.synchronize()
state_dict = _save_to_rank0(model)
if config["rank"] == 0:
torch.save(state_dict, file_name)
if non_blocking is False:
torch.save(state_dict, file_name)
else:
if 'finish_save' not in config:
config['finish_save'] = True

if config['finish_save'] is False:
config['save_thread'].join()

config['finish_save'] = False
config['save_thread'] = threading.Thread(target=async_save_to_file, args=(state_dict, file_name))
config['save_thread'].start()

DTYPE_LIST = [
torch.float64,
Expand Down

0 comments on commit 290386a

Please sign in to comment.