diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d18a82d17fef..db194ab91a1c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -519,47 +519,40 @@ def _prepare_model_inputs( inputs_kwarg = model_kwargs.pop(input_name, None) if inputs_kwarg is not None and inputs is not None: raise ValueError( - f"`inputs`: {inputs}` were passed alongside " - f"{input_name} which is not allowed." + f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed." f"Make sure to either pass {inputs} or {input_name}=..." ) elif inputs_kwarg is not None: inputs = inputs_kwarg - # 3. models with `input_ids` can also make use of `inputs_embeds` - if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs): - inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" - - # 4. Only encoder-decoder models can have non `input_ids` input format - if not self.config.is_encoder_decoder and input_name != "input_ids": - raise ValueError( - f"If {input_name} is passed as model-specific keyword " - "input then model has to be an encoder-decoder and not a " - f"{self.__class__.__name__}." - ) + # 3. In the presence of `inputs_embeds` for text models: + # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model + # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with + # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`) + # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and + # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states. + if input_name == "input_ids" and "inputs_embeds" in model_kwargs: + if not self.config.is_encoder_decoder: + has_inputs_embeds_forwarding = "inputs_embeds" in set( + inspect.signature(self.prepare_inputs_for_generation).parameters.keys() + ) + if not has_inputs_embeds_forwarding: + raise ValueError( + f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} " + "doesn't have its forwarding implemented. See the GPT2 implementation for an example " + "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!" + ) + else: + if inputs is not None: + raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") + inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" - # 5. if `inputs` is still None, try to create `input_ids` from BOS token + # 4. if `inputs` is still None, try to create `input_ids` from BOS token if inputs is None: inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) return inputs, input_name, model_kwargs - def _can_retrieve_inputs_from_name( - self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor] - ) -> torch.Tensor: - """ - If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved - from name - """ - can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set( - inspect.signature(self.forward).parameters.keys() - ) - - if can_retrieve_inputs and inputs is not None: - raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}") - - return can_retrieve_inputs - def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: """ Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method. diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 5fe33bbca509..1a7ba62c4146 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -981,7 +981,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past_key_values: @@ -1000,14 +1000,23 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg position_ids = position_ids[:, -1].unsqueeze(-1) else: position_ids = None - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_code_sample_docstrings( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1546d14b438a..7a57a06afe54 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2359,17 +2359,6 @@ def test_encoder_decoder_generate_attention_mask(self): self.assertTrue(diff < 1e-4) - def test_decoder_generate_with_inputs_embeds(self): - article = """I need input_ids to generate""" - tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=5).to(torch_device) - input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - inputs_embeds = model.get_input_embeddings()(input_ids) - - # cannot generate from `inputs_embeds` for decoder only - with self.assertRaises(ValueError): - model.generate(inputs_embeds=inputs_embeds) - def test_generate_input_ids_as_kwarg(self): article = """I need input_ids to generate""" tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") @@ -2417,8 +2406,10 @@ def test_generate_inputs_and_encoder_kwargs(self): def test_generate_too_many_encoder_kwargs(self): article = """I need input_ids to generate""" - tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device) + tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=10).to( + torch_device + ) input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) with self.assertRaises(ValueError): model.generate(input_ids=input_ids, inputs_embeds=input_ids) @@ -3128,3 +3119,26 @@ def test_eos_token_id_int_and_list_beam_search(self): eos_token_id = [873] generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) self.assertTrue(expectation == len(generated_tokens[0])) + + def test_generate_from_input_embeds_decoder_only(self): + # Note: the model must support generation from input embeddings + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + + text = "Hello world" + input_ids = tokenizer.encode(text, return_tensors="pt") + + # Traditional way of generating text + outputs_from_ids = model.generate(input_ids) + + # Same thing, but from input embeddings + inputs_embeds = model.transformer.wte(input_ids) + outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds) + self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist()) + + # But if we pass different inputs_embeds, we should get different outputs + torch.manual_seed(0) + random_embeds = torch.rand_like(inputs_embeds) + outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())