Skip to content

feat: 添加minicpm3模型缺失的算子 #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions operators/src/attention/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ pub(super) struct Meta {
pub seq: MaybeDyn<usize>,
pub att: MaybeDyn<usize>,
pub dh: MaybeDyn<usize>,
pub dv: MaybeDyn<usize>,
}

impl<H: Hardware> Args<H> {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new_null(
mask: AttnMask,
dt: DigitLayout,
Expand All @@ -41,17 +43,20 @@ impl<H: Hardware> Args<H> {
seq: MaybeDyn<usize>,
att: MaybeDyn<usize>,
dh: MaybeDyn<usize>,
dv: MaybeDyn<usize>,
) -> Self {
let qo_layout = TensorLayout::new_dyn(dt, &[nh, seq, dh], &[dyn_(); 3]);
let kv_layout = TensorLayout::new_dyn(dt, &[nkvh, att, dh], &[dyn_(); 3]);
let q_layout = TensorLayout::new_dyn(dt, &[nh, seq, dh], &[dyn_(); 3]);
let k_layout = TensorLayout::new_dyn(dt, &[nkvh, seq, dh], &[dyn_(); 3]);
let v_layout = TensorLayout::new_dyn(dt, &[nkvh, att, dv], &[dyn_(); 3]);
let o_layout = TensorLayout::new_dyn(dt, &[nkvh, att, dh], &[dyn_(); 3]);
Self {
q_layout: qo_layout.clone(),
q_layout: q_layout.clone(),
q_base: null_mut(),
k_layout: kv_layout.clone(),
k_layout: k_layout.clone(),
k_base: null(),
v_layout: kv_layout,
v_layout,
v_base: null(),
o_layout: qo_layout,
o_layout,
o_base: null_mut(),
mask,
}
Expand Down Expand Up @@ -85,7 +90,8 @@ impl<H: Hardware> Args<H> {
nkvh: dim_distinct(&[nkvh_k, nkvh_v])?,
seq: dim_distinct(&[seq_q, seq_o])?,
att: dim_distinct(&[att_k, att_v])?,
dh: dim_distinct(&[dh_q, dh_k, dh_v, dh_o])?,
dh: dim_distinct(&[dh_q, dh_k])?,
dv: dim_distinct(&[dh_v, dh_o])?,
})
}
}
1 change: 1 addition & 0 deletions operators/src/attention/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod test {
seq.into(),
att.into(),
dyn_(),
dyn_(),
)
}

Expand Down
18 changes: 11 additions & 7 deletions operators/src/attention/operator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{args::Meta, Args, Attention};
use super::{args::Meta, Args, Attention};
use crate::{
dyn_, fuesd_softmax, get_static, mat_mul, rearrange, ByteOf, Hardware, LaunchError, QueueAlloc,
SchemeError, TensorLayout, Workspace, WorkspaceCollector,
Expand Down Expand Up @@ -53,6 +53,7 @@ where
seq,
att,
dh,
dv,
..
} = args.meta()?;
let Args {
Expand All @@ -64,11 +65,12 @@ where
} = args;

// 如果不能保证 nh seq att dh 已知,用任意值初始化算子
let (Some(&nh), Some(&seq), Some(&att), Some(&dh)) = (
let (Some(&nh), Some(&seq), Some(&att), Some(&dh), Some(&_dv)) = (
nh.get_static(),
seq.get_static(),
att.get_static(),
dh.get_static(),
dv.get_static(),
) else {
let mut wc = WorkspaceCollector::new();

Expand Down Expand Up @@ -149,6 +151,7 @@ where
seq,
att,
dh,
dv,
} = args.meta()?;
let Args {
mask,
Expand All @@ -172,8 +175,8 @@ where
let ele = dt.nbytes();
get_static! {
nh seq dh
nh_sq seq_sq dh_sq
nkvh att
dv seq_sq dh_sq
nkvh att nh_sq
nkvh_sk att_sk dh_sk
};

Expand All @@ -191,7 +194,7 @@ where
let (att_buf, workspace) = workspace.split_at_mut(att_size);

let head_group = nh / nkvh;
let (q_layout, qx_layout, q_base) = match qx {
let (_q_layout, qx_layout, q_base) = match qx {
None => {
let q_layout = TensorLayout::new_contiguous(dt, &[nh, seq, dh]);
let qx_layout = TensorLayout::new_contiguous(dt, &[nkvh, head_group * seq, dh]);
Expand Down Expand Up @@ -219,6 +222,7 @@ where
let k_layout = TensorLayout::new(dt, k_layout.shape(), k_layout.strides());
let att_mat_mul = TensorLayout::new_contiguous(dt, &[nkvh, head_group * seq, att]);
let att_softmax = TensorLayout::new_contiguous(dt, &[nh, seq, att]);
let att_result = TensorLayout::new_contiguous(dt, &[nkvh, head_group * seq, dv]);

// att = q . k^T
self.mat_mul.launch(
Expand Down Expand Up @@ -248,7 +252,7 @@ where
// q = att . v
self.mat_mul.launch(
&mat_mul::Args {
c_layout: qx_layout.clone(),
c_layout: att_result.clone(),
c_base: q_base,
beta: 0.,
a_layout: att_mat_mul,
Expand All @@ -266,7 +270,7 @@ where
&rearrange::Args {
dst_layout: o_layout.clone(),
dst_base: *o_base,
src_layout: q_layout.clone(),
src_layout: o_layout.clone(),
src_base: q_base,
},
workspace,
Expand Down
4 changes: 2 additions & 2 deletions operators/src/attention_kv_cached/operator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{args::Meta, Args, AttnKVCached};
use super::{args::Meta, Args, AttnKVCached};
use crate::{
attention, dyn_, get_static, rearrange, shape_mismatch, ByteOf, Hardware, LaunchError,
MaybeDyn, QueueAlloc, TensorLayout, WorkspaceCollector,
Expand Down Expand Up @@ -66,7 +66,7 @@ where
};

wc.push_sub(self.attention.scheme(
&attention::Args::new_null(args.mask, dt, nh, nkvh, seq, att, dh),
&attention::Args::new_null(args.mask, dt, nh, nkvh, seq, att, dh, dh),
max_workspace_size,
)?);

Expand Down
125 changes: 125 additions & 0 deletions operators/src/attention_mla/args.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
use crate::{
dyn_,
fuesd_softmax::AttnMask,
utils::{dim_distinct, rank_error, type_distinct},
ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout,
};
use digit_layout::DigitLayout;
use std::ptr::{null, null_mut};

pub struct Args<H: Hardware> {
// q传入的是是吸收后的
pub q_layout: TensorLayout,
pub q_base: MutPtr<H>,

pub kv_layout: TensorLayout,
pub kv_base: ConstPtr<H>,

pub absorb_layout: TensorLayout,
pub absorb_base: ConstPtr<H>,

pub qr_layout: TensorLayout,
pub qr_base: ConstPtr<H>,

pub kr_layout: TensorLayout,
pub kr_base: ConstPtr<H>,

pub o_layout: TensorLayout,
pub o_base: MutPtr<H>,

pub mask: AttnMask,
}

pub(super) struct Meta {
pub dt: DigitLayout,
pub nh: MaybeDyn<usize>,
pub seq: MaybeDyn<usize>,
pub att: MaybeDyn<usize>,
pub dkv: MaybeDyn<usize>,
pub dv: MaybeDyn<usize>,
pub dr: MaybeDyn<usize>,
}

impl<H: Hardware> Args<H> {

Check warning

Code scanning / clippy

associated function new_null is never used Warning

associated function new\_null is never used
#[allow(clippy::too_many_arguments)]
pub(crate) fn new_null(

Check warning

Code scanning / clippy

associated function new_null is never used Warning

associated function new\_null is never used
mask: AttnMask,
dt: DigitLayout,
nh: MaybeDyn<usize>,
dkv: MaybeDyn<usize>,
seq: MaybeDyn<usize>,
att: MaybeDyn<usize>,
dv: MaybeDyn<usize>,
dr: MaybeDyn<usize>,
) -> Self {
let q_layout = TensorLayout::new_dyn(dt, &[nh, seq, dkv], &[dyn_(); 3]);
let kv_layout = TensorLayout::new_dyn(dt, &[nh, att, dkv], &[dyn_(); 3]);
let absorb_layout = TensorLayout::new_dyn(dt, &[nh, dv, dkv], &[dyn_(); 3]);
let qr_layout = TensorLayout::new_dyn(dt, &[nh, seq, dr], &[dyn_(); 3]);
let kr_layout = TensorLayout::new_dyn(dt, &[nh, att, dr], &[dyn_(); 3]);
let o_layout = TensorLayout::new_dyn(dt, &[nh, seq, dv], &[dyn_(); 3]);
Self {
q_layout,
q_base: null_mut(),
kv_layout,
kv_base: null(),
absorb_layout,
absorb_base: null(),
qr_layout,
qr_base: null(),
kr_layout,
kr_base: null(),
o_layout,
o_base: null_mut(),
mask,
}
}

pub(super) fn meta(&self) -> Result<Meta, SchemeError> {
let Self {
q_layout,
kv_layout,
absorb_layout,
qr_layout,
kr_layout,
o_layout,
..
} = self;

let &[nh_q, seq_q, dkv_q] = q_layout.shape() else {
return Err(rank_error("q", 3, q_layout.ndim()));
};

let &[nh_kv, attn_kv, dkv_kv] = kv_layout.shape() else {
return Err(rank_error("kv", 3, kv_layout.ndim()));
};
let &[nh_a, dv_a, dkv_a] = absorb_layout.shape() else {
return Err(rank_error("absorb", 3, absorb_layout.ndim()));
};
let &[nh_qr, seq_qr, dr_qr] = qr_layout.shape() else {
return Err(rank_error("qr", 3, qr_layout.ndim()));
};
let &[nh_kr, att_kr, dr_kr] = kr_layout.shape() else {
return Err(rank_error("kr", 3, kr_layout.ndim()));
};
let &[nh_o, seq_o, dv_o] = o_layout.shape() else {
return Err(rank_error("o", 3, o_layout.ndim()));
};

Ok(Meta {
dt: type_distinct(&[
q_layout.dt(),
kv_layout.dt(),
qr_layout.dt(),
kr_layout.dt(),
o_layout.dt(),
])?,
nh: dim_distinct(&[nh_q, nh_kv, nh_a, nh_qr, nh_kr, nh_o])?,
seq: dim_distinct(&[seq_q, seq_o, seq_qr])?,
att: dim_distinct(&[attn_kv, att_kr])?,
dkv: dim_distinct(&[dkv_a, dkv_kv, dkv_q])?,
dv: dim_distinct(&[dv_a, dv_o])?,
dr: dim_distinct(&[dr_kr, dr_qr])?,
})
}
}
1 change: 1 addition & 0 deletions operators/src/attention_mla/common_cpu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
impl_op!(common_cpu, Cpu);
1 change: 1 addition & 0 deletions operators/src/attention_mla/cuda.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
impl_op!(cuda, Gpu);
1 change: 1 addition & 0 deletions operators/src/attention_mla/infini.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
impl_op!(infini, Device);
26 changes: 26 additions & 0 deletions operators/src/attention_mla/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
mod args;
mod operator;

pub use args::Args;

crate::op_trait!(AttentionMLA);

macro_rules! impl_op {
($dev:ident, $proc:ident) => {
pub type Operator = super::operator::Operator<
crate::$dev::$proc,
crate::mat_mul::$dev::Operator,
crate::fuesd_softmax::$dev::Operator,
crate::rearrange::$dev::Operator,
>;
};
}

#[cfg(any(use_cpu, test))]
pub mod common_cpu;
#[cfg(use_cuda)]
pub mod cuda;
#[cfg(use_infini)]
pub mod infini;
#[cfg(use_cl)]
pub mod opencl;
1 change: 1 addition & 0 deletions operators/src/attention_mla/opencl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
impl_op!(opencl, ClDevice);
Loading
Loading