From 7f0ea8e406245ab8061726c4a6ad57067e7e709b Mon Sep 17 00:00:00 2001 From: szhengac Date: Wed, 1 Feb 2023 06:47:23 +0000 Subject: [PATCH] fix gpt --- examples/gpt/schedule.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/examples/gpt/schedule.py b/examples/gpt/schedule.py index 3c0e0325..e48071cf 100644 --- a/examples/gpt/schedule.py +++ b/examples/gpt/schedule.py @@ -269,6 +269,7 @@ def fwd_post_hook(_module, _input, output): # Shard output embedding. if head_sch is not None: head_sch.shard("weight", axis=0) + head_sch.sync(mode="bwd_post", sync_op_or_fn="all_reduce") def shard_qkv( @@ -363,15 +364,9 @@ def replace_and_shard_mlp( sub_sch["fc_out"].sync(mode="fwd_post", sync_op_or_fn="all_reduce") elif sch.world_size > 1: - sch[f"{prefix}.{fc_names[0]}"].sync( - mode="fwd_pre", sync_op_or_fn="all_gather", axis=1 - ) sch[f"{prefix}.{fc_names[0]}"].shard("weight", axis=0) sch[f"{prefix}.{fc_names[0]}"].shard("bias", axis=0) sch[f"{prefix}.{fc_names[1]}"].shard("weight", axis=1) - sch[f"{prefix}.{fc_names[1]}"].sync( - mode="fwd_post", sync_op_or_fn="reduce_scatter", axis=1 - ) if sequence_parallel: sch[f"{prefix}.{fc_names[0]}"].sync(