Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Support dynamic shapes torch.export #28295

Merged
merged 4 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
377 changes: 228 additions & 149 deletions src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -515,12 +515,12 @@ def may_produce_alias(self, in_index: int, out_index: int) -> bool:
# Sometimes pytorch fails to get result with IndexError exception while these indexes exist in node
return False

def inlined_input(self, index):
return []

def is_input_inlined(self, index):
return False

def get_inlined_input_decoder(self, index):
return None

def get_attribute(self, name):
return OVAny(None)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
PYBIND11_OVERRIDE_PURE(bool, TorchDecoder, may_produce_alias, in_index, out_index);
}

ov::OutputVector inlined_input(size_t index) const override {
PYBIND11_OVERRIDE_PURE(ov::OutputVector, TorchDecoder, inlined_input, index);
}

bool is_input_inlined(size_t index) const override {
PYBIND11_OVERRIDE_PURE(bool, TorchDecoder, is_input_inlined, index);
}

std::shared_ptr<TorchDecoder> get_inlined_input_decoder(size_t index) const override {
PYBIND11_OVERRIDE_PURE(std::shared_ptr<TorchDecoder>, TorchDecoder, get_inlined_input_decoder, index);
}

ov::Any get_attribute(const std::string &name) const override{
PYBIND11_OVERRIDE_PURE(ov::Any, TorchDecoder, get_attribute, name);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,13 @@ class TorchDecoder : public IDecoder {
/// \brief Returns if output may contain alias of input in AliasDB
virtual bool may_produce_alias(size_t in_index, size_t out_index) const = 0;

/// Returns new nodes for inputs inlined in the op itself
// Used in Torch.FX decoder
virtual OutputVector inlined_input(size_t index) const = 0;

/// Returns if input is inlined
// Used in Torch.FX decoder
virtual bool is_input_inlined(size_t index) const = 0;

/// Return decoder for inlined input
virtual std::shared_ptr<TorchDecoder> get_inlined_input_decoder(size_t index) const = 0;

/// Returns named attribute as Any. For example kwargs input for FX graph
virtual ov::Any get_attribute(const std::string& name) const = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,44 +51,9 @@ class NodeContext : public frontend::NodeContext {

// Search for input in tensor map and return an output port for already converted op
// TODO: int due to base class uses it, but naturally it should be size_t for PT
Output<Node> get_input(int index) const override {
size_t index_ = static_cast<size_t>(index);
FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index_),
"Input doesn't exist with index: ",
index,
" for operation ",
get_op_type());
auto input = m_decoder_inputs.at(index);
if (input == 0) {
// Case when input can be inlined (possible only for fx decoder)
if (m_decoder->is_input_inlined(index_)) {
auto inlined_input = m_decoder->inlined_input(index_);
FRONT_END_GENERAL_CHECK(inlined_input.size() == 1,
"Incorrect inlined input with index: ",
index,
" for operation ",
get_op_type());
return inlined_input[0];
}
}
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input), "No tensor corresponding input: ", input, " exist.");
return m_tensor_map->at(input);
}
Output<Node> get_input(int index) const override;

Output<Node> get_input(const std::string& name) const override {
FRONT_END_GENERAL_CHECK(has_attribute(name), "Input with name ", name, " doesn't exist");
auto attr = get_attribute_as_any(name);
if (attr.is<Output<Node>>()) {
// Case when input is constant value
return attr.as<Output<Node>>();
} else if (attr.is<type::PyNone>()) {
// None means input is unknown type, most likely a Node
auto input = m_decoder->get_named_input(name);
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input), "No tensor corresponding input: ", input, " exist.");
return m_tensor_map->at(input);
}
FRONT_END_GENERAL_CHECK(false, "Input has type which can't be converted to ov::Node.");
}
Output<Node> get_input(const std::string& name) const override;

Any get_values_from_const_input(int index) const override;

Expand Down
86 changes: 56 additions & 30 deletions src/frontends/pytorch/src/node_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,39 +145,65 @@ std::shared_ptr<ov::Model> NodeContext::convert_subgraph(size_t index) const {
return model;
}

Output<Node> NodeContext::get_input(int index) const {
size_t index_ = static_cast<size_t>(index);
auto input = m_decoder_inputs.at(index);
if (input == 0) {
// Case when input can be inlined (possible only for fx decoder)
if (m_decoder->is_input_inlined(index_)) {
if (m_decoder->input_is_none(index_)) {
// some operations like aten.index.Tensor can have None inputs
auto dummy_decoder = std::make_shared<InternalOpDecoder>("torch::None", 1);
auto fw_node = std::make_shared<PtFrameworkNode>(dummy_decoder, OutputVector{});
auto attrs = fw_node->get_attrs();
attrs["none_value"] = "";
attrs[PtFrameworkNode::failed_conversion_key] =
"None constant cannot be converted to OpenVINO opset and should be removed by consuming "
"operation.";
fw_node->set_attrs(attrs);
return fw_node->output(0);
} else {
auto inlined_decoder = m_decoder->get_inlined_input_decoder(index_);
auto inlined_ctx = NodeContext(inlined_decoder,
m_ext_tensor_map,
m_tensor_map,
m_external_parameters,
m_mutated_tensors,
m_translate_session);
auto inlined_input = m_translate_session->convert_node(inlined_ctx);
FRONT_END_GENERAL_CHECK(inlined_input.size() == 1,
"Incorrect inlined input with index: ",
index,
" for operation ",
get_op_type());
return inlined_input[0];
}
}
}
auto tensor_it = m_tensor_map->find(input);
FRONT_END_GENERAL_CHECK(tensor_it != m_tensor_map->end(), "No tensor corresponding input: ", input, " exist.");
return tensor_it->second;
}

Output<Node> NodeContext::get_input(const std::string& name) const {
FRONT_END_GENERAL_CHECK(has_attribute(name), "Input with name ", name, " doesn't exist");
auto attr = get_attribute_as_any(name);
if (attr.is<Output<Node>>()) {
// Case when input is constant value
return attr.as<Output<Node>>();
} else if (attr.is<type::PyNone>()) {
// None means input is unknown type, most likely a Node
auto input = m_decoder->get_named_input(name);
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input), "No tensor corresponding input: ", input, " exist.");
return m_tensor_map->at(input);
}
FRONT_END_GENERAL_CHECK(false, "Input has type which can't be converted to ov::Node.");
}

OutputVector NodeContext::inputs() const {
OutputVector res;
for (size_t i = 0; i < m_decoder_inputs.size(); i++) {
auto input = m_decoder_inputs.at(i);
if (input == 0) {
// Case when input can be inlined (possible only for fx decoder)
if (m_decoder->is_input_inlined(i)) {
if (input_is_none(i)) {
// some operations like aten.index.Tensor can have None inputs
auto dummy_decoder = std::make_shared<InternalOpDecoder>("torch::None", 1);
auto fw_node = std::make_shared<PtFrameworkNode>(dummy_decoder, OutputVector{});
auto attrs = fw_node->get_attrs();
attrs["none_value"] = "";
attrs[PtFrameworkNode::failed_conversion_key] =
"None constant cannot be converted to OpenVINO opset and should be removed by consuming "
"operation.";
fw_node->set_attrs(attrs);
res.push_back(fw_node->output(0));
} else {
auto inlined_input = m_decoder->inlined_input(i);
FRONT_END_GENERAL_CHECK(inlined_input.size() == 1,
"Incorrect inlined input with index: ",
i,
" for operation ",
get_op_type());
res.push_back(inlined_input[0]);
}
continue;
}
}
auto tensor_it = m_tensor_map->find(input);
FRONT_END_GENERAL_CHECK(tensor_it != m_tensor_map->end(), "No tensor corresponding input: ", input, " exist.");
res.push_back(tensor_it->second);
res.push_back(get_input(static_cast<int>(i)));
}
return res;
}
Expand Down
77 changes: 45 additions & 32 deletions src/frontends/pytorch/src/op/as_strided.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,20 @@ namespace pytorch {
namespace op {

using namespace ov::op;

namespace {
bool compare_strides(const std::tuple<size_t, size_t>& a, const std::tuple<size_t, size_t>& b) {
return std::get<0>(a) > std::get<0>(b);
}
OutputVector translate_as_strided(const NodeContext& context) {
// "aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)"
num_inputs_check(context, 3, 4);
auto decoder = context.get_decoder();
auto input = context.get_input(0);

OutputVector translate_as_strided_common(const NodeContext& context,
const Output<Node>& input,
const std::vector<size_t>& input_strides,
const std::deque<Output<Node>>& sizes,
const std::deque<Output<Node>>& strides) {
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto input_strides = decoder->get_input_strides(0);
PYTORCH_OP_CONVERSION_CHECK(input_strides.size() != 0,
"aten::as_strided: Couldn't retrieve input stride information from torchscript.");

std::vector<size_t> idxs(input_strides.size());
iota(idxs.begin(), idxs.end(), 0);
std::vector<std::tuple<size_t, size_t>> stride_idxs(idxs.size());
Expand All @@ -53,26 +52,6 @@ OutputVector translate_as_strided(const NodeContext& context) {
context.mark_node(v0::Constant::create(element::i32, Shape{transpose_idx.size()}, transpose_idx));
auto transposed_input = context.mark_node(std::make_shared<v1::Transpose>(input, transpose_idx_const));
auto flat_input = context.mark_node(std::make_shared<v1::Reshape>(transposed_input, const_neg_1, false));
std::deque<Output<Node>> sizes;
std::deque<Output<Node>> strides;
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(1).get_node_shared_ptr())) {
auto input_vector = context.const_input<std::vector<int64_t>>(1);
std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) {
auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val}));
sizes.push_front(const_input);
});
} else {
sizes = get_list_as_outputs(context.get_input(1));
}
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(2).get_node_shared_ptr())) {
auto input_vector = context.const_input<std::vector<int64_t>>(2);
std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) {
auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val}));
strides.push_front(const_input);
});
} else {
strides = get_list_as_outputs(context.get_input(2));
}
auto offset = const_0->output(0);
if (!context.input_is_none(3)) {
offset = get_input_as_i32(context, 3);
Expand All @@ -84,12 +63,12 @@ OutputVector translate_as_strided(const NodeContext& context) {
auto strides_length_const = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {strides.size()}));
auto ones_strides_len = context.mark_node(std::make_shared<v0::Tile>(const_1, strides_length_const));
auto indices = const_0;
std::for_each(strides.rbegin(), strides.rend(), [&](Output<Node>& stride) {
std::for_each(strides.rbegin(), strides.rend(), [&](const Output<Node>& stride) {
auto const_num_iter = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {strides_size - i}));
stride = context.mark_node(std::make_shared<v0::Convert>(stride, element::i32));
auto stride_conv = context.mark_node(std::make_shared<v0::Convert>(stride, element::i32));
auto size = sizes.at(strides_size - i);
auto range = context.mark_node(std::make_shared<v4::Range>(const_0, size, const_1, element::i32));
range = context.mark_node(std::make_shared<v1::Multiply>(range, stride));
range = context.mark_node(std::make_shared<v1::Multiply>(range, stride_conv));
auto iteration_shape = context.mark_node(
std::make_shared<v3::ScatterUpdate>(ones_strides_len, const_num_iter, const_neg_1, const_0));
range = context.mark_node(std::make_shared<v1::Reshape>(range, iteration_shape, false));
Expand All @@ -99,7 +78,41 @@ OutputVector translate_as_strided(const NodeContext& context) {
indices = context.mark_node(std::make_shared<v1::Add>(indices, offset));
auto gather = context.mark_node(std::make_shared<v8::Gather>(flat_input, indices, const_0));
return {gather};
}
} // namespace

OutputVector translate_as_strided(const NodeContext& context) {
// "aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)"
num_inputs_check(context, 3, 4);
auto decoder = context.get_decoder();
auto input = context.get_input(0);
auto input_strides = decoder->get_input_strides(0);
PYTORCH_OP_CONVERSION_CHECK(input_strides.size() != 0,
"aten::as_strided: Couldn't retrieve input stride information from torchscript.");

std::deque<Output<Node>> sizes;
std::deque<Output<Node>> strides;
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(1).get_node_shared_ptr())) {
auto input_vector = context.const_input<std::vector<int64_t>>(1);
std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) {
auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val}));
sizes.push_front(const_input);
});
} else {
sizes = get_list_as_outputs(context.get_input(1));
}
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(2).get_node_shared_ptr())) {
auto input_vector = context.const_input<std::vector<int64_t>>(2);
std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) {
auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val}));
strides.push_front(const_input);
});
} else {
strides = get_list_as_outputs(context.get_input(2));
}
return translate_as_strided_common(context, input, input_strides, sizes, strides);
};

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
15 changes: 4 additions & 11 deletions src/frontends/pytorch/src/op/cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,11 @@ OutputVector translate_cat(const NodeContext& context) {
};

OutputVector translate_cat_fx(const NodeContext& context) {
// This translator is only needed to get axis as constant from external scope
num_inputs_check(context, 1, context.get_input_size());
std::deque<Output<Node>> list_elems;
for (size_t i = 0; i < context.get_input_size() - 1; i++) {
list_elems.push_back(context.get_input(static_cast<int>(i)));
}
num_inputs_check(context, 1, 2);
const auto&& list_elems = get_list_as_outputs(context.get_input(0));
int64_t axis = 0;
if (!context.get_input_type(context.get_input_size() - 1).is<type::List>()) {
// axis can be not present and that means that last input will have List type
axis = context.const_input<int64_t>(context.get_input_size() - 1);
} else {
list_elems.push_back(context.get_input(static_cast<int>(context.get_input_size() - 1)));
if (!context.input_is_none(1)) {
axis = context.const_input<int64_t>(1);
}
return translate_cat_common(context, list_elems, axis, true);
};
Expand Down
18 changes: 0 additions & 18 deletions src/frontends/pytorch/src/op/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,6 @@ OutputVector translate_expand_as(const NodeContext& context) {
return {context.mark_node(std::make_shared<v3::Broadcast>(x, shape, BroadcastType::BIDIRECTIONAL))};
};

OutputVector translate_expand_fx(const NodeContext& context) {
auto num_inputs = context.get_input_size();
num_inputs_check(context, 2, num_inputs);
auto x = context.get_input(0);
std::vector<int32_t> shape_vec;
if (context.get_input_type(1).is<type::List>()) {
auto concat = concat_list_from_inputs(context, 1, num_inputs);
return base_expand(context, x, concat);
} else {
auto x = context.get_input(0);
auto sizes = context.get_input(1);
// TODO: figure out what implicit means
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(2) || context.const_input<bool>(2) == false,
"Unexpected value of implicit for expand operation");
return base_expand(context, x, sizes);
}
};

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
12 changes: 3 additions & 9 deletions src/frontends/pytorch/src/op/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,9 @@ OutputVector translate_full(const NodeContext& context) {
OutputVector translate_full_fx(const NodeContext& context) {
// aten.full.default([16, 16], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'),
// pin_memory = False)
auto num_inputs = context.get_input_size();
num_inputs_check(context, 2, num_inputs);
ov::Output<ov::Node> sizes;
if (context.get_input_type(0).is<type::List>()) {
sizes = concat_list_from_inputs(context, 0, num_inputs - 1);
} else {
sizes = context.get_input(0);
}
auto value = context.get_input(static_cast<int>(num_inputs - 1));
num_inputs_check(context, 2, 2);
auto sizes = get_input_concat_if_list(context, 0);
auto value = context.get_input(1);

auto filled_tensor = base_translate_full(context, sizes, value);
if (context.has_attribute("dtype")) {
Expand Down
Loading
Loading