From 828d12056d31233f3ad5623b2d212f081e3f7dd5 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Sat, 7 May 2022 19:54:59 -0700 Subject: [PATCH] tests: Update fp16 test for new API Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- tests/accuracy/test_fp16_accuracy.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/accuracy/test_fp16_accuracy.cpp b/tests/accuracy/test_fp16_accuracy.cpp index dd68202312..f32f8c1df0 100644 --- a/tests/accuracy/test_fp16_accuracy.cpp +++ b/tests/accuracy/test_fp16_accuracy.cpp @@ -25,8 +25,10 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { } torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100; - std::vector> input_shape = {{32, 3, 32, 32}}; - auto compile_spec = torch_tensorrt::ts::CompileSpec({input_shape}); + std::vector input_shape = {32, 3, 32, 32}; + auto input = torch_tensorrt::Input(input_shape); + input.dtype = torch::kF16; + auto compile_spec = torch_tensorrt::ts::CompileSpec({input}); compile_spec.enabled_precisions.insert(torch::kF16); auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);