Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: decoder-only models can generate with inputs_embeds #21405

Merged
merged 6 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 23 additions & 30 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,47 +519,40 @@ def _prepare_model_inputs(
inputs_kwarg = model_kwargs.pop(input_name, None)
if inputs_kwarg is not None and inputs is not None:
raise ValueError(
f"`inputs`: {inputs}` were passed alongside "
f"{input_name} which is not allowed."
f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed."
f"Make sure to either pass {inputs} or {input_name}=..."
)
elif inputs_kwarg is not None:
inputs = inputs_kwarg

# 3. models with `input_ids` can also make use of `inputs_embeds`
if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"

# 4. Only encoder-decoder models can have non `input_ids` input format
if not self.config.is_encoder_decoder and input_name != "input_ids":
raise ValueError(
f"If {input_name} is passed as model-specific keyword "
"input then model has to be an encoder-decoder and not a "
f"{self.__class__.__name__}."
)
# 3. In the presence of `inputs_embeds` for text models:
# - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
# doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
if not self.config.is_encoder_decoder:
has_inputs_embeds_forwarding = "inputs_embeds" in set(
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
)
if not has_inputs_embeds_forwarding:
raise ValueError(
f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
else:
if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"

# 5. if `inputs` is still None, try to create `input_ids` from BOS token
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
if inputs is None:
inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))

return inputs, input_name, model_kwargs

def _can_retrieve_inputs_from_name(
self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""
If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved
from name
"""
can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set(
inspect.signature(self.forward).parameters.keys()
)

if can_retrieve_inputs and inputs is not None:
raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}")

return can_retrieve_inputs

def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
Expand Down
27 changes: 18 additions & 9 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ def get_output_embeddings(self):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past_key_values:
Expand All @@ -1000,14 +1000,23 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)
return model_inputs

@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
Expand Down
40 changes: 27 additions & 13 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2359,17 +2359,6 @@ def test_encoder_decoder_generate_attention_mask(self):

self.assertTrue(diff < 1e-4)

def test_decoder_generate_with_inputs_embeds(self):
article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=5).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
inputs_embeds = model.get_input_embeddings()(input_ids)

# cannot generate from `inputs_embeds` for decoder only
with self.assertRaises(ValueError):
model.generate(inputs_embeds=inputs_embeds)

def test_generate_input_ids_as_kwarg(self):
article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
Expand Down Expand Up @@ -2417,8 +2406,10 @@ def test_generate_inputs_and_encoder_kwargs(self):

def test_generate_too_many_encoder_kwargs(self):
article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=10).to(
torch_device
)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
with self.assertRaises(ValueError):
model.generate(input_ids=input_ids, inputs_embeds=input_ids)
Comment on lines 2407 to 2415
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this PR, this test only applies to encoder-decoder models (and, looking at the title of the test, it seems like it was the original goal :P )

Expand Down Expand Up @@ -3128,3 +3119,26 @@ def test_eos_token_id_int_and_list_beam_search(self):
eos_token_id = [873]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))

def test_generate_from_input_embeds_decoder_only(self):
# Note: the model must support generation from input embeddings
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")

text = "Hello world"
input_ids = tokenizer.encode(text, return_tensors="pt")

# Traditional way of generating text
outputs_from_ids = model.generate(input_ids)

# Same thing, but from input embeddings
inputs_embeds = model.transformer.wte(input_ids)
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())

# But if we pass different inputs_embeds, we should get different outputs
torch.manual_seed(0)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())