Skip to content

Commit 90e3a3d

Browse files
Revert "[ca] trace saved variable unpacking (pytorch#147242)"
This reverts commit 68ddca9. Reverted pytorch#147242 on behalf of https://github.com/wdvr due to failing tests in the slow workflow, see below ([comment](pytorch#147242 (comment)))
1 parent 4d614ba commit 90e3a3d

10 files changed

+120
-511
lines changed

test/inductor/test_compiled_autograd.py

+24-285
Large diffs are not rendered by default.

torch/_dynamo/compiled_autograd.py

+8-47
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __repr__(self):
134134
ops = OpNamespace()
135135

136136

137-
_graph_placeholders = ["inputs", "sizes", "scalars", "hooks", "packed_data"]
137+
_graph_placeholders = ["inputs", "sizes", "scalars", "hooks"]
138138
_impure_targets = OrderedSet(
139139
[
140140
call_hook,
@@ -206,13 +206,7 @@ def begin_capture(
206206
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
207207
self.fx_tracer.tensor_attrs = {}
208208
self.symnode_proxy_lookup = {}
209-
(
210-
args_proxy,
211-
self.sizes_proxy,
212-
self.scalars_proxy,
213-
self.hooks_proxy,
214-
self.packed_data_proxy,
215-
) = (
209+
args_proxy, self.sizes_proxy, self.scalars_proxy, self.hooks_proxy = (
216210
self.fx_tracer.create_proxy("placeholder", name, (), {})
217211
for name in _graph_placeholders
218212
)
@@ -274,12 +268,7 @@ def begin_capture(
274268
self.stack.enter_context(
275269
torch.fx.experimental.symbolic_shapes._suppress_guards(env)
276270
)
277-
return (
278-
str(CompileContext.current_compile_id()),
279-
inputs,
280-
sizes,
281-
scalars,
282-
)
271+
return str(CompileContext.current_compile_id()), inputs, sizes, scalars
283272

284273
def log_compile_reasons(
285274
self,
@@ -578,19 +567,6 @@ def proxy_call_hook(self, hook, *args, **kwargs):
578567
kwargs,
579568
)
580569

581-
def unpack_hook(self, hook_id, data_id):
582-
assert self.hooks_proxy is not None
583-
hook = self.hooks_proxy[hook_id] # type: ignore[index]
584-
data = self.packed_data_proxy[data_id] # type: ignore[index]
585-
proxy = self.proxy_call_hook(
586-
hook,
587-
data,
588-
hook_type="unpack_hook",
589-
)
590-
out = self.allocate_dummy()
591-
self.bind_objects_to_proxies([out], [proxy])
592-
return out
593-
594570
def tensor_pre_hook(self, inputs, hook_id, i: int):
595571
assert self.hooks_proxy is not None
596572
hook = self.hooks_proxy[hook_id] # type: ignore[index]
@@ -730,9 +706,6 @@ def is_impure(node):
730706
after = len(self.fx_tracer.graph.nodes)
731707
verbose_log.debug("DCE removed %d nodes", before - after)
732708

733-
def create_graph_module(self, id):
734-
return GraphModule(self.fx_tracer.root, self.fx_tracer.graph, id)
735-
736709
def end_capture(self, outputs):
737710
self.fx_tracer.create_proxy(
738711
"call_function",
@@ -772,7 +745,6 @@ def end_capture(self, outputs):
772745
).print_readable(print_output=False),
773746
)
774747
self.rename_aot_dispatcher_nodes()
775-
self.delay_unpack_hook_nodes()
776748
self.reorder_tensor_pre_hook_nodes()
777749
self.reorder_pre_hook_nodes_to_schedule_asap()
778750
self.reorder_accumulate_grad_nodes()
@@ -791,7 +763,9 @@ def end_capture(self, outputs):
791763
# should prevent these ops from going into the CA graph.
792764
self.dce()
793765

794-
graph = self.create_graph_module(f"CompiledAutograd{self.id}")
766+
graph = GraphModule(
767+
self.fx_tracer.root, self.fx_tracer.graph, f"CompiledAutograd{self.id}"
768+
)
795769
set_locals_to_steal(graph, ["inputs"])
796770
lazy_graph_code = lazy_format_graph_code(
797771
"Compiled autograd graph",
@@ -807,15 +781,15 @@ def end_capture(self, outputs):
807781
payload_fn=lambda: graph.print_readable(print_output=False),
808782
)
809783

810-
def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks, packed_inputs):
784+
def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks):
811785
global in_compiled_autograd_region
812786
try:
813787
in_compiled_autograd_region = True
814788
for i in runtime_inputs_to_move:
815789
inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
816790

817791
with _disable(), make_compile_context(self.id):
818-
return compiled_fn(inputs, sizes, scalars, hooks, packed_inputs)
792+
return compiled_fn(inputs, sizes, scalars, hooks)
819793
finally:
820794
in_compiled_autograd_region = False
821795

@@ -964,19 +938,6 @@ def reorder_accumulate_grad_nodes(self):
964938
if getitem_node is not None:
965939
arg.append(getitem_node)
966940

967-
def delay_unpack_hook_nodes(self):
968-
"""
969-
We can delay unpack hooks until they are needed, even later than in the eager autograd engine.
970-
"""
971-
for node in self.fx_tracer.graph.find_nodes(
972-
op="call_function", target=call_hook
973-
):
974-
if node.kwargs.get("hook_type", None) != "unpack_hook":
975-
continue
976-
977-
first_user = min(node.users)
978-
first_user.prepend(node)
979-
980941
def reorder_tensor_pre_hook_nodes(self):
981942
"""
982943
Usage of AOTAutograd causes all the tensor_pre_hook nodes to get pushed

torch/csrc/autograd/engine.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -1334,7 +1334,6 @@ auto Engine::execute(
13341334
!AnomalyMode::is_enabled(),
13351335
"compiled_autograd does not support AnomalyMode")
13361336
GraphTaskGuard guard(graph_task);
1337-
CheckpointValidGuard cpvguard(graph_task);
13381337
return (*compiled_autograd)(
13391338
graph_root, *graph_task, accumulate_grad, outputs);
13401339
}

torch/csrc/autograd/python_saved_variable_hooks.cpp

-9
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,6 @@ at::Tensor PySavedVariableHooks::call_unpack_hook() {
4646
// unpack_hook_ will be manually decrefed when the saved variable is released
4747
}
4848

49-
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
50-
PySavedVariableHooks::retrieve_unpack_hook_data() const {
51-
Py_INCREF(unpack_hook_);
52-
Py_INCREF(data_);
53-
return std::make_pair(
54-
c10::SafePyObject(unpack_hook_, getPyInterpreter()),
55-
c10::SafePyObject(data_, getPyInterpreter()));
56-
}
57-
5849
// NOLINTNEXTLINE(bugprone-exception-escape)
5950
PySavedVariableHooks::~PySavedVariableHooks() {
6051
// If python is already dead, leak the wrapped python objects

torch/csrc/autograd/python_saved_variable_hooks.h

-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#pragma once
22

33
#include <ATen/ATen.h>
4-
#include <c10/core/SafePyObject.h>
54
#include <pybind11/pybind11.h>
65
#include <torch/csrc/Export.h>
76
#include <torch/csrc/autograd/python_variable.h>
@@ -18,8 +17,6 @@ struct PySavedVariableHooks : public SavedVariableHooks {
1817
void call_pack_hook(const at::Tensor& tensor) override;
1918
at::Tensor call_unpack_hook() override;
2019
~PySavedVariableHooks() override;
21-
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
22-
retrieve_unpack_hook_data() const override;
2320

2421
private:
2522
PyObject* pack_hook_;

torch/csrc/autograd/saved_variable.cpp

+3-9
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ SavedVariable::SavedVariable(
5959
if (maybe_hooks && !variable.unsafeGetTensorImpl()->is_wrapped_number()) {
6060
save_metadata(variable);
6161
set_hooks_and_pack_data(std::move(maybe_hooks), variable);
62-
TORCH_INTERNAL_ASSERT(!data_.defined());
6362
return;
6463
}
6564

@@ -135,14 +134,9 @@ Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
135134
// We want grad_fn here to provide the most helpful debug message to the user
136135
// if versions don't match
137136

138-
std::shared_ptr<Node> grad_fn;
139-
if (is_inplace_on_view_) {
140-
grad_fn = weak_grad_fn_.lock();
141-
} else if (!hooks_) {
142-
grad_fn = saved_original_ ? data_.grad_fn() : nullptr;
143-
} else {
144-
grad_fn = grad_fn_;
145-
}
137+
auto grad_fn = is_inplace_on_view_ ? weak_grad_fn_.lock()
138+
: !hooks_ ? saved_original_ ? data_.grad_fn() : nullptr
139+
: grad_fn_;
146140

147141
if (!is_leaf_ && !grad_fn) {
148142
// This issue was introduced when we added logic to save the original

torch/csrc/autograd/saved_variable.h

-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#pragma once
22

3-
#include <c10/core/SafePyObject.h>
43
#include <torch/csrc/Export.h>
54
#include <torch/csrc/autograd/forward_grad.h>
65
#include <torch/csrc/autograd/saved_variable_hooks.h>
@@ -54,15 +53,6 @@ class TORCH_API SavedVariable {
5453
return (bool)hooks_;
5554
}
5655

57-
// Used by compiled autograd
58-
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
59-
retrieve_unpack_hook_data() const {
60-
if (!hooks_) {
61-
return std::nullopt;
62-
}
63-
return hooks_->retrieve_unpack_hook_data();
64-
}
65-
6656
private:
6757
// This field contains either:
6858
// 1. the variable to save
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
11
#pragma once
22

33
#include <ATen/core/Tensor.h>
4-
#include <c10/core/SafePyObject.h>
54

65
namespace torch::autograd {
76

87
struct TORCH_API SavedVariableHooks {
98
virtual void call_pack_hook(const at::Tensor& tensor) = 0;
109
virtual at::Tensor call_unpack_hook() = 0;
1110
virtual ~SavedVariableHooks() = default;
12-
virtual std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
13-
retrieve_unpack_hook_data() const {
14-
throw std::runtime_error(
15-
"Compiled Autograd only supports python saved tensor hooks ");
16-
}
1711
};
1812

1913
} // namespace torch::autograd

0 commit comments

Comments
 (0)