1
- use super :: { args:: Meta , fill_pos, Args , Rope , Seq , SinCosTable } ;
1
+ use super :: { args:: Meta , args :: RopeType as R , fill_pos, Args , Rope , Seq , SinCosTable } ;
2
2
use crate :: {
3
3
common_cpu:: Cpu , get_static, strides_not_support, ByteOf , LaunchError , QueueAlloc , SchemeError ,
4
4
Unsigned ,
5
5
} ;
6
6
use digit_layout:: { types as ty, DigitLayout } ;
7
7
use half:: f16;
8
+ use std:: ptr:: null;
9
+ #[ derive( Copy , Clone ) ]
10
+ enum NtkPartsType {
11
+ None ,
12
+ Yarn ,
13
+ }
8
14
15
+ #[ derive( Copy , Clone ) ]
16
+ enum SchemeType {
17
+ Rope {
18
+ s : f32 ,
19
+ } ,
20
+ Long {
21
+ long : * const u8 ,
22
+ short : * const u8 ,
23
+ s : f32 ,
24
+ origin_pos : u32 ,
25
+ } ,
26
+ NtkParts {
27
+ alpha : f32 ,
28
+ beta : f32 ,
29
+ l0 : f32 ,
30
+ s : f32 ,
31
+ ntktype : NtkPartsType ,
32
+ } ,
33
+ }
9
34
pub struct Operator ;
10
35
11
36
impl Rope < Cpu > for Operator {
@@ -78,6 +103,7 @@ impl crate::Operator for Operator {
78
103
p_layout,
79
104
p_base,
80
105
theta,
106
+ rope_type,
81
107
..
82
108
} = args;
83
109
let & [ _, nh, dh] = t_layout. shape ( ) else {
@@ -99,6 +125,50 @@ impl crate::Operator for Operator {
99
125
return Err ( strides_not_support ( "" ) . into ( ) ) ;
100
126
}
101
127
128
+ let ( theta, scheme_type) = match rope_type {
129
+ R :: Rope | R :: Dyn { .. } | R :: Ntk { .. } | R :: Pi { .. } => {
130
+ let ( theta, s) = match rope_type {
131
+ R :: Rope => ( * theta, 1. ) ,
132
+ R :: Dyn { s, a } => ( theta * ( a * s - a + 1. ) , 1. ) ,
133
+ R :: Ntk { s } => ( theta * s, 1. ) ,
134
+ R :: Pi { s } => ( * theta, * s) ,
135
+ _ => unreachable ! ( ) ,
136
+ } ;
137
+ ( theta, SchemeType :: Rope { s } )
138
+ }
139
+ R :: Long {
140
+ long,
141
+ short,
142
+ max_pos,
143
+ origin_pos,
144
+ } => {
145
+ let s = 1.0
146
+ + ( ( * max_pos as f32 / * origin_pos as f32 ) . ln ( ) / ( * origin_pos as f32 ) . ln ( ) )
147
+ . sqrt ( ) ;
148
+ let scheme_type = SchemeType :: Long {
149
+ long : long. cast ( ) ,
150
+ short : short. cast ( ) ,
151
+ s,
152
+ origin_pos : * origin_pos,
153
+ } ;
154
+ ( * theta, scheme_type)
155
+ }
156
+ R :: Yarn { alpha, beta, l0, s } | R :: NtkParts { alpha, beta, l0, s } => {
157
+ let ntktype = match rope_type {
158
+ R :: NtkParts { .. } => NtkPartsType :: None ,
159
+ R :: Yarn { .. } => NtkPartsType :: Yarn ,
160
+ _ => unreachable ! ( ) ,
161
+ } ;
162
+ let scheme_type = SchemeType :: NtkParts {
163
+ alpha : * alpha,
164
+ beta : * beta,
165
+ l0 : * l0,
166
+ s : * s,
167
+ ntktype,
168
+ } ;
169
+ ( * theta, scheme_type)
170
+ }
171
+ } ;
102
172
macro_rules! calculate {
103
173
( $t: ty, $p: ty) => {
104
174
Scheme :: <$t, $p> {
@@ -108,9 +178,10 @@ impl crate::Operator for Operator {
108
178
st,
109
179
sh,
110
180
sp,
111
- theta: * theta ,
181
+ theta,
112
182
t_base: t_base. cast( ) ,
113
183
p_base: p_base. cast( ) ,
184
+ scheme_type,
114
185
}
115
186
. calculate( )
116
187
} ;
@@ -142,15 +213,15 @@ struct Scheme<A, P> {
142
213
theta : f32 ,
143
214
t_base : * mut A ,
144
215
p_base : * const P ,
216
+ scheme_type : SchemeType ,
145
217
}
146
218
147
219
unsafe impl < A , P > Send for Scheme < A , P > { }
148
220
unsafe impl < A , P > Sync for Scheme < A , P > { }
149
-
150
221
/// 激活值。
151
222
trait Activation : Sized {
152
223
/// 激活值类型决定计算类型。
153
- type Calculation ;
224
+ type Calculation : Copy ;
154
225
/// 计算流程。
155
226
fn calculate ( pair : & mut [ Self ; 2 ] , sin : Self :: Calculation , cos : Self :: Calculation ) ;
156
227
}
@@ -187,15 +258,69 @@ impl Activation for f64 {
187
258
}
188
259
189
260
trait Position < Calculation > {
190
- fn freq_sin_cos ( self , k : isize , dh : isize , theta : f32 ) -> ( Calculation , Calculation ) ;
261
+ fn freq_sin_cos_rope (
262
+ self ,
263
+ k : isize ,
264
+ dh : isize ,
265
+ theta : f32 ,
266
+ s : f32 ,
267
+ ) -> ( Calculation , Calculation ) ;
268
+ fn freq_sin_cos_long (
269
+ self ,
270
+ k : isize ,
271
+ dh : isize ,
272
+ t : f32 ,
273
+ f : Calculation ,
274
+ s : f32 ,
275
+ ) -> ( Calculation , Calculation ) ;
276
+ #[ allow( clippy:: too_many_arguments) ]
277
+ fn freq_sin_cos_ntk_part (
278
+ self ,
279
+ k : isize ,
280
+ dh : isize ,
281
+ theta : f32 ,
282
+ alpha : f32 ,
283
+ beta : f32 ,
284
+ l0 : f32 ,
285
+ s : f32 ,
286
+ ntktype : NtkPartsType ,
287
+ ) -> ( Calculation , Calculation ) ;
191
288
}
192
289
193
290
macro_rules! impl_position {
194
291
( $a: ty) => {
195
292
impl <T : Unsigned > Position <$a> for T {
196
293
#[ inline]
197
- fn freq_sin_cos( self , k: isize , dh: isize , theta: f32 ) -> ( $a, $a) {
198
- ( self . val( ) as $a / ( theta as $a) . powf( k as $a / dh as $a) ) . sin_cos( )
294
+ fn freq_sin_cos_rope( self , k: isize , dh: isize , theta: f32 , s: f32 ) -> ( $a, $a) {
295
+ ( self . val( ) as $a * s as $a * ( theta as $a) . powf( k as $a / dh as $a) . recip( ) )
296
+ . sin_cos( )
297
+ }
298
+ #[ inline]
299
+ fn freq_sin_cos_long( self , k: isize , dh: isize , t: f32 , f: $a, s: f32 ) -> ( $a, $a) {
300
+ let ( sin, cos) =
301
+ ( self . val( ) as $a * ( t as $a) . powf( k as $a / dh as $a) . recip( ) * f) . sin_cos( ) ;
302
+ ( sin * s as $a, cos * s as $a)
303
+ }
304
+ #[ inline]
305
+ fn freq_sin_cos_ntk_part(
306
+ self ,
307
+ k: isize ,
308
+ dh: isize ,
309
+ theta: f32 ,
310
+ alpha: f32 ,
311
+ beta: f32 ,
312
+ l0: f32 ,
313
+ s: f32 ,
314
+ ntktype: NtkPartsType ,
315
+ ) -> ( $a, $a) {
316
+ use std:: f32 :: consts:: PI ;
317
+ let pos = match ntktype {
318
+ NtkPartsType :: None => self . val( ) as $a,
319
+ NtkPartsType :: Yarn => self . val( ) as $a * ( 0.1 * s. ln( ) + 1. ) as $a,
320
+ } ;
321
+ let theta = theta. powf( k as f32 / dh as f32 ) . recip( ) ;
322
+ let r = ( ( l0 / ( 2. * PI / theta) - alpha) / ( beta - alpha) ) . clamp( 0. , 1. ) ;
323
+ ( pos * ( ( 1. - r) / s + r) as $a * theta as $a) . sin_cos( )
199
324
}
200
325
}
201
326
} ;
@@ -206,8 +331,8 @@ impl_position!(f64);
206
331
207
332
impl < A , P > Scheme < A , P >
208
333
where
209
- A : Activation ,
210
- P : Position < A :: Calculation > + Sync + Copy ,
334
+ A : Activation + Copy ,
335
+ P : Position < A :: Calculation > + Sync + Copy + Unsigned ,
211
336
{
212
337
fn calculate ( & self ) {
213
338
let & Self {
@@ -220,6 +345,7 @@ where
220
345
theta,
221
346
t_base,
222
347
p_base,
348
+ scheme_type,
223
349
} = self ;
224
350
let nt = nt as isize ;
225
351
let nh = nh as isize ;
@@ -229,10 +355,38 @@ where
229
355
for i in 0 ..nt {
230
356
let t = unsafe { t_base. byte_offset ( i * st) . cast :: < [ A ; 2 ] > ( ) } ;
231
357
let p = unsafe { * p_base. byte_offset ( i * sp) } ;
358
+ let factor = match scheme_type {
359
+ SchemeType :: Long {
360
+ long,
361
+ short,
362
+ origin_pos,
363
+ ..
364
+ } => unsafe {
365
+ if p. val ( ) < origin_pos as usize {
366
+ ( short as * const P ) . byte_offset ( i * st) . cast ( )
367
+ } else {
368
+ ( long as * const P ) . byte_offset ( i * st) . cast ( )
369
+ }
370
+ } ,
371
+ _ => null ( ) ,
372
+ } ;
232
373
for j in 0 ..nh {
233
374
for k in 0 ..dh {
234
375
let pair = unsafe { & mut * t. byte_offset ( j * sh + k * sd) } ;
235
- let ( sin, cos) = p. freq_sin_cos ( k, dh, theta) ;
376
+ let ( sin, cos) = match scheme_type {
377
+ SchemeType :: Rope { s } => p. freq_sin_cos_rope ( k, dh, theta, s) ,
378
+ SchemeType :: Long { s, .. } => {
379
+ let factor = unsafe { * factor } ;
380
+ p. freq_sin_cos_long ( k, dh, theta, factor, s)
381
+ }
382
+ SchemeType :: NtkParts {
383
+ alpha,
384
+ beta,
385
+ l0,
386
+ s,
387
+ ntktype,
388
+ } => p. freq_sin_cos_ntk_part ( k, dh, theta, alpha, beta, l0, s, ntktype) ,
389
+ } ;
236
390
A :: calculate ( pair, sin, cos)
237
391
}
238
392
}
0 commit comments