Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create dedicated build for training api #14136

Merged
merged 20 commits into from
Jan 11, 2023
Merged

Conversation

askhade
Copy link
Contributor

@askhade askhade commented Jan 5, 2023

Description

Enable creating dedicated build for on device training. With this PR we can build a lean binary for on device training using flag --enable_training_apis. This binary includes only the essentials like training ops, optimizers etc and NOT features like Aten fallback, strided tensors, gradient builders etc . This binary also removes all the deprecated components like training::TrainingSession and OrtTrainer etc

Motivation and Context

This enables our partners to create a lean binary for on device training.

@askhade askhade requested a review from a team as a code owner January 5, 2023 00:57
@askhade askhade changed the title [WIP] Creating dedicated build for training api Create dedicated build for training api Jan 6, 2023
cmake/CMakeLists.txt Show resolved Hide resolved
@@ -429,7 +429,7 @@ if(onnxruntime_ENABLE_ATEN)
FetchContent_Populate(dlpack)
endif()

if(onnxruntime_ENABLE_TRAINING)
if(onnxruntime_ENABLE_TRAINING OR (onnxruntime_ENABLE_TRAINING_APIS AND onnxruntime_BUILD_UNIT_TESTS))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if we have ENABLE_TRAINING but not onnxruntime_BUILD_UNIT_TESTS, we still enable the following flag?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It is being used in onnxruntime_training_runner executable as well as onnxruntime_training_mnist and onnxruntime_training_gpt2... All this is deprecated code and we can simply remove onnxruntime_ENABLE_TRAINING from there once that code is removed. I will add this as a comment in the cmake.

@@ -72,6 +72,27 @@ if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING)
"${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h"
)
endif()

if (onnxruntime_ENABLE_TRAINING_APIS AND NOT onnxruntime_ENABLE_TRAINING)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a little confusing, as the definition in the build flag part indicate the ENABLE_TRAINING seems to be a superset of ENABLE_TRAINING_API...

I guess the idea here is the following list file only used for ort training C++ API, but not in ort training python api (ORTModule), right? if that is the case, could we make it the python API build and C++ training API build a explicit flag, instead of mixed with ENABLE_TRAINING.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe onnxruntime_ENABLE_TRAINING_APIS is for on-device training, so it does not want to include those ORT trainer C++ codes. But yes, onnxruntime_ENABLE_TRAINING_APIS from its name is a subset of ENABLE_TRAINING, I recalled it is called TRAINING_ON_DEVICE previously.

cmake/onnxruntime_optimizer.cmake Outdated Show resolved Hide resolved
cmake/onnxruntime_optimizer.cmake Outdated Show resolved Hide resolved
onnxruntime/core/framework/session_state.h Outdated Show resolved Hide resolved
onnxruntime/core/framework/session_state.h Show resolved Hide resolved
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_CORE
#include <onnx/defs/attr_proto_util.h>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the whole file is only for training build, could we just put it under the training folder, and only include it in training build in the cmake, instead of have the ifdef in the code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is added by me, I put it here for easily enabling it for inferencing later.

@@ -34,13 +34,14 @@
#if defined(ENABLE_TRAINING_OPS)
#include "orttraining/core/graph/training_op_defs.h"
#endif
#ifdef ENABLE_TRAINING_CORE
#include "orttraining/core/graph/loss_function_registry.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need these for ORTModule build? I thought it is only useful for training c++ api, but ENABLE_TRAINING_CORE seems to be shared by both side.

Copy link
Contributor Author

@askhade askhade Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally this was part of ENABLE_TRAINING so I think they are needed (will check). BTW we don't do any dedicated build for ortmodule, we simply do a full training build, so will need these.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe "orttraining/core/graph/loss_function_registry.h" is not needed for training_api (on-device training)

@@ -471,7 +471,9 @@ if (onnxruntime_USE_CUDA)
onnxruntime_add_include_to_target(onnxruntime_providers_cuda onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers)
if (onnxruntime_ENABLE_TRAINING_OPS)
onnxruntime_add_include_to_target(onnxruntime_providers_cuda onnxruntime_training)
target_link_libraries(onnxruntime_providers_cuda PRIVATE onnxruntime_training)
if (onnxruntime_ENABLE_TRAINING)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if "if (onnxruntime_ENABLE_TRAINING_OPS)" is true, then "if (onnxruntime_ENABLE_TRAINING)" is ture, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not always...

  1. On Device Training build uses enable_training_apis flag and in this case enable_training will not be true
  2. There is 1 more scenario where training ops are included in the inference build... I am not 100% sure how it is used but this enable_trianing_ops macro was first added for this scenario.

cmake/onnxruntime_unittests.cmake Outdated Show resolved Hide resolved
onnxruntime/core/framework/session_state.h Outdated Show resolved Hide resolved
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_CORE
#include <onnx/defs/attr_proto_util.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is added by me, I put it here for easily enabling it for inferencing later.

@@ -34,13 +34,14 @@
#if defined(ENABLE_TRAINING_OPS)
#include "orttraining/core/graph/training_op_defs.h"
#endif
#ifdef ENABLE_TRAINING_CORE
#include "orttraining/core/graph/loss_function_registry.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe "orttraining/core/graph/loss_function_registry.h" is not needed for training_api (on-device training)

onnxruntime/core/session/environment.cc Outdated Show resolved Hide resolved
@@ -3650,7 +3650,7 @@ TEST(AttentionTest, DISABLED_Attention_Mask1D_Fp16_B2_FusedNoPadding) {
}
}

#ifndef ENABLE_TRAINING // Prepacking is enabled only on non-training builds
#ifndef ENABLE_TRAINING_CORE // Prepacking is enabled only on non-training builds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need make it clear the relationship between those training macros somewhere. The inferencing guys at least need to know what macro should be used to explicitly turning off some code snippet.

@@ -471,7 +471,9 @@ if (onnxruntime_USE_CUDA)
onnxruntime_add_include_to_target(onnxruntime_providers_cuda onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers)
if (onnxruntime_ENABLE_TRAINING_OPS)
onnxruntime_add_include_to_target(onnxruntime_providers_cuda onnxruntime_training)
target_link_libraries(onnxruntime_providers_cuda PRIVATE onnxruntime_training)
if (onnxruntime_ENABLE_TRAINING)
target_link_libraries(onnxruntime_providers_cuda PRIVATE onnxruntime_training)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why cuda EP need to linked with gradient builder / training agent? I thought in training build, the only impact to cuda ep is to include the additional training kernels.

#ifdef ENABLE_TRAINING_CORE
// <training schemas>
// This can also be moved inside enable_training. Needs more investigation
training::GraphTransformerRegistry::GetInstance().RegisterExternalGraphTransformers();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if i remember correctly, this api is added for Apollo usage, which we don't need it anymore. we can double confirm whether we can remove it, but i believe you don't need it for on device training.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense will cover this in a separate PR which I am working on right now... Will merge this PR.

@askhade askhade merged commit d92c663 into main Jan 11, 2023
@askhade askhade deleted the askhade/dedicated_trainingapi_build branch January 11, 2023 04:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants