Skip to content

Commit 23f8506

Browse files
committed
feat: 添加吸收后的mla
1 parent ab65dfa commit 23f8506

File tree

9 files changed

+362
-2
lines changed

9 files changed

+362
-2
lines changed

operators/src/attention_mla/args.rs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
use crate::{
2+
dyn_,
3+
fuesd_softmax::AttnMask,
4+
utils::{dim_distinct, rank_error, type_distinct},
5+
ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout,
6+
};
7+
use digit_layout::DigitLayout;
8+
use std::ptr::{null, null_mut};
9+
10+
pub struct Args<H: Hardware> {
11+
// q传入的是是吸收后的
12+
pub q_layout: TensorLayout,
13+
pub q_base: MutPtr<H>,
14+
15+
pub kv_layout: TensorLayout,
16+
pub kv_base: ConstPtr<H>,
17+
18+
pub absorb_layout: TensorLayout,
19+
pub absorb_base: ConstPtr<H>,
20+
21+
pub qr_layout: TensorLayout,
22+
pub qr_base: ConstPtr<H>,
23+
24+
pub kr_layout: TensorLayout,
25+
pub kr_base: ConstPtr<H>,
26+
27+
pub o_layout: TensorLayout,
28+
pub o_base: MutPtr<H>,
29+
30+
pub mask: AttnMask,
31+
}
32+
33+
pub(super) struct Meta {
34+
pub dt: DigitLayout,
35+
pub nh: MaybeDyn<usize>,
36+
pub seq: MaybeDyn<usize>,
37+
pub att: MaybeDyn<usize>,
38+
pub dkv: MaybeDyn<usize>,
39+
pub dv: MaybeDyn<usize>,
40+
pub dr: MaybeDyn<usize>,
41+
}
42+
43+
impl<H: Hardware> Args<H> {
44+
#[allow(clippy::too_many_arguments)]
45+
pub(crate) fn new_null(
46+
mask: AttnMask,
47+
dt: DigitLayout,
48+
nh: MaybeDyn<usize>,
49+
dkv: MaybeDyn<usize>,
50+
seq: MaybeDyn<usize>,
51+
att: MaybeDyn<usize>,
52+
dv: MaybeDyn<usize>,
53+
dr: MaybeDyn<usize>,
54+
) -> Self {
55+
let q_layout = TensorLayout::new_dyn(dt, &[nh, seq, dkv], &[dyn_(); 3]);
56+
let kv_layout = TensorLayout::new_dyn(dt, &[nh, att, dkv], &[dyn_(); 3]);
57+
let absorb_layout = TensorLayout::new_dyn(dt, &[nh, dv, dkv], &[dyn_(); 3]);
58+
let qr_layout = TensorLayout::new_dyn(dt, &[nh, seq, dr], &[dyn_(); 3]);
59+
let kr_layout = TensorLayout::new_dyn(dt, &[nh, att, dr], &[dyn_(); 3]);
60+
let o_layout = TensorLayout::new_dyn(dt, &[nh, seq, dv], &[dyn_(); 3]);
61+
Self {
62+
q_layout,
63+
q_base: null_mut(),
64+
kv_layout,
65+
kv_base: null(),
66+
absorb_layout,
67+
absorb_base: null(),
68+
qr_layout,
69+
qr_base: null(),
70+
kr_layout,
71+
kr_base: null(),
72+
o_layout,
73+
o_base: null_mut(),
74+
mask,
75+
}
76+
}
77+
78+
pub(super) fn meta(&self) -> Result<Meta, SchemeError> {
79+
let Self {
80+
q_layout,
81+
kv_layout,
82+
absorb_layout,
83+
qr_layout,
84+
kr_layout,
85+
o_layout,
86+
..
87+
} = self;
88+
89+
let &[nh_q, seq_q, dkv_q] = q_layout.shape() else {
90+
return Err(rank_error("q", 3, q_layout.ndim()));
91+
};
92+
93+
let &[nh_kv, attn_kv, dkv_kv] = kv_layout.shape() else {
94+
return Err(rank_error("kv", 3, kv_layout.ndim()));
95+
};
96+
let &[nh_a, dv_a, dkv_a] = absorb_layout.shape() else {
97+
return Err(rank_error("absorb", 3, absorb_layout.ndim()));
98+
};
99+
let &[nh_qr, seq_qr, dr_qr] = qr_layout.shape() else {
100+
return Err(rank_error("qr", 3, qr_layout.ndim()));
101+
};
102+
let &[nh_kr, att_kr, dr_kr] = kr_layout.shape() else {
103+
return Err(rank_error("kr", 3, kr_layout.ndim()));
104+
};
105+
let &[nh_o, seq_o, dv_o] = o_layout.shape() else {
106+
return Err(rank_error("o", 3, o_layout.ndim()));
107+
};
108+
109+
Ok(Meta {
110+
dt: type_distinct(&[
111+
q_layout.dt(),
112+
kv_layout.dt(),
113+
qr_layout.dt(),
114+
kr_layout.dt(),
115+
o_layout.dt(),
116+
])?,
117+
nh: dim_distinct(&[nh_q, nh_kv, nh_a, nh_qr, nh_kr, nh_o])?,
118+
seq: dim_distinct(&[seq_q, seq_o, seq_qr])?,
119+
att: dim_distinct(&[attn_kv, att_kr])?,
120+
dkv: dim_distinct(&[dkv_a, dkv_kv, dkv_q])?,
121+
dv: dim_distinct(&[dv_a, dv_o])?,
122+
dr: dim_distinct(&[dr_kr, dr_qr])?,
123+
})
124+
}
125+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
impl_op!(common_cpu, Cpu);

operators/src/attention_mla/cuda.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
impl_op!(cuda, Gpu);

operators/src/attention_mla/infini.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
impl_op!(infini, Device);

operators/src/attention_mla/mod.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
mod args;
2+
mod operator;
3+
4+
pub use args::Args;
5+
6+
crate::op_trait!(AttentionMLA);
7+
8+
macro_rules! impl_op {
9+
($dev:ident, $proc:ident) => {
10+
pub type Operator = super::operator::Operator<
11+
crate::$dev::$proc,
12+
crate::mat_mul::$dev::Operator,
13+
crate::fuesd_softmax::$dev::Operator,
14+
crate::rearrange::$dev::Operator,
15+
>;
16+
};
17+
}
18+
19+
#[cfg(any(use_cpu, test))]
20+
pub mod common_cpu;
21+
#[cfg(use_cuda)]
22+
pub mod cuda;
23+
#[cfg(use_infini)]
24+
pub mod infini;
25+
#[cfg(use_cl)]
26+
pub mod opencl;

operators/src/attention_mla/opencl.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
impl_op!(opencl, ClDevice);
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
use super::{args::Meta, Args, AttentionMLA};
2+
use crate::{
3+
dyn_, fuesd_softmax, get_static, mat_mul, rearrange, ByteOf, Hardware, LaunchError, QueueAlloc,
4+
SchemeError, TensorLayout, Workspace, WorkspaceCollector,
5+
};
6+
use ndarray_layout::ArrayLayout;
7+
use std::marker::PhantomData;
8+
9+
pub struct Operator<Hardware, MatMul, Softmax, Rearrange> {
10+
mat_mul: MatMul,
11+
softmax: Softmax,
12+
rearrange: Rearrange,
13+
_phantom: PhantomData<Hardware>,
14+
}
15+
16+
impl<H, M, S, R> AttentionMLA<H> for Operator<H, M, S, R>
17+
where
18+
H: Hardware,
19+
M: mat_mul::MatMul<H>,
20+
S: fuesd_softmax::FusedSoftmax<H>,
21+
R: rearrange::Rearrange<H>,
22+
{
23+
}
24+
25+
impl<H, M, S, R> crate::Operator for Operator<H, M, S, R>
26+
where
27+
H: Hardware,
28+
M: mat_mul::MatMul<H>,
29+
S: fuesd_softmax::FusedSoftmax<H>,
30+
R: rearrange::Rearrange<H>,
31+
{
32+
type Hardware = H;
33+
type TopoNode = H;
34+
type Args = Args<H>;
35+
36+
fn new(node: &Self::TopoNode) -> Self {
37+
Self {
38+
mat_mul: M::new(node),
39+
softmax: S::new(node),
40+
rearrange: R::new(node),
41+
_phantom: PhantomData,
42+
}
43+
}
44+
45+
fn scheme(
46+
&mut self,
47+
args: &Self::Args,
48+
max_workspace_size: usize,
49+
) -> Result<usize, SchemeError> {
50+
// TODO
51+
Ok(0)
52+
}
53+
54+
fn launch<QA>(
55+
&self,
56+
args: &Self::Args,
57+
workspace: &mut [ByteOf<Self::Hardware>],
58+
queue_alloc: &QA,
59+
) -> Result<(), LaunchError>
60+
where
61+
QA: QueueAlloc<Hardware = Self::Hardware>,
62+
{
63+
let Meta {
64+
dt,
65+
nh,
66+
seq,
67+
att,
68+
dkv,
69+
dv,
70+
dr,
71+
} = args.meta()?;
72+
let Args {
73+
q_layout,
74+
q_base,
75+
kv_layout,
76+
kv_base,
77+
absorb_layout,
78+
absorb_base,
79+
qr_layout,
80+
qr_base,
81+
kr_layout,
82+
kr_base,
83+
o_layout,
84+
o_base,
85+
mask,
86+
} = args;
87+
88+
let &[nh_skv, att_skv, dkv_skv] = kv_layout.strides() else {
89+
unreachable!()
90+
};
91+
let &[nh_skr, att_skr, dr_skr] = kr_layout.strides() else {
92+
unreachable!()
93+
};
94+
let &[nh_sa, dv_sa, dkv_sa] = absorb_layout.strides() else {
95+
unreachable!()
96+
};
97+
let &[nh_so, seq_so, dv_so] = o_layout.strides() else {
98+
unreachable!()
99+
};
100+
let ele = dt.nbytes();
101+
get_static! {
102+
nh seq dkv dr
103+
nh_skv att_skv dkv_skv
104+
nh_skr att_skr dr_skr
105+
nh_sa dv_sa dkv_sa
106+
nh_so seq_so dv_so
107+
dv att
108+
};
109+
110+
#[inline(always)]
111+
fn layout(shape: [usize; 3], strides: [isize; 3]) -> ArrayLayout<3> {
112+
ArrayLayout::new(&shape, &strides, 0)
113+
}
114+
let kv_first_layout = layout([nh, att, dkv], [nh_skv, att_skv, dkv_skv]).transpose(&[2, 1]);
115+
let kr_layout = layout([nh, att, dr], [nh_skr, att_skr, dr_skr]).transpose(&[2, 1]);
116+
let a_layout = layout([nh, dv, dkv], [nh_sa, dv_sa, dkv_sa]).transpose(&[2, 1]);
117+
let o_layout = layout([nh, seq, dv], [nh_so, seq_so, dv_so]).transpose(&[1, 0]);
118+
let att_w_layout = TensorLayout::new_contiguous(dt, &[nh, seq, att]);
119+
let attn_t_layout = TensorLayout::new_contiguous(dt, &[nh, seq, dkv]);
120+
let att_w_size = nh * seq * att * ele;
121+
let att_t_size = nh * seq * dkv * ele;
122+
let mut workspace = Workspace::new(queue_alloc, workspace, att_w_size + att_t_size);
123+
let (att_w_buf, workspace) = workspace.split_at_mut(att_w_size);
124+
let (attn_t_buf, workspace) = workspace.split_at_mut(att_t_size);
125+
126+
let kv_first_layout =
127+
TensorLayout::new(dt, kv_first_layout.shape(), kv_first_layout.strides());
128+
let kr_layout = TensorLayout::new(dt, kr_layout.shape(), kr_layout.strides());
129+
let a_layout = TensorLayout::new(dt, a_layout.shape(), a_layout.strides());
130+
let o_layout = TensorLayout::new(dt, o_layout.shape(), o_layout.strides());
131+
// att_w = qr*kr^T + q*kv^T
132+
self.mat_mul.launch(
133+
&mat_mul::Args {
134+
c_layout: att_w_layout.clone(),
135+
c_base: att_w_buf.as_mut_ptr(),
136+
beta: 0.,
137+
a_layout: qr_layout.clone(),
138+
a_base: *qr_base,
139+
b_layout: kr_layout.clone(),
140+
b_base: *kr_base,
141+
alpha: ((dv + dr) as f32).sqrt().recip(),
142+
},
143+
workspace,
144+
queue_alloc,
145+
)?;
146+
self.mat_mul.launch(
147+
&mat_mul::Args {
148+
c_layout: att_w_layout.clone(),
149+
c_base: att_w_buf.as_mut_ptr(),
150+
beta: 1.,
151+
a_layout: q_layout.clone(),
152+
a_base: *q_base,
153+
b_layout: kv_first_layout.clone(),
154+
b_base: *kv_base,
155+
alpha: ((dv + dr) as f32).sqrt().recip(),
156+
},
157+
workspace,
158+
queue_alloc,
159+
)?;
160+
161+
// att_w = softmax(att)
162+
self.softmax.launch(
163+
&fuesd_softmax::Args {
164+
att_mask: *mask,
165+
att_layout: att_w_layout.clone(),
166+
att_base: att_w_buf.as_mut_ptr(),
167+
},
168+
workspace,
169+
queue_alloc,
170+
)?;
171+
// attn_t=att_o*kv
172+
self.mat_mul.launch(
173+
&mat_mul::Args {
174+
c_layout: attn_t_layout.clone(),
175+
c_base: attn_t_buf.as_mut_ptr(),
176+
beta: 0.,
177+
a_layout: att_w_layout.clone(),
178+
a_base: att_w_buf.as_ptr(),
179+
b_layout: kv_layout.clone(),
180+
b_base: *kv_base,
181+
alpha: 1.,
182+
},
183+
workspace,
184+
queue_alloc,
185+
)?;
186+
187+
// attn =attn_t*absorb^T
188+
self.mat_mul.launch(
189+
&mat_mul::Args {
190+
c_layout: o_layout.clone(),
191+
c_base: *o_base,
192+
beta: 0.,
193+
a_layout: attn_t_layout.clone(),
194+
a_base: attn_t_buf.as_ptr(),
195+
b_layout: a_layout.clone(),
196+
b_base: *absorb_base,
197+
alpha: 1.,
198+
},
199+
workspace,
200+
queue_alloc,
201+
)?;
202+
203+
Ok(())
204+
}
205+
}

operators/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub mod add_rows;
88
pub mod all_reduce;
99
pub mod attention;
1010
pub mod attention_kv_cached;
11+
pub mod attention_mla;
1112
pub mod broadcast;
1213
pub mod conv;
1314
pub mod fuesd_softmax;

operators/src/rope/common_cpu/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
use std::ptr::null;
2-
31
use super::{args::Meta, args::RopeType as R, fill_pos, Args, Rope, Seq, SinCosTable};
42
use crate::{
53
common_cpu::Cpu, get_static, strides_not_support, ByteOf, LaunchError, QueueAlloc, SchemeError,
64
Unsigned,
75
};
86
use digit_layout::{types as ty, DigitLayout};
97
use half::f16;
8+
use std::ptr::null;
109
#[derive(Copy, Clone)]
1110
enum NtkPartsType {
1211
None,

0 commit comments

Comments
 (0)