Skip to content

Commit

Permalink
TP + FSDP distributed training (full finetuning) (#2330)
Browse files Browse the repository at this point in the history
Co-authored-by: JessicaZhong <[email protected]>
Co-authored-by: joecummings <[email protected]>
  • Loading branch information
3 people authored Feb 10, 2025
1 parent 8c9235e commit 9da35c7
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 40 deletions.
7 changes: 6 additions & 1 deletion recipes/configs/llama3/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@

output_dir: /tmp/torchtune/llama3_70B/full # /tmp may be deleted by your system. Change it to your preference.

# Parallelism
tensor_parallel_dim: 1
parallelize_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
Expand Down Expand Up @@ -54,7 +59,7 @@ epochs: 1
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
fused: True
fused: False

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
7 changes: 6 additions & 1 deletion recipes/configs/llama3_1/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

output_dir: /tmp/torchtune/llama3_1_70B/full # /tmp may be deleted by your system. Change it to your preference.

# Parallelism
tensor_parallel_dim: 1
parallelize_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
Expand Down Expand Up @@ -55,7 +60,7 @@ optimizer:
lr: 2e-5
# Note: highly recommended to use fused=True optimizer flag
# with CPU offload for faster optimizer step.
fused: True
fused: False

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
7 changes: 6 additions & 1 deletion recipes/configs/llama3_3/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

output_dir: /tmp/torchtune/llama3_3_70B/full # /tmp may be deleted by your system. Change it to your preference.

# Parallelism
tensor_parallel_dim: 1
parallelize_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
Expand Down Expand Up @@ -55,7 +60,7 @@ optimizer:
lr: 2e-5
# Note: highly recommended to use fused=True optimizer flag
# with CPU offload for faster optimizer step.
fused: True
fused: False

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
2 changes: 1 addition & 1 deletion recipes/dev/generate_v2_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def setup(self, cfg: DictConfig) -> None:
tp_device_mesh = dist.init_device_mesh("cuda", tp_mesh_shape)

# Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor paralell
training.prepare_mha_for_tp(model, tp_device_mesh)
model = training.prepare_mha_for_tp(model, tp_device_mesh)
parallelize_module(
model,
tp_device_mesh,
Expand Down
93 changes: 64 additions & 29 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
from omegaconf import DictConfig, ListConfig

from torch import nn
from torch.distributed import destroy_process_group, init_process_group

from torch.distributed import (
destroy_process_group,
init_device_mesh,
init_process_group,
)
from torch.distributed._tensor import DTensor
from torch.distributed.tensor.parallel import parallelize_module
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
Expand Down Expand Up @@ -136,14 +141,26 @@ def __init__(self, cfg: DictConfig) -> None:
or self._enable_async_checkpointing,
)
init_process_group(self.distributed_backend)
_, rank = utils.get_world_size_and_rank()
self._is_rank_zero = rank == 0

# Initialize distributed variables
self.world_size, self.rank = utils.get_world_size_and_rank()
self._is_rank_zero = self.rank == 0
self.parallelize_plan = config.instantiate(cfg.get("parallelize_plan", None))
self.tensor_parallel_dim = cfg.get("tensor_parallel_dim", 1)
if self.tensor_parallel_dim > 1 and self.parallelize_plan is None:
raise ValueError(
"Parallelism plan need to be provided when tensor parallel is enabled."
)
if self.world_size % self.tensor_parallel_dim != 0:
raise ValueError(
f"world_size {self.world_size} must be divisible by tensor_parallel_dim {self.tensor_parallel_dim}"
)
self.data_parallel_dim = self.world_size // self.tensor_parallel_dim

# Logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)

if self._log_peak_memory_stats and device_type != "cuda":
log.info(
"log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
Expand Down Expand Up @@ -505,7 +522,7 @@ def _setup_model(

utils.log_rank_zero(
log,
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
"Distributed training is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
)
init_start = time.perf_counter()

Expand All @@ -515,6 +532,24 @@ def _setup_model(
if self._compile:
training.compile_model(model, verbose=self._is_rank_zero)

device_mesh = init_device_mesh(
self._device.type,
mesh_shape=(self.data_parallel_dim, self.tensor_parallel_dim),
mesh_dim_names=("dp", "tp"),
)
self.dp_size = device_mesh["dp"].size()
self.dp_rank = device_mesh["dp"].get_local_rank()

# Apply tensor parallelism to the model
if self.tensor_parallel_dim > 1:
# Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel
model = training.prepare_mha_for_tp(model, device_mesh["tp"])
parallelize_module(
model,
device_mesh["tp"],
parallelize_plan=self.parallelize_plan,
)

# We currently have two versions of activation checkpointing in this recipe
# for testing and BC purposes. ``enable_activation_checkpointing`` controls
# the older version of AC and this behavior is unchanged
Expand All @@ -534,19 +569,21 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

# For FSDP sharding
fsdp_shard_conditions = [
partial(
training.get_shard_conditions,
names_to_match=custom_sharded_layers,
# Apply Fully Sharded Data Parallelism to the model
if self.data_parallel_dim > 1:
fsdp_shard_conditions = [
partial(
training.get_shard_conditions,
names_to_match=custom_sharded_layers,
)
]
training.shard_model(
model=model,
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
dp_mesh=device_mesh["dp"],
)
]
training.shard_model(
model=model,
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
)

with training.set_default_dtype(self._dtype), self._device:
for m in model.modules():
Expand Down Expand Up @@ -651,8 +688,6 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = utils.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, self._tokenizer)
Expand All @@ -670,7 +705,7 @@ def _setup_data(
collate_fn = _get_component_from_path(collate_fn)

sampler = DistributedSampler(
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
ds, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle, seed=0
)
dataloader = DataLoader(
dataset=ds,
Expand Down Expand Up @@ -700,8 +735,6 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

world_size, rank = utils.get_world_size_and_rank()

# zero out the gradients before starting training
if not self._optimizer_in_bwd:
self._optimizer.zero_grad()
Expand All @@ -721,7 +754,7 @@ def train(self) -> None:
# in case shuffle is True
self._sampler.set_epoch(curr_epoch)

pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0))
pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
for idx, batch in enumerate(self._dataloader):
if (
self.max_steps_per_epoch is not None
Expand All @@ -739,7 +772,6 @@ def train(self) -> None:
and self._device.type == "cuda"
):
torch.cuda.memory._record_memory_history()

utils.batch_to_device(batch, self._device)

# Calculate the number of unmasked tokens in the current batch
Expand Down Expand Up @@ -782,7 +814,7 @@ def train(self) -> None:
torch.distributed.all_reduce(running_loss)

# We multiply by world_size to undo FSDP2 gradient normalization.
current_loss = current_loss * (world_size / num_tokens)
current_loss = current_loss * (self.world_size / num_tokens)

current_loss.backward()

Expand All @@ -795,12 +827,15 @@ def train(self) -> None:
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
training.scale_grads(self._model, self.world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
).full_tensor()
)
# If sharded, collect the DTensor here
if isinstance(grad_norm, DTensor):
grad_norm = grad_norm.full_tensor()
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

Expand Down Expand Up @@ -833,7 +868,7 @@ def train(self) -> None:
),
),
"tokens_per_second_per_gpu": num_tokens
/ (time_per_step * world_size),
/ (time_per_step * self.world_size),
}
if self._log_peak_memory_stats:
log_dict.update(
Expand Down
69 changes: 69 additions & 0 deletions tests/recipes/test_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,75 @@ def test_loss(
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
)

@pytest.mark.integration_test
@pytest.mark.parametrize(
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd, tensor_parallel_dim",
[
("llama3/8B_full", "llama3", "tune", 4, 1, True, 2),
("llama3/8B_full", "llama3", "tune", 4, 1, True, 4),
],
)
@gpu_test(gpu_count=4)
def test_loss_2d_parallel(
self,
micro_batch_size,
gradient_accumulation_steps,
config,
model_type,
ckpt_type,
optim_in_bwd,
tensor_parallel_dim,
tmpdir,
monkeypatch,
):
ckpt_component = CKPT_COMPONENT_MAP[ckpt_type]
ckpt = model_type + "_" + ckpt_type
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
tokenizer_path = Path(TOKENIZER_PATHS[model_type])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)
parallelize_plan = "torchtune.models.llama3.base_llama_tp_plan"

# Config file needed for model conversion.
write_hf_ckpt_config(ckpt_dir)

cmd = f"""
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed \
--config {config} \
batch_size={micro_batch_size} \
gradient_accumulation_steps={gradient_accumulation_steps} \
output_dir={tmpdir} \
checkpointer._component_={ckpt_component} \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
tensor_parallel_dim={tensor_parallel_dim} \
parallelize_plan._component_={parallelize_plan} \
metric_logger.filename={log_file} \
""".split()
model_config = MODEL_TEST_CONFIGS[model_type]
cmd = cmd + self._get_test_config_overrides() + model_config
# "optimizer_in_bwd=True" would free gradient info before clip_grad, causing
# wrong grad_norm, so we only test one of them each time. But loss values
# should be the same.
if not optim_in_bwd:
cmd.append("clip_grad_norm=100")
# Test that gradient clipping works with CPU offload
cmd.append("fsdp_cpu_offload=True")
else:
cmd.append("optimizer_in_bwd=True")

monkeypatch.setattr(sys, "argv", cmd)
runpy.run_path(TUNE_PATH, run_name="__main__")
loss_values = get_loss_values_from_metric_logger(log_file)
expected_loss_values = self._fetch_expected_loss_values_multi_rank(model_type)
torch.testing.assert_close(
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
)

@pytest.mark.integration_test
@pytest.mark.parametrize(
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd",
Expand Down
19 changes: 12 additions & 7 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4
from torchtune.modules import TransformerDecoder
from torchtune.modules.attention import MultiHeadAttention
from torchtune.modules.model_fusion import DeepFusionModel

from torchtune.modules.model_fusion import DeepFusionModel, EarlyFusionModel
from torchtune.modules.peft import get_adapter_state_dict
from torchtune.utils import get_device, get_logger
from torchtune.utils._logging import deprecated
Expand Down Expand Up @@ -523,6 +522,7 @@ def shard_model(
*,
cpu_offload: bool,
reshard_after_forward: bool = True,
dp_mesh: Optional[DeviceMesh] = None,
) -> None:
"""
Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.
Expand All @@ -541,11 +541,13 @@ def shard_model(
reshard_after_forward (bool): Whether to reshard parameters and buffers after
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.
dp_mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under mutliple parallelism.
Default to None.
Raises:
ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered.
"""
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward}
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward, "mesh": dp_mesh}
if cpu_offload:
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()

Expand Down Expand Up @@ -599,11 +601,11 @@ def prepare_mha_for_tp(
>>> # num_kv_heads = 16 (32/2)
>>> # embed_dim = 2048 (4096/2)
"""
# Consider the case of Deep Fusion models
if isinstance(model, DeepFusionModel):
model = model.decoder
# Handle fusion models by extracting decoder
is_fusion_model = isinstance(model, (DeepFusionModel, EarlyFusionModel))
decoder = model.decoder if is_fusion_model else model
tp_size = tp_mesh.size()
for m in list(model.modules()):
for m in list(decoder.modules()):
if isinstance(m, MultiHeadAttention):
# Adjust attention module to use the local number of heads
if m.num_heads % tp_size != 0:
Expand All @@ -624,4 +626,7 @@ def prepare_mha_for_tp(
m.num_heads = m.num_heads // tp_size
m.num_kv_heads = m.num_kv_heads // tp_size
m.embed_dim = m.embed_dim // tp_size

if is_fusion_model:
model.decoder = decoder
return model

0 comments on commit 9da35c7

Please sign in to comment.