From 290386a76d3268f3c42d5d1a5e8bd1186b26c1e6 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 27 Sep 2023 20:56:14 +0800 Subject: [PATCH] Async save state_dict to file (#171) --- bmtrain/store.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/bmtrain/store.py b/bmtrain/store.py index 254213bd..7279ac53 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -9,6 +9,7 @@ from . import nccl import io, pickle from typing import Mapping +import threading def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix): if isinstance(model, Block): @@ -81,8 +82,12 @@ def _save_to_infer_model(model : torch.nn.Module, infer_model, destination=None, infer_model.load_layer_state_dict(destination) +def async_save_to_file(state_dict, file_path): + torch.save(state_dict, file_path) + config['finish_save'] = True + print("finish save state_dict to ", file_path) -def save(model : torch.nn.Module, file_name : str): +def save(model : torch.nn.Module, file_name : str, non_blocking : bool=True): """Saves the model to the file. Similar to torch.save, but it used for distributed modules. @@ -90,6 +95,8 @@ def save(model : torch.nn.Module, file_name : str): Args: model (torch.nn.Module): The model to be saved. file_name (str): The file name of the checkpoint. + non_blocking (bool): Whether to asynchronously save state_dict to file + Examples: >>> bmtrain.save(model, "model.pt") @@ -97,7 +104,18 @@ def save(model : torch.nn.Module, file_name : str): torch.cuda.synchronize() state_dict = _save_to_rank0(model) if config["rank"] == 0: - torch.save(state_dict, file_name) + if non_blocking is False: + torch.save(state_dict, file_name) + else: + if 'finish_save' not in config: + config['finish_save'] = True + + if config['finish_save'] is False: + config['save_thread'].join() + + config['finish_save'] = False + config['save_thread'] = threading.Thread(target=async_save_to_file, args=(state_dict, file_name)) + config['save_thread'].start() DTYPE_LIST = [ torch.float64,