Skip to content

Commit

Permalink
Add input_embeds functionality to gpt_neo Causal LM (#25659)
Browse files Browse the repository at this point in the history
* Updated gpt_neo causalLM to support using input embeddings for generation

* added indentation

* Did make fixup
  • Loading branch information
gaasher authored Aug 22, 2023
1 parent 908f853 commit 977b2f0
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def get_output_embeddings(self):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past_key_values:
Expand All @@ -698,14 +698,24 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)

return model_inputs

@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
Expand Down

0 comments on commit 977b2f0

Please sign in to comment.