Skip to content

Commit

Permalink
fix dataloader issue (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada authored Feb 16, 2023
1 parent 032676a commit 2fde77e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tests/trainer/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,22 @@ def _init_dummy_dataset(self):

return dummy_dataset

def test_drop_last_dataloader(self):
self.ppo_config = PPOConfig(batch_size=3, forward_batch_size=1, log_with=None)

dummy_dataset = self._init_dummy_dataset()

ppo_trainer = PPOTrainer(
config=self.ppo_config,
model=self.gpt2_model,
ref_model=self.gpt2_model_ref,
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
dummy_dataloader = ppo_trainer.dataloader

self.assertEqual(len(dummy_dataloader), 0)

def test_ppo_step(self):
# initialize dataset
dummy_dataset = self._init_dummy_dataset()
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset],
batch_size=self.config.batch_size,
collate_fn=data_collator,
shuffle=True,
drop_last=True,
)
return dataloader

Expand Down

0 comments on commit 2fde77e

Please sign in to comment.