Skip to content

Commit

Permalink
Fix code to work with latest GGML.
Browse files Browse the repository at this point in the history
  • Loading branch information
dranger003 authored and skeskinen committed Aug 26, 2023
1 parent bdc0c5d commit a97ec01
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions bert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ struct bert_ctx * bert_load_from_file(const char *fname)
model_mem_req += n_layer * (n_intermediate * ggml_type_sizef(GGML_TYPE_F32)); // ff_i_b
model_mem_req += n_layer * (n_embd * ggml_type_sizef(GGML_TYPE_F32)); // ff_o_b

model_mem_req += (5 + 16 * n_layer) * 256; // object overhead
model_mem_req += (5 + 16 * n_layer) * 512; // object overhead

printf("%s: ggml ctx size = %6.2f MB\n", __func__, model_mem_req / (1024.0 * 1024.0));
}
Expand Down Expand Up @@ -678,8 +678,8 @@ struct bert_ctx * bert_load_from_file(const char *fname)
// Calculate space requirements for setting up context buffers later
{
bert_vocab_id tokens[] = {0, 1, 2, 3};
// TODO: We set the initial buffer size to 16MB and hope it's enough. Maybe there is a better way to do this?
new_bert->buf_compute.resize(16 * 1024 * 1024);
// TODO: We set the initial buffer size to 32MB and hope it's enough. Maybe there is a better way to do this?
new_bert->buf_compute.resize(32 * 1024 * 1024);
bert_eval(new_bert, 1, tokens, 4, nullptr);
new_bert->max_batch_n = 0;

Expand All @@ -688,7 +688,7 @@ struct bert_ctx * bert_load_from_file(const char *fname)
new_bert->mem_per_input = 1.1 * (new_bert->mem_per_token * N); // add 10% to account for ggml object overhead

}
printf("%s: mem_per_token %zd KB, mem_per_input %lld MB\n", __func__, new_bert->mem_per_token / (1 << 10), new_bert->mem_per_input / (1 << 20));
printf("%s: mem_per_token %d KB, mem_per_input %lld MB\n", __func__, new_bert->mem_per_token / (1 << 10), new_bert->mem_per_input / (1 << 20));

return new_bert;
}
Expand Down Expand Up @@ -779,7 +779,6 @@ void bert_eval_batch(

struct ggml_context *ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;

// Embeddings. word_embeddings + token_type_embeddings + position_embeddings
struct ggml_tensor *token_layer = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
Expand Down Expand Up @@ -916,7 +915,7 @@ void bert_eval_batch(
ggml_tensor *output = inpL;
// run the computation
ggml_build_forward_expand(&gf, output);
ggml_graph_compute(ctx0, &gf);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);


// float *dat = ggml_get_data_f32(output);
Expand Down

0 comments on commit a97ec01

Please sign in to comment.