Skip to content

ggml : add ggml_set_rows #14274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft

Conversation

rgerganov
Copy link
Collaborator

Add ggml_set_rows(a, b, c) which copies rows from 'b' into 'a' using indices from 'c'.

ref: #8366

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Jun 19, 2025
@ggerganov ggerganov mentioned this pull request Jun 19, 2025
4 tasks
@ggerganov
Copy link
Member

So far so good: #14285

I think the ggml_set_rows() alone could be a very useful addition since this mechanism can make the llama_kv_cache_unified::find_slot() to search not just for continuous slots of KV cells, but effectively be able to "scatter" the ubatch. This would be a useful improvement, regardless if the graph reuse works or not, so I think we should proceed to implement this operator.

switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_set_rows_f32(params, dst);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should aim to reuse the existing cpy/dup code in order to support F32 -> any type, not just F16.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need support for quantized data types or just F16,F32,BF16?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should support all possible KV cache quantization types:

llama.cpp/common/arg.cpp

Lines 819 to 830 in 28ee6d2

const std::vector<ggml_type> kv_cache_types = {
GGML_TYPE_F32,
GGML_TYPE_F16,
GGML_TYPE_BF16,
GGML_TYPE_Q8_0,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
GGML_TYPE_IQ4_NL,
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
};

Ideally, it should work for all ggml types that also work with ggml_cpy.

Comment on lines +3400 to +3411
struct ggml_tensor * ggml_set_rows(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c) {
GGML_ASSERT(b->ne[2] == c->ne[1]);
GGML_ASSERT(c->ne[3] == 1);
GGML_ASSERT(a->type == GGML_TYPE_F16);
GGML_ASSERT(b->type == GGML_TYPE_F32);
GGML_ASSERT(c->type == GGML_TYPE_I64);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to allow broadcasting c into b. It would avoid this ggml_repeat_4d here:

v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
kv_idxs = ggml_repeat_4d(ctx, kv_idxs, v_cur->ne[1], v_cur->ne[2], 1, 1);
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);

Add ggml_set_rows(a, b, c) which copies rows from 'b' into 'a' using
indices from 'c'.

ref: ggml-org#8366
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants