Skip to content

Commit

Permalink
log rank zero everywhere (#2030)
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA authored Nov 20, 2024
1 parent abdb5a4 commit a4a74a0
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 153 deletions.
65 changes: 30 additions & 35 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ def __init__(self, cfg: DictConfig) -> None:
)
self._log_peak_memory_stats = False

# _is_rank_zero is used primarily for logging. In the future, the logger
# should directly take care of this
_, rank = training.get_world_size_and_rank()
self._is_rank_zero = rank == 0

Expand Down Expand Up @@ -302,8 +300,7 @@ def setup(self, cfg: DictConfig) -> None:
# set num_output_chunks for model
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)

if self._is_rank_zero:
log.info("Loss is initialized.")
utils.log_rank_zero(log, "Loss is initialized.")

# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after both of these are initialized
Expand Down Expand Up @@ -397,9 +394,10 @@ def _setup_profiler(

profiler, profiler_cfg = config.instantiate(cfg_profiler)

utils.log_rank_zero(
log, f" Profiler config after instantiation: {profiler_cfg}"
)
if self._is_rank_zero:
log.info(f" Profiler config after instantiation: {profiler_cfg}")

self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
if profiler_cfg["enabled"]:
self.profiler_wait_steps = profiler_cfg["wait_steps"]
Expand Down Expand Up @@ -428,11 +426,11 @@ def _setup_model(
full state dicts are loaded with ``torch.load(mmap=True)``
"""

if self._is_rank_zero:
log.info(
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..."
)
init_start = time.perf_counter()
utils.log_rank_zero(
log,
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
)
init_start = time.perf_counter()

with training.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)
Expand Down Expand Up @@ -498,10 +496,11 @@ def _setup_model(
# Ensure no params and buffers are on meta device
training.validate_no_params_on_meta_device(model)

utils.log_rank_zero(
log,
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs",
)
if self._is_rank_zero:
log.info(
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
)
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)

Expand Down Expand Up @@ -547,8 +546,7 @@ def _setup_optimizer(
"Failed loading in-backward optimizer checkpoints."
"Please make sure run being restored from was using in-backward optimizer."
) from e
if self._is_rank_zero:
log.info("In-backward optimizers are set up.")
utils.log_rank_zero(log, "In-backward optimizers are set up.")
return None
else:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
Expand All @@ -559,8 +557,7 @@ def _setup_optimizer(
self._device,
)

if self._is_rank_zero:
log.info("Optimizer is initialized.")
utils.log_rank_zero(log, "Optimizer is initialized.")
return optimizer

def _setup_data(
Expand Down Expand Up @@ -613,8 +610,7 @@ def _setup_data(
),
)

if self._is_rank_zero:
log.info("Dataset and Sampler are initialized.")
utils.log_rank_zero(log, "Dataset and Sampler are initialized.")

return sampler, dataloader

Expand All @@ -637,11 +633,11 @@ def save_checkpoint(

intermediate_checkpoint = epoch + 1 < self.total_epochs

if self._is_rank_zero:
log.info(
"Saving checkpoint. This may take some time. Retrieving full model state dict..."
)
start = time.perf_counter()
utils.log_rank_zero(
log,
"Saving checkpoint. This may take some time. Retrieving full model state dict...",
)
start = time.perf_counter()

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
Expand All @@ -651,15 +647,14 @@ def save_checkpoint(
device=self._device,
)

if self._is_rank_zero:
log.info(
f"Getting full model state dict took {time.perf_counter() - start:.2f} secs"
)
utils.log_rank_zero(
log,
f"Getting full model state dict took {time.perf_counter() - start:.2f} secs",
)

if intermediate_checkpoint:
start = time.perf_counter()
if self._is_rank_zero:
log.info("Getting optimizer state dict...")
utils.log_rank_zero(log, "Getting optimizer state dict...")
if not self._optimizer_in_bwd:
opt_state_dict = training.get_full_optimizer_state_dict(
self._optimizer,
Expand All @@ -672,10 +667,10 @@ def save_checkpoint(
opt_state_dict[param] = training.get_full_optimizer_state_dict(
opt, self._is_rank_zero, device=self._device
)
if self._is_rank_zero:
log.info(
f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs"
)
utils.log_rank_zero(
log,
f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs",
)
else:
opt_state_dict = None

Expand Down
53 changes: 25 additions & 28 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def __init__(self, cfg: DictConfig) -> None:

_, rank = training.get_world_size_and_rank()

# _is_rank_zero is used primarily for logging. In the future, the logger
# should directly take care of this
self._is_rank_zero = rank == 0

# logging attributes
Expand Down Expand Up @@ -287,8 +285,7 @@ def setup(self, cfg: DictConfig) -> None:
self._loss_fn.num_output_chunks == self._kd_loss_fn.num_output_chunks
), "Number of output chunks for loss_fn and kd_loss_fn must be the same."

if self._is_rank_zero:
log.info("Loss is initialized.")
utils.log_rank_zero(log, "Loss is initialized.")

# Dataloader depends on the tokenizer and loss_fn and should be
# setup after all of these are setup
Expand Down Expand Up @@ -389,9 +386,10 @@ def _setup_profiler(

profiler, profiler_cfg = config.instantiate(cfg_profiler)

utils.log_rank_zero(
log, f" Profiler config after instantiation: {profiler_cfg}"
)
if self._is_rank_zero:
log.info(f" Profiler config after instantiation: {profiler_cfg}")

self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
if profiler_cfg["enabled"]:
self.profiler_wait_steps = profiler_cfg["wait_steps"]
Expand Down Expand Up @@ -425,11 +423,11 @@ def _setup_model(
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False)

if self._is_rank_zero:
log.info(
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..."
)
init_start = time.perf_counter()
utils.log_rank_zero(
log,
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
)
init_start = time.perf_counter()

with training.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)
Expand Down Expand Up @@ -512,10 +510,11 @@ def _setup_model(
# Ensure no params and buffers are on meta device
training.validate_no_params_on_meta_device(model)

utils.log_rank_zero(
log,
f"Instantiating student model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs",
)
if self._is_rank_zero:
log.info(
f"Instantiating student model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
)
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(
memory_stats, message="Memory stats after student model init:"
Expand All @@ -542,11 +541,11 @@ def _setup_teacher_model(
full state dicts are loaded with ``torch.load(mmap=True)``
"""

if self._is_rank_zero:
log.info(
"FSDP enabled. Instantiating teacher model and loading checkpoint on Rank 0 ..."
)
init_start = time.perf_counter()
utils.log_rank_zero(
log,
"FSDP enabled. Instantiating teacher model and loading checkpoint on Rank 0 ...",
)
init_start = time.perf_counter()

with training.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(model_cfg)
Expand Down Expand Up @@ -596,10 +595,11 @@ def _setup_teacher_model(
# Ensure no params and buffers are on meta device
training.validate_no_params_on_meta_device(model)

utils.log_rank_zero(
log,
f"Instantiating teacher model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs",
)
if self._is_rank_zero:
log.info(
f"Instantiating teacher model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
)
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(
memory_stats, message="Memory stats after teacher model init:"
Expand All @@ -621,8 +621,7 @@ def _setup_optimizer(
self._device,
)

if self._is_rank_zero:
log.info("Optimizer is initialized.")
utils.log_rank_zero(log, "Optimizer is initialized.")
return optimizer

def _setup_lr_scheduler(
Expand All @@ -638,8 +637,7 @@ def _setup_lr_scheduler(
last_epoch=last_epoch,
)

if self._is_rank_zero:
log.info("Learning rate scheduler is initialized.")
utils.log_rank_zero(log, "Learning rate scheduler is initialized.")
return lr_scheduler

def _setup_data(
Expand Down Expand Up @@ -690,8 +688,7 @@ def _setup_data(
),
)

if self._is_rank_zero:
log.info("Dataset and Sampler are initialized.")
utils.log_rank_zero(log, "Dataset and Sampler are initialized.")

return sampler, dataloader

Expand Down
37 changes: 18 additions & 19 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ def __init__(self, cfg: DictConfig) -> None:

_, rank = training.get_world_size_and_rank()

# _is_rank_zero is used primarily for logging. In the future, the logger
# should directly take care of this
self._is_rank_zero = rank == 0

# logging attributes
Expand Down Expand Up @@ -226,7 +224,8 @@ def setup(self, cfg: DictConfig) -> None:

# log config with parameter override
self._metric_logger.log_config(cfg)
log.info("_metric_logger is initialized.")

utils.log_rank_zero(log, "metric logger is initialized.")

checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)

Expand Down Expand Up @@ -254,8 +253,8 @@ def setup(self, cfg: DictConfig) -> None:
)

self._loss_fn = config.instantiate(cfg.loss)
if self._is_rank_zero:
log.info("Loss is initialized.")

utils.log_rank_zero(log, "Loss is initialized.")

# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after all of these are setup
Expand Down Expand Up @@ -314,11 +313,12 @@ def _setup_model(
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False)

if self._is_rank_zero:
log.info(
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..."
)
init_start = time.perf_counter()
init_start = time.perf_counter()

utils.log_rank_zero(
log,
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
)

with training.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)
Expand Down Expand Up @@ -397,10 +397,11 @@ def _setup_model(
)
# Ensure no params and buffers are on meta device
training.validate_no_params_on_meta_device(model)
utils.log_rank_zero(
log,
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs",
)
if self._is_rank_zero:
log.info(
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
)
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)

Expand All @@ -420,8 +421,7 @@ def _setup_optimizer(
self._device,
)

if self._is_rank_zero:
log.info("Optimizer and loss are initialized.")
utils.log_rank_zero(log, "Optimizer and loss are initialized.")
return optimizer

def _setup_lr_scheduler(
Expand All @@ -436,8 +436,8 @@ def _setup_lr_scheduler(
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)
if self._is_rank_zero:
log.info("Learning rate scheduler is initialized.")

utils.log_rank_zero(log, "Learning rate scheduler is initialized.")
return lr_scheduler

def _setup_data(
Expand Down Expand Up @@ -479,8 +479,7 @@ def _setup_data(
),
)

if self._is_rank_zero:
log.info("Dataset and Sampler are initialized.")
utils.log_rank_zero(log, "Dataset and Sampler are initialized.")

return sampler, dataloader

Expand Down
Loading

0 comments on commit a4a74a0

Please sign in to comment.