From 7c517e1722e05dd042ca40a264e4c94f5ededa65 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 24 Nov 2023 16:47:21 +0200 Subject: [PATCH 1/9] lookahead : init --- examples/CMakeLists.txt | 1 + examples/lookahead/CMakeLists.txt | 5 + examples/lookahead/lookahead.cpp | 236 ++++++++++++++++++++++++++++++ 3 files changed, 242 insertions(+) create mode 100644 examples/lookahead/CMakeLists.txt create mode 100644 examples/lookahead/lookahead.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 71bcb6893e20d..6744944fd8b99 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -32,6 +32,7 @@ else() add_subdirectory(save-load-state) add_subdirectory(simple) add_subdirectory(speculative) + add_subdirectory(lookahead) add_subdirectory(train-text-from-scratch) if (LLAMA_METAL) add_subdirectory(metal) diff --git a/examples/lookahead/CMakeLists.txt b/examples/lookahead/CMakeLists.txt new file mode 100644 index 0000000000000..8827e3f11ecd6 --- /dev/null +++ b/examples/lookahead/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET lookahead) +add_executable(${TARGET} lookahead.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp new file mode 100644 index 0000000000000..23b9cd02fa062 --- /dev/null +++ b/examples/lookahead/lookahead.cpp @@ -0,0 +1,236 @@ +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include + +struct seq_ngram { + bool active = false; + + std::vector tokens; +}; + +int main(int argc, char ** argv) { + gpt_params params; + + if (gpt_params_parse(argc, argv, params) == false) { + return 1; + } + + const int W = 5; // lookahead window + const int N = 4; // n-gram size + const int G = 5; // max verification n-grams + + const bool dump_kv_cache = params.dump_kv_cache; + +#ifndef LOG_DISABLE_LOGS + log_set_target(log_filename_generator("lookahead", "log")); + LOG_TEE("Log start\n"); + log_dump_cmdline(argc, argv); +#endif // LOG_DISABLE_LOGS + + // init llama.cpp + llama_backend_init(params.numa); + + llama_model * model = NULL; + llama_context * ctx = NULL; + + // load the target model + std::tie(model, ctx) = llama_init_from_gpt_params(params); + + // Tokenize the prompt + const bool add_bos = llama_should_add_bos_token(model); + LOG("add_bos tgt: %d\n", add_bos); + + std::vector inp; + std::vector all; + + inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + all = inp; + + const int max_context_size = llama_n_ctx(ctx); + const int max_tokens_list_size = max_context_size - 4; + + if ((int) inp.size() > max_tokens_list_size) { + fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); + return 1; + } + + fprintf(stderr, "\n\n"); + + for (auto id : inp) { + fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); + } + + fflush(stderr); + + const int n_input = inp.size(); + + const auto t_enc_start = ggml_time_us(); + + // eval the prompt + llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); + llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); + + for (int s = 0; s < W + G + 1; ++s) { + llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); + } + + const auto t_enc_end = ggml_time_us(); + + int n_predict = 0; + int n_accept = 0; + + int n_past = inp.size(); + + llama_token id = 0; + + // used to determine end of generation + bool has_eos = false; + + // seq_id == 0 : the current input token + // seq_id [1, W] : tokens from the past N - 1 Jacobi iterations + // seq_id [W + 1, W + G] : verification n-grams + llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); + + // target model sampling context + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + + // verification n-grams + std::vector drafts(G); + + // tokens for the past N - 1 Jacobi iterations + // TODO: how to initialize? + std::vector> tokens_j(N - 1); + for (int j = 0; j < N - 1; j++) { + tokens_j[j].resize(W); + for (int i = 0; i < W; i++) { + tokens_j[j][i] = all[1 + rand() % (all.size() - 1)]; + } + } + + std::vector seq_id_look(W + 1); + for (int i = 0; i < W + 1; i++) { + seq_id_look[i] = i; + } + + std::vector seq_id_all(W + G + 1); + for (int i = 0; i < W + G + 1; i++) { + seq_id_all[i] = i; + } + + // debug + struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1); + + const auto t_dec_start = ggml_time_us(); + + // sample first token + { + id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0); + + llama_sampling_accept(ctx_sampling, ctx, id, true); + + { + const std::string token_str = llama_token_to_piece(ctx, id); + + printf("%s", token_str.c_str()); + fflush(stdout); + } + } + + while (true) { + // debug + if (dump_kv_cache) { + llama_kv_cache_view_update(ctx, &kvc_view); + dump_kv_cache_view_seqs(kvc_view, 40); + } + + // build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/ + { + llama_batch_clear(batch); + + llama_batch_add(batch, id, n_past, seq_id_all, true); + for (int i = 1; i < W; i++) { + llama_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false); + } + for (int j = 1; j < N - 1; j++) { + for (int i = 0; i < W; i++) { + llama_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2); + } + } + + // TODO: add verification n-grams + } + + llama_decode(ctx, batch); + + id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0); + + llama_sampling_accept(ctx_sampling, ctx, id, true); + + { + const std::string token_str = llama_token_to_piece(ctx, id); + + printf("%s", token_str.c_str()); + fflush(stdout); + + if (id == llama_token_eos(model)) { + has_eos = true; + } + } + + ++n_predict; + ++n_past; + + if (n_predict > params.n_predict || has_eos) { + break; + } + + // update Jacobi tokens (or whatever these are called) + { + for (int j = 0; j < N - 2; j++) { + tokens_j[j] = tokens_j[j + 1]; + } + + for (int i = 0; i < W; i++) { + tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, W*(N - 2) + i); + } + } + + // verification + // TODO + { + } + + llama_kv_cache_seq_rm(ctx, -1, n_past, -1); + } + + auto t_dec_end = ggml_time_us(); + + LOG_TEE("\n\n"); + + LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); + LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + + LOG_TEE("\n"); + LOG_TEE("n_predict = %d\n", n_predict); + LOG_TEE("n_accept = %d\n", n_accept); + + llama_print_timings(ctx); + + llama_kv_cache_view_free(&kvc_view); + llama_sampling_free(ctx_sampling); + + llama_batch_free(batch); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + fprintf(stderr, "\n\n"); + + return 0; +} From eb03b9ad6991d6b029eb7a6e738816d00caaaca5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 25 Nov 2023 13:54:07 +0200 Subject: [PATCH 2/9] lookahead : generate and store n-grams --- examples/lookahead/lookahead.cpp | 80 +++++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 2 deletions(-) diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 23b9cd02fa062..e5fa37b811a29 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -12,6 +12,21 @@ struct seq_ngram { std::vector tokens; }; +struct ngram_container { + ngram_container(int n_vocab, int N, int G) { + cnt.resize(n_vocab); + head.resize(n_vocab); + tokens.resize(n_vocab * (N - 1)*G); + } + + int n_total = 0; + + std::vector cnt; + std::vector head; + + std::vector tokens; +}; + int main(int argc, char ** argv) { gpt_params params; @@ -99,10 +114,10 @@ int main(int argc, char ** argv) { struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); // verification n-grams - std::vector drafts(G); + std::vector ngrams(G); // tokens for the past N - 1 Jacobi iterations - // TODO: how to initialize? + std::vector tokens_j_prev(W); std::vector> tokens_j(N - 1); for (int j = 0; j < N - 1; j++) { tokens_j[j].resize(W); @@ -121,6 +136,8 @@ int main(int argc, char ** argv) { seq_id_all[i] = i; } + ngram_container ngrams_observed(llama_n_vocab(model), N, G); + // debug struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1); @@ -188,8 +205,33 @@ int main(int argc, char ** argv) { break; } + // print known n-grams starting with token id + if (1) { + if (ngrams_observed.cnt[id] > 0) { + printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str()); + } + + for (int i = 0; i < ngrams_observed.cnt[id]; i++) { + printf(" - ngram %2d: ", i); + + const int idx = id*(N - 1)*G + i*(N - 1); + + for (int j = 0; j < N - 1; j++) { + const std::string token_str = llama_token_to_piece(ctx, ngrams_observed.tokens[idx + j]); + + printf("%s", token_str.c_str()); + } + + printf("\n"); + } + } + // update Jacobi tokens (or whatever these are called) { + for (int i = 0; i < W; i++) { + tokens_j_prev[i] = tokens_j[0][i]; + } + for (int j = 0; j < N - 2; j++) { tokens_j[j] = tokens_j[j + 1]; } @@ -199,6 +241,40 @@ int main(int argc, char ** argv) { } } + // update observed ngrams + { + // the first token of the n-gram is determined by the index in the container so it is not stored + std::vector ngram(N - 1); + + // n-gram generation + for (int f = 0; f < W; ++f) { + std::function rec = [&](int j) { + if (j == N - 1) { + const int ft = tokens_j_prev[f]; // first token of the n-gram + const int head = ngrams_observed.head[ft]; + const int idx = ft*(N - 1)*G + head*(N - 1); + + for (int i = 0; i < N - 1; i++) { + ngrams_observed.tokens[idx + i] = ngram[i]; + } + + ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1); + ngrams_observed.head[ft] = (head + 1) % G; + + ngrams_observed.n_total++; + + return; + } + + ngram[j] = tokens_j[j][f]; + + rec(j + 1); + }; + + rec(0); + } + } + // verification // TODO { From 1b2e0bc3e6c48b9cda28e595ee6237dbd814efb7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 25 Nov 2023 13:58:41 +0200 Subject: [PATCH 3/9] lookahead : use loop instead recursion to generate n-grams --- examples/lookahead/lookahead.cpp | 34 +++++++++++++------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index e5fa37b811a29..c45184b146f57 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -169,9 +169,11 @@ int main(int argc, char ** argv) { llama_batch_clear(batch); llama_batch_add(batch, id, n_past, seq_id_all, true); + for (int i = 1; i < W; i++) { llama_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false); } + for (int j = 1; j < N - 1; j++) { for (int i = 0; i < W; i++) { llama_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2); @@ -248,30 +250,22 @@ int main(int argc, char ** argv) { // n-gram generation for (int f = 0; f < W; ++f) { - std::function rec = [&](int j) { - if (j == N - 1) { - const int ft = tokens_j_prev[f]; // first token of the n-gram - const int head = ngrams_observed.head[ft]; - const int idx = ft*(N - 1)*G + head*(N - 1); - - for (int i = 0; i < N - 1; i++) { - ngrams_observed.tokens[idx + i] = ngram[i]; - } - - ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1); - ngrams_observed.head[ft] = (head + 1) % G; - - ngrams_observed.n_total++; + for (int j = 0; j < N - 1; ++j) { + ngram[j] = tokens_j[j][f]; + }; - return; - } + const int ft = tokens_j_prev[f]; // first token of the n-gram + const int head = ngrams_observed.head[ft]; + const int idx = ft*(N - 1)*G + head*(N - 1); - ngram[j] = tokens_j[j][f]; + for (int i = 0; i < N - 1; i++) { + ngrams_observed.tokens[idx + i] = ngram[i]; + } - rec(j + 1); - }; + ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1); + ngrams_observed.head[ft] = (head + 1) % G; - rec(0); + ngrams_observed.n_total++; } } From 61d039727a8460e369f41efb30a3bd9243555ff6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 25 Nov 2023 16:25:38 +0200 Subject: [PATCH 4/9] lookahead : initial working implementation --- examples/lookahead/lookahead.cpp | 236 ++++++++++++++++++++++--------- 1 file changed, 166 insertions(+), 70 deletions(-) diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index c45184b146f57..33af03a3e6d74 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -7,7 +7,11 @@ #include struct seq_ngram { - bool active = false; + bool active = false; + + llama_seq_id seq_id = -1; + + std::vector i_batch; std::vector tokens; }; @@ -34,9 +38,9 @@ int main(int argc, char ** argv) { return 1; } - const int W = 5; // lookahead window - const int N = 4; // n-gram size - const int G = 5; // max verification n-grams + const int W = 10; // lookahead window + const int N = 8; // n-gram size + const int G = 10; // max verification n-grams const bool dump_kv_cache = params.dump_kv_cache; @@ -89,7 +93,7 @@ int main(int argc, char ** argv) { llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); - for (int s = 0; s < W + G + 1; ++s) { + for (int s = 1; s < W + G + 1; ++s) { llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); } @@ -114,15 +118,18 @@ int main(int argc, char ** argv) { struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); // verification n-grams - std::vector ngrams(G); + std::vector ngrams_cur(G); // tokens for the past N - 1 Jacobi iterations std::vector tokens_j_prev(W); std::vector> tokens_j(N - 1); for (int j = 0; j < N - 1; j++) { tokens_j[j].resize(W); + for (int i = 0; i < W; i++) { - tokens_j[j][i] = all[1 + rand() % (all.size() - 1)]; + // initialize randomly from the prompt tokens + //tokens_j[j][i] = all[1 + rand() % (all.size() - 1)]; + tokens_j[j][i] = 100 + i; } } @@ -168,113 +175,202 @@ int main(int argc, char ** argv) { { llama_batch_clear(batch); + // current token - first token of the first level llama_batch_add(batch, id, n_past, seq_id_all, true); + // verification n-grams - queue this here for less KV cache fragmentation + { + const int g_cur = ngrams_observed.cnt[id]; + + ngrams_cur.resize(g_cur); + for (int g = 0; g < g_cur; g++) { + ngrams_cur[g].active = true; + ngrams_cur[g].tokens.resize(N); + ngrams_cur[g].i_batch.resize(N); + ngrams_cur[g].seq_id = W + 1 + g; + ngrams_cur[g].i_batch[0] = 0; + ngrams_cur[g].tokens [0] = id; + } + + for (int j = 0; j < N - 1; j++) { + for (int g = 0; g < g_cur; g++) { + const int idx = id*(N - 1)*G + g*(N - 1); + + const llama_token t = ngrams_observed.tokens[idx + j]; + + ngrams_cur[g].tokens [j + 1] = t; + ngrams_cur[g].i_batch[j + 1] = batch.n_tokens; + + llama_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true); + } + } + } + + // fill the remaining W - 1 tokens for the first level for (int i = 1; i < W; i++) { llama_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false); } + // fill the rest of the levels for (int j = 1; j < N - 1; j++) { for (int i = 0; i < W; i++) { llama_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2); } } + } - // TODO: add verification n-grams + if (llama_decode(ctx, batch) != 0) { + fprintf(stderr, "\n\n%s: error: llama_decode failed - increase KV cache size\n", __func__); + return 1; } - llama_decode(ctx, batch); + int seq_id_best = 0; - id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0); + for (int v = 0; v < N; ++v) { + int i_batch = 0; - llama_sampling_accept(ctx_sampling, ctx, id, true); + if (v > 0) { + for (int g = 0; g < (int) ngrams_cur.size(); g++) { + if (ngrams_cur[g].active) { + i_batch = ngrams_cur[g].i_batch[v]; + seq_id_best = ngrams_cur[g].seq_id; + break; + } + } - { - const std::string token_str = llama_token_to_piece(ctx, id); + // no more matches + if (i_batch == 0) { + break; + } + } - printf("%s", token_str.c_str()); - fflush(stdout); + id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch); - if (id == llama_token_eos(model)) { - has_eos = true; - } - } + llama_sampling_accept(ctx_sampling, ctx, id, true); - ++n_predict; - ++n_past; + { + const std::string token_str = llama_token_to_piece(ctx, id); - if (n_predict > params.n_predict || has_eos) { - break; - } + if (v == 0) { + printf("%s", token_str.c_str()); + } else { + // print light cyan + printf("\033[0;96m%s\033[0m", token_str.c_str()); + } + fflush(stdout); + + if (id == llama_token_eos(model)) { + has_eos = true; + } - // print known n-grams starting with token id - if (1) { - if (ngrams_observed.cnt[id] > 0) { - printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str()); + all.push_back(id); } - for (int i = 0; i < ngrams_observed.cnt[id]; i++) { - printf(" - ngram %2d: ", i); + ++n_predict; + ++n_past; - const int idx = id*(N - 1)*G + i*(N - 1); + if (n_predict > params.n_predict || has_eos) { + break; + } - for (int j = 0; j < N - 1; j++) { - const std::string token_str = llama_token_to_piece(ctx, ngrams_observed.tokens[idx + j]); + // verify across active n-grams + for (int g = 0; g < (int) ngrams_cur.size(); g++) { + if (ngrams_cur[g].active) { + if (v == N - 1) { + ngrams_cur[g].active = false; + } else { + if (id != ngrams_cur[g].tokens[v + 1]) { + ngrams_cur[g].active = false; + } else { + } + } + } + } - printf("%s", token_str.c_str()); + // print known n-grams starting with token id + if (0) { + if (ngrams_observed.cnt[id] > 0) { + printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str()); } - printf("\n"); - } - } + for (int i = 0; i < ngrams_observed.cnt[id]; i++) { + printf(" - ngram %2d: ", i); - // update Jacobi tokens (or whatever these are called) - { - for (int i = 0; i < W; i++) { - tokens_j_prev[i] = tokens_j[0][i]; - } + const int idx = id*(N - 1)*G + i*(N - 1); + + for (int j = 0; j < N - 1; j++) { + const std::string token_str = llama_token_to_piece(ctx, ngrams_observed.tokens[idx + j]); - for (int j = 0; j < N - 2; j++) { - tokens_j[j] = tokens_j[j + 1]; + printf("%s", token_str.c_str()); + } + + printf("\n"); + } } - for (int i = 0; i < W; i++) { - tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, W*(N - 2) + i); + // update Jacobi tokens (or whatever these are called) + { + for (int i = 0; i < W; i++) { + tokens_j_prev[i] = tokens_j[0][i]; + } + + for (int j = 0; j < N - 2; j++) { + tokens_j[j] = tokens_j[j + 1]; + } + + if (v == 0) { + // sample from the last level + for (int i = 0; i < W; i++) { + tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i); + } + } else { + for (int i = 0; i < W; i++) { + // random init + //tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)]; + tokens_j[N - 2][i] = tokens_j[0][i]; + } + } } - } - // update observed ngrams - { - // the first token of the n-gram is determined by the index in the container so it is not stored - std::vector ngram(N - 1); + // update observed ngrams + { + // the first token of the n-gram is determined by the index in the container so it is not stored + std::vector ngram(N - 1); - // n-gram generation - for (int f = 0; f < W; ++f) { - for (int j = 0; j < N - 1; ++j) { - ngram[j] = tokens_j[j][f]; - }; + // n-gram generation + // ref: https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518 + for (int f = 0; f < W; ++f) { + for (int j = 0; j < N - 1; ++j) { + ngram[j] = tokens_j[j][f]; + }; - const int ft = tokens_j_prev[f]; // first token of the n-gram - const int head = ngrams_observed.head[ft]; - const int idx = ft*(N - 1)*G + head*(N - 1); + const int ft = tokens_j_prev[f]; // first token of the n-gram + const int head = ngrams_observed.head[ft]; + const int idx = ft*(N - 1)*G + head*(N - 1); - for (int i = 0; i < N - 1; i++) { - ngrams_observed.tokens[idx + i] = ngram[i]; - } + for (int i = 0; i < N - 1; i++) { + ngrams_observed.tokens[idx + i] = ngram[i]; + } - ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1); - ngrams_observed.head[ft] = (head + 1) % G; + ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1); + ngrams_observed.head[ft] = (head + 1) % G; - ngrams_observed.n_total++; + ngrams_observed.n_total++; + } } } - // verification - // TODO - { - } - llama_kv_cache_seq_rm(ctx, -1, n_past, -1); + + if (seq_id_best != 0) { + llama_kv_cache_seq_keep(ctx, seq_id_best); + llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1); + llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1); + + for (int s = 1; s < W + G + 1; ++s) { + llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); + } + } } auto t_dec_end = ggml_time_us(); From 6eb5166e5ac132328fbebe1152c60a75875755ae Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 25 Nov 2023 17:02:56 +0200 Subject: [PATCH 5/9] lookahead : filter repeating n-grams --- examples/lookahead/lookahead.cpp | 56 +++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 33af03a3e6d74..6f841fff4d6b1 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -38,9 +38,9 @@ int main(int argc, char ** argv) { return 1; } - const int W = 10; // lookahead window - const int N = 8; // n-gram size - const int G = 10; // max verification n-grams + const int W = 15; // lookahead window + const int N = 5; // n-gram size + const int G = 15; // max verification n-grams const bool dump_kv_cache = params.dump_kv_cache; @@ -128,8 +128,8 @@ int main(int argc, char ** argv) { for (int i = 0; i < W; i++) { // initialize randomly from the prompt tokens - //tokens_j[j][i] = all[1 + rand() % (all.size() - 1)]; - tokens_j[j][i] = 100 + i; + tokens_j[j][i] = all[1 + rand() % (all.size() - 1)]; + //tokens_j[j][i] = 100 + i; } } @@ -234,6 +234,8 @@ int main(int argc, char ** argv) { if (ngrams_cur[g].active) { i_batch = ngrams_cur[g].i_batch[v]; seq_id_best = ngrams_cur[g].seq_id; + + ++n_accept; break; } } @@ -281,14 +283,13 @@ int main(int argc, char ** argv) { } else { if (id != ngrams_cur[g].tokens[v + 1]) { ngrams_cur[g].active = false; - } else { } } } } - // print known n-grams starting with token id - if (0) { + // print known n-grams starting with token id (debug) + if (0 && v == 0) { if (ngrams_observed.cnt[id] > 0) { printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str()); } @@ -326,8 +327,8 @@ int main(int argc, char ** argv) { } else { for (int i = 0; i < W; i++) { // random init - //tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)]; - tokens_j[N - 2][i] = tokens_j[0][i]; + tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)]; + //tokens_j[N - 2][i] = tokens_j[0][i]; } } } @@ -340,11 +341,38 @@ int main(int argc, char ** argv) { // n-gram generation // ref: https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518 for (int f = 0; f < W; ++f) { + const int ft = tokens_j_prev[f]; // first token of the n-gram + for (int j = 0; j < N - 1; ++j) { ngram[j] = tokens_j[j][f]; - }; + } + + // filter-out repeating n-grams + { + bool is_unique = true; + + for (int k = 0; k < ngrams_observed.cnt[ft]; ++k) { + const int idx = ft*(N - 1)*G + k*(N - 1); + + bool is_match = true; + for (int j = 0; j < N - 1; ++j) { + if (ngrams_observed.tokens[idx + j] != ngram[j]) { + is_match = false; + break; + } + } + + if (is_match) { + is_unique = false; + break; + } + } + + if (!is_unique) { + continue; + } + } - const int ft = tokens_j_prev[f]; // first token of the n-gram const int head = ngrams_observed.head[ft]; const int idx = ft*(N - 1)*G + head*(N - 1); @@ -360,6 +388,10 @@ int main(int argc, char ** argv) { } } + if (n_predict > params.n_predict || has_eos) { + break; + } + llama_kv_cache_seq_rm(ctx, -1, n_past, -1); if (seq_id_best != 0) { From 7bd1cd7ef48db34146491c19aa8570f994d59671 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 25 Nov 2023 17:12:16 +0200 Subject: [PATCH 6/9] lookahead : use deterministic init --- examples/lookahead/lookahead.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 6f841fff4d6b1..9a39f0b5ec692 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -129,7 +129,9 @@ int main(int argc, char ** argv) { for (int i = 0; i < W; i++) { // initialize randomly from the prompt tokens tokens_j[j][i] = all[1 + rand() % (all.size() - 1)]; - //tokens_j[j][i] = 100 + i; + + // initialize with a sequence of increasing numbers + tokens_j[j][i] = 100 + i; } } @@ -327,14 +329,16 @@ int main(int argc, char ** argv) { } else { for (int i = 0; i < W; i++) { // random init - tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)]; - //tokens_j[N - 2][i] = tokens_j[0][i]; + //tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)]; + + // init from the previous level + tokens_j[N - 2][i] = tokens_j[0][i]; } } } // update observed ngrams - { + if (v == 0) { // the first token of the n-gram is determined by the index in the container so it is not stored std::vector ngram(N - 1); From 7d50de2de146256cb434cc3c9e7e5aa582e53c8c Mon Sep 17 00:00:00 2001 From: slaren Date: Sun, 26 Nov 2023 08:33:11 +0100 Subject: [PATCH 7/9] lookahead : add to Makefile --- .gitignore | 1 + Makefile | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 41259a12f50cb..3806e05ddcc12 100644 --- a/.gitignore +++ b/.gitignore @@ -47,6 +47,7 @@ models-mnt /libllama.so /llama-bench /llava-cli +/lookahead /main /metal /perplexity diff --git a/Makefile b/Makefile index a6d2c2ec0f380..95d85236f8f24 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ BUILD_TARGETS = \ main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \ simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \ - speculative infill tokenize benchmark-matmult parallel finetune export-lora tests/test-c.o + speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead tests/test-c.o # Binaries only useful for tests TEST_TARGETS = \ @@ -657,6 +657,9 @@ speculative: examples/speculative/speculative.cpp ggml.o llama.o $(COMMON_DEPS) parallel: examples/parallel/parallel.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) +lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + ifdef LLAMA_METAL metal: examples/metal/metal.cpp ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) From 1a07a33939c05ad6252f75ab8a371d96794247cd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 26 Nov 2023 11:26:43 +0200 Subject: [PATCH 8/9] lookahead : fix a bug in the seq_id of the lookahead tokens --- examples/lookahead/lookahead.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 9a39f0b5ec692..ff17f06da146b 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -135,10 +135,7 @@ int main(int argc, char ** argv) { } } - std::vector seq_id_look(W + 1); - for (int i = 0; i < W + 1; i++) { - seq_id_look[i] = i; - } + std::vector seq_id_look; std::vector seq_id_all(W + G + 1); for (int i = 0; i < W + G + 1; i++) { @@ -210,6 +207,11 @@ int main(int argc, char ** argv) { // fill the remaining W - 1 tokens for the first level for (int i = 1; i < W; i++) { + seq_id_look.resize(W - i); + for (int j = 0; j < W - i; j++) { + seq_id_look[j] = i + j + 1; + } + llama_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false); } From 8d8b76d469763fca498d55e04c8b10a18a545c3b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 26 Nov 2023 11:26:55 +0200 Subject: [PATCH 9/9] lookahead : add comments --- examples/lookahead/lookahead.cpp | 79 +++++++++++++++++++++++++------- 1 file changed, 63 insertions(+), 16 deletions(-) diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index ff17f06da146b..4c49a85ebcde7 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -6,7 +6,7 @@ #include #include -struct seq_ngram { +struct ngram_data { bool active = false; llama_seq_id seq_id = -1; @@ -16,11 +16,12 @@ struct seq_ngram { std::vector tokens; }; +// n-gram container struct ngram_container { ngram_container(int n_vocab, int N, int G) { cnt.resize(n_vocab); head.resize(n_vocab); - tokens.resize(n_vocab * (N - 1)*G); + tokens.resize(n_vocab * G * (N - 1)); } int n_total = 0; @@ -28,6 +29,8 @@ struct ngram_container { std::vector cnt; std::vector head; + // [n_vocab][G][N - 1] + // for each token of the vocab, keep a ring-buffer of capacity G of n-grams of size N - 1 std::vector tokens; }; @@ -109,6 +112,7 @@ int main(int argc, char ** argv) { // used to determine end of generation bool has_eos = false; + // for each decoded batch, we have at most W + G + 1 distinct sequences: // seq_id == 0 : the current input token // seq_id [1, W] : tokens from the past N - 1 Jacobi iterations // seq_id [W + 1, W + G] : verification n-grams @@ -118,7 +122,7 @@ int main(int argc, char ** argv) { struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); // verification n-grams - std::vector ngrams_cur(G); + std::vector ngrams_cur(G); // tokens for the past N - 1 Jacobi iterations std::vector tokens_j_prev(W); @@ -127,21 +131,26 @@ int main(int argc, char ** argv) { tokens_j[j].resize(W); for (int i = 0; i < W; i++) { - // initialize randomly from the prompt tokens - tokens_j[j][i] = all[1 + rand() % (all.size() - 1)]; - - // initialize with a sequence of increasing numbers - tokens_j[j][i] = 100 + i; + // there are different ways to init these tokens + if (0) { + // initialize randomly from the prompt tokens + tokens_j[j][i] = all[1 + rand() % (all.size() - 1)]; + } else { + // initialize with a sequence of increasing numbers + tokens_j[j][i] = 100 + i; + } } } std::vector seq_id_look; + // the input token belongs both to all sequences std::vector seq_id_all(W + G + 1); for (int i = 0; i < W + G + 1; i++) { seq_id_all[i] = i; } + // here we keep adding new n-grams as we go ngram_container ngrams_observed(llama_n_vocab(model), N, G); // debug @@ -171,13 +180,37 @@ int main(int argc, char ** argv) { } // build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/ + // + // Example for W = 5, N = 4, G = 2: + // (I = input, L = lookahead, V = verification) + // + // Batch: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 + // T: -2 -2 -2 -2 -1 -1 -1 -1 -1 0 0 0 0 0 0 + // Info: I L L L L L L L L L L L L L L V V V V V V + // Pos: 0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 1 2 3 1 2 3 (+ n_past) + // Logits: 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 + // --------------------------------------------------------------------- + // Seq: 0 + // 1 1 1 + // 2 2 2 2 + // 3 3 3 3 3 + // 4 4 4 4 4 4 + // 5 5 5 5 5 5 5 + // 6 6 6 6 + // 7 7 7 7 + // --------------------------------------------------------------------- + // | | | | | | | | | | | + // V V V V V | | | | | | + // j_tokens | | | | | | + // V V V V V V + // id { llama_batch_clear(batch); // current token - first token of the first level llama_batch_add(batch, id, n_past, seq_id_all, true); - // verification n-grams - queue this here for less KV cache fragmentation + // verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation { const int g_cur = ngrams_observed.cnt[id]; @@ -233,6 +266,7 @@ int main(int argc, char ** argv) { for (int v = 0; v < N; ++v) { int i_batch = 0; + // if no active ngrams are left, it means the sampled token does not pass the verification if (v > 0) { for (int g = 0; g < (int) ngrams_cur.size(); g++) { if (ngrams_cur[g].active) { @@ -244,16 +278,18 @@ int main(int argc, char ** argv) { } } - // no more matches + // no more matches -> create a new batch if (i_batch == 0) { break; } } + // sample the next token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch); llama_sampling_accept(ctx_sampling, ctx, id, true); + // print { const std::string token_str = llama_token_to_piece(ctx, id); @@ -313,7 +349,7 @@ int main(int argc, char ** argv) { } } - // update Jacobi tokens (or whatever these are called) + // update lookahead tokens { for (int i = 0; i < W; i++) { tokens_j_prev[i] = tokens_j[0][i]; @@ -330,11 +366,14 @@ int main(int argc, char ** argv) { } } else { for (int i = 0; i < W; i++) { - // random init - //tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)]; - - // init from the previous level - tokens_j[N - 2][i] = tokens_j[0][i]; + // there are different ways to init these tokens + if (0) { + // random init + tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)]; + } else { + // init from the previous level + tokens_j[N - 2][i] = tokens_j[0][i]; + } } } } @@ -398,9 +437,13 @@ int main(int argc, char ** argv) { break; } + // KV cache management + // if no verification token matched, we simply remove all cells from this batch -> no fragmentation llama_kv_cache_seq_rm(ctx, -1, n_past, -1); if (seq_id_best != 0) { + // if a verification token matched, we keep the best sequence and remove the rest + // this leads to some KV cache fragmentation llama_kv_cache_seq_keep(ctx, seq_id_best); llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1); llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1); @@ -418,6 +461,10 @@ int main(int argc, char ** argv) { LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + LOG_TEE("\n"); + LOG_TEE("W = %2d\n", W); + LOG_TEE("N = %2d\n", N); + LOG_TEE("G = %2d\n", G); LOG_TEE("\n"); LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_accept = %d\n", n_accept);