From 63506496f012ed01aa6c8bc4eb34be804a33299b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 23 Jan 2023 10:39:39 +0000 Subject: [PATCH 1/2] improve API - add kwargs check on `PPOTrainer` - add tests --- tests/test_ppo_trainer.py | 31 +++++++++++++++++++++++++++++++ trl/trainer/ppo_trainer.py | 17 +++++++++++------ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index 6c2ac0da07..a5004b6653 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -98,6 +98,37 @@ def setUp(self): return super().setUp() + def test_raise_unvalid_args(self): + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + with self.assertRaises(ValueError): + _ = PPOTrainer( + config={}, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + with self.assertRaises(ValueError): + _ = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=None, + dataset=dummy_dataset, + ) + + with self.assertRaises(ValueError): + _ = PPOTrainer( + config=self.ppo_config, + model=None, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + def test_ppo_step(self): # initialize dataset dummy_dataset = self._init_dummy_dataset() diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 5c28aa3efc..6801b33fc3 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -99,16 +99,21 @@ def __init__( used only if `ref_model` is `None`. """ super().__init__(config) - - # Step 1: Initialize Accelerator - self.accelerator = Accelerator(log_with=config.log_with, **config.accelerator_kwargs) - self.accelerator.init_trackers(config.tracker_project_name, config=config.to_dict(), **config.tracker_kwargs) - - # Step 2: Initialize model, tokenizer, and dataloader + # Step 0: check positional arguments validity + if not isinstance(config, PPOConfig): + raise ValueError(f"config must be a PPOConfig, got {type(config)}") + if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + raise ValueError( + f"tokenizer must be a PreTrainedTokenizer or PreTrainedTokenizerFast, got {type(tokenizer)}" + ) if not isinstance(model, PreTrainedModelWrapper): raise ValueError( f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}" ) + # Step 1: Initialize Accelerator + self.accelerator = Accelerator(log_with=config.log_with, **config.accelerator_kwargs) + self.accelerator.init_trackers(config.tracker_project_name, config=config.to_dict(), **config.tracker_kwargs) + self.model = model self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") From 2ddce83d220533667b1bd7c489a60e351427f691 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 23 Jan 2023 14:38:19 +0000 Subject: [PATCH 2/2] make all args kwargs --- trl/trainer/ppo_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 6801b33fc3..8de10df95b 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -65,10 +65,10 @@ class PPOTrainer(BaseTrainer): def __init__( self, - config: PPOConfig, - model: PreTrainedModelWrapper, - ref_model: PreTrainedModelWrapper, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + config: PPOConfig = None, + model: PreTrainedModelWrapper = None, + ref_model: PreTrainedModelWrapper = None, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None, optimizer: Optional[torch.optim.Optimizer] = None, data_collator=None,