Skip to content

Commit

Permalink
llama: support dbrx
Browse files Browse the repository at this point in the history
  • Loading branch information
phymbert committed Apr 6, 2024
1 parent 1d8de31 commit 0722ad1
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 2 deletions.
2 changes: 1 addition & 1 deletion convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_head_count(self.hparams["n_heads"])
self.gguf_writer.add_head_count_kv(attn_config["kv_n_heads"])
self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])
self.gguf_writer.add_clip_kqv(attn_config["clip_qkv"])
self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"])
self.gguf_writer.add_file_type(self.ftype)

self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])
Expand Down
1 change: 0 additions & 1 deletion gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class Attention:
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
CAUSAL = "{arch}.attention.causal"
CLIP_KQV = "{arch}.attention.clip_kqv"

class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
Expand Down
258 changes: 258 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ enum llm_arch {
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R,
LLM_ARCH_DBRX,
LLM_ARCH_UNKNOWN,
};

Expand Down Expand Up @@ -252,6 +253,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_DBRX, "dbrx" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -926,6 +928,23 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_DBRX,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
},
},
{
LLM_ARCH_UNKNOWN,
{
Expand Down Expand Up @@ -1692,6 +1711,7 @@ enum e_model {
MODEL_40B,
MODEL_65B,
MODEL_70B,
MODEL_132B,
MODEL_314B,
MODEL_SMALL,
MODEL_MEDIUM,
Expand Down Expand Up @@ -3961,6 +3981,15 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_DBRX:
{
ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv);

switch (hparams.n_layer) {
case 40: model.type = e_model::MODEL_132B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0;
}

Expand Down Expand Up @@ -4635,6 +4664,46 @@ static bool llm_load_tensors(
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
}
} break;
case LLM_ARCH_DBRX:
{
if (n_expert == 0) {
throw std::runtime_error("DBRX model cannot have zero experts");
}

model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});

// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
// if output is NULL, init from the input tok embed
if (model.output == NULL) {
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
ml.n_created--; // artificial tensor
ml.size_data += ggml_nbytes(model.output);
}
}

for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);

auto & layer = model.layers[i];

layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2,"weight", i), {n_embd});

layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});

layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false);
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert});
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});

layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
}
} break;
case LLM_ARCH_BAICHUAN:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
Expand Down Expand Up @@ -7030,6 +7099,190 @@ struct llm_build_context {
return gf;
}

struct ggml_cgraph * build_dbrx() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);

// mutable variable, needed during the last layer of the computation to skip unused tokens
int32_t n_tokens = this->n_tokens;

const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);

struct ggml_tensor * cur;
struct ggml_tensor * inpL;

inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);

// multiply by embedding_multiplier_scale of 78.38367176906169
inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);

// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();

// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();

for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;

// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);


// self-attention
{
if (model.layers[il].attn_norm_2) {
// DBRX
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm_2,
NULL,
LLM_NORM, cb, il);
cb(cur, "attn_norm_2", il);
}

cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);

if (hparams.f_clamp_kqv > 0.0f) {
cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
cb(cur, "wqkv_clamped", il);
}

struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);

struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);

struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
cb(Vcur, "Vcur", il);

cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}

if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
n_tokens = n_outputs;
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}

struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);

// feed-forward network
// MoE branch
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);

ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
cb(logits, "ffn_moe_logits", il);

ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
cb(probs, "ffn_moe_probs", il);

// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
cb(selected_experts->src[0], "ffn_moe_argsort", il);

ggml_tensor * weights = ggml_get_rows(ctx0,
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
cb(weights, "ffn_moe_weights", il);

weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]

ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
cb(weights_sum, "ffn_moe_weights_sum", il);

weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
cb(weights, "ffn_moe_weights_norm", il);

// compute expert outputs
ggml_tensor * moe_out = nullptr;

for (int i = 0; i < n_expert_used; ++i) {
ggml_tensor * cur_expert;

ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
cb(cur_up, "ffn_moe_up", il);

ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
cb(cur_gate, "ffn_moe_gate", il);

//GeLU
cur_gate = ggml_gelu(ctx0, cur_gate);
cb(cur_gate, "ffn_moe_gelu", il);

cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
cb(cur_expert, "ffn_moe_gate_par", il);

cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
cb(cur_expert, "ffn_moe_down", il);

cur_expert = ggml_mul(ctx0, cur_expert,
ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
cb(cur_expert, "ffn_moe_weighted", il);

if (i == 0) {
moe_out = cur_expert;
} else {
moe_out = ggml_add(ctx0, moe_out, cur_expert);
cb(moe_out, "ffn_moe_out", il);
}
}

cur = moe_out;

cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);

ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
if (layer_dir != nullptr) {
cur = ggml_add(ctx0, cur, layer_dir);
}
cb(cur, "l_out", il);

// input for next layer
inpL = cur;
}

cur = inpL;

cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);

// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);

cb(cur, "result_output", -1);

ggml_build_forward_expand(gf, cur);

return gf;
}

struct ggml_cgraph * build_starcoder() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);

Expand Down Expand Up @@ -9719,6 +9972,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_command_r();
} break;
case LLM_ARCH_DBRX:
{
result = llm.build_dbrx();
} break;
default:
GGML_ASSERT(false);
}
Expand Down Expand Up @@ -14525,6 +14782,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_MINICPM:
case LLM_ARCH_XVERSE:
case LLM_ARCH_COMMAND_R:
case LLM_ARCH_DBRX: // FIXME REVIEW @ggerganov I am not sure what to put here
return LLAMA_ROPE_TYPE_NORM;

// the pairs of head values are offset by n_rot/2
Expand Down

0 comments on commit 0722ad1

Please sign in to comment.