Skip to content

Commit 0ea4faf

Browse files
committed
feat: 成功
1 parent dccabc0 commit 0ea4faf

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

operators/src/rearrange/cuda/mod.rs

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ impl crate::Operator for Operator {
226226
let dst_num_per_block = dst_num_per_block.floor() as usize;
227227
let src_num_per_grid = src_current_dim_len.div_ceil(src_num_per_block);
228228
let dst_num_per_grid = dst_current_dim_len.div_ceil(dst_num_per_block);
229+
229230
if src_num_per_block > 1 {
230231
split_dims.push(SplitDim {
231232
choose_idx: src_idx,
@@ -504,22 +505,28 @@ mod test {
504505
cpu_op.scheme(&dyn_args(dt), 0).unwrap();
505506
gpu_op.scheme(&dyn_args(dt), 0).unwrap();
506507

507-
let nh = 100;
508-
let seq = 3343;
509-
let dh = 100;
510-
let mut src = vec![0u64; nh * seq * dh];
508+
const N: usize = 5;
509+
const TRANS_N: usize = 3;
510+
let shape: [usize; N] = [2232, 3, 7, 9, 4];
511+
let mut r_shape: [usize; N] = shape.clone();
512+
r_shape[0..TRANS_N].reverse();
513+
514+
let trans_param: [usize; TRANS_N] =
515+
(0..TRANS_N).rev().collect::<Vec<_>>().try_into().unwrap();
516+
517+
let mut src = vec![0u64; shape.iter().product::<usize>()];
511518
rand::rng().fill(&mut src[..]);
512519

513520
let ele = dt.nbytes();
514-
let s_src = ArrayLayout::<3>::new_contiguous(&[nh, seq, dh], BigEndian, ele);
521+
let s_src = ArrayLayout::<3>::new_contiguous(&shape, BigEndian, ele);
515522
let s_dst =
516-
ArrayLayout::<3>::new_contiguous(&[dh, seq, nh], BigEndian, ele).transpose(&[2, 1, 0]);
523+
ArrayLayout::<3>::new_contiguous(&r_shape, BigEndian, ele).transpose(&trans_param);
517524

518525
println!("s_src: {:?}", s_src.shape());
519526
println!("s_dst: {:?}", s_dst.shape());
520527
println!("s_src strides: {:?}", s_src.strides());
521-
522528
println!("s_dst strides: {:?}", s_dst.strides());
529+
523530
let dst_ans = gpu.apply(|ctx| {
524531
let stream = ctx.stream();
525532
#[cfg(use_nvidia)]
@@ -538,7 +545,7 @@ mod test {
538545
.launch(
539546
&args(
540547
dt,
541-
&[nh, seq, dh],
548+
&shape,
542549
s_src.strides(),
543550
s_dst.strides(),
544551
src.as_ptr().cast(),
@@ -556,7 +563,7 @@ mod test {
556563
.launch(
557564
&args(
558565
dt,
559-
&[nh, seq, dh],
566+
&shape,
560567
s_src.strides(),
561568
s_dst.strides(),
562569
src.as_ptr().cast(),
@@ -571,17 +578,17 @@ mod test {
571578
let time = end_event.elapse_from(&start_event);
572579
println!("time: {time:?}");
573580

574-
let mut host = vec![0u64; nh * seq * dh];
581+
let mut host = vec![0u64; shape.iter().product::<usize>()];
575582
memcpy_d2h(&mut host, &dst);
576583
host
577584
});
578585

579-
let mut dst_ref = vec![0u64; nh * seq * dh];
586+
let mut dst_ref = vec![0u64; shape.iter().product::<usize>()];
580587
cpu_op
581588
.launch(
582589
&args(
583590
dt,
584-
&[nh, seq, dh],
591+
&shape,
585592
s_src.strides(),
586593
s_dst.strides(),
587594
src.as_ptr().cast(),

operators/src/rearrange/cuda/rearrange.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,10 @@ static __device__ void rearrange_1(
223223
dst_offset += idx * dst_grid_stride.a[i];
224224

225225
if (i == constrains1.a[0]) {
226-
shared_constrains1_grid_idx_multiple = idx;
226+
shared_constrains1_grid_idx_multiple = idx * constrains1.a[2];
227227
}
228228
if (i == constrains2.a[0]) {
229-
shared_constrains2_grid_idx_multiple = idx;
229+
shared_constrains2_grid_idx_multiple = idx * constrains2.a[2];
230230
}
231231

232232
// 将结果存入共享内存

0 commit comments

Comments
 (0)