Skip to content

Commit

Permalink
rewrite traning loop
Browse files Browse the repository at this point in the history
  • Loading branch information
haohongxiang committed Jan 25, 2024
1 parent 6a381c3 commit dee9d04
Show file tree
Hide file tree
Showing 4 changed files with 479 additions and 116 deletions.
13 changes: 4 additions & 9 deletions llm/llama/auto_parallel/run_pretrain_3D_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@
LinearAnnealingWithWarmupDecay,
LlamaConfig,
LlamaForCausalLM3DAuto,
LlamaPretrainingCriterion3DAuto,
)
from paddlenlp.utils.log import logger

MODEL_CLASSES = {
"llama": (LlamaConfig, LlamaForCausalLM3DAuto),
"llama": (LlamaConfig, LlamaForCausalLM3DAuto, LlamaPretrainingCriterion3DAuto),
}


Expand Down Expand Up @@ -487,7 +488,7 @@ def main():
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)

config_class, model_class = MODEL_CLASSES[model_args.model_type]
config_class, model_class, criterion_class = MODEL_CLASSES[model_args.model_type]

tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)

Expand Down Expand Up @@ -551,8 +552,7 @@ def main():

with paddle.LazyGuard():
model = model_class.from_config(config, dtype=dtype)

criterion = None
criterion = criterion_class(config)

for param in model.parameters():
assert not param._is_initialized()
Expand Down Expand Up @@ -598,11 +598,6 @@ def fn(layer):
need_data=training_args.should_load_dataset,
)

# total_train_batch_size_per_acc_step = (
# training_args.per_device_train_batch_size * training_args.data_parallel_degree
# )
# total_train_batch_size = total_train_batch_size_per_acc_step * training_args.gradient_accumulation_steps

trainer = PretrainingTrainer(
model=model,
criterion=criterion,
Expand Down
Loading

0 comments on commit dee9d04

Please sign in to comment.