Skip to content

Commit

Permalink
nit: Correct compile_loss return type hint (#1940)
Browse files Browse the repository at this point in the history
  • Loading branch information
bradhilton authored Nov 1, 2024
1 parent eab21f0 commit f560cbb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchtune/training/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def compile_model(
model.compile(backend=backend)


def compile_loss(loss: nn.Module, verbose: bool = True) -> None:
def compile_loss(loss: nn.Module, verbose: bool = True) -> nn.Module:
"""
Utility to compile and return loss function. If the loss function is chunked cross-entropy,
we only compile the upcast + cross-entropy calculation, not the chunking. For other losses
Expand Down

0 comments on commit f560cbb

Please sign in to comment.