-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring training for readability #296
base: main
Are you sure you want to change the base?
Changes from 4 commits
6681f0e
cdf6a06
443c232
25cfec3
16626bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,21 @@ | ||
import json | ||
import math | ||
import os | ||
import pathlib | ||
import sys | ||
from dataclasses import dataclass, field | ||
from typing import List, Optional | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import torch | ||
import torch.distributed | ||
import transformers | ||
from aenum import extend_enum | ||
from torch.optim.lr_scheduler import LambdaLR | ||
from training_utils import print_rank0 | ||
from transformers import Trainer | ||
|
||
from functionary.prompt_template import get_prompt_template_by_version | ||
from functionary.train import training_utils | ||
from functionary.train.custom_datasets import read_dataset | ||
|
||
extend_enum( | ||
transformers.trainer_utils.SchedulerType, | ||
|
@@ -47,17 +51,8 @@ def lr_lambda(current_step): | |
get_scheduler | ||
) | ||
|
||
from torch.nn import CrossEntropyLoss | ||
from torch.utils.data import DataLoader | ||
from transformers import AutoConfig, AutoTokenizer, Trainer | ||
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) | ||
from typing import Union | ||
|
||
from functionary.prompt_template import PromptTemplate, get_prompt_template_by_version | ||
from functionary.train.custom_datasets import read_dataset | ||
from functionary.train import training_utils | ||
from training_utils import print_rank0 | ||
|
||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) | ||
|
||
|
@@ -139,6 +134,23 @@ def trainer_save_model_safe(trainer: transformers.Trainer): | |
trainer.save_model() | ||
|
||
|
||
""" | ||
Below is the updated train() function from LEVENT OZBEK. | ||
Most of the changes are identical to those in train_lora.py. I simply applied the changes to the utility code in training_utils.py | ||
I commented out the original train() function | ||
|
||
- training_utils.tokenize_and_cache() is used for both training and evaluation datasets to avoid repetition. | ||
- dynamic_batch_size() function auto adjusts batch sizes based on token counts. I did not implement this in train_lora.py since loras are trained on a smaller data so I felt that it wasn't too necessary there. | ||
- DataLoaders are constructed using BatchSampler to dynamically adjust the batch size per epoch. | ||
- distributed DataLoader is used if local_rank != -1. | ||
- updated to use the optimized preprocess_logits_for_metrics dynamically compute_metrics from training_utils.py. | ||
|
||
Advantages of These Changes: | ||
- handles datasets with varying sequence lengths dynamically | ||
- supports both single-GPU and distributed setups. | ||
""" | ||
|
||
|
||
def train(): | ||
Comment on lines
+149
to
+163
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updates would rather preferred to go into PR description, not the code |
||
"""Training loop""" | ||
|
||
|
@@ -186,10 +198,9 @@ def train(): | |
torch_dtype=compute_dtype, | ||
config=config, | ||
cache_dir=training_args.cache_dir, | ||
attn_implementation="flash_attention_2", # use_flash_attention_2 is replaced by this from version: 4.36.0 | ||
attn_implementation="flash_attention_2", | ||
) | ||
model.config.use_cache = False | ||
# Activate computing load balancing loss iin MixtralForCausalLM | ||
if hasattr(model.config, "output_router_logits"): | ||
setattr(model.config, "output_router_logits", True) | ||
print_rank0("Activate computing load balancing loss") | ||
|
@@ -213,7 +224,6 @@ def train(): | |
|
||
tokenizer.save_pretrained(training_args.output_dir) | ||
|
||
# get id of added tokens to compute the accuracy of predicing the token | ||
id2token = { | ||
tokenizer.encode(token)[-1]: token | ||
for token in prompt_template.get_additional_tokens() | ||
|
@@ -222,22 +232,29 @@ def train(): | |
|
||
assert data_args.train_data_path is not None, "Please provide a training data file." | ||
|
||
train_dataset = read_dataset( | ||
# Cache and tokenize training data | ||
raw_train_dataset = read_dataset( | ||
model_args.model_name_or_path, data_args, training_args, tokenizer, "train" | ||
) | ||
train_dataset = training_utils.tokenize_and_cache( | ||
raw_train_dataset, tokenizer, training_args.cache_dir | ||
) | ||
|
||
if torch.distributed.get_rank() == 0: | ||
print(f"Training Data Loaded: #{len(train_dataset)}") | ||
|
||
if training_args.do_eval: | ||
eval_dataset = read_dataset( | ||
# Cache and tokenize evaluation data | ||
raw_eval_dataset = read_dataset( | ||
model_args.model_name_or_path, | ||
data_args, | ||
training_args, | ||
tokenizer, | ||
"validation", | ||
) | ||
|
||
eval_dataset = training_utils.tokenize_and_cache( | ||
raw_eval_dataset, tokenizer, training_args.cache_dir | ||
) | ||
if torch.distributed.get_rank() == 0: | ||
print(f"Eval Data Loaded: #{len(eval_dataset)}") | ||
|
||
|
@@ -248,6 +265,55 @@ def train(): | |
print_rank0("***** HERE ARE SOME EXAMPLES FROM EVALUATION ***") | ||
training_utils.print_some_examples(eval_dataset, tokenizer) | ||
|
||
# Dynamic batch size based on max tokens per batch | ||
max_tokens_per_batch = 2048 # You can adjust this as needed | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While dynamic batch size might seem a good idea for dealing with memory issues, it would cause instabilities in training. Due to the difference in gradient updates per conversation. I would rather prefer stability over memory efficiency. Stable updates with higher cost of GPUs should be preferred over a cheaper&faster training |
||
train_batch_sizes = training_utils.dynamic_batch_size( | ||
train_dataset, max_tokens_per_batch, tokenizer | ||
) | ||
print_rank0(f"Dynamic train batch sizes: {train_batch_sizes}") | ||
|
||
if training_args.do_eval: | ||
eval_batch_sizes = training_utils.dynamic_batch_size( | ||
eval_dataset, max_tokens_per_batch, tokenizer | ||
) | ||
print_rank0(f"Dynamic eval batch sizes: {eval_batch_sizes}") | ||
|
||
# DataLoaders with dynamic batch sizes | ||
train_loader = ( | ||
DataLoader( | ||
train_dataset, | ||
batch_sampler=torch.utils.data.BatchSampler( | ||
sampler=torch.utils.data.SequentialSampler(train_dataset), | ||
batch_size=max(train_batch_sizes), # Adjust batch size dynamically | ||
drop_last=False, | ||
), | ||
num_workers=4, | ||
pin_memory=True, | ||
) | ||
if training_args.local_rank == -1 | ||
else training_utils.create_distributed_data_loader( | ||
train_dataset, batch_size=max(train_batch_sizes) | ||
) | ||
) | ||
|
||
if training_args.do_eval: | ||
eval_loader = ( | ||
DataLoader( | ||
eval_dataset, | ||
batch_sampler=torch.utils.data.BatchSampler( | ||
sampler=torch.utils.data.SequentialSampler(eval_dataset), | ||
batch_size=max(eval_batch_sizes), # Adjust batch size dynamically | ||
drop_last=False, | ||
), | ||
num_workers=4, | ||
pin_memory=True, | ||
) | ||
if training_args.local_rank == -1 | ||
else training_utils.create_distributed_data_loader( | ||
eval_dataset, batch_size=max(eval_batch_sizes) | ||
) | ||
) | ||
|
||
def preprocess_logits_for_metrics(logits, labels): | ||
return training_utils.preprocess_logits_for_metrics( | ||
logits, labels, len(tokenizer) | ||
|
@@ -256,23 +322,17 @@ def preprocess_logits_for_metrics(logits, labels): | |
def compute_metrics(eval_preds): | ||
return training_utils.compute_metrics(eval_preds, id2token, tokenizer) | ||
|
||
if training_args.do_eval: | ||
trainer = Trainer( | ||
model=model, | ||
tokenizer=tokenizer, | ||
args=training_args, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
compute_metrics=compute_metrics, | ||
preprocess_logits_for_metrics=preprocess_logits_for_metrics, | ||
) | ||
else: | ||
trainer = Trainer( | ||
model=model, | ||
tokenizer=tokenizer, | ||
args=training_args, | ||
train_dataset=train_dataset, | ||
) | ||
trainer = Trainer( | ||
model=model, | ||
tokenizer=tokenizer, | ||
args=training_args, | ||
train_dataset=train_loader.dataset, | ||
eval_dataset=eval_loader.dataset if training_args.do_eval else None, | ||
compute_metrics=compute_metrics if training_args.do_eval else None, | ||
preprocess_logits_for_metrics=( | ||
preprocess_logits_for_metrics if training_args.do_eval else None | ||
), | ||
) | ||
|
||
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): | ||
trainer.train(resume_from_checkpoint=True) | ||
|
@@ -281,7 +341,6 @@ def compute_metrics(eval_preds): | |
|
||
trainer.save_state() | ||
|
||
# FSDP requires state_dict_type=FULL_STATE_DICT in order to save the model weights in .bin format | ||
if trainer.is_fsdp_enabled: | ||
trainer_save_model_safe(trainer=trainer) | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Authorship tracking is a responsibility of git. We should remove all authorship info from the code.