Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 16, 2024
1 parent 446773d commit 9d647b1
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions pdelfin/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .core.loggers import get_logger
from .core.paths import copy_dir, is_local
from .core.state import BeakerState
#from .tokenization import ModelTokenizer
# from .tokenization import ModelTokenizer

T = TypeVar("T")

Expand All @@ -49,21 +49,31 @@ def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset
random.seed(config.train_data.seed)

# Training sets get all concatenated and shuffled
train_dataset = (
concatenate_datasets(
[
get_rawdataset_from_source(config.train_data, source)
for source in config.train_data.sources
]
)
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
train_dataset = concatenate_datasets(
[
get_rawdataset_from_source(config.train_data, source).with_transform(
partial(
batch_prepare_data_for_qwen2_training,
processor=processor,
target_longest_image_dim=source.target_longest_image_dim,
target_anchor_text_len=source.target_anchor_text_len,
)
)
for source in config.train_data.sources
]
)

# Validation sets get put into a datasetdict so each can report a loss separately
valid_dataset = DatasetDict(
**{
source.name: get_rawdataset_from_source(config.valid_data, source)
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
source.name: get_rawdataset_from_source(config.valid_data, source).with_transform(
partial(
batch_prepare_data_for_qwen2_training,
processor=processor,
target_longest_image_dim=source.target_longest_image_dim,
target_anchor_text_len=source.target_anchor_text_len,
)
)
for source in config.valid_data.sources
}
)
Expand Down Expand Up @@ -186,5 +196,3 @@ def get_local_dir(output_dir: str):
else:
yield tmp_dir
copy_dir(tmp_dir, output_dir)


0 comments on commit 9d647b1

Please sign in to comment.