Skip to content

Commit

Permalink
LIT: Fix assertion error when generation stops short of max_length
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681096247
  • Loading branch information
RyanMullins authored and LIT team committed Oct 2, 2024
1 parent 3dc61a2 commit 400a239
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions lit_nlp/examples/prompt_debugging/transformers_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,12 @@ def _get_batched_outputs(
if self.framework == MLFramework.PT:
encoded_inputs = encoded_inputs.to(self.device)

outputs = self.model.generate(
encoded_inputs["input_ids"],
attention_mask=encoded_inputs["attention_mask"],
max_length=self.max_length,
)
ntok_out = self.max_length - encoded_inputs["input_ids"].shape[1]
outputs = self.model.generate(**encoded_inputs, max_length=self.max_length)

if isinstance(outputs, transformers.utils.ModelOutput):
outputs = outputs.sequences

ntok_out = outputs.shape[1] - encoded_inputs["input_ids"].shape[1]

responses = self.tokenizer.batch_decode(
outputs[:, -ntok_out:], skip_special_tokens=True
Expand Down

0 comments on commit 400a239

Please sign in to comment.