Skip to content

Commit ab65dfa

Browse files
committed
fix: 添加更多rope
1 parent f4a83f7 commit ab65dfa

File tree

9 files changed

+296
-65
lines changed

9 files changed

+296
-65
lines changed

operators/src/attention/args.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ pub(super) struct Meta {
3434
}
3535

3636
impl<H: Hardware> Args<H> {
37+
#[allow(clippy::too_many_arguments)]
3738
pub(crate) fn new_null(
3839
mask: AttnMask,
3940
dt: DigitLayout,
@@ -53,9 +54,9 @@ impl<H: Hardware> Args<H> {
5354
q_base: null_mut(),
5455
k_layout: k_layout.clone(),
5556
k_base: null(),
56-
v_layout: v_layout,
57+
v_layout,
5758
v_base: null(),
58-
o_layout: o_layout,
59+
o_layout,
5960
o_base: null_mut(),
6061
mask,
6162
}

operators/src/attention/operator.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ where
6565
} = args;
6666

6767
// 如果不能保证 nh seq att dh 已知,用任意值初始化算子
68-
let (Some(&nh), Some(&seq), Some(&att), Some(&dh), Some(&dv)) = (
68+
let (Some(&nh), Some(&seq), Some(&att), Some(&dh), Some(&_dv)) = (
6969
nh.get_static(),
7070
seq.get_static(),
7171
att.get_static(),
@@ -194,7 +194,7 @@ where
194194
let (att_buf, workspace) = workspace.split_at_mut(att_size);
195195

196196
let head_group = nh / nkvh;
197-
let (q_layout, qx_layout, q_base) = match qx {
197+
let (_q_layout, qx_layout, q_base) = match qx {
198198
None => {
199199
let q_layout = TensorLayout::new_contiguous(dt, &[nh, seq, dh]);
200200
let qx_layout = TensorLayout::new_contiguous(dt, &[nkvh, head_group * seq, dh]);

operators/src/rearrange/args.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,13 +268,13 @@ fn test_scheme() {
268268
dst_layout: TensorLayout::new(
269269
F16,
270270
&shape,
271-
&[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 1 * 2],
271+
&[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 2],
272272
),
273273
dst_base: null_mut(),
274274
src_layout: TensorLayout::new(
275275
F16,
276276
&shape,
277-
&[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 1 * 2],
277+
&[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 2],
278278
),
279279
src_base: null(),
280280
};

operators/src/rope/args.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,45 @@
1-
use crate::{
1+
use crate::{
22
type_not_support,
33
utils::{dim_distinct, rank_error},
44
ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout,
55
};
66
use digit_layout::DigitLayout;
77

8+
pub enum RopeType<H: Hardware> {
9+
// 以下枚举通用一个 Scheme
10+
Rope,
11+
Pi {
12+
s: f32,
13+
},
14+
Ntk {
15+
s: f32,
16+
},
17+
Dyn {
18+
s: f32,
19+
a: f32,
20+
},
21+
22+
// 以下枚举通用一个 Scheme
23+
NtkParts {
24+
alpha: f32,
25+
beta: f32,
26+
l0: f32,
27+
s: f32,
28+
},
29+
Yarn {
30+
alpha: f32,
31+
beta: f32,
32+
l0: f32,
33+
s: f32,
34+
},
35+
Long {
36+
long: ConstPtr<H>,
37+
short: ConstPtr<H>,
38+
max_pos: u32,
39+
origin_pos: u32,
40+
},
41+
}
42+
843
pub struct Args<H: Hardware> {
944
pub t_layout: TensorLayout,
1045
pub t_base: MutPtr<H>,
@@ -15,6 +50,7 @@ pub struct Args<H: Hardware> {
1550
pub cos_layout: TensorLayout,
1651
pub cos_base: ConstPtr<H>,
1752
pub theta: f32,
53+
pub rope_type: RopeType<H>,
1854
}
1955

2056
pub(super) struct Meta {

0 commit comments

Comments
 (0)