Skip to content

Commit 4956850

Browse files
ezyangpobin6
authored and
pobin6
committed
Introduce CompiledAOTI (pytorch#141695)
Stacked on pytorch#141691 Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#141695 Approved by: https://github.com/aorenste ghstack dependencies: pytorch#141681, pytorch#141683, pytorch#141685, pytorch#141688, pytorch#141689, pytorch#141691
1 parent 2d10b81 commit 4956850

File tree

4 files changed

+85
-33
lines changed

4 files changed

+85
-33
lines changed

torch/_dynamo/repro/after_aot.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from torch._dynamo.trace_rules import is_fbcode
4141
from torch._dynamo.utils import clone_inputs, counters, same
42+
from torch._inductor.output_code import OutputCode
4243
from torch.fx.experimental.proxy_tensor import make_fx
4344
from torch.fx.experimental.symbolic_shapes import (
4445
fx_placeholder_targets,
@@ -51,7 +52,6 @@
5152

5253
if TYPE_CHECKING:
5354
from torch._inductor.compile_fx import _CompileFxCallable, _CompileFxKwargs
54-
from torch._inductor.output_code import CompiledFxGraph
5555
from torch._inductor.utils import InputType
5656

5757

@@ -83,7 +83,7 @@ def debug_wrapper(
8383
gm: torch.fx.GraphModule,
8484
example_inputs: Sequence["InputType"],
8585
**kwargs: Unpack["_CompileFxKwargs"],
86-
) -> Union["CompiledFxGraph", str]:
86+
) -> OutputCode:
8787
from torch._subclasses import FakeTensorMode
8888

8989
compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs)

torch/_inductor/codecache.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@
8484
if TYPE_CHECKING:
8585
from collections.abc import KeysView
8686

87-
from .compile_fx import _CompileFxKwargs
88-
from .output_code import CompiledFxGraph
87+
from .compile_fx import _CompileFxKwargs, CompiledFxGraph
88+
from .output_code import OutputCode
8989
from .remote_cache import JsonDataTy, RemoteCache
9090
from .utils import InputType
9191

@@ -1322,14 +1322,19 @@ def post_compile(
13221322
@staticmethod
13231323
def _save_graph(
13241324
key: str,
1325-
compiled_graph: CompiledFxGraph,
1325+
compiled_graph: OutputCode,
13261326
example_inputs: Sequence[InputType],
13271327
local: bool,
13281328
remote_cache: Optional[RemoteCache[JsonDataTy]],
13291329
) -> None:
13301330
"""
13311331
Store a serialized CompiledFxGraph on disk.
13321332
"""
1333+
from .compile_fx import CompiledFxGraph
1334+
1335+
assert isinstance(
1336+
compiled_graph, CompiledFxGraph
1337+
), f"serialization for {type(compiled_graph)} NYI"
13331338
disk_compiled_graph = copy(compiled_graph)
13341339
# We can't really serialize callables that may be C++/Triton/etc.,
13351340
# so we serialize their PyCodeCache disk cache location instead.

torch/_inductor/compile_fx.py

+25-28
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,11 @@
5656
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo
5757
from torch._inductor.debug import save_args_for_compile_fx_inner
5858
from torch._inductor.output_code import (
59+
CompiledAOTI,
5960
CompiledFxGraph,
6061
get_expanded_dims,
6162
index_expanded_dims,
63+
OutputCode,
6264
)
6365
from torch._inductor.runtime.runtime_utils import cache_dir
6466
from torch._inductor.utils import (
@@ -509,15 +511,15 @@ def __call__(
509511
gm: GraphModule,
510512
example_inputs: Sequence[InputType],
511513
**kwargs: Unpack[_CompileFxKwargs],
512-
) -> Union[CompiledFxGraph, str]:
514+
) -> OutputCode:
513515
...
514516

515517

516518
def compile_fx_inner(
517519
gm: GraphModule,
518520
example_inputs: Sequence[InputType],
519521
**kwargs: Unpack[_CompileFxKwargs],
520-
) -> Union[CompiledFxGraph, str]:
522+
) -> OutputCode:
521523
kwargs.setdefault("cudagraphs", None)
522524
kwargs.setdefault("static_input_idxs", ())
523525
kwargs.setdefault("is_backward", False)
@@ -570,7 +572,7 @@ def _compile_fx_inner(
570572
gm: GraphModule,
571573
example_inputs: Sequence[InputType],
572574
**graph_kwargs: Unpack[_CompileFxKwargs],
573-
) -> Union[CompiledFxGraph, str]:
575+
) -> OutputCode:
574576
"""
575577
Inductor API that compiles a single graph.
576578
@@ -630,11 +632,7 @@ def _compile_fx_inner(
630632
):
631633
input._is_inductor_static = True # type: ignore[attr-defined]
632634

633-
# TODO: Remove this short circuit once types are unified here
634-
if aot_mode:
635-
return fx_codegen_and_compile(gm, example_inputs, inputs_to_check, **graph_kwargs) # type: ignore[assignment]
636-
637-
mb_compiled_graph: Optional[CompiledFxGraph] = None
635+
mb_compiled_graph: Optional[OutputCode] = None
638636
key_info = None
639637
cache_info = None
640638
remote_cache = None
@@ -668,31 +666,28 @@ def _compile_fx_inner(
668666
# determined the input is uncacheable)
669667
if cache_info is None or cache_info["cache_state"] == "bypass":
670668
assert mb_compiled_graph is None
671-
r = fx_codegen_and_compile(
669+
mb_compiled_graph = fx_codegen_and_compile(
672670
gm, example_inputs, inputs_to_check, **graph_kwargs
673671
)
674-
assert not isinstance(r, str) # due to aot test
675-
mb_compiled_graph = r
676672

677673
# CACHE MISS: Compile the graph and save to cache
678674
elif cache_info["cache_state"] == "miss":
679675
assert mb_compiled_graph is None
680676
assert key_info is not None
681677
TritonBundler.begin_compile()
682678
try:
683-
r = fx_codegen_and_compile(
679+
mb_compiled_graph = fx_codegen_and_compile(
684680
gm, example_inputs, inputs_to_check, **graph_kwargs
685681
)
686-
assert not isinstance(r, str) # due to aot test
687-
mb_compiled_graph = r
688682
assert mb_compiled_graph is not None
689683
mb_compiled_graph._time_taken_ns = time.time_ns() - start_time
690684
cache_key = key_info[0]
691685
mb_compiled_graph._fx_graph_cache_key = cache_key
692686
(
693-
mb_compiled_graph._triton_bundle,
687+
triton_bundle,
694688
triton_bundler_meta,
695689
) = TritonBundler.collect()
690+
mb_compiled_graph.set_triton_bundle(triton_bundle)
696691
finally:
697692
TritonBundler.end_compile()
698693
if triton_bundler_meta is not None:
@@ -782,7 +777,7 @@ def fx_codegen_and_compile(
782777
# in explicitly because it's nontrivial to compute
783778
inputs_to_check: Sequence[int],
784779
**graph_kwargs: Unpack[_CompileFxKwargs],
785-
) -> Union[CompiledFxGraph, str]:
780+
) -> OutputCode:
786781
# Sorry about the mess, we need graph_kwargs to continue to be able
787782
# to propagate it further on
788783
# TODO: _CompileFxKwargs actually has stronger types than in the
@@ -979,6 +974,10 @@ def log_graph_runnable() -> str:
979974

980975
_check_triton_bf16_support(graph)
981976

977+
# TODO: The switching between AOT mode and not here is a bit
978+
# messy, but it's localized to the block of code below so I'm
979+
# not going to touch it for now
980+
982981
compiled_fn: Any
983982

984983
with dynamo_timed(
@@ -1058,8 +1057,10 @@ def log_graph_runnable() -> str:
10581057
V.graph.disable_cudagraphs_reason = disable
10591058

10601059
if V.aot_compilation is True:
1061-
return compiled_fn
1060+
assert isinstance(compiled_fn, (str, list))
1061+
return CompiledAOTI(compiled_fn)
10621062

1063+
# TODO: Hoist this above V.aot_compilation
10631064
if cudagraphs and not V.graph.disable_cudagraphs_reason:
10641065
from torch._inductor.cudagraph_utils import (
10651066
check_lowering_disable_cudagraph,
@@ -1069,7 +1070,7 @@ def log_graph_runnable() -> str:
10691070
check_lowering_disable_cudagraph(V.graph.device_node_mapping)
10701071
)
10711072

1072-
compiled_graph = CompiledFxGraph(
1073+
return CompiledFxGraph(
10731074
compiled_fn,
10741075
graph,
10751076
gm,
@@ -1085,8 +1086,6 @@ def log_graph_runnable() -> str:
10851086
boxed_forward_device_index,
10861087
)
10871088

1088-
return compiled_graph
1089-
10901089

10911090
def get_input_idxs_to_check(
10921091
inputs: Sequence[InputType],
@@ -1326,11 +1325,9 @@ def compile_fx_aot(
13261325
config_patches=config_patches,
13271326
)
13281327

1329-
assert isinstance(compiled_artifacts, str) or (
1330-
isinstance(compiled_artifacts, list)
1331-
and isinstance(compiled_artifacts[0], str)
1332-
)
1333-
return compiled_artifacts
1328+
assert isinstance(compiled_artifacts, CompiledAOTI)
1329+
1330+
return compiled_artifacts.filename
13341331

13351332

13361333
_graph_counter = count(0)
@@ -1487,7 +1484,7 @@ def get_cuda_device_context(gm: torch.fx.GraphModule) -> ContextManager[None]:
14871484
def compile_fx(
14881485
model_: GraphModule,
14891486
example_inputs_: Sequence[InputType],
1490-
inner_compile: Callable[..., Any] = compile_fx_inner,
1487+
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
14911488
config_patches: Optional[Dict[str, Any]] = None,
14921489
decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
14931490
) -> Union[Callable[[List[object]], Sequence[torch.Tensor]], str, List[str]]:
@@ -1631,7 +1628,7 @@ def fw_compiler_base(
16311628
model: GraphModule,
16321629
example_inputs: List[InputType],
16331630
is_inference: bool,
1634-
) -> CompiledFxGraph:
1631+
) -> OutputCode:
16351632
with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
16361633
if is_inference:
16371634
# partition_fn won't be called
@@ -1737,7 +1734,7 @@ def partition_fn(
17371734
@compile_time_strobelight_meta(phase_name="backward")
17381735
def bw_compiler(
17391736
model: GraphModule, example_inputs: List[InputType]
1740-
) -> Union[CompiledFxGraph, str]:
1737+
) -> OutputCode:
17411738
from torch._dynamo.convert_frame import compile_lock
17421739

17431740
with dynamo_utils.dynamo_timed(

torch/_inductor/output_code.py

+50
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
Set,
3636
Tuple,
3737
TYPE_CHECKING,
38+
Union,
3839
)
3940
from typing_extensions import TypeAlias
4041

@@ -73,6 +74,21 @@ def post_compile(
7374
) -> None:
7475
...
7576

77+
# TODO: Not sure if I really want these to be properties, this is easy
78+
# though
79+
#
80+
# TODO: Remove leading underscores
81+
82+
# None if the output is not remote cacheable
83+
_fx_graph_cache_key: Optional[str]
84+
85+
# How long it took to compile this OutputCode, end to end
86+
_time_taken_ns: Optional[int]
87+
88+
# TODO: Get rid of this
89+
def set_triton_bundle(self, triton_bundle: Any) -> None:
90+
...
91+
7692

7793
_StrideExprStr: TypeAlias = str
7894

@@ -300,6 +316,9 @@ def post_compile(
300316
# TODO: Not sure why this isn't just set by default on CompiledFxGraph
301317
self._boxed_call = True
302318

319+
def set_triton_bundle(self, triton_bundle: Any) -> None:
320+
self._triton_bundle = triton_bundle
321+
303322
def get_constants(
304323
self, gm: Optional[torch.fx.GraphModule]
305324
) -> Dict[str, torch.Tensor]:
@@ -323,3 +342,34 @@ def get_constants(
323342

324343
def _typecheck_CompiledFxGraph(h: CompiledFxGraph) -> OutputCode:
325344
return h
345+
346+
347+
@dataclasses.dataclass
348+
class CompiledAOTI:
349+
"""
350+
Class holding an AOTInductor compiled so.
351+
"""
352+
353+
filename: Union[str, List[str]]
354+
355+
# TODO: Figure out if these make sense or not here
356+
_fx_graph_cache_key: Optional[str] = None
357+
_time_taken_ns: Optional[int] = None
358+
359+
def __call__(self, inputs: Sequence[Any]) -> Any:
360+
raise NotImplementedError("NYI")
361+
362+
def post_compile(
363+
self,
364+
example_inputs: Sequence[InputType],
365+
cudagraphs: BoxedBool,
366+
gm: GraphModule,
367+
) -> None:
368+
pass
369+
370+
def set_triton_bundle(self, triton_bundle: Any) -> None:
371+
pass
372+
373+
374+
def _typecheck_CompiledAOTI(h: CompiledAOTI) -> OutputCode:
375+
return h

0 commit comments

Comments
 (0)