From 4d0f6b3ac670083d6605d3cd2a14933f36644814 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 22 Jul 2024 21:57:27 -0700 Subject: [PATCH] [misc] only tqdm for first rank (#6672) --- .../model_loader/weight_utils.py | 37 ++++++++++++++++--- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 831bdcd242d28..ee3b2530880d1 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -313,6 +313,13 @@ def filter_files_not_needed_for_inference( return hf_weights_files +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + def np_cache_weights_iterator( model_name_or_path: str, cache_dir: Optional[str], hf_folder: str, hf_weights_files: List[str] @@ -321,6 +328,8 @@ def np_cache_weights_iterator( Will dump the model weights to numpy files if they are not already dumped. """ + enable_tqdm = not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0 # Convert the model weights from torch tensors to numpy arrays for # faster loading. np_folder = os.path.join(hf_folder, "np") @@ -331,8 +340,12 @@ def np_cache_weights_iterator( with get_lock(model_name_or_path, cache_dir): if not os.path.exists(weight_names_file): weight_names: List[str] = [] - for bin_file in tqdm(hf_weights_files, - desc="Loading np_cache checkpoint shards"): + for bin_file in tqdm( + hf_weights_files, + desc="Loading np_cache checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): state = torch.load(bin_file, map_location="cpu") for name, param in state.items(): param_path = os.path.join(np_folder, name) @@ -356,8 +369,14 @@ def safetensors_weights_iterator( hf_weights_files: List[str] ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" - for st_file in tqdm(hf_weights_files, - desc="Loading safetensors checkpoint shards"): + enable_tqdm = not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0 + for st_file in tqdm( + hf_weights_files, + desc="Loading safetensors checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): with safe_open(st_file, framework="pt") as f: for name in f.keys(): # noqa: SIM118 param = f.get_tensor(name) @@ -368,8 +387,14 @@ def pt_weights_iterator( hf_weights_files: List[str] ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model bin/pt files.""" - for bin_file in tqdm(hf_weights_files, - desc="Loading pt checkpoint shards"): + enable_tqdm = not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0 + for bin_file in tqdm( + hf_weights_files, + desc="Loading pt checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): state = torch.load(bin_file, map_location="cpu") for name, param in state.items(): yield name, param