diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 790be71cb8..ac028a0ebb 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -293,6 +293,8 @@ class OpTags(Enum): CTX_MANAGER_ENTER_EXIT_OP = auto() # Label to explicitly disable an operation from recomputing in backward - see function `recompute_saved_for_backward`. DONT_RECOMPUTE_IN_BACKWARD = auto() + # Don't automatically tag operation to be recomputed in backward + DONT_AUTO_RECOMPUTE_IN_BACKWARD = auto() # TODO RC1 Document this function and describe the parts of a primitive @@ -3780,7 +3782,9 @@ def linear_meta(a: TensorProxy, w: TensorProxy, bias: None | TensorProxy) -> Ten return TensorProxy(shape=out_shape, device=a.device, dtype=dtype, requires_grad=requires_grad) -linear = make_prim(PrimIDs.LINEAR, "linear", meta=linear_meta, tags=(OpTags.MATMUL_OP,)) +linear = make_prim( + PrimIDs.LINEAR, "linear", meta=linear_meta, tags=(OpTags.MATMUL_OP, OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD) +) def matmul_meta(a: TensorProxy, b: TensorProxy, /) -> TensorProxy: @@ -3839,7 +3843,9 @@ def matmul_meta(a: TensorProxy, b: TensorProxy, /) -> TensorProxy: return TensorProxy(like=a, shape=shape) -matmul = make_prim(PrimIDs.MATMUL, "matmul", meta=matmul_meta, tags=(OpTags.MATMUL_OP,)) +matmul = make_prim( + PrimIDs.MATMUL, "matmul", meta=matmul_meta, tags=(OpTags.MATMUL_OP, OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD) +) # # NN prims diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index 491ba49233..4b795868d0 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -8,6 +8,7 @@ from thunder.core.utils import safe_map_flat, sequencify from thunder.core.proxies import variableify, ProxyTag from thunder.core.transform_common import VJPDual +from thunder.core.symbol import has_tags # TODO: Currently we use trace.args and trace.kwargs to get the arguments @@ -183,12 +184,20 @@ def do_swap(v): for new_bsym in new_bsyms: # TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym? - for o in new_bsym.flat_proxy_outs: - if variableify(o) not in swap_map: - # when we decompose to compute the forward/backward, we mark intermediates as to be recomputed in the backward. - # Typically our decompositions are for things that will then be fused together. - # We could refine this heuristic to exclude "expensive" operations. - o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD) + if not has_tags( + bsym, + { + prims.OpTags.RANDOM_OP, + prims.OpTags.DONT_RECOMPUTE_IN_BACKWARD, + prims.OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD, + }, + ): + # when we decompose to compute the forward/backward, we mark intermediates as to be recomputed in the backward. + # Typically our decompositions are for things that will then be fused together. + # We tag "expensive" operations (sdpa, matmul) as DONT_AUTO_RECOMPUTE_IN_BACKWARD to block this. + for o in new_bsym.flat_proxy_outs: + if variableify(o) not in swap_map: + o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD) new_trace.bound_symbols.append( new_bsym.from_bsym_swap_proxies(swap_map).from_bsym( source_filename=bsym.source_filename, source_positions=bsym.source_positions diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 409c167284..e7f9f4505d 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -79,6 +79,7 @@ def _get_cudnn_handle(query_device): from thunder.torch import TensorLike from thunder.core.compile_data import get_compile_option from thunder.core.proxies import Proxy, TensorProxy +from thunder.core.prims import OpTags from thunder.core.transforms import ( @@ -425,6 +426,7 @@ def _cudnn_sdpa_checker( "cudnn_sdpa_fwd", meta=_cudnn_sdpa_forward_meta, fn=_cudnn_sdpa_fwd_impl, + tags=(OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD,), ) diff --git a/thunder/executors/sdpaex.py b/thunder/executors/sdpaex.py index 3fc6eab9e9..85b5145d30 100644 --- a/thunder/executors/sdpaex.py +++ b/thunder/executors/sdpaex.py @@ -7,6 +7,7 @@ from thunder.core.proxies import Proxy, TensorProxy import thunder.core.utils as utils import thunder.core.devices as devices +from thunder.core.prims import OpTags import thunder.torch as ltorch from thunder.torch import TensorLike @@ -171,6 +172,7 @@ def _grad_forward_scaled_dot_product_efficient_attention_impl( "sdpaex_grad_forward_scaled_dot_product_efficient_attention", meta=_grad_forward_scaled_dot_product_efficient_attention_meta, fn=_grad_forward_scaled_dot_product_efficient_attention_impl, + tags=(OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD,), ) @@ -244,6 +246,7 @@ def _grad_forward_scaled_dot_product_flash_attention_impl( "sdpafx_grad_forward_scaled_dot_product_efficient_attention", meta=_grad_forward_scaled_dot_product_flash_attention_meta, fn=_grad_forward_scaled_dot_product_flash_attention_impl, + tags=(OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD,), ) diff --git a/thunder/tests/test_examine_memory.py b/thunder/tests/test_examine_memory.py index 4a7b40d194..2f89f834ae 100644 --- a/thunder/tests/test_examine_memory.py +++ b/thunder/tests/test_examine_memory.py @@ -113,7 +113,7 @@ def test_nanogpt_block(): # Actual memory usage may vary depending on hardware and cuBLAS settings. # We are checking the estimated memory against a fixed value for consistency. - assert max_mem_fw[0] == 262183936 - assert sum(max_mem_fw[1].values()) == 135306240 - assert max_mem_bw[0] == 375516160 + assert max_mem_fw[0] == 293641216 + assert sum(max_mem_fw[1].values()) == 249601024 + assert max_mem_bw[0] == 399633408 assert sum(max_mem_bw[1].values()) == 40934400 diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index daea2dbff6..265c46a4b0 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -581,4 +581,13 @@ def forward_backward_peak(m, inp): mem_thunder = forward_backward_peak(jm, inp) mem_eager = forward_backward_peak(m, inp) + # assert that attention is not automatically recomputed, see + # https://github.com/Lightning-AI/lightning-thunder/issues/1646 + assert not { + bsym.sym.name + for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols + if ("attention" in bsym.sym.name or "sdpa" in bsym.sym.name) + and ("forward" in bsym.sym.name or "fwd" in bsym.sym.name) + } + assert mem_thunder < mem_eager diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 0d10f03563..bb661df2ab 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5206,7 +5206,7 @@ def mse_loss( # TODO Add annotations # NOTE The scale parameter is kwarg-only in PyTorch -@torchsymbol(torch.nn.functional.scaled_dot_product_attention) +@torchsymbol(torch.nn.functional.scaled_dot_product_attention, tags=(prims.OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD,)) def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *, scale=None): for arg_name, arg in zip(("query", "key", "value"), (query, key, value)): utils.check(