Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't auto-recompute attention or linear #1648

Merged
merged 6 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ class OpTags(Enum):
CTX_MANAGER_ENTER_EXIT_OP = auto()
# Label to explicitly disable an operation from recomputing in backward - see function `recompute_saved_for_backward`.
DONT_RECOMPUTE_IN_BACKWARD = auto()
# Don't automatically tag operation to be recomputed in backward
DONT_AUTO_RECOMPUTE_IN_BACKWARD = auto()
t-vi marked this conversation as resolved.
Show resolved Hide resolved


# TODO RC1 Document this function and describe the parts of a primitive
Expand Down Expand Up @@ -3780,7 +3782,9 @@ def linear_meta(a: TensorProxy, w: TensorProxy, bias: None | TensorProxy) -> Ten
return TensorProxy(shape=out_shape, device=a.device, dtype=dtype, requires_grad=requires_grad)


linear = make_prim(PrimIDs.LINEAR, "linear", meta=linear_meta, tags=(OpTags.MATMUL_OP,))
linear = make_prim(
PrimIDs.LINEAR, "linear", meta=linear_meta, tags=(OpTags.MATMUL_OP, OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD)
)


def matmul_meta(a: TensorProxy, b: TensorProxy, /) -> TensorProxy:
Expand Down Expand Up @@ -3839,7 +3843,9 @@ def matmul_meta(a: TensorProxy, b: TensorProxy, /) -> TensorProxy:
return TensorProxy(like=a, shape=shape)


matmul = make_prim(PrimIDs.MATMUL, "matmul", meta=matmul_meta, tags=(OpTags.MATMUL_OP,))
matmul = make_prim(
PrimIDs.MATMUL, "matmul", meta=matmul_meta, tags=(OpTags.MATMUL_OP, OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD)
)

#
# NN prims
Expand Down
21 changes: 15 additions & 6 deletions thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from thunder.core.utils import safe_map_flat, sequencify
from thunder.core.proxies import variableify, ProxyTag
from thunder.core.transform_common import VJPDual
from thunder.core.symbol import has_tags


# TODO: Currently we use trace.args and trace.kwargs to get the arguments
Expand Down Expand Up @@ -183,12 +184,20 @@ def do_swap(v):

for new_bsym in new_bsyms:
# TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym?
for o in new_bsym.flat_proxy_outs:
if variableify(o) not in swap_map:
# when we decompose to compute the forward/backward, we mark intermediates as to be recomputed in the backward.
# Typically our decompositions are for things that will then be fused together.
# We could refine this heuristic to exclude "expensive" operations.
o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD)
if not has_tags(
bsym,
{
prims.OpTags.RANDOM_OP,
prims.OpTags.DONT_RECOMPUTE_IN_BACKWARD,
prims.OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD,
},
):
# when we decompose to compute the forward/backward, we mark intermediates as to be recomputed in the backward.
# Typically our decompositions are for things that will then be fused together.
# We tag "expensive" operations (sdpa, matmul) as DONT_AUTO_RECOMPUTE_IN_BACKWARD to block this.
for o in new_bsym.flat_proxy_outs:
if variableify(o) not in swap_map:
o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD)
new_trace.bound_symbols.append(
new_bsym.from_bsym_swap_proxies(swap_map).from_bsym(
source_filename=bsym.source_filename, source_positions=bsym.source_positions
Expand Down
2 changes: 2 additions & 0 deletions thunder/executors/cudnnex.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def _get_cudnn_handle(query_device):
from thunder.torch import TensorLike
from thunder.core.compile_data import get_compile_option
from thunder.core.proxies import Proxy, TensorProxy
from thunder.core.prims import OpTags


from thunder.core.transforms import (
Expand Down Expand Up @@ -425,6 +426,7 @@ def _cudnn_sdpa_checker(
"cudnn_sdpa_fwd",
meta=_cudnn_sdpa_forward_meta,
fn=_cudnn_sdpa_fwd_impl,
tags=(OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD,),
)


Expand Down
3 changes: 3 additions & 0 deletions thunder/executors/sdpaex.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from thunder.core.proxies import Proxy, TensorProxy
import thunder.core.utils as utils
import thunder.core.devices as devices
from thunder.core.prims import OpTags

import thunder.torch as ltorch
from thunder.torch import TensorLike
Expand Down Expand Up @@ -171,6 +172,7 @@ def _grad_forward_scaled_dot_product_efficient_attention_impl(
"sdpaex_grad_forward_scaled_dot_product_efficient_attention",
meta=_grad_forward_scaled_dot_product_efficient_attention_meta,
fn=_grad_forward_scaled_dot_product_efficient_attention_impl,
tags=(OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD,),
)


Expand Down Expand Up @@ -244,6 +246,7 @@ def _grad_forward_scaled_dot_product_flash_attention_impl(
"sdpafx_grad_forward_scaled_dot_product_efficient_attention",
meta=_grad_forward_scaled_dot_product_flash_attention_meta,
fn=_grad_forward_scaled_dot_product_flash_attention_impl,
tags=(OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD,),
)


Expand Down
6 changes: 3 additions & 3 deletions thunder/tests/test_examine_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_nanogpt_block():

# Actual memory usage may vary depending on hardware and cuBLAS settings.
# We are checking the estimated memory against a fixed value for consistency.
assert max_mem_fw[0] == 262183936
assert sum(max_mem_fw[1].values()) == 135306240
assert max_mem_bw[0] == 375516160
assert max_mem_fw[0] == 293641216
assert sum(max_mem_fw[1].values()) == 249601024
assert max_mem_bw[0] == 399633408
assert sum(max_mem_bw[1].values()) == 40934400
9 changes: 9 additions & 0 deletions thunder/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,4 +581,13 @@ def forward_backward_peak(m, inp):
mem_thunder = forward_backward_peak(jm, inp)
mem_eager = forward_backward_peak(m, inp)

# assert that attention is not automatically recomputed, see
# https://github.com/Lightning-AI/lightning-thunder/issues/1646
assert not {
bsym.sym.name
for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols
if ("attention" in bsym.sym.name or "sdpa" in bsym.sym.name)
and ("forward" in bsym.sym.name or "fwd" in bsym.sym.name)
}
t-vi marked this conversation as resolved.
Show resolved Hide resolved

assert mem_thunder < mem_eager
2 changes: 1 addition & 1 deletion thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5206,7 +5206,7 @@ def mse_loss(

# TODO Add annotations
# NOTE The scale parameter is kwarg-only in PyTorch
@torchsymbol(torch.nn.functional.scaled_dot_product_attention)
t-vi marked this conversation as resolved.
Show resolved Hide resolved
@torchsymbol(torch.nn.functional.scaled_dot_product_attention, tags=(prims.OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD,))
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *, scale=None):
for arg_name, arg in zip(("query", "key", "value"), (query, key, value)):
utils.check(
Expand Down
Loading