@@ -777,18 +777,27 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
777
777
778
778
if (!v_trans) {
779
779
if (kv_idxs) {
780
- return ggml_set_rows (ctx, v, ggml_reshape_2d (ctx, v_cur, v-> ne [ 0 ], n_tokens) , kv_idxs);
780
+ return ggml_set_rows (ctx, v, v_cur, kv_idxs);
781
781
}
782
782
783
783
v_view = ggml_view_1d (ctx, v,
784
784
n_tokens*hparams.n_embd_v_gqa (il),
785
785
ggml_row_size (v->type , hparams.n_embd_v_gqa (il))*head_cur);
786
786
} else {
787
+ // note: the V cache is transposed when not using flash attention
787
788
if (kv_idxs) {
788
- GGML_ABORT (" TODO: implement kv_idxs for transposed V cache -- for now use flash attention" );
789
+ // the row becomes a single element and we repeat the KV indices d_head times
790
+ // TODO: this seems not very optimal - can we do something better?
791
+ v_view = ggml_view_3d (ctx, v, 1 , v->ne [1 ], hparams.n_embd_v_gqa (il),
792
+ ggml_element_size (v),
793
+ (v->ne [1 ])*ggml_element_size (v),
794
+ 0 );
795
+
796
+ v_cur = ggml_reshape_3d (ctx, ggml_cont (ctx, ggml_transpose (ctx, v_cur)), 1 , n_tokens, hparams.n_embd_v_gqa (il));
797
+
798
+ return ggml_set_rows (ctx, v_view, v_cur, ggml_repeat_4d (ctx, kv_idxs, n_tokens, hparams.n_embd_v_gqa (il), 1 , 1 ));
789
799
}
790
800
791
- // note: the V cache is transposed when not using flash attention
792
801
v_view = ggml_view_2d (ctx, v, n_tokens, hparams.n_embd_v_gqa (il),
793
802
(v->ne [1 ])*ggml_element_size (v),
794
803
(head_cur)*ggml_element_size (v));
0 commit comments