Skip to content

Commit 4d0c0ea

Browse files
committed
kv-cache : support non-FA case
ggml-ci
1 parent 28ee6d2 commit 4d0c0ea

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -777,18 +777,27 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
777777

778778
if (!v_trans) {
779779
if (kv_idxs) {
780-
return ggml_set_rows(ctx, v, ggml_reshape_2d(ctx, v_cur, v->ne[0], n_tokens), kv_idxs);
780+
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
781781
}
782782

783783
v_view = ggml_view_1d(ctx, v,
784784
n_tokens*hparams.n_embd_v_gqa(il),
785785
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
786786
} else {
787+
// note: the V cache is transposed when not using flash attention
787788
if (kv_idxs) {
788-
GGML_ABORT("TODO: implement kv_idxs for transposed V cache -- for now use flash attention");
789+
// the row becomes a single element and we repeat the KV indices d_head times
790+
// TODO: this seems not very optimal - can we do something better?
791+
v_view = ggml_view_3d(ctx, v, 1, v->ne[1], hparams.n_embd_v_gqa(il),
792+
ggml_element_size(v),
793+
(v->ne[1])*ggml_element_size(v),
794+
0);
795+
796+
v_cur = ggml_reshape_3d(ctx, ggml_cont(ctx, ggml_transpose(ctx, v_cur)), 1, n_tokens, hparams.n_embd_v_gqa(il));
797+
798+
return ggml_set_rows(ctx, v_view, v_cur, ggml_repeat_4d(ctx, kv_idxs, n_tokens, hparams.n_embd_v_gqa(il), 1, 1));
789799
}
790800

791-
// note: the V cache is transposed when not using flash attention
792801
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
793802
(v->ne[1])*ggml_element_size(v),
794803
(head_cur)*ggml_element_size(v));

0 commit comments

Comments
 (0)