Skip to content

Commit

Permalink
[SFT] fix check for AutoLigerKernelForCausalLM (#2874)
Browse files Browse the repository at this point in the history
* fix check for AutoLigerKernelForCausalLM

* fix case where AutoLigerKernelForCausalLM is not defined

* update min liger version

* formatting

* fix win CI
  • Loading branch information
kashif authored and qgallouedec committed Feb 17, 2025
1 parent def8e48 commit 7596db8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
"diffusers": ["diffusers>=0.18.0"],
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
# liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility
"liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"],
"liger": ["liger-kernel>=0.5.3; sys_platform != 'win32'"],
"mergekit": ["mergekit>=0.0.5.1"],
"peft": ["peft>=0.8.0"],
"quantization": ["bitsandbytes"],
Expand Down
7 changes: 6 additions & 1 deletion trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@

if is_liger_kernel_available():
from liger_kernel.transformers import AutoLigerKernelForCausalLM
else:
AutoLigerKernelForCausalLM = None

if is_wandb_available():
import wandb
Expand Down Expand Up @@ -449,7 +451,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
)

# Compute token accuracy if we have labels and if the model is not using Liger (no logits)
if "labels" in inputs and not self.args.use_liger:
use_liger = self.args.use_liger or (
AutoLigerKernelForCausalLM is not None and isinstance(model, AutoLigerKernelForCausalLM)
)
if "labels" in inputs and not use_liger:
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = inputs["labels"][..., 1:].contiguous()

Expand Down

0 comments on commit 7596db8

Please sign in to comment.