Skip to content

Commit 6989b92

Browse files
committed
feat: 添加mmla的吸收版本
1 parent ab65dfa commit 6989b92

File tree

9 files changed

+360
-2
lines changed

9 files changed

+360
-2
lines changed

operators/src/attention_mla/args.rs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
use crate::{
2+
dyn_,
3+
fuesd_softmax::AttnMask,
4+
utils::{dim_distinct, rank_error, type_distinct},
5+
ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout,
6+
};
7+
use digit_layout::DigitLayout;
8+
use std::ptr::{null, null_mut};
9+
10+
pub struct Args<H: Hardware> {
11+
// q传入的是是吸收后的
12+
pub q_layout: TensorLayout,
13+
pub q_base: MutPtr<H>,
14+
15+
pub kv_layout: TensorLayout,
16+
pub kv_base: ConstPtr<H>,
17+
18+
pub absorb_layout: TensorLayout,
19+
pub absorb_base: ConstPtr<H>,
20+
21+
pub qr_layout: TensorLayout,
22+
pub qr_base: ConstPtr<H>,
23+
24+
pub kr_layout: TensorLayout,
25+
pub kr_base: ConstPtr<H>,
26+
27+
pub o_layout: TensorLayout,
28+
pub o_base: MutPtr<H>,
29+
30+
pub mask: AttnMask,
31+
}
32+
33+
pub(super) struct Meta {
34+
pub dt: DigitLayout,
35+
pub nh: MaybeDyn<usize>,
36+
pub seq: MaybeDyn<usize>,
37+
pub att: MaybeDyn<usize>,
38+
pub dkv: MaybeDyn<usize>,
39+
pub dv: MaybeDyn<usize>,
40+
pub dr: MaybeDyn<usize>,
41+
}
42+
43+
impl<H: Hardware> Args<H> {
44+
#[allow(clippy::too_many_arguments)]
45+
pub(crate) fn new_null(
46+
mask: AttnMask,
47+
dt: DigitLayout,
48+
nh: MaybeDyn<usize>,
49+
dkv: MaybeDyn<usize>,
50+
seq: MaybeDyn<usize>,
51+
att: MaybeDyn<usize>,
52+
dv: MaybeDyn<usize>,
53+
dr: MaybeDyn<usize>,
54+
) -> Self {
55+
let q_layout = TensorLayout::new_dyn(dt, &[nh, seq, dkv], &[dyn_(); 3]);
56+
let kv_layout = TensorLayout::new_dyn(dt, &[nh, att, dkv], &[dyn_(); 3]);
57+
let absorb_layout = TensorLayout::new_dyn(dt, &[nh, dv, dkv], &[dyn_(); 3]);
58+
let qr_layout = TensorLayout::new_dyn(dt, &[nh, seq, dr], &[dyn_(); 3]);
59+
let kr_layout = TensorLayout::new_dyn(dt, &[nh, att, dr], &[dyn_(); 3]);
60+
let o_layout = TensorLayout::new_dyn(dt, &[nh, seq, dv], &[dyn_(); 3]);
61+
Self {
62+
q_layout,
63+
q_base: null_mut(),
64+
kv_layout,
65+
kv_base: null(),
66+
absorb_layout,
67+
absorb_base: null(),
68+
qr_layout,
69+
qr_base: null(),
70+
kr_layout,
71+
kr_base: null(),
72+
o_layout,
73+
o_base: null_mut(),
74+
mask,
75+
}
76+
}
77+
78+
pub(super) fn meta(&self) -> Result<Meta, SchemeError> {
79+
let Self {
80+
q_layout,
81+
kv_layout,
82+
absorb_layout,
83+
qr_layout,
84+
kr_layout,
85+
o_layout,
86+
..
87+
} = self;
88+
89+
let &[nh_q, seq_q, dkv_q] = q_layout.shape() else {
90+
return Err(rank_error("q", 3, q_layout.ndim()));
91+
};
92+
93+
let &[nh_kv, attn_kv, dkv_kv] = kv_layout.shape() else {
94+
return Err(rank_error("kv", 3, kv_layout.ndim()));
95+
};
96+
let &[nh_a, dv_a, dkv_a] = absorb_layout.shape() else {
97+
return Err(rank_error("absorb", 3, absorb_layout.ndim()));
98+
};
99+
let &[nh_qr, seq_qr, dr_qr] = qr_layout.shape() else {
100+
return Err(rank_error("qr", 3, qr_layout.ndim()));
101+
};
102+
let &[nh_kr, att_kr, dr_kr] = kr_layout.shape() else {
103+
return Err(rank_error("kr", 3, kr_layout.ndim()));
104+
};
105+
let &[nh_o, seq_o, dv_o] = o_layout.shape() else {
106+
return Err(rank_error("o", 3, o_layout.ndim()));
107+
};
108+
109+
Ok(Meta {
110+
dt: type_distinct(&[
111+
q_layout.dt(),
112+
kv_layout.dt(),
113+
qr_layout.dt(),
114+
kr_layout.dt(),
115+
o_layout.dt(),
116+
])?,
117+
nh: dim_distinct(&[nh_q, nh_kv, nh_a, nh_qr, nh_kr, nh_o])?,
118+
seq: dim_distinct(&[seq_q, seq_o, seq_qr])?,
119+
att: dim_distinct(&[attn_kv, att_kr])?,
120+
dkv: dim_distinct(&[dkv_a, dkv_kv, dkv_q])?,
121+
dv: dim_distinct(&[dv_a, dv_o])?,
122+
dr: dim_distinct(&[dr_kr, dr_qr])?,
123+
})
124+
}
125+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
impl_op!(common_cpu, Cpu);

operators/src/attention_mla/cuda.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
impl_op!(cuda, Gpu);

operators/src/attention_mla/infini.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
impl_op!(infini, Device);

operators/src/attention_mla/mod.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
mod args;
2+
mod operator;
3+
4+
pub use args::Args;
5+
6+
crate::op_trait!(AttentionMLA);
7+
8+
macro_rules! impl_op {
9+
($dev:ident, $proc:ident) => {
10+
pub type Operator = super::operator::Operator<
11+
crate::$dev::$proc,
12+
crate::mat_mul::$dev::Operator,
13+
crate::fuesd_softmax::$dev::Operator,
14+
crate::rearrange::$dev::Operator,
15+
>;
16+
};
17+
}
18+
19+
#[cfg(any(use_cpu, test))]
20+
pub mod common_cpu;
21+
#[cfg(use_cuda)]
22+
pub mod cuda;
23+
#[cfg(use_infini)]
24+
pub mod infini;
25+
#[cfg(use_cl)]
26+
pub mod opencl;

operators/src/attention_mla/opencl.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
impl_op!(opencl, ClDevice);
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
use super::{args::Meta, Args, AttentionMLA};
2+
use crate::{
3+
dyn_, fuesd_softmax, get_static, mat_mul, rearrange, ByteOf, Hardware, LaunchError, QueueAlloc,
4+
SchemeError, TensorLayout, Workspace, WorkspaceCollector,
5+
};
6+
use ndarray_layout::ArrayLayout;
7+
use std::marker::PhantomData;
8+
9+
pub struct Operator<Hardware, MatMul, Softmax, Rearrange> {
10+
mat_mul: MatMul,
11+
softmax: Softmax,
12+
rearrange: Rearrange,
13+
_phantom: PhantomData<Hardware>,
14+
}
15+
16+
impl<H, M, S, R> AttentionMLA<H> for Operator<H, M, S, R>
17+
where
18+
H: Hardware,
19+
M: mat_mul::MatMul<H>,
20+
S: fuesd_softmax::FusedSoftmax<H>,
21+
R: rearrange::Rearrange<H>,
22+
{
23+
}
24+
25+
impl<H, M, S, R> crate::Operator for Operator<H, M, S, R>
26+
where
27+
H: Hardware,
28+
M: mat_mul::MatMul<H>,
29+
S: fuesd_softmax::FusedSoftmax<H>,
30+
R: rearrange::Rearrange<H>,
31+
{
32+
type Hardware = H;
33+
type TopoNode = H;
34+
type Args = Args<H>;
35+
36+
fn new(node: &Self::TopoNode) -> Self {
37+
Self {
38+
mat_mul: M::new(node),
39+
softmax: S::new(node),
40+
rearrange: R::new(node),
41+
_phantom: PhantomData,
42+
}
43+
}
44+
45+
fn scheme(
46+
&mut self,
47+
args: &Self::Args,
48+
max_workspace_size: usize,
49+
) -> Result<usize, SchemeError> {
50+
// TODO
51+
Ok(0)
52+
}
53+
54+
fn launch<QA>(
55+
&self,
56+
args: &Self::Args,
57+
workspace: &mut [ByteOf<Self::Hardware>],
58+
queue_alloc: &QA,
59+
) -> Result<(), LaunchError>
60+
where
61+
QA: QueueAlloc<Hardware = Self::Hardware>,
62+
{
63+
let Meta {
64+
dt,
65+
nh,
66+
seq,
67+
att,
68+
dkv,
69+
dv,
70+
dr,
71+
} = args.meta()?;
72+
let Args {
73+
q_layout,
74+
q_base,
75+
kv_layout,
76+
kv_base,
77+
absorb_layout,
78+
absorb_base,
79+
qr_layout,
80+
qr_base,
81+
kr_layout,
82+
kr_base,
83+
o_layout,
84+
o_base,
85+
mask,
86+
} = args;
87+
88+
let &[nh_skv, att_skv, dkv_skv] = kv_layout.strides() else {
89+
unreachable!()
90+
};
91+
let &[nh_skr, att_skr, dr_skr] = kr_layout.strides() else {
92+
unreachable!()
93+
};
94+
let &[nh_sa, dv_sa, dkv_sa] = absorb_layout.strides() else {
95+
unreachable!()
96+
};
97+
let &[nh_so, seq_so, dv_so] = o_layout.strides() else {
98+
unreachable!()
99+
};
100+
let ele = dt.nbytes();
101+
get_static! {
102+
nh seq dkv dr
103+
nh_skv att_skv dkv_skv
104+
nh_skr att_skr dr_skr
105+
nh_sa dv_sa dkv_sa
106+
nh_so seq_so dv_so
107+
dv att
108+
};
109+
110+
#[inline(always)]
111+
fn layout(shape: [usize; 3], strides: [isize; 3]) -> ArrayLayout<3> {
112+
ArrayLayout::new(&shape, &strides, 0)
113+
}
114+
let kv_first_layout = layout([nh, att, dkv], [nh_skv, att_skv, dkv_skv]).transpose(&[2, 1]);
115+
let kr_layout = layout([nh, att, dr], [nh_skr, att_skr, dr_skr]).transpose(&[2, 1]);
116+
let a_layout = layout([nh, dv, dkv], [nh_sa, dv_sa, dkv_sa]).transpose(&[2, 1]);
117+
let att_w_layout = TensorLayout::new_contiguous(dt, &[nh, seq, att]);
118+
let attn_t_layout = TensorLayout::new_contiguous(dt, &[nh, seq, dkv]);
119+
let att_w_size = nh * seq * att * ele;
120+
let att_t_size = nh * seq * dkv * ele;
121+
let mut workspace = Workspace::new(queue_alloc, workspace, att_w_size + att_t_size);
122+
let (att_w_buf, workspace) = workspace.split_at_mut(att_w_size);
123+
let (attn_t_buf, workspace) = workspace.split_at_mut(att_t_size);
124+
125+
let kv_first_layout =
126+
TensorLayout::new(dt, kv_first_layout.shape(), kv_first_layout.strides());
127+
let kr_layout = TensorLayout::new(dt, kr_layout.shape(), kr_layout.strides());
128+
let a_layout = TensorLayout::new(dt, a_layout.shape(), a_layout.strides());
129+
// att_w = qr*kr^T + q*kv^T
130+
self.mat_mul.launch(
131+
&mat_mul::Args {
132+
c_layout: att_w_layout.clone(),
133+
c_base: att_w_buf.as_mut_ptr(),
134+
beta: 0.,
135+
a_layout: qr_layout.clone(),
136+
a_base: *qr_base,
137+
b_layout: kr_layout.clone(),
138+
b_base: *kr_base,
139+
alpha: ((dv + dr) as f32).sqrt().recip(),
140+
},
141+
workspace,
142+
queue_alloc,
143+
)?;
144+
145+
self.mat_mul.launch(
146+
&mat_mul::Args {
147+
c_layout: att_w_layout.clone(),
148+
c_base: att_w_buf.as_mut_ptr(),
149+
beta: 1.,
150+
a_layout: q_layout.clone(),
151+
a_base: *q_base,
152+
b_layout: kv_first_layout.clone(),
153+
b_base: *kv_base,
154+
alpha: ((dv + dr) as f32).sqrt().recip(),
155+
},
156+
workspace,
157+
queue_alloc,
158+
)?;
159+
// att_w = softmax(att)
160+
self.softmax.launch(
161+
&fuesd_softmax::Args {
162+
att_mask: *mask,
163+
att_layout: att_w_layout.clone(),
164+
att_base: att_w_buf.as_mut_ptr(),
165+
},
166+
workspace,
167+
queue_alloc,
168+
)?;
169+
// attn_t=att_o*kv
170+
self.mat_mul.launch(
171+
&mat_mul::Args {
172+
c_layout: attn_t_layout.clone(),
173+
c_base: attn_t_buf.as_mut_ptr(),
174+
beta: 0.,
175+
a_layout: att_w_layout.clone(),
176+
a_base: att_w_buf.as_ptr(),
177+
b_layout: kv_layout.clone(),
178+
b_base: *kv_base,
179+
alpha: 1.,
180+
},
181+
workspace,
182+
queue_alloc,
183+
)?;
184+
185+
// attn =attn_t*absorb^T
186+
self.mat_mul.launch(
187+
&mat_mul::Args {
188+
c_layout: o_layout.clone(),
189+
c_base: *o_base,
190+
beta: 0.,
191+
a_layout: attn_t_layout.clone(),
192+
a_base: attn_t_buf.as_ptr(),
193+
b_layout: a_layout.clone(),
194+
b_base: *absorb_base,
195+
alpha: 1.,
196+
},
197+
workspace,
198+
queue_alloc,
199+
)?;
200+
201+
Ok(())
202+
}
203+
}

operators/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub mod add_rows;
88
pub mod all_reduce;
99
pub mod attention;
1010
pub mod attention_kv_cached;
11+
pub mod attention_mla;
1112
pub mod broadcast;
1213
pub mod conv;
1314
pub mod fuesd_softmax;

operators/src/rope/common_cpu/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
use std::ptr::null;
2-
31
use super::{args::Meta, args::RopeType as R, fill_pos, Args, Rope, Seq, SinCosTable};
42
use crate::{
53
common_cpu::Cpu, get_static, strides_not_support, ByteOf, LaunchError, QueueAlloc, SchemeError,
64
Unsigned,
75
};
86
use digit_layout::{types as ty, DigitLayout};
97
use half::f16;
8+
use std::ptr::null;
109
#[derive(Copy, Clone)]
1110
enum NtkPartsType {
1211
None,

0 commit comments

Comments
 (0)