Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support constrained decoding #1038

Open
wants to merge 86 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
bf38535
add llguidance based logits processor
Taka152 Oct 31, 2024
c151d52
add unit test
Taka152 Oct 31, 2024
9d5a8a0
constrained decoding fixes (#1023)
mmoskal Nov 1, 2024
48c3e96
add test grammars
Taka152 Nov 1, 2024
d70b849
support cuda
Taka152 Nov 1, 2024
6b90c1c
use tokenize.json to generate token_bytes
Taka152 Nov 4, 2024
bdb9ca4
fix win build
Taka152 Nov 5, 2024
a25de8e
async compute mask
Taka152 Nov 6, 2024
edc0bae
add llguidance build in cmake
Taka152 Nov 6, 2024
09861d7
update windows build
Taka152 Nov 6, 2024
ee94df8
clean cmake
Taka152 Nov 6, 2024
4d077cf
add install rust to GHA
Taka152 Nov 6, 2024
15b20b8
test action
Taka152 Nov 6, 2024
6029510
test win cpu build action
Taka152 Nov 6, 2024
4d8d8a6
update win build action
Taka152 Nov 6, 2024
346f88c
update win build action
Taka152 Nov 6, 2024
c00d8fa
update win build action
Taka152 Nov 6, 2024
39fb7ed
update win build action
Taka152 Nov 6, 2024
8038723
update win build action
Taka152 Nov 6, 2024
324f550
update win build action
Taka152 Nov 6, 2024
8722727
update win build action
Taka152 Nov 6, 2024
d620422
add rust install to workflows
Taka152 Nov 7, 2024
c1ede01
support batch infer
Taka152 Nov 7, 2024
d2f47e2
add corrosion to deps.txt
Taka152 Nov 8, 2024
8deba60
Merge branch 'main' into yingxiong/constrained_decoding
Taka152 Nov 8, 2024
b256e6d
fix merge
Taka152 Nov 8, 2024
e5d6dad
fix bugs
Taka152 Nov 8, 2024
2fd52d2
update linux gpu workflow
Taka152 Nov 8, 2024
a11684b
update linux gpu workfow
Taka152 Nov 8, 2024
56663c0
update linux gpu workflow
Taka152 Nov 8, 2024
8997064
update workflow
Taka152 Nov 8, 2024
ddda727
update workflow
Taka152 Nov 8, 2024
cb55778
update workflows
Taka152 Nov 8, 2024
4cf5b5f
add shared lib of llguidance
Taka152 Nov 15, 2024
18c2f6c
add disable_guidance option
Taka152 Nov 15, 2024
65340f2
fix format
Taka152 Nov 15, 2024
eca06f5
fix win error
Taka152 Nov 15, 2024
b306428
fix segfault
Taka152 Nov 18, 2024
df34b1e
fix segfault and move test
Taka152 Nov 18, 2024
2a5efe1
Merge remote-tracking branch 'origin/main' into yingxiong/constrained…
Taka152 Nov 18, 2024
92251d4
minor fixes
Taka152 Nov 18, 2024
56ab9ee
fix bug when is_stop
Taka152 Nov 20, 2024
03a6bb7
fixes for reviews
Taka152 Nov 20, 2024
29fc868
fix
Taka152 Nov 20, 2024
3b046c3
fix win error
Taka152 Nov 20, 2024
e9c818e
add rust env to dockerfile
Taka152 Nov 20, 2024
e06fb0a
fix dockerfile env
Taka152 Nov 20, 2024
ef141f2
update workflows
Taka152 Nov 20, 2024
13056b6
Update Rust environment in Dockerfiles
Taka152 Nov 20, 2024
22c7c37
Update Rust environment permissions in Dockerfiles
Taka152 Nov 20, 2024
a1186a5
Update Rust installation in Dockerfiles
Taka152 Nov 20, 2024
2c9b02c
revert linux arm workflow
Taka152 Nov 20, 2024
899edf9
Update Rust installation with specific version
Taka152 Nov 22, 2024
2d47c20
fix android error
Taka152 Nov 22, 2024
9a15385
fix for review
Taka152 Nov 25, 2024
ec09868
Merge remote-tracking branch 'origin/main' into yingxiong/constrained…
Taka152 Nov 27, 2024
ff94fe1
fix SetGuidance unit test
Taka152 Dec 6, 2024
88f8ef4
Merge remote-tracking branch 'origin/main' into yingxiong/constrained…
Taka152 Dec 6, 2024
4ca4075
fix format
Taka152 Dec 6, 2024
9849f65
fix to new continuous decoding api
Taka152 Dec 11, 2024
13e5100
remove comments
Taka152 Dec 12, 2024
aea5323
Merge remote-tracking branch 'origin/main' into yingxiong/constrained…
Taka152 Dec 12, 2024
9b5a6ce
fix
Taka152 Dec 12, 2024
a9390e3
fix segfault
Taka152 Dec 12, 2024
fc4b7e9
fix win build
Taka152 Dec 13, 2024
a0710d5
fix win error
Taka152 Dec 13, 2024
7d4d6bb
fix win error
Taka152 Dec 16, 2024
5d175a7
add comments
Taka152 Dec 16, 2024
52ccc8b
fix format
Taka152 Dec 16, 2024
161fcfc
fix bug
Taka152 Dec 17, 2024
ca6d86c
Merge branch 'main' into yingxiong/constrained_decoding
Taka152 Dec 17, 2024
8e03735
suuport build in ios GHA
Taka152 Dec 17, 2024
db3062b
update win azure ci
Taka152 Dec 17, 2024
0cdb4ac
update linux ci
Taka152 Dec 17, 2024
991a012
fix win ci
Taka152 Dec 17, 2024
4fdbea6
fix win ci
Taka152 Dec 17, 2024
e99ae65
fix macos arm
Taka152 Dec 17, 2024
ed96504
fix macos azure ci
Taka152 Dec 17, 2024
4cb9e55
fix for review
Taka152 Dec 18, 2024
e75166a
fix
Taka152 Dec 18, 2024
82d33e5
fix ios ci
Taka152 Dec 18, 2024
3498e0e
disable on ios ci
Taka152 Dec 18, 2024
2dbc6ee
Merge remote-tracking branch 'origin/main' into yingxiong/constrained…
Taka152 Dec 19, 2024
ce846c9
disable by default
Taka152 Dec 19, 2024
c644b26
remove azure ci code
Taka152 Dec 19, 2024
422af28
build and test with use_guidance
Taka152 Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix to new continuous decoding api
  • Loading branch information
Taka152 committed Dec 11, 2024
commit 9849f65970af32a6b247b4a2d9afa40756ba05d7
26 changes: 15 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,6 @@ target_include_directories(onnxruntime-genai-static PUBLIC ${onnxruntime_extensi
target_link_libraries(onnxruntime-genai PRIVATE onnxruntime_extensions)
target_link_libraries(onnxruntime-genai-static PUBLIC onnxruntime_extensions)

if(USE_GUIDANCE)
target_include_directories(onnxruntime-genai PUBLIC ${llguidance_SOURCE_DIR}/parser/)
target_include_directories(onnxruntime-genai-static PUBLIC ${llguidance_SOURCE_DIR}/parser/)
target_link_libraries(onnxruntime-genai PRIVATE llguidance_parser)
target_link_libraries(onnxruntime-genai-static PUBLIC llguidance_parser)
if (WIN32)
# bcrypt is needed for the rust std lib
target_link_libraries(onnxruntime-genai PRIVATE bcrypt)
target_link_libraries(onnxruntime-genai-static PRIVATE bcrypt)
endif()
endif()
target_link_directories(onnxruntime-genai PRIVATE ${ORT_LIB_DIR})

# we keep the shared libraries disconnected on Android as they will come from separate AARs and we don't want to force
Expand All @@ -142,6 +131,8 @@ if(USE_CUDA AND CMAKE_CUDA_COMPILER)
add_library(onnxruntime-genai-cuda SHARED ${generator_cudalib_srcs})
target_include_directories(onnxruntime-genai-cuda PRIVATE ${ORT_HEADER_DIR})
target_include_directories(onnxruntime-genai-cuda PRIVATE ${GENERATORS_ROOT})
# target_include_directories(onnxruntime-genai-cuda PRIVATE ${onnxruntime_extensions_SOURCE_DIR}/include)
# target_include_directories(onnxruntime-genai-cuda PRIVATE ${onnxruntime_extensions_SOURCE_DIR}/shared/api/)
target_link_libraries(onnxruntime-genai-cuda PRIVATE cublasLt cublas curand cufft cudart)
set_target_properties(onnxruntime-genai-cuda PROPERTIES LINKER_LANGUAGE CUDA)
add_dependencies(onnxruntime-genai onnxruntime-genai-cuda)
Expand All @@ -158,6 +149,19 @@ if(USE_CUDA AND CMAKE_CUDA_COMPILER)
endif()
endif()


if(USE_GUIDANCE)
target_include_directories(onnxruntime-genai PUBLIC ${llguidance_SOURCE_DIR}/parser/)
target_include_directories(onnxruntime-genai-static PUBLIC ${llguidance_SOURCE_DIR}/parser/)
target_link_libraries(onnxruntime-genai PRIVATE llguidance_parser)
target_link_libraries(onnxruntime-genai-static PUBLIC llguidance_parser)
if (WIN32)
# bcrypt is needed for the rust std lib
target_link_libraries(onnxruntime-genai PRIVATE bcrypt)
target_link_libraries(onnxruntime-genai-static PRIVATE bcrypt)
endif()
endif()

if(CMAKE_GENERATOR_TOOLSET MATCHES "Visual Studio")
target_link_options(onnxruntime-genai PRIVATE "/CETCOMPAT")
target_compile_options(onnxruntime-genai PRIVATE "/sdl")
Expand Down
12 changes: 11 additions & 1 deletion src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "sequences.h"
#include "models/model.h"
#include "models/decoder_only.h"
#include "logits_processor.h"
#include "search.h"
#include "cpu/interface.h"
#include "cuda/interface.h"
Expand Down Expand Up @@ -266,6 +267,7 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_
search_ = CreateSearch(params);
state_ = model.CreateState(search_->GetSequenceLengths(), params); // Search sequence lengths set when creating state

logits_processor_ = CreateLogitsProcessor(*state_);
// Temporary solution for multimodal and whisper models
if (!params.aux_input_ids.empty() && params.aux_input_ids.data() != nullptr) {
AppendTokens(params.aux_input_ids);
Expand Down Expand Up @@ -302,7 +304,10 @@ void Generator::AppendTokens(const cpu_span<int32_t> input_ids) {
void Generator::ComputeLogits(DeviceSpan<int32_t> next_tokens) {
if (computed_logits_)
throw std::runtime_error("ComputeLogits called again without calling AppendTokens or GenerateNextToken first");

if (last_action_ == Action::generated && logits_processor_) {
auto next_tokens_span = next_tokens.CopyDeviceToCpu();
logits_processor_->CommitTokens(next_tokens_span);
}
auto logits = state_->Run(search_->GetSequenceLength(), next_tokens, search_->GetNextIndices());
if (g_log.enabled && g_log.model_logits) {
auto& stream = Log("model_logits");
Expand Down Expand Up @@ -364,6 +369,10 @@ void Generator::GenerateNextToken() {
search_->AppendTokens(next_tokens);
ComputeLogits(next_tokens);
}
if (logits_processor_) {
auto logits = GetLogits();
logits_processor_->ProcessLogits(logits);
}
computed_logits_ = false;
auto& search = search_->params_->search;
search_->ApplyMinLength(search.min_length);
Expand Down Expand Up @@ -417,6 +426,7 @@ void Generator::RewindToLength(size_t new_length) {
throw std::runtime_error("RewindToLength must be called with new_length=0 when batch_size > 1");
search_->RewindTo(new_length);
state_->RewindTo(new_length);
logits_processor_->Reset();
computed_logits_ = false;
last_action_ = Action::rewound;
}
Expand Down
2 changes: 2 additions & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ struct Model;
struct State;
struct Search;
struct Tokenizer;
struct LogitsProcessor;

template <typename T>
DeviceSpan<T> WrapTensor(DeviceInterface& device, OrtValue& value) {
Expand Down Expand Up @@ -128,6 +129,7 @@ struct Generator : LeakChecked<Generator> {
std::shared_ptr<const Model> model_;
std::unique_ptr<State> state_;
std::unique_ptr<Search> search_;
std::unique_ptr<LogitsProcessor> logits_processor_;
bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio

private:
Expand Down
240 changes: 240 additions & 0 deletions src/logits_processor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cstdint>
#include <cstdio>
#include <fstream>
#include <memory>
#include <sstream>
#include <string>
#include <sys/types.h>

#include "generators.h"
#if USE_GUIDANCE
#include "llguidance.h"
#endif

#if USE_CUDA
#include "cuda/cuda_common.h"
#include "models/kernels.h"
#endif

#include "logits_processor.h"

namespace Generators {

#if USE_GUIDANCE
GuidanceLogitsProcessor::GuidanceLogitsProcessor(const State& state)
: vocab_size_(state.params_->config.model.vocab_size), eos_token_(state.params_->config.model.eos_token_id), device_type_(state.params_->device_type), batch_size_(state.params_->search.batch_size) {
guidance_type_ = state.params_->guidance_type;
guidance_data_ = state.params_->guidance_data;
if (guidance_type_.empty() || guidance_data_.empty()) {
throw std::runtime_error("Guidance type and data must be provided");
}

if (guidance_type_ != "json_schema" && guidance_type_ != "regex" && guidance_type_ != "grammar") {
throw std::runtime_error("Unsupported guidance type: " + std::string(guidance_type_));
}

auto tokenize_fn = (LlgTokenizeFn) + [](const void* user_data, const uint8_t* bytes,
size_t bytes_len, uint32_t* output_tokens, size_t output_tokens_len) -> unsigned long {
const TokenizeData* tokenize_data = reinterpret_cast<const TokenizeData*>(user_data);
auto output_ids = tokenize_partial(reinterpret_cast<const Tokenizer*>(tokenize_data->tokenizer), tokenize_data->prefix_len, bytes, bytes_len);
size_t output_size = std::min(output_tokens_len, output_ids.size());
for (size_t i = 0; i < output_size; i++) {
output_tokens[i] = output_ids[i];
}
return static_cast<unsigned long>(output_ids.size());
};

auto tokenizer_path = state.params_->config.config_path.string();
fs::path tokenizer_path_fs(tokenizer_path);
fs::path json_path(tokenizer_path_fs / kDefaultVocabFile);
std::ifstream json_file(json_path.string());
std::stringstream json_buffer;
json_buffer << json_file.rdbuf();
std::string json_data = json_buffer.str();
tokenizer_ = state.model_.CreateTokenizer();
auto prefix_len = tokenizer_->Encode(kTokenizePrefixStr).size();
tokenize_data_ = {tokenizer_.get(), prefix_len};
LlgTokenizerInit tokenizer_init = {
static_cast<uint32_t>(vocab_size_), // vocab_size
eos_token_, // eos_token
nullptr, // token_lens
nullptr, // token_bytes
json_data.c_str(), // tokenizer_json config data
false, // tokenize_assumes_string
tokenize_fn, // tokenize_fn
false, // use_approximate_greedy_tokenize_fn
&tokenize_data_, // user_data
};

char error_buf[128];
llg_tokenizer_ = std::unique_ptr<LlgTokenizer, LlgTokenizerDeleter>(llg_new_tokenizer(&tokenizer_init, error_buf, sizeof(error_buf)));
if (!llg_tokenizer_) {
throw std::runtime_error("Error creating llg_tokenizer: " + std::string(error_buf));
}

llg_constraints_.resize(batch_size_);
for (int i = 0; i < batch_size_; i++) {
LlgConstraintInit constraint_init;
llg_constraint_init_set_defaults(&constraint_init, llg_tokenizer_.get());
LlgConstraint* constraint_ptr;
if (guidance_type_ == "json_schema") {
constraint_ptr = llg_new_constraint_json(&constraint_init, guidance_data_.data());
} else if (guidance_type_ == "regex") {
constraint_ptr = llg_new_constraint_regex(&constraint_init, guidance_data_.data());
} else {
constraint_ptr = llg_new_constraint(&constraint_init, guidance_data_.data());
}
if (llg_get_error(constraint_ptr) != nullptr) {
std::string error_message = llg_get_error(constraint_ptr);
llg_free_constraint(constraint_ptr);
throw std::runtime_error("Error creating grammar: " + error_message);
}
llg_constraints_[i] = std::unique_ptr<LlgConstraint, LlgConstraintDeleter>(constraint_ptr);
}

mask_future_ = std::async(std::launch::async, [&]() {
return ComputeMask();
});

#if USE_CUDA
if (state.params_->device_type == DeviceType::CUDA) {
cuda_logits_mask_ptr_ = state.params_->p_device->Allocate<uint32_t>(batch_size_ * vocab_size_ / 32);
}
cuda_stream_ = state.params_->cuda_stream;
#endif
}

std::vector<std::vector<uint32_t>> GuidanceLogitsProcessor::ComputeMask() {
std::vector<std::vector<uint32_t>> masks;
for (int i = 0; i < batch_size_; i++) {
LlgMaskResult mask_result;
auto error = llg_compute_mask(llg_constraints_[i].get(), &mask_result);
if (error != 0) {
std::string error_message = llg_get_error(llg_constraints_[i].get());
throw std::runtime_error("Error computing mask: " + error_message);
}

std::vector<uint32_t> mask;
if (mask_result.is_stop) {
std::cout << "should stop" << std::endl;
mask = std::vector<uint32_t>((vocab_size_ - 1) / 32 + 1, 0);
uint32_t eos_mask32 = 1 << (eos_token_ % 32);
mask[eos_token_ / 32] = eos_mask32;
} else {
mask.reserve((vocab_size_ - 1) / 32 + 1);
for (int i = 0; i < (vocab_size_ - 1) / 32 + 1; i++) {
mask.push_back(mask_result.sample_mask[i]);
}
}
masks.push_back(mask);
}
return masks;
}

void GuidanceLogitsProcessor::CommitTokens(std::span<int32_t> tokens) {
for (int i = 0; i < batch_size_; i++) {
LlgCommitResult commit_result;
auto error = llg_commit_token(llg_constraints_[i].get(), static_cast<uint32_t>(tokens[i]), &commit_result);
if (error != 0) {
std::string error_message = llg_get_error(llg_constraints_[i].get());
throw std::runtime_error("Error committing tokens: " + error_message);
}
}
mask_future_ = std::async(std::launch::async, [&]() {
return ComputeMask();
});
masks_.clear();
}

std::vector<std::vector<uint32_t>> GuidanceLogitsProcessor::GetMask() {
if (masks_.empty()) {
masks_ = mask_future_.get();
}
return masks_;
}

void GuidanceLogitsProcessor::ProcessLogits(DeviceSpan<float> logits) {
auto masks = GetMask();

#if USE_CUDA
if (device_type_ == DeviceType::CUDA) {
for (int i = 0; i < masks.size(); i++) {
cudaMemcpyAsync(cuda_logits_mask_ptr_.Span().data() + (i * vocab_size_ / 32), masks.at(i).data(),
masks.at(i).size() * sizeof(uint32_t), ::cudaMemcpyHostToDevice, cuda_stream_);
}
cuda::LaunchAddLogitsMask(logits.Span().data(), batch_size_, vocab_size_, cuda_logits_mask_ptr_.Span().data(), cuda_stream_);
return;
}
#else
size_t vocab_index = 0;

auto logits_span = logits.Span();
for (int index = 0; index < batch_size_; index++) {
auto subspan = logits_span.subspan(vocab_index, vocab_size_);
auto& mask = masks[index];
for (size_t i = 0; i < vocab_size_; i++) {
// mask is a 32-bit integer, where each bit corresponds to a token in the vocabulary.
// If the bit is set, the corresponding token is masked (i.e., its logit is set to the lowest possible value).
subspan[i] = mask[i / 32] & (1 << (i % 32)) ? subspan[i] : std::numeric_limits<float>::lowest();
}
vocab_index += vocab_size_;
}
#endif
}

void GuidanceLogitsProcessor::Reset() {
masks_.clear();
llg_constraints_.clear();
llg_constraints_.resize(batch_size_);
for (int i = 0; i < batch_size_; i++) {
LlgConstraintInit constraint_init;
llg_constraint_init_set_defaults(&constraint_init, llg_tokenizer_.get());
LlgConstraint* constraint_ptr;
if (guidance_type_ == "json_schema") {
constraint_ptr = llg_new_constraint_json(&constraint_init, guidance_data_.data());
} else if (guidance_type_ == "regex") {
constraint_ptr = llg_new_constraint_regex(&constraint_init, guidance_data_.data());
} else {
constraint_ptr = llg_new_constraint(&constraint_init, guidance_data_.data());
}
if (llg_get_error(constraint_ptr) != nullptr) {
std::string error_message = llg_get_error(constraint_ptr);
llg_free_constraint(constraint_ptr);
throw std::runtime_error("Error creating grammar: " + error_message);
}
llg_constraints_[i] = std::unique_ptr<LlgConstraint, LlgConstraintDeleter>(constraint_ptr);
}

mask_future_ = std::async(std::launch::async, [&]() {
return ComputeMask();
});
}

std::vector<int32_t> GuidanceLogitsProcessor::tokenize_partial(const Tokenizer* tokenizer, const size_t prefix_len,
const uint8_t* bytes, size_t bytes_len) {
// add prefix to tokenize for partial tokenization, it will produce ids more stable
std::string input_string = kTokenizePrefixStr;
input_string.reserve(bytes_len + 2);
for (size_t i = 0; i < bytes_len; i++) {
input_string.push_back(bytes[i]);
}
std::vector<int32_t> output_ids = tokenizer->Encode(input_string.c_str());
return std::vector<int32_t>(output_ids.begin() + prefix_len, output_ids.end());
}

#endif

std::unique_ptr<LogitsProcessor> CreateLogitsProcessor(const State& state) {
#if USE_GUIDANCE
if (!state.params_->guidance_type.empty() && !state.params_->guidance_data.empty()) {
return std::make_unique<GuidanceLogitsProcessor>(state);
}

#endif

Log("warning", "No supported LogitsProcessor found. e.g. to use guidance, build with use_guidance=true");
return nullptr;
}
} // namespace Generators
Loading