Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Jul 30, 2024
1 parent 8501fb2 commit 4ac241c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 29 deletions.
29 changes: 26 additions & 3 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@
#include <ggml-cuda.h>
#endif

#ifdef GGML_USE_METAL
#include <ggml-metal.h>
#endif

namespace chatglm {

static std::string shape_to_string(ggml_tensor *tensor) {
Expand Down Expand Up @@ -1114,8 +1118,27 @@ ggml_tensor *GLMBlock::forward(ModelContext *mctx, ggml_tensor *hidden_states, g
return output;
}

static void alloc_weight_context(ModelContext *mctx, const ggml_backend_buffer_t sd_buf) {
void *sd_buf_base = ggml_backend_buffer_get_base(sd_buf);
const size_t sd_buf_size = ggml_backend_buffer_get_size(sd_buf);
if (ggml_backend_is_cpu(mctx->backend.get())) {
mctx->buf_w = unique_ggml_backend_buffer_t(ggml_backend_cpu_buffer_from_ptr(sd_buf_base, sd_buf_size));
}
#ifdef GGML_USE_METAL
else if (ggml_backend_is_metal(mctx->backend.get())) {
const size_t max_size = ggml_get_max_tensor_size(mctx->ctx_w.get());
mctx->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_metal_buffer_from_ptr(sd_buf_base, sd_buf_size, max_size));
}
#endif
else {
mctx->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_alloc_ctx_tensors(mctx->ctx_w.get(), mctx->backend.get()));
}
}

void ChatGLMForCausalLM::load_state_dict(const StateDict &sd) {
alloc_weight_context(sd.buf.get());
alloc_weight_context(mctx_.get(), sd.buf.get());

StateDict self_sd = state_dict();
for (auto &item : self_sd.kv) {
Expand Down Expand Up @@ -1255,7 +1278,7 @@ bool ChatGLM2Tokenizer::is_special_id(int id) const {
}

void ChatGLM2ForCausalLM::load_state_dict(const StateDict &sd) {
alloc_weight_context(sd.buf.get());
alloc_weight_context(mctx_.get(), sd.buf.get());

if (config.num_virtual_tokens > 0) {
ggml_tensor *past_key_values = sd.kv.at("past_key_values");
Expand Down Expand Up @@ -1955,7 +1978,7 @@ int ChatGLM4VForCausalLM::count_tokens(const std::vector<int> &input_ids, const
}

void ChatGLM4VForCausalLM::load_state_dict(const StateDict &sd) {
alloc_weight_context(sd.buf.get());
alloc_weight_context(mctx_.get(), sd.buf.get());

auto self_sd = state_dict();
ChatGLM2ForCausalLM::load_state_dict(mctx_.get(), self_sd, sd);
Expand Down
27 changes: 1 addition & 26 deletions chatglm.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@
#include <sstream>
#include <unordered_map>

#ifdef GGML_USE_METAL
#include <ggml-metal.h>
#endif

namespace chatglm {

// ===== common =====
Expand Down Expand Up @@ -451,8 +447,7 @@ using unique_ggml_backend_buffer_t = std::unique_ptr<ggml_backend_buffer, ggml_b
template <typename T>
struct no_init {
T value;
no_init() { /* do nothing */
}
no_init() { /* do nothing */ }
};

struct ModelContext {
Expand Down Expand Up @@ -1003,26 +998,6 @@ class BasicModelForCausalLM : public BaseModelForCausalLM {

void load_prefix_cache(ggml_tensor *past_key_values) { transformer.load_prefix_cache(config, past_key_values); }

protected:
void alloc_weight_context(const ggml_backend_buffer_t sd_buf) const {
void *sd_buf_base = ggml_backend_buffer_get_base(sd_buf);
const size_t sd_buf_size = ggml_backend_buffer_get_size(sd_buf);
if (ggml_backend_is_cpu(mctx_->backend.get())) {
mctx_->buf_w = unique_ggml_backend_buffer_t(ggml_backend_cpu_buffer_from_ptr(sd_buf_base, sd_buf_size));
}
#ifdef GGML_USE_METAL
else if (ggml_backend_is_metal(mctx_->backend.get())) {
const size_t max_size = ggml_get_max_tensor_size(mctx_->ctx_w.get());
mctx_->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_metal_buffer_from_ptr(sd_buf_base, sd_buf_size, max_size));
}
#endif
else {
mctx_->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_alloc_ctx_tensors(mctx_->ctx_w.get(), mctx_->backend.get()));
}
}

public:
Model transformer;
Linear lm_head;
Expand Down

0 comments on commit 4ac241c

Please sign in to comment.