Skip to content

Commit d1da992

Browse files
committed
kv-cache : use ggml_set_rows
ggml-ci
1 parent bec0ceb commit d1da992

File tree

4 files changed

+89
-18
lines changed

4 files changed

+89
-18
lines changed

src/llama-graph.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,24 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281281
}
282282

283283
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284+
if (self_kv_idxs) {
285+
mctx->set_input_kv_idxs(self_kv_idxs, ubatch);
286+
}
287+
284288
if (self_kq_mask) {
285289
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
286290
}
287291
}
288292

289293
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
294+
if (self_kv_idxs) {
295+
mctx->get_base()->set_input_kv_idxs(self_kv_idxs, ubatch);
296+
}
297+
298+
if (self_kv_idxs_swa) {
299+
mctx->get_swa()->set_input_kv_idxs(self_kv_idxs_swa, ubatch);
300+
}
301+
290302
if (self_kq_mask) {
291303
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
292304
}
@@ -1192,6 +1204,9 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
11921204

11931205
const auto n_kv = mctx_cur->get_n_kv();
11941206

1207+
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
1208+
ggml_set_input(inp->self_kv_idxs);
1209+
11951210
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
11961211
//cb(inp->self_kq_mask, "KQ_mask", -1);
11971212
ggml_set_input(inp->self_kq_mask);
@@ -1224,8 +1239,10 @@ ggml_tensor * llm_graph_context::build_attn(
12241239

12251240
// store to KV cache
12261241
{
1227-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1228-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1242+
const auto & kv_idxs = inp->get_kv_idxs();
1243+
1244+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
1245+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
12291246
}
12301247

12311248
const auto & kq_mask = inp->get_kq_mask();
@@ -1278,8 +1295,10 @@ ggml_tensor * llm_graph_context::build_attn(
12781295

12791296
// store to KV cache
12801297
{
1281-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1282-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1298+
const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa() : inp->get_kv_idxs();
1299+
1300+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
1301+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
12831302
}
12841303

12851304
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1383,8 +1402,8 @@ ggml_tensor * llm_graph_context::build_attn(
13831402

13841403
// store to KV cache
13851404
{
1386-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1387-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1405+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, nullptr, il));
1406+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, nullptr, il));
13881407
}
13891408

13901409
const auto & kq_mask = inp->get_kq_mask();
@@ -1419,6 +1438,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14191438
{
14201439
const auto n_kv = mctx_cur->get_base()->get_n_kv();
14211440

1441+
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
1442+
ggml_set_input(inp->self_kv_idxs);
1443+
14221444
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
14231445
//cb(inp->self_kq_mask, "KQ_mask", -1);
14241446
ggml_set_input(inp->self_kq_mask);
@@ -1431,6 +1453,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14311453

14321454
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
14331455

1456+
inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
1457+
ggml_set_input(inp->self_kv_idxs_swa);
1458+
14341459
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
14351460
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
14361461
ggml_set_input(inp->self_kq_mask_swa);

src/llama-graph.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,12 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
248248

249249
void set_input(const llama_ubatch * ubatch) override;
250250

251+
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
251252
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
252253

254+
// TODO: should this be I64?
255+
ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch]
256+
253257
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
254258
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
255259

@@ -273,9 +277,14 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
273277

274278
void set_input(const llama_ubatch * ubatch) override;
275279

280+
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
281+
ggml_tensor * get_kv_idxs_swa() const { return self_kv_idxs_swa; }
276282
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
277283
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
278284

285+
ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch]
286+
ggml_tensor * self_kv_idxs_swa = nullptr; // I32 [n_batch]
287+
279288
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
280289
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
281290
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]

src/llama-kv-cache-unified.cpp

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -746,21 +746,25 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
746746
0);
747747
}
748748

749-
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
749+
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
750750
const int32_t ikv = map_layer_ids.at(il);
751751

752752
auto * k = layers[ikv].k;
753753

754754
const int64_t n_tokens = k_cur->ne[2];
755755

756+
if (kv_idxs) {
757+
return ggml_set_rows(ctx, k, ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens), kv_idxs);
758+
}
759+
756760
ggml_tensor * k_view = ggml_view_1d(ctx, k,
757761
n_tokens*hparams.n_embd_k_gqa(il),
758762
ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
759763

760764
return ggml_cpy(ctx, k_cur, k_view);
761765
}
762766

763-
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
767+
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
764768
const int32_t ikv = map_layer_ids.at(il);
765769

766770
auto * v = layers[ikv].v;
@@ -772,21 +776,48 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
772776
ggml_tensor * v_view = nullptr;
773777

774778
if (!v_trans) {
779+
if (kv_idxs) {
780+
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
781+
}
782+
775783
v_view = ggml_view_1d(ctx, v,
776784
n_tokens*hparams.n_embd_v_gqa(il),
777785
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
778786
} else {
787+
v_cur = ggml_transpose(ctx, v_cur);
788+
779789
// note: the V cache is transposed when not using flash attention
790+
if (kv_idxs) {
791+
// the row becomes a single element and we repeat the KV indices d_head times
792+
// TODO: this seems not very optimal - can we do something better?
793+
v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
794+
795+
v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
796+
797+
kv_idxs = ggml_repeat_4d(ctx, kv_idxs, v_cur->ne[1], v_cur->ne[2], 1, 1);
798+
799+
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
800+
}
801+
780802
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
781803
(v->ne[1])*ggml_element_size(v),
782804
(head_cur)*ggml_element_size(v));
783-
784-
v_cur = ggml_transpose(ctx, v_cur);
785805
}
786806

787807
return ggml_cpy(ctx, v_cur, v_view);
788808
}
789809

810+
void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const {
811+
const uint32_t n_tokens = ubatch->n_tokens;
812+
813+
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
814+
int64_t * data = (int64_t *) dst->data;
815+
816+
for (int64_t i = 0; i < n_tokens; ++i) {
817+
data[i] = head_cur + i;
818+
}
819+
}
820+
790821
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
791822
const uint32_t n_tokens = ubatch->n_tokens;
792823

@@ -1789,18 +1820,22 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
17891820
return kv->get_v(ctx, il, n_kv);
17901821
}
17911822

1792-
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1793-
return kv->cpy_k(ctx, k_cur, il, head);
1823+
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
1824+
return kv->cpy_k(ctx, k_cur, kv_idxs, il, head);
17941825
}
17951826

1796-
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1797-
return kv->cpy_v(ctx, v_cur, il, head);
1827+
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
1828+
return kv->cpy_v(ctx, v_cur, kv_idxs, il, head);
17981829
}
17991830

18001831
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
18011832
kv->set_input_k_shift(dst);
18021833
}
18031834

1835+
void llama_kv_cache_unified_context::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1836+
kv->set_input_kv_idxs(dst, ubatch, head);
1837+
}
1838+
18041839
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
18051840
kv->set_input_kq_mask(dst, ubatch, causal_attn);
18061841
}

src/llama-kv-cache-unified.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ class llama_kv_cache_unified : public llama_memory_i {
102102
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
103103

104104
// store k_cur and v_cur in the cache based on the provided head location
105-
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
106-
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
105+
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const;
106+
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const;
107107

108108
//
109109
// preparation API
@@ -126,6 +126,7 @@ class llama_kv_cache_unified : public llama_memory_i {
126126
// set_input API
127127
//
128128

129+
void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const;
129130
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
130131
void set_input_k_shift (ggml_tensor * dst) const;
131132
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@@ -257,11 +258,12 @@ class llama_kv_cache_unified_context : public llama_memory_context_i {
257258
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
258259

259260
// store k_cur and v_cur in the cache based on the provided head location
260-
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
261-
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
261+
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const;
262+
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const;
262263

263264
void set_input_k_shift(ggml_tensor * dst) const;
264265

266+
void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const;
265267
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
266268
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
267269

0 commit comments

Comments
 (0)