From d06ad9462b680163e4d83f8a294979ba23114bd6 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 31 Jan 2023 17:17:26 -0800 Subject: [PATCH] [Bug Fix] Include python training apis when enable_training is enabled (#14485) --- setup.py | 69 +++++++++++++++++++++++------------------ tools/ci_build/build.py | 5 +++ 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/setup.py b/setup.py index b7cd7c753a2e9..0c10195dc3b62 100644 --- a/setup.py +++ b/setup.py @@ -520,41 +520,48 @@ def finalize_options(self): if not enable_training: classifiers.extend(["Operating System :: Microsoft :: Windows", "Operating System :: MacOS"]) -if enable_training: +if enable_training or enable_training_apis: + packages.append("onnxruntime.training") + if enable_training: + packages.extend( + [ + "onnxruntime.training.amp", + "onnxruntime.training.experimental", + "onnxruntime.training.experimental.gradient_graph", + "onnxruntime.training.optim", + "onnxruntime.training.torchdynamo", + "onnxruntime.training.ortmodule", + "onnxruntime.training.ortmodule.experimental", + "onnxruntime.training.ortmodule.experimental.json_config", + "onnxruntime.training.ortmodule.experimental.hierarchical_ortmodule", + "onnxruntime.training.ortmodule.torch_cpp_extensions", + "onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor", + "onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils", + "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator", + "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops", + "onnxruntime.training.utils.data", + ] + ) + + package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor"] = ["*.cc"] + package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc"] + package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator"] = ["*.cc"] + package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops"] = [ + "*.cpp", + "*.cu", + "*.cuh", + "*.h", + ] + packages.extend( [ - "onnxruntime.training", - "onnxruntime.training.amp", - "onnxruntime.training.experimental", - "onnxruntime.training.experimental.gradient_graph", - "onnxruntime.training.optim", - "onnxruntime.training.torchdynamo", - "onnxruntime.training.ortmodule", - "onnxruntime.training.ortmodule.experimental", - "onnxruntime.training.ortmodule.experimental.json_config", - "onnxruntime.training.ortmodule.experimental.hierarchical_ortmodule", - "onnxruntime.training.ortmodule.torch_cpp_extensions", - "onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor", - "onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils", - "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator", - "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops", - "onnxruntime.training.utils.data", + "onnxruntime.training.api", + "onnxruntime.training.onnxblock", + "onnxruntime.training.onnxblock.loss", + "onnxruntime.training.onnxblock.optim", ] ) - if enable_training_apis: - packages.append("onnxruntime.training.api") - packages.append("onnxruntime.training.onnxblock") - packages.append("onnxruntime.training.onnxblock.loss") - packages.append("onnxruntime.training.onnxblock.optim") - package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor"] = ["*.cc"] - package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc"] - package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator"] = ["*.cc"] - package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops"] = [ - "*.cpp", - "*.cu", - "*.cuh", - "*.h", - ] + requirements_file = "requirements-training.txt" # with training, we want to follow this naming convention: # stable: diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 432c93599668c..d552fb71b6547 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2375,6 +2375,11 @@ def main(): if args.use_gdk: args.test = False + # enable_training is a higher level flag that enables all training functionality. + if args.enable_training: + args.enable_training_apis = True + args.enable_training_ops = True + configs = set(args.config) # setup paths and directories