Skip to content

Commit

Permalink
tests: Update fp16 test for new API
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 8, 2022
1 parent c275dd0 commit 828d120
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tests/accuracy/test_fp16_accuracy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
}
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;

std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
auto compile_spec = torch_tensorrt::ts::CompileSpec({input_shape});
std::vector<int64_t> input_shape = {32, 3, 32, 32};
auto input = torch_tensorrt::Input(input_shape);
input.dtype = torch::kF16;
auto compile_spec = torch_tensorrt::ts::CompileSpec({input});
compile_spec.enabled_precisions.insert(torch::kF16);

auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);
Expand Down

0 comments on commit 828d120

Please sign in to comment.