Skip to content

Commit

Permalink
use tagging checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Jan 8, 2025
1 parent 2979681 commit 1953f60
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 20 deletions.
4 changes: 0 additions & 4 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,10 +541,6 @@ def setup_distributed(self, model):
return model

def setup_activation_checkpointing(self):
if "thunder" in self.compile and "dynamo" not in self.compile:
# checkpointing is an option to thunder.jit
return

if any(isinstance(mod, CheckpointWrapper) for mod in self.model.modules()):
warnings.warn(
"FSDP checkpointing is configured, but the model already contains checkpointed layers."
Expand Down
55 changes: 45 additions & 10 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,11 @@ def _general_jit_torch_finfo_lookaside(dtype: thunder.dtypes.dtype):
def _general_jit_torch_checkpoint_lookaside(
function: Callable,
*args,
**kwargs: Any,
context_fn: None | Callable[..., Any] = None,
debug: None | bool = None,
determinism_check: None | str = None,
preserve_rng_state: None | bool = None,
use_reentrant: bool = False,
):
"""
This function does preprocessing of the `function` argument before
Expand All @@ -917,17 +921,48 @@ def _general_jit_torch_checkpoint_lookaside(
The result of calling `thunder.torch.checkpoint` with the preprocessed
`function` and its arguments.
"""
from thunder.torch import checkpoint

# It should be possible to call the general_thunder_jit here to handle the
# conversion from torch to thunder but it doesn't work now
# See https://github.com/Lightning-AI/lightning-thunder/issues/1126
# TODO: Convert the function to a Thunder function
def thunder_function(*args, **kwargs):
return unwrap(function)(*args, **kwargs)
if unwrap(use_reentrant):
return do_raise(
"torch.checkpoint: use_reentrant=True is not supported in Thunder",
)
# NOTE: Thunder currently ignores the context_fn, debug, determinism_check, preserve_rng_state arguments
# Let's raise a warning if any of these arguments are passed
if unwrap(context_fn) is not None:
warnings.warn("torch.checkpoint: context_fn is not supported in Thunder and will be ignored")
if unwrap(debug) is not None:
warnings.warn("torch.checkpoint: debug is not supported in Thunder and will be ignored")
if unwrap(determinism_check) is not None:
warnings.warn("torch.checkpoint: determinism_check is not supported in Thunder and will be ignored")
if unwrap(preserve_rng_state) is not None:
warnings.warn("torch.checkpoint: preserve_rng_state is not supported in Thunder and will be ignored")

jit_ctx: JitCtx = get_jit_ctx()
jit_ctx.computation_trace.push_scope([])

input_output_proxy_names = set()

def add_input_output_proxy_name(p):
if isinstance(p, Proxy):
input_output_proxy_names.add(p.name)

tree_map(add_input_output_proxy_name, [unwrap(a) for a in args])

res = _interpret_call(function, *args)
if res is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return res

tree_map(add_input_output_proxy_name, unwrap(res))

wrapped_thunder_function = wrap_const(thunder_function)
return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs)
new_bsyms = jit_ctx.computation_trace.pop_scope()
jit_ctx.computation_trace.bound_symbols.extend(new_bsyms)

for bsym in new_bsyms:
for o in bsym.flat_proxy_outs:
if o.name not in input_output_proxy_names:
o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD)

return res


# Adds proxy methods
Expand Down
57 changes: 53 additions & 4 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,12 +1180,14 @@ def func(a, b, *, c):
a = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True)
b = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True)
c = make_tensor((3,), device=device, dtype=torch.float64, requires_grad=True)
initial_trace = trace(inline_trace=False)(func, a, b, c=c)
jfn = thunder.jit(func)
cd, inps, _ = thunder.compile_data(jfn).get_computation_and_inputs(a, b, c=c)
initial_trace = cd.computation_traces[0]
wrapped_trace = wrap_return_value_together_with_arguments(initial_trace)
fw_trace, bw_trace = forward_and_backward_from_trace(wrapped_trace)
fw = executor.make_callable(fw_trace)
bw = executor.make_callable(bw_trace)
fw_out, saved_for_backward = fw(a, b, c=c)
fw_out, saved_for_backward = fw(*inps)

initial_trace = trace()(value_and_grad(func), a, b, c=c)
expected_vjp_func = executor.make_callable(initial_trace.python_callable(), disable_torch_autograd=True)
Expand All @@ -1195,6 +1197,7 @@ def func(a, b, *, c):

output_grads = tree_map(lambda x: torch.ones_like(x), fw_out["output"])
bw_out = bw(saved_for_backward, output_grads)
expected_grads = (*expected_grads[:-1], expected_grads[-1]["c"])
torch.testing.assert_close(bw_out, expected_grads)


Expand Down Expand Up @@ -1755,8 +1758,8 @@ def f(x, y):
# The intermediate values are recomputed during backward pass.
assert len(out.grad_fn.next_functions[0][0].saved_tensors) == 2
# We detach the saved tensors (which returns a new Python tensor backed by same storage)
assert out.grad_fn.next_functions[0][0].saved_tensors[0].data_ptr() == x.data_ptr()
assert out.grad_fn.next_functions[0][0].saved_tensors[1].data_ptr() == y.data_ptr()
# the order seems to be non-deterministic sometimes
assert {t.data_ptr() for t in out.grad_fn.next_functions[0][0].saved_tensors} == {x.data_ptr(), y.data_ptr()}

g = torch.ones_like(out)
out.backward(g)
Expand All @@ -1769,6 +1772,52 @@ def f(x, y):
torch.testing.assert_close(y.grad, y_ref.grad)


@requiresCUDA
def test_checkpoint_max_memory():
import torch.utils.checkpoint

class Checkpoint(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module

def forward(self, *args):
return torch.utils.checkpoint.checkpoint(self.module, *args, use_reentrant=False)

with torch.device("cuda:0"):
m = torch.nn.Sequential(
torch.nn.Linear(1024, 16),
torch.nn.ReLU(),
*[
Checkpoint(
torch.nn.Sequential(
torch.nn.Linear(16, 2048),
torch.nn.Linear(2048, 16),
torch.nn.ReLU(),
)
)
for _ in range(10)
],
torch.nn.Linear(16, 1024),
)
inps = torch.randn(512, 1024, requires_grad=True)

jm = thunder.jit(m, executors=()) # no rematerialization
res = jm(inps)
res.sum().backward()

torch.cuda.reset_peak_memory_stats()
mem_base = torch.cuda.memory_allocated()
res = jm(inps)
res.sum().backward()
mem_max = torch.cuda.max_memory_allocated()
# without chewckpointing the peak mem about 43MB.
# With checkpointing as coded in the model and recomputation where the
# values are used, we get a little over 10MB, so we put the barrier at 16MB
mb_used = (mem_max - mem_base) / 2**20
assert mb_used < 16


def test_inconsistent_output_length_grad_transform():
from thunder.extend import OperatorExecutor
from thunder.core.proxies import AnyProxy, TensorProxy
Expand Down
2 changes: 0 additions & 2 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@

# NOTE torch is a requirement
import torch
import torch.utils.checkpoint
import torch._higher_order_ops.wrap

import warnings
Expand Down Expand Up @@ -5325,7 +5324,6 @@ def _unwrap_if_dead(tensor):


@torchsymbol(
torch.utils.checkpoint.checkpoint,
torch.ops.higher_order.tag_activation_checkpoint,
id="activation_checkpoint",
)
Expand Down

0 comments on commit 1953f60

Please sign in to comment.