diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 98d34b5f94..b0f2752823 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -142,8 +142,6 @@ def __init__(self, cfg: DictConfig) -> None: ) self._log_peak_memory_stats = False - # _is_rank_zero is used primarily for logging. In the future, the logger - # should directly take care of this _, rank = training.get_world_size_and_rank() self._is_rank_zero = rank == 0 @@ -302,8 +300,7 @@ def setup(self, cfg: DictConfig) -> None: # set num_output_chunks for model self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) - if self._is_rank_zero: - log.info("Loss is initialized.") + utils.log_rank_zero(log, "Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized @@ -397,9 +394,10 @@ def _setup_profiler( profiler, profiler_cfg = config.instantiate(cfg_profiler) + utils.log_rank_zero( + log, f" Profiler config after instantiation: {profiler_cfg}" + ) if self._is_rank_zero: - log.info(f" Profiler config after instantiation: {profiler_cfg}") - self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) if profiler_cfg["enabled"]: self.profiler_wait_steps = profiler_cfg["wait_steps"] @@ -428,11 +426,11 @@ def _setup_model( full state dicts are loaded with ``torch.load(mmap=True)`` """ - if self._is_rank_zero: - log.info( - "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." - ) - init_start = time.perf_counter() + utils.log_rank_zero( + log, + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...", + ) + init_start = time.perf_counter() with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(cfg_model) @@ -498,10 +496,11 @@ def _setup_model( # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) + utils.log_rank_zero( + log, + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs", + ) if self._is_rank_zero: - log.info( - f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" - ) memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats(memory_stats) @@ -547,8 +546,7 @@ def _setup_optimizer( "Failed loading in-backward optimizer checkpoints." "Please make sure run being restored from was using in-backward optimizer." ) from e - if self._is_rank_zero: - log.info("In-backward optimizers are set up.") + utils.log_rank_zero(log, "In-backward optimizers are set up.") return None else: optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) @@ -559,8 +557,7 @@ def _setup_optimizer( self._device, ) - if self._is_rank_zero: - log.info("Optimizer is initialized.") + utils.log_rank_zero(log, "Optimizer is initialized.") return optimizer def _setup_data( @@ -613,8 +610,7 @@ def _setup_data( ), ) - if self._is_rank_zero: - log.info("Dataset and Sampler are initialized.") + utils.log_rank_zero(log, "Dataset and Sampler are initialized.") return sampler, dataloader @@ -637,11 +633,11 @@ def save_checkpoint( intermediate_checkpoint = epoch + 1 < self.total_epochs - if self._is_rank_zero: - log.info( - "Saving checkpoint. This may take some time. Retrieving full model state dict..." - ) - start = time.perf_counter() + utils.log_rank_zero( + log, + "Saving checkpoint. This may take some time. Retrieving full model state dict...", + ) + start = time.perf_counter() # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 @@ -651,15 +647,14 @@ def save_checkpoint( device=self._device, ) - if self._is_rank_zero: - log.info( - f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" - ) + utils.log_rank_zero( + log, + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs", + ) if intermediate_checkpoint: start = time.perf_counter() - if self._is_rank_zero: - log.info("Getting optimizer state dict...") + utils.log_rank_zero(log, "Getting optimizer state dict...") if not self._optimizer_in_bwd: opt_state_dict = training.get_full_optimizer_state_dict( self._optimizer, @@ -672,10 +667,10 @@ def save_checkpoint( opt_state_dict[param] = training.get_full_optimizer_state_dict( opt, self._is_rank_zero, device=self._device ) - if self._is_rank_zero: - log.info( - f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs" - ) + utils.log_rank_zero( + log, + f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs", + ) else: opt_state_dict = None diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py index c4dcf3b62f..c920f4b069 100644 --- a/recipes/knowledge_distillation_distributed.py +++ b/recipes/knowledge_distillation_distributed.py @@ -119,8 +119,6 @@ def __init__(self, cfg: DictConfig) -> None: _, rank = training.get_world_size_and_rank() - # _is_rank_zero is used primarily for logging. In the future, the logger - # should directly take care of this self._is_rank_zero = rank == 0 # logging attributes @@ -287,8 +285,7 @@ def setup(self, cfg: DictConfig) -> None: self._loss_fn.num_output_chunks == self._kd_loss_fn.num_output_chunks ), "Number of output chunks for loss_fn and kd_loss_fn must be the same." - if self._is_rank_zero: - log.info("Loss is initialized.") + utils.log_rank_zero(log, "Loss is initialized.") # Dataloader depends on the tokenizer and loss_fn and should be # setup after all of these are setup @@ -389,9 +386,10 @@ def _setup_profiler( profiler, profiler_cfg = config.instantiate(cfg_profiler) + utils.log_rank_zero( + log, f" Profiler config after instantiation: {profiler_cfg}" + ) if self._is_rank_zero: - log.info(f" Profiler config after instantiation: {profiler_cfg}") - self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) if profiler_cfg["enabled"]: self.profiler_wait_steps = profiler_cfg["wait_steps"] @@ -425,11 +423,11 @@ def _setup_model( self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) - if self._is_rank_zero: - log.info( - "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." - ) - init_start = time.perf_counter() + utils.log_rank_zero( + log, + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...", + ) + init_start = time.perf_counter() with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(cfg_model) @@ -512,10 +510,11 @@ def _setup_model( # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) + utils.log_rank_zero( + log, + f"Instantiating student model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs", + ) if self._is_rank_zero: - log.info( - f"Instantiating student model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" - ) memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats( memory_stats, message="Memory stats after student model init:" @@ -542,11 +541,11 @@ def _setup_teacher_model( full state dicts are loaded with ``torch.load(mmap=True)`` """ - if self._is_rank_zero: - log.info( - "FSDP enabled. Instantiating teacher model and loading checkpoint on Rank 0 ..." - ) - init_start = time.perf_counter() + utils.log_rank_zero( + log, + "FSDP enabled. Instantiating teacher model and loading checkpoint on Rank 0 ...", + ) + init_start = time.perf_counter() with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(model_cfg) @@ -596,10 +595,11 @@ def _setup_teacher_model( # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) + utils.log_rank_zero( + log, + f"Instantiating teacher model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs", + ) if self._is_rank_zero: - log.info( - f"Instantiating teacher model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" - ) memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats( memory_stats, message="Memory stats after teacher model init:" @@ -621,8 +621,7 @@ def _setup_optimizer( self._device, ) - if self._is_rank_zero: - log.info("Optimizer is initialized.") + utils.log_rank_zero(log, "Optimizer is initialized.") return optimizer def _setup_lr_scheduler( @@ -638,8 +637,7 @@ def _setup_lr_scheduler( last_epoch=last_epoch, ) - if self._is_rank_zero: - log.info("Learning rate scheduler is initialized.") + utils.log_rank_zero(log, "Learning rate scheduler is initialized.") return lr_scheduler def _setup_data( @@ -690,8 +688,7 @@ def _setup_data( ), ) - if self._is_rank_zero: - log.info("Dataset and Sampler are initialized.") + utils.log_rank_zero(log, "Dataset and Sampler are initialized.") return sampler, dataloader diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index ee7ca5e729..ab37623cc1 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -122,8 +122,6 @@ def __init__(self, cfg: DictConfig) -> None: _, rank = training.get_world_size_and_rank() - # _is_rank_zero is used primarily for logging. In the future, the logger - # should directly take care of this self._is_rank_zero = rank == 0 # logging attributes @@ -226,7 +224,8 @@ def setup(self, cfg: DictConfig) -> None: # log config with parameter override self._metric_logger.log_config(cfg) - log.info("_metric_logger is initialized.") + + utils.log_rank_zero(log, "metric logger is initialized.") checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) @@ -254,8 +253,8 @@ def setup(self, cfg: DictConfig) -> None: ) self._loss_fn = config.instantiate(cfg.loss) - if self._is_rank_zero: - log.info("Loss is initialized.") + + utils.log_rank_zero(log, "Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after all of these are setup @@ -314,11 +313,12 @@ def _setup_model( self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) - if self._is_rank_zero: - log.info( - "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." - ) - init_start = time.perf_counter() + init_start = time.perf_counter() + + utils.log_rank_zero( + log, + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...", + ) with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(cfg_model) @@ -397,10 +397,11 @@ def _setup_model( ) # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) + utils.log_rank_zero( + log, + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs", + ) if self._is_rank_zero: - log.info( - f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" - ) memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats(memory_stats) @@ -420,8 +421,7 @@ def _setup_optimizer( self._device, ) - if self._is_rank_zero: - log.info("Optimizer and loss are initialized.") + utils.log_rank_zero(log, "Optimizer and loss are initialized.") return optimizer def _setup_lr_scheduler( @@ -436,8 +436,8 @@ def _setup_lr_scheduler( num_training_steps=num_training_steps, last_epoch=last_epoch, ) - if self._is_rank_zero: - log.info("Learning rate scheduler is initialized.") + + utils.log_rank_zero(log, "Learning rate scheduler is initialized.") return lr_scheduler def _setup_data( @@ -479,8 +479,7 @@ def _setup_data( ), ) - if self._is_rank_zero: - log.info("Dataset and Sampler are initialized.") + utils.log_rank_zero(log, "Dataset and Sampler are initialized.") return sampler, dataloader diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 6bbecc3d91..45209814a0 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -138,8 +138,6 @@ def __init__(self, cfg: DictConfig) -> None: _, rank = training.get_world_size_and_rank() - # _is_rank_zero is used primarily for logging. In the future, the logger - # should directly take care of this self._is_rank_zero = rank == 0 # logging attributes @@ -304,8 +302,7 @@ def setup(self, cfg: DictConfig) -> None: if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": # set num_output_chunks for model self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) - if self._is_rank_zero: - log.info("Loss is initialized.") + utils.log_rank_zero(log, "Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after all of these are setup @@ -407,9 +404,10 @@ def _setup_profiler( profiler, profiler_cfg = config.instantiate(cfg_profiler) + utils.log_rank_zero( + log, f" Profiler config after instantiation: {profiler_cfg}" + ) if self._is_rank_zero: - log.info(f" Profiler config after instantiation: {profiler_cfg}") - self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) if profiler_cfg["enabled"]: self.profiler_wait_steps = profiler_cfg["wait_steps"] @@ -444,11 +442,11 @@ def _setup_model( self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) - if self._is_rank_zero: - log.info( - "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." - ) - init_start = time.perf_counter() + utils.log_rank_zero( + log, + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...", + ) + init_start = time.perf_counter() with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(cfg_model) @@ -536,10 +534,11 @@ def _setup_model( ) # log + utils.log_rank_zero( + log, + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs", + ) if self._is_rank_zero: - log.info( - f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" - ) memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats(memory_stats) @@ -559,8 +558,7 @@ def _setup_optimizer( self._device, ) - if self._is_rank_zero: - log.info("Optimizer is initialized.") + utils.log_rank_zero(log, "Optimizer is initialized.") return optimizer def _setup_lr_scheduler( @@ -575,8 +573,7 @@ def _setup_lr_scheduler( num_training_steps=num_training_steps, last_epoch=last_epoch, ) - if self._is_rank_zero: - log.info("Learning rate scheduler is initialized.") + utils.log_rank_zero(log, "Learning rate scheduler is initialized.") return lr_scheduler def _setup_data( @@ -630,8 +627,7 @@ def _setup_data( ), ) - if self._is_rank_zero: - log.info("Dataset and Sampler are initialized.") + utils.log_rank_zero(log, "Dataset and Sampler are initialized.") return sampler, dataloader @@ -656,11 +652,11 @@ def save_checkpoint( intermediate_checkpoint = epoch + 1 < self.total_epochs - if self._is_rank_zero: - log.info( - "Saving checkpoint. This may take some time. Retrieving full model state dict..." - ) - start = time.perf_counter() + utils.log_rank_zero( + log, + "Saving checkpoint. This may take some time. Retrieving full model state dict...", + ) + start = time.perf_counter() # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 @@ -673,23 +669,22 @@ def save_checkpoint( self._is_rank_zero, device=self._device, ) - if self._is_rank_zero: - log.info( - f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" - ) + utils.log_rank_zero( + log, + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs", + ) if intermediate_checkpoint: - if self._is_rank_zero: - log.info("Retrieving optimizer state dict...") + utils.log_rank_zero(log, "Retrieving optimizer state dict...") opt_state_dict = training.get_full_optimizer_state_dict( self._optimizer, self._is_rank_zero, device=self._device, ) - if self._is_rank_zero: - log.info( - f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs" - ) + utils.log_rank_zero( + log, + f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs", + ) else: opt_state_dict = None diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index 1aa622ba63..ab9cea3eda 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -153,8 +153,6 @@ def __init__(self, cfg: DictConfig) -> None: ) self._log_peak_memory_stats = False - # _is_rank_zero is used primarily for logging. In the future, the logger - # should directly take care of this _, rank = training.get_world_size_and_rank() self._is_rank_zero = rank == 0 @@ -316,8 +314,7 @@ def setup(self, cfg: DictConfig) -> None: # set num_output_chunks for model self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) - if self._is_rank_zero: - log.info("Loss is initialized.") + utils.log_rank_zero(log, "Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized @@ -411,9 +408,10 @@ def _setup_profiler( profiler, profiler_cfg = config.instantiate(cfg_profiler) + utils.log_rank_zero( + log, f" Profiler config after instantiation: {profiler_cfg}" + ) if self._is_rank_zero: - log.info(f" Profiler config after instantiation: {profiler_cfg}") - self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) if profiler_cfg["enabled"]: self.profiler_wait_steps = profiler_cfg["wait_steps"] @@ -443,11 +441,11 @@ def _setup_model( full state dicts are loaded with ``torch.load(mmap=True)`` """ - if self._is_rank_zero: - log.info( - "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." - ) - init_start = time.perf_counter() + utils.log_rank_zero( + log, + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...", + ) + init_start = time.perf_counter() with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(cfg_model) @@ -526,10 +524,11 @@ def _setup_model( # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) + utils.log_rank_zero( + log, + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs", + ) if self._is_rank_zero: - log.info( - f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" - ) memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats(memory_stats) @@ -575,8 +574,7 @@ def _setup_optimizer( "Failed loading in-backward optimizer checkpoints." "Please make sure run being restored from was using in-backward optimizer." ) from e - if self._is_rank_zero: - log.info("In-backward optimizers are set up.") + utils.log_rank_zero(log, "In-backward optimizers are set up.") return None else: optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) @@ -587,8 +585,7 @@ def _setup_optimizer( self._device, ) - if self._is_rank_zero: - log.info("Optimizer is initialized.") + utils.log_rank_zero(log, "Optimizer is initialized.") return optimizer def _setup_data( @@ -641,8 +638,7 @@ def _setup_data( ), ) - if self._is_rank_zero: - log.info("Dataset and Sampler are initialized.") + utils.log_rank_zero(log, "Dataset and Sampler are initialized.") return sampler, dataloader @@ -665,11 +661,11 @@ def save_checkpoint( intermediate_checkpoint = epoch + 1 < self.total_epochs - if self._is_rank_zero: - log.info( - "Saving checkpoint. This may take some time. Retrieving full model state dict..." - ) - start = time.perf_counter() + utils.log_rank_zero( + log, + "Saving checkpoint. This may take some time. Retrieving full model state dict...", + ) + start = time.perf_counter() # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 @@ -679,15 +675,14 @@ def save_checkpoint( device=self._device, ) - if self._is_rank_zero: - log.info( - f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" - ) + utils.log_rank_zero( + log, + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs", + ) if intermediate_checkpoint: start = time.perf_counter() - if self._is_rank_zero: - log.info("Getting optimizer state dict...") + utils.log_rank_zero(log, "Getting optimizer state dict...") if not self._optimizer_in_bwd: opt_state_dict = training.get_full_optimizer_state_dict( self._optimizer, @@ -700,10 +695,10 @@ def save_checkpoint( opt_state_dict[param] = training.get_full_optimizer_state_dict( opt, self._is_rank_zero, device=self._device ) - if self._is_rank_zero: - log.info( - f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs" - ) + utils.log_rank_zero( + log, + f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs", + ) else: opt_state_dict = None diff --git a/torchtune/utils/_logging.py b/torchtune/utils/_logging.py index 075663c166..ec3912e317 100644 --- a/torchtune/utils/_logging.py +++ b/torchtune/utils/_logging.py @@ -98,4 +98,4 @@ def log_rank_zero(logger: logging.Logger, msg: str, level: int = logging.INFO) - rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 if rank != 0: return - logger.log(level, msg) + logger.log(level, msg, stacklevel=2)