Skip to content

Commit

Permalink
Hopefully fixing dataloader for now
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 7, 2024
1 parent 5d35461 commit ebd40f9
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 24 deletions.
34 changes: 10 additions & 24 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
log_trainable_parameters,
packing_collator,
setup_environment,
make_dataset
)


Expand Down Expand Up @@ -113,6 +114,15 @@ def run_train(config: TrainConfig):
run_name = RunName.get(config)

setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
accelerator = accelerate.Accelerator()

# Build and download the dataset on process 0
if accelerator.is_main_process:
make_dataset(config)

accelerator.wait_for_everyone()

train_dataset, valid_dataset = make_dataset(config)

model = Qwen2VLForConditionalGeneration.from_pretrained(
config.model.name_or_path, torch_dtype=torch.bfloat16,
Expand All @@ -132,30 +142,6 @@ def run_train(config: TrainConfig):
model = get_peft_model(model=model, peft_config=peft_config)
log_trainable_parameters(model=model, logger=logger)

random.seed(config.train_data.seed)

# Training sets get all concatenated and shuffled
train_dataset = (
concatenate_datasets(
[
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
for source in config.train_data.sources
]
)
.filter(partial(filter_by_max_seq_len, processor=processor))
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
)

# Validation sets get put into a datasetdict so each can report a loss separately
valid_dataset = DatasetDict(
**{
source.name: build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
.filter(partial(filter_by_max_seq_len, processor=processor))
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
for source in config.valid_data.sources
}
)

save_path = join_path("", config.save.path, run_name.run)

save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore
Expand Down
32 changes: 32 additions & 0 deletions pdelfin/train/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import multiprocessing
import os
import random
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
Expand All @@ -13,6 +14,7 @@
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import PrecisionType
from datasets import Dataset, concatenate_datasets, DatasetDict

from .core.cli import to_native_types
from .core.config import AwsConfig, TrainConfig, WandbConfig
Expand All @@ -23,6 +25,9 @@

T = TypeVar("T")

from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, filter_by_max_seq_len


def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
pt = PrecisionType(accelerator.mixed_precision)
Expand All @@ -34,6 +39,33 @@ def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
return torch.float8_e4m3fn
return torch.float32

def make_dataset(config: TrainConfig) -> tuple[Dataset, Dataset]:
random.seed(config.train_data.seed)

# Training sets get all concatenated and shuffled
train_dataset = (
concatenate_datasets(
[
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
for source in config.train_data.sources
]
)
.filter(partial(filter_by_max_seq_len, processor=processor))
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
)

# Validation sets get put into a datasetdict so each can report a loss separately
valid_dataset = DatasetDict(
**{
source.name: build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
.filter(partial(filter_by_max_seq_len, processor=processor))
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
for source in config.valid_data.sources
}
)

return train_dataset, valid_dataset


def setup_environment(
aws_config: Optional[AwsConfig] = None, wandb_config: Optional[WandbConfig] = None, **kwargs: str
Expand Down
15 changes: 15 additions & 0 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ def testCombinedQueryResponse(self):

print(ds[0])

def testLocalDS(self):
ds = build_batch_query_response_vision_dataset(
query_glob_path="/root/openai_batch_data_v5_1_train/*.jsonl",
response_glob_path="/root/openai_batch_data_v5_1_train_done/*.json",
)

print(ds)

ds.to_parquet("/root/trainds_parquet/bigds.parquet")

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
from pdelfin.train.dataprep import filter_by_max_seq_len
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=1000))

print(ds[0])

def testPlotSequenceLengthHistogram(self):
import plotly.express as px
Expand Down

0 comments on commit ebd40f9

Please sign in to comment.