Skip to content

Commit

Permalink
fix test_grad::test_forward_and_backward_from_trace (#1623)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Jan 8, 2025
1 parent d19e063 commit c5c0b77
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,14 +1180,12 @@ def func(a, b, *, c):
a = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True)
b = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True)
c = make_tensor((3,), device=device, dtype=torch.float64, requires_grad=True)
jfn = thunder.jit(func)
cd, inps, _ = thunder.compile_data(jfn).get_computation_and_inputs(a, b, c=c)
initial_trace = cd.computation_traces[0]
initial_trace = trace(inline_trace=False)(func, a, b, c=c)
wrapped_trace = wrap_return_value_together_with_arguments(initial_trace)
fw_trace, bw_trace = forward_and_backward_from_trace(wrapped_trace)
fw = executor.make_callable(fw_trace)
bw = executor.make_callable(bw_trace)
fw_out, saved_for_backward = fw(*inps)
fw_out, saved_for_backward = fw(a, b, c=c)

initial_trace = trace()(value_and_grad(func), a, b, c=c)
expected_vjp_func = executor.make_callable(initial_trace.python_callable(), disable_torch_autograd=True)
Expand All @@ -1197,7 +1195,6 @@ def func(a, b, *, c):

output_grads = tree_map(lambda x: torch.ones_like(x), fw_out["output"])
bw_out = bw(saved_for_backward, output_grads)
expected_grads = (*expected_grads[:-1], expected_grads[-1]["c"])
torch.testing.assert_close(bw_out, expected_grads)


Expand Down

0 comments on commit c5c0b77

Please sign in to comment.