Skip to content

Commit

Permalink
refactor(//tests): Adding more specific tests and restructuring module
Browse files Browse the repository at this point in the history
fallback tests

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Aug 20, 2021
1 parent 532efed commit a5bc3b0
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 146 deletions.
6 changes: 3 additions & 3 deletions tests/core/lowering/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ lowering_test(
)

cc_test(
name = "test_module_level_fallback",
srcs = ["test_module_level_fallback.cpp"],
name = "test_module_fallback_passes",
srcs = ["test_module_fallback_passes.cpp"],
deps = [
"//tests/util",
"//core",
Expand Down Expand Up @@ -63,7 +63,7 @@ test_suite(
name = "lowering_tests",
tests = [
":test_linear_to_addmm",
":test_module_level_fallback",
":test_module_fallback_passes",
":test_operator_aliasing_pass",
":test_remove_contiguous_pass",
":test_remove_detach_pass",
Expand Down
127 changes: 127 additions & 0 deletions tests/core/lowering/test_module_fallback_passes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#include <string>
#include <unordered_set>
#include "core/compiler.h"
#include "core/lowering/lowering.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/script.h"
#include "core/lowering/passes/passes.h"
#include "torch/csrc/jit/passes/freeze_module.h"

TEST(Lowering, NotateModuleForFallbackWorksCorrectly) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("tests/modules/module_fallback_scripted.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
ASSERT_TRUE(false);
}

std::unordered_set<std::string> mods_to_mark;
mods_to_mark.insert("ModuleFallbackSub");

trtorch::core::lowering::passes::NotateModuleForFallback(mod, "", "forward", mods_to_mark);

auto g = mod.get_method("forward").graph();
auto nodes = g->block()->nodes();

bool seen_enter = false;
int64_t enter_count = 0;
int64_t exit_count = 0;
int64_t intermediate_nodes = 0;
for (auto it = nodes.begin(); it != nodes.end(); it++) {
auto n = *it;
if (n->kind() == torch::jit::prim::Enter) {
enter_count++;
auto internal_n = *(++it);
ASSERT_TRUE(internal_n->kind() != torch::jit::prim::Exit);
intermediate_nodes++;
auto end = *(++it);
ASSERT_TRUE(end->kind() == torch::jit::prim::Exit);
exit_count++;
seen_enter = true;
}
}
ASSERT_TRUE(seen_enter);
ASSERT_TRUE(enter_count == 1);
ASSERT_TRUE(intermediate_nodes == 1);
ASSERT_TRUE(exit_count == 1);
}

TEST(Lowering, MarkNodesForFallbackWorksCorrectly) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("tests/modules/module_fallback_scripted.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
ASSERT_TRUE(false);
}

std::unordered_set<std::string> mods_to_mark;
mods_to_mark.insert("ModuleFallbackSub");

trtorch::core::lowering::passes::NotateModuleForFallback(mod, "", "forward", mods_to_mark);
auto mod_ = torch::jit::freeze_module(mod);
auto g = mod_.get_method("forward").graph();
trtorch::core::lowering::passes::MarkNodesForFallback(g, true);
auto nodes = g->block()->nodes();

int64_t num_marked_nodes = 0;

for (auto n : nodes) {
auto has_compile_attribute = n->hasAttribute(c10::Symbol::attr("to_compile"));
if (has_compile_attribute && n->i(c10::Symbol::attr("to_compile")) == (int64_t) false) {
num_marked_nodes++;
}
}

ASSERT_TRUE(num_marked_nodes == 2);
}

TEST(Lowering, LowerAndPartitionSimpleModuleFallbackCorrectly) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("tests/modules/module_fallback_scripted.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
ASSERT_TRUE(false);
}

const std::vector<std::vector<int64_t>> input_shapes = {{1, 1, 16, 16}};
std::vector<torch::jit::IValue> jit_inputs_ivalues;
std::vector<torch::jit::IValue> trt_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randint(5, in_shape, {at::kCUDA});
jit_inputs_ivalues.push_back(in.clone());
trt_inputs_ivalues.push_back(in.clone());
}

std::vector<trtorch::core::ir::Input> input_ranges{trtorch::core::ir::Input({1, 1, 16, 16})};
trtorch::core::CompileSpec cfg(input_ranges);
cfg.partition_info.enabled = true;
cfg.lower_info.forced_fallback_modules.push_back("ModuleFallbackSub");

auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);

auto g = trt_mod.get_method("forward").graph();
auto nodes = g->block()->nodes();
std::size_t curr_node = 0;
for (const auto n : nodes) {
if (curr_node == 5) {
ASSERT_TRUE(n->kind() == torch::jit::aten::conv2d);
ASSERT_TRUE(n->i(c10::Symbol::attr("to_compile")) == (int64_t) false);
} else if (curr_node == 6) {
ASSERT_TRUE(n->kind() == torch::jit::aten::relu);
ASSERT_TRUE(n->i(c10::Symbol::attr("to_compile")) == (int64_t) false);
} else if (curr_node == 7) {
ASSERT_TRUE(n->kind() == torch::jit::prim::GetAttr);
ASSERT_TRUE(n->s(c10::Symbol::attr("name")).find("trt_engine") != std::string::npos);
}
curr_node++;
}

auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
}

142 changes: 0 additions & 142 deletions tests/core/lowering/test_module_level_fallback.cpp

This file was deleted.

18 changes: 18 additions & 0 deletions tests/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ test_suite(
":test_modules_as_engines",
":test_multiple_registered_engines",
":test_serialization",
":test_module_fallback"
],
)

Expand All @@ -27,6 +28,7 @@ test_suite(
":test_modules_as_engines",
":test_multiple_registered_engines",
":test_serialization",
":test_module_fallback"
],
)

Expand Down Expand Up @@ -79,6 +81,22 @@ cc_test(
],
)

cc_test(
name = "test_module_fallback",
srcs = ["test_module_fallback.cpp"],
data = [
"//tests/modules:jit_models",
],
deps = [
"//cpp/api:trtorch",
"//tests/util",
"@googletest//:gtest_main",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
})
)

cc_test(
name = "test_compiled_modules",
srcs = ["test_compiled_modules.cpp"],
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/cpp_api_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class CppAPITests : public testing::TestWithParam<PathAndInSize> {
mod = torch::jit::load(path);
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
ASSERT_TRUE(false);
}
input_shapes = std::get<1>(params);
threshold = std::get<2>(params);
Expand Down
Loading

0 comments on commit a5bc3b0

Please sign in to comment.