Skip to content

Commit

Permalink
support shard_dataloader in dynamic semi-auto
Browse files Browse the repository at this point in the history
  • Loading branch information
haohongxiang committed Jan 24, 2024
1 parent 16bca68 commit 624abd7
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 94 deletions.
11 changes: 4 additions & 7 deletions llm/llama/auto_parallel/run_pretrain_3D_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,10 @@ class PretrainingTrainer(SemiAutoTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _wrap_dist_loader(self, train_dataloader):
return dist.shard_dataloader(
dataloader=train_dataloader,
meshes=self._get_meshes_for_loader(),
input_keys=["input_ids", "labels"],
shard_dims="dp",
)
def get_train_dataloader(self):
dist_loader = super().get_train_dataloader()
dist_loader._input_keys = ["input_ids", "labels"]
return dist_loader


def print_config(args, key=""):
Expand Down
40 changes: 14 additions & 26 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def loss_func(loss, outputs):
super().__init__(*args, **kwargs)
assert self.args.use_auto_parallel

self.global_mesh = fleet.auto.get_mesh()

def _nested_gather(self, tensors):
"""
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
Expand All @@ -54,31 +56,16 @@ def _wrap_model(self, model, training=True):

def _get_meshes_for_loader(self):
def _get_mesh(pp_idx=0):
mesh = fleet.auto.get_mesh()
if "pp" in mesh.dim_names:
mesh = mesh.get_mesh_with_dim("pp")[pp_idx]
return mesh
return self.global_mesh.get_mesh_with_dim("pp")[pp_idx]

meshes = []
for pp_idx in range(self.args.pipeline_parallel_degree):
meshes.append(_get_mesh(pp_idx))
return meshes

def _wrap_dist_loader(self, train_dataloader):
return dist.shard_dataloader(
dataloader=train_dataloader,
meshes=self._get_meshes_for_loader(),
shard_dims="dp",
)

def _wrap_for_static(self, model, train_dataloader):
# TODO: convert fleet.auto.Strategy to dist.Strategy
# TODO: fix bugs in paddle/distributed/auto_parallel/api.py#L981 about sample_split of engine._prepare_data_spec

dist_loader = self._wrap_dist_loader(train_dataloader)

model = dist.to_static(model, dist_loader, self.criterion, self.optimizer, strategy=self.args.strategy)
return model, dist_loader
model = dist.to_static(model, train_dataloader, self.criterion, self.optimizer, strategy=self.args.strategy)
return model

def _wrap_for_amp_training(self):
pass
Expand Down Expand Up @@ -125,14 +112,15 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]:
drop_last=self.args.dataloader_drop_last,
)

# return DistributedBatchSampler(
# self.train_dataset,
# batch_size=self.args.per_device_train_batch_size,
# shuffle=True,
# num_replicas=self.args.dataset_world_size,
# rank=self.args.dataset_rank,
# drop_last=self.args.dataloader_drop_last,
# )
def get_train_dataloader(self):
train_dataloader = super().get_train_dataloader()
dist_loader = dist.shard_dataloader(
dataloader=train_dataloader,
meshes=self._get_meshes_for_loader(),
shard_dims="dp",
)

return dist_loader

def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
model.train()
Expand Down
7 changes: 6 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ def train(
self._load_optimizer_and_scheduler(resume_from_checkpoint)

if self.args.use_auto_parallel and self.args.run_static_semi_auto:
model, train_dataloader = self._wrap_for_static(model, train_dataloader)
model = self._wrap_for_static(model, train_dataloader)

self.model = model

Expand Down Expand Up @@ -2275,6 +2275,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
self.model.save_pretrained(
output_dir,
variant=self.args.weight_name_suffix,
save_function=self._save_ckpt_func,
merge_tensor_parallel=merge_tensor_parallel,
is_main_process=self.args.should_save,
max_shard_size="1024GB",
Expand All @@ -2293,6 +2294,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
config_to_save=config_to_save,
merge_tensor_parallel=merge_tensor_parallel,
variant=weight_name_suffix,
save_function=self._save_ckpt_func,
is_main_process=self.args.should_save,
max_shard_size="1024GB",
)
Expand All @@ -2301,6 +2303,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
output_dir,
merge_tensor_parallel=merge_tensor_parallel,
variant=self.args.weight_name_suffix,
save_function=self._save_ckpt_func,
is_main_process=self.args.should_save,
max_shard_size="1024GB",
)
Expand All @@ -2327,6 +2330,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
config_to_save=config_to_save,
merge_tensor_parallel=merge_tensor_parallel,
variant=weight_name_suffix,
save_function=self._save_ckpt_func,
is_main_process=self.args.should_save,
max_shard_size="1024GB",
)
Expand All @@ -2335,6 +2339,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
output_dir,
merge_tensor_parallel=merge_tensor_parallel,
variant=self.args.weight_name_suffix,
save_function=self._save_ckpt_func,
is_main_process=self.args.should_save,
max_shard_size="1024GB",
)
Expand Down
16 changes: 11 additions & 5 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing import Any, Dict, List, Optional

import paddle
import paddle.distributed as dist
from paddle.distributed import fleet

from ..utils.log import logger
Expand Down Expand Up @@ -1278,10 +1279,7 @@ def is_segment_parallel_supported():
self.strategy = strategy
order = ["dp", "pp", "mp"]
degree = [self.data_parallel_degree, self.pipeline_parallel_degree, self.tensor_parallel_degree]
mesh_dims = list(filter(lambda x: x[1] > 1, list(zip(order, degree))))
if not mesh_dims:
mesh_dims = [("dp", 1)]

mesh_dims = list(zip(order, degree))
fleet.auto.create_mesh(mesh_dims)
else:
world_size = paddle.distributed.get_world_size()
Expand Down Expand Up @@ -1400,7 +1398,7 @@ def data_parallel_rank(self):
return dp_group.rank
elif self.use_auto_parallel:
mesh = fleet.auto.get_mesh()
return mesh.get_dim_size("dp")
return mesh.get_rank_by_dim_and_process_id("dp", dist.get_rank())
else:
return paddle.distributed.get_rank()

Expand Down Expand Up @@ -1437,6 +1435,9 @@ def tensor_parallel_rank(self):
hcg = fleet.get_hybrid_communicate_group()
tp_group = hcg.get_model_parallel_group()
return max(tp_group.rank, 0)
elif self.use_auto_parallel:
mesh = fleet.auto.get_mesh()
return mesh.get_rank_by_dim_and_process_id("mp", dist.get_rank())
else:
return 0

Expand All @@ -1446,6 +1447,9 @@ def pipeline_parallel_rank(self):
hcg = fleet.get_hybrid_communicate_group()
rank = hcg.get_stage_id()
return max(rank, 0)
elif self.use_auto_parallel:
mesh = fleet.auto.get_mesh()
return mesh.get_rank_by_dim_and_process_id("pp", dist.get_rank())
else:
return 0

Expand Down Expand Up @@ -1588,6 +1592,8 @@ def should_save_model_state(self):
else:
if self.should_save_sharding_stage1_model:
return True
elif self.use_auto_parallel:
return True
elif self.use_hybrid_parallel:
# save on dataset rank 0
return self.sharding_parallel_rank == 0 and self.data_parallel_rank == 0
Expand Down
Loading

0 comments on commit 624abd7

Please sign in to comment.