@@ -134,7 +134,7 @@ def __repr__(self):
134
134
ops = OpNamespace ()
135
135
136
136
137
- _graph_placeholders = ["inputs" , "sizes" , "scalars" , "hooks" , "packed_data" ]
137
+ _graph_placeholders = ["inputs" , "sizes" , "scalars" , "hooks" ]
138
138
_impure_targets = OrderedSet (
139
139
[
140
140
call_hook ,
@@ -206,13 +206,7 @@ def begin_capture(
206
206
self .fx_tracer .graph = torch .fx .Graph (tracer_cls = PythonKeyTracer )
207
207
self .fx_tracer .tensor_attrs = {}
208
208
self .symnode_proxy_lookup = {}
209
- (
210
- args_proxy ,
211
- self .sizes_proxy ,
212
- self .scalars_proxy ,
213
- self .hooks_proxy ,
214
- self .packed_data_proxy ,
215
- ) = (
209
+ args_proxy , self .sizes_proxy , self .scalars_proxy , self .hooks_proxy = (
216
210
self .fx_tracer .create_proxy ("placeholder" , name , (), {})
217
211
for name in _graph_placeholders
218
212
)
@@ -274,12 +268,7 @@ def begin_capture(
274
268
self .stack .enter_context (
275
269
torch .fx .experimental .symbolic_shapes ._suppress_guards (env )
276
270
)
277
- return (
278
- str (CompileContext .current_compile_id ()),
279
- inputs ,
280
- sizes ,
281
- scalars ,
282
- )
271
+ return str (CompileContext .current_compile_id ()), inputs , sizes , scalars
283
272
284
273
def log_compile_reasons (
285
274
self ,
@@ -578,19 +567,6 @@ def proxy_call_hook(self, hook, *args, **kwargs):
578
567
kwargs ,
579
568
)
580
569
581
- def unpack_hook (self , hook_id , data_id ):
582
- assert self .hooks_proxy is not None
583
- hook = self .hooks_proxy [hook_id ] # type: ignore[index]
584
- data = self .packed_data_proxy [data_id ] # type: ignore[index]
585
- proxy = self .proxy_call_hook (
586
- hook ,
587
- data ,
588
- hook_type = "unpack_hook" ,
589
- )
590
- out = self .allocate_dummy ()
591
- self .bind_objects_to_proxies ([out ], [proxy ])
592
- return out
593
-
594
570
def tensor_pre_hook (self , inputs , hook_id , i : int ):
595
571
assert self .hooks_proxy is not None
596
572
hook = self .hooks_proxy [hook_id ] # type: ignore[index]
@@ -730,9 +706,6 @@ def is_impure(node):
730
706
after = len (self .fx_tracer .graph .nodes )
731
707
verbose_log .debug ("DCE removed %d nodes" , before - after )
732
708
733
- def create_graph_module (self , id ):
734
- return GraphModule (self .fx_tracer .root , self .fx_tracer .graph , id )
735
-
736
709
def end_capture (self , outputs ):
737
710
self .fx_tracer .create_proxy (
738
711
"call_function" ,
@@ -772,7 +745,6 @@ def end_capture(self, outputs):
772
745
).print_readable (print_output = False ),
773
746
)
774
747
self .rename_aot_dispatcher_nodes ()
775
- self .delay_unpack_hook_nodes ()
776
748
self .reorder_tensor_pre_hook_nodes ()
777
749
self .reorder_pre_hook_nodes_to_schedule_asap ()
778
750
self .reorder_accumulate_grad_nodes ()
@@ -791,7 +763,9 @@ def end_capture(self, outputs):
791
763
# should prevent these ops from going into the CA graph.
792
764
self .dce ()
793
765
794
- graph = self .create_graph_module (f"CompiledAutograd{ self .id } " )
766
+ graph = GraphModule (
767
+ self .fx_tracer .root , self .fx_tracer .graph , f"CompiledAutograd{ self .id } "
768
+ )
795
769
set_locals_to_steal (graph , ["inputs" ])
796
770
lazy_graph_code = lazy_format_graph_code (
797
771
"Compiled autograd graph" ,
@@ -807,15 +781,15 @@ def end_capture(self, outputs):
807
781
payload_fn = lambda : graph .print_readable (print_output = False ),
808
782
)
809
783
810
- def runtime_wrapper (compiled_fn , inputs , sizes , scalars , hooks , packed_inputs ):
784
+ def runtime_wrapper (compiled_fn , inputs , sizes , scalars , hooks ):
811
785
global in_compiled_autograd_region
812
786
try :
813
787
in_compiled_autograd_region = True
814
788
for i in runtime_inputs_to_move :
815
789
inputs [i ] = inputs [i ].pin_memory ().cuda (non_blocking = True )
816
790
817
791
with _disable (), make_compile_context (self .id ):
818
- return compiled_fn (inputs , sizes , scalars , hooks , packed_inputs )
792
+ return compiled_fn (inputs , sizes , scalars , hooks )
819
793
finally :
820
794
in_compiled_autograd_region = False
821
795
@@ -964,19 +938,6 @@ def reorder_accumulate_grad_nodes(self):
964
938
if getitem_node is not None :
965
939
arg .append (getitem_node )
966
940
967
- def delay_unpack_hook_nodes (self ):
968
- """
969
- We can delay unpack hooks until they are needed, even later than in the eager autograd engine.
970
- """
971
- for node in self .fx_tracer .graph .find_nodes (
972
- op = "call_function" , target = call_hook
973
- ):
974
- if node .kwargs .get ("hook_type" , None ) != "unpack_hook" :
975
- continue
976
-
977
- first_user = min (node .users )
978
- first_user .prepend (node )
979
-
980
941
def reorder_tensor_pre_hook_nodes (self ):
981
942
"""
982
943
Usage of AOTAutograd causes all the tensor_pre_hook nodes to get pushed
0 commit comments