Skip to content

Commit

Permalink
C, C++, and Python API for Adapters (Multi-LoRA and others) (#911)
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Oct 16, 2024
1 parent 7998f13 commit 47132b6
Show file tree
Hide file tree
Showing 23 changed files with 483 additions and 49 deletions.
1 change: 0 additions & 1 deletion cmake/options.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@ option(TEST_PHI2 "Enable tests for Phi2" OFF)

# performance
option(ENABLE_MODEL_BENCHMARK "Build model benchmark program" ON)

5 changes: 4 additions & 1 deletion cmake/ortlib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ if(ORT_HOME)
endif()
else()
# If ORT_HOME is not specified, download the onnxruntime headers and libraries from the nightly feed
set(ORT_VERSION "1.19.2")
set(ORT_VERSION "1.20.0-dev-20241007-1101-407c1ab2e2")
set(ORT_FEED_ORG_NAME "aiinfra")
set(ORT_FEED_PROJECT "2692857e-05ef-43b4-ba9c-ccf1c22c437c")
set(ORT_NIGHTLY_FEED_ID "7982ae20-ed19-4a35-a362-a96ac99897b7")

if(USE_CUDA)
set(ORT_VERSION "1.20.0-dev-20241007-0341-407c1ab2e2")
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
set(ORT_PACKAGE_NAME "Microsoft.ML.OnnxRuntime.Gpu.Linux")
elseif(WIN32)
Expand All @@ -32,8 +33,10 @@ else()
message(FATAL_ERROR "Unsupported platform for CUDA")
endif()
elseif(USE_DML)
set(ORT_VERSION "1.20.0-dev-20241007-1101-407c1ab2e2")
set(ORT_PACKAGE_NAME "Microsoft.ML.OnnxRuntime.DirectML")
elseif(USE_ROCM)
set(ORT_VERSION "1.19.2")
set(ORT_PACKAGE_NAME "Microsoft.ML.OnnxRuntime.Rocm")
else()
set(ORT_PACKAGE_NAME "Microsoft.ML.OnnxRuntime")
Expand Down
73 changes: 73 additions & 0 deletions src/models/adapters.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "../generators.h"
#include "model.h"

namespace Generators {

Adapter::Adapter(const char* adapter_file_path, Ort::Allocator* allocator)
: adapter_{OrtLoraAdapter::Create(fs::path(adapter_file_path).c_str(), *allocator)} {}

const OrtLoraAdapter* Adapter::AcquireRef() {
ref_count_++;

return adapter_.get();
}

void Adapter::ReleaseRef() {
ref_count_--;
if (ref_count_ < 0) {
throw std::runtime_error("Adapter ref count went negative.");
}
}

int32_t Adapter::RefCount() const {
return ref_count_;
}

Adapters::Adapters(const Model* model) : model_{model} {}

void Adapters::LoadAdapter(const char* adapter_file_path, const std::string& adapter_name) {
if (adapters_.find(adapter_name) != adapters_.end()) {
throw std::runtime_error("Adapter already loaded: " + std::string{adapter_name});
}

adapters_.emplace(adapter_name, std::make_unique<Adapter>(adapter_file_path,
model_->allocator_device_ == &model_->allocator_cpu_
? nullptr
: model_->allocator_device_));
}

void Adapters::UnloadAdapter(const std::string& adapter_name) {
auto adapter = adapters_.find(adapter_name);
if (adapter == adapters_.end()) {
throw std::runtime_error("Adapter not found: " + std::string{adapter_name});
}

if (adapter->second->RefCount() > 0) {
throw std::runtime_error("Adapter still in use: " + std::string{adapter_name});
}

adapters_.erase(adapter);
}

const OrtLoraAdapter* Adapters::AcquireAdapter(const std::string& adapter_name) {
auto adapter = adapters_.find(adapter_name);
if (adapter == adapters_.end()) {
throw std::runtime_error("Adapter not found: " + std::string{adapter_name});
}

return adapter->second->AcquireRef();
}

void Adapters::ReleaseAdapter(const std::string& adapter_name) {
auto adapter = adapters_.find(adapter_name);
if (adapter == adapters_.end()) {
throw std::runtime_error("Adapter not found: " + std::string{adapter_name});
}

adapter->second->ReleaseRef();
}

} // namespace Generators
47 changes: 47 additions & 0 deletions src/models/adapters.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once

namespace Generators {

struct Adapter {
Adapter() = delete;
Adapter(const Adapter&) = delete;
Adapter& operator=(const Adapter&) = delete;

Adapter(const char* adapter_file_path, Ort::Allocator* allocator);

const OrtLoraAdapter* AcquireRef();

void ReleaseRef();

int32_t RefCount() const;

private:
int32_t ref_count_{};
std::unique_ptr<OrtLoraAdapter> adapter_;
};

struct Adapters : std::enable_shared_from_this<Adapters> {
Adapters() = delete;
Adapters(const Adapters&) = delete;
Adapters& operator=(const Adapters&) = delete;

Adapters(const Model* model);

void LoadAdapter(const char* adapter_file_path, const std::string& adapter_name);

void UnloadAdapter(const std::string& adapter_name);

const OrtLoraAdapter* AcquireAdapter(const std::string& adapter_name);

void ReleaseAdapter(const std::string& adapter_name);

std::shared_ptr<Adapters> external_owner_;

private:
const Model* model_;
std::unordered_map<std::string, std::unique_ptr<Adapter>> adapters_;
};

} // namespace Generators
2 changes: 1 addition & 1 deletion src/models/decoder_only.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ RoamingArray<float> DecoderOnly_State::Run(int current_length, RoamingArray<int3
}

int batch_size = static_cast<int>(input_ids_.GetShape()[0]);
State::Run(*model_.session_decoder_, *model_.run_options_, batch_size);
State::Run(*model_.session_decoder_, batch_size);

return logits_.Get();
}
Expand Down
2 changes: 1 addition & 1 deletion src/models/decoder_only_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ bool IntermediatePipelineState::SupportsPrimaryDevice() const {

RoamingArray<float> IntermediatePipelineState::Run(int current_length, RoamingArray<int32_t> next_tokens,
RoamingArray<int32_t> next_indices) {
State::Run(*model_.sessions_[id_], *model_.run_options_, params_->BatchBeamSize());
State::Run(*model_.sessions_[id_], params_->BatchBeamSize());

return RoamingArray<float>();
}
Expand Down
2 changes: 1 addition & 1 deletion src/models/gpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ RoamingArray<float> Gpt_State::Run(int current_length, RoamingArray<int32_t> nex
UpdateInputsOutputs(next_tokens, next_indices, current_length);
}

State::Run(*model_.session_decoder_, *model_.run_options_, batch_size);
State::Run(*model_.session_decoder_, batch_size);
return logits_.Get();
}

Expand Down
36 changes: 28 additions & 8 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,21 @@ namespace Generators {

State::State(const GeneratorParams& params, const Model& model)
: model_{model},
params_{params.shared_from_this()} {}
params_{params.shared_from_this()},
run_options_{OrtRunOptions::Create()} {}

void State::Run(OrtSession& session, OrtRunOptions& run_options, int new_batch_size) {
void State::Run(OrtSession& session, int new_batch_size) {
auto captured_graph_info = GetCapturedGraphInfo();

if (first_run_) {
if (captured_graph_info) {
model_.run_options_->AddConfigEntry("gpu_graph_id", "-1");
run_options_->AddConfigEntry("gpu_graph_id", "-1");
}
first_run_ = false;
} else if (captured_graph_info && new_batch_size != current_batch_size_) {
current_batch_size_ = new_batch_size;
auto annotation_id = std::to_string(captured_graph_info->GenerateUniqueAnnotationID(new_batch_size));
model_.run_options_->AddConfigEntry("gpu_graph_id", annotation_id.c_str());
run_options_->AddConfigEntry("gpu_graph_id", annotation_id.c_str());
}

if (g_log.enabled && g_log.model_input_values) {
Expand All @@ -67,7 +68,8 @@ void State::Run(OrtSession& session, OrtRunOptions& run_options, int new_batch_s
DumpTensors(model_, stream, outputs_.data(), output_names_.data(), output_names_.size(), false);
}

session.Run(&run_options, input_names_.data(), inputs_.data(), input_names_.size(), output_names_.data(), outputs_.data(), output_names_.size());
session.Run(run_options_.get(), input_names_.data(), inputs_.data(), input_names_.size(),
output_names_.data(), outputs_.data(), output_names_.size());

if (g_log.enabled && g_log.model_output_values) {
auto& stream = Log("model_output_values");
Expand Down Expand Up @@ -101,6 +103,27 @@ void State::ClearIO() {
outputs_.clear();
}

void State::SetActiveAdapter(Adapters* adapters, const std::string& adapter_name) {
if (!adapters_) {
adapters_ = adapters->shared_from_this();
} else if (adapters_.get() != adapters) {
// Two different instances of Adapters are being used. The Generator state can only manage
// active adapters from a single Adapters container.
throw std::runtime_error("Generator state can only register a single Adapters container.");
}

run_options_->AddActiveLoraAdapter(*adapters_->AcquireAdapter(adapter_name));
adapter_names_.push_back(adapter_name);
}

State::~State() {
if (adapters_) {
for (const auto& adapter_name : adapter_names_) {
adapters_->ReleaseAdapter(adapter_name);
}
}
}

std::vector<int32_t> PadInputs(std::span<std::span<const int32_t>> sequences, int32_t pad_token_id) {
bool pad_right_{true};

Expand Down Expand Up @@ -261,9 +284,6 @@ ONNXTensorElementDataType SessionInfo::GetOutputDataType(const std::string& name
}

Model::Model(std::unique_ptr<Config> config) : config_{std::move(config)} {
// TODO: add function to create run options
run_options_ = OrtRunOptions::Create();

CreateSessionOptions();
}

Expand Down
13 changes: 10 additions & 3 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "utils.h"
#include "prompt_image_processor.h"
#include "audio_processor.h"
#include "adapters.h"

#if USE_DML
#include "dml_provider_factory.h"
Expand All @@ -27,7 +28,7 @@ void CheckResult(extError_t error);

struct State {
State(const GeneratorParams& params, const Model& model_);
virtual ~State() = default;
virtual ~State();

virtual RoamingArray<float> Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices = {}) = 0;
virtual const CapturedGraphInfo* GetCapturedGraphInfo() const { return nullptr; }
Expand All @@ -39,18 +40,25 @@ struct State {

void ClearIO(); // Clear all inputs/outputs

void SetActiveAdapter(Adapters* adapters, const std::string& adapter_name);

const Model& model_;

std::shared_ptr<const GeneratorParams> params_;

std::vector<const char*> input_names_, output_names_;
std::vector<std::string> adapter_names_;
std::vector<OrtValue*> inputs_, outputs_;

protected:
void Run(OrtSession& session, OrtRunOptions& run_options, int new_batch_size); // Uses the inputs below to run
void Run(OrtSession& session, int new_batch_size); // Uses the inputs below to run
bool first_run_{true};

std::unique_ptr<OrtRunOptions> run_options_;

private:
int current_batch_size_{0};
std::shared_ptr<Adapters> adapters_;
};

struct TokenizerStream : LeakChecked<TokenizerStream> {
Expand Down Expand Up @@ -131,7 +139,6 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {

std::unique_ptr<Config> config_;
std::unique_ptr<OrtSessionOptions> session_options_;
std::unique_ptr<OrtRunOptions> run_options_;

cuda_stream_holder cuda_stream_;
DeviceType device_type_{DeviceType::CPU};
Expand Down
6 changes: 3 additions & 3 deletions src/models/multi_modal_vision_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ void EmbeddingState::UpdateInputsAndOutputs(RoamingArray<int32_t> next_tokens) {

RoamingArray<float> EmbeddingState::Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) {
int batch_size = static_cast<int>(input_ids_.GetShape()[0]);
State::Run(*model_.embedding_session_, *model_.run_options_, batch_size);
State::Run(*model_.embedding_session_, batch_size);

return MakeDummy();
}
Expand All @@ -110,7 +110,7 @@ VisionState::VisionState(const MultiModalVisionModel& model, const GeneratorPara

RoamingArray<float> VisionState::Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) {
const int num_images = static_cast<int>(inputs_[0]->GetTensorTypeAndShapeInfo()->GetShape()[0]);
State::Run(*model_.vision_session_, *model_.run_options_, num_images);
State::Run(*model_.vision_session_, num_images);

return MakeDummy();
}
Expand All @@ -128,7 +128,7 @@ DecoderState::DecoderState(const MultiModalVisionModel& model, RoamingArray<int3

RoamingArray<float> DecoderState::Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) {
int batch_size = static_cast<int>(inputs_embeds_.GetShape()[0]);
State::Run(*model_.decoder_session_, *model_.run_options_, batch_size);
State::Run(*model_.decoder_session_, batch_size);
return logits_.Get();
}

Expand Down
12 changes: 12 additions & 0 deletions src/models/onnxruntime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,8 @@ struct OrtRunOptions {
*/
OrtRunOptions& UnsetTerminate();

OrtRunOptions& AddActiveLoraAdapter(const OrtLoraAdapter& adapter); ///< Wraps OrtApi::RunOptionsSetActiveLoraAdapter

static void operator delete(void* p) { Ort::api->ReleaseRunOptions(reinterpret_cast<OrtRunOptions*>(p)); }
Ort::Abstract make_abstract;
};
Expand Down Expand Up @@ -1307,4 +1309,14 @@ struct OrtOp {
size_t output_count);
};

/** \brief LoraAdapter
*
*/
struct OrtLoraAdapter {
static std::unique_ptr<OrtLoraAdapter> Create(const ORTCHAR_T* adapter_file_path, OrtAllocator& allocator); ///< Wraps OrtApi::CreateOrtLoraAdapter

static void operator delete(void* p) { Ort::api->ReleaseLoraAdapter(reinterpret_cast<OrtLoraAdapter*>(p)); }
Ort::Abstract make_abstract;
};

#include "onnxruntime_inline.h"
11 changes: 11 additions & 0 deletions src/models/onnxruntime_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,11 @@ inline OrtRunOptions& OrtRunOptions::UnsetTerminate() {
return *this;
}

inline OrtRunOptions& OrtRunOptions::AddActiveLoraAdapter(const OrtLoraAdapter& adapter) {
Ort::ThrowOnError(Ort::api->RunOptionsAddActiveLoraAdapter(this, &adapter));
return *this;
}

inline std::unique_ptr<OrtCUDAProviderOptionsV2> OrtCUDAProviderOptionsV2::Create() {
OrtCUDAProviderOptionsV2* p;
Ort::ThrowOnError(Ort::api->CreateCUDAProviderOptions(&p));
Expand Down Expand Up @@ -1336,3 +1341,9 @@ inline void OrtOp::Invoke(const OrtKernelContext* context,
Ort::ThrowOnError(Ort::api->InvokeOp(context, this, input_values, static_cast<int>(input_count),
output_values, static_cast<int>(output_count)));
}

inline std::unique_ptr<OrtLoraAdapter> OrtLoraAdapter::Create(const ORTCHAR_T* adapter_file_path, OrtAllocator& allocator) {
OrtLoraAdapter* p;
Ort::ThrowOnError(Ort::api->CreateLoraAdapter(adapter_file_path, &allocator, &p));
return std::unique_ptr<OrtLoraAdapter>{p};
}
4 changes: 2 additions & 2 deletions src/models/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ RoamingArray<float> Whisper_State::Run(int current_length, RoamingArray<int32_t>

switch (run_state_) {
case RunState::Encoder_Decoder_Init:
State::Run(*model_.session_encoder_, *model_.run_options_, batch_size);
State::Run(*model_.session_encoder_, batch_size);

run_state_ = RunState::Decoder_First;
return logits_.Get();
Expand Down Expand Up @@ -308,7 +308,7 @@ RoamingArray<float> Whisper_State::Run(int current_length, RoamingArray<int32_t>
}
}

State::Run(*model_.session_decoder_, *model_.run_options_, batch_size);
State::Run(*model_.session_decoder_, batch_size);
return logits_.Get();
}

Expand Down
Loading

0 comments on commit 47132b6

Please sign in to comment.