From ec2232f92b1be795434df440c5099e58029b2eba Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Tue, 1 Mar 2022 19:36:21 -0800 Subject: [PATCH] fix: Fixed failures for host deps sessions Signed-off-by: Anurag Dixit --- noxfile.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/noxfile.py b/noxfile.py index 5cc1b06b17..f7e16a7107 100644 --- a/noxfile.py +++ b/noxfile.py @@ -58,7 +58,8 @@ def train_model(session, use_host_env=False): session.run_always('python', 'export_ckpt.py', - 'vgg16_ckpts/ckpt_epoch25.pth') + 'vgg16_ckpts/ckpt_epoch25.pth', + env={'PYTHONPATH': PYT_PATH}) else: session.run_always('python', 'main.py', @@ -146,13 +147,27 @@ def run_accuracy_tests(session, use_host_env=False): else: session.run_always("python", test) +def copy_model(session): + model_files = [ 'trained_vgg16.jit.pt', + 'trained_vgg16_qat.jit.pt'] + + for file_name in model_files: + src_file = os.path.join(TOP_DIR, str('examples/int8/training/vgg16/') + file_name) + if os.path.exists(src_file): + session.run_always('cp', + '-rpf', + os.path.join(TOP_DIR, src_file), + os.path.join(TOP_DIR, str('tests/py/') + file_name), + external=True) + def run_int8_accuracy_tests(session, use_host_env=False): print("Running accuracy tests") + copy_model(session) session.chdir(os.path.join(TOP_DIR, 'tests/py')) tests = [ - "test_ptq_dataloader.py", + "test_ptq_dataloader_calibrator.py", "test_ptq_to_backend.py", - "test_qat_trt_accuracy", + "test_qat_trt_accuracy.py", ] for test in tests: if use_host_env: @@ -162,9 +177,10 @@ def run_int8_accuracy_tests(session, use_host_env=False): def run_trt_compatibility_tests(session, use_host_env=False): print("Running TensorRT compatibility tests") + copy_model(session) session.chdir(os.path.join(TOP_DIR, 'tests/py')) tests = [ - "test_trt_intercompatibilty.py", + "test_trt_intercompatability.py", "test_ptq_trt_calibrator.py", ] for test in tests: @@ -218,7 +234,7 @@ def run_l1_accuracy_tests(session, use_host_env=False): install_deps(session) install_torch_trt(session) download_models(session, use_host_env) - download_datasets(session, use_host_env) + download_datasets(session) train_model(session, use_host_env) run_accuracy_tests(session, use_host_env) cleanup(session) @@ -228,7 +244,7 @@ def run_l1_int8_accuracy_tests(session, use_host_env=False): install_deps(session) install_torch_trt(session) download_models(session, use_host_env) - download_datasets(session, use_host_env) + download_datasets(session) train_model(session, use_host_env) finetune_model(session, use_host_env) run_int8_accuracy_tests(session, use_host_env) @@ -239,6 +255,8 @@ def run_l2_trt_compatibility_tests(session, use_host_env=False): install_deps(session) install_torch_trt(session) download_models(session, use_host_env) + download_datasets(session) + train_model(session, use_host_env) run_trt_compatibility_tests(session, use_host_env) cleanup(session)