56
56
from torch ._inductor .cudagraph_utils import BoxedDeviceIndex , PlaceholderInfo
57
57
from torch ._inductor .debug import save_args_for_compile_fx_inner
58
58
from torch ._inductor .output_code import (
59
+ CompiledAOTI ,
59
60
CompiledFxGraph ,
60
61
get_expanded_dims ,
61
62
index_expanded_dims ,
63
+ OutputCode ,
62
64
)
63
65
from torch ._inductor .runtime .runtime_utils import cache_dir
64
66
from torch ._inductor .utils import (
@@ -509,15 +511,15 @@ def __call__(
509
511
gm : GraphModule ,
510
512
example_inputs : Sequence [InputType ],
511
513
** kwargs : Unpack [_CompileFxKwargs ],
512
- ) -> Union [ CompiledFxGraph , str ] :
514
+ ) -> OutputCode :
513
515
...
514
516
515
517
516
518
def compile_fx_inner (
517
519
gm : GraphModule ,
518
520
example_inputs : Sequence [InputType ],
519
521
** kwargs : Unpack [_CompileFxKwargs ],
520
- ) -> Union [ CompiledFxGraph , str ] :
522
+ ) -> OutputCode :
521
523
kwargs .setdefault ("cudagraphs" , None )
522
524
kwargs .setdefault ("static_input_idxs" , ())
523
525
kwargs .setdefault ("is_backward" , False )
@@ -570,7 +572,7 @@ def _compile_fx_inner(
570
572
gm : GraphModule ,
571
573
example_inputs : Sequence [InputType ],
572
574
** graph_kwargs : Unpack [_CompileFxKwargs ],
573
- ) -> Union [ CompiledFxGraph , str ] :
575
+ ) -> OutputCode :
574
576
"""
575
577
Inductor API that compiles a single graph.
576
578
@@ -630,11 +632,7 @@ def _compile_fx_inner(
630
632
):
631
633
input ._is_inductor_static = True # type: ignore[attr-defined]
632
634
633
- # TODO: Remove this short circuit once types are unified here
634
- if aot_mode :
635
- return fx_codegen_and_compile (gm , example_inputs , inputs_to_check , ** graph_kwargs ) # type: ignore[assignment]
636
-
637
- mb_compiled_graph : Optional [CompiledFxGraph ] = None
635
+ mb_compiled_graph : Optional [OutputCode ] = None
638
636
key_info = None
639
637
cache_info = None
640
638
remote_cache = None
@@ -668,31 +666,28 @@ def _compile_fx_inner(
668
666
# determined the input is uncacheable)
669
667
if cache_info is None or cache_info ["cache_state" ] == "bypass" :
670
668
assert mb_compiled_graph is None
671
- r = fx_codegen_and_compile (
669
+ mb_compiled_graph = fx_codegen_and_compile (
672
670
gm , example_inputs , inputs_to_check , ** graph_kwargs
673
671
)
674
- assert not isinstance (r , str ) # due to aot test
675
- mb_compiled_graph = r
676
672
677
673
# CACHE MISS: Compile the graph and save to cache
678
674
elif cache_info ["cache_state" ] == "miss" :
679
675
assert mb_compiled_graph is None
680
676
assert key_info is not None
681
677
TritonBundler .begin_compile ()
682
678
try :
683
- r = fx_codegen_and_compile (
679
+ mb_compiled_graph = fx_codegen_and_compile (
684
680
gm , example_inputs , inputs_to_check , ** graph_kwargs
685
681
)
686
- assert not isinstance (r , str ) # due to aot test
687
- mb_compiled_graph = r
688
682
assert mb_compiled_graph is not None
689
683
mb_compiled_graph ._time_taken_ns = time .time_ns () - start_time
690
684
cache_key = key_info [0 ]
691
685
mb_compiled_graph ._fx_graph_cache_key = cache_key
692
686
(
693
- mb_compiled_graph . _triton_bundle ,
687
+ triton_bundle ,
694
688
triton_bundler_meta ,
695
689
) = TritonBundler .collect ()
690
+ mb_compiled_graph .set_triton_bundle (triton_bundle )
696
691
finally :
697
692
TritonBundler .end_compile ()
698
693
if triton_bundler_meta is not None :
@@ -782,7 +777,7 @@ def fx_codegen_and_compile(
782
777
# in explicitly because it's nontrivial to compute
783
778
inputs_to_check : Sequence [int ],
784
779
** graph_kwargs : Unpack [_CompileFxKwargs ],
785
- ) -> Union [ CompiledFxGraph , str ] :
780
+ ) -> OutputCode :
786
781
# Sorry about the mess, we need graph_kwargs to continue to be able
787
782
# to propagate it further on
788
783
# TODO: _CompileFxKwargs actually has stronger types than in the
@@ -979,6 +974,10 @@ def log_graph_runnable() -> str:
979
974
980
975
_check_triton_bf16_support (graph )
981
976
977
+ # TODO: The switching between AOT mode and not here is a bit
978
+ # messy, but it's localized to the block of code below so I'm
979
+ # not going to touch it for now
980
+
982
981
compiled_fn : Any
983
982
984
983
with dynamo_timed (
@@ -1058,8 +1057,10 @@ def log_graph_runnable() -> str:
1058
1057
V .graph .disable_cudagraphs_reason = disable
1059
1058
1060
1059
if V .aot_compilation is True :
1061
- return compiled_fn
1060
+ assert isinstance (compiled_fn , (str , list ))
1061
+ return CompiledAOTI (compiled_fn )
1062
1062
1063
+ # TODO: Hoist this above V.aot_compilation
1063
1064
if cudagraphs and not V .graph .disable_cudagraphs_reason :
1064
1065
from torch ._inductor .cudagraph_utils import (
1065
1066
check_lowering_disable_cudagraph ,
@@ -1069,7 +1070,7 @@ def log_graph_runnable() -> str:
1069
1070
check_lowering_disable_cudagraph (V .graph .device_node_mapping )
1070
1071
)
1071
1072
1072
- compiled_graph = CompiledFxGraph (
1073
+ return CompiledFxGraph (
1073
1074
compiled_fn ,
1074
1075
graph ,
1075
1076
gm ,
@@ -1085,8 +1086,6 @@ def log_graph_runnable() -> str:
1085
1086
boxed_forward_device_index ,
1086
1087
)
1087
1088
1088
- return compiled_graph
1089
-
1090
1089
1091
1090
def get_input_idxs_to_check (
1092
1091
inputs : Sequence [InputType ],
@@ -1326,11 +1325,9 @@ def compile_fx_aot(
1326
1325
config_patches = config_patches ,
1327
1326
)
1328
1327
1329
- assert isinstance (compiled_artifacts , str ) or (
1330
- isinstance (compiled_artifacts , list )
1331
- and isinstance (compiled_artifacts [0 ], str )
1332
- )
1333
- return compiled_artifacts
1328
+ assert isinstance (compiled_artifacts , CompiledAOTI )
1329
+
1330
+ return compiled_artifacts .filename
1334
1331
1335
1332
1336
1333
_graph_counter = count (0 )
@@ -1487,7 +1484,7 @@ def get_cuda_device_context(gm: torch.fx.GraphModule) -> ContextManager[None]:
1487
1484
def compile_fx (
1488
1485
model_ : GraphModule ,
1489
1486
example_inputs_ : Sequence [InputType ],
1490
- inner_compile : Callable [..., Any ] = compile_fx_inner ,
1487
+ inner_compile : Callable [..., OutputCode ] = compile_fx_inner ,
1491
1488
config_patches : Optional [Dict [str , Any ]] = None ,
1492
1489
decompositions : Optional [Dict [OpOverload , Callable [..., Any ]]] = None ,
1493
1490
) -> Union [Callable [[List [object ]], Sequence [torch .Tensor ]], str , List [str ]]:
@@ -1631,7 +1628,7 @@ def fw_compiler_base(
1631
1628
model : GraphModule ,
1632
1629
example_inputs : List [InputType ],
1633
1630
is_inference : bool ,
1634
- ) -> CompiledFxGraph :
1631
+ ) -> OutputCode :
1635
1632
with dynamo_utils .dynamo_timed ("compile_fx.<locals>.fw_compiler_base" ):
1636
1633
if is_inference :
1637
1634
# partition_fn won't be called
@@ -1737,7 +1734,7 @@ def partition_fn(
1737
1734
@compile_time_strobelight_meta (phase_name = "backward" )
1738
1735
def bw_compiler (
1739
1736
model : GraphModule , example_inputs : List [InputType ]
1740
- ) -> Union [ CompiledFxGraph , str ] :
1737
+ ) -> OutputCode :
1741
1738
from torch ._dynamo .convert_frame import compile_lock
1742
1739
1743
1740
with dynamo_utils .dynamo_timed (
0 commit comments