Skip to content

Commit e754707

Browse files
committed
fix: 添加更多rope
1 parent f4a83f7 commit e754707

File tree

9 files changed

+219
-42
lines changed

9 files changed

+219
-42
lines changed

operators/src/attention/args.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ pub(super) struct Meta {
3434
}
3535

3636
impl<H: Hardware> Args<H> {
37+
#[allow(clippy::too_many_arguments)]
3738
pub(crate) fn new_null(
3839
mask: AttnMask,
3940
dt: DigitLayout,
@@ -53,9 +54,9 @@ impl<H: Hardware> Args<H> {
5354
q_base: null_mut(),
5455
k_layout: k_layout.clone(),
5556
k_base: null(),
56-
v_layout: v_layout,
57+
v_layout,
5758
v_base: null(),
58-
o_layout: o_layout,
59+
o_layout,
5960
o_base: null_mut(),
6061
mask,
6162
}

operators/src/attention/operator.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ where
6565
} = args;
6666

6767
// 如果不能保证 nh seq att dh 已知,用任意值初始化算子
68-
let (Some(&nh), Some(&seq), Some(&att), Some(&dh), Some(&dv)) = (
68+
let (Some(&nh), Some(&seq), Some(&att), Some(&dh), Some(&_dv)) = (
6969
nh.get_static(),
7070
seq.get_static(),
7171
att.get_static(),
@@ -194,7 +194,7 @@ where
194194
let (att_buf, workspace) = workspace.split_at_mut(att_size);
195195

196196
let head_group = nh / nkvh;
197-
let (q_layout, qx_layout, q_base) = match qx {
197+
let (_q_layout, qx_layout, q_base) = match qx {
198198
None => {
199199
let q_layout = TensorLayout::new_contiguous(dt, &[nh, seq, dh]);
200200
let qx_layout = TensorLayout::new_contiguous(dt, &[nkvh, head_group * seq, dh]);

operators/src/rearrange/args.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,13 +268,13 @@ fn test_scheme() {
268268
dst_layout: TensorLayout::new(
269269
F16,
270270
&shape,
271-
&[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 1 * 2],
271+
&[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 2],
272272
),
273273
dst_base: null_mut(),
274274
src_layout: TensorLayout::new(
275275
F16,
276276
&shape,
277-
&[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 1 * 2],
277+
&[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 2],
278278
),
279279
src_base: null(),
280280
};

operators/src/rope/args.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,45 @@
1-
use crate::{
1+
use crate::{
22
type_not_support,
33
utils::{dim_distinct, rank_error},
44
ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout,
55
};
66
use digit_layout::DigitLayout;
77

8+
pub enum RopeType<H: Hardware> {
9+
// 以下枚举通用一个 Scheme
10+
Rope,
11+
Pi {
12+
s: f32,
13+
},
14+
Ntk {
15+
s: f32,
16+
},
17+
Dyn {
18+
s: f32,
19+
a: f32,
20+
},
21+
22+
// 以下枚举通用一个 Scheme
23+
NtkParts {
24+
alpha: f32,
25+
beta: f32,
26+
l0: f32,
27+
s: f32,
28+
},
29+
Yarn {
30+
alpha: f32,
31+
beta: f32,
32+
l0: f32,
33+
s: f32,
34+
},
35+
Long {
36+
long: ConstPtr<H>,
37+
short: ConstPtr<H>,
38+
max_pos: u32,
39+
origin_pos: u32,
40+
},
41+
}
42+
843
pub struct Args<H: Hardware> {
944
pub t_layout: TensorLayout,
1045
pub t_base: MutPtr<H>,
@@ -15,6 +50,7 @@ pub struct Args<H: Hardware> {
1550
pub cos_layout: TensorLayout,
1651
pub cos_base: ConstPtr<H>,
1752
pub theta: f32,
53+
pub rope_type: RopeType<H>,
1854
}
1955

2056
pub(super) struct Meta {

operators/src/rope/common_cpu/mod.rs

Lines changed: 164 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,36 @@
1-
use super::{args::Meta, fill_pos, Args, Rope, Seq, SinCosTable};
1+
use super::{args::Meta, args::RopeType as R, fill_pos, Args, Rope, Seq, SinCosTable};
22
use crate::{
33
common_cpu::Cpu, get_static, strides_not_support, ByteOf, LaunchError, QueueAlloc, SchemeError,
44
Unsigned,
55
};
66
use digit_layout::{types as ty, DigitLayout};
77
use half::f16;
8+
use std::ptr::null;
9+
#[derive(Copy, Clone)]
10+
enum NtkPartsType {
11+
None,
12+
Yarn,
13+
}
814

15+
#[derive(Copy, Clone)]
16+
enum SchemeType {
17+
Rope {
18+
s: f32,
19+
},
20+
Long {
21+
long: *const u8,
22+
short: *const u8,
23+
s: f32,
24+
origin_pos: u32,
25+
},
26+
NtkParts {
27+
alpha: f32,
28+
beta: f32,
29+
l0: f32,
30+
s: f32,
31+
ntktype: NtkPartsType,
32+
},
33+
}
934
pub struct Operator;
1035

1136
impl Rope<Cpu> for Operator {
@@ -78,6 +103,7 @@ impl crate::Operator for Operator {
78103
p_layout,
79104
p_base,
80105
theta,
106+
rope_type,
81107
..
82108
} = args;
83109
let &[_, nh, dh] = t_layout.shape() else {
@@ -99,6 +125,50 @@ impl crate::Operator for Operator {
99125
return Err(strides_not_support("").into());
100126
}
101127

128+
let (theta, scheme_type) = match rope_type {
129+
R::Rope | R::Dyn { .. } | R::Ntk { .. } | R::Pi { .. } => {
130+
let (theta, s) = match rope_type {
131+
R::Rope => (*theta, 1.),
132+
R::Dyn { s, a } => (theta * (a * s - a + 1.), 1.),
133+
R::Ntk { s } => (theta * s, 1.),
134+
R::Pi { s } => (*theta, *s),
135+
_ => unreachable!(),
136+
};
137+
(theta, SchemeType::Rope { s })
138+
}
139+
R::Long {
140+
long,
141+
short,
142+
max_pos,
143+
origin_pos,
144+
} => {
145+
let s = 1.0
146+
+ ((*max_pos as f32 / *origin_pos as f32).ln() / (*origin_pos as f32).ln())
147+
.sqrt();
148+
let scheme_type = SchemeType::Long {
149+
long: long.cast(),
150+
short: short.cast(),
151+
s,
152+
origin_pos: *origin_pos,
153+
};
154+
(*theta, scheme_type)
155+
}
156+
R::Yarn { alpha, beta, l0, s } | R::NtkParts { alpha, beta, l0, s } => {
157+
let ntktype = match rope_type {
158+
R::NtkParts { .. } => NtkPartsType::None,
159+
R::Yarn { .. } => NtkPartsType::Yarn,
160+
_ => unreachable!(),
161+
};
162+
let scheme_type = SchemeType::NtkParts {
163+
alpha: *alpha,
164+
beta: *beta,
165+
l0: *l0,
166+
s: *s,
167+
ntktype,
168+
};
169+
(*theta, scheme_type)
170+
}
171+
};
102172
macro_rules! calculate {
103173
($t:ty, $p:ty) => {
104174
Scheme::<$t, $p> {
@@ -108,9 +178,10 @@ impl crate::Operator for Operator {
108178
st,
109179
sh,
110180
sp,
111-
theta: *theta,
181+
theta,
112182
t_base: t_base.cast(),
113183
p_base: p_base.cast(),
184+
scheme_type,
114185
}
115186
.calculate()
116187
};
@@ -142,15 +213,15 @@ struct Scheme<A, P> {
142213
theta: f32,
143214
t_base: *mut A,
144215
p_base: *const P,
216+
scheme_type: SchemeType,
145217
}
146218

147219
unsafe impl<A, P> Send for Scheme<A, P> {}
148220
unsafe impl<A, P> Sync for Scheme<A, P> {}
149-
150221
/// 激活值。
151222
trait Activation: Sized {
152223
/// 激活值类型决定计算类型。
153-
type Calculation;
224+
type Calculation: Copy;
154225
/// 计算流程。
155226
fn calculate(pair: &mut [Self; 2], sin: Self::Calculation, cos: Self::Calculation);
156227
}
@@ -187,15 +258,69 @@ impl Activation for f64 {
187258
}
188259

189260
trait Position<Calculation> {
190-
fn freq_sin_cos(self, k: isize, dh: isize, theta: f32) -> (Calculation, Calculation);
261+
fn freq_sin_cos_rope(
262+
self,
263+
k: isize,
264+
dh: isize,
265+
theta: f32,
266+
s: f32,
267+
) -> (Calculation, Calculation);
268+
fn freq_sin_cos_long(
269+
self,
270+
k: isize,
271+
dh: isize,
272+
t: f32,
273+
f: Calculation,
274+
s: f32,
275+
) -> (Calculation, Calculation);
276+
#[allow(clippy::too_many_arguments)]
277+
fn freq_sin_cos_ntk_part(
278+
self,
279+
k: isize,
280+
dh: isize,
281+
theta: f32,
282+
alpha: f32,
283+
beta: f32,
284+
l0: f32,
285+
s: f32,
286+
ntktype: NtkPartsType,
287+
) -> (Calculation, Calculation);
191288
}
192289

193290
macro_rules! impl_position {
194291
($a:ty) => {
195292
impl<T: Unsigned> Position<$a> for T {
196293
#[inline]
197-
fn freq_sin_cos(self, k: isize, dh: isize, theta: f32) -> ($a, $a) {
198-
(self.val() as $a / (theta as $a).powf(k as $a / dh as $a)).sin_cos()
294+
fn freq_sin_cos_rope(self, k: isize, dh: isize, theta: f32, s: f32) -> ($a, $a) {
295+
(self.val() as $a * s as $a * (theta as $a).powf(k as $a / dh as $a).recip())
296+
.sin_cos()
297+
}
298+
#[inline]
299+
fn freq_sin_cos_long(self, k: isize, dh: isize, t: f32, f: $a, s: f32) -> ($a, $a) {
300+
let (sin, cos) =
301+
(self.val() as $a * (t as $a).powf(k as $a / dh as $a).recip() * f).sin_cos();
302+
(sin * s as $a, cos * s as $a)
303+
}
304+
#[inline]
305+
fn freq_sin_cos_ntk_part(
306+
self,
307+
k: isize,
308+
dh: isize,
309+
theta: f32,
310+
alpha: f32,
311+
beta: f32,
312+
l0: f32,
313+
s: f32,
314+
ntktype: NtkPartsType,
315+
) -> ($a, $a) {
316+
use std::f32::consts::PI;
317+
let pos = match ntktype {
318+
NtkPartsType::None => self.val() as $a,
319+
NtkPartsType::Yarn => self.val() as $a * (0.1 * s.ln() + 1.) as $a,
320+
};
321+
let theta = theta.powf(k as f32 / dh as f32).recip();
322+
let r = ((l0 / (2. * PI / theta) - alpha) / (beta - alpha)).clamp(0., 1.);
323+
(pos * ((1. - r) / s + r) as $a * theta as $a).sin_cos()
199324
}
200325
}
201326
};
@@ -206,8 +331,8 @@ impl_position!(f64);
206331

207332
impl<A, P> Scheme<A, P>
208333
where
209-
A: Activation,
210-
P: Position<A::Calculation> + Sync + Copy,
334+
A: Activation + Copy,
335+
P: Position<A::Calculation> + Sync + Copy + Unsigned,
211336
{
212337
fn calculate(&self) {
213338
let &Self {
@@ -220,6 +345,7 @@ where
220345
theta,
221346
t_base,
222347
p_base,
348+
scheme_type,
223349
} = self;
224350
let nt = nt as isize;
225351
let nh = nh as isize;
@@ -229,10 +355,38 @@ where
229355
for i in 0..nt {
230356
let t = unsafe { t_base.byte_offset(i * st).cast::<[A; 2]>() };
231357
let p = unsafe { *p_base.byte_offset(i * sp) };
358+
let factor = match scheme_type {
359+
SchemeType::Long {
360+
long,
361+
short,
362+
origin_pos,
363+
..
364+
} => unsafe {
365+
if p.val() < origin_pos as usize {
366+
(short as *const P).byte_offset(i * st).cast()
367+
} else {
368+
(long as *const P).byte_offset(i * st).cast()
369+
}
370+
},
371+
_ => null(),
372+
};
232373
for j in 0..nh {
233374
for k in 0..dh {
234375
let pair = unsafe { &mut *t.byte_offset(j * sh + k * sd) };
235-
let (sin, cos) = p.freq_sin_cos(k, dh, theta);
376+
let (sin, cos) = match scheme_type {
377+
SchemeType::Rope { s } => p.freq_sin_cos_rope(k, dh, theta, s),
378+
SchemeType::Long { s, .. } => {
379+
let factor = unsafe { *factor };
380+
p.freq_sin_cos_long(k, dh, theta, factor, s)
381+
}
382+
SchemeType::NtkParts {
383+
alpha,
384+
beta,
385+
l0,
386+
s,
387+
ntktype,
388+
} => p.freq_sin_cos_ntk_part(k, dh, theta, alpha, beta, l0, s, ntktype),
389+
};
236390
A::calculate(pair, sin, cos)
237391
}
238392
}

operators/src/rope/cuda/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ extern "C" __global__ void {POS_U64}(
184184
#[cfg(test)]
185185
mod test {
186186
use super::{Args, Gpu, Operator, POS_U32, POS_U64};
187-
use crate::{Hardware, Operator as _, TensorLayout};
187+
use crate::{rope::args, Hardware, Operator as _, TensorLayout};
188188
use digit_layout::{
189189
types::{F16, F64, U32},
190190
DigitLayout,
@@ -203,6 +203,7 @@ mod test {
203203
cos_layout: TensorLayout::new_dyn(dt_t, &[dyn_(); 2], &[dyn_(); 2]),
204204
cos_base: null(),
205205
theta: 0.,
206+
rope_type: args::RopeType::Rope,
206207
}
207208
}
208209

@@ -227,6 +228,7 @@ mod test {
227228
cos_layout: TensorLayout::new_contiguous(dt_t, &[0, dh]),
228229
cos_base: null(),
229230
theta,
231+
rope_type: args::RopeType::Rope,
230232
}
231233
}
232234

0 commit comments

Comments
 (0)