Skip to content

Commit

Permalink
llama : add Command-R support (ggml-org#6033)
Browse files Browse the repository at this point in the history
Information about the Command-R 35B model (128k context) can be found at:
	https://huggingface.co/CohereForAI/c4ai-command-r-v01

Based on the llama2 model with a few changes:

1) New hyper parameter to scale output logits (logit_scale)
2) Uses LayerNorm instead of RMSNorm
3) Transfomer layers have a single shared LayerNorm that feeds into both the
   self-attention and FFN layers in parallel. There is no post-attention LayerNorm.
4) No support for Rotary Position Embeddings (RoPE) scaling
5) No biases used

Find GGUF files here:
	https://huggingface.co/andrewcanis/c4ai-command-r-v01-GGUF

To convert model to GGUF format yourself:

1) Download Command-R Hugging Face safetensors:
	git lfs install
	git clone https://huggingface.co/CohereForAI/c4ai-command-r-v01

2) Run:
	python3 convert-hf-to-gguf.py --outtype f16 ./c4ai-command-r-v01
  • Loading branch information
acanis authored Mar 15, 2024
1 parent 4e9a7f7 commit 12247f4
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [CodeShell](https://github.com/WisdomShell/codeshell)
- [x] [Gemma](https://ai.google.dev/gemma)
- [x] [Mamba](https://github.com/state-spaces/mamba)
- [x] [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)

**Multimodal models:**

Expand Down
17 changes: 17 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,23 @@ def write_tensors(self):
self.gguf_writer.add_tensor(new_name, data)


@Model.register("CohereForCausalLM")
class CommandR2Model(Model):
model_arch = gguf.MODEL_ARCH.COMMAND_R

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# max_position_embeddings = 8192 in config.json but model was actually
# trained on 128k context length
self.hparams["max_position_embeddings"] = self.hparams["model_max_length"]

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_logit_scale(self.hparams["logit_scale"])
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)


###### CONVERSION LOGIC ######


Expand Down
15 changes: 15 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class LLM:
EXPERT_COUNT = "{arch}.expert_count"
EXPERT_USED_COUNT = "{arch}.expert_used_count"
POOLING_TYPE = "{arch}.pooling_type"
LOGIT_SCALE = "{arch}.logit_scale"

class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
Expand Down Expand Up @@ -121,6 +122,7 @@ class MODEL_ARCH(IntEnum):
GEMMA = auto()
STARCODER2 = auto()
MAMBA = auto()
COMMAND_R = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -187,6 +189,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.COMMAND_R: "command-r",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -579,6 +582,18 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT,
],
MODEL_ARCH.COMMAND_R: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO
}

Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,9 @@ def add_max_alibi_bias(self, bias: float) -> None:
def add_clamp_kqv(self, value: float) -> None:
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)

def add_logit_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)

def add_expert_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)

Expand Down
183 changes: 183 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ enum llm_arch {
LLM_ARCH_GEMMA,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_COMMAND_R,
LLM_ARCH_UNKNOWN,
};

Expand Down Expand Up @@ -243,6 +244,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand All @@ -268,6 +270,7 @@ enum llm_kv {
LLM_KV_EXPERT_COUNT,
LLM_KV_EXPERT_USED_COUNT,
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,

LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
Expand Down Expand Up @@ -332,6 +335,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },

{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
Expand Down Expand Up @@ -838,6 +842,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{
LLM_ARCH_COMMAND_R,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_UNKNOWN,
{
Expand Down Expand Up @@ -1597,6 +1616,7 @@ enum e_model {
MODEL_20B,
MODEL_30B,
MODEL_34B,
MODEL_35B,
MODEL_40B,
MODEL_65B,
MODEL_70B,
Expand Down Expand Up @@ -1643,6 +1663,7 @@ struct llama_hparams {

float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;
float f_logit_scale = 0.0f;

bool causal_attn = true;
bool need_kq_pos = false;
Expand Down Expand Up @@ -3231,6 +3252,7 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_20B: return "20B";
case MODEL_30B: return "30B";
case MODEL_34B: return "34B";
case MODEL_35B: return "35B";
case MODEL_40B: return "40B";
case MODEL_65B: return "65B";
case MODEL_70B: return "70B";
Expand Down Expand Up @@ -3623,6 +3645,15 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_COMMAND_R:
{
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) {
case 40: model.type = e_model::MODEL_35B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0;
}

Expand Down Expand Up @@ -3944,6 +3975,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
Expand Down Expand Up @@ -4918,6 +4950,37 @@ static bool llm_load_tensors(
layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
}
} break;
case LLM_ARCH_COMMAND_R:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});

// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
// init output from the input tok embed
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
ml.n_created--; // artificial tensor
ml.size_data += ggml_nbytes(model.output);
}

for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);

auto & layer = model.layers[i];

layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});

layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});

layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
}
} break;
default:
throw std::runtime_error("unknown architecture");
}
Expand Down Expand Up @@ -8315,6 +8378,121 @@ struct llm_build_context {

return gf;
}

struct ggml_cgraph * build_command_r() {

struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);

const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
const float f_logit_scale = hparams.f_logit_scale;

struct ggml_tensor * cur;
struct ggml_tensor * inpL;

inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);

// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();

// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();

for (int il = 0; il < n_layer; ++il) {

// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM, cb, il);
cb(cur, "attn_norm", il);
struct ggml_tensor * ffn_inp = cur;

// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}

struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}

struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}

Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);

Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);

cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}

struct ggml_tensor * attn_out = cur;

// feed-forward network
{
cur = llm_build_ffn(ctx0, ffn_inp,
model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
}

// add together residual + FFN + self-attention
cur = ggml_add(ctx0, cur, inpL);
cur = ggml_add(ctx0, cur, attn_out);
cb(cur, "l_out", il);

// input for next layer
inpL = cur;
}

cur = inpL;

cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, NULL,
LLM_NORM, cb, -1);
cb(cur, "result_norm", -1);

// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);

if (f_logit_scale) {
cur = ggml_scale(ctx0, cur, f_logit_scale);
}

cb(cur, "result_output", -1);

ggml_build_forward_expand(gf, cur);

return gf;

}
};

static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
Expand Down Expand Up @@ -8497,6 +8675,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_mamba();
} break;
case LLM_ARCH_COMMAND_R:
{
result = llm.build_command_r();
} break;
default:
GGML_ASSERT(false);
}
Expand Down Expand Up @@ -13147,6 +13329,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_ORION:
case LLM_ARCH_INTERNLM2:
case LLM_ARCH_MINICPM:
case LLM_ARCH_COMMAND_R:
return LLAMA_ROPE_TYPE_NORM;

// the pairs of head values are offset by n_rot/2
Expand Down

0 comments on commit 12247f4

Please sign in to comment.