Skip to content

memory : rename interface to llama_memory_context_i #14296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 39 additions & 39 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,16 +280,16 @@ llama_context::llama_context(

// simulate full KV cache

const auto mstate = memory->init_full();
if (!mstate) {
const auto mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize KV cache");
}

cross.v_embd.clear();

// reserve pp graph first so that buffers are only allocated once
{
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
Expand All @@ -300,7 +300,7 @@ llama_context::llama_context(

// reserve with tg graph to get the number of splits and nodes
{
auto * gf = graph_reserve(1, 1, 1, mstate.get());
auto * gf = graph_reserve(1, 1, 1, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers");
}
Expand All @@ -311,7 +311,7 @@ llama_context::llama_context(

// reserve again with pp graph to avoid ggml-alloc reallocations during inference
{
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
Expand Down Expand Up @@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) {
optimize |= memory_force_optimize;
memory_force_optimize = false;

const auto mstate = memory->init_update(this, optimize);
switch (mstate->get_status()) {
const auto mctx = memory->init_update(this, optimize);
switch (mctx->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
// noop
Expand All @@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) {
}
}

if (!mstate->apply()) {
if (!mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
}
}

// if the memory module did any computation, we have to reserve a new worst-case graph
{
const auto mstate = memory->init_full();
if (!mstate) {
throw std::runtime_error("failed to initialize memory state");
const auto mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize memory context");
}

const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);

auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
}
Expand Down Expand Up @@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec(
return cvec.apply(model, data, len, n_embd, il_start, il_end);
}

llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
if (mstate && !mstate->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
if (mctx && !mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}
Expand All @@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
return nullptr;
}

auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
if (!res) {
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
ret = GGML_STATUS_FAILED;
Expand Down Expand Up @@ -933,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
// handle any pending defrags/shifts
kv_self_update(false);

llama_memory_state_ptr mstate;
llama_memory_context_ptr mctx;

while (true) {
mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
if (!mstate) {
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
if (!mctx) {
return -2;
}

switch (mstate->get_status()) {
switch (mctx->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
} break;
case LLAMA_MEMORY_STATUS_NO_UPDATE:
{
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());

return -2;
}
Expand Down Expand Up @@ -987,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
int64_t n_outputs_prev = 0;

do {
const auto & ubatch = mstate->get_ubatch();
const auto & ubatch = mctx->get_ubatch();

// count the outputs in this ubatch
{
Expand All @@ -1009,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);

ggml_status status;
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);

if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
Expand Down Expand Up @@ -1126,7 +1126,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
}

n_outputs_prev += n_outputs;
} while (mstate->next());
} while (mctx->next());

// set to total number of outputs in the batch, for use in llama_get_logits_ith
n_outputs = n_outputs_all;
Expand Down Expand Up @@ -1292,7 +1292,7 @@ ggml_cgraph * llama_context::graph_init() {
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
}

ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);

if (n_tokens % n_seqs != 0) {
Expand All @@ -1312,7 +1312,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);

auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);

this->n_outputs = save_n_outputs;

Expand All @@ -1333,11 +1333,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
}

llm_graph_result_ptr llama_context::graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_state_i * mstate) {
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_context_i * mctx) {
return model.build_graph(
{
/*.ctx =*/ ctx,
Expand All @@ -1349,7 +1349,7 @@ llm_graph_result_ptr llama_context::graph_build(
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mstate =*/ mstate,
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(),
Expand Down Expand Up @@ -2042,8 +2042,8 @@ void llama_context::opt_epoch_iter(

uint32_t n_outputs_all = n_tokens_all;

auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break;
}
Expand All @@ -2056,17 +2056,17 @@ void llama_context::opt_epoch_iter(

uint32_t pos_batch = 0;
do {
const auto & ubatch = mstate->get_ubatch();
const auto & ubatch = mctx->get_ubatch();

n_outputs = ubatch.n_tokens;

if (!mstate->apply()) {
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
if (!mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
break;
}

auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());

struct ggml_context * ctx_compute_opt;
{
Expand Down Expand Up @@ -2101,7 +2101,7 @@ void llama_context::opt_epoch_iter(
ggml_free(ctx_compute_opt);

pos_batch += ubatch.n_tokens;
} while (mstate->next());
} while (mctx->next());
}
}

Expand Down
24 changes: 12 additions & 12 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class llama_io_read_i;
class llama_io_write_i;

struct llama_memory_i;
struct llama_memory_state_i;
struct llama_memory_context_i;

struct llama_context {
// init scheduler and compute buffers, reserve worst-case graphs
Expand Down Expand Up @@ -93,14 +93,14 @@ struct llama_context {
int32_t il_end);

// process a single ubatch with a specific graph type
// if memory_state is provided, it will be applied first to the context's memory
// if memory_context is provided, it will be applied first to the context's memory
// ret contains the status of the graph computation
// returns nullptr only if ret != GGML_STATUS_SUCCESS
llm_graph_result_ptr process_ubatch(
const llama_ubatch & ubatch,
llm_graph_type gtype,
llama_memory_state_i * mstate,
ggml_status & ret);
const llama_ubatch & ubatch,
llm_graph_type gtype,
llama_memory_context_i * mctx,
ggml_status & ret);

int encode(const llama_batch & batch_inp);
int decode(const llama_batch & batch_inp);
Expand Down Expand Up @@ -197,15 +197,15 @@ struct llama_context {
ggml_status graph_compute(ggml_cgraph * gf, bool batched);

// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);

private:
llm_graph_result_ptr graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_state_i * mstate);
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_context_i * mctx);

llm_graph_cb graph_get_cb() const;

Expand Down
Loading
Loading