diff --git a/src/arcsf/eval/evaluate.py b/src/arcsf/eval/evaluate.py index f32adf8..8510f55 100644 --- a/src/arcsf/eval/evaluate.py +++ b/src/arcsf/eval/evaluate.py @@ -212,7 +212,12 @@ def compute_dataset_metrics( batch_gen_kwargs = {key: value for key, value in generate_kwargs.items()} if batch_gen_kwargs["max_new_tokens"] == "adaptive": - batch_gen_kwargs["max_new_tokens"] = answers["input_ids"].shape[-1] + # +10 tokens to account for slight rephrasings in generated answers that + # may make them longer than the target answers + # and always generate at least 50 tokens + batch_gen_kwargs["max_new_tokens"] = max( + answers["input_ids"].shape[-1] + 10, 50 + ) with torch.no_grad(): # DO NOT pass position_ids into this method -> see collator for info @@ -229,13 +234,15 @@ def compute_dataset_metrics( target_answers = tokenizer.batch_decode( answers["input_ids"], skip_special_tokens=True ) - len_target_answers = answers["attention_mask"].sum(axis=1) + # as above, vary generated tokens to keep based on target answer length + # but allow for some rephrasing/extra tokens + len_target_answers = answers["attention_mask"].sum(axis=1) + 10 generated_answers = [ # start at len(q): only want the tokens for the answers # end at len(q) + targ_len: evaluate only as many tokens as in the # target answer tokenizer.decode( - gen_a[len(q) : (len(q) + targ_len)], + gen_a[len(q) : (len(q) + max(targ_len, 50))], skip_special_tokens=True, ) for q, gen_a, targ_len in zip(