From 0e9f0e0f52c2e25b8453a86d575dbb0f07287d62 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 20 May 2025 12:59:07 -0600 Subject: [PATCH 01/21] tests: Initial unit tests for memory hierarchy These only test the basics so far, but should allow for more expansive tests to come. Branch: MemoryTests Signed-off-by: Gabe Goodhart --- tests/test-memory.cpp | 175 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 tests/test-memory.cpp diff --git a/tests/test-memory.cpp b/tests/test-memory.cpp new file mode 100644 index 0000000000000..ad6c13800cbb6 --- /dev/null +++ b/tests/test-memory.cpp @@ -0,0 +1,175 @@ +/*------------------------------------------------------------------------------ + * Unit tests for llama-memory.h and derived memory implementations. It contains + * a number of tests which can be run all together or separately. + * + * USAGE: ./bin/test-memory + * + * When adding a new test, do the following: + * + * 1. Add the new test__description function under the + * appropriate memory type section + * + * 2. Add `RUN_TEST(test__description);` to main + *----------------------------------------------------------------------------*/ + +#include "../src/llama-arch.h" +#include "../src/llama-batch.h" +#include "../src/llama-hparams.h" +#include "../src/llama-impl.h" +#include "../src/llama-kv-cache.h" +#include "../src/llama-model.h" + +#include "common.h" +#include "llama.h" + +#include +#include +#include + +/*- Helpers ------------------------------------------------------------------*/ + +static std::shared_ptr _make_model( + llm_arch arch = LLM_ARCH_LLAMA, + uint32_t n_layer = 4, + uint32_t n_embd_head_k = 4, + uint32_t n_embd_head_v = 4, + uint32_t n_head = 8, + uint32_t n_head_kv = 2) { + + llama_model_params params; + params.tensor_buft_overrides = nullptr; + std::shared_ptr model(new llama_model(params)); + model->hparams = llama_hparams(); + model->arch = arch; + + model->hparams.n_layer = n_layer; + model->hparams.n_embd_head_k = n_embd_head_k; + model->hparams.n_embd_head_v = n_embd_head_v; + + // If set to 0, assume the test will fill out the array elementwise (hybrid) + if (n_head > 0) { + auto& n_head_arr = model->hparams.n_head_arr; + std::fill(n_head_arr.begin(), n_head_arr.end(), n_head); + } + if (n_head_kv > 0) { + auto& n_head_kv_arr = model->hparams.n_head_kv_arr; + std::fill(n_head_kv_arr.begin(), n_head_kv_arr.end(), n_head_kv); + } + + return model; +} + +struct log_scope { + const char * name; + explicit log_scope(const char * name) : name(name) { + LLAMA_LOG_INFO("--------\n"); + LLAMA_LOG_INFO("START: %s\n", name); + } + ~log_scope() { + LLAMA_LOG_INFO("END: %s\n", name); + LLAMA_LOG_INFO("--------\n"); + } +}; + +#define RUN_TEST(test_name) \ + do { \ + bool run_test = argc < 2; \ + std::vector args(argv + 1, argv + argc); \ + if (std::find(args.begin(), args.end(), #test_name) != args.end()) \ + run_test = true; \ + if (run_test) { \ + log_scope __log_scope(#test_name); \ + test_name(); \ + } \ + } while (0) + +/*- Unified Cache ------------------------------------------------------------*/ + +/* Test that the unified cache can be constructed and destructed safely */ +static void test_llama_kv_cache_unified_constructor() { + auto model = _make_model(); + llama_kv_cache_unified cache( + /* model */ *model, + /* filter */ nullptr, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* v_trans */ false, + /* offload */ false, + /* kv_size */ 10, + /* padding */ 10, + /* n_swa */ 0, + /* swa_type */ LLAMA_SWA_TYPE_NONE + ); +} + +/* Test that the unified cache can operate with a single seq */ +static void test_llama_kv_cache_unified_single_seq() { + auto model = _make_model(); + llama_kv_cache_unified cache( + /* model */ *model, + /* filter */ nullptr, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* v_trans */ false, + /* offload */ false, + /* kv_size */ 10, + /* padding */ 10, + /* n_swa */ 0, + /* swa_type */ LLAMA_SWA_TYPE_NONE + ); + GGML_ASSERT(cache.get_used_cells() == 0); + + // Create the micro batch with a single 3-token sequence + // + // NOTE: A bunch of these asserts were just me figuring out how the batches + // relate to each other, but they're left for future readers to help in the + // same understanding process. + llama_seq_id seq_id = 42; + llama_batch batch = llama_batch_init(3, 0, 1); + common_batch_add(batch, 101, 0, {seq_id}, false); + common_batch_add(batch, 1, 1, {seq_id}, false); + common_batch_add(batch, 102, 2, {seq_id}, false); + llama_sbatch sbatch(batch, 0, true, false); + GGML_ASSERT(batch.n_tokens == 3); + GGML_ASSERT(sbatch.n_tokens == 3); + GGML_ASSERT(!sbatch.seq.empty()); + llama_ubatch ubatch = sbatch.split_simple(4); + printf("ubatch.n_seqs=%d\n", ubatch.n_seqs); + GGML_ASSERT(ubatch.n_seqs == 3); + GGML_ASSERT(ubatch.n_seq_tokens == 1); + GGML_ASSERT(ubatch.n_tokens == 3); + GGML_ASSERT(ubatch.seq_id[0][0] == seq_id); + GGML_ASSERT(ubatch.seq_id[1][0] == seq_id); + GGML_ASSERT(ubatch.seq_id[2][0] == seq_id); + + // Find a slot for a new sequence + GGML_ASSERT(cache.find_slot(ubatch)); + + // Clean up + llama_batch_free(batch); +} + +/*- Recurrent Cache ----------------------------------------------------------*/ + +/* Test that the recurrent cache can be constructed and destructed safely */ +static void test_llama_kv_cache_recurrent_constructor() { + auto model = _make_model(LLM_ARCH_MAMBA); + llama_kv_cache_recurrent cache( + /* model */ *model, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* offload */ false, + /* kv_size */ 10 + ); +} + +/*- Main ---------------------------------------------------------------------*/ + +int main(int argc, char* argv[]) { + // Unified Cache Tests + RUN_TEST(test_llama_kv_cache_unified_constructor); + RUN_TEST(test_llama_kv_cache_unified_single_seq); + // Recurrent Cache Tests + RUN_TEST(test_llama_kv_cache_recurrent_constructor); + return 0; +} From b656613196fe32b546343ce61216a07117adfe1c Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 20 May 2025 12:59:36 -0600 Subject: [PATCH 02/21] build: Add build step for test-memory on non-windows builds These tests use private headers, so won't build on windows Branch: MemoryTests Signed-off-by: Gabe Goodhart --- tests/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 62a9f5842bca8..ff3d97d7a27eb 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -109,6 +109,7 @@ if (NOT WIN32) llama_build_and_test(test-grammar-integration.cpp) llama_build_and_test(test-llama-grammar.cpp) llama_build_and_test(test-chat.cpp) + llama_build_and_test(test-memory.cpp) # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8 if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") llama_build_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..) From 9a15f2793a3b74db35c6279a44d5e87a37089c4e Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 27 May 2025 09:00:50 -0600 Subject: [PATCH 03/21] fix(tests): Fix constructors in tests for signature changes after rebase Branch: HybridCache Signed-off-by: Gabe Goodhart --- tests/test-memory.cpp | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/tests/test-memory.cpp b/tests/test-memory.cpp index ad6c13800cbb6..36b008da9d8ca 100644 --- a/tests/test-memory.cpp +++ b/tests/test-memory.cpp @@ -89,16 +89,17 @@ struct log_scope { static void test_llama_kv_cache_unified_constructor() { auto model = _make_model(); llama_kv_cache_unified cache( - /* model */ *model, - /* filter */ nullptr, - /* type_k */ GGML_TYPE_F32, - /* type_v */ GGML_TYPE_F16, - /* v_trans */ false, - /* offload */ false, - /* kv_size */ 10, - /* padding */ 10, - /* n_swa */ 0, - /* swa_type */ LLAMA_SWA_TYPE_NONE + /* model */ *model, + /* filter */ nullptr, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* v_trans */ false, + /* offload */ false, + /* kv_size */ 10, + /* n_seq_max */ 1, + /* padding */ 10, + /* n_swa */ 0, + /* swa_type */ LLAMA_SWA_TYPE_NONE ); } @@ -113,11 +114,11 @@ static void test_llama_kv_cache_unified_single_seq() { /* v_trans */ false, /* offload */ false, /* kv_size */ 10, + /* n_seq_max */ 1, /* padding */ 10, /* n_swa */ 0, /* swa_type */ LLAMA_SWA_TYPE_NONE ); - GGML_ASSERT(cache.get_used_cells() == 0); // Create the micro batch with a single 3-token sequence // @@ -155,11 +156,12 @@ static void test_llama_kv_cache_unified_single_seq() { static void test_llama_kv_cache_recurrent_constructor() { auto model = _make_model(LLM_ARCH_MAMBA); llama_kv_cache_recurrent cache( - /* model */ *model, - /* type_k */ GGML_TYPE_F32, - /* type_v */ GGML_TYPE_F16, - /* offload */ false, - /* kv_size */ 10 + /* model */ *model, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* offload */ false, + /* kv_size */ 10, + /* n_seq_max */ 1 ); } From d74d76a64ec5c999f0c973c9cc05eb07bc57c9b5 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 23 May 2025 17:29:04 -0600 Subject: [PATCH 04/21] tests(wip): More robust test for unified cache I'm still not clear how cache hits should be detected since find_slot does not seem to take into account the tokens themselves and simply looks for a sequence of cells that fits the size of the ubatch and has no set positions in any of the cells. I'm clearly still missing something about how this works! Branch: HybridCache Signed-off-by: Gabe Goodhart --- tests/test-memory.cpp | 98 +++++++++++++++++++++++++++++++++---------- 1 file changed, 75 insertions(+), 23 deletions(-) diff --git a/tests/test-memory.cpp b/tests/test-memory.cpp index 36b008da9d8ca..d843a6b5ea175 100644 --- a/tests/test-memory.cpp +++ b/tests/test-memory.cpp @@ -20,6 +20,7 @@ #include "../src/llama-model.h" #include "common.h" +#include "ggml.h" #include "llama.h" #include @@ -59,6 +60,43 @@ static std::shared_ptr _make_model( return model; } +static llama_batch _make_batch( + std::vector> token_seqs, + std::vector> seq_ids) { + GGML_ASSERT(token_seqs.size() == seq_ids.size()); + + size_t total_tokens = 0; + for (const auto & token_seq : token_seqs) { + total_tokens += token_seq.size(); + } + size_t max_seq_ids = 0; + for (const auto & seq_ids_i : seq_ids) { + max_seq_ids = std::max(max_seq_ids, seq_ids_i.size()); + } + llama_batch batch = llama_batch_init(total_tokens, 0, max_seq_ids); + + for (size_t i = 0; i < token_seqs.size(); ++i) { + const auto& token_seq = token_seqs[i]; + const auto& seq_ids_i = seq_ids[i]; + for (int pos = 0; pos < (int)token_seq.size(); ++pos) { + common_batch_add(batch, token_seq[pos], pos, seq_ids_i, false); + } + } + return batch; +} + +static bool is_source_tensor(ggml_tensor * child, ggml_tensor * parent) { + if (!child || !parent) return false; + for (size_t i = 0; i < GGML_MAX_SRC; ++i) { + if (child->src[i] == parent) { + return true; + } else if (child->src[i] != nullptr && is_source_tensor(child->src[i], parent)) { + return true; + } + } + return false; +} + struct log_scope { const char * name; explicit log_scope(const char * name) : name(name) { @@ -121,33 +159,47 @@ static void test_llama_kv_cache_unified_single_seq() { ); // Create the micro batch with a single 3-token sequence - // - // NOTE: A bunch of these asserts were just me figuring out how the batches - // relate to each other, but they're left for future readers to help in the - // same understanding process. - llama_seq_id seq_id = 42; - llama_batch batch = llama_batch_init(3, 0, 1); - common_batch_add(batch, 101, 0, {seq_id}, false); - common_batch_add(batch, 1, 1, {seq_id}, false); - common_batch_add(batch, 102, 2, {seq_id}, false); - llama_sbatch sbatch(batch, 0, true, false); - GGML_ASSERT(batch.n_tokens == 3); - GGML_ASSERT(sbatch.n_tokens == 3); - GGML_ASSERT(!sbatch.seq.empty()); - llama_ubatch ubatch = sbatch.split_simple(4); - printf("ubatch.n_seqs=%d\n", ubatch.n_seqs); - GGML_ASSERT(ubatch.n_seqs == 3); - GGML_ASSERT(ubatch.n_seq_tokens == 1); - GGML_ASSERT(ubatch.n_tokens == 3); - GGML_ASSERT(ubatch.seq_id[0][0] == seq_id); - GGML_ASSERT(ubatch.seq_id[1][0] == seq_id); - GGML_ASSERT(ubatch.seq_id[2][0] == seq_id); + llama_batch batch1 = _make_batch({{101, 1, 102}}, {{42}}); + llama_sbatch sbatch1 = cache.sbatch_init(batch1, false); + llama_ubatch ubatch1 = cache.ubatch_next(sbatch1, 4, false); // Find a slot for a new sequence - GGML_ASSERT(cache.find_slot(ubatch)); + GGML_ASSERT(cache.find_slot(ubatch1)); + + // Cache the k/v for a single layer in this slot + ggml_context * ctx = ggml_init({10240, NULL, false}); + ggml_tensor * k1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_k_gqa(0)); + ggml_tensor * v1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_v_gqa(0)); + ggml_tensor * k1_view = cache.cpy_k(ctx, k1, 0); + ggml_tensor * v1_view = cache.cpy_v(ctx, v1, 0); + GGML_ASSERT(is_source_tensor(k1_view, k1)); + GGML_ASSERT(is_source_tensor(v1_view, v1)); + + // Create a second batch with different tokens and find a slot for it + llama_batch batch2 = _make_batch({{1, 2, 3, 4}}, {{5}}); + llama_sbatch sbatch2 = cache.sbatch_init(batch2, false); + llama_ubatch ubatch2 = cache.ubatch_next(sbatch2, 4, false); + GGML_ASSERT(cache.find_slot(ubatch2)); + + // Add some different tensors + ggml_tensor * k2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_k_gqa(0)); + ggml_tensor * v2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_v_gqa(0)); + ggml_tensor * k2_view = cache.cpy_k(ctx, k2, 0); + ggml_tensor * v2_view = cache.cpy_v(ctx, v2, 0); + GGML_ASSERT(is_source_tensor(k2_view, k2)); + GGML_ASSERT(is_source_tensor(v2_view, v2)); + + // Make sure first batch's k/v aren't cache hit + GGML_ASSERT(!is_source_tensor(k2_view, k1)); + GGML_ASSERT(!is_source_tensor(v2_view, v1)); + + // Re-find the slot for the first batch and make sure they cache hit + GGML_ASSERT(cache.find_slot(ubatch1)); // Clean up - llama_batch_free(batch); + llama_batch_free(batch1); + llama_batch_free(batch2); + ggml_free(ctx); } /*- Recurrent Cache ----------------------------------------------------------*/ From 114f2cea8e750c7f9db6c25754c64643162d7e42 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 27 May 2025 11:49:58 -0600 Subject: [PATCH 05/21] feat: Add can_seq_rm API to llama_kv_cache API This will be key for the hybrid cache which needs to be able to validate that all children can perform seq_rm cleanly before attempting to remove the seq from any single child to avoid ending up in a corrupted state. Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.cpp | 69 +++++++++++++++++++++++++++++++++--------- src/llama-kv-cache.h | 11 +++++++ 2 files changed, 65 insertions(+), 15 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3e3d26286e1ee..de1decf8ab680 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -518,6 +518,14 @@ void llama_kv_cache_unified::set_full() { head = 0; } +bool llama_kv_cache_unified::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const { + GGML_UNUSED(seq_id); + GGML_UNUSED(p0); + GGML_UNUSED(p1); + // Unified attention cache can always do a sequence removal + return true; +} + int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; @@ -1861,6 +1869,15 @@ void llama_kv_cache_unified_iswa::set_full() { kv_swa ->set_full(); } +bool llama_kv_cache_unified_iswa::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const { + GGML_UNUSED(seq_id); + GGML_UNUSED(p0); + GGML_UNUSED(p1); + // Unified attention caches can always do a sequence removal, so since both + // children can, the parent can as well. + return true; +} + bool llama_kv_cache_unified_iswa::get_can_shift() const { return kv_base->get_size() == kv_swa->get_size(); } @@ -2051,39 +2068,33 @@ void llama_kv_cache_recurrent::clear() { } bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - uint32_t new_head = size; + if (!can_seq_rm(seq_id, p0, p1)) { + // could be fatal + return false; + } + uint32_t new_head = size; if (p0 < 0) { p0 = 0; } - if (p1 < 0) { p1 = std::numeric_limits::max(); } - // models like Mamba or RWKV can't have a state partially erased - if (seq_id >= (int64_t) size) { - // could be fatal - return false; - } if (0 <= seq_id) { int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { const kv_cell & cell = cells[tail_id]; - // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { - return false; - } + // already validated in can_seq_rm + GGML_ASSERT(!((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos))); // invalidate tails which will be cleared if (p0 <= cell.pos && cell.pos < p1) { tail_id = -1; } } } else { - // seq_id is negative, then the range should include everything or nothing - if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { - return false; - } + // already validated in can_seq_rm + GGML_ASSERT(!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max()))); } for (uint32_t i = 0; i < size; ++i) { @@ -2349,6 +2360,34 @@ void llama_kv_cache_recurrent::set_full() { n = size; head = 0; } +bool llama_kv_cache_recurrent::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const { + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + // models like Mamba or RWKV can't have a state partially erased + if (seq_id >= (int64_t) size) { + // could be fatal + return false; + } + if (0 <= seq_id) { + const int32_t & tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + const kv_cell & cell = cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + } + // seq_id is negative, then the range should include everything or nothing + } else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + return true; +} bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { const uint32_t n_tokens = ubatch.n_tokens; diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index f1ba7cba390e2..5a8a0dd15d624 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -42,6 +42,11 @@ struct llama_kv_cache : public llama_memory_i { // TODO: remove virtual void set_full() = 0; + // sometimes it is useful to check whether a cache can remove a sequence + // before attempting to mutate the cache (eg a hybrid cache with multiple + // children to keep in sync) + virtual bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const = 0; + // getters virtual bool get_can_shift() const = 0; @@ -112,6 +117,8 @@ class llama_kv_cache_unified : public llama_kv_cache { void set_full() override; + bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override; + bool get_can_shift() const override; // state write/load @@ -287,6 +294,8 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { void set_full() override; + bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override; + bool get_can_shift() const override; // state write/load @@ -355,6 +364,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache { void set_full() override; + bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override; + bool prepare(const std::vector & ubatches); // find a contiguous slot of kv cells and emplace the ubatch there From 13efcf31a80292ca4aa5804c2812f9a6e4124081 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 20 May 2025 13:42:14 -0600 Subject: [PATCH 06/21] feat: Move layer_filter_cb up to llama_kv_cache This will be needed by other cache types as well, so centralizing the definition will make it more reusable. Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.h | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 5a8a0dd15d624..e9f7e40389fa4 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -20,6 +20,12 @@ struct llama_model; struct llama_context; struct llama_kv_cache : public llama_memory_i { + + // some child types need to perform different caching for each layer, so + // this callback can be used to determine which layers a given cache should + // be used for + using layer_filter_cb = std::function; + virtual ~llama_kv_cache() = default; // split the input batch into a set of ubatches and verify that they can fit into the cache @@ -68,9 +74,6 @@ class llama_kv_cache_unified : public llama_kv_cache { public: static uint32_t get_padding(const llama_cparams & cparams); - // this callback is used to filter out layers that should not be included in the cache - using layer_filter_cb = std::function; - llama_kv_cache_unified( const llama_model & model, layer_filter_cb && filter, From 6625248043bc39c1e8a89c591eeb7a4ac4f5ac7d Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 20 May 2025 13:43:16 -0600 Subject: [PATCH 07/21] feat: Add layer filter to recurrent cache Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.cpp | 18 ++++++++++++------ src/llama-kv-cache.h | 13 +++++++------ src/llama-model.cpp | 1 + tests/test-memory.cpp | 1 + 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index de1decf8ab680..309c418efc39c 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1951,12 +1951,13 @@ class llama_kv_cache_recurrent_decode_state_t : public llama_memory_decode_state }; llama_kv_cache_recurrent::llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", @@ -1998,6 +1999,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( v_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { + if (filter && !filter(i)) { + LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i); + continue; + } + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index e9f7e40389fa4..99979dea9f40b 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -327,12 +327,13 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { class llama_kv_cache_recurrent : public llama_kv_cache { public: llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max); + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max); ~llama_kv_cache_recurrent() = default; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e99f5309f9904..4ba9f454891b3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13208,6 +13208,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, { res = new llama_kv_cache_recurrent( *this, + nullptr, GGML_TYPE_F32, GGML_TYPE_F32, cparams.offload_kqv, diff --git a/tests/test-memory.cpp b/tests/test-memory.cpp index d843a6b5ea175..43b8dc1973b7a 100644 --- a/tests/test-memory.cpp +++ b/tests/test-memory.cpp @@ -209,6 +209,7 @@ static void test_llama_kv_cache_recurrent_constructor() { auto model = _make_model(LLM_ARCH_MAMBA); llama_kv_cache_recurrent cache( /* model */ *model, + /* filter */ nullptr, /* type_k */ GGML_TYPE_F32, /* type_v */ GGML_TYPE_F16, /* offload */ false, From 76771e87fa323597e8588b09b3e911be035d7285 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 20 May 2025 13:56:50 -0600 Subject: [PATCH 08/21] feat: Initial implementation of llama_kv_cache_hybrid Condensed from initial version https://github.com/gabe-l-hart/llama.cpp/tree/ec08571 The only difference is the removal of m_layer_cache_map which was unused and unnecessary now that child caches are instantiated with their own filters. Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.cpp | 237 +++++++++++++++++++++++++++++++++++++++++ src/llama-kv-cache.h | 98 +++++++++++++++++ tests/test-memory.cpp | 68 ++++++++++++ 3 files changed, 403 insertions(+) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 309c418efc39c..a21f293b6ac10 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -3014,3 +3014,240 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce return true; } + +// +// llama_kv_cache_hybrid +// +llama_kv_cache_hybrid::llama_kv_cache_hybrid( + const llama_hparams & hparams, + std::vector children) : + m_hparams(hparams), + m_children( + [](std::vector& caches) -> std::set> { + // Sort the caches by the lowest layer ID so the order is repeatable + for (auto & cache : caches) { + GGML_ASSERT(cache.layer_ids.size() > 0); + std::sort(cache.layer_ids.begin(), cache.layer_ids.end()); + } + std::sort(caches.begin(), caches.end(), [](const child_cache & a, const child_cache & b) { + return a.layer_ids[0] < b.layer_ids[0]; + }); + std::set> unique_caches; + for (auto & cache : caches) { + unique_caches.emplace(cache.child.release()); + } + return unique_caches; + }(children) + ), + m_has_recurrent( + [](const std::set> & caches) -> bool { + for (const auto & cache : caches) { + if (dynamic_cast(cache.get())) { + return true; + } + } + return false; + }(m_children) + ) +{ + // Ensure at least one child + GGML_ASSERT(m_children.size() > 0); + + // Ensure layers are not overlapping and are concurrent + std::set seen_layers; + size_t max_layer = 0; + for (const auto & cache : children) { + for (const auto & layer_id : cache.layer_ids) { + GGML_ASSERT(seen_layers.find(layer_id) == seen_layers.end()); + seen_layers.insert(layer_id); + if (layer_id > max_layer) { + max_layer = layer_id; + } + } + } + LLAMA_LOG_DEBUG("max_layer=%zu, seen_layers.size()=%zu\n", max_layer, seen_layers.size()); + GGML_ASSERT(max_layer + 1 == seen_layers.size()); +} + +void llama_kv_cache_hybrid::clear() { + for (const auto & cache : m_children) { + cache->clear(); + } +} + +bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // First check if we can do this removal. This checks all children so that + // no mutation happens before we know if it's possible + if (!can_seq_rm(seq_id, p0, p1)) { + return false; + } + + // Do the removal from each child which should never fail + for (const auto & cache : m_children) { + const bool failed = cache->seq_rm(seq_id, p0, p1); + GGML_ASSERT(!failed); + } + return true; +} + +void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + for (const auto & cache : m_children) { + cache->seq_cp(seq_id_src, seq_id_dst, p0, p1); + } +} + +void llama_kv_cache_hybrid::seq_keep(llama_seq_id seq_id) { + for (const auto & cache : m_children) { + cache->seq_keep(seq_id); + } +} + +void llama_kv_cache_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + for (const auto & cache : m_children) { + cache->seq_add(seq_id, p0, p1, delta); + } +} + +void llama_kv_cache_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + for (const auto & cache : m_children) { + cache->seq_div(seq_id, p0, p1, d); + } +} + +llama_pos llama_kv_cache_hybrid::seq_pos_min(llama_seq_id seq_id) const { + llama_pos min_pos = -1; + for (const auto & cache : m_children) { + const auto child_min_pos = cache->seq_pos_min(seq_id); + min_pos = min_pos == -1 ? child_min_pos : std::min(min_pos, child_min_pos); + } + return min_pos; +} + +llama_pos llama_kv_cache_hybrid::seq_pos_max(llama_seq_id seq_id) const { + llama_pos max_pos = 0; + for (const auto & cache : m_children) { + max_pos = std::max(max_pos, cache->seq_pos_max(seq_id)); + } + return max_pos; +} + +void llama_kv_cache_hybrid::restore() { + for (const auto & cache : m_children) { + cache->restore(); + } +} + +void llama_kv_cache_hybrid::commit() { + for (const auto & cache : m_children) { + cache->commit(); + } +} + +bool llama_kv_cache_hybrid::update(llama_context & ctx) { + bool updated = false; + for (const auto & cache : m_children) { + updated = cache->update(ctx) || updated; + } + return updated; +} + +void llama_kv_cache_hybrid::defrag_sched(float thold) { + for (const auto & cache : m_children) { + cache->defrag_sched(thold); + } +} + +void llama_kv_cache_hybrid::set_full() { + for (const auto & cache : m_children) { + cache->set_full(); + } +} + +bool llama_kv_cache_hybrid::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const { + for (const auto & cache : m_children) { + if (!cache->can_seq_rm(seq_id, p0, p1)) { + return false; + } + } + return true; +} + +llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) { + // If any of the caches are recurrent, require equal split + return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all); +} + +llama_ubatch llama_kv_cache_hybrid::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + return sbatch.split_seq(n_ubatch); + } + if (m_has_recurrent) { + return sbatch.split_equal(n_ubatch); + } + return sbatch.split_simple(n_ubatch); +} + +bool llama_kv_cache_hybrid::find_slot(const llama_ubatch & batch) { + bool found = true; + for (const auto & cache : m_children) { + found = cache->find_slot(batch) && found; + } + return found; +} + +int32_t llama_kv_cache_hybrid::get_n_tokens() const { + // The number of tokens should be the same across all child caches + int32_t n_tokens = -1; + for (const auto & cache : m_children) { + const auto cache_n_tokens = cache->get_n_tokens(); + GGML_ASSERT(n_tokens == -1 || cache_n_tokens == n_tokens); + n_tokens = cache_n_tokens; + } + return n_tokens; +} + +int32_t llama_kv_cache_hybrid::get_used_cells() const { + // TODO: Is this correct? + // Return the largest number of used cells + int32_t used_cells = -1; + for (const auto & cache : m_children) { + used_cells = std::max(used_cells, cache->get_used_cells()); + } + return used_cells; +} + +llama_pos llama_kv_cache_hybrid::get_pos_max() const { + llama_pos pos_max = -1; + for (const auto & cache : m_children) { + pos_max = std::max(pos_max, cache->get_pos_max()); + } + return pos_max; +} + +bool llama_kv_cache_hybrid::get_can_shift() const { + // TODO: Is this correct? + // If any children can shift, return true + for (const auto & cache : m_children) { + if (cache->get_can_shift()) { + return true; + } + } + return false; +} + +void llama_kv_cache_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + // Write each cache state in order. Note that order is guaranteed at + // initialization by using an ordered set sorted by lowest layer ID + for (const auto & cache : m_children) { + cache->state_write(io, seq_id); + } +} + +void llama_kv_cache_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + // Read each cache state in order. Note that order is guaranteed at + // initialization by using an ordered set sorted by lowest layer ID + for (const auto & cache : m_children) { + cache->state_read(io, seq_id); + } +} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 99979dea9f40b..a88cd45fd1821 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -439,3 +439,101 @@ class llama_kv_cache_recurrent : public llama_kv_cache { bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; + +// +// llama_kv_cache_hybrid +// + +// utilizes multiple different cache types with each layer assigned to exactly +// one cache. This is typically used for hybrid attention / recurrent caching + +class llama_kv_cache_hybrid : public llama_kv_cache { +public: + + struct child_cache { + std::unique_ptr child; + std::vector layer_ids; + + child_cache(std::unique_ptr child_, std::vector layer_ids_) + : child(std::move(child_)), layer_ids(std::move(layer_ids_)) {} + }; + + llama_kv_cache_hybrid( + const llama_hparams & hparams, + std::vector children); + + virtual ~llama_kv_cache_hybrid() = default; + + // getters for specific child cache type + // NOTE: This will fail if there are multiple of the given type + template + const child_t * get_child_cache() const { + const child_t * child = nullptr; + for (const auto & child_cache : m_children) { + const child_t * child_cast = dynamic_cast(child_cache.get()); + if (child_cast) { + GGML_ASSERT(!child); + child = child_cast; + } + } + return child; + } + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + void restore() override; + void commit() override; + + bool update(llama_context & ctx) override; + + void defrag_sched(float thold) override; + + void set_full() override; + + bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override; + + llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; + + llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; + + // updates the cache head + // Note: On success, it's important that cache.head points + // to the first cell of the slot. + bool find_slot(const llama_ubatch & batch) override; + + int32_t get_n_tokens() const override; + int32_t get_used_cells() const override; + + // TODO: better data structures to reduce the cost of this operation + llama_pos get_pos_max() const override; + + bool get_can_shift() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + +private: + + const llama_hparams & m_hparams; + const std::set> m_children; // Ordered for state IO + const bool m_has_recurrent; +}; diff --git a/tests/test-memory.cpp b/tests/test-memory.cpp index 43b8dc1973b7a..fb8ba8586a684 100644 --- a/tests/test-memory.cpp +++ b/tests/test-memory.cpp @@ -218,6 +218,72 @@ static void test_llama_kv_cache_recurrent_constructor() { ); } +/*- Hybrid Cache -------------------------------------------------------------*/ + +/* Test that the hybrid cache can be constructed and destructed safely */ +static void test_llama_kv_cache_hybrid_constructor() { + auto model = _make_model( + /* arch =*/ LLM_ARCH_LLAMA, + /* n_layer =*/ 4, + /* n_embd_head_k =*/ 4, + /* n_embd_head_v =*/ 4, + /* n_head =*/ 0, + /* n_head_kv =*/ 0 + ); + auto recurrent_filter = [](int32_t il) { + return il == 0 || il == 2; + }; + auto unified_filter = [&recurrent_filter](int32_t il) { + return !recurrent_filter(il); + }; + auto& n_head_arr = model->hparams.n_head_arr; + n_head_arr[0] = 16; + n_head_arr[1] = 32; + n_head_arr[2] = 16; + n_head_arr[3] = 32; + auto& n_head_kv_arr = model->hparams.n_head_kv_arr; + n_head_kv_arr[0] = 16; + n_head_kv_arr[1] = 8; + n_head_kv_arr[2] = 16; + n_head_kv_arr[3] = 8; + + std::unique_ptr u_cache( + new llama_kv_cache_unified( + /* model */ *model, + /* filter */ unified_filter, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* v_trans */ false, + /* offload */ false, + /* kv_size */ 10, + /* padding */ 10, + /* n_swa */ 0, + /* swa_type */ LLAMA_SWA_TYPE_NONE + ) + ); + auto * u_cache_ptr = u_cache.get(); + std::unique_ptr r_cache ( + new llama_kv_cache_recurrent( + /* model */ *model, + /* filter */ recurrent_filter, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* offload */ false, + /* kv_size */ 10 + ) + ); + auto * r_cache_ptr = r_cache.get(); + + std::vector children; + children.emplace_back(std::move(u_cache), std::vector{1, 3}); + children.emplace_back(std::move(r_cache), std::vector{0, 2}); + + llama_kv_cache_hybrid cache(model->hparams, std::move(children)); + + GGML_ASSERT(cache.get_child_cache() == u_cache_ptr); + GGML_ASSERT(cache.get_child_cache() == r_cache_ptr); +} + /*- Main ---------------------------------------------------------------------*/ int main(int argc, char* argv[]) { @@ -226,5 +292,7 @@ int main(int argc, char* argv[]) { RUN_TEST(test_llama_kv_cache_unified_single_seq); // Recurrent Cache Tests RUN_TEST(test_llama_kv_cache_recurrent_constructor); + // Hybrid Cache Tests + RUN_TEST(test_llama_kv_cache_hybrid_constructor); return 0; } From f031fb836e755d7c592d1cfcb53e68bb9aae3615 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:21:29 -0600 Subject: [PATCH 09/21] feat: Add llama_model_is_hybrid API call Also, split llama_model_is_recurrent into llm_arch_is_recurrent in llama-arch with llama_model_is_recurrent delegating to llm_arch_is_recurrent. The same split is done for hybird. This is needed because there are places where the llama_model has not yet been initialized but we need to check if the model is recurrent (specifically for the per-layer recurrent check array in hparams). Branch: GraniteFour Signed-off-by: Gabe Goodhart --- include/llama.h | 3 +++ src/llama-arch.cpp | 22 ++++++++++++++++++++++ src/llama-arch.h | 3 +++ src/llama-model.cpp | 13 +++++-------- 4 files changed, 33 insertions(+), 8 deletions(-) diff --git a/include/llama.h b/include/llama.h index 29677d74207a3..7accdb0142591 100644 --- a/include/llama.h +++ b/include/llama.h @@ -554,6 +554,9 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.) + LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index abf436adac416..fd00b307bdd82 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1744,3 +1744,25 @@ llm_arch llm_arch_from_string(const std::string & name) { const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) { return LLM_TENSOR_INFOS.at(tensor); } + +bool llm_arch_is_recurrent(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_MAMBA: + case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: + case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: + return true; + default: + return false; + } +} + +bool llm_arch_is_hybrid(const llm_arch & arch) { + // TODO: There are currently no hybrid models! Once there are, this will be + // the place to identify them + switch (arch) { + default: + return false; + } +} diff --git a/src/llama-arch.h b/src/llama-arch.h index 41a023da3da6e..a9d66643495cf 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -435,3 +435,6 @@ const char * llm_arch_name(llm_arch arch); llm_arch llm_arch_from_string(const std::string & name); const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); + +bool llm_arch_is_recurrent(const llm_arch& arch); +bool llm_arch_is_hybrid(const llm_arch& arch); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4ba9f454891b3..b5dc3e7e45d33 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13800,14 +13800,11 @@ llama_token llama_model_decoder_start_token(const llama_model * model) { } bool llama_model_is_recurrent(const llama_model * model) { - switch (model->arch) { - case LLM_ARCH_MAMBA: return true; - case LLM_ARCH_RWKV6: return true; - case LLM_ARCH_RWKV6QWEN2: return true; - case LLM_ARCH_RWKV7: return true; - case LLM_ARCH_ARWKV7: return true; - default: return false; - } + return llm_arch_is_recurrent(model->arch); +} + +bool llama_model_is_hybrid(const llama_model * model) { + return llm_arch_is_hybrid(model->arch); } const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { From 298e147721f7a8e9f293ac9eb47610a485c3653d Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:08:33 -0600 Subject: [PATCH 10/21] feat: Add c++ side constants for attention layer indices hparam Branch: GraniteFour --- src/llama-arch.cpp | 1 + src/llama-arch.h | 1 + 2 files changed, 2 insertions(+) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index fd00b307bdd82..291e5da230ad2 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -144,6 +144,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + { LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index a9d66643495cf..1bfd9780ac962 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -148,6 +148,7 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + LLM_KV_ATTENTION_LAYER_INDICES, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, From f4723735696e0ff479eb35124f8dd4e303f6987c Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:04:36 -0600 Subject: [PATCH 11/21] feat: Add support for distinguishing recurrent vs non-recurrent layers in hparams Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-hparams.cpp | 14 ++++++++++++-- src/llama-hparams.h | 10 ++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 1499eb08a5dd9..70a7114f39715 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -65,7 +65,10 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { return n_embd_head_v * n_head_kv; } -uint32_t llama_hparams::n_embd_k_s() const { +uint32_t llama_hparams::n_embd_k_s(uint32_t il) const { + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // for RWKV models return token_shift_count * n_embd; @@ -76,7 +79,10 @@ uint32_t llama_hparams::n_embd_k_s() const { return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } -uint32_t llama_hparams::n_embd_v_s() const { +uint32_t llama_hparams::n_embd_v_s(uint32_t il) const { + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // corresponds to RWKV's wkv_states size return n_embd * wkv_head_size; @@ -86,6 +92,10 @@ uint32_t llama_hparams::n_embd_v_s() const { return ssm_d_state * ssm_d_inner; } +bool llama_hparams::recurrent_layer(uint32_t il) const { + return recurrent_layer_arr[il]; +} + bool llama_hparams::is_swa(uint32_t il) const { if (il < n_layer) { return swa_layers[il]; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 2d72eab180ad0..e10741a104cb7 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -115,6 +115,9 @@ struct llama_hparams { uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + // for hybrid state space models + std::array recurrent_layer_arr; + bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; @@ -178,10 +181,13 @@ struct llama_hparams { // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size - uint32_t n_embd_k_s() const; + uint32_t n_embd_k_s(uint32_t il = 0) const; // dimension of the recurrent state embeddings - uint32_t n_embd_v_s() const; + uint32_t n_embd_v_s(uint32_t il = 0) const; + + // whether or not the given layer is recurrent (for hybrid models) + bool recurrent_layer(uint32_t il) const; bool is_swa(uint32_t il) const; }; From 404783b228e2dd970a7417ff3e53cc2838b81f2b Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:22:18 -0600 Subject: [PATCH 12/21] feat: Auto-fill hparams.recurrent_layer_arr based on whether the model is recurrent Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b5dc3e7e45d33..8cac26f9c275e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -466,6 +466,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + std::fill( + hparams.recurrent_layer_arr.begin(), + hparams.recurrent_layer_arr.end(), + llm_arch_is_recurrent(ml.get_arch())); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); From 53ef2d4a7ac9c17e1576761b29afd3119b1c9f78 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 20 May 2025 14:57:57 -0600 Subject: [PATCH 13/21] feat: Instantiate hybrid cache for hybrid models (currently none) This includes a slight architectural change where create_memory now only uses model architectures in the switch statement if their required cache type is not handled by llm_arch_is_[recurrent|hybrid]. Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 142 ++++++++++++++++++++++++++++++-------------- 1 file changed, 97 insertions(+), 45 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8cac26f9c275e..c26aef7591c13 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13196,6 +13196,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_memory_i * res; switch (arch) { + // Models that need specific instantiation should be handled in the + // switch statement case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: @@ -13204,58 +13206,108 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, { res = nullptr; } break; - case LLM_ARCH_MAMBA: - case LLM_ARCH_RWKV6: - case LLM_ARCH_RWKV6QWEN2: - case LLM_ARCH_RWKV7: - case LLM_ARCH_ARWKV7: - { - res = new llama_kv_cache_recurrent( - *this, - nullptr, - GGML_TYPE_F32, - GGML_TYPE_F32, - cparams.offload_kqv, - std::max((uint32_t) 1, cparams.n_seq_max), - cparams.n_seq_max); - } break; + // Models that need standard caching should rely on recurrent/hybrid + // checks default: { - const auto padding = llama_kv_cache_unified::get_padding(cparams); - - cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); - - LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - GGML_ASSERT(hparams.is_swa_any()); + if (llm_arch_is_hybrid(arch)) { + // make vectors of recurrent and non-recurrent layer indices + std::vector recurrent_layers; + std::vector unified_layers; + for (auto il = 0u; il < hparams.n_layer; ++il) { + if (hparams.recurrent_layer(il)) { + recurrent_layers.push_back(il); + } else { + unified_layers.push_back(il); + } + } - res = new llama_kv_cache_unified_iswa( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - params.swa_full, - cparams.n_ctx, - cparams.n_seq_max, - cparams.n_batch, - padding); - } else { - GGML_ASSERT(!hparams.is_swa_any()); + const auto padding = llama_kv_cache_unified::get_padding(cparams); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + + // initialize the children + std::vector children; + children.emplace_back( + std::unique_ptr( + new llama_kv_cache_recurrent( + *this, + [&](int32_t il) { + return hparams.recurrent_layer(il); + }, + GGML_TYPE_F32, + GGML_TYPE_F32, + cparams.offload_kqv, + std::max((uint32_t) 1, cparams.n_seq_max), + cparams.n_seq_max) + ), + std::move(recurrent_layers) + ); + children.emplace_back( + std::unique_ptr( + new llama_kv_cache_unified( + *this, + [&](int32_t il) { + return ! hparams.recurrent_layer(il); + }, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + cparams.n_seq_max, + padding, + hparams.n_swa, + hparams.swa_type) + ), + std::move(unified_layers) + ); - res = new llama_kv_cache_unified( + // initialize the hybrid cache with both children + res = new llama_kv_cache_hybrid(hparams, std::move(children)); + } else if (llm_arch_is_recurrent(arch)) { + res = new llama_kv_cache_recurrent( *this, nullptr, - params.type_k, - params.type_v, - !cparams.flash_attn, + GGML_TYPE_F32, + GGML_TYPE_F32, cparams.offload_kqv, - cparams.n_ctx, - cparams.n_seq_max, - padding, - hparams.n_swa, - hparams.swa_type); + std::max((uint32_t) 1, cparams.n_seq_max), + cparams.n_seq_max + ); + } else { + const auto padding = llama_kv_cache_unified::get_padding(cparams); + + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + + if (hparams.n_swa > 0) { + res = new llama_kv_cache_unified_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + params.swa_full, + cparams.n_seq_max, + cparams.n_batch, + padding); + } else { + res = new llama_kv_cache_unified( + *this, + nullptr, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + cparams.n_seq_max, + padding, + hparams.n_swa, + hparams.swa_type); + } } } } From e6ff93a201a664b775d3406cc46e8dc60d2f1c85 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 20 May 2025 15:03:30 -0600 Subject: [PATCH 14/21] refactor: rename *_is_hybrid -> *_is_hybrid_recurrent The implementation of the hybrid cache intentionally does not specify the types of the child caches, so there was a naming mismatch with these predicate functions that used "hybrid" to imply "hybrid recurrent." Branch: HybridCache Signed-off-by: Gabe Goodhart --- include/llama.h | 2 +- src/llama-arch.cpp | 2 +- src/llama-arch.h | 2 +- src/llama-model.cpp | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/llama.h b/include/llama.h index 7accdb0142591..a2b87bcb4d750 100644 --- a/include/llama.h +++ b/include/llama.h @@ -555,7 +555,7 @@ extern "C" { LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); // Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.) - LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + LLAMA_API bool llama_model_is_hybrid_recurrent(const struct llama_model * model); // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 291e5da230ad2..70060debbcb77 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1759,7 +1759,7 @@ bool llm_arch_is_recurrent(const llm_arch & arch) { } } -bool llm_arch_is_hybrid(const llm_arch & arch) { +bool llm_arch_is_hybrid_recurrent(const llm_arch & arch) { // TODO: There are currently no hybrid models! Once there are, this will be // the place to identify them switch (arch) { diff --git a/src/llama-arch.h b/src/llama-arch.h index 1bfd9780ac962..35c917a5c365a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -438,4 +438,4 @@ llm_arch llm_arch_from_string(const std::string & name); const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); bool llm_arch_is_recurrent(const llm_arch& arch); -bool llm_arch_is_hybrid(const llm_arch& arch); +bool llm_arch_is_hybrid_recurrent(const llm_arch& arch); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c26aef7591c13..b51142121b60d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13210,7 +13210,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, // checks default: { - if (llm_arch_is_hybrid(arch)) { + if (llm_arch_is_hybrid_recurrent(arch)) { // make vectors of recurrent and non-recurrent layer indices std::vector recurrent_layers; std::vector unified_layers; @@ -13859,8 +13859,8 @@ bool llama_model_is_recurrent(const llama_model * model) { return llm_arch_is_recurrent(model->arch); } -bool llama_model_is_hybrid(const llama_model * model) { - return llm_arch_is_hybrid(model->arch); +bool llama_model_is_hybrid_recurrent(const llama_model * model) { + return llm_arch_is_hybrid_recurrent(model->arch); } const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { From 226955baa7b54ccf5a0b7ef528f248b30d569e13 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 20 May 2025 16:30:35 -0600 Subject: [PATCH 15/21] fix: Fix indexing into k_l for recurrent cache with filter Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index a21f293b6ac10..0d86c6e869460 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -2029,8 +2029,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); - k_l.push_back(k); - v_l.push_back(v); + k_l[i] = k; + v_l[i] = v; } // allocate tensors and initialize the buffers to avoid NaNs in the padding From 1fb08dad29607e7ba885da18919af80ddda4a2ed Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 14 May 2025 09:16:06 -0600 Subject: [PATCH 16/21] fix: Use per-layer sizing everywhere in kv caches Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.cpp | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 0d86c6e869460..c60deac98e7aa 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -129,8 +129,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); const char * dev_name = "CPU"; @@ -1387,7 +1387,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); // Write key type const int32_t k_type_i = (int32_t)layer.k->type; @@ -1409,7 +1409,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)layer.v->type; @@ -1433,7 +1433,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)layer.v->type; @@ -1573,7 +1573,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); // Read type of key int32_t k_type_i_ref; @@ -1603,7 +1603,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; @@ -1633,7 +1633,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; @@ -2004,8 +2004,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i); const char * dev_name = "CPU"; @@ -2730,7 +2730,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); // Write key type const int32_t k_type_i = (int32_t)k_l[il]->type; @@ -2750,7 +2750,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -2771,7 +2771,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = size; for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -2918,7 +2918,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); // Read type of key int32_t k_type_i_ref; @@ -2946,7 +2946,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; @@ -2974,7 +2974,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; From 04fe482ed7a6a66a0cbc80c4cb25019dc8c3c410 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 23 May 2025 12:22:08 -0600 Subject: [PATCH 17/21] fix: Remove unused kv cache methods after rebase Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.cpp | 29 ----------------------------- src/llama-kv-cache.h | 6 ------ 2 files changed, 35 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index c60deac98e7aa..a2e10611da08d 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -3196,35 +3196,6 @@ bool llama_kv_cache_hybrid::find_slot(const llama_ubatch & batch) { return found; } -int32_t llama_kv_cache_hybrid::get_n_tokens() const { - // The number of tokens should be the same across all child caches - int32_t n_tokens = -1; - for (const auto & cache : m_children) { - const auto cache_n_tokens = cache->get_n_tokens(); - GGML_ASSERT(n_tokens == -1 || cache_n_tokens == n_tokens); - n_tokens = cache_n_tokens; - } - return n_tokens; -} - -int32_t llama_kv_cache_hybrid::get_used_cells() const { - // TODO: Is this correct? - // Return the largest number of used cells - int32_t used_cells = -1; - for (const auto & cache : m_children) { - used_cells = std::max(used_cells, cache->get_used_cells()); - } - return used_cells; -} - -llama_pos llama_kv_cache_hybrid::get_pos_max() const { - llama_pos pos_max = -1; - for (const auto & cache : m_children) { - pos_max = std::max(pos_max, cache->get_pos_max()); - } - return pos_max; -} - bool llama_kv_cache_hybrid::get_can_shift() const { // TODO: Is this correct? // If any children can shift, return true diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index a88cd45fd1821..b71cff983981c 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -518,12 +518,6 @@ class llama_kv_cache_hybrid : public llama_kv_cache { // to the first cell of the slot. bool find_slot(const llama_ubatch & batch) override; - int32_t get_n_tokens() const override; - int32_t get_used_cells() const override; - - // TODO: better data structures to reduce the cost of this operation - llama_pos get_pos_max() const override; - bool get_can_shift() const override; // state write/load From db9a618b640e3ef7f5c7e3b5355f9fbc10eb4547 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 23 May 2025 12:22:36 -0600 Subject: [PATCH 18/21] fix(tests): Fix constructors in tests for signature changes after rebase Branch: HybridCache Signed-off-by: Gabe Goodhart --- tests/test-memory.cpp | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/tests/test-memory.cpp b/tests/test-memory.cpp index fb8ba8586a684..b3def0df374ea 100644 --- a/tests/test-memory.cpp +++ b/tests/test-memory.cpp @@ -249,27 +249,29 @@ static void test_llama_kv_cache_hybrid_constructor() { std::unique_ptr u_cache( new llama_kv_cache_unified( - /* model */ *model, - /* filter */ unified_filter, - /* type_k */ GGML_TYPE_F32, - /* type_v */ GGML_TYPE_F16, - /* v_trans */ false, - /* offload */ false, - /* kv_size */ 10, - /* padding */ 10, - /* n_swa */ 0, - /* swa_type */ LLAMA_SWA_TYPE_NONE + /* model */ *model, + /* filter */ unified_filter, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* v_trans */ false, + /* offload */ false, + /* kv_size */ 10, + /* n_seq_max */ 1, + /* padding */ 10, + /* n_swa */ 0, + /* swa_type */ LLAMA_SWA_TYPE_NONE ) ); auto * u_cache_ptr = u_cache.get(); std::unique_ptr r_cache ( new llama_kv_cache_recurrent( - /* model */ *model, - /* filter */ recurrent_filter, - /* type_k */ GGML_TYPE_F32, - /* type_v */ GGML_TYPE_F16, - /* offload */ false, - /* kv_size */ 10 + /* model */ *model, + /* filter */ recurrent_filter, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* offload */ false, + /* kv_size */ 10, + /* n_seq_max */ 1 ) ); auto * r_cache_ptr = r_cache.get(); From 8aee2e785d20b488a798e51036131d584dd0b5ca Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 27 May 2025 16:52:23 -0600 Subject: [PATCH 19/21] feat: Add split_equal to init(...) signature This will enable the hybrid cache to control the split type for all children together. Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.cpp | 33 ++++++++++++++++++++++++++------- src/llama-kv-cache.h | 12 ++++++++---- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index a2e10611da08d..c8e463750ee7c 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -352,14 +352,19 @@ llama_memory_decode_state_ptr llama_kv_cache_unified::init( const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, - bool logits_all) { + bool logits_all, + bool split_equal) { GGML_UNUSED(embd_pooled); auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); std::vector ubatches; while (sbatch.n_tokens > 0) { - ubatches.push_back(sbatch.split_simple(n_ubatch)); + if (split_equal) { + ubatches.push_back(sbatch.split_equal(n_ubatch)); + } else { + ubatches.push_back(sbatch.split_simple(n_ubatch)); + } } auto heads = prepare(ubatches); @@ -1821,7 +1826,12 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { return kv_swa->seq_pos_max(seq_id); } -llama_memory_decode_state_ptr llama_kv_cache_unified_iswa::init(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { +llama_memory_decode_state_ptr llama_kv_cache_unified_iswa::init( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all, + bool split_equal) { GGML_UNUSED(embd_pooled); auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); @@ -1829,9 +1839,11 @@ llama_memory_decode_state_ptr llama_kv_cache_unified_iswa::init(const llama_batc std::vector ubatches; while (sbatch.n_tokens > 0) { - auto ubatch = sbatch.split_simple(n_ubatch); - - ubatches.push_back(ubatch); + if (split_equal) { + ubatches.push_back(sbatch.split_equal(n_ubatch)); + } else { + ubatches.push_back(sbatch.split_simple(n_ubatch)); + } } auto heads_base = kv_base->prepare(ubatches); @@ -2291,8 +2303,15 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } -llama_memory_decode_state_ptr llama_kv_cache_recurrent::init(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { +llama_memory_decode_state_ptr llama_kv_cache_recurrent::init( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all, + bool split_equal) { GGML_UNUSED(embd_pooled); + // TODO: Should this just be ignored? + assert(split_equal); auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index b71cff983981c..6a9150f67cf81 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -34,7 +34,8 @@ struct llama_kv_cache : public llama_memory_i { const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, - bool logits_all) = 0; + bool logits_all, + bool split_equal = false) = 0; // process any pending defrag/shift/etc. operations // optionally call once before processing a new batch @@ -112,7 +113,8 @@ class llama_kv_cache_unified : public llama_kv_cache { const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, - bool logits_all) override; + bool logits_all, + bool split_equal = false) override; bool update(llama_context & lctx) override; @@ -289,7 +291,8 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, - bool logits_all) override; + bool logits_all, + bool split_equal = false) override; bool update(llama_context & lctx) override; @@ -360,7 +363,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache { const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, - bool logits_all) override; + bool logits_all, + bool split_equal = true) override; bool update(llama_context & lctx) override; From a4cc4aaa285405c47891acae0b15d504d0767dba Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 27 May 2025 16:54:07 -0600 Subject: [PATCH 20/21] fix: Overhaul hybrid cache for refactor part3 (::init interface) Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.cpp | 184 ++++++++++++++++++++++++++++------------- src/llama-kv-cache.h | 29 +++---- src/llama-model.cpp | 2 +- 3 files changed, 137 insertions(+), 78 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index c8e463750ee7c..0763feddf66d5 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -399,7 +399,7 @@ std::vector llama_kv_cache_unified::prepare(const std::vector children) : - m_hparams(hparams), - m_children( + + +class llama_kv_cache_hybrid_decode_state_t : public llama_memory_decode_state_i { +public: + explicit llama_kv_cache_hybrid_decode_state_t( + std::vector decode_states) : + status([](const std::vector & decode_states) -> llama_memory_status { + for (const auto & decode_state : decode_states) { + if (!decode_state) { + return LLAMA_MEMORY_STATUS_FAILED_PREPARE; + } + const auto & status = decode_state->get_status(); + if (status != LLAMA_MEMORY_STATUS_SUCCESS) { + return status; + } + } + return LLAMA_MEMORY_STATUS_SUCCESS; + }(decode_states)), + decode_states(std::move(decode_states)) { + + // make sure at least one decode state + assert(!decode_states.empty()); + + // make sure all out_ids match across states + // TODO: This could be expensive, so maybe don't do it? + const auto & out_ids = decode_states[0]->out_ids(); + for (size_t i = 1; i < decode_states.size(); ++i) { + const auto & out_ids_i = decode_states[i]->out_ids(); + assert(out_ids.size() == out_ids_i.size()); + for (size_t j = 0; j < out_ids.size(); ++j) { + assert(out_ids[j] == out_ids_i[j]); + } + } + } + + ~llama_kv_cache_hybrid_decode_state_t() = default; + + llama_ubatch * next() override { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + // hit next on each child + std::vector next_ubatches; + for (const auto & decode_state : decode_states) { + next_ubatches.push_back(decode_state->next()); + } + + // make sure they all match + // TODO: unnecessary safety? + llama_ubatch * res = next_ubatches[0]; + assert(res); + for (size_t i = 1; i < next_ubatches.size(); ++i) { + llama_ubatch * ubatch_i = next_ubatches[i]; + assert(ubatch_i); + assert(ubatch_i->n_tokens == res->n_tokens); + assert(ubatch_i->n_seq_tokens == res->n_seq_tokens); + assert(ubatch_i->n_seqs == res->n_seqs); + for (size_t j = 0; j < res->n_tokens; ++j) { + assert(ubatch_i->token[j] == res->token[j]); + assert(ubatch_i->pos[j] == res->pos[j]); + assert(ubatch_i->output[j] == res->output[j]); + } + for (size_t j = 0; j < res->n_seqs; ++j) { + assert(ubatch_i->n_seq_id[j] == res->n_seq_id[j]); + } + } + + // return the first ubatch since they all match + return res; + } + + std::vector & out_ids() override { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return decode_states[0]->out_ids(); + } + + llama_memory_status get_status() const override { + return status; + } + +private: + const llama_memory_status status; + std::vector decode_states; +}; + +llama_kv_cache_hybrid::llama_kv_cache_hybrid(std::vector children_) : + children( [](std::vector& caches) -> std::set> { // Sort the caches by the lowest layer ID so the order is repeatable for (auto & cache : caches) { @@ -3056,9 +3138,9 @@ llama_kv_cache_hybrid::llama_kv_cache_hybrid( unique_caches.emplace(cache.child.release()); } return unique_caches; - }(children) + }(children_) ), - m_has_recurrent( + has_recurrent( [](const std::set> & caches) -> bool { for (const auto & cache : caches) { if (dynamic_cast(cache.get())) { @@ -3066,16 +3148,16 @@ llama_kv_cache_hybrid::llama_kv_cache_hybrid( } } return false; - }(m_children) + }(children) ) { // Ensure at least one child - GGML_ASSERT(m_children.size() > 0); + GGML_ASSERT(children.size() > 0); // Ensure layers are not overlapping and are concurrent std::set seen_layers; size_t max_layer = 0; - for (const auto & cache : children) { + for (const auto & cache : children_) { for (const auto & layer_id : cache.layer_ids) { GGML_ASSERT(seen_layers.find(layer_id) == seen_layers.end()); seen_layers.insert(layer_id); @@ -3089,7 +3171,7 @@ llama_kv_cache_hybrid::llama_kv_cache_hybrid( } void llama_kv_cache_hybrid::clear() { - for (const auto & cache : m_children) { + for (const auto & cache : children) { cache->clear(); } } @@ -3102,7 +3184,7 @@ bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } // Do the removal from each child which should never fail - for (const auto & cache : m_children) { + for (const auto & cache : children) { const bool failed = cache->seq_rm(seq_id, p0, p1); GGML_ASSERT(!failed); } @@ -3110,32 +3192,32 @@ bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - for (const auto & cache : m_children) { + for (const auto & cache : children) { cache->seq_cp(seq_id_src, seq_id_dst, p0, p1); } } void llama_kv_cache_hybrid::seq_keep(llama_seq_id seq_id) { - for (const auto & cache : m_children) { + for (const auto & cache : children) { cache->seq_keep(seq_id); } } void llama_kv_cache_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - for (const auto & cache : m_children) { + for (const auto & cache : children) { cache->seq_add(seq_id, p0, p1, delta); } } void llama_kv_cache_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - for (const auto & cache : m_children) { + for (const auto & cache : children) { cache->seq_div(seq_id, p0, p1, d); } } llama_pos llama_kv_cache_hybrid::seq_pos_min(llama_seq_id seq_id) const { llama_pos min_pos = -1; - for (const auto & cache : m_children) { + for (const auto & cache : children) { const auto child_min_pos = cache->seq_pos_min(seq_id); min_pos = min_pos == -1 ? child_min_pos : std::min(min_pos, child_min_pos); } @@ -3144,46 +3226,56 @@ llama_pos llama_kv_cache_hybrid::seq_pos_min(llama_seq_id seq_id) const { llama_pos llama_kv_cache_hybrid::seq_pos_max(llama_seq_id seq_id) const { llama_pos max_pos = 0; - for (const auto & cache : m_children) { + for (const auto & cache : children) { max_pos = std::max(max_pos, cache->seq_pos_max(seq_id)); } return max_pos; } -void llama_kv_cache_hybrid::restore() { - for (const auto & cache : m_children) { - cache->restore(); - } -} +llama_memory_decode_state_ptr llama_kv_cache_hybrid::init( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all, + bool split_equal) { -void llama_kv_cache_hybrid::commit() { - for (const auto & cache : m_children) { - cache->commit(); + // recurrent children require equal splits + // TODO: just ignore this if set incorrectly? + assert(!has_recurrent || split_equal); + + // init all children and capture their decode states + std::vector decode_states; + for (const auto & child : children) { + decode_states.emplace_back( + child->init(batch, n_ubatch, embd_pooled, logits_all, split_equal)); } + + // return the hybrid decode state + return std::make_unique(std::move(decode_states)); } bool llama_kv_cache_hybrid::update(llama_context & ctx) { bool updated = false; - for (const auto & cache : m_children) { + for (const auto & cache : children) { updated = cache->update(ctx) || updated; } return updated; } void llama_kv_cache_hybrid::defrag_sched(float thold) { - for (const auto & cache : m_children) { + for (const auto & cache : children) { cache->defrag_sched(thold); } } void llama_kv_cache_hybrid::set_full() { - for (const auto & cache : m_children) { + for (const auto & cache : children) { cache->set_full(); } } bool llama_kv_cache_hybrid::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const { - for (const auto & cache : m_children) { + for (const auto & cache : children) { if (!cache->can_seq_rm(seq_id, p0, p1)) { return false; } @@ -3191,34 +3283,10 @@ bool llama_kv_cache_hybrid::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_ return true; } -llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) { - // If any of the caches are recurrent, require equal split - return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all); -} - -llama_ubatch llama_kv_cache_hybrid::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - return sbatch.split_seq(n_ubatch); - } - if (m_has_recurrent) { - return sbatch.split_equal(n_ubatch); - } - return sbatch.split_simple(n_ubatch); -} - -bool llama_kv_cache_hybrid::find_slot(const llama_ubatch & batch) { - bool found = true; - for (const auto & cache : m_children) { - found = cache->find_slot(batch) && found; - } - return found; -} - bool llama_kv_cache_hybrid::get_can_shift() const { // TODO: Is this correct? // If any children can shift, return true - for (const auto & cache : m_children) { + for (const auto & cache : children) { if (cache->get_can_shift()) { return true; } @@ -3229,7 +3297,7 @@ bool llama_kv_cache_hybrid::get_can_shift() const { void llama_kv_cache_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { // Write each cache state in order. Note that order is guaranteed at // initialization by using an ordered set sorted by lowest layer ID - for (const auto & cache : m_children) { + for (const auto & cache : children) { cache->state_write(io, seq_id); } } @@ -3237,7 +3305,7 @@ void llama_kv_cache_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_ void llama_kv_cache_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) { // Read each cache state in order. Note that order is guaranteed at // initialization by using an ordered set sorted by lowest layer ID - for (const auto & cache : m_children) { + for (const auto & cache : children) { cache->state_read(io, seq_id); } } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 6a9150f67cf81..d47f25402cee4 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -462,10 +462,7 @@ class llama_kv_cache_hybrid : public llama_kv_cache { : child(std::move(child_)), layer_ids(std::move(layer_ids_)) {} }; - llama_kv_cache_hybrid( - const llama_hparams & hparams, - std::vector children); - + explicit llama_kv_cache_hybrid(std::vector children); virtual ~llama_kv_cache_hybrid() = default; // getters for specific child cache type @@ -473,7 +470,7 @@ class llama_kv_cache_hybrid : public llama_kv_cache { template const child_t * get_child_cache() const { const child_t * child = nullptr; - for (const auto & child_cache : m_children) { + for (const auto & child_cache : children) { const child_t * child_cast = dynamic_cast(child_cache.get()); if (child_cast) { GGML_ASSERT(!child); @@ -502,8 +499,12 @@ class llama_kv_cache_hybrid : public llama_kv_cache { // llama_kv_cache // - void restore() override; - void commit() override; + llama_memory_decode_state_ptr init( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all, + bool split_equal = true) override; bool update(llama_context & ctx) override; @@ -513,15 +514,6 @@ class llama_kv_cache_hybrid : public llama_kv_cache { bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override; - llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; - - // updates the cache head - // Note: On success, it's important that cache.head points - // to the first cell of the slot. - bool find_slot(const llama_ubatch & batch) override; - bool get_can_shift() const override; // state write/load @@ -531,7 +523,6 @@ class llama_kv_cache_hybrid : public llama_kv_cache { private: - const llama_hparams & m_hparams; - const std::set> m_children; // Ordered for state IO - const bool m_has_recurrent; + const std::set> children; // Ordered for state IO + const bool has_recurrent; }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b51142121b60d..110c4863880b3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13264,7 +13264,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, ); // initialize the hybrid cache with both children - res = new llama_kv_cache_hybrid(hparams, std::move(children)); + res = new llama_kv_cache_hybrid(std::move(children)); } else if (llm_arch_is_recurrent(arch)) { res = new llama_kv_cache_recurrent( *this, From 58994f653086b19be3d8d988f7176106bb60cdec Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 27 May 2025 16:54:46 -0600 Subject: [PATCH 21/21] tests(wip): Comment out broken test for now and fix other constructor signatures Branch: HybridCache Signed-off-by: Gabe Goodhart --- tests/test-memory.cpp | 86 +++++++++++++++++++++---------------------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/tests/test-memory.cpp b/tests/test-memory.cpp index b3def0df374ea..33dd22fc7eae8 100644 --- a/tests/test-memory.cpp +++ b/tests/test-memory.cpp @@ -158,48 +158,48 @@ static void test_llama_kv_cache_unified_single_seq() { /* swa_type */ LLAMA_SWA_TYPE_NONE ); - // Create the micro batch with a single 3-token sequence - llama_batch batch1 = _make_batch({{101, 1, 102}}, {{42}}); - llama_sbatch sbatch1 = cache.sbatch_init(batch1, false); - llama_ubatch ubatch1 = cache.ubatch_next(sbatch1, 4, false); - - // Find a slot for a new sequence - GGML_ASSERT(cache.find_slot(ubatch1)); - - // Cache the k/v for a single layer in this slot - ggml_context * ctx = ggml_init({10240, NULL, false}); - ggml_tensor * k1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_k_gqa(0)); - ggml_tensor * v1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_v_gqa(0)); - ggml_tensor * k1_view = cache.cpy_k(ctx, k1, 0); - ggml_tensor * v1_view = cache.cpy_v(ctx, v1, 0); - GGML_ASSERT(is_source_tensor(k1_view, k1)); - GGML_ASSERT(is_source_tensor(v1_view, v1)); - - // Create a second batch with different tokens and find a slot for it - llama_batch batch2 = _make_batch({{1, 2, 3, 4}}, {{5}}); - llama_sbatch sbatch2 = cache.sbatch_init(batch2, false); - llama_ubatch ubatch2 = cache.ubatch_next(sbatch2, 4, false); - GGML_ASSERT(cache.find_slot(ubatch2)); - - // Add some different tensors - ggml_tensor * k2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_k_gqa(0)); - ggml_tensor * v2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_v_gqa(0)); - ggml_tensor * k2_view = cache.cpy_k(ctx, k2, 0); - ggml_tensor * v2_view = cache.cpy_v(ctx, v2, 0); - GGML_ASSERT(is_source_tensor(k2_view, k2)); - GGML_ASSERT(is_source_tensor(v2_view, v2)); - - // Make sure first batch's k/v aren't cache hit - GGML_ASSERT(!is_source_tensor(k2_view, k1)); - GGML_ASSERT(!is_source_tensor(v2_view, v1)); - - // Re-find the slot for the first batch and make sure they cache hit - GGML_ASSERT(cache.find_slot(ubatch1)); - - // Clean up - llama_batch_free(batch1); - llama_batch_free(batch2); - ggml_free(ctx); + // // Create the micro batch with a single 3-token sequence + // llama_batch batch1 = _make_batch({{101, 1, 102}}, {{42}}); + // llama_sbatch sbatch1 = cache.sbatch_init(batch1, false); + // llama_ubatch ubatch1 = cache.ubatch_next(sbatch1, 4, false); + + // // Find a slot for a new sequence + // GGML_ASSERT(cache.find_slot(ubatch1)); + + // // Cache the k/v for a single layer in this slot + // ggml_context * ctx = ggml_init({10240, NULL, false}); + // ggml_tensor * k1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_k_gqa(0)); + // ggml_tensor * v1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_v_gqa(0)); + // ggml_tensor * k1_view = cache.cpy_k(ctx, k1, 0); + // ggml_tensor * v1_view = cache.cpy_v(ctx, v1, 0); + // GGML_ASSERT(is_source_tensor(k1_view, k1)); + // GGML_ASSERT(is_source_tensor(v1_view, v1)); + + // // Create a second batch with different tokens and find a slot for it + // llama_batch batch2 = _make_batch({{1, 2, 3, 4}}, {{5}}); + // llama_sbatch sbatch2 = cache.sbatch_init(batch2, false); + // llama_ubatch ubatch2 = cache.ubatch_next(sbatch2, 4, false); + // GGML_ASSERT(cache.find_slot(ubatch2)); + + // // Add some different tensors + // ggml_tensor * k2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_k_gqa(0)); + // ggml_tensor * v2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_v_gqa(0)); + // ggml_tensor * k2_view = cache.cpy_k(ctx, k2, 0); + // ggml_tensor * v2_view = cache.cpy_v(ctx, v2, 0); + // GGML_ASSERT(is_source_tensor(k2_view, k2)); + // GGML_ASSERT(is_source_tensor(v2_view, v2)); + + // // Make sure first batch's k/v aren't cache hit + // GGML_ASSERT(!is_source_tensor(k2_view, k1)); + // GGML_ASSERT(!is_source_tensor(v2_view, v1)); + + // // Re-find the slot for the first batch and make sure they cache hit + // GGML_ASSERT(cache.find_slot(ubatch1)); + + // // Clean up + // llama_batch_free(batch1); + // llama_batch_free(batch2); + // ggml_free(ctx); } /*- Recurrent Cache ----------------------------------------------------------*/ @@ -280,7 +280,7 @@ static void test_llama_kv_cache_hybrid_constructor() { children.emplace_back(std::move(u_cache), std::vector{1, 3}); children.emplace_back(std::move(r_cache), std::vector{0, 2}); - llama_kv_cache_hybrid cache(model->hparams, std::move(children)); + llama_kv_cache_hybrid cache(std::move(children)); GGML_ASSERT(cache.get_child_cache() == u_cache_ptr); GGML_ASSERT(cache.get_child_cache() == r_cache_ptr);