diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5a18a4fb3939a..e352d81e4ed7c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -280,8 +280,8 @@ llama_context::llama_context( // simulate full KV cache - const auto mstate = memory->init_full(); - if (!mstate) { + const auto mctx = memory->init_full(); + if (!mctx) { throw std::runtime_error("failed to initialize KV cache"); } @@ -289,7 +289,7 @@ llama_context::llama_context( // reserve pp graph first so that buffers are only allocated once { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -300,7 +300,7 @@ llama_context::llama_context( // reserve with tg graph to get the number of splits and nodes { - auto * gf = graph_reserve(1, 1, 1, mstate.get()); + auto * gf = graph_reserve(1, 1, 1, mctx.get()); if (!gf) { throw std::runtime_error("failed to allocate compute tg buffers"); } @@ -311,7 +311,7 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) { optimize |= memory_force_optimize; memory_force_optimize = false; - const auto mstate = memory->init_update(this, optimize); - switch (mstate->get_status()) { + const auto mctx = memory->init_update(this, optimize); + switch (mctx->get_status()) { case LLAMA_MEMORY_STATUS_SUCCESS: { // noop @@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) { } } - if (!mstate->apply()) { + if (!mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); } } // if the memory module did any computation, we have to reserve a new worst-case graph { - const auto mstate = memory->init_full(); - if (!mstate) { - throw std::runtime_error("failed to initialize memory state"); + const auto mctx = memory->init_full(); + if (!mctx) { + throw std::runtime_error("failed to initialize memory context"); } const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); } @@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) { - if (mstate && !mstate->apply()) { - LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__); +llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { + if (mctx && !mctx->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; return nullptr; } @@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, return nullptr; } - auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate); + auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx); if (!res) { LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); ret = GGML_STATUS_FAILED; @@ -933,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) { // handle any pending defrags/shifts kv_self_update(false); - llama_memory_state_ptr mstate; + llama_memory_context_ptr mctx; while (true) { - mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all); - if (!mstate) { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + if (!mctx) { return -2; } - switch (mstate->get_status()) { + switch (mctx->get_status()) { case LLAMA_MEMORY_STATUS_SUCCESS: { } break; case LLAMA_MEMORY_STATUS_NO_UPDATE: { - LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status()); + LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status()); return -2; } @@ -987,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) { int64_t n_outputs_prev = 0; do { - const auto & ubatch = mstate->get_ubatch(); + const auto & ubatch = mctx->get_ubatch(); // count the outputs in this ubatch { @@ -1009,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) { ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); ggml_status status; - const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status); + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache @@ -1126,7 +1126,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } n_outputs_prev += n_outputs; - } while (mstate->next()); + } while (mctx->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith n_outputs = n_outputs_all; @@ -1292,7 +1292,7 @@ ggml_cgraph * llama_context::graph_init() { return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); } -ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) { +ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) { LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); if (n_tokens % n_seqs != 0) { @@ -1312,7 +1312,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx); this->n_outputs = save_n_outputs; @@ -1333,11 +1333,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u } llm_graph_result_ptr llama_context::graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype, - const llama_memory_state_i * mstate) { + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_context_i * mctx) { return model.build_graph( { /*.ctx =*/ ctx, @@ -1349,7 +1349,7 @@ llm_graph_result_ptr llama_context::graph_build( /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, - /*.mstate =*/ mstate, + /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), @@ -2042,8 +2042,8 @@ void llama_context::opt_epoch_iter( uint32_t n_outputs_all = n_tokens_all; - auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true); - if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { + auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true); + if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); break; } @@ -2056,17 +2056,17 @@ void llama_context::opt_epoch_iter( uint32_t pos_batch = 0; do { - const auto & ubatch = mstate->get_ubatch(); + const auto & ubatch = mctx->get_ubatch(); n_outputs = ubatch.n_tokens; - if (!mstate->apply()) { - LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__); + if (!mctx->apply()) { + LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__); break; } auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get()); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get()); struct ggml_context * ctx_compute_opt; { @@ -2101,7 +2101,7 @@ void llama_context::opt_epoch_iter( ggml_free(ctx_compute_opt); pos_batch += ubatch.n_tokens; - } while (mstate->next()); + } while (mctx->next()); } } diff --git a/src/llama-context.h b/src/llama-context.h index 7d300c14572e9..9ce05715a8c03 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -18,7 +18,7 @@ class llama_io_read_i; class llama_io_write_i; struct llama_memory_i; -struct llama_memory_state_i; +struct llama_memory_context_i; struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs @@ -93,14 +93,14 @@ struct llama_context { int32_t il_end); // process a single ubatch with a specific graph type - // if memory_state is provided, it will be applied first to the context's memory + // if memory_context is provided, it will be applied first to the context's memory // ret contains the status of the graph computation // returns nullptr only if ret != GGML_STATUS_SUCCESS llm_graph_result_ptr process_ubatch( - const llama_ubatch & ubatch, - llm_graph_type gtype, - llama_memory_state_i * mstate, - ggml_status & ret); + const llama_ubatch & ubatch, + llm_graph_type gtype, + llama_memory_context_i * mctx, + ggml_status & ret); int encode(const llama_batch & batch_inp); int decode(const llama_batch & batch_inp); @@ -197,15 +197,15 @@ struct llama_context { ggml_status graph_compute(ggml_cgraph * gf, bool batched); // reserve a graph with a dummy ubatch of the specified size - ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate); + ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); private: llm_graph_result_ptr graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype, - const llama_memory_state_i * mstate); + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_context_i * mctx); llm_graph_cb graph_get_cb() const; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 7e162c5552204..48589a50ab24d 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -87,7 +87,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) { if (pos_bucket) { - kv_state->set_input_pos_bucket(pos_bucket, ubatch); + mctx->set_input_pos_bucket(pos_bucket, ubatch); } } @@ -221,7 +221,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); - const int64_t n_rs = mem_state->get_n_rs(); + const int64_t n_rs = mctx->get_n_rs(); if (s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); @@ -229,7 +229,7 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_rs; ++i) { - data[i] = mem_state->s_copy(i); + data[i] = mctx->s_copy(i); } } } @@ -282,17 +282,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { - kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } } void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { - kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } if (self_kq_mask_swa) { - kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } } @@ -334,10 +334,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { - mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } - const int64_t n_rs = mem_state->get_state_recr()->get_n_rs(); + const int64_t n_rs = mctx->get_recr()->get_n_rs(); if (s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); @@ -345,7 +345,7 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_rs; ++i) { - data[i] = mem_state->get_state_recr()->s_copy(i); + data[i] = mctx->get_recr()->s_copy(i); } } } @@ -389,7 +389,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : backend_cpu (params.backend_cpu), cvec (params.cvec), loras (params.loras), - mstate (params.mstate), + mctx (params.mctx), cross (params.cross), cb_func (params.cb), res (std::make_unique()) { @@ -950,11 +950,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const { } ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); - auto inp = std::make_unique(hparams, kv_state); + auto inp = std::make_unique(hparams, mctx_cur); - const auto n_kv = kv_state->get_n_kv(); + const auto n_kv = mctx_cur->get_n_kv(); auto & cur = inp->pos_bucket; @@ -982,14 +982,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t } llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { - const auto * mem_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); - auto inp = std::make_unique(hparams, cparams, mem_state); + auto inp = std::make_unique(hparams, cparams, mctx_cur); { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); - const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv(); + const auto n_kv = inp->mctx->get_attn()->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -999,7 +999,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { } { - const auto n_rs = mem_state->get_state_recr()->get_n_rs(); + const auto n_rs = mctx_cur->get_recr()->get_n_rs(); inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); ggml_set_input(inp->s_copy); @@ -1183,14 +1183,14 @@ ggml_tensor * llm_graph_context::build_attn( } llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); - auto inp = std::make_unique(hparams, cparams, kv_state); + auto inp = std::make_unique(hparams, cparams, mctx_cur); { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); - const auto n_kv = kv_state->get_n_kv(); + const auto n_kv = mctx_cur->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -1220,19 +1220,19 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); // store to KV cache { - ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); } const auto & kq_mask = inp->get_kq_mask(); ggml_tensor * q = q_cur; - ggml_tensor * k = kv_state->get_k(ctx0, il); - ggml_tensor * v = kv_state->get_v(ctx0, il); + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = mctx_cur->get_v(ctx0, il); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); @@ -1270,23 +1270,23 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto * kv_state_iswa = static_cast(mstate); + const auto * mctx_iswa = static_cast(mctx); const bool is_swa = hparams.is_swa(il); - const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base(); + const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base(); // store to KV cache { - ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); } const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); ggml_tensor * q = q_cur; - ggml_tensor * k = kv_state->get_k(ctx0, il); - ggml_tensor * v = kv_state->get_v(ctx0, il); + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = mctx_cur->get_v(ctx0, il); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); @@ -1379,19 +1379,19 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto * kv_state = static_cast(mstate)->get_state_attn(); + const auto * mctx_cur = static_cast(mctx)->get_attn(); // store to KV cache { - ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); } const auto & kq_mask = inp->get_kq_mask(); ggml_tensor * q = q_cur; - ggml_tensor * k = kv_state->get_k(ctx0, il); - ggml_tensor * v = kv_state->get_v(ctx0, il); + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = mctx_cur->get_v(ctx0, il); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); @@ -1412,12 +1412,12 @@ ggml_tensor * llm_graph_context::build_attn( } llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); - auto inp = std::make_unique(hparams, cparams, kv_state); + auto inp = std::make_unique(hparams, cparams, mctx_cur); { - const auto n_kv = kv_state->get_base()->get_n_kv(); + const auto n_kv = mctx_cur->get_base()->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -1429,7 +1429,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif { GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); - const auto n_kv = kv_state->get_swa()->get_n_kv(); + const auto n_kv = mctx_cur->get_swa()->get_n_kv(); inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); @@ -1485,11 +1485,11 @@ ggml_tensor * llm_graph_context::build_rs( } llm_graph_input_rs * llm_graph_context::build_rs_inp() const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); - auto inp = std::make_unique(kv_state); + auto inp = std::make_unique(mctx_cur); - const auto n_rs = kv_state->get_n_rs(); + const auto n_rs = mctx_cur->get_n_rs(); inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); ggml_set_input(inp->s_copy); @@ -1504,9 +1504,9 @@ ggml_tensor * llm_graph_context::build_rs( int32_t state_size, int32_t n_seqs, bool avoid_copies) const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); - return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies); + return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies); } ggml_tensor * llm_graph_context::build_rs( @@ -1516,9 +1516,9 @@ ggml_tensor * llm_graph_context::build_rs( int32_t state_size, int32_t n_seqs, bool avoid_copies) const { - const auto * kv_state = static_cast(mstate)->get_state_recr(); + const auto * mctx_cur = static_cast(mctx)->get_recr(); - return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies); + return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies); } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( @@ -1526,13 +1526,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_cgraph * gf, const llama_ubatch & ubatch, int il) const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); const auto token_shift_count = hparams.token_shift_count; const int64_t n_seqs = ubatch.n_seqs; - ggml_tensor * token_shift_all = kv_state->get_r_l(il); + ggml_tensor * token_shift_all = mctx_cur->get_r_l(il); ggml_tensor * token_shift = build_rs( inp, gf, token_shift_all, @@ -1547,19 +1547,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ggml_tensor * token_shift, const llama_ubatch & ubatch, int il) const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); const auto token_shift_count = hparams.token_shift_count; const auto n_embd = hparams.n_embd; const int64_t n_seqs = ubatch.n_seqs; - const auto kv_head = kv_state->get_head(); + const auto kv_head = mctx_cur->get_head(); return ggml_cpy( ctx0, ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0), - ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il))) + ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il))) ); } diff --git a/src/llama-graph.h b/src/llama-graph.h index 9e62fa60720d7..b433f266d1b29 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -17,12 +17,12 @@ struct ggml_tensor; struct llama_ubatch; struct llama_cparams; -struct llama_memory_state_i; +struct llama_memory_context_i; -class llama_kv_cache_unified_state; -class llama_kv_cache_unified_iswa_state; -class llama_memory_recurrent_state; -class llama_memory_hybrid_state; +class llama_kv_cache_unified_context; +class llama_kv_cache_unified_iswa_context; +class llama_memory_recurrent_context; +class llama_memory_hybrid_context; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -136,7 +136,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { public: llm_graph_input_pos_bucket_kv( const llama_hparams & hparams, - const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {} + const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {} virtual ~llm_graph_input_pos_bucket_kv() = default; void set_input(const llama_ubatch * ubatch) override; @@ -144,7 +144,8 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch] const llama_hparams & hparams; - const llama_kv_cache_unified_state * kv_state; + + const llama_kv_cache_unified_context * mctx; }; class llm_graph_input_out_ids : public llm_graph_input_i { @@ -191,14 +192,14 @@ class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_rs : public llm_graph_input_i { public: - llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {} + llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {} virtual ~llm_graph_input_rs() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_copy; // I32 [kv_size] - const llama_memory_recurrent_state * mem_state; + const llama_memory_recurrent_context * mctx; }; class llm_graph_input_cross_embd : public llm_graph_input_i { @@ -238,10 +239,10 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { llm_graph_input_attn_kv_unified( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified_state * kv_state) : + const llama_kv_cache_unified_context * mctx) : hparams(hparams), cparams(cparams), - kv_state(kv_state) { + mctx(mctx) { } ~llm_graph_input_attn_kv_unified() = default; @@ -255,7 +256,7 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { const llama_hparams & hparams; const llama_cparams & cparams; - const llama_kv_cache_unified_state * kv_state; + const llama_kv_cache_unified_context * mctx; }; class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { @@ -263,10 +264,10 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { llm_graph_input_attn_kv_unified_iswa( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified_iswa_state * kv_state) : + const llama_kv_cache_unified_iswa_context * mctx) : hparams(hparams), cparams(cparams), - kv_state(kv_state) { + mctx(mctx) { } ~llm_graph_input_attn_kv_unified_iswa() = default; @@ -283,7 +284,7 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { const llama_hparams & hparams; const llama_cparams & cparams; - const llama_kv_cache_unified_iswa_state * kv_state; + const llama_kv_cache_unified_iswa_context * mctx; }; class llm_graph_input_attn_cross : public llm_graph_input_i { @@ -306,10 +307,10 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i { llm_graph_input_mem_hybrid( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_memory_hybrid_state * mem_state) : + const llama_memory_hybrid_context * mctx) : hparams(hparams), cparams(cparams), - mem_state(mem_state) { + mctx(mctx) { } virtual ~llm_graph_input_mem_hybrid() = default; @@ -325,7 +326,7 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i { const llama_hparams & hparams; const llama_cparams & cparams; - const llama_memory_hybrid_state * mem_state; + const llama_memory_hybrid_context * mctx; }; // @@ -401,10 +402,10 @@ struct llm_graph_params { ggml_backend_sched_t sched; ggml_backend_t backend_cpu; - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_state_i * mstate; - const llama_cross * cross; + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_context_i * mctx; + const llama_cross * cross; uint32_t n_outputs; @@ -453,10 +454,10 @@ struct llm_graph_context { ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_state_i * mstate; - const llama_cross * cross; + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_context_i * mctx; + const llama_cross * cross; const llm_graph_cb & cb_func; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index 0ced340dec6c5..b9169299c0760 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -95,7 +95,7 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { return kv_swa->seq_pos_max(seq_id); } -llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { +llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { GGML_UNUSED(embd_all); // first try simple split @@ -125,7 +125,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc assert(heads_base.size() == heads_swa.size()); - return std::make_unique( + return std::make_unique( this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); } while (false); @@ -156,22 +156,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc assert(heads_base.size() == heads_swa.size()); - return std::make_unique( + return std::make_unique( this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); } while (false); // TODO: if we fail again, we should attempt different splitting strategies // but to do that properly, we first have to refactor the batches to be more flexible - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } -llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { - return std::make_unique(this); +llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() { + return std::make_unique(this); } -llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) { - return std::make_unique(this, lctx, optimize); +llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); } bool llama_kv_cache_unified_iswa::get_can_shift() const { @@ -197,46 +197,46 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { } // -// llama_kv_cache_unified_iswa_state +// llama_kv_cache_unified_iswa_context // -llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} +llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {} -llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( +llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv) : - state_base(kv->get_base()->init_full()), - state_swa (kv->get_swa ()->init_full()), - status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) { + ctx_base(kv->get_base()->init_full()), + ctx_swa (kv->get_swa ()->init_full()), + status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { } -llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( +llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv, llama_context * lctx, bool optimize) : - state_base(kv->get_base()->init_update(lctx, optimize)), - state_swa (kv->get_swa ()->init_update(lctx, optimize)), - status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) { + ctx_base(kv->get_base()->init_update(lctx, optimize)), + ctx_swa (kv->get_swa ()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { } -llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( +llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv, std::vector heads_base, std::vector heads_swa, std::vector ubatches) : ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal - state_base(new llama_kv_cache_unified_state(kv->get_base(), std::move(heads_base), this->ubatches)), - state_swa (new llama_kv_cache_unified_state(kv->get_swa (), std::move(heads_swa), this->ubatches)), - status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) { + ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)), + ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)), + status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { } -llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; +llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default; -bool llama_kv_cache_unified_iswa_state::next() { +bool llama_kv_cache_unified_iswa_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - state_base->next(); - state_swa ->next(); + ctx_base->next(); + ctx_swa ->next(); if (++i_next >= ubatches.size()) { return false; @@ -245,35 +245,35 @@ bool llama_kv_cache_unified_iswa_state::next() { return true; } -bool llama_kv_cache_unified_iswa_state::apply() { +bool llama_kv_cache_unified_iswa_context::apply() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); bool res = true; - res = res & state_base->apply(); - res = res & state_swa ->apply(); + res = res & ctx_base->apply(); + res = res & ctx_swa ->apply(); return res; } -llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { +llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const { return status; } -const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { +const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); return ubatches[i_next]; } -const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { +const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return static_cast(state_base.get()); + return static_cast(ctx_base.get()); } -const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { +const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return static_cast(state_swa.get()); + return static_cast(ctx_swa.get()); } diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index 071041585db38..46c1ed614f2f0 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -31,14 +31,14 @@ class llama_kv_cache_unified_iswa : public llama_memory_i { // llama_memory_i // - llama_memory_state_ptr init_batch( + llama_memory_context_ptr init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) override; - llama_memory_state_ptr init_full() override; + llama_memory_context_ptr init_full() override; - llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; bool get_can_shift() const override; @@ -72,32 +72,32 @@ class llama_kv_cache_unified_iswa : public llama_memory_i { std::unique_ptr kv_swa; }; -class llama_kv_cache_unified_iswa_state : public llama_memory_state_i { +class llama_kv_cache_unified_iswa_context : public llama_memory_context_i { public: // used for errors - llama_kv_cache_unified_iswa_state(llama_memory_status status); + llama_kv_cache_unified_iswa_context(llama_memory_status status); - // used to create a full-cache state - llama_kv_cache_unified_iswa_state( + // used to create a full-cache context + llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv); - // used to create an update state - llama_kv_cache_unified_iswa_state( + // used to create an update context + llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv, llama_context * lctx, bool optimize); - // used to create a state from a batch - llama_kv_cache_unified_iswa_state( + // used to create a batch processing context from a batch + llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv, std::vector heads_base, std::vector heads_swa, std::vector ubatches); - virtual ~llama_kv_cache_unified_iswa_state(); + virtual ~llama_kv_cache_unified_iswa_context(); // - // llama_memory_state_i + // llama_memory_context_i // bool next() override; @@ -107,11 +107,11 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i { const llama_ubatch & get_ubatch() const override; // - // llama_kv_cache_unified_iswa_state specific API + // llama_kv_cache_unified_iswa_context specific API // - const llama_kv_cache_unified_state * get_base() const; - const llama_kv_cache_unified_state * get_swa() const; + const llama_kv_cache_unified_context * get_base() const; + const llama_kv_cache_unified_context * get_swa() const; private: //llama_kv_cache_unified_iswa * kv; @@ -121,8 +121,8 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i { std::vector ubatches; - const llama_memory_state_ptr state_base; - const llama_memory_state_ptr state_swa; + const llama_memory_context_ptr ctx_base; + const llama_memory_context_ptr ctx_swa; const llama_memory_status status; }; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 6897b797153db..b506d32ed4d06 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -307,7 +307,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { return cells.seq_pos_max(seq_id); } -llama_memory_state_ptr llama_kv_cache_unified::init_batch( +llama_memory_context_ptr llama_kv_cache_unified::init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { @@ -332,18 +332,18 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch( break; } - return std::make_unique( + return std::make_unique( this, std::move(heads), std::move(ubatches)); } while (false); - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } -llama_memory_state_ptr llama_kv_cache_unified::init_full() { - return std::make_unique(this); +llama_memory_context_ptr llama_kv_cache_unified::init_full() { + return std::make_unique(this); } -llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) { +llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) { bool do_shift = get_has_shift(); defrag_info dinfo; @@ -373,7 +373,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, } } - return std::make_unique(this, lctx, do_shift, std::move(dinfo)); + return std::make_unique(this, lctx, do_shift, std::move(dinfo)); } llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector & ubatches) { @@ -1710,18 +1710,18 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell } // -// llama_kv_cache_unified_state +// llama_kv_cache_unified_context // -llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} +llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {} -llama_kv_cache_unified_state::llama_kv_cache_unified_state( +llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { n_kv = kv->get_size(); head = 0; } -llama_kv_cache_unified_state::llama_kv_cache_unified_state( +llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_context * lctx, bool do_shift, @@ -1731,15 +1731,15 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state( } } -llama_kv_cache_unified_state::llama_kv_cache_unified_state( +llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_kv_cache_unified::ubatch_heads heads, std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) { } -llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; +llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; -bool llama_kv_cache_unified_state::next() { +bool llama_kv_cache_unified_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); if (++i_next >= ubatches.size()) { @@ -1749,7 +1749,7 @@ bool llama_kv_cache_unified_state::next() { return true; } -bool llama_kv_cache_unified_state::apply() { +bool llama_kv_cache_unified_context::apply() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); // no ubatches -> this is a KV cache update @@ -1767,45 +1767,45 @@ bool llama_kv_cache_unified_state::apply() { return true; } -llama_memory_status llama_kv_cache_unified_state::get_status() const { +llama_memory_status llama_kv_cache_unified_context::get_status() const { return status; } -const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const { +const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); return ubatches[i_next]; } -uint32_t llama_kv_cache_unified_state::get_n_kv() const { +uint32_t llama_kv_cache_unified_context::get_n_kv() const { return n_kv; } -ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const { +ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const { return kv->get_k(ctx, il, n_kv); } -ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const { +ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const { return kv->get_v(ctx, il, n_kv); } -ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { +ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { return kv->cpy_k(ctx, k_cur, il, head); } -ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { +ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { return kv->cpy_v(ctx, v_cur, il, head); } -void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const { +void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const { kv->set_input_k_shift(dst); } -void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { +void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { kv->set_input_kq_mask(dst, ubatch, causal_attn); } -void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { +void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_pos_bucket(dst, ubatch); } diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 1560640045c82..4c53f1273ab88 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -56,14 +56,14 @@ class llama_kv_cache_unified : public llama_memory_i { // llama_memory_i // - llama_memory_state_ptr init_batch( + llama_memory_context_ptr init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) override; - llama_memory_state_ptr init_full() override; + llama_memory_context_ptr init_full() override; - llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; bool get_can_shift() const override; @@ -208,36 +208,36 @@ class llama_kv_cache_unified : public llama_memory_i { bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; -class llama_kv_cache_unified_state : public llama_memory_state_i { +class llama_kv_cache_unified_context : public llama_memory_context_i { public: // some shorthands using ubatch_heads = llama_kv_cache_unified::ubatch_heads; using defrag_info = llama_kv_cache_unified::defrag_info; // used for errors - llama_kv_cache_unified_state(llama_memory_status status); + llama_kv_cache_unified_context(llama_memory_status status); - // used to create a full-cache state - llama_kv_cache_unified_state( + // used to create a full-cache context + llama_kv_cache_unified_context( llama_kv_cache_unified * kv); - // used to create an update state - llama_kv_cache_unified_state( + // used to create an update context + llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_context * lctx, bool do_shift, defrag_info dinfo); - // used to create a decode state from a batch - llama_kv_cache_unified_state( + // used to create a batch procesing context from a batch + llama_kv_cache_unified_context( llama_kv_cache_unified * kv, ubatch_heads heads, std::vector ubatches); - virtual ~llama_kv_cache_unified_state(); + virtual ~llama_kv_cache_unified_context(); // - // llama_memory_state_i + // llama_memory_context_i // bool next() override; @@ -247,7 +247,7 @@ class llama_kv_cache_unified_state : public llama_memory_state_i { const llama_ubatch & get_ubatch() const override; // - // llama_kv_cache_unified_state specific API + // llama_kv_cache_unified_context specific API // uint32_t get_n_kv() const; @@ -272,7 +272,7 @@ class llama_kv_cache_unified_state : public llama_memory_state_i { llama_context * lctx; // - // update state + // update context // bool do_shift = false; @@ -280,7 +280,7 @@ class llama_kv_cache_unified_state : public llama_memory_state_i { defrag_info dinfo; // - // batch processing state + // batch processing context // // the index of the next ubatch to process diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 1b16686819eff..15cde98d138a8 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -56,7 +56,7 @@ llama_memory_hybrid::llama_memory_hybrid( n_seq_max )) {} -llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { +llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { do { balloc.split_reset(); @@ -82,31 +82,31 @@ llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ball // prepare the recurrent batches first if (!mem_recr->prepare(ubatches)) { - // TODO: will the recurrent cache be in an undefined state at this point? + // TODO: will the recurrent cache be in an undefined context at this point? LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } // prepare the attention cache auto heads_attn = mem_attn->prepare(ubatches); if (heads_attn.empty()) { LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } - return std::make_unique( + return std::make_unique( this, std::move(heads_attn), std::move(ubatches)); } while(false); - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } -llama_memory_state_ptr llama_memory_hybrid::init_full() { - return std::make_unique(this); +llama_memory_context_ptr llama_memory_hybrid::init_full() { + return std::make_unique(this); } -llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) { - return std::make_unique(this, lctx, optimize); +llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); } bool llama_memory_hybrid::get_can_shift() const { @@ -176,39 +176,39 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const { return mem_recr.get(); } -llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {} +llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {} -llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) : - state_attn(mem->get_mem_attn()->init_full()), - state_recr(mem->get_mem_recr()->init_full()), - status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) { +llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) : + ctx_attn(mem->get_mem_attn()->init_full()), + ctx_recr(mem->get_mem_recr()->init_full()), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } -llama_memory_hybrid_state::llama_memory_hybrid_state( +llama_memory_hybrid_context::llama_memory_hybrid_context( llama_memory_hybrid * mem, llama_context * lctx, bool optimize) : - state_attn(mem->get_mem_attn()->init_update(lctx, optimize)), - state_recr(mem->get_mem_recr()->init_update(lctx, optimize)), - status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) { + ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)), + ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } -llama_memory_hybrid_state::llama_memory_hybrid_state( +llama_memory_hybrid_context::llama_memory_hybrid_context( llama_memory_hybrid * mem, std::vector heads_attn, std::vector ubatches) : ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal - state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)), - state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), this->ubatches)), - status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) { + ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } -bool llama_memory_hybrid_state::next() { +bool llama_memory_hybrid_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - state_attn->next(); - state_recr->next(); + ctx_attn->next(); + ctx_recr->next(); if (++i_next >= ubatches.size()) { return false; @@ -217,30 +217,30 @@ bool llama_memory_hybrid_state::next() { return true; } -bool llama_memory_hybrid_state::apply() { +bool llama_memory_hybrid_context::apply() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); bool res = true; - res = res & state_attn->apply(); - res = res & state_recr->apply(); + res = res & ctx_attn->apply(); + res = res & ctx_recr->apply(); return res; } -llama_memory_status llama_memory_hybrid_state::get_status() const { +llama_memory_status llama_memory_hybrid_context::get_status() const { return status; } -const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const { +const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); return ubatches[i_next]; } -const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const { - return static_cast(state_attn.get()); +const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const { + return static_cast(ctx_attn.get()); } -const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const { - return static_cast(state_recr.get()); +const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const { + return static_cast(ctx_recr.get()); } diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h index 4d27ab896aa05..f0c2420e9a2df 100644 --- a/src/llama-memory-hybrid.h +++ b/src/llama-memory-hybrid.h @@ -49,14 +49,14 @@ class llama_memory_hybrid : public llama_memory_i { // llama_memory_i // - llama_memory_state_ptr init_batch( + llama_memory_context_ptr init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) override; - llama_memory_state_ptr init_full() override; + llama_memory_context_ptr init_full() override; - llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; bool get_can_shift() const override; @@ -90,27 +90,27 @@ class llama_memory_hybrid : public llama_memory_i { const std::unique_ptr mem_recr; }; -class llama_memory_hybrid_state : public llama_memory_state_i { +class llama_memory_hybrid_context : public llama_memory_context_i { public: // init failure - explicit llama_memory_hybrid_state(llama_memory_status status); + explicit llama_memory_hybrid_context(llama_memory_status status); // init full - explicit llama_memory_hybrid_state(llama_memory_hybrid * mem); + explicit llama_memory_hybrid_context(llama_memory_hybrid * mem); // init update - explicit llama_memory_hybrid_state( + explicit llama_memory_hybrid_context( llama_memory_hybrid * mem, llama_context * lctx, bool optimize); // init success - llama_memory_hybrid_state( + llama_memory_hybrid_context( llama_memory_hybrid * mem, std::vector heads_attn, std::vector ubatches); - ~llama_memory_hybrid_state() = default; + ~llama_memory_hybrid_context() = default; bool next() override; bool apply() override; @@ -119,11 +119,11 @@ class llama_memory_hybrid_state : public llama_memory_state_i { const llama_ubatch & get_ubatch() const override; // - // llama_memory_hybrid_state + // llama_memory_hybrid_context // - const llama_kv_cache_unified_state * get_state_attn() const; - const llama_memory_recurrent_state * get_state_recr() const; + const llama_kv_cache_unified_context * get_attn() const; + const llama_memory_recurrent_context * get_recr() const; private: // the index of the next ubatch to process @@ -131,8 +131,8 @@ class llama_memory_hybrid_state : public llama_memory_state_i { std::vector ubatches; - const llama_memory_state_ptr state_attn; - const llama_memory_state_ptr state_recr; + const llama_memory_context_ptr ctx_attn; + const llama_memory_context_ptr ctx_recr; const llama_memory_status status; }; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index b064da0084c52..1b1e95d567a6c 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -362,7 +362,7 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } -llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { +llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { std::vector ubatches; while (true) { @@ -383,21 +383,21 @@ llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & b } if (!prepare(ubatches)) { - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } - return std::make_unique(this, std::move(ubatches)); + return std::make_unique(this, std::move(ubatches)); } -llama_memory_state_ptr llama_memory_recurrent::init_full() { - return std::make_unique(this); +llama_memory_context_ptr llama_memory_recurrent::init_full() { + return std::make_unique(this); } -llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) { +llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) { GGML_UNUSED(lctx); GGML_UNUSED(optimize); - return std::make_unique(LLAMA_MEMORY_STATUS_NO_UPDATE); + return std::make_unique(LLAMA_MEMORY_STATUS_NO_UPDATE); } bool llama_memory_recurrent::prepare(const std::vector & ubatches) { @@ -1040,22 +1040,22 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell } // -// llama_memory_recurrent_state +// llama_memory_recurrent_context // -llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {} +llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {} -llama_memory_recurrent_state::llama_memory_recurrent_state( +llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) { } -llama_memory_recurrent_state::llama_memory_recurrent_state( +llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent * mem, std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {} -llama_memory_recurrent_state::~llama_memory_recurrent_state() = default; +llama_memory_recurrent_context::~llama_memory_recurrent_context() = default; -bool llama_memory_recurrent_state::next() { +bool llama_memory_recurrent_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); if (++i_next >= ubatches.size()) { @@ -1065,7 +1065,7 @@ bool llama_memory_recurrent_state::next() { return true; } -bool llama_memory_recurrent_state::apply() { +bool llama_memory_recurrent_context::apply() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); mem->find_slot(ubatches[i_next]); @@ -1073,40 +1073,40 @@ bool llama_memory_recurrent_state::apply() { return true; } -llama_memory_status llama_memory_recurrent_state::get_status() const { +llama_memory_status llama_memory_recurrent_context::get_status() const { return status; } -const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const { +const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); return ubatches[i_next]; } -uint32_t llama_memory_recurrent_state::get_n_rs() const { +uint32_t llama_memory_recurrent_context::get_n_rs() const { return is_full ? mem->size : mem->n; } -uint32_t llama_memory_recurrent_state::get_head() const { +uint32_t llama_memory_recurrent_context::get_head() const { return is_full ? 0 : mem->head; } -int32_t llama_memory_recurrent_state::get_rs_z() const { +int32_t llama_memory_recurrent_context::get_rs_z() const { return is_full ? 0 : mem->rs_z; } -uint32_t llama_memory_recurrent_state::get_size() const { +uint32_t llama_memory_recurrent_context::get_size() const { return mem->size; } -ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const { +ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const { return mem->r_l[il]; } -ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const { +ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { return mem->s_l[il]; } -int32_t llama_memory_recurrent_state::s_copy(int i) const { +int32_t llama_memory_recurrent_context::s_copy(int i) const { return mem->cells[i + mem->head].src0; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index be58dae7cfe33..4d094f9a05788 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -11,8 +11,8 @@ // llama_memory_recurrent // -// TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i -// see the implementation of llama_kv_cache_unified_state_i for an example how to do it +// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i +// see the implementation of llama_kv_cache_unified_context_i for an example how to do it class llama_memory_recurrent : public llama_memory_i { public: @@ -34,14 +34,14 @@ class llama_memory_recurrent : public llama_memory_i { // llama_memory_i // - llama_memory_state_ptr init_batch( + llama_memory_context_ptr init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) override; - llama_memory_state_ptr init_full() override; + llama_memory_context_ptr init_full() override; - llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; void clear(bool data) override; @@ -125,24 +125,24 @@ class llama_memory_recurrent : public llama_memory_i { bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; -class llama_memory_recurrent_state : public llama_memory_state_i { +class llama_memory_recurrent_context : public llama_memory_context_i { public: // used for errors - llama_memory_recurrent_state(llama_memory_status status); + llama_memory_recurrent_context(llama_memory_status status); - // used to create a full-cache state - llama_memory_recurrent_state( + // used to create a full-cache or update context + llama_memory_recurrent_context( llama_memory_recurrent * mem); - // used to create a state from a batch - llama_memory_recurrent_state( + // used to create a batch processing context from a batch + llama_memory_recurrent_context( llama_memory_recurrent * mem, std::vector ubatches); - virtual ~llama_memory_recurrent_state(); + virtual ~llama_memory_recurrent_context(); // - // llama_memory_state_i + // llama_memory_context_i // bool next() override; @@ -152,7 +152,7 @@ class llama_memory_recurrent_state : public llama_memory_state_i { const llama_ubatch & get_ubatch() const override; // - // llama_memory_recurrent_state specific API + // llama_memory_recurrent_context specific API // uint32_t get_n_rs() const; diff --git a/src/llama-memory.h b/src/llama-memory.h index d2ef0c2a3b4aa..16b7e5ee2484a 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -3,7 +3,6 @@ #include "llama.h" #include -#include struct llama_ubatch; @@ -28,23 +27,21 @@ enum llama_memory_status { LLAMA_MEMORY_STATUS_FAILED_COMPUTE, }; -// helper function for combining the status of two memory states +// helper function for combining the status of two memory contexts // useful for implementing hybrid memory types (e.g. iSWA) llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); -// the interface for managing the memory state during batch processing +// the interface for managing the memory context during batch processing // this interface is implemented per memory type. see: -// - llama_kv_cache_unified_state -// - llama_kv_cache_unified_iswa_state +// - llama_kv_cache_unified_context +// - llama_kv_cache_unified_iswa_context // ... // -// the only method that can mutate the memory and the memory state is llama_memory_i::apply() -// -// TODO: rename to llama_memory_context_i ? -struct llama_memory_state_i { - virtual ~llama_memory_state_i() = default; +// the only method that should mutate the memory and the memory context is llama_memory_i::apply() +struct llama_memory_context_i { + virtual ~llama_memory_context_i() = default; - // consume the current ubatch from the state and proceed to the next one + // consume the current ubatch from the context and proceed to the next one // return false if we are done virtual bool next() = 0; @@ -55,11 +52,11 @@ struct llama_memory_state_i { // get the current ubatch virtual const llama_ubatch & get_ubatch() const = 0; - // get the status of the memory state - used for error handling and checking if any updates would be applied + // get the status of the memory context - used for error handling and checking if any updates would be applied virtual llama_memory_status get_status() const = 0; }; -using llama_memory_state_ptr = std::unique_ptr; +using llama_memory_context_ptr = std::unique_ptr; // general concept of LLM memory // the KV cache is a type of LLM memory, but there can be other types @@ -67,19 +64,19 @@ struct llama_memory_i { virtual ~llama_memory_i() = default; // split the input batch into a set of ubatches and verify that they can fit into the cache - // return a state object containing the ubatches and KV cache state required to process them - // check the llama_memory_state_i::get_status() for the result - virtual llama_memory_state_ptr init_batch( + // return a context object containing the ubatches and memory state required to process them + // check the llama_memory_context_i::get_status() for the result + virtual llama_memory_context_ptr init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) = 0; // simulate full cache, used for allocating worst-case compute buffers - virtual llama_memory_state_ptr init_full() = 0; + virtual llama_memory_context_ptr init_full() = 0; // prepare for any pending memory updates, such as shifts, defrags, etc. // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update - virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0; + virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0; // getters virtual bool get_can_shift() const = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e2c82017f6890..9b19da984081e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9171,9 +9171,9 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * cur, const llama_ubatch & ubatch, int il) const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); - const auto kv_head = kv_state->get_head(); + const auto kv_head = mctx_cur->get_head(); const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; @@ -9191,8 +9191,8 @@ struct llm_build_mamba : public llm_graph_context { GGML_ASSERT(ubatch.equal_seqs); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - ggml_tensor * conv_states_all = kv_state->get_r_l(il); - ggml_tensor * ssm_states_all = kv_state->get_s_l(il); + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); // (ab)using the KV cache to store the states ggml_tensor * conv = build_rs( @@ -11916,7 +11916,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { ggml_tensor * x_prev, const llama_ubatch & ubatch, int il) const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -11926,7 +11926,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { const auto n_head = n_embd / head_size; const auto n_head_kv = hparams.n_head_kv(il); - const auto kv_head = kv_state->get_head(); + const auto kv_head = mctx_cur->get_head(); const auto & layer = model.layers[il]; @@ -12038,7 +12038,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { } ggml_tensor * wkv_state = build_rs( - inp, gf, kv_state->get_s_l(il), + inp, gf, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs); ggml_tensor * wkv_output; @@ -12057,9 +12057,9 @@ struct llm_build_rwkv6_base : public llm_graph_context { wkv_state, ggml_view_1d( ctx0, - kv_state->get_s_l(il), + mctx_cur->get_s_l(il), hparams.n_embd_s() * n_seqs, - hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il)) + hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il)) ) ) ); @@ -12313,7 +12313,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -12322,7 +12322,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { const auto head_count = n_embd / head_size; const auto n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = kv_state->get_head(); + const auto kv_head = mctx_cur->get_head(); const auto & layer = model.layers[il]; @@ -12393,7 +12393,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); ggml_tensor * wkv_state = build_rs( - inp, gf, kv_state->get_s_l(il), + inp, gf, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); @@ -12407,9 +12407,9 @@ struct llm_build_rwkv7_base : public llm_graph_context { wkv_state, ggml_view_1d( ctx0, - kv_state->get_s_l(il), + mctx_cur->get_s_l(il), hparams.n_embd_s() * n_seqs, - hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il)) + hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il)) ) ) );