Skip to content

Commit

Permalink
don't pass input_pos_maxp1 to ThunderModules (#1912)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Jan 20, 2025
1 parent 5cd6f04 commit 9d6cfe6
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ def generate_fn(
token = prompt
prefill_token = True
input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64)
input_pos_maxp1 = torch.tensor(prompt_size, device=device)
if model.__class__.__name__ != 'ThunderModule':
input_pos_maxp1 = torch.tensor(prompt_size, device=device)
else:
input_pos_maxp1 = None
for current_idx in range(max_returned_tokens - prompt_size):

# Generate the token
Expand Down Expand Up @@ -222,7 +225,8 @@ def generate_fn(
input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64)
else:
input_pos.add_(1)
input_pos_maxp1.add_(1)
if input_pos_maxp1 is not None:
input_pos_maxp1.add_(1)

# Yield any remaining tokens
if yielded_idx < len(tokens):
Expand Down

0 comments on commit 9d6cfe6

Please sign in to comment.