Skip to content

Commit

Permalink
add extra generate tokens in evaluate (merge)
Browse files Browse the repository at this point in the history
  • Loading branch information
jack89roberts committed Aug 16, 2024
1 parent b118080 commit 613beef
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/arcsf/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,12 @@ def compute_dataset_metrics(

batch_gen_kwargs = {key: value for key, value in generate_kwargs.items()}
if batch_gen_kwargs["max_new_tokens"] == "adaptive":
batch_gen_kwargs["max_new_tokens"] = answers["input_ids"].shape[-1]
# +10 tokens to account for slight rephrasings in generated answers that
# may make them longer than the target answers
# and always generate at least 50 tokens
batch_gen_kwargs["max_new_tokens"] = max(
answers["input_ids"].shape[-1] + 10, 50
)

with torch.no_grad():
# DO NOT pass position_ids into this method -> see collator for info
Expand All @@ -229,13 +234,15 @@ def compute_dataset_metrics(
target_answers = tokenizer.batch_decode(
answers["input_ids"], skip_special_tokens=True
)
len_target_answers = answers["attention_mask"].sum(axis=1)
# as above, vary generated tokens to keep based on target answer length
# but allow for some rephrasing/extra tokens
len_target_answers = answers["attention_mask"].sum(axis=1) + 10
generated_answers = [
# start at len(q): only want the tokens for the answers
# end at len(q) + targ_len: evaluate only as many tokens as in the
# target answer
tokenizer.decode(
gen_a[len(q) : (len(q) + targ_len)],
gen_a[len(q) : (len(q) + max(targ_len, 50))],
skip_special_tokens=True,
)
for q, gen_a, targ_len in zip(
Expand Down

0 comments on commit 613beef

Please sign in to comment.