Skip to content

Commit

Permalink
feat(hardtanh): Adds support for the the hard tanh operator
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Mar 30, 2020
1 parent 73bfd4c commit 391af52
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 7 deletions.
53 changes: 49 additions & 4 deletions core/conversion/converters/impl/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ namespace {
auto act##_registrations TRTORCH_UNUSED = \
RegisterNodeConversionPatterns() \
.pattern({"aten::" #act "(Tensor input) -> (Tensor)", \
[](ConversionCtx *ctx, const torch::jit::Node *n, \
args &args) -> bool { return act(ctx, n, args); }}) \
[](ConversionCtx* ctx, const torch::jit::Node* n, \
args& args) -> bool { return act(ctx, n, args); }}) \
.pattern({"aten::" #act "_(Tensor(a!) self) -> (Tensor(a!))", \
[](ConversionCtx *ctx, const torch::jit::Node *n, \
args &args) -> bool { return act(ctx, n, args); }});
[](ConversionCtx* ctx, const torch::jit::Node* n, \
args& args) -> bool { return act(ctx, n, args); }});

//TODO: remove support for conversion of implace operators and move to the functionalization pass

Expand All @@ -41,6 +41,51 @@ convert(sigmoid, kSIGMOID);
convert(tanh, kTANH);

#undef convert

auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto min = args[1].unwrapToDouble();
auto max = args[2].unwrapToDouble();

auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kCLIP);
TRTORCH_CHECK(new_layer, "Unable to create layer for aten::hardtanh");

new_layer->setAlpha(min);
new_layer->setBeta(max);

new_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));

LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}
}).pattern({
//TODO: Remove after functionalization
"aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor(a!))",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto min = args[1].unwrapToDouble();
auto max = args[2].unwrapToDouble();

auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kCLIP);
TRTORCH_CHECK(new_layer, "Unable to create layer for aten::hardtanh");

new_layer->setAlpha(min);
new_layer->setBeta(max);

new_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));

LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}
});



} // namespace
} // namespace impl
} // namespace converters
Expand Down
52 changes: 49 additions & 3 deletions tests/core/converters/test_activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ TEST(Converters, ATenReLUConvertsCorrectly) {
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenSigmoidConvertsCorrectly) {
Expand All @@ -41,7 +41,7 @@ TEST(Converters, ATenSigmoidConvertsCorrectly) {
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenTanhConvertsCorrectly) {
Expand All @@ -61,5 +61,51 @@ TEST(Converters, ATenTanhConvertsCorrectly) {
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

//TODO: Seems like the IR parser is not handling negative numbers well, need to follow up with the PyTorch Team
// TEST(Converters, ATenHardTanhConvertsCorrectly) {
// const auto graph = R"IR(
// graph(%0 : Tensor):
// %1 : float = prim::Constant[value=-1.0]()
// %2 : float = prim::Constant[value=1.0]()
// %3 : Tensor = aten::hardtanh(%0, %1, %2)
// return (%3))IR";

// auto g = std::make_shared<torch::jit::Graph>();
// torch::jit::script::parseIR(graph, &*g);

// auto in = at::randint(-5, 5, {5}, {at::kCUDA});
// auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
// auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

// in = at::clone(in);
// params = trtorch::core::conversion::get_named_params(g->inputs(), {});
// auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

// ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
// }

TEST(Converters, ATenHardTanhCustomRangeConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : float = prim::Constant[value=0.0]()
%2 : float = prim::Constant[value=6.0]()
%3 : Tensor = aten::hardtanh(%0, %1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::script::parseIR(graph, &*g);

auto in = at::randint(-5, 5, {5}, {at::kCUDA});
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

0 comments on commit 391af52

Please sign in to comment.