forked from facebookresearch/chameleon
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcudagraph.py
85 lines (70 loc) · 2.49 KB
/
cudagraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Chameleon License found in the
# LICENSE file in the root directory of this source tree.
import functools
from typing import Any, Callable, TypeVar
import torch
T = TypeVar("T")
FN = Callable[..., T] # type: ignore
class CUDAGraphWrapper:
def __init__(
self,
fn: FN[T],
warmup_iter: int = 1,
debug_dump_path: str | None = None,
):
self.fn = fn
self.warmup_iter = warmup_iter
self.debug_dump_path = debug_dump_path
self.graph: torch.cuda.CUDAGraph | None = None
self.result: T | None = None
def __call__(self, *args, **kwargs) -> Any: # type: ignore
if self.warmup_iter > 0:
self.warmup_iter -= 1
return self.fn(*args, **kwargs)
if self.graph is None:
self.graph = torch.cuda.CUDAGraph()
if self.debug_dump_path is not None:
self.graph.enable_debug_mode()
recording_kwargs = {}
if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
# In PyTorch 2.1+ and nightlies from late Aug 2023,
# we can do this to maybe avoid watchdog-related crashes
recording_kwargs["capture_error_mode"] = "thread_local"
with torch.cuda.graph(self.graph, **recording_kwargs):
self.result = self.fn(*args, **kwargs)
torch.cuda.synchronize()
if self.debug_dump_path is not None:
self.graph.debug_dump(self.debug_dump_path)
assert self.graph is not None
self.graph.replay()
return self.result
def cudagraph_wrap(
*args,
warmup_iter: int = 1,
debug_dump_path: str | None = None,
) -> Callable[[FN[T]], FN[T]]:
def wrapper(fn: FN[T]) -> FN[T]:
graph_wrapper = CUDAGraphWrapper(
fn, warmup_iter=warmup_iter, debug_dump_path=debug_dump_path
)
@functools.wraps(fn)
def call_wrapper(*inner_args, **inner_kwargs):
return graph_wrapper(*inner_args, **inner_kwargs)
return call_wrapper
# @cudagraph_wrap
# def fn(...):
# ...
#
# - or -
#
# fast_fn = cudagraph_wrap(slow_fn, warmup_iter=2)
if len(args) == 1 and callable(args[0]):
return wrapper(args[0])
# @cudagraph_wrap(warmup_iter=3)
# def fn(...):
# ...
def decorator(fn: FN[T]) -> FN[T]:
return wrapper(fn)
return decorator