@@ -226,6 +226,7 @@ impl crate::Operator for Operator {
226
226
let dst_num_per_block = dst_num_per_block. floor ( ) as usize ;
227
227
let src_num_per_grid = src_current_dim_len. div_ceil ( src_num_per_block) ;
228
228
let dst_num_per_grid = dst_current_dim_len. div_ceil ( dst_num_per_block) ;
229
+
229
230
if src_num_per_block > 1 {
230
231
split_dims. push ( SplitDim {
231
232
choose_idx : src_idx,
@@ -504,22 +505,28 @@ mod test {
504
505
cpu_op. scheme ( & dyn_args ( dt) , 0 ) . unwrap ( ) ;
505
506
gpu_op. scheme ( & dyn_args ( dt) , 0 ) . unwrap ( ) ;
506
507
507
- let nh = 100 ;
508
- let seq = 3343 ;
509
- let dh = 100 ;
510
- let mut src = vec ! [ 0u64 ; nh * seq * dh] ;
508
+ const N : usize = 5 ;
509
+ const TRANS_N : usize = 3 ;
510
+ let shape: [ usize ; N ] = [ 2232 , 3 , 7 , 9 , 4 ] ;
511
+ let mut r_shape: [ usize ; N ] = shape. clone ( ) ;
512
+ r_shape[ 0 ..TRANS_N ] . reverse ( ) ;
513
+
514
+ let trans_param: [ usize ; TRANS_N ] =
515
+ ( 0 ..TRANS_N ) . rev ( ) . collect :: < Vec < _ > > ( ) . try_into ( ) . unwrap ( ) ;
516
+
517
+ let mut src = vec ! [ 0u64 ; shape. iter( ) . product:: <usize >( ) ] ;
511
518
rand:: rng ( ) . fill ( & mut src[ ..] ) ;
512
519
513
520
let ele = dt. nbytes ( ) ;
514
- let s_src = ArrayLayout :: < 3 > :: new_contiguous ( & [ nh , seq , dh ] , BigEndian , ele) ;
521
+ let s_src = ArrayLayout :: < 3 > :: new_contiguous ( & shape , BigEndian , ele) ;
515
522
let s_dst =
516
- ArrayLayout :: < 3 > :: new_contiguous ( & [ dh , seq , nh ] , BigEndian , ele) . transpose ( & [ 2 , 1 , 0 ] ) ;
523
+ ArrayLayout :: < 3 > :: new_contiguous ( & r_shape , BigEndian , ele) . transpose ( & trans_param ) ;
517
524
518
525
println ! ( "s_src: {:?}" , s_src. shape( ) ) ;
519
526
println ! ( "s_dst: {:?}" , s_dst. shape( ) ) ;
520
527
println ! ( "s_src strides: {:?}" , s_src. strides( ) ) ;
521
-
522
528
println ! ( "s_dst strides: {:?}" , s_dst. strides( ) ) ;
529
+
523
530
let dst_ans = gpu. apply ( |ctx| {
524
531
let stream = ctx. stream ( ) ;
525
532
#[ cfg( use_nvidia) ]
@@ -538,7 +545,7 @@ mod test {
538
545
. launch (
539
546
& args (
540
547
dt,
541
- & [ nh , seq , dh ] ,
548
+ & shape ,
542
549
s_src. strides ( ) ,
543
550
s_dst. strides ( ) ,
544
551
src. as_ptr ( ) . cast ( ) ,
@@ -556,7 +563,7 @@ mod test {
556
563
. launch (
557
564
& args (
558
565
dt,
559
- & [ nh , seq , dh ] ,
566
+ & shape ,
560
567
s_src. strides ( ) ,
561
568
s_dst. strides ( ) ,
562
569
src. as_ptr ( ) . cast ( ) ,
@@ -571,17 +578,17 @@ mod test {
571
578
let time = end_event. elapse_from ( & start_event) ;
572
579
println ! ( "time: {time:?}" ) ;
573
580
574
- let mut host = vec ! [ 0u64 ; nh * seq * dh ] ;
581
+ let mut host = vec ! [ 0u64 ; shape . iter ( ) . product :: < usize > ( ) ] ;
575
582
memcpy_d2h ( & mut host, & dst) ;
576
583
host
577
584
} ) ;
578
585
579
- let mut dst_ref = vec ! [ 0u64 ; nh * seq * dh ] ;
586
+ let mut dst_ref = vec ! [ 0u64 ; shape . iter ( ) . product :: < usize > ( ) ] ;
580
587
cpu_op
581
588
. launch (
582
589
& args (
583
590
dt,
584
- & [ nh , seq , dh ] ,
591
+ & shape ,
585
592
s_src. strides ( ) ,
586
593
s_dst. strides ( ) ,
587
594
src. as_ptr ( ) . cast ( ) ,
0 commit comments