Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa committed Feb 9, 2025
1 parent 894f72c commit ef2c811
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 65 deletions.
4 changes: 3 additions & 1 deletion examples/llm/peft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
import tempfile

import fiddle as fdl
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger

from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning import NeMoLogger
from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform
import lightning.pytorch as pl


def make_squad_hf_dataset(tokenizer):
Expand Down Expand Up @@ -51,6 +51,7 @@ def formatting_prompts_func(example):
)
return datamodule


def make_strategy(strategy, model, devices, num_nodes, adapter_only=False):
if strategy == 'auto':
return pl.strategies.SingleDeviceStrategy(
Expand Down Expand Up @@ -88,6 +89,7 @@ def logger(ckpt_folder) -> nl.NeMoLogger:
wandb=None,
)


def main():
"""Example script to run PEFT with a HF transformers-instantiated model on squad."""
import argparse
Expand Down
3 changes: 3 additions & 0 deletions examples/llm/sft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def squad(tokenizer, mbs=1, gbs=2) -> pl.LightningDataModule:
},
)


def make_strategy(strategy, model, devices, num_nodes, adapter_only=False):
if strategy == 'auto':
return pl.strategies.SingleDeviceStrategy(
Expand All @@ -80,6 +81,7 @@ def make_strategy(strategy, model, devices, num_nodes, adapter_only=False):
else:
raise NotImplementedError("Encountered unknown strategy")


def logger(ckpt_folder) -> nl.NeMoLogger:
ckpt = nl.ModelCheckpoint(
save_last=True,
Expand All @@ -98,6 +100,7 @@ def logger(ckpt_folder) -> nl.NeMoLogger:
wandb=None,
)


def main():
"""Example script to run SFT with a HF transformers-instantiated model on squad."""
import argparse
Expand Down
21 changes: 11 additions & 10 deletions nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nemo.lightning.pytorch.strategies.utils import fsdp2_strategy_parallelize
from nemo.utils import logging


def masked_cross_entropy(logits, targets, mask=None):
"""
Compute the masked cross-entropy loss between logits and targets.
Expand All @@ -47,6 +48,7 @@ def masked_cross_entropy(logits, targets, mask=None):
loss = F.cross_entropy(logits, targets)
return loss


class HFAutoModelForCausalLM(pl.LightningModule, io.IOMixin, fn.FNMixin):
"""
A LightningModule wrapper for AutoModelForCausalLm.
Expand Down Expand Up @@ -269,9 +271,7 @@ def training_step(self, batch, batch_idx=None):

assert logits.shape[-2] == labels.shape[-1], "Expected logits & labels to have the same length"
loss = self.loss_fn(logits, labels, loss_mask)
self.log(
'reduced_train_loss', loss, prog_bar=True, rank_zero_only=True, batch_size=1, sync_dist=False
)
self.log('reduced_train_loss', loss, prog_bar=True, rank_zero_only=True, batch_size=1, sync_dist=False)
return loss

@torch.no_grad
Expand Down Expand Up @@ -329,16 +329,17 @@ def save_pretrained(self, path, state_dict):

def load_pretrained(self, path):
return AutoModelForCausalLM.from_pretrained(
path,
torch_dtype='auto',
device_map="cpu",
trust_remote_code=self.trust_remote_code,
load_in_4bit=self.load_in_4bit,
attn_implementation=self.attn_implementation,
).state_dict()
path,
torch_dtype='auto',
device_map="cpu",
trust_remote_code=self.trust_remote_code,
load_in_4bit=self.load_in_4bit,
attn_implementation=self.attn_implementation,
).state_dict()

def make_checkpoint_io(self, adapter_only=False):
from nemo.lightning.io.hf import HFCheckpointIO

return HFCheckpointIO(model=self, adapter_only=adapter_only)

def _remove_extra_batch_keys(self, batch, reserved_keys=['labels', 'loss_mask']):
Expand Down
2 changes: 1 addition & 1 deletion nemo/lightning/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from nemo.lightning.io.api import export_ckpt, import_ckpt, load, load_context, model_exporter, model_importer
from nemo.lightning.io.capture import reinit
from nemo.lightning.io.connector import Connector, ModelConnector
from nemo.lightning.io.hf import HFCheckpointIO
from nemo.lightning.io.mixin import ConnectorMixin, IOMixin, drop_unexpected_params, track_io
from nemo.lightning.io.pl import TrainerContext, is_distributed_ckpt
from nemo.lightning.io.state import TransformCTX, apply_transforms, state_transform
from nemo.lightning.io.hf import HFCheckpointIO

__all__ = [
"apply_transforms",
Expand Down
30 changes: 14 additions & 16 deletions nemo/lightning/io/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from lightning.fabric.utilities.types import _PATH
from torch import nn
from typing_extensions import override
from nemo.lightning.io.pl import ckpt_to_weights_subdir
from nemo.lightning.io.mixin import IOMixin

from nemo.lightning.io.mixin import IOMixin
from nemo.lightning.io.pl import ckpt_to_weights_subdir

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,8 +88,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
# └── adapter_model.safetensors
# Where the `trainer.pt` stores trainer's state (optimizer, dataloader, etc).
# The `weights` directory contains the adapter's state dict, in HF format.
self._save_adapter_weights_only(checkpoint.pop(
'state_dict'), checkpoint_dir, storage_options)
self._save_adapter_weights_only(checkpoint.pop('state_dict'), checkpoint_dir, storage_options)
torch.save(checkpoint, checkpoint_dir.parent / 'trainer.pt')
elif callable(getattr(self.model, 'save_pretrained', None)):
# In this case the output looks like the following:
Expand All @@ -106,14 +105,15 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio

# Where the `weights` directory contains the model's state dict, in HF format.
# The `trainer.pt` stores trainer's state (optimizer, dataloader, etc).
self.model.save_pretrained(
checkpoint_dir, state_dict=checkpoint.pop('state_dict'))
self.model.save_pretrained(checkpoint_dir, state_dict=checkpoint.pop('state_dict'))
torch.save(checkpoint, checkpoint_dir.parent / 'trainer.pt')
else:
super().save_checkpoint(checkpoint, path, storage_options)
raise NotImplementedError("Checkpoint was saved at: " + str(path))

def _save_adapter_weights_only(self, state_dict: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None) -> None:
def _save_adapter_weights_only(
self, state_dict: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None
) -> None:
"""
Saves only the adapter weights in a safetensors format.
Expand All @@ -131,10 +131,11 @@ def _save_adapter_weights_only(self, state_dict: Dict[str, Any], path: Union[str
module_names = list(state_dict.keys())
for name in module_names:
param = state_dict.pop(name)
name = name\
.replace("model.model", "base_model.model")\
.replace("lora_a.weight", "lora_A.weight")\
name = (
name.replace("model.model", "base_model.model")
.replace("lora_a.weight", "lora_A.weight")
.replace("lora_b.weight", "lora_B.weight")
)
state_dict[name] = param

# Save weights to safetensors format
Expand Down Expand Up @@ -165,8 +166,7 @@ def _load_adapter_weights_only(path: Union[str, Path]) -> Dict[str, Any]:
raise FileNotFoundError(f"Checkpoint file not found: {path}")

if not fs.isdir(path):
raise ValueError(
f"Checkpoints should be a directory. Found: {path}.")
raise ValueError(f"Checkpoints should be a directory. Found: {path}.")

state_dict = {}
adapter_file = Path(path) / "adapter_model.safetensors"
Expand All @@ -175,8 +175,7 @@ def _load_adapter_weights_only(path: Union[str, Path]) -> Dict[str, Any]:
from safetensors import safe_open

if not adapter_file.exists():
raise FileNotFoundError(
f"Adapter weights file not found: {adapter_file}")
raise FileNotFoundError(f"Adapter weights file not found: {adapter_file}")

try:
with safe_open(adapter_file, framework="pt", device=0) as f:
Expand Down Expand Up @@ -217,8 +216,7 @@ def load_checkpoint(
if self.adapter_only:
trainer_state |= HFCheckpointIO._load_adapter_weights_only(path)
elif callable(getattr(self.model, 'load_pretrained', None)):
trainer_state['state_dict'] = self.model.load_pretrained(
f'{path}/model/')
trainer_state['state_dict'] = self.model.load_pretrained(f'{path}/model/')
else:
raise ValueError("Badio")

Expand Down
57 changes: 21 additions & 36 deletions nemo/lightning/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ def from_trainer(cls, trainer: pl.Trainer) -> Self:
ValueError: If the trainer or its LightningModule does not extend `IOMixin`.
"""
if not hasattr(trainer, "__io__"):
raise ValueError(
f"Trainer must be an instance of {IOProtocol}. Please use the Trainer from nemo.")
raise ValueError(f"Trainer must be an instance of {IOProtocol}. Please use the Trainer from nemo.")
if not hasattr(trainer.lightning_module, "__io__"):
raise ValueError("LightningModule must extend IOMixin.")

Expand Down Expand Up @@ -117,7 +116,7 @@ def construct_extra(cls, trainer: pl.Trainer) -> Dict[str, Any]:

def ckpt_to_weights_subdir(filepath: Union[str, Path], is_saving) -> Path:
"""Given an input checkpoint filepath, clean it using `ckpt_to_dir`
and then return the weights subdirectory, if it exists."""
and then return the weights subdirectory, if it exists."""
filepath = ckpt_to_dir(filepath=filepath)
base_dir = filepath
assert isinstance(base_dir, Path)
Expand Down Expand Up @@ -195,8 +194,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
fs = get_filesystem(checkpoint_dir)
fs.makedirs(checkpoint_dir, exist_ok=True)

validate_sharding_integrity = not (
self.validated_consistency and self.assume_constant_structure)
validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure)
self.validated_consistency = True

return dist_checkpointing.save(
Expand Down Expand Up @@ -233,16 +231,14 @@ def load_checkpoint(
from megatron.core.dist_checkpointing.validation import StrictHandling

if map_location is not None:
raise ValueError(
"`map_location` argument is not supported for `MegatronCheckpointIO.load_checkpoint`.")
raise ValueError("`map_location` argument is not supported for `MegatronCheckpointIO.load_checkpoint`.")

# Try to read the checkpoint at `path`. If not exist, do not restore checkpoint.
fs = get_filesystem(path)
if not fs.exists(path):
raise FileNotFoundError(f"Checkpoint file not found: {path}")
if not fs.isdir(path):
raise ValueError(
f"Distributed checkpoints should be a directory. Found: {path}.")
raise ValueError(f"Distributed checkpoints should be a directory. Found: {path}.")

# Load from ckpt_path/weights (new format) if it exists
path = ckpt_to_weights_subdir(path, is_saving=False)
Expand All @@ -252,17 +248,15 @@ def load_checkpoint(
if self.save_ckpt_format == 'zarr' and self.load_directly_on_device:
from megatron.core.dist_checkpointing.strategies.tensorstore import TensorStoreLoadShardedStrategy

sharded_strategy = TensorStoreLoadShardedStrategy(
load_directly_on_device=True)
sharded_strategy = TensorStoreLoadShardedStrategy(load_directly_on_device=True)
else:
sharded_strategy = None

if self.parallel_load:
if sharded_strategy is None:
sharded_strategy = get_default_load_sharded_strategy(path)
sharded_strategy = FullyParallelLoadStrategyWrapper(
sharded_strategy, get_data_parallel_group(
with_context_parallel=True)
sharded_strategy, get_data_parallel_group(with_context_parallel=True)
)

if sharded_strategy is not None:
Expand All @@ -272,8 +266,7 @@ def load_checkpoint(
# For backward-compatibility reasons and a bug in MCore (strict check not applied to factories)
# we must apply a simple strict check here.
if not strict:
sharded_state_dict = self.adjust_non_strict_load(
path, sharded_state_dict)
sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict)
strict = StrictHandling.ASSUME_OK_UNEXPECTED if strict else StrictHandling.LOG_ALL
if strict is None:
# Default behavior
Expand Down Expand Up @@ -317,26 +310,21 @@ def _determine_dist_ckpt_save_strategy(self):
)

if self.async_save and self.save_ckpt_format != 'torch_dist':
raise ValueError(
'Async dist-ckpt save supported only for torch_dist format')
raise ValueError('Async dist-ckpt save supported only for torch_dist format')

torch_dist_kwargs = {} if self.torch_dist_multiproc is None else dict(
thread_count=self.torch_dist_multiproc)
torch_dist_kwargs = {} if self.torch_dist_multiproc is None else dict(thread_count=self.torch_dist_multiproc)
if self.save_ckpt_format == 'torch_dist' and torch_dist_kwargs:
save_strategy = TorchDistSaveShardedStrategy(
self.save_ckpt_format, 1, **torch_dist_kwargs)
save_strategy = TorchDistSaveShardedStrategy(self.save_ckpt_format, 1, **torch_dist_kwargs)
else:
save_strategy = get_default_save_sharded_strategy(
self.save_ckpt_format, 1)
save_strategy = get_default_save_sharded_strategy(self.save_ckpt_format, 1)

# MCore v0.8 introduces `use_cached_ckpt_structure` attribute
if hasattr(save_strategy, 'use_cached_ckpt_structure'):
save_strategy.use_cached_ckpt_structure = self.assume_constant_structure

if self.parallel_save:
parallelization_group = (
get_data_parallel_group(
with_context_parallel=True) if self.parallel_save_within_dp else None
get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None
)
save_strategy = FullyParallelSaveStrategyWrapper(
save_strategy, parallelization_group, self.assume_constant_structure
Expand All @@ -357,8 +345,8 @@ def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]
"""
Adjusts the loading of a non-strict sharded checkpoint by filtering out missing keys.
This function loads the checkpoint's metadata and removes any `ShardedBase` keys from
`sharded_state_dict` that do not exist in the checkpoint. It also logs unexpected keys
This function loads the checkpoint's metadata and removes any `ShardedBase` keys from
`sharded_state_dict` that do not exist in the checkpoint. It also logs unexpected keys
that were not found in the checkpoint.
Args:
Expand All @@ -373,7 +361,7 @@ def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]
are considered "unexpected" and are logged.
- Missing keys are not computed yet. To fully determine missing keys:
1. Perform an `all_gather_object` operation on `loaded_keys`.
2. Compute `missing_keys` as the difference between `ckpt_sharded_metadata.keys()`
2. Compute `missing_keys` as the difference between `ckpt_sharded_metadata.keys()`
and `loaded_keys`.
"""
from megatron.core import dist_checkpointing
Expand Down Expand Up @@ -403,20 +391,18 @@ def should_remove_missing_sharded_base(x: Any):
return True
return False

_, sharded_state_dict = extract_matching_values(
sharded_state_dict, should_remove_missing_sharded_base)
logging.info(
f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}')
_, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base)
logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}')

# TODO: compute missing_keys by:
# 1. all_gather_object of loaded_keys
# 2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys
return sharded_state_dict


def _fix_tensors_device(ckpt: Dict) -> Dict:
"""Ensure checkpoint tensors are on the correct device."""
assert torch.cuda.is_initialized(), (torch.cuda.is_available(),
torch.cuda.is_initialized())
assert torch.cuda.is_initialized(), (torch.cuda.is_available(), torch.cuda.is_initialized())
cur_dev = torch.device("cuda", index=torch.cuda.current_device())
from megatron.core.dist_checkpointing.dict_utils import dict_list_map_outplace

Expand Down Expand Up @@ -446,5 +432,4 @@ def is_distributed_ckpt(path) -> bool:

checkpoint_dir = ckpt_to_dir(path)
fs = get_filesystem(checkpoint_dir)
return fs.isdir(checkpoint_dir) \
and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir)
return fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir)
3 changes: 2 additions & 1 deletion nemo/lightning/nemo_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,14 @@ def _setup_trainer_model_checkpoint(self, trainer, log_dir, ckpt=None):
)

from nemo.lightning import MegatronStrategy

for callback in trainer.callbacks:
if isinstance(callback, PTLModelCheckpoint):
if callback.dirpath is None:
callback.dirpath = Path(log_dir / "checkpoints")
if callback.filename is None:
if isinstance(trainer.strategy, MegatronStrategy):
callback.filename = f"{self.name}--{{{callback.monitor}:.4f}}-{{epoch}}-{{consumed_samples}}"
callback.filename = f"{self.name}--{{{callback.monitor}:.4f}}-{{epoch}}-{{consumed_samples}}"
else:
# For automodel we log global_step
callback.filename = f"{self.name}--{{{callback.monitor}:.4f}}-{{epoch}}-{{step}}"
Expand Down
1 change: 1 addition & 0 deletions nemo/lightning/pytorch/strategies/fsdp2_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
setup_parallel_ranks,
)


class FSDP2Strategy(PLModelParallelStrategy, io.IOMixin):
"""Megatron plugin for Pytorch Lightning implementing FSDP 2.
Expand Down

0 comments on commit ef2c811

Please sign in to comment.