Skip to content

Commit b97a786

Browse files
ezyangpytorchmergebot
authored andcommitted
Inline compile_to_fn at its only call site (pytorch#141691)
Stacked on pytorch#141689 Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#141691 Approved by: https://github.com/jansel ghstack dependencies: pytorch#141681, pytorch#141683, pytorch#141685, pytorch#141688, pytorch#141689
1 parent 9e4723c commit b97a786

File tree

3 files changed

+53
-72
lines changed

3 files changed

+53
-72
lines changed

test/inductor/test_codecache.py

+13-34
Original file line numberDiff line numberDiff line change
@@ -673,45 +673,24 @@ def test_inductor_counters(self):
673673
"""
674674
Test that we bump the inductor counters on a cache hit.
675675
"""
676-
compile_to_fn = GraphLowering.compile_to_fn
677676

678-
counter_name = "a_test_counter"
679-
counter_incr = 7
680-
681-
def bump_counter(self):
682-
# Mock that bumps some arbitrary test counter by a set amount, then calls
683-
# the original GraphLowering.compile_to_fn.
684-
counters["inductor"][counter_name] += counter_incr
685-
return compile_to_fn(self)
686-
687-
with mock.patch.object(GraphLowering, "compile_to_fn", bump_counter):
688-
689-
def fn(a, b):
690-
return torch.mm(a, b)
677+
def fn(a, b):
678+
return torch.mm(a, b)
691679

692-
a = torch.rand(8, 32, device="cpu")
693-
b = torch.rand(32, 8, device="cpu")
680+
a = torch.rand(8, 32, device="cpu")
681+
b = torch.rand(32, 8, device="cpu")
694682

695-
compiled_fn = torch.compile(fn)
683+
compiled_fn = torch.compile(fn)
696684

697-
# Verify the "miss" case.
698-
counter_val = 2
699-
counters["inductor"][counter_name] = counter_val
700-
self.assertEqual(fn(a, b), compiled_fn(a, b))
701-
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
702-
self.assertEqual(
703-
counters["inductor"][counter_name], counter_val + counter_incr
704-
)
685+
# Verify the "miss" case.
686+
self.assertEqual(fn(a, b), compiled_fn(a, b))
687+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
705688

706-
# Verify the "hit" case.
707-
self.reset()
708-
counter_val = 5
709-
counters["inductor"][counter_name] = counter_val
710-
self.assertEqual(fn(a, b), compiled_fn(a, b))
711-
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
712-
self.assertEqual(
713-
counters["inductor"][counter_name], counter_val + counter_incr
714-
)
689+
# Verify the "hit" case.
690+
self.reset()
691+
counter_val = 5
692+
self.assertEqual(fn(a, b), compiled_fn(a, b))
693+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
715694

716695
@config.patch({"fx_graph_cache": True})
717696
@config.patch({"fx_graph_remote_cache": False})

torch/_inductor/compile_fx.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
)
5353
from torch._functorch import config as functorch_config
5454
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
55-
from torch._inductor.codecache import code_hash, FxGraphCache
55+
from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log
5656
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo
5757
from torch._inductor.debug import save_args_for_compile_fx_inner
5858
from torch._inductor.output_code import (
@@ -978,7 +978,45 @@ def log_graph_runnable() -> str:
978978
output_strides.append(None)
979979

980980
_check_triton_bf16_support(graph)
981-
compiled_fn = graph.compile_to_fn()
981+
982+
compiled_fn: Any
983+
984+
with dynamo_timed(
985+
"GraphLowering.compile_to_fn", log_pt2_compile_event=True
986+
):
987+
if graph.aot_mode:
988+
from .codecache import AotCodeCompiler
989+
990+
assert graph.cpp_wrapper, "AOT mode only supports C++ wrapper"
991+
code, linemap = graph.codegen_with_cpp_wrapper()
992+
output_code_log.debug("Output code: \n%s", code)
993+
994+
serialized_extern_kernel_nodes = None
995+
if graph.extern_kernel_nodes:
996+
serialized_extern_kernel_nodes = (
997+
graph.extern_node_serializer(graph.extern_kernel_nodes)
998+
)
999+
output_code_log.debug(
1000+
"Serialized Extern Kernel Nodes: \n%s",
1001+
serialized_extern_kernel_nodes,
1002+
)
1003+
1004+
additional_files = graph.wrapper_code.additional_files
1005+
1006+
with dynamo_timed(
1007+
"AotCodeCompiler.compile", log_pt2_compile_event=True
1008+
):
1009+
# Directly return the file path with the compiled code
1010+
compiled_fn = AotCodeCompiler.compile(
1011+
graph,
1012+
code,
1013+
serialized_extern_kernel_nodes,
1014+
device_type=graph.device_type,
1015+
additional_files=additional_files,
1016+
)
1017+
else:
1018+
compiled_fn = graph.compile_to_module().call
1019+
9821020
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
9831021
metrics.num_bytes_accessed += num_bytes
9841022
metrics.node_runtimes += node_runtimes

torch/_inductor/graph.py

-36
Original file line numberDiff line numberDiff line change
@@ -2059,42 +2059,6 @@ def _compile_to_module(self) -> ModuleType:
20592059
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
20602060
return mod
20612061

2062-
def compile_to_fn(self) -> Any:
2063-
with dynamo_timed("GraphLowering.compile_to_fn", log_pt2_compile_event=True):
2064-
return self._compile_to_fn()
2065-
2066-
def _compile_to_fn(self) -> Any:
2067-
if self.aot_mode:
2068-
from .codecache import AotCodeCompiler
2069-
2070-
assert self.cpp_wrapper, "AOT mode only supports C++ wrapper"
2071-
code, linemap = self.codegen_with_cpp_wrapper()
2072-
output_code_log.debug("Output code: \n%s", code)
2073-
2074-
serialized_extern_kernel_nodes = None
2075-
if self.extern_kernel_nodes:
2076-
serialized_extern_kernel_nodes = self.extern_node_serializer(
2077-
self.extern_kernel_nodes
2078-
)
2079-
output_code_log.debug(
2080-
"Serialized Extern Kernel Nodes: \n%s",
2081-
serialized_extern_kernel_nodes,
2082-
)
2083-
2084-
additional_files = self.wrapper_code.additional_files
2085-
2086-
with dynamo_timed("AotCodeCompiler.compile", log_pt2_compile_event=True):
2087-
# Directly return the file path with the compiled code
2088-
return AotCodeCompiler.compile(
2089-
self,
2090-
code,
2091-
serialized_extern_kernel_nodes,
2092-
device_type=self.device_type,
2093-
additional_files=additional_files,
2094-
)
2095-
else:
2096-
return self.compile_to_module().call
2097-
20982062
def get_output_names(self) -> List[str]:
20992063
return [
21002064
node.get_name()

0 commit comments

Comments
 (0)