Skip to content

Commit

Permalink
Loading dataset from config now
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 23, 2024
1 parent ab9458b commit ea3af01
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 58 deletions.
Empty file added pdelfin/buildsilver/__init__.py
Empty file.
56 changes: 7 additions & 49 deletions pdelfin/train/config/qwen2vl-2b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,63 +27,21 @@ generate:
train_data:
seed: 1337
sources:
- name: fw-edu-all
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/fw-edu-all/*.json.gz
backend:
- openai
size: 100_000
- name: dclm
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dclm/*.zstd
backend:
- openai
size: 100_000
- name: dolma-v17
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v17/*.zstd
backend:
- openai
size: 100_000
- name: dolma-v1-small
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v1-small/*.zstd
- name: openai_batch_data_v2_mini
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json
backend:
- openai
size: 100_000

valid_data:
sources:
- name: fw-edu-10k
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/fw-edu-10k/valid/*.gz
- name: openai_batch_data_v2_mini
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json
backend:
- openai
size: 1500
- name: dolma-10k
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-10k/valid/*.gz
backend:
- openai
size: 1500
- name: dclm
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dclm/*.zstd
backend:
- openai
size: 1500
- name: dolma-v17
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v17/*.zstd
backend:
- openai
size: 1500
- name: dolma-v1-small
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v1-small/*.zstd
backend:
- openai
size: 3000
size: 100_000

# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
Expand Down
3 changes: 2 additions & 1 deletion pdelfin/train/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ class AwsConfig:
class SourceConfig:
name: str = field(help="The name of the source")
size: int = field(help="Limit size for the source")
paths: List[str] = field(help="The paths to the data files")
query_glob_path: str = field(help="The s3 bucket pointing to the inputs sent to OpenAI to generate the silver data")
response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai")
backend: List[str] = field(help="The data generation backend to use to train the model")


Expand Down
61 changes: 58 additions & 3 deletions pdelfin/train/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
import logging
import multiprocessing
import re
import random

from functools import partial
from typing import Any, Dict
from typing import Any, Dict, Optional
from logging import Logger

import boto3
from datasets import Dataset, Features, Value, load_dataset

from .core.config import DataConfig, SourceConfig

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -121,7 +126,7 @@ def merge_query_response(query_example, response_data: Dataset, response_map: di
return {"response": response_row["response"], "finish_reason": response_row["finish_reason"]}


def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str) -> Dataset:
def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str, num_proc: int=32) -> Dataset:
logger.info("Loading query and response datasets")
query_data = load_jsonl_from_s3(query_glob_path)
response_data = load_jsonl_from_s3(response_glob_path)
Expand All @@ -145,8 +150,58 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
logger.info("Running merge map")
final_dataset = query_data.map(
partial(merge_query_response, response_data=response_data, response_map=custom_id_to_response_row),
num_proc=multiprocessing.cpu_count(),
num_proc=num_proc
)
final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop")

return final_dataset


def make_dataset(
train_data_config: DataConfig,
valid_data_config: Optional[DataConfig] = None,
test_data_config: Optional[DataConfig] = None,
num_proc: int = 32,
logger: Optional[Logger] = None,
):
logger = logger or get_logger(__name__)
random.seed(train_data_config.seed)

dataset_splits: Dict[str, datasets.Dataset] = {}
tmp_train_sets = []

logger.info("Loading training data from %s sources", len(train_data_config.sources))
for source in train_data_config.sources:
tmp_train_sets.append(
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
)
dataset_splits["train"] = datasets.concatenate_datasets(tmp_train_sets)
logger.info(
f"Loaded {len(dataset_splits['train'])} training samples from {len(train_data_config.sources)} sources"
)

if valid_data_config:
tmp_validation_sets = []
logger.info("Loading validation data from %s sources", len(valid_data_config.sources))
for source in valid_data_config.sources:
tmp_validation_sets.append(
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
)
dataset_splits["validation"] = datasets.concatenate_datasets(tmp_validation_sets)
logger.info(
f"Loaded {len(dataset_splits['validation'])} validation samples from {len(valid_data_config.sources)} sources"
)

if test_data_config:
tmp_test_sets = []
logger.info("Loading test data from %s sources", len(test_data_config.sources))
for source in test_data_config.sources:
tmp_test_sets.append(
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
)
dataset_splits["test"] = datasets.concatenate_datasets(tmp_test_sets)
logger.info(
f"Loaded {len(dataset_splits['test'])} test samples from {len(test_data_config.sources)} sources"
)

return datasets.DatasetDict(**dataset_splits)
12 changes: 7 additions & 5 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
)


from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
from pdelfin.train.dataloader import make_dataset
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training


Expand Down Expand Up @@ -125,10 +125,12 @@ def run_train(config: TrainConfig):

setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)

train_ds = build_batch_query_response_vision_dataset(
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl",
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json",
)
dataset = make_dataset(
train_data_config=config.train_data,
valid_data_config=config.valid_data,
num_proc=config.num_proc,
logger=logger,
)

model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"
Expand Down

0 comments on commit ea3af01

Please sign in to comment.