-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create dedicated build for training api #14136
Conversation
…p to true when enable_training is ON
@@ -429,7 +429,7 @@ if(onnxruntime_ENABLE_ATEN) | |||
FetchContent_Populate(dlpack) | |||
endif() | |||
|
|||
if(onnxruntime_ENABLE_TRAINING) | |||
if(onnxruntime_ENABLE_TRAINING OR (onnxruntime_ENABLE_TRAINING_APIS AND onnxruntime_BUILD_UNIT_TESTS)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so if we have ENABLE_TRAINING but not onnxruntime_BUILD_UNIT_TESTS, we still enable the following flag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. It is being used in onnxruntime_training_runner executable as well as onnxruntime_training_mnist and onnxruntime_training_gpt2... All this is deprecated code and we can simply remove onnxruntime_ENABLE_TRAINING from there once that code is removed. I will add this as a comment in the cmake.
cmake/onnxruntime_graph.cmake
Outdated
@@ -72,6 +72,27 @@ if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING) | |||
"${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h" | |||
) | |||
endif() | |||
|
|||
if (onnxruntime_ENABLE_TRAINING_APIS AND NOT onnxruntime_ENABLE_TRAINING) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is a little confusing, as the definition in the build flag part indicate the ENABLE_TRAINING seems to be a superset of ENABLE_TRAINING_API...
I guess the idea here is the following list file only used for ort training C++ API, but not in ort training python api (ORTModule), right? if that is the case, could we make it the python API build and C++ training API build a explicit flag, instead of mixed with ENABLE_TRAINING.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe onnxruntime_ENABLE_TRAINING_APIS is for on-device training, so it does not want to include those ORT trainer C++ codes. But yes, onnxruntime_ENABLE_TRAINING_APIS from its name is a subset of ENABLE_TRAINING, I recalled it is called TRAINING_ON_DEVICE previously.
@@ -1,7 +1,7 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. | |||
// Licensed under the MIT License. | |||
|
|||
#ifdef ENABLE_TRAINING | |||
#ifdef ENABLE_TRAINING_CORE | |||
#include <onnx/defs/attr_proto_util.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the whole file is only for training build, could we just put it under the training folder, and only include it in training build in the cmake, instead of have the ifdef in the code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is added by me, I put it here for easily enabling it for inferencing later.
@@ -34,13 +34,14 @@ | |||
#if defined(ENABLE_TRAINING_OPS) | |||
#include "orttraining/core/graph/training_op_defs.h" | |||
#endif | |||
#ifdef ENABLE_TRAINING_CORE | |||
#include "orttraining/core/graph/loss_function_registry.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need these for ORTModule build? I thought it is only useful for training c++ api, but ENABLE_TRAINING_CORE seems to be shared by both side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally this was part of ENABLE_TRAINING so I think they are needed (will check). BTW we don't do any dedicated build for ortmodule, we simply do a full training build, so will need these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe "orttraining/core/graph/loss_function_registry.h" is not needed for training_api (on-device training)
@@ -471,7 +471,9 @@ if (onnxruntime_USE_CUDA) | |||
onnxruntime_add_include_to_target(onnxruntime_providers_cuda onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers) | |||
if (onnxruntime_ENABLE_TRAINING_OPS) | |||
onnxruntime_add_include_to_target(onnxruntime_providers_cuda onnxruntime_training) | |||
target_link_libraries(onnxruntime_providers_cuda PRIVATE onnxruntime_training) | |||
if (onnxruntime_ENABLE_TRAINING) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if "if (onnxruntime_ENABLE_TRAINING_OPS)" is true, then "if (onnxruntime_ENABLE_TRAINING)" is ture, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not always...
- On Device Training build uses enable_training_apis flag and in this case enable_training will not be true
- There is 1 more scenario where training ops are included in the inference build... I am not 100% sure how it is used but this enable_trianing_ops macro was first added for this scenario.
@@ -1,7 +1,7 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. | |||
// Licensed under the MIT License. | |||
|
|||
#ifdef ENABLE_TRAINING | |||
#ifdef ENABLE_TRAINING_CORE | |||
#include <onnx/defs/attr_proto_util.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is added by me, I put it here for easily enabling it for inferencing later.
@@ -34,13 +34,14 @@ | |||
#if defined(ENABLE_TRAINING_OPS) | |||
#include "orttraining/core/graph/training_op_defs.h" | |||
#endif | |||
#ifdef ENABLE_TRAINING_CORE | |||
#include "orttraining/core/graph/loss_function_registry.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe "orttraining/core/graph/loss_function_registry.h" is not needed for training_api (on-device training)
@@ -3650,7 +3650,7 @@ TEST(AttentionTest, DISABLED_Attention_Mask1D_Fp16_B2_FusedNoPadding) { | |||
} | |||
} | |||
|
|||
#ifndef ENABLE_TRAINING // Prepacking is enabled only on non-training builds | |||
#ifndef ENABLE_TRAINING_CORE // Prepacking is enabled only on non-training builds |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need make it clear the relationship between those training macros somewhere. The inferencing guys at least need to know what macro should be used to explicitly turning off some code snippet.
@@ -471,7 +471,9 @@ if (onnxruntime_USE_CUDA) | |||
onnxruntime_add_include_to_target(onnxruntime_providers_cuda onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers) | |||
if (onnxruntime_ENABLE_TRAINING_OPS) | |||
onnxruntime_add_include_to_target(onnxruntime_providers_cuda onnxruntime_training) | |||
target_link_libraries(onnxruntime_providers_cuda PRIVATE onnxruntime_training) | |||
if (onnxruntime_ENABLE_TRAINING) | |||
target_link_libraries(onnxruntime_providers_cuda PRIVATE onnxruntime_training) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why cuda EP need to linked with gradient builder / training agent? I thought in training build, the only impact to cuda ep is to include the additional training kernels.
#ifdef ENABLE_TRAINING_CORE | ||
// <training schemas> | ||
// This can also be moved inside enable_training. Needs more investigation | ||
training::GraphTransformerRegistry::GetInstance().RegisterExternalGraphTransformers(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if i remember correctly, this api is added for Apollo usage, which we don't need it anymore. we can double confirm whether we can remove it, but i believe you don't need it for on device training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense will cover this in a separate PR which I am working on right now... Will merge this PR.
Description
Enable creating dedicated build for on device training. With this PR we can build a lean binary for on device training using flag --enable_training_apis. This binary includes only the essentials like training ops, optimizers etc and NOT features like Aten fallback, strided tensors, gradient builders etc . This binary also removes all the deprecated components like training::TrainingSession and OrtTrainer etc
Motivation and Context
This enables our partners to create a lean binary for on device training.