Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
cebtenzzre committed Dec 22, 2023
1 parent 55f3495 commit 930f7b1
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 24 deletions.
4 changes: 2 additions & 2 deletions examples/contrastive/contrastive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ int main(int argc, char ** argv) {
float * logits_exp = llama_get_logits_ith(ctx_exp, idx);
float * logits_ama = llama_get_logits_ith(ctx_ama, idx);

#if 1
float max_logit_exp = *std::max_element(logits_exp, logits_exp + n_vocab);

for (int i = 0; i < n_vocab_exp; ++i) {
Expand All @@ -387,11 +388,10 @@ int main(int argc, char ** argv) {
// token not known to amateur
logits_exp[i] = mask;
} else {
#if 1
logits_exp[i] = (1 + cd_beta) * logits_exp[i] - cd_beta * logits_ama[i];
#endif
}
}
#endif

const llama_token id = llama_sampling_sample(ctx_sampling, ctx_exp, NULL, idx);
llama_sampling_accept(ctx_sampling, ctx_exp, id, true);
Expand Down
77 changes: 55 additions & 22 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ static results_perplexity perplexity(llama_context * ctx_exp, llama_context * ct
auto * batch_logits_ama = llama_get_logits(ctx_ama);

const float cd_alpha = 0.1; // TODO: make CLI argument
const float cd_beta = 0.5; // set to 0.5 to behave like original paper
const float cd_beta = 0.5;
const float mask = std::numeric_limits<float>::lowest();

#if 1
#if 0
for (int t = 0; t < batch_size; ++t) {
auto * logits_exp = batch_logits_exp + t * n_vocab_exp;
auto * logits_ama = batch_logits_ama + t * n_vocab_ama;
Expand Down Expand Up @@ -274,28 +274,64 @@ static results_perplexity perplexity(llama_context * ctx_exp, llama_context * ct
}

static std::vector<float> hellaswag_evaluate_tokens(
llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab
llama_context * ctx_exp, llama_context * ctx_ama, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab_exp
) {
const int n_vocab_ama = llama_n_vocab(llama_get_model(ctx_ama));
const int n_vocab = std::min(n_vocab_exp, n_vocab_ama);

std::vector<float> result;
result.reserve(tokens.size() * n_vocab);
result.reserve(tokens.size() * n_vocab_exp);
size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch;
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
size_t n_tokens = tokens.size() - i_chunk * n_batch;
n_tokens = std::min(n_tokens, size_t(n_batch));
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) {
if (llama_decode(ctx_exp, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {};
}
if (llama_decode(ctx_ama, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {};
}

auto * batch_logits_exp = llama_get_logits(ctx_exp);
auto * batch_logits_ama = llama_get_logits(ctx_ama);

const auto logits = llama_get_logits(ctx);
result.insert(result.end(), logits, logits + n_tokens * n_vocab);
const float cd_alpha = 0.1; // TODO: make CLI argument
const float cd_beta = 0.5;
const float mask = std::numeric_limits<float>::lowest();

#if 1
for (int t = 0; t < n_tokens; ++t) {
auto * logits_exp = batch_logits_exp + t * n_vocab_exp;
auto * logits_ama = batch_logits_ama + t * n_vocab_ama;
float max_logit_exp = *std::max_element(logits_exp, logits_exp + n_vocab);

for (int i = 0; i < n_vocab_exp; ++i) {
// NB: original paper applies alpha to probabilities, further paper defines in terms of log probs
// both have the same meaning
if (logits_exp[i] < max_logit_exp + log(cd_alpha)) {
// not a plausible token according to expert
logits_exp[i] = mask;
} else if (i >= n_vocab) {
// token not known to amateur
logits_exp[i] = mask;
} else {
logits_exp[i] = (1 + cd_beta) * logits_exp[i] - cd_beta * logits_ama[i];
// logits_exp[i] = logits_exp[i] - logits_ama[i] / 0.75f;
}
}
}
#endif

result.insert(result.end(), batch_logits_exp, batch_logits_exp + n_tokens * n_vocab_exp);

n_past += n_tokens;
}
return result;
}

static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
static void hellaswag_score(llama_context * ctx_exp, llama_context * ctx_ama, const gpt_params & params) {
// Calculates hellaswag score (acc_norm) from prompt
//
// Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
Expand Down Expand Up @@ -329,11 +365,8 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
size_t hs_task_count = prompt_lines.size()/6;
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);

const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
fprintf(stderr, "================================= is_spm = %d\n", is_spm);

// This is needed as usual for LLaMA models
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx_exp));

// Number of tasks to use when computing the score
if ( params.hellaswag_tasks < hs_task_count ) {
Expand Down Expand Up @@ -384,20 +417,20 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
printf("\ntask\tacc_norm\n");

double acc = 0.0f;
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx_exp));
const int n_ctx = llama_n_ctx(ctx_exp);

std::vector<std::vector<int>> ending_tokens(4);

std::vector<float> tok_logits(n_vocab);

for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
// Tokenize the context to count tokens
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos);
std::vector<int> context_embd = ::llama_tokenize(ctx_exp, hs_data[task_idx].context, add_bos);
size_t context_size = context_embd.size();

for (int i = 0; i < 4; ++i) {
ending_tokens[i] = ::llama_tokenize(ctx, hs_data[task_idx].context + " " + hs_data[task_idx].ending[i], add_bos);
ending_tokens[i] = ::llama_tokenize(ctx_exp, hs_data[task_idx].context + " " + hs_data[task_idx].ending[i], add_bos);
for (int k = 0; k < int(context_size); ++k) {
if (ending_tokens[i][k] != context_embd[k]) {
fprintf(stderr, "Oops: ending %d of task %d differs from context at position %d\n",i,int(task_idx),k);
Expand All @@ -424,9 +457,10 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
}

// clear the KV cache
llama_kv_cache_clear(ctx);
llama_kv_cache_clear(ctx_exp);
llama_kv_cache_clear(ctx_ama);

auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
auto logits = hellaswag_evaluate_tokens(ctx_exp, ctx_ama, query_embd, 0, params.n_batch, n_vocab);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
Expand Down Expand Up @@ -475,7 +509,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
//}

// Evaluate the query
logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab);
logits = hellaswag_evaluate_tokens(ctx_exp, ctx_ama, query_embd, context_size, params.n_batch, n_vocab);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
Expand Down Expand Up @@ -633,11 +667,10 @@ int main(int argc, char ** argv) {
return 1;
}

struct results_perplexity results;
if (params.hellaswag) {
hellaswag_score(ctx, params);
hellaswag_score(ctx_exp, ctx_ama, params);
} else {
results = perplexity(ctx_exp, ctx_ama, params);
perplexity(ctx_exp, ctx_ama, params);
}

fprintf(stderr, "\namateur:\n");
Expand Down

0 comments on commit 930f7b1

Please sign in to comment.