Skip to content

Commit

Permalink
feat: support aten::div.Tensor_mode
Browse files Browse the repository at this point in the history
Signed-off-by: Ruoqian Guo <[email protected]>
  • Loading branch information
ruoqianguo committed Feb 11, 2022
1 parent c77def0 commit bb3046a
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 2 deletions.
37 changes: 37 additions & 0 deletions core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,43 @@ auto element_wise_registrations TORCHTRT_UNUSED =
div->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Should implement self / other
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);
std::string rounding_mode = "default";
if (args[2].isIValue() && args[2].IValue()->isString()) {
rounding_mode = args[2].unwrapToString();
}
nvinfer1::ILayer* div = nullptr;
if (rounding_mode == "floor") {
div = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV, self, other, util::node_info(n));
} else if (rounding_mode == "trunc") {
// trunc = floor(abs(div)) * sign(div)
auto tmp_div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, "tmp_div");
auto abs = ctx->net->addUnary(*tmp_div->getOutput(0), nvinfer1::UnaryOperation::kABS);
auto floor = ctx->net->addUnary(*abs->getOutput(0), nvinfer1::UnaryOperation::kFLOOR);
auto sign = ctx->net->addUnary(*tmp_div->getOutput(0), nvinfer1::UnaryOperation::kSIGN);
div = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
floor->getOutput(0),
sign->getOutput(0),
util::node_info(n));
} else {
div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
}

TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);

div->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
Expand Down
48 changes: 46 additions & 2 deletions tests/core/conversion/converters/test_element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,19 @@ void pointwise_test_helper(
bool singleInput,
bool dynamicInput = false,
std::vector<int64_t> shape1 = {5},
std::vector<int64_t> shape2 = {5}) {
std::vector<int64_t> shape2 = {5},
bool negative_input = false) {
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph_ir, g.get());

// singleInput case is enabled when elementwise operation is performed
// with an input and a constant embedded in graph
std::vector<at::Tensor> torch_inputs;
torch_inputs.push_back(at::randint(1, 5, shape1, {at::kCUDA}));
if (negative_input) {
torch_inputs.push_back(at::randint(-5, 5, shape1, {at::kCUDA}));
} else {
torch_inputs.push_back(at::randint(1, 5, shape1, {at::kCUDA}));
}
if (!singleInput) {
torch_inputs.push_back(at::randint(1, 5, shape2, {at::kCUDA}));
}
Expand Down Expand Up @@ -141,6 +146,45 @@ TEST(Converters, ATenDivWithScalarConvertsCorrectly) {
pointwise_test_helper(graph, true);
}

TEST(Converters, ATenDivRoundingFloorConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%3 : str = prim::Constant[value="floor"]()
%2 : Tensor = aten::div(%0, %1, %3)
return (%2))IR";
pointwise_test_helper(graph, false, false, {5}, {5}, true);
pointwise_test_helper(graph, false, false, {3, 4}, {4}, true);
pointwise_test_helper(graph, false, false, {4}, {3, 4}, true);
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true);
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
}

TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%3 : str = prim::Constant[value="trunc"]()
%2 : Tensor = aten::div(%0, %1, %3)
return (%2))IR";
pointwise_test_helper(graph, false, false, {5}, {5}, true);
pointwise_test_helper(graph, false, false, {3, 4}, {4}, true);
pointwise_test_helper(graph, false, false, {4}, {3, 4}, true);
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true);
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
}

TEST(Converters, ATenDivRoundingNoneConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%3 : None = prim::Constant()
%2 : Tensor = aten::div(%0, %1, %3)
return (%2))IR";
pointwise_test_helper(graph, false, false, {5}, {5}, true);
pointwise_test_helper(graph, false, false, {3, 4}, {4}, true);
pointwise_test_helper(graph, false, false, {4}, {3, 4}, true);
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true);
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
}

TEST(Converters, ATenPowTensorConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor, %x2.1 : Tensor):
Expand Down

0 comments on commit bb3046a

Please sign in to comment.