Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use tagging checkpointing #1616

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,10 +541,6 @@ def setup_distributed(self, model):
return model

def setup_activation_checkpointing(self):
if "thunder" in self.compile and "dynamo" not in self.compile:
# checkpointing is an option to thunder.jit
return

if any(isinstance(mod, CheckpointWrapper) for mod in self.model.modules()):
warnings.warn(
"FSDP checkpointing is configured, but the model already contains checkpointed layers."
Expand Down
55 changes: 45 additions & 10 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,11 @@ def _general_jit_torch_finfo_lookaside(dtype: thunder.dtypes.dtype):
def _general_jit_torch_checkpoint_lookaside(
function: Callable,
*args,
**kwargs: Any,
context_fn: None | Callable[..., Any] = None,
debug: None | bool = None,
determinism_check: None | str = None,
preserve_rng_state: None | bool = None,
use_reentrant: bool = False,
):
"""
This function does preprocessing of the `function` argument before
Expand All @@ -917,17 +921,48 @@ def _general_jit_torch_checkpoint_lookaside(
The result of calling `thunder.torch.checkpoint` with the preprocessed
`function` and its arguments.
"""
from thunder.torch import checkpoint

# It should be possible to call the general_thunder_jit here to handle the
# conversion from torch to thunder but it doesn't work now
# See https://github.com/Lightning-AI/lightning-thunder/issues/1126
# TODO: Convert the function to a Thunder function
def thunder_function(*args, **kwargs):
return unwrap(function)(*args, **kwargs)
if unwrap(use_reentrant):
return do_raise(
"torch.checkpoint: use_reentrant=True is not supported in Thunder",
)
# NOTE: Thunder currently ignores the context_fn, debug, determinism_check, preserve_rng_state arguments
# Let's raise a warning if any of these arguments are passed
if unwrap(context_fn) is not None:
warnings.warn("torch.checkpoint: context_fn is not supported in Thunder and will be ignored")
if unwrap(debug) is not None:
warnings.warn("torch.checkpoint: debug is not supported in Thunder and will be ignored")
if unwrap(determinism_check) is not None:
warnings.warn("torch.checkpoint: determinism_check is not supported in Thunder and will be ignored")
if unwrap(preserve_rng_state) is not None:
warnings.warn("torch.checkpoint: preserve_rng_state is not supported in Thunder and will be ignored")

jit_ctx: JitCtx = get_jit_ctx()
jit_ctx.computation_trace.push_scope([])

input_output_proxy_names = set()

def add_input_output_proxy_name(p):
if isinstance(p, Proxy):
input_output_proxy_names.add(p.name)

tree_map(add_input_output_proxy_name, [unwrap(a) for a in args])

res = _interpret_call(function, *args)
if res is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return res

tree_map(add_input_output_proxy_name, unwrap(res))

wrapped_thunder_function = wrap_const(thunder_function)
return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs)
new_bsyms = jit_ctx.computation_trace.pop_scope()
jit_ctx.computation_trace.bound_symbols.extend(new_bsyms)

for bsym in new_bsyms:
for o in bsym.flat_proxy_outs:
if o.name not in input_output_proxy_names:
o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD)

return res


# Adds proxy methods
Expand Down
57 changes: 53 additions & 4 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,12 +1180,14 @@ def func(a, b, *, c):
a = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True)
b = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True)
c = make_tensor((3,), device=device, dtype=torch.float64, requires_grad=True)
initial_trace = trace(inline_trace=False)(func, a, b, c=c)
jfn = thunder.jit(func)
cd, inps, _ = thunder.compile_data(jfn).get_computation_and_inputs(a, b, c=c)
initial_trace = cd.computation_traces[0]
wrapped_trace = wrap_return_value_together_with_arguments(initial_trace)
fw_trace, bw_trace = forward_and_backward_from_trace(wrapped_trace)
fw = executor.make_callable(fw_trace)
bw = executor.make_callable(bw_trace)
fw_out, saved_for_backward = fw(a, b, c=c)
fw_out, saved_for_backward = fw(*inps)

initial_trace = trace()(value_and_grad(func), a, b, c=c)
expected_vjp_func = executor.make_callable(initial_trace.python_callable(), disable_torch_autograd=True)
Expand All @@ -1195,6 +1197,7 @@ def func(a, b, *, c):

output_grads = tree_map(lambda x: torch.ones_like(x), fw_out["output"])
bw_out = bw(saved_for_backward, output_grads)
expected_grads = (*expected_grads[:-1], expected_grads[-1]["c"])
torch.testing.assert_close(bw_out, expected_grads)


Expand Down Expand Up @@ -1755,8 +1758,8 @@ def f(x, y):
# The intermediate values are recomputed during backward pass.
assert len(out.grad_fn.next_functions[0][0].saved_tensors) == 2
# We detach the saved tensors (which returns a new Python tensor backed by same storage)
assert out.grad_fn.next_functions[0][0].saved_tensors[0].data_ptr() == x.data_ptr()
assert out.grad_fn.next_functions[0][0].saved_tensors[1].data_ptr() == y.data_ptr()
# the order seems to be non-deterministic sometimes
assert {t.data_ptr() for t in out.grad_fn.next_functions[0][0].saved_tensors} == {x.data_ptr(), y.data_ptr()}

g = torch.ones_like(out)
out.backward(g)
Expand All @@ -1769,6 +1772,52 @@ def f(x, y):
torch.testing.assert_close(y.grad, y_ref.grad)


@requiresCUDA
def test_checkpoint_max_memory():
import torch.utils.checkpoint

class Checkpoint(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module

def forward(self, *args):
return torch.utils.checkpoint.checkpoint(self.module, *args, use_reentrant=False)

with torch.device("cuda:0"):
m = torch.nn.Sequential(
torch.nn.Linear(1024, 16),
torch.nn.ReLU(),
*[
Checkpoint(
torch.nn.Sequential(
torch.nn.Linear(16, 2048),
torch.nn.Linear(2048, 16),
torch.nn.ReLU(),
)
)
for _ in range(10)
],
torch.nn.Linear(16, 1024),
)
inps = torch.randn(512, 1024, requires_grad=True)

jm = thunder.jit(m, executors=()) # no rematerialization
res = jm(inps)
res.sum().backward()

torch.cuda.reset_peak_memory_stats()
mem_base = torch.cuda.memory_allocated()
res = jm(inps)
res.sum().backward()
mem_max = torch.cuda.max_memory_allocated()
# without chewckpointing the peak mem about 43MB.
# With checkpointing as coded in the model and recomputation where the
# values are used, we get a little over 10MB, so we put the barrier at 16MB
mb_used = (mem_max - mem_base) / 2**20
assert mb_used < 16


def test_inconsistent_output_length_grad_transform():
from thunder.extend import OperatorExecutor
from thunder.core.proxies import AnyProxy, TensorProxy
Expand Down
2 changes: 0 additions & 2 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@

# NOTE torch is a requirement
import torch
import torch.utils.checkpoint
import torch._higher_order_ops.wrap

import warnings
Expand Down Expand Up @@ -5325,7 +5324,6 @@ def _unwrap_if_dead(tensor):


@torchsymbol(
torch.utils.checkpoint.checkpoint,
torch.ops.higher_order.tag_activation_checkpoint,
id="activation_checkpoint",
)
Expand Down
Loading