diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index 6c2ac0da07..a5004b6653 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -98,6 +98,37 @@ def setUp(self): return super().setUp() + def test_raise_unvalid_args(self): + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + with self.assertRaises(ValueError): + _ = PPOTrainer( + config={}, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + with self.assertRaises(ValueError): + _ = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=None, + dataset=dummy_dataset, + ) + + with self.assertRaises(ValueError): + _ = PPOTrainer( + config=self.ppo_config, + model=None, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + def test_ppo_step(self): # initialize dataset dummy_dataset = self._init_dummy_dataset() diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 5c28aa3efc..8de10df95b 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -65,10 +65,10 @@ class PPOTrainer(BaseTrainer): def __init__( self, - config: PPOConfig, - model: PreTrainedModelWrapper, - ref_model: PreTrainedModelWrapper, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + config: PPOConfig = None, + model: PreTrainedModelWrapper = None, + ref_model: PreTrainedModelWrapper = None, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None, optimizer: Optional[torch.optim.Optimizer] = None, data_collator=None, @@ -99,16 +99,21 @@ def __init__( used only if `ref_model` is `None`. """ super().__init__(config) - - # Step 1: Initialize Accelerator - self.accelerator = Accelerator(log_with=config.log_with, **config.accelerator_kwargs) - self.accelerator.init_trackers(config.tracker_project_name, config=config.to_dict(), **config.tracker_kwargs) - - # Step 2: Initialize model, tokenizer, and dataloader + # Step 0: check positional arguments validity + if not isinstance(config, PPOConfig): + raise ValueError(f"config must be a PPOConfig, got {type(config)}") + if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + raise ValueError( + f"tokenizer must be a PreTrainedTokenizer or PreTrainedTokenizerFast, got {type(tokenizer)}" + ) if not isinstance(model, PreTrainedModelWrapper): raise ValueError( f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}" ) + # Step 1: Initialize Accelerator + self.accelerator = Accelerator(log_with=config.log_with, **config.accelerator_kwargs) + self.accelerator.init_trackers(config.tracker_project_name, config=config.to_dict(), **config.tracker_kwargs) + self.model = model self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")