Skip to content

Commit

Permalink
Implement separate states
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Apr 22, 2024
1 parent ad9dee2 commit 7a7af22
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 20 deletions.
1 change: 1 addition & 0 deletions modules/llama_cpp_plugin/include/compiled_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class LlamaCppModel : public ICompiledModel {
private:
gguf_context* m_gguf_ctx = nullptr;
std::string m_gguf_fname;
size_t m_num_threads;

llama_model* m_llama_model_ptr = nullptr;
llama_context* m_llama_ctx = nullptr;
Expand Down
5 changes: 3 additions & 2 deletions modules/llama_cpp_plugin/include/infer_request.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ namespace llama_cpp_plugin {

class LlamaCppSyncInferRequest : public ISyncInferRequest {
public:
explicit LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model);
virtual ~LlamaCppSyncInferRequest(){};
explicit LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model, size_t num_threads);
virtual ~LlamaCppSyncInferRequest() override;

virtual void set_tensors_impl(const ov::Output<const ov::Node> port,
const std::vector<ov::SoPtr<ov::ITensor>>& tensors) override;
Expand All @@ -24,6 +24,7 @@ class LlamaCppSyncInferRequest : public ISyncInferRequest {

private:
std::shared_ptr<const LlamaCppModel> m_compiled_model_ptr;
llama_context* m_llama_ctx;
};

} // namespace llama_cpp_plugin
Expand Down
9 changes: 5 additions & 4 deletions modules/llama_cpp_plugin/include/state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ namespace llama_cpp_plugin {
class LlamaCppState : public IVariableState {
public:
LlamaCppState() = delete;
LlamaCppState(const std::shared_ptr<const LlamaCppModel>& model_ptr)
: m_model_ptr(model_ptr),
LlamaCppState(llama_context* llama_context_ptr)
: m_llama_ctx_ptr(llama_context_ptr),
IVariableState("llama_cpp_state") {}
void reset() override {
llama_kv_cache_clear(m_model_ptr->m_llama_ctx);
OPENVINO_ASSERT(m_llama_ctx_ptr != nullptr);
llama_kv_cache_clear(m_llama_ctx_ptr);
}

private:
const std::shared_ptr<const LlamaCppModel>& m_model_ptr;
llama_context* m_llama_ctx_ptr;
};
} // namespace llama_cpp_plugin
} // namespace ov
Expand Down
13 changes: 4 additions & 9 deletions modules/llama_cpp_plugin/src/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <openvino/opsets/opset13.hpp>
#include <openvino/runtime/properties.hpp>
#include <openvino/util/log.hpp>
#include <thread>

#include "infer_request.hpp"
#include "plugin.hpp"
Expand All @@ -18,7 +17,6 @@ namespace ov {
namespace llama_cpp_plugin {

LlamaCppModel::~LlamaCppModel() {
llama_free(m_llama_ctx);
llama_free_model(m_llama_model_ptr);
llama_backend_free();
}
Expand All @@ -27,15 +25,12 @@ LlamaCppModel::LlamaCppModel(const std::string& gguf_fname,
const std::shared_ptr<const IPlugin>& plugin,
size_t num_threads)
: ICompiledModel(nullptr, plugin),
m_gguf_fname(gguf_fname) {
m_gguf_fname(gguf_fname),
m_num_threads(num_threads) {
OPENVINO_DEBUG << "llama_cpp_plugin: loading llama model directly from GGUF... " << std::endl;
llama_model_params mparams = llama_model_default_params();
mparams.n_gpu_layers = 99;
m_llama_model_ptr = llama_load_model_from_file(gguf_fname.c_str(), mparams);
llama_context_params cparams = llama_context_default_params();
cparams.n_threads = num_threads ? num_threads : std::thread::hardware_concurrency();
cparams.n_ctx = 0; // this means that the actual n_ctx will be taken equal to the model's train-time value
m_llama_ctx = llama_new_context_with_model(m_llama_model_ptr, cparams);
OPENVINO_DEBUG << "llama_cpp_plugin: llama model loaded successfully from GGUF..." << std::endl;

auto input_ids = std::make_shared<ov::opset13::Parameter>(ov::element::Type_t::i64, ov::PartialShape({-1, -1}));
Expand Down Expand Up @@ -87,8 +82,8 @@ ov::Any LlamaCppModel::get_property(const std::string& name) const {
}

std::shared_ptr<ov::ISyncInferRequest> LlamaCppModel::create_sync_infer_request() const {
return std::make_shared<LlamaCppSyncInferRequest>(
std::static_pointer_cast<const LlamaCppModel>(shared_from_this()));
return std::make_shared<LlamaCppSyncInferRequest>(std::static_pointer_cast<const LlamaCppModel>(shared_from_this()),
m_num_threads);
}

const std::vector<ov::Output<const ov::Node>>& LlamaCppModel::inputs() const {
Expand Down
21 changes: 16 additions & 5 deletions modules/llama_cpp_plugin/src/infer_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <memory>
#include <openvino/runtime/ivariable_state.hpp>
#include <thread>

#include "llama.h"
#include "openvino/runtime/make_tensor.hpp"
Expand All @@ -24,9 +25,14 @@ void allocate_tensor_impl(ov::SoPtr<ov::ITensor>& tensor,
}
}

LlamaCppSyncInferRequest::LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model)
LlamaCppSyncInferRequest::LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model,
size_t num_threads)
: ov::ISyncInferRequest(compiled_model) {
OPENVINO_DEBUG << "llama_cpp_plugin: infer request ctor called\n";
llama_context_params cparams = llama_context_default_params();
cparams.n_threads = num_threads ? num_threads : std::thread::hardware_concurrency();
cparams.n_ctx = 0; // this means that the actual n_ctx will be taken equal to the model's train-time value
m_llama_ctx = llama_new_context_with_model(compiled_model->m_llama_model_ptr, cparams);
m_compiled_model_ptr = compiled_model;
for (const auto& input : get_inputs()) {
allocate_tensor(input, [input](ov::SoPtr<ov::ITensor>& tensor) {
Expand Down Expand Up @@ -97,8 +103,7 @@ void LlamaCppSyncInferRequest::infer() {
}
}

llama_context* ctx = m_compiled_model_ptr->m_llama_ctx;
int32_t sts = llama_decode(ctx, batch);
int32_t sts = llama_decode(m_llama_ctx, batch);

if (sts != 0) {
OPENVINO_THROW("llama_decode failed with code ", sts);
Expand All @@ -112,7 +117,7 @@ void LlamaCppSyncInferRequest::infer() {
for (size_t batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (size_t seq_idx = 0; seq_idx < sequence_length; seq_idx++) {
size_t pos = batch_idx * sequence_length + seq_idx;
float* logits_from_llama = llama_get_logits_ith(ctx, pos);
float* logits_from_llama = llama_get_logits_ith(m_llama_ctx, pos);
std::copy(logits_from_llama, logits_from_llama + n_vocab, output_tensor_data_ptr + pos * n_vocab);
}
}
Expand All @@ -132,7 +137,13 @@ std::vector<ov::ProfilingInfo> LlamaCppSyncInferRequest::get_profiling_info() co

std::vector<ov::SoPtr<ov::IVariableState>> LlamaCppSyncInferRequest::query_state() const {
OPENVINO_DEBUG << "llama_cpp_plugin: query_state() called\n";
return {std::static_pointer_cast<ov::IVariableState>(std::make_shared<LlamaCppState>(m_compiled_model_ptr))};
return {std::static_pointer_cast<ov::IVariableState>(std::make_shared<LlamaCppState>(m_llama_ctx))};
}

LlamaCppSyncInferRequest::~LlamaCppSyncInferRequest() {
if (m_llama_ctx != nullptr) {
llama_free(m_llama_ctx);
}
}
} // namespace llama_cpp_plugin
} // namespace ov

0 comments on commit 7a7af22

Please sign in to comment.