diff --git a/example/auto_compression/nlp/run.py b/example/auto_compression/nlp/run.py index e09a92244..013b58262 100644 --- a/example/auto_compression/nlp/run.py +++ b/example/auto_compression/nlp/run.py @@ -4,6 +4,7 @@ import functools from functools import partial import numpy as np +import shutil import paddle import paddle.nn as nn from paddle.io import Dataset, BatchSampler, DataLoader @@ -305,13 +306,17 @@ def main(): if 'HyperParameterOptimization' not in all_config else eval_dataloader, eval_dataloader=eval_dataloader) - ac.compress() + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + for file_name in os.listdir(global_config['model_dir']): if 'json' in file_name or 'txt' in file_name: shutil.copy( os.path.join(global_config['model_dir'], file_name), args.save_dir) + ac.compress() + if __name__ == '__main__': paddle.enable_static() diff --git a/example/auto_compression/pytorch_huggingface/run.py b/example/auto_compression/pytorch_huggingface/run.py index b723be4d3..4da4e703f 100644 --- a/example/auto_compression/pytorch_huggingface/run.py +++ b/example/auto_compression/pytorch_huggingface/run.py @@ -363,7 +363,8 @@ def main(): 'HyperParameterOptimization' not in all_config else eval_dataloader, eval_dataloader=eval_dataloader) - ac.compress() + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) for file_name in os.listdir(global_config['model_dir']): if 'json' in file_name or 'txt' in file_name: @@ -371,6 +372,8 @@ def main(): os.path.join(global_config['model_dir'], file_name), args.save_dir) + ac.compress() + if __name__ == '__main__': paddle.enable_static()