Skip to content

Commit

Permalink
typos
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 23, 2024
1 parent ea3af01 commit 5916239
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
12 changes: 6 additions & 6 deletions pdelfin/train/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from logging import Logger

import boto3
from datasets import Dataset, Features, Value, load_dataset
from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict

from .core.config import DataConfig, SourceConfig

Expand Down Expand Up @@ -167,15 +167,15 @@ def make_dataset(
logger = logger or get_logger(__name__)
random.seed(train_data_config.seed)

dataset_splits: Dict[str, datasets.Dataset] = {}
dataset_splits: Dict[str, Dataset] = {}
tmp_train_sets = []

logger.info("Loading training data from %s sources", len(train_data_config.sources))
for source in train_data_config.sources:
tmp_train_sets.append(
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
)
dataset_splits["train"] = datasets.concatenate_datasets(tmp_train_sets)
dataset_splits["train"] = concatenate_datasets(tmp_train_sets)
logger.info(
f"Loaded {len(dataset_splits['train'])} training samples from {len(train_data_config.sources)} sources"
)
Expand All @@ -187,7 +187,7 @@ def make_dataset(
tmp_validation_sets.append(
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
)
dataset_splits["validation"] = datasets.concatenate_datasets(tmp_validation_sets)
dataset_splits["validation"] = concatenate_datasets(tmp_validation_sets)
logger.info(
f"Loaded {len(dataset_splits['validation'])} validation samples from {len(valid_data_config.sources)} sources"
)
Expand All @@ -199,9 +199,9 @@ def make_dataset(
tmp_test_sets.append(
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
)
dataset_splits["test"] = datasets.concatenate_datasets(tmp_test_sets)
dataset_splits["test"] = concatenate_datasets(tmp_test_sets)
logger.info(
f"Loaded {len(dataset_splits['test'])} test samples from {len(test_data_config.sources)} sources"
)

return datasets.DatasetDict(**dataset_splits)
return DatasetDict(**dataset_splits)
8 changes: 4 additions & 4 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def run_train(config: TrainConfig):
model = get_peft_model(model=model, peft_config=peft_config)
log_trainable_parameters(model=model, logger=logger)

train_ds = train_ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
print(train_ds)
formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
print(formatted_dataset)
print("---------------")

save_path = join_path("", config.save.path, run_name.run)
Expand Down Expand Up @@ -202,8 +202,8 @@ def run_train(config: TrainConfig):
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
#eval_dataset=formatted_dataset["validation"], # pyright: ignore
train_dataset=formatted_dataset["train"],
eval_dataset=formatted_dataset["validation"], # pyright: ignore
tokenizer=processor.tokenizer,
#data_collator=collator,
#callbacks=[checkpoint_callback],
Expand Down

0 comments on commit 5916239

Please sign in to comment.