Skip to content

Commit

Permalink
fix and reenable forward backward rematerialization (#1622)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored and riccardofelluga committed Jan 27, 2025
1 parent 1fd6f57 commit 0f29768
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 15 deletions.
43 changes: 40 additions & 3 deletions thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,21 +591,57 @@ def rematerialize_forward_and_backward(fw_trace: TraceCtx, bw_trace: TraceCtx) -
_update_backward_with_new_saved_for_backward,
_update_forward_with_new_saved_for_backward,
)
from thunder.core.trace import tracectx

def joint_fn(args, kwargs, cotangents):
pass

joint_extrace = TraceCtx(joint_fn)
joint_extrace.names = set.union(fw_trace.names, bw_trace.names)
# name clash in args?
joint_extrace.args = (fw_trace.args, fw_trace.kwargs, bw_trace.args[1])
assert fw_trace.bound_symbols[-1].sym.id == PrimIDs.RETURN
assert bw_trace.bound_symbols[-1].sym.id == PrimIDs.RETURN
# Omit the last RETURN symbol
joint_extrace.bound_symbols = fw_trace.bound_symbols[:-1] + bw_trace.bound_symbols[:-1]
joint_extrace.bound_symbols = fw_trace.bound_symbols[:-1]

def apply_to_proxy_outputs_and_subsymbols(bsym, fn):
for p in bsym.flat_proxy_outs:
fn(p)
for subsym in bsym.subsymbols:
apply_to_proxy_outputs_and_subsymbols(subsym, fn)

swapmap = {}
skipmap = set()

def add_to_swapmap(p):
v = variableify(p)
if isinstance(p, TensorProxy) and v not in swapmap and p.name not in skipmap:
with tracectx(joint_extrace):
swapmap[v] = p.replace(name=f"bw_{p.name}")

for bsym in bw_trace.bound_symbols[:-1]:
if bsym.sym.id != PrimIDs.UNPACK_SEQUENCE:
# we want rename except the saved for backkwards tensors
apply_to_proxy_outputs_and_subsymbols(bsym, add_to_swapmap)
else:
for p in bsym.flat_proxy_outs:
skipmap.add(p.name)
joint_extrace.bound_symbols.append(bsym.from_bsym_swap_proxies(swapmap))

# Add a new RETURN symbol
joint_extrace.bound_symbols.append(
replace(fw_trace.bound_symbols[-1], args=(fw_trace.bound_symbols[-1].args[0], bw_trace.bound_symbols[-1].args))
replace(
fw_trace.bound_symbols[-1],
args=(
fw_trace.bound_symbols[-1].args[0],
tuple(
bw_trace.bound_symbols[-1].from_bsym_swap_proxies(swapmap).args,
),
),
)
)

joint_extrace = rematerialize(joint_extrace)

# We need to update "save_for_backward" sequence
Expand Down Expand Up @@ -641,7 +677,7 @@ def joint_fn(args, kwargs, cotangents):
new_bw_trace = from_trace(bw_trace)
new_bw_trace.set_provenance(TraceProvenance("Rematerialization"))
new_bw_trace.bound_symbols = new_bw_bsyms
new_bw_trace.bound_symbols.append(replace(bw_trace.bound_symbols[-1], args=bw_trace.bound_symbols[-1].args))
new_bw_trace.bound_symbols.append(bw_trace.bound_symbols[-1].from_bsym_swap_proxies(swapmap))
_update_backward_with_new_saved_for_backward(new_bw_trace, new_required_for_backward)

new_fw_trace = from_trace(fw_trace)
Expand All @@ -660,6 +696,7 @@ def joint_fn(args, kwargs, cotangents):
# Update the call context
new_fw_trace = update_fusion_call_ctx(new_fw_trace)
new_bw_trace = update_fusion_call_ctx(new_bw_trace)

return new_fw_trace, new_bw_trace


Expand Down
6 changes: 3 additions & 3 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3266,18 +3266,18 @@ def compute_proxy_from_producer(p):
compute_proxy_from_producer(p)
for o in producer_bsym.flat_proxy_outs:
have_in_backward.add(variableify(o))
new_bwd_trace.bound_symbols.append(producer_bsym)
new_bwd_trace.bound_symbols.append(producer_bsym.from_bsym())

for idx, bsym in enumerate(bwd_trace.bound_symbols):
if idx in {4, 5}:
# handled later
new_bwd_trace.bound_symbols.append(bsym)
new_bwd_trace.bound_symbols.append(bsym.from_bsym())
else:
for p in bsym.flat_proxy_args:
compute_proxy_from_producer(p)
for o in bsym.flat_proxy_outs:
have_in_backward.add(variableify(o))
new_bwd_trace.bound_symbols.append(bsym)
new_bwd_trace.bound_symbols.append(bsym.from_bsym())

new_fwd_trace = from_trace(fwd_trace)
new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy()
Expand Down
7 changes: 1 addition & 6 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch

import thunder.core.utils as utils
from thunder.core.compile_data import get_compile_option
from thunder.core.prims import PrimIDs
from thunder.core.proxies import TensorProxy, variableify
from thunder.core.pytree import tree_flatten
Expand Down Expand Up @@ -354,11 +353,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
)
bw_traces.append(bw_extrace)

use_rematerialization: None | bool = get_compile_option(
"use_forward_backward_rematerialization", "use rematerialization of saved for backward values in fusions"
)
if use_rematerialization:
fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace)
fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace)
fw_traces.append(fw_extrace)
bw_traces.append(bw_extrace)

Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_examine_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,5 @@ def test_nanogpt_block():
# We are checking the estimated memory against a fixed value for consistency.
assert max_mem_fw[0] == 262183936
assert sum(max_mem_fw[1].values()) == 135306240
assert max_mem_bw[0] == 484833280
assert sum(max_mem_bw[1].values()) == 169915392
assert max_mem_bw[0] == 375516160
assert sum(max_mem_bw[1].values()) == 40934400
3 changes: 2 additions & 1 deletion thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1943,7 +1943,8 @@ def test_backward_recomputation_decomposed_ops(device):
def fn(a):
return torch.nn.functional.gelu(a)

jfn = thunder.jit(fn, enable_saved_for_backward_recomputation=False)
# rematerialization will also trigger recomputation here.
jfn = thunder.jit(fn, executors=(), enable_saved_for_backward_recomputation=False)
jfn2 = thunder.jit(fn, enable_saved_for_backward_recomputation=True)
a = torch.randn(2, 2, device=device, requires_grad=True)
res = jfn(a)
Expand Down
29 changes: 29 additions & 0 deletions thunder/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,32 @@ def test_hf_llama():
top_level_symbol_names = {bsym.sym.name for bsym in thunder.last_traces(jm)[-1].bound_symbols}
# changes this to fewer as needed, the goal is to not have too many fusions
assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 7


@requiresCUDA
def test_memory_litgpt_llama3():
from thunder.tests import litgpt_model

def forward_backward_peak(m, inp):
torch.cuda.reset_peak_memory_stats(device=None)
mem_before = torch.cuda.max_memory_allocated()
res = m(inp)
res.sum().backward()
mem_after = torch.cuda.max_memory_allocated()
return (mem_after - mem_before) / 2**20

with torch.device("cuda"):
m = litgpt_model.GPT.from_name("llama2-like").bfloat16()
inp = torch.ones((1, 2048), dtype=torch.int64)

# warmup, allocate grads etc.
forward_backward_peak(m, inp)
forward_backward_peak(m, inp)
jm = thunder.jit(m)
forward_backward_peak(jm, inp)
forward_backward_peak(jm, inp)

mem_thunder = forward_backward_peak(jm, inp)
mem_eager = forward_backward_peak(m, inp)

assert mem_thunder < mem_eager
1 change: 1 addition & 0 deletions thunder/tests/test_torch_compile_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_torch_compile_litgpt():
# https://github.com/Lightning-AI/lightning-thunder/issues/292 The issue was
# that the CSE pass was not being run correctly on the TorchCompile region.
# Here we test that everything works as expected.
@pytest.mark.skip(reason="https://github.com/NVIDIA/Fuser/issues/3688")
@pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported")
@requiresCUDA
@pytest.mark.skipif(not device_supports_bf16(torch.device("cuda")), reason="bf16 is not supported")
Expand Down

0 comments on commit 0f29768

Please sign in to comment.