Skip to content

Commit

Permalink
First part of new dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 16, 2024
1 parent 202d81c commit 446773d
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 19 deletions.
8 changes: 6 additions & 2 deletions pdelfin/train/config/qwen2vl-7b-lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@ generate:

train_data:
seed: 1337
cache_location: /root/pdelfin_cache
sources:
- name: openai_batch_data_v5_1_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
#response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
- name: openai_batch_data_v5_1_iabooks_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
#response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000

valid_data:
cache_location: /root/pdelfin_cache
metric_for_best_model: openai_batch_data_v5_1_eval_loss
sources:
# These tend to be small, so you can load from s3 it's no big deal
Expand Down
1 change: 1 addition & 0 deletions pdelfin/train/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class SourceConfig:
@dataclass
class DataConfig:
seed: int = field(default=42, help="The seed to use for data loading")
cache_location: str = field(help="Location to store s3 pdfs that need to be used to compute page images")
metric_for_best_model: Optional[str] = field(help="metric to pass to trainer args to use for picking best model checkpoint at end", default=None)
sources: List[SourceConfig] = field(help="The source configurations")

Expand Down
5 changes: 0 additions & 5 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@
make_dataset
)


from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, filter_by_max_seq_len


class CheckpointUploadCallback(TrainerCallback):
def __init__(self, save_path: str, logger: Optional[Logger] = None):
self.save_path = save_path
Expand Down
19 changes: 7 additions & 12 deletions pdelfin/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset

from .core.cli import to_native_types
from .core.config import AwsConfig, TrainConfig, WandbConfig
from .core.config import AwsConfig, TrainConfig, WandbConfig, DataConfig, SourceConfig
from .core.loggers import get_logger
from .core.paths import copy_dir, is_local
from .core.state import BeakerState
#from .tokenization import ModelTokenizer

T = TypeVar("T")

from pdelfin.train.dataloader import build_batch_query_response_vision_dataset, list_dataset_files
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, filter_by_max_seq_len
from pdelfin.train.dataloader import build_finetuning_dataset, list_dataset_files
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training


def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
Expand All @@ -42,11 +42,8 @@ def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
return torch.float8_e4m3fn
return torch.float32

def get_rawdataset_from_source(source) -> Dataset:
if source.parquet_path is not None:
return load_dataset("parquet", data_files=list_dataset_files(source.parquet_path))["train"]
else:
return build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
def get_rawdataset_from_source(data_config: DataConfig, source: SourceConfig) -> Dataset:
return build_finetuning_dataset(source.response_glob_path, pdf_cache_location=data_config.cache_location)

def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset, Dataset]:
random.seed(config.train_data.seed)
Expand All @@ -55,19 +52,17 @@ def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset
train_dataset = (
concatenate_datasets(
[
get_rawdataset_from_source(source)
get_rawdataset_from_source(config.train_data, source)
for source in config.train_data.sources
]
)
.filter(partial(filter_by_max_seq_len, processor=processor), num_proc=multiprocessing.cpu_count())
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
)

# Validation sets get put into a datasetdict so each can report a loss separately
valid_dataset = DatasetDict(
**{
source.name: get_rawdataset_from_source(source)
.filter(partial(filter_by_max_seq_len, processor=processor))
source.name: get_rawdataset_from_source(config.valid_data, source)
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
for source in config.valid_data.sources
}
Expand Down

0 comments on commit 446773d

Please sign in to comment.