Skip to content

Commit

Permalink
Merge pull request #367 from janimo/long-multiply
Browse files Browse the repository at this point in the history
Do parameter count calculations in 64 bits to not overflow in case of…
  • Loading branch information
karpathy authored Sep 1, 2023
2 parents 0776f86 + c5ec6e2 commit b9fb861
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,28 @@ void free_run_state(RunState* s) {

void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
int head_size = p->dim / p->n_heads;
// make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models
unsigned long long n_layers = p->n_layers;
w->token_embedding_table = ptr;
ptr += p->vocab_size * p->dim;
w->rms_att_weight = ptr;
ptr += p->n_layers * p->dim;
ptr += n_layers * p->dim;
w->wq = ptr;
ptr += p->n_layers * p->dim * (p->n_heads * head_size);
ptr += n_layers * p->dim * (p->n_heads * head_size);
w->wk = ptr;
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
w->wv = ptr;
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
w->wo = ptr;
ptr += p->n_layers * (p->n_heads * head_size) * p->dim;
ptr += n_layers * (p->n_heads * head_size) * p->dim;
w->rms_ffn_weight = ptr;
ptr += p->n_layers * p->dim;
ptr += n_layers * p->dim;
w->w1 = ptr;
ptr += p->n_layers * p->dim * p->hidden_dim;
ptr += n_layers * p->dim * p->hidden_dim;
w->w2 = ptr;
ptr += p->n_layers * p->hidden_dim * p->dim;
ptr += n_layers * p->hidden_dim * p->dim;
w->w3 = ptr;
ptr += p->n_layers * p->dim * p->hidden_dim;
ptr += n_layers * p->dim * p->hidden_dim;
w->rms_final_weight = ptr;
ptr += p->dim;
ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE)
Expand Down Expand Up @@ -249,7 +251,7 @@ float* forward(Transformer* transformer, int token, int pos) {
memcpy(x, content_row, dim*sizeof(*x));

// forward all the layers
for(int l = 0; l < p->n_layers; l++) {
for(unsigned long long l = 0; l < p->n_layers; l++) {

// attention rmsnorm
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
Expand Down

0 comments on commit b9fb861

Please sign in to comment.