Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring training for readability #296

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 95 additions & 36 deletions functionary/train/train.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import json
import math
import os
import pathlib
import sys
from dataclasses import dataclass, field
from typing import List, Optional
from typing import Optional

import numpy as np
import torch
import torch.distributed
import transformers
from aenum import extend_enum
from torch.optim.lr_scheduler import LambdaLR
from training_utils import print_rank0
from transformers import Trainer

from functionary.prompt_template import get_prompt_template_by_version
from functionary.train import training_utils
from functionary.train.custom_datasets import read_dataset

extend_enum(
transformers.trainer_utils.SchedulerType,
Expand Down Expand Up @@ -47,17 +51,8 @@ def lr_lambda(current_step):
get_scheduler
)

from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer, Trainer

sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
from typing import Union

from functionary.prompt_template import PromptTemplate, get_prompt_template_by_version
from functionary.train.custom_datasets import read_dataset
from functionary.train import training_utils
from training_utils import print_rank0

LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))

Expand Down Expand Up @@ -139,6 +134,23 @@ def trainer_save_model_safe(trainer: transformers.Trainer):
trainer.save_model()


"""
Below is the updated train() function from LEVENT OZBEK.
Most of the changes are identical to those in train_lora.py. I simply applied the changes to the utility code in training_utils.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Authorship tracking is a responsibility of git. We should remove all authorship info from the code.

I commented out the original train() function

- training_utils.tokenize_and_cache() is used for both training and evaluation datasets to avoid repetition.
- dynamic_batch_size() function auto adjusts batch sizes based on token counts. I did not implement this in train_lora.py since loras are trained on a smaller data so I felt that it wasn't too necessary there.
- DataLoaders are constructed using BatchSampler to dynamically adjust the batch size per epoch.
- distributed DataLoader is used if local_rank != -1.
- updated to use the optimized preprocess_logits_for_metrics dynamically compute_metrics from training_utils.py.

Advantages of These Changes:
- handles datasets with varying sequence lengths dynamically
- supports both single-GPU and distributed setups.
"""


def train():
Comment on lines +149 to +163
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updates would rather preferred to go into PR description, not the code

"""Training loop"""

Expand Down Expand Up @@ -186,10 +198,9 @@ def train():
torch_dtype=compute_dtype,
config=config,
cache_dir=training_args.cache_dir,
attn_implementation="flash_attention_2", # use_flash_attention_2 is replaced by this from version: 4.36.0
attn_implementation="flash_attention_2",
)
model.config.use_cache = False
# Activate computing load balancing loss iin MixtralForCausalLM
if hasattr(model.config, "output_router_logits"):
setattr(model.config, "output_router_logits", True)
print_rank0("Activate computing load balancing loss")
Expand All @@ -213,7 +224,6 @@ def train():

tokenizer.save_pretrained(training_args.output_dir)

# get id of added tokens to compute the accuracy of predicing the token
id2token = {
tokenizer.encode(token)[-1]: token
for token in prompt_template.get_additional_tokens()
Expand All @@ -222,22 +232,29 @@ def train():

assert data_args.train_data_path is not None, "Please provide a training data file."

train_dataset = read_dataset(
# Cache and tokenize training data
raw_train_dataset = read_dataset(
model_args.model_name_or_path, data_args, training_args, tokenizer, "train"
)
train_dataset = training_utils.tokenize_and_cache(
raw_train_dataset, tokenizer, training_args.cache_dir
)

if torch.distributed.get_rank() == 0:
print(f"Training Data Loaded: #{len(train_dataset)}")

if training_args.do_eval:
eval_dataset = read_dataset(
# Cache and tokenize evaluation data
raw_eval_dataset = read_dataset(
model_args.model_name_or_path,
data_args,
training_args,
tokenizer,
"validation",
)

eval_dataset = training_utils.tokenize_and_cache(
raw_eval_dataset, tokenizer, training_args.cache_dir
)
if torch.distributed.get_rank() == 0:
print(f"Eval Data Loaded: #{len(eval_dataset)}")

Expand All @@ -248,6 +265,55 @@ def train():
print_rank0("***** HERE ARE SOME EXAMPLES FROM EVALUATION ***")
training_utils.print_some_examples(eval_dataset, tokenizer)

# Dynamic batch size based on max tokens per batch
max_tokens_per_batch = 2048 # You can adjust this as needed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While dynamic batch size might seem a good idea for dealing with memory issues, it would cause instabilities in training. Due to the difference in gradient updates per conversation. I would rather prefer stability over memory efficiency. Stable updates with higher cost of GPUs should be preferred over a cheaper&faster training

train_batch_sizes = training_utils.dynamic_batch_size(
train_dataset, max_tokens_per_batch, tokenizer
)
print_rank0(f"Dynamic train batch sizes: {train_batch_sizes}")

if training_args.do_eval:
eval_batch_sizes = training_utils.dynamic_batch_size(
eval_dataset, max_tokens_per_batch, tokenizer
)
print_rank0(f"Dynamic eval batch sizes: {eval_batch_sizes}")

# DataLoaders with dynamic batch sizes
train_loader = (
DataLoader(
train_dataset,
batch_sampler=torch.utils.data.BatchSampler(
sampler=torch.utils.data.SequentialSampler(train_dataset),
batch_size=max(train_batch_sizes), # Adjust batch size dynamically
drop_last=False,
),
num_workers=4,
pin_memory=True,
)
if training_args.local_rank == -1
else training_utils.create_distributed_data_loader(
train_dataset, batch_size=max(train_batch_sizes)
)
)

if training_args.do_eval:
eval_loader = (
DataLoader(
eval_dataset,
batch_sampler=torch.utils.data.BatchSampler(
sampler=torch.utils.data.SequentialSampler(eval_dataset),
batch_size=max(eval_batch_sizes), # Adjust batch size dynamically
drop_last=False,
),
num_workers=4,
pin_memory=True,
)
if training_args.local_rank == -1
else training_utils.create_distributed_data_loader(
eval_dataset, batch_size=max(eval_batch_sizes)
)
)

def preprocess_logits_for_metrics(logits, labels):
return training_utils.preprocess_logits_for_metrics(
logits, labels, len(tokenizer)
Expand All @@ -256,23 +322,17 @@ def preprocess_logits_for_metrics(logits, labels):
def compute_metrics(eval_preds):
return training_utils.compute_metrics(eval_preds, id2token, tokenizer)

if training_args.do_eval:
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
else:
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
)
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_loader.dataset,
eval_dataset=eval_loader.dataset if training_args.do_eval else None,
compute_metrics=compute_metrics if training_args.do_eval else None,
preprocess_logits_for_metrics=(
preprocess_logits_for_metrics if training_args.do_eval else None
),
)

if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
Expand All @@ -281,7 +341,6 @@ def compute_metrics(eval_preds):

trainer.save_state()

# FSDP requires state_dict_type=FULL_STATE_DICT in order to save the model weights in .bin format
if trainer.is_fsdp_enabled:
trainer_save_model_safe(trainer=trainer)
else:
Expand Down
Loading
Loading