From f4d54fe6337746fda3f9312c8360595748b3a1d0 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 13 Sep 2023 13:55:48 -0500 Subject: [PATCH 1/2] Add fp16 flag to test runenr to check models quantized to fp16 --- tools/test_runner.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tools/test_runner.py b/tools/test_runner.py index 8fb16fc17e8..75ed126b61e 100644 --- a/tools/test_runner.py +++ b/tools/test_runner.py @@ -39,6 +39,17 @@ def parse_args(): type=str, default='gpu', help='Specify where the tests execute (ref, gpu)') + parser.add_argument('--fp16', + action='store_true', + help='Quantize to fp16') + parser.add_argument('--atol', + type=float, + default=1e-3, + help='The absolute tolerance parameter') + parser.add_argument('--rtol', + type=float, + default=1e-3, + help='The relative tolerance parameter') args = parser.parse_args() return args @@ -257,6 +268,8 @@ def main(): # read and compile model model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes) + if args.fp16: + migraphx.quantize_fp16(model) model.compile(migraphx.get_target(target)) # get test cases @@ -279,7 +292,7 @@ def main(): output_data = run_one_case(model, input_data) # check output correctness - ret = check_correctness(gold_outputs, output_data) + ret = check_correctness(gold_outputs, output_data, atol=args.atol, rtol=args.rtol) if ret: correct_num += 1 From e24f190418e7803029e4e7506a73321d54739d24 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 13 Sep 2023 13:56:00 -0500 Subject: [PATCH 2/2] Format --- tools/test_runner.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tools/test_runner.py b/tools/test_runner.py index 75ed126b61e..8bcd9fbc7b5 100644 --- a/tools/test_runner.py +++ b/tools/test_runner.py @@ -39,9 +39,7 @@ def parse_args(): type=str, default='gpu', help='Specify where the tests execute (ref, gpu)') - parser.add_argument('--fp16', - action='store_true', - help='Quantize to fp16') + parser.add_argument('--fp16', action='store_true', help='Quantize to fp16') parser.add_argument('--atol', type=float, default=1e-3, @@ -292,7 +290,10 @@ def main(): output_data = run_one_case(model, input_data) # check output correctness - ret = check_correctness(gold_outputs, output_data, atol=args.atol, rtol=args.rtol) + ret = check_correctness(gold_outputs, + output_data, + atol=args.atol, + rtol=args.rtol) if ret: correct_num += 1