From c4269e020002107181ccc2c7837e20abf6191c1d Mon Sep 17 00:00:00 2001 From: Matt Pulver Date: Tue, 18 Jul 2023 14:33:34 -0400 Subject: [PATCH 1/8] Add llama_beam_search(). --- common/common.h | 1 + examples/CMakeLists.txt | 1 + examples/beam_search/CMakeLists.txt | 8 + examples/beam_search/beam_search.cpp | 186 ++++++++++++++++++++ examples/server/server.cpp | 91 ++++++++-- llama.cpp | 247 +++++++++++++++++++++++++++ llama.h | 33 ++++ 7 files changed, 554 insertions(+), 13 deletions(-) create mode 100644 examples/beam_search/CMakeLists.txt create mode 100644 examples/beam_search/beam_search.cpp diff --git a/common/common.h b/common/common.h index 17d271e6750e2..ce61265f8c124 100644 --- a/common/common.h +++ b/common/common.h @@ -28,6 +28,7 @@ struct gpt_params { int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t n_beams = 0; // if non-zero then use beam search of given width. float rope_freq_base = 10000.0f; // RoPE base frequency float rope_freq_scale = 1.0f; // RoPE frequency scaling factor diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index d2176c910c299..94b7852248748 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -25,6 +25,7 @@ else() add_subdirectory(simple) add_subdirectory(embd-input) add_subdirectory(llama-bench) + add_subdirectory(beam_search) if (LLAMA_METAL) add_subdirectory(metal) endif() diff --git a/examples/beam_search/CMakeLists.txt b/examples/beam_search/CMakeLists.txt new file mode 100644 index 0000000000000..b29e01092feb5 --- /dev/null +++ b/examples/beam_search/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET beam_search) +add_executable(${TARGET} beam_search.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() diff --git a/examples/beam_search/beam_search.cpp b/examples/beam_search/beam_search.cpp new file mode 100644 index 0000000000000..2bc0a378b77aa --- /dev/null +++ b/examples/beam_search/beam_search.cpp @@ -0,0 +1,186 @@ +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include "common.h" +#include "llama.h" +#include "build-info.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +#include +#include +#elif defined (_WIN32) +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include +#include +#endif + +// Used for debugging to print out beam tokens. +struct ostream_beam_view { + llama_context* ctx; + llama_beam_view beam_view; +}; +std::ostream& operator<<(std::ostream& os, ostream_beam_view const& obv) { + os << "p(" << obv.beam_view.p << ") eos(" << std::boolalpha << obv.beam_view.eos << ") tokens("; + for (size_t i=0 ; i response; +}; + +bool is_at_eos(beam_search_callback_data const& callback_data, llama_token const* tokens, size_t const n_tokens) { + return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx); +} + +// Function matching type llama_beam_search_callback_fn_t. +// Custom callback example is called each time the beams lengths increase: +// * Show progress by printing ',' following by number of convergent beam tokens if any. +// * When all beams converge to a common prefix, they are made available in beams_state.beams[0]. +// This is also called when the stop condition is met. +// Collect tokens into std::vector response which is pointed to by callback_data. +void beam_search_callback(void* callback_data_ptr, llama_beams_state beams_state) { + auto& callback_data = *static_cast(callback_data_ptr); + // Mark beams as EOS as needed. + for (size_t i=0 ; i 3 ) + { + params.prompt = argv[3]; + } + + if ( params.prompt.empty() ) + { + params.prompt = "### Request:\nHow many countries are there?\n\n### Response:\n"; + } + + //--------------------------------- + // Init LLM : + //--------------------------------- + + llama_backend_init(params.numa); + + llama_model * model; + llama_context * ctx; + + std::tie(model, ctx) = llama_init_from_gpt_params( params ); + + if ( model == NULL ) + { + fprintf( stderr , "%s: error: unable to load model\n" , __func__ ); + return 1; + } + + //--------------------------------- + // Tokenize the prompt : + //--------------------------------- + + std::vector tokens_list = llama_tokenize(ctx, params.prompt, true); + + const size_t max_context_size = llama_n_ctx( ctx ); + const size_t max_tokens_list_size = max_context_size - 4 ; + + if (tokens_list.size() > max_tokens_list_size) + { + fprintf( stderr , "%s: error: prompt too long (%lu tokens, max %lu)\n" , + __func__ , tokens_list.size() , max_tokens_list_size ); + return 1; + } + + fprintf( stderr, "\n\n" ); + + // Print the tokens from the prompt : + + for( auto id : tokens_list ) + { + std::cout << llama_token_to_str(ctx, id); + } + std::cout << std::flush; + + int n_past = llama_get_kv_cache_token_count(ctx); + if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads)) + { + fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ ); + return 1; + } + n_past += tokens_list.size(); + + beam_search_callback_data callback_data{ctx, {}}; + size_t const beam_width = static_cast(params.n_beams); + int const n_predict = 256; + llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict, params.n_threads); + + std::cout << "\n\n"; + for (llama_token const token_id : callback_data.response) { + std::cout << llama_token_to_str(ctx,token_id); + } + std::cout << std::endl; + + llama_free( ctx ); + llama_free_model( model ); + + llama_backend_free(); + + return 0; +} diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 025b385cc8b1e..7985392fef172 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1209,6 +1209,63 @@ static void log_server_request(const Request &req, const Response &res) }); } +bool is_at_eos(llama_server_context& server_context, llama_token const* tokens, size_t const n_tokens) { + return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx); +} + +// Function matching type llama_beam_search_callback_fn_t. +// Custom callback example is called each time the beams lengths increase: +// * Show progress by printing ',' following by number of convergent beam tokens if any. +// * When all beams converge to a common prefix, they are made available in beams_state.beams[0]. +// This is also called when the stop condition is met. +// Collect tokens into std::vector response which is pointed to by callback_data. +void beam_search_callback(void* callback_data, llama_beams_state beams_state) { + auto& llama = *static_cast(callback_data); + // Mark beams as EOS as needed. + for (size_t i=0 ; igenerated_token_probs.end() - n); + auto const map = [](llama_token tok) { return completion_token_output{{},tok}; }; + std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map); + printf("%lu", n); + } + fflush(stdout); +#if 0 // DEBUG: print current beams for this iteration + std::cout << "\n\nCurrent beams:\n"; + for (size_t i=0 ; i < beams_state.n_beams ; ++i) { + std::cout << "beams["<t_sample_us += ggml_time_us() - t_start_sample_us; } +struct llama_beam { + std::vector tokens; + float p; // Cumulative beam probability (renormalized relative to all beams) + bool eos; // Initialize end-of-sentence to false. Callback sets this to true. + // Sort beams by probability. In case of ties, prefer beams at eos. + bool operator<(llama_beam const& rhs) const { + return std::make_tuple(p, eos) < std::make_tuple(rhs.p, rhs.eos); + } + // Shift off first n tokens and discard them. + void shift_tokens(size_t const n) { + if (n) { + std::copy(tokens.begin() + n, tokens.end(), tokens.begin()); + tokens.resize(tokens.size() - n); + } + } + llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eos}; } +}; + +// A struct for calculating logit-related info. +struct logit_info { + float const* const logits; + int const n_vocab; + float const max_l; + float const normalizer; + struct sum_exp { + float max_l; + float operator()(float sum, float l) const { return sum + std::exp(l - max_l); } + }; + logit_info(llama_context* ctx) + : logits(llama_get_logits(ctx)) + , n_vocab(llama_n_vocab(ctx)) + , max_l(*std::max_element(logits, logits + n_vocab)) + , normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l})) + { } + llama_token_data get_token_data(llama_token const token_id) const { + constexpr auto p = std::numeric_limits::quiet_NaN(); // never used + return {token_id, logits[token_id], p}; + } + // Return top k token_data by logit. + std::vector top_k(size_t k) { + std::vector min_heap; // min-heap by logit + llama_token const k_min = std::min(static_cast(k), n_vocab); + min_heap.reserve(k_min); + for (llama_token token_id=0 ; token_id b.logit; }; + std::make_heap(min_heap.begin(), min_heap.end(), comp); + for (llama_token token_id=k_min ; token_id beams; + std::vector next_beams; + + // Re-calculated on each loop iteration + size_t common_prefix_length; + + // Used to communicate to/from callback on beams state. + std::vector beam_views; + + beam_search(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads) + : ctx(ctx) + , n_beams(n_beams) + , n_past(n_past) + , n_predict(n_predict) + , n_threads(n_threads) + , beam_views(n_beams) { + beams.reserve(n_beams); + next_beams.reserve(n_beams); + } + + // Collapse beams to a single beam given by index. + void collapse_beams(size_t const beam_idx) { + if (0u < beam_idx) { + std::swap(beams[0], beams[beam_idx]); + } + beams.resize(1); + } + + // Min-heaps are used to efficiently collect the top-k elements (k=n_beams). + // The repetative patterns below reflect the 2 stages of heaps: + // * Gather elements until the vector is full, then call std::make_heap() on it. + // * If the heap is full and a new element is found that should be included, pop the + // least element to the back(), replace it with the new, then push it into the heap. + void fill_next_beams_by_top_probabilities(llama_beam& beam) { + // Min-heaps use a greater-than comparator. + auto const comp = [](llama_beam const& a, llama_beam const& b) { return a.p > b.p; }; + if (beam.eos) { + // beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough. + if (next_beams.size() < n_beams) { + next_beams.push_back(std::move(beam)); + if (next_beams.size() == n_beams) { + std::make_heap(next_beams.begin(), next_beams.end(), comp); + } + } else if (next_beams.front().p < beam.p) { + std::pop_heap(next_beams.begin(), next_beams.end(), comp); + next_beams.back() = std::move(beam); + std::push_heap(next_beams.begin(), next_beams.end(), comp); + } + } else { + // beam is not at end-of-sentence, so branch with next top_k tokens. + if (!beam.tokens.empty()) { + llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads); + } + logit_info logit_info(ctx); + std::vector next_tokens = logit_info.top_k(n_beams); + size_t i=0; + if (next_beams.size() < n_beams) { + for (; next_beams.size() < n_beams ; ++i) { + llama_beam next_beam = beam; + next_beam.tokens.push_back(next_tokens[i].id); + next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit); + next_beams.push_back(std::move(next_beam)); + } + std::make_heap(next_beams.begin(), next_beams.end(), comp); + } else { + for (; next_beams.front().p == 0.0f ; ++i) { + std::pop_heap(next_beams.begin(), next_beams.end(), comp); + next_beams.back() = beam; + next_beams.back().tokens.push_back(next_tokens[i].id); + next_beams.back().p *= logit_info.probability_from_logit(next_tokens[i].logit); + std::push_heap(next_beams.begin(), next_beams.end(), comp); + } + } + for (; i < n_beams ; ++i) { + float const next_p = beam.p * logit_info.probability_from_logit(next_tokens[i].logit); + if (next_beams.front().p < next_p) { + std::pop_heap(next_beams.begin(), next_beams.end(), comp); + next_beams.back() = beam; + next_beams.back().tokens.push_back(next_tokens[i].id); + next_beams.back().p = next_p; + std::push_heap(next_beams.begin(), next_beams.end(), comp); + } + } + } + } + + // Find common_prefix_length based on beams. + // Requires beams is not empty. + size_t find_common_prefix_length() { + size_t common_prefix_length = beams[0].tokens.size(); + for (size_t i=1 ; i& beams) { + auto const sum_p = [](float sum, llama_beam& beam) { return sum + beam.p; }; + float const inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p); + std::for_each(beams.begin(), beams.end(), [=](llama_beam& beam) { beam.p *= inv_sum; }); + } + + // Assumes beams is non-empty. Uses llama_beam::operator<() for ordering. + size_t top_beam_index() { + return std::max_element(beams.begin(), beams.end()) - beams.begin(); + } + + // Copy (p,eos) for each beam which may have been changed by the callback. + void update_beams_from_beam_views() { + for (size_t i=0 ; it_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; +} + // // quantization // diff --git a/llama.h b/llama.h index 2bcf94e0f3fd2..e88a45078e382 100644 --- a/llama.h +++ b/llama.h @@ -465,6 +465,39 @@ extern "C" { /// @details Accepts the sampled token into the grammar LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); + struct llama_beam_view { + llama_token const* tokens; + size_t n_tokens; + float p; // Cumulative beam probability (renormalized relative to all beams) + bool eos; // Callback should set this to true when a beam is at end-of-sentence. + }; + + // Passed to beam_search_callback function. + // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams + // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. + // These pointers are valid only during the synchronous callback, so should not be saved. + struct llama_beams_state { + llama_beam_view* beam_views; + size_t n_beams; // Number of elements in beam_views[]. + size_t common_prefix_length; // Current max length of prefix tokens shared by all beams. + bool last_call; // True iff this is the last callback invocation. + }; + // Type of pointer to the beam_search_callback function. + // void* callback_data is any custom data passed to llama_beam_search, that is subsequently + // passed back to beam_search_callback. This avoids having to use global variables in the callback. + typedef void (*llama_beam_search_callback_fn_t)(void* callback_data, llama_beams_state); + + /// @details Deterministically returns entire sentence constructed by a beam search. + /// @param ctx Pointer to the llama_context. + /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state. + /// The return beam_search_control can be used to control the beam_search execution. + /// @param callback_data A pointer that is simply passed back to callback. + /// @param n_beams Number of beams to use. + /// @param n_past Number of tokens already evaluated. + /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. + /// @param n_threads Number of threads as passed to llama_eval(). + LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void* callback_data, size_t n_beams, int n_past, int n_predict, int n_threads); + // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); LLAMA_API void llama_print_timings(struct llama_context * ctx); From abe0829984d0a4473628b90cf4418b752470de49 Mon Sep 17 00:00:00 2001 From: Matt Pulver Date: Fri, 25 Aug 2023 09:18:24 -0400 Subject: [PATCH 2/8] Add '// Beam search' heading to llama.{h,cpp} after llama_grammar_accept_token(). --- llama.cpp | 4 ++++ llama.h | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/llama.cpp b/llama.cpp index af950d3d58db2..1e4cf4055dd1d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4326,6 +4326,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } +// +// Beam search +// + struct llama_beam { std::vector tokens; float p; // Cumulative beam probability (renormalized relative to all beams) diff --git a/llama.h b/llama.h index e88a45078e382..81a27d438bfac 100644 --- a/llama.h +++ b/llama.h @@ -465,6 +465,10 @@ extern "C" { /// @details Accepts the sampled token into the grammar LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); + // + // Beam search + // + struct llama_beam_view { llama_token const* tokens; size_t n_tokens; @@ -482,6 +486,7 @@ extern "C" { size_t common_prefix_length; // Current max length of prefix tokens shared by all beams. bool last_call; // True iff this is the last callback invocation. }; + // Type of pointer to the beam_search_callback function. // void* callback_data is any custom data passed to llama_beam_search, that is subsequently // passed back to beam_search_callback. This avoids having to use global variables in the callback. From 9bedaf4c7191ec0644f832b37e693a2ed5e2c714 Mon Sep 17 00:00:00 2001 From: Matt Pulver Date: Fri, 25 Aug 2023 09:22:14 -0400 Subject: [PATCH 3/8] Add space around * pointers and & references. --- examples/beam_search/beam_search.cpp | 12 ++++++------ examples/server/server.cpp | 21 ++++++++++----------- llama.cpp | 28 ++++++++++++++-------------- llama.h | 8 ++++---- 4 files changed, 34 insertions(+), 35 deletions(-) diff --git a/examples/beam_search/beam_search.cpp b/examples/beam_search/beam_search.cpp index 2bc0a378b77aa..95a8e5179b67b 100644 --- a/examples/beam_search/beam_search.cpp +++ b/examples/beam_search/beam_search.cpp @@ -29,10 +29,10 @@ // Used for debugging to print out beam tokens. struct ostream_beam_view { - llama_context* ctx; + llama_context * ctx; llama_beam_view beam_view; }; -std::ostream& operator<<(std::ostream& os, ostream_beam_view const& obv) { +std::ostream& operator<<(std::ostream& os, ostream_beam_view const & obv) { os << "p(" << obv.beam_view.p << ") eos(" << std::boolalpha << obv.beam_view.eos << ") tokens("; for (size_t i=0 ; i response; }; -bool is_at_eos(beam_search_callback_data const& callback_data, llama_token const* tokens, size_t const n_tokens) { +bool is_at_eos(beam_search_callback_data const & callback_data, llama_token const * tokens, size_t const n_tokens) { return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx); } @@ -56,7 +56,7 @@ bool is_at_eos(beam_search_callback_data const& callback_data, llama_token const // * When all beams converge to a common prefix, they are made available in beams_state.beams[0]. // This is also called when the stop condition is met. // Collect tokens into std::vector response which is pointed to by callback_data. -void beam_search_callback(void* callback_data_ptr, llama_beams_state beams_state) { +void beam_search_callback(void * callback_data_ptr, llama_beams_state beams_state) { auto& callback_data = *static_cast(callback_data_ptr); // Mark beams as EOS as needed. for (size_t i=0 ; i response which is pointed to by callback_data. -void beam_search_callback(void* callback_data, llama_beams_state beams_state) { - auto& llama = *static_cast(callback_data); +void beam_search_callback(void * callback_data, llama_beams_state beams_state) { + auto & llama = *static_cast(callback_data); // Mark beams as EOS as needed. for (size_t i=0 ; igenerated_token_probs.end() - n); + llama_token const * tokens = beams_state.beam_views[0].tokens; auto const map = [](llama_token tok) { return completion_token_output{{},tok}; }; std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map); printf("%lu", n); @@ -1248,20 +1247,20 @@ void beam_search_callback(void* callback_data, llama_beams_state beams_state) { } struct token_translator { - llama_context* ctx; + llama_context * ctx; std::string operator()(llama_token tok) const { return llama_token_to_str(ctx, tok); } std::string operator()(completion_token_output cto) const { return (*this)(cto.tok); } }; -void append_to_generated_text_from_generated_token_probs(llama_server_context& llama) { - auto& gtps = llama.generated_token_probs; +void append_to_generated_text_from_generated_token_probs(llama_server_context & llama) { + auto & gtps = llama.generated_token_probs; auto translator = token_translator{llama.ctx}; - auto add_strlen = [=](size_t sum, completion_token_output const& cto) { return sum + translator(cto).size(); }; + auto add_strlen = [=](size_t sum, completion_token_output const & cto) { return sum + translator(cto).size(); }; size_t const len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen); if (llama.generated_text.capacity() < llama.generated_text.size() + len) { llama.generated_text.reserve(llama.generated_text.size() + len); } - for (completion_token_output const& cto : gtps) { + for (completion_token_output const & cto : gtps) { llama.generated_text += translator(cto); } } diff --git a/llama.cpp b/llama.cpp index 1e4cf4055dd1d..f13c0aa6af536 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4335,7 +4335,7 @@ struct llama_beam { float p; // Cumulative beam probability (renormalized relative to all beams) bool eos; // Initialize end-of-sentence to false. Callback sets this to true. // Sort beams by probability. In case of ties, prefer beams at eos. - bool operator<(llama_beam const& rhs) const { + bool operator<(llama_beam const & rhs) const { return std::make_tuple(p, eos) < std::make_tuple(rhs.p, rhs.eos); } // Shift off first n tokens and discard them. @@ -4350,7 +4350,7 @@ struct llama_beam { // A struct for calculating logit-related info. struct logit_info { - float const* const logits; + float const * const logits; int const n_vocab; float const max_l; float const normalizer; @@ -4358,7 +4358,7 @@ struct logit_info { float max_l; float operator()(float sum, float l) const { return sum + std::exp(l - max_l); } }; - logit_info(llama_context* ctx) + logit_info(llama_context * ctx) : logits(llama_get_logits(ctx)) , n_vocab(llama_n_vocab(ctx)) , max_l(*std::max_element(logits, logits + n_vocab)) @@ -4376,7 +4376,7 @@ struct logit_info { for (llama_token token_id=0 ; token_id b.logit; }; + auto comp = [](llama_token_data const & a, llama_token_data const & b) { return a.logit > b.logit; }; std::make_heap(min_heap.begin(), min_heap.end(), comp); for (llama_token token_id=k_min ; token_id b.p; }; + auto const comp = [](llama_beam const & a, llama_beam const & b) { return a.p > b.p; }; if (beam.eos) { // beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough. if (next_beams.size() < n_beams) { @@ -4516,9 +4516,9 @@ struct beam_search { // * any of the beams have not yet reached end-of-sentence, AND // * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence // (since all other beam probabilities can only decrease) - void loop(llama_beam_search_callback_fn_t const callback, void* const callback_data) { + void loop(llama_beam_search_callback_fn_t const callback, void * const callback_data) { beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eos. - auto const not_eos = [](llama_beam const& beam) { return !beam.eos; }; + auto const not_eos = [](llama_beam const & beam) { return !beam.eos; }; for (int i=0 ; i& beams) { - auto const sum_p = [](float sum, llama_beam& beam) { return sum + beam.p; }; + static void renormalize_beam_probabilities(std::vector & beams) { + auto const sum_p = [](float sum, llama_beam & beam) { return sum + beam.p; }; float const inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p); - std::for_each(beams.begin(), beams.end(), [=](llama_beam& beam) { beam.p *= inv_sum; }); + std::for_each(beams.begin(), beams.end(), [=](llama_beam & beam) { beam.p *= inv_sum; }); } // Assumes beams is non-empty. Uses llama_beam::operator<() for ordering. @@ -4564,7 +4564,7 @@ struct beam_search { }; void llama_beam_search(llama_context * ctx, - llama_beam_search_callback_fn_t callback, void* callback_data, + llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads) { assert(ctx); const int64_t t_start_sample_us = ggml_time_us(); diff --git a/llama.h b/llama.h index 81a27d438bfac..c19a60a5dc7fb 100644 --- a/llama.h +++ b/llama.h @@ -470,7 +470,7 @@ extern "C" { // struct llama_beam_view { - llama_token const* tokens; + const llama_token * tokens; size_t n_tokens; float p; // Cumulative beam probability (renormalized relative to all beams) bool eos; // Callback should set this to true when a beam is at end-of-sentence. @@ -481,7 +481,7 @@ extern "C" { // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. // These pointers are valid only during the synchronous callback, so should not be saved. struct llama_beams_state { - llama_beam_view* beam_views; + llama_beam_view * beam_views; size_t n_beams; // Number of elements in beam_views[]. size_t common_prefix_length; // Current max length of prefix tokens shared by all beams. bool last_call; // True iff this is the last callback invocation. @@ -490,7 +490,7 @@ extern "C" { // Type of pointer to the beam_search_callback function. // void* callback_data is any custom data passed to llama_beam_search, that is subsequently // passed back to beam_search_callback. This avoids having to use global variables in the callback. - typedef void (*llama_beam_search_callback_fn_t)(void* callback_data, llama_beams_state); + typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, llama_beams_state); /// @details Deterministically returns entire sentence constructed by a beam search. /// @param ctx Pointer to the llama_context. @@ -501,7 +501,7 @@ extern "C" { /// @param n_past Number of tokens already evaluated. /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. /// @param n_threads Number of threads as passed to llama_eval(). - LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void* callback_data, size_t n_beams, int n_past, int n_predict, int n_threads); + LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads); // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); From e46a8b517fb176829fe8cf3ddc6444cab1e33ce7 Mon Sep 17 00:00:00 2001 From: Matt Pulver Date: Fri, 25 Aug 2023 09:31:19 -0400 Subject: [PATCH 4/8] Add spaces around comparison and assignment operators. --- examples/beam_search/beam_search.cpp | 6 +++--- examples/server/server.cpp | 2 +- llama.cpp | 14 +++++++------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/beam_search/beam_search.cpp b/examples/beam_search/beam_search.cpp index 95a8e5179b67b..29b69b76df40c 100644 --- a/examples/beam_search/beam_search.cpp +++ b/examples/beam_search/beam_search.cpp @@ -34,7 +34,7 @@ struct ostream_beam_view { }; std::ostream& operator<<(std::ostream& os, ostream_beam_view const & obv) { os << "p(" << obv.beam_view.p << ") eos(" << std::boolalpha << obv.beam_view.eos << ") tokens("; - for (size_t i=0 ; i(callback_data_ptr); // Mark beams as EOS as needed. - for (size_t i=0 ; i(callback_data); // Mark beams as EOS as needed. - for (size_t i=0 ; i min_heap; // min-heap by logit llama_token const k_min = std::min(static_cast(k), n_vocab); min_heap.reserve(k_min); - for (llama_token token_id=0 ; token_id b.logit; }; std::make_heap(min_heap.begin(), min_heap.end(), comp); - for (llama_token token_id=k_min ; token_id Date: Fri, 25 Aug 2023 09:34:13 -0400 Subject: [PATCH 5/8] Prefer west const. --- examples/beam_search/beam_search.cpp | 8 +++---- examples/server/server.cpp | 14 ++++++------ llama.cpp | 34 ++++++++++++++-------------- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/examples/beam_search/beam_search.cpp b/examples/beam_search/beam_search.cpp index 29b69b76df40c..1d0d077d152c3 100644 --- a/examples/beam_search/beam_search.cpp +++ b/examples/beam_search/beam_search.cpp @@ -32,7 +32,7 @@ struct ostream_beam_view { llama_context * ctx; llama_beam_view beam_view; }; -std::ostream& operator<<(std::ostream& os, ostream_beam_view const & obv) { +std::ostream& operator<<(std::ostream& os, const ostream_beam_view & obv) { os << "p(" << obv.beam_view.p << ") eos(" << std::boolalpha << obv.beam_view.eos << ") tokens("; for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) { os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]); @@ -46,7 +46,7 @@ struct beam_search_callback_data { std::vector response; }; -bool is_at_eos(beam_search_callback_data const & callback_data, llama_token const * tokens, size_t const n_tokens) { +bool is_at_eos(const beam_search_callback_data & callback_data, const llama_token * tokens, const size_t n_tokens) { return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx); } @@ -66,10 +66,10 @@ void beam_search_callback(void * callback_data_ptr, llama_beams_state beams_stat } } printf(","); // Show progress - if (size_t const n = beams_state.common_prefix_length) { + if (const size_t n = beams_state.common_prefix_length) { callback_data.response.resize(callback_data.response.size() + n); assert(0u < beams_state.n_beams); - llama_token const * tokens = beams_state.beam_views[0].tokens; + const llama_token * tokens = beams_state.beam_views[0].tokens; std::copy(tokens, tokens + n, callback_data.response.end() - n); printf("%lu", n); } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 33d8575e0148a..94a029bbf21fd 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1209,7 +1209,7 @@ static void log_server_request(const Request &req, const Response &res) }); } -bool is_at_eos(llama_server_context & server_context, llama_token const * tokens, size_t const n_tokens) { +bool is_at_eos(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) { return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx); } @@ -1229,11 +1229,11 @@ void beam_search_callback(void * callback_data, llama_beams_state beams_state) { } } printf(","); // Show progress - if (size_t const n = beams_state.common_prefix_length) { + if (const size_t n = beams_state.common_prefix_length) { llama.generated_token_probs.resize(llama.generated_token_probs.size() + n); assert(0u < beams_state.n_beams); - llama_token const * tokens = beams_state.beam_views[0].tokens; - auto const map = [](llama_token tok) { return completion_token_output{{},tok}; }; + const llama_token * tokens = beams_state.beam_views[0].tokens; + const auto map = [](llama_token tok) { return completion_token_output{{},tok}; }; std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map); printf("%lu", n); } @@ -1255,12 +1255,12 @@ struct token_translator { void append_to_generated_text_from_generated_token_probs(llama_server_context & llama) { auto & gtps = llama.generated_token_probs; auto translator = token_translator{llama.ctx}; - auto add_strlen = [=](size_t sum, completion_token_output const & cto) { return sum + translator(cto).size(); }; - size_t const len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen); + auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); }; + const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen); if (llama.generated_text.capacity() < llama.generated_text.size() + len) { llama.generated_text.reserve(llama.generated_text.size() + len); } - for (completion_token_output const & cto : gtps) { + for (const completion_token_output & cto : gtps) { llama.generated_text += translator(cto); } } diff --git a/llama.cpp b/llama.cpp index 434acc3278eeb..15b617fc9f0cc 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4335,11 +4335,11 @@ struct llama_beam { float p; // Cumulative beam probability (renormalized relative to all beams) bool eos; // Initialize end-of-sentence to false. Callback sets this to true. // Sort beams by probability. In case of ties, prefer beams at eos. - bool operator<(llama_beam const & rhs) const { + bool operator<(const llama_beam & rhs) const { return std::make_tuple(p, eos) < std::make_tuple(rhs.p, rhs.eos); } // Shift off first n tokens and discard them. - void shift_tokens(size_t const n) { + void shift_tokens(const size_t n) { if (n) { std::copy(tokens.begin() + n, tokens.end(), tokens.begin()); tokens.resize(tokens.size() - n); @@ -4350,10 +4350,10 @@ struct llama_beam { // A struct for calculating logit-related info. struct logit_info { - float const * const logits; - int const n_vocab; - float const max_l; - float const normalizer; + const float * const logits; + const int n_vocab; + const float max_l; + const float normalizer; struct sum_exp { float max_l; float operator()(float sum, float l) const { return sum + std::exp(l - max_l); } @@ -4364,19 +4364,19 @@ struct logit_info { , max_l(*std::max_element(logits, logits + n_vocab)) , normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l})) { } - llama_token_data get_token_data(llama_token const token_id) const { + llama_token_data get_token_data(const llama_token token_id) const { constexpr auto p = std::numeric_limits::quiet_NaN(); // never used return {token_id, logits[token_id], p}; } // Return top k token_data by logit. std::vector top_k(size_t k) { std::vector min_heap; // min-heap by logit - llama_token const k_min = std::min(static_cast(k), n_vocab); + const llama_token k_min = std::min(static_cast(k), n_vocab); min_heap.reserve(k_min); for (llama_token token_id = 0 ; token_id < k_min ; ++token_id) { min_heap.push_back(get_token_data(token_id)); } - auto comp = [](llama_token_data const & a, llama_token_data const & b) { return a.logit > b.logit; }; + auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; std::make_heap(min_heap.begin(), min_heap.end(), comp); for (llama_token token_id = k_min ; token_id < n_vocab ; ++token_id) { if (min_heap.front().logit < logits[token_id]) { @@ -4420,7 +4420,7 @@ struct beam_search { } // Collapse beams to a single beam given by index. - void collapse_beams(size_t const beam_idx) { + void collapse_beams(const size_t beam_idx) { if (0u < beam_idx) { std::swap(beams[0], beams[beam_idx]); } @@ -4434,7 +4434,7 @@ struct beam_search { // least element to the back(), replace it with the new, then push it into the heap. void fill_next_beams_by_top_probabilities(llama_beam & beam) { // Min-heaps use a greater-than comparator. - auto const comp = [](llama_beam const & a, llama_beam const & b) { return a.p > b.p; }; + const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p; }; if (beam.eos) { // beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough. if (next_beams.size() < n_beams) { @@ -4473,7 +4473,7 @@ struct beam_search { } } for (; i < n_beams ; ++i) { - float const next_p = beam.p * logit_info.probability_from_logit(next_tokens[i].logit); + const float next_p = beam.p * logit_info.probability_from_logit(next_tokens[i].logit); if (next_beams.front().p < next_p) { std::pop_heap(next_beams.begin(), next_beams.end(), comp); next_beams.back() = beam; @@ -4503,7 +4503,7 @@ struct beam_search { // Construct beams_state to send back to caller via the callback function. // Side effect: set common_prefix_length = find_common_prefix_length(); - llama_beams_state get_beams_state(bool const last_call) { + llama_beams_state get_beams_state(const bool last_call) { for (size_t i = 0 ; i < beams.size() ; ++i) { beam_views[i] = beams[i].view(); } @@ -4516,9 +4516,9 @@ struct beam_search { // * any of the beams have not yet reached end-of-sentence, AND // * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence // (since all other beam probabilities can only decrease) - void loop(llama_beam_search_callback_fn_t const callback, void * const callback_data) { + void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) { beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eos. - auto const not_eos = [](llama_beam const & beam) { return !beam.eos; }; + const auto not_eos = [](const llama_beam & beam) { return !beam.eos; }; for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eos) && !beams[top_beam_index()].eos ; ++i) { callback(callback_data, get_beams_state(false)); // Sets common_prefix_length @@ -4544,8 +4544,8 @@ struct beam_search { // As beams grow, the cumulative probabilities decrease. // Renormalize them to avoid floating point underflow. static void renormalize_beam_probabilities(std::vector & beams) { - auto const sum_p = [](float sum, llama_beam & beam) { return sum + beam.p; }; - float const inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p); + const auto sum_p = [](float sum, llama_beam & beam) { return sum + beam.p; }; + const float inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p); std::for_each(beams.begin(), beams.end(), [=](llama_beam & beam) { beam.p *= inv_sum; }); } From fa33614b4d55813e066769ba2aa9ea4c0db81f55 Mon Sep 17 00:00:00 2001 From: Matt Pulver Date: Fri, 25 Aug 2023 09:36:34 -0400 Subject: [PATCH 6/8] Use llama_ prefix for structs in global namespace. --- llama.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/llama.cpp b/llama.cpp index 15b617fc9f0cc..9f23a6a9d9d05 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4349,7 +4349,7 @@ struct llama_beam { }; // A struct for calculating logit-related info. -struct logit_info { +struct llama_logit_info { const float * const logits; const int n_vocab; const float max_l; @@ -4358,7 +4358,7 @@ struct logit_info { float max_l; float operator()(float sum, float l) const { return sum + std::exp(l - max_l); } }; - logit_info(llama_context * ctx) + llama_logit_info(llama_context * ctx) : logits(llama_get_logits(ctx)) , n_vocab(llama_n_vocab(ctx)) , max_l(*std::max_element(logits, logits + n_vocab)) @@ -4393,7 +4393,7 @@ struct logit_info { } }; -struct beam_search { +struct llama_beam_search_data { llama_context * ctx; size_t n_beams; int n_past; @@ -4408,7 +4408,7 @@ struct beam_search { // Used to communicate to/from callback on beams state. std::vector beam_views; - beam_search(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads) + llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads) : ctx(ctx) , n_beams(n_beams) , n_past(n_past) @@ -4452,7 +4452,7 @@ struct beam_search { if (!beam.tokens.empty()) { llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads); } - logit_info logit_info(ctx); + llama_logit_info logit_info(ctx); std::vector next_tokens = logit_info.top_k(n_beams); size_t i=0; if (next_beams.size() < n_beams) { @@ -4569,9 +4569,9 @@ void llama_beam_search(llama_context * ctx, assert(ctx); const int64_t t_start_sample_us = ggml_time_us(); - beam_search beam_search(ctx, n_beams, n_past, n_predict, n_threads); + llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict, n_threads); - beam_search.loop(callback, callback_data); + beam_search_data.loop(callback, callback_data); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->n_sample++; From b619cfc059f093803341930e1a57dff8f554d082 Mon Sep 17 00:00:00 2001 From: Matt Pulver Date: Fri, 25 Aug 2023 09:38:15 -0400 Subject: [PATCH 7/8] Delete obsolete comment from an earlier revision. --- llama.h | 1 - 1 file changed, 1 deletion(-) diff --git a/llama.h b/llama.h index c19a60a5dc7fb..47e7a2ebe3cbc 100644 --- a/llama.h +++ b/llama.h @@ -495,7 +495,6 @@ extern "C" { /// @details Deterministically returns entire sentence constructed by a beam search. /// @param ctx Pointer to the llama_context. /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state. - /// The return beam_search_control can be used to control the beam_search execution. /// @param callback_data A pointer that is simply passed back to callback. /// @param n_beams Number of beams to use. /// @param n_past Number of tokens already evaluated. From 5fa1ea2c38c208ad9f0c523ffb3c8033ec9648af Mon Sep 17 00:00:00 2001 From: Matt Pulver Date: Fri, 25 Aug 2023 09:47:52 -0400 Subject: [PATCH 8/8] Change eos to eob in llama_beam and llama_beam_view structs. --- examples/beam_search/beam_search.cpp | 10 ++++++---- examples/server/server.cpp | 8 ++++---- llama.cpp | 26 +++++++++++++------------- llama.h | 2 +- 4 files changed, 24 insertions(+), 22 deletions(-) diff --git a/examples/beam_search/beam_search.cpp b/examples/beam_search/beam_search.cpp index 1d0d077d152c3..1c04fabc21b3d 100644 --- a/examples/beam_search/beam_search.cpp +++ b/examples/beam_search/beam_search.cpp @@ -33,7 +33,7 @@ struct ostream_beam_view { llama_beam_view beam_view; }; std::ostream& operator<<(std::ostream& os, const ostream_beam_view & obv) { - os << "p(" << obv.beam_view.p << ") eos(" << std::boolalpha << obv.beam_view.eos << ") tokens("; + os << "p(" << obv.beam_view.p << ") eob(" << std::boolalpha << obv.beam_view.eob << ") tokens("; for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) { os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]); } @@ -46,7 +46,9 @@ struct beam_search_callback_data { std::vector response; }; -bool is_at_eos(const beam_search_callback_data & callback_data, const llama_token * tokens, const size_t n_tokens) { +// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same. +// For example, eob can be flagged due to maximum token length, stop words, etc. +bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, const size_t n_tokens) { return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx); } @@ -61,8 +63,8 @@ void beam_search_callback(void * callback_data_ptr, llama_beams_state beams_stat // Mark beams as EOS as needed. for (size_t i = 0 ; i < beams_state.n_beams ; ++i) { llama_beam_view& beam_view = beams_state.beam_views[i]; - if (!beam_view.eos && is_at_eos(callback_data, beam_view.tokens, beam_view.n_tokens)) { - beam_view.eos = true; + if (!beam_view.eob && is_at_eob(callback_data, beam_view.tokens, beam_view.n_tokens)) { + beam_view.eob = true; } } printf(","); // Show progress diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 94a029bbf21fd..3300553f9b397 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1209,7 +1209,7 @@ static void log_server_request(const Request &req, const Response &res) }); } -bool is_at_eos(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) { +bool is_at_eob(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) { return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx); } @@ -1223,9 +1223,9 @@ void beam_search_callback(void * callback_data, llama_beams_state beams_state) { auto & llama = *static_cast(callback_data); // Mark beams as EOS as needed. for (size_t i = 0 ; i < beams_state.n_beams ; ++i) { - llama_beam_view & beam_view = beams_state.beam_views[i]; - if (!beam_view.eos && is_at_eos(llama, beam_view.tokens, beam_view.n_tokens)) { - beam_view.eos = true; + llama_beam_view& beam_view = beams_state.beam_views[i]; + if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) { + beam_view.eob = true; } } printf(","); // Show progress diff --git a/llama.cpp b/llama.cpp index 9f23a6a9d9d05..5b8e3bbaae4d0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4333,10 +4333,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar struct llama_beam { std::vector tokens; float p; // Cumulative beam probability (renormalized relative to all beams) - bool eos; // Initialize end-of-sentence to false. Callback sets this to true. - // Sort beams by probability. In case of ties, prefer beams at eos. + bool eob; // Initialize end-of-beam to false. Callback sets this to true. + // Sort beams by probability. In case of ties, prefer beams at eob. bool operator<(const llama_beam & rhs) const { - return std::make_tuple(p, eos) < std::make_tuple(rhs.p, rhs.eos); + return std::make_pair(p, eob) < std::make_pair(rhs.p, rhs.eob); } // Shift off first n tokens and discard them. void shift_tokens(const size_t n) { @@ -4345,7 +4345,7 @@ struct llama_beam { tokens.resize(tokens.size() - n); } } - llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eos}; } + llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eob}; } }; // A struct for calculating logit-related info. @@ -4435,7 +4435,7 @@ struct llama_beam_search_data { void fill_next_beams_by_top_probabilities(llama_beam & beam) { // Min-heaps use a greater-than comparator. const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p; }; - if (beam.eos) { + if (beam.eob) { // beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough. if (next_beams.size() < n_beams) { next_beams.push_back(std::move(beam)); @@ -4513,16 +4513,16 @@ struct llama_beam_search_data { // Loop: // * while i < n_predict, AND - // * any of the beams have not yet reached end-of-sentence, AND + // * any of the beams have not yet reached end-of-beam (eob), AND // * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence // (since all other beam probabilities can only decrease) void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) { - beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eos. - const auto not_eos = [](const llama_beam & beam) { return !beam.eos; }; - for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eos) && - !beams[top_beam_index()].eos ; ++i) { + beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eob. + const auto not_eob = [](const llama_beam & beam) { return !beam.eob; }; + for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eob) && + !beams[top_beam_index()].eob ; ++i) { callback(callback_data, get_beams_state(false)); // Sets common_prefix_length - update_beams_from_beam_views(); // Update values (p,eos) that callback may have changed. + update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed. if (common_prefix_length) { llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads); n_past += common_prefix_length; @@ -4554,11 +4554,11 @@ struct llama_beam_search_data { return std::max_element(beams.begin(), beams.end()) - beams.begin(); } - // Copy (p,eos) for each beam which may have been changed by the callback. + // Copy (p,eob) for each beam which may have been changed by the callback. void update_beams_from_beam_views() { for (size_t i = 0 ; i < beams.size() ; ++i) { beams[i].p = beam_views[i].p; - beams[i].eos = beam_views[i].eos; + beams[i].eob = beam_views[i].eob; } } }; diff --git a/llama.h b/llama.h index 47e7a2ebe3cbc..cca803181b49c 100644 --- a/llama.h +++ b/llama.h @@ -473,7 +473,7 @@ extern "C" { const llama_token * tokens; size_t n_tokens; float p; // Cumulative beam probability (renormalized relative to all beams) - bool eos; // Callback should set this to true when a beam is at end-of-sentence. + bool eob; // Callback should set this to true when a beam is at end-of-beam. }; // Passed to beam_search_callback function.