Skip to content

Commit de841a0

Browse files
committed
reflactor: 整理代码
1 parent 0ea4faf commit de841a0

File tree

2 files changed

+48
-567
lines changed

2 files changed

+48
-567
lines changed

operators/src/rearrange/cuda/mod.rs

Lines changed: 48 additions & 265 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct SplitDim {
1717
array_struct_idx_grid: ArrayType,
1818
}
1919

20-
const ARRAY_SIZE: usize = 7;
20+
const ARRAY_SIZE: usize = 5;
2121

2222
type ArrayType = i32;
2323
#[derive(Debug)]
@@ -35,12 +35,15 @@ impl<const N: usize> ArrayStruct<N> {
3535
Some(Self(array))
3636
}
3737
}
38-
//TODO 需要检查正确性
38+
3939
impl<const N: usize> AsParam for ArrayStruct<N> {}
4040

41+
//TODO 需要使用max_warps_block和warp_size来进行计算
4142
pub struct Operator {
4243
_handle: Arc<Handle>,
44+
#[allow(unused)]
4345
max_warps_block: usize,
46+
#[allow(unused)]
4447
warp_size: usize,
4548
module: Arc<ModuleBox>,
4649
}
@@ -90,54 +93,6 @@ impl crate::Operator for Operator {
9093
QA: QueueAlloc<Hardware = Self::Hardware>,
9194
{
9295
let scheme = Scheme::new(args)?;
93-
// if scheme.ndim() == 0 {
94-
// let unit = scheme.unit();
95-
// let dst = unsafe { from_raw_parts_mut(args.dst_base, unit) };
96-
// let src = unsafe { from_raw_parts(args.src_base, unit) };
97-
// queue_alloc.queue().memcpy_d2d(dst, src);
98-
// return Ok(());
99-
// }
100-
101-
if scheme.ndim() == 0 {
102-
let unit = unsafe { BARE_UNIT };
103-
let len = scheme.unit();
104-
105-
let name = CString::new(NAME).unwrap();
106-
107-
// 使用较大的block size来提高并行度
108-
let block_size = 1024;
109-
110-
// 计算总元素数
111-
let total_elements: u32 = (len / unit) as u32;
112-
113-
let grid_size = (total_elements + block_size - 1) / block_size;
114-
115-
let params = cuda::params![
116-
args.dst_base,
117-
0i32, // rsa
118-
0i32, // csa
119-
args.src_base,
120-
0i32, // rsb
121-
0i32, // csb
122-
total_elements, // nrows
123-
1u32, // ncols
124-
32u32, // sub_size_x
125-
32u32, // sub_size_y
126-
unit // bytes_per_thread
127-
];
128-
129-
self.module.launch(
130-
&name,
131-
grid_size as u32,
132-
block_size as u32,
133-
params.as_ptr(),
134-
0,
135-
queue_alloc.queue(),
136-
);
137-
return Ok(());
138-
}
139-
//----------------------------------------------------------------------
140-
// 发现读取的最大连续内存和写入的最大连续内存
14196

14297
// 发现最大的1 thread 处理的数据量
14398
let scheme_update = scheme.distribute_unit((0..=5).rev().map(|n| (1 << n)));
@@ -305,7 +260,6 @@ impl crate::Operator for Operator {
305260
}
306261
}
307262

308-
println!("split_dims: {:?}", split_dims);
309263
// cuda 参数准备
310264
let block_len_total = block_len.iter().product::<ArrayType>();
311265
let src_block_stride =
@@ -394,45 +348,59 @@ impl crate::Operator for Operator {
394348
fn format_code() -> String {
395349
format!(
396350
r#"#define ARRAY_SIZE {ARRAY_SIZE}
397-
#define ARRAY_TYPE int
398-
{CODE}
351+
#define ARRAY_TYPE int
352+
{CODE}
399353
400354
extern "C" __global__ void {NAME}(
401355
void *__restrict__ dst,
402356
void const *__restrict__ src,
403-
const int block_dim, // block维度数量
404-
const int block_len_total, // block_len 各元素的乘积
405-
const ArrayStruct<4, ARRAY_TYPE> constrains1, // 切分维度的约束条件1
406-
const ArrayStruct<4, ARRAY_TYPE> constrains2, // 切分维度的约束条件2
407-
const ArrayStruct<ARRAY_SIZE, ARRAY_TYPE> block_len, // 各维度的长度
408-
const ArrayStruct<ARRAY_SIZE, ARRAY_TYPE> src_block_stride, // 源tensor在各维度上的步长(bytes)
409-
const ArrayStruct<ARRAY_SIZE, ARRAY_TYPE> dst_block_stride, // 目标tensor在各维度上的步长(bytes)
410-
const ArrayStruct<ARRAY_SIZE, ARRAY_TYPE> grid_len, // 各维度的长度
411-
const ArrayStruct<ARRAY_SIZE, ARRAY_TYPE> src_grid_stride, // 源tensor在各维度上的步长(bytes)
412-
const ArrayStruct<ARRAY_SIZE, ARRAY_TYPE> dst_grid_stride, // 目标tensor在各维度上的步长(bytes)
413-
unsigned int const unit_size // 每个元素的字节数
357+
const int block_dim, // block维度数量
358+
const int block_len_total, // block_len 各元素的乘积
359+
const ArrayStruct<4, ARRAY_TYPE> constrains1, // 切分维度的约束条件1
360+
const ArrayStruct<4, ARRAY_TYPE> constrains2, // 切分维度的约束条件2
361+
const ArrayStruct<ARRAY_SIZE, ARRAY_TYPE> block_len, // 各维度的长度
362+
const ArrayStruct<ARRAY_SIZE, ARRAY_TYPE> src_block_stride, // 源tensor在各维度上的步长(bytes)
363+
const ArrayStruct<ARRAY_SIZE, ARRAY_TYPE> dst_block_stride, // 目标tensor在各维度上的步长(bytes)
364+
const ArrayStruct<ARRAY_SIZE, ARRAY_TYPE> grid_len, // 各维度的长度
365+
const ArrayStruct<ARRAY_SIZE, ARRAY_TYPE> src_grid_stride, // 源tensor在各维度上的步长(bytes)
366+
const ArrayStruct<ARRAY_SIZE, ARRAY_TYPE> dst_grid_stride, // 目标tensor在各维度上的步长(bytes)
367+
unsigned int const unit_size // 每个元素的字节数
414368
){{
415369
switch (unit_size) {{
416-
case 1: rearrange_1<uchar1 ,ARRAY_SIZE, ARRAY_TYPE>(dst, src, block_dim, block_len_total, constrains1, constrains2, block_len, src_block_stride, dst_block_stride, grid_len, src_grid_stride, dst_grid_stride, unit_size); break;
417-
case 2: rearrange_1<uchar2 ,ARRAY_SIZE, ARRAY_TYPE>(dst, src, block_dim, block_len_total, constrains1, constrains2, block_len, src_block_stride, dst_block_stride, grid_len, src_grid_stride, dst_grid_stride, unit_size); break;
418-
case 4: rearrange_1<float1 ,ARRAY_SIZE, ARRAY_TYPE>(dst, src, block_dim, block_len_total, constrains1, constrains2, block_len, src_block_stride, dst_block_stride, grid_len, src_grid_stride, dst_grid_stride, unit_size); break;
419-
case 8: rearrange_1<float2 ,ARRAY_SIZE, ARRAY_TYPE>(dst, src, block_dim, block_len_total, constrains1, constrains2, block_len, src_block_stride, dst_block_stride, grid_len, src_grid_stride, dst_grid_stride, unit_size); break;
420-
case 16: rearrange_1<float4 ,ARRAY_SIZE, ARRAY_TYPE>(dst, src, block_dim, block_len_total, constrains1, constrains2, block_len, src_block_stride, dst_block_stride, grid_len, src_grid_stride, dst_grid_stride, unit_size); break;
421-
case 32: rearrange_1<double4,ARRAY_SIZE, ARRAY_TYPE>(dst, src, block_dim, block_len_total, constrains1, constrains2, block_len, src_block_stride, dst_block_stride, grid_len, src_grid_stride, dst_grid_stride, unit_size); break;
370+
case 1:
371+
rearrange_1<uchar1 ,ARRAY_SIZE, ARRAY_TYPE>(dst, src, block_dim, block_len_total, constrains1, constrains2,
372+
block_len, src_block_stride, dst_block_stride, grid_len, src_grid_stride, dst_grid_stride, unit_size);
373+
break;
374+
case 2:
375+
rearrange_1<uchar2 ,ARRAY_SIZE, ARRAY_TYPE>(dst, src, block_dim, block_len_total, constrains1, constrains2,
376+
block_len, src_block_stride, dst_block_stride, grid_len, src_grid_stride, dst_grid_stride, unit_size);
377+
break;
378+
case 4:
379+
rearrange_1<float1 ,ARRAY_SIZE, ARRAY_TYPE>(dst, src, block_dim, block_len_total, constrains1, constrains2,
380+
block_len, src_block_stride, dst_block_stride, grid_len, src_grid_stride, dst_grid_stride, unit_size);
381+
break;
382+
case 8:
383+
rearrange_1<float2 ,ARRAY_SIZE, ARRAY_TYPE>(dst, src, block_dim, block_len_total, constrains1, constrains2,
384+
block_len, src_block_stride, dst_block_stride, grid_len, src_grid_stride, dst_grid_stride, unit_size);
385+
break;
386+
case 16:
387+
rearrange_1<float4 ,ARRAY_SIZE, ARRAY_TYPE>(dst, src, block_dim, block_len_total, constrains1, constrains2,
388+
block_len, src_block_stride, dst_block_stride, grid_len, src_grid_stride, dst_grid_stride, unit_size);
389+
break;
390+
case 32:
391+
rearrange_1<double4,ARRAY_SIZE, ARRAY_TYPE>(dst, src, block_dim, block_len_total, constrains1, constrains2,
392+
block_len, src_block_stride, dst_block_stride, grid_len, src_grid_stride, dst_grid_stride, unit_size);
393+
break;
422394
}}
423395
}}
424396
"#
425397
)
426398
}
427399

428-
static mut IS_INVERSE: bool = false;
429-
static mut BARE_UNIT: usize = 4;
430-
431400
#[cfg(test)]
432401
mod test {
433402
use super::{Args, Gpu, Operator};
434403
use crate::{ConstPtr, Hardware, MutPtr, Operator as _, TensorLayout};
435-
use cuda::{DevMem, Ptx};
436404
use digit_layout::{types as ty, DigitLayout};
437405

438406
fn dyn_args<H: Hardware>(dt: DigitLayout) -> Args<H> {
@@ -488,12 +456,13 @@ mod test {
488456
fn test_compute() {
489457
use super::super::common_cpu::Operator as RefOp;
490458
use crate::common_cpu::{Cpu, ThisThread};
491-
use crate::rearrange::cuda::format_code;
459+
492460
use cuda::memcpy_d2h;
493461
use ndarray_layout::{ArrayLayout, Endian::BigEndian};
494462
use rand::Rng;
495-
let code = format_code();
496-
std::fs::write("rearrange.cu", code).unwrap();
463+
// use crate::rearrange::cuda::format_code;
464+
// let code = format_code();
465+
// std::fs::write("rearrange.cu", code).unwrap();
497466
let Some(gpu) = Gpu::init() else {
498467
return;
499468
};
@@ -522,8 +491,8 @@ mod test {
522491
let s_dst =
523492
ArrayLayout::<3>::new_contiguous(&r_shape, BigEndian, ele).transpose(&trans_param);
524493

525-
println!("s_src: {:?}", s_src.shape());
526-
println!("s_dst: {:?}", s_dst.shape());
494+
println!("s_src shape: {:?}", s_src.shape());
495+
println!("s_dst shape: {:?}", s_dst.shape());
527496
println!("s_src strides: {:?}", s_src.strides());
528497
println!("s_dst strides: {:?}", s_dst.strides());
529498

@@ -600,190 +569,4 @@ mod test {
600569
.unwrap();
601570
assert_eq!(dst_ans, dst_ref);
602571
}
603-
604-
use crate::cuda::CurrentCtx;
605-
use crate::cuda::Stream;
606-
607-
use std::ffi::CString;
608-
fn fill_src_code() -> String {
609-
format!(
610-
r#"
611-
612-
extern "C" __global__ void fill_src(
613-
void *__restrict__ src,
614-
unsigned int n
615-
){{
616-
int idx = threadIdx.x + blockIdx.x * blockDim.x;
617-
618-
if (idx < n) {{
619-
reinterpret_cast<char *>(src)[idx] = 11;
620-
}}
621-
}}
622-
"#
623-
)
624-
}
625-
fn fill_src(src: &mut DevMem, ctx: &CurrentCtx, queue: &Stream) {
626-
let (ptx, _) = Ptx::compile(fill_src_code(), ctx.dev().compute_capability());
627-
let module = ctx.load(&ptx.unwrap());
628-
let name = CString::new("fill_src").unwrap();
629-
630-
let block_size = 256; // 使用较小的 block size
631-
let total_threads = src.len();
632-
633-
let grid_size = (total_threads + block_size - 1) / block_size;
634-
635-
let block = block_size;
636-
let grid = grid_size;
637-
638-
let src_ptr = src.as_mut_ptr();
639-
let src_len = src.len() as i32;
640-
641-
let params = cuda::params![src_ptr, src_len];
642-
643-
module
644-
.get_kernel(&name)
645-
.launch(grid as u32, block as u32, params.as_ptr(), 0, Some(queue));
646-
let _keep_alive = (src_ptr, src_len);
647-
}
648-
649-
use std::time::Duration;
650-
fn time_cost(is_inverse: bool, total_exp: u32, dh_exp: u32) -> Duration {
651-
use super::super::common_cpu::Operator as RefOp;
652-
use crate::common_cpu::Cpu;
653-
use ndarray_layout::{ArrayLayout, Endian::BigEndian};
654-
let Some(gpu) = Gpu::init() else {
655-
panic!("init gpu failed");
656-
};
657-
let dt = ty::U8;
658-
let mut cpu_op = RefOp::new(&Cpu);
659-
let mut gpu_op = Operator::new(&gpu);
660-
cpu_op.scheme(&dyn_args(dt), 0).unwrap();
661-
gpu_op.scheme(&dyn_args(dt), 0).unwrap();
662-
let nh = 1 << ((total_exp + 1) / 2 - (dh_exp + 1) / 2);
663-
let seq = 1 << (total_exp / 2 - dh_exp / 2);
664-
let dh = 1 << dh_exp;
665-
// println!("nh: {nh}, seq: {seq}, dh: {dh}");
666-
let ele = dt.nbytes();
667-
let s_src = ArrayLayout::<3>::new_contiguous(&[nh, seq, dh], BigEndian, ele);
668-
let s_dst =
669-
ArrayLayout::<3>::new_contiguous(&[seq, nh, dh], BigEndian, ele).transpose(&[1, 0]);
670-
use super::IS_INVERSE;
671-
unsafe {
672-
IS_INVERSE = is_inverse;
673-
}
674-
gpu.apply(|ctx| {
675-
let stream = ctx.stream();
676-
#[cfg(use_nvidia)]
677-
let rt = &stream;
678-
#[cfg(use_iluvatar)]
679-
let rt = ctx;
680-
let mut src = rt.malloc::<u8>(nh * seq * dh);
681-
let mut dst = rt.malloc::<u8>(nh * seq * dh);
682-
fill_src(&mut src, ctx, &stream);
683-
stream.bench(
684-
|_, stream| {
685-
gpu_op
686-
.launch(
687-
&args(
688-
dt,
689-
&[nh, seq, dh],
690-
s_src.strides(),
691-
s_dst.strides(),
692-
src.as_ptr().cast(),
693-
dst.as_mut_ptr().cast(),
694-
),
695-
&mut [],
696-
stream,
697-
)
698-
.unwrap();
699-
},
700-
20,
701-
2,
702-
)
703-
})
704-
}
705-
706-
fn time_cost_bare(total_exp: u32, dh_exp: u32) -> Duration {
707-
use super::super::common_cpu::Operator as RefOp;
708-
use crate::common_cpu::Cpu;
709-
use ndarray_layout::{ArrayLayout, Endian::BigEndian};
710-
let Some(gpu) = Gpu::init() else {
711-
panic!("init gpu failed");
712-
};
713-
let dt = ty::U8;
714-
let mut cpu_op = RefOp::new(&Cpu);
715-
let mut gpu_op = Operator::new(&gpu);
716-
cpu_op.scheme(&dyn_args(dt), 0).unwrap();
717-
gpu_op.scheme(&dyn_args(dt), 0).unwrap();
718-
719-
let total_size = 1 << total_exp;
720-
let unit = 1 << dh_exp;
721-
use crate::rearrange::cuda::BARE_UNIT;
722-
unsafe {
723-
BARE_UNIT = unit;
724-
}
725-
let ele = dt.nbytes();
726-
let s_src = ArrayLayout::<1>::new_contiguous(&[total_size], BigEndian, ele);
727-
728-
gpu.apply(|ctx| {
729-
let stream = ctx.stream();
730-
#[cfg(use_nvidia)]
731-
let rt = &stream;
732-
#[cfg(use_iluvatar)]
733-
let rt = ctx;
734-
let mut src = rt.malloc::<u8>(total_size);
735-
let mut dst = rt.malloc::<u8>(total_size);
736-
fill_src(&mut src, ctx, &stream);
737-
stream.bench(
738-
|_, stream| {
739-
gpu_op
740-
.launch(
741-
&args(
742-
dt,
743-
&[total_size],
744-
s_src.strides(),
745-
s_src.strides(),
746-
src.as_ptr().cast(),
747-
dst.as_mut_ptr().cast(),
748-
),
749-
&mut [],
750-
stream,
751-
)
752-
.unwrap();
753-
},
754-
20,
755-
2,
756-
)
757-
})
758-
}
759-
760-
#[test]
761-
fn test_time() {
762-
for total_exp in [24, 26, 28, 30] {
763-
println!("\n性能测试结果 (total_exp = {total_exp}):");
764-
println!(
765-
"数据规模: {} ({:.2}GB)",
766-
1u64 << total_exp,
767-
(1u64 << total_exp) as f64 / (1024.0 * 1024.0 * 1024.0)
768-
);
769-
println!("----------------------------------------");
770-
println!("dh_exp dh大小 正向时间 反向时间 直接拷贝时间");
771-
println!("----------------------------------------");
772-
for dh_exp in 1..=5 {
773-
let dh_size = 1 << dh_exp;
774-
let inverse_time = time_cost(true, total_exp, dh_exp);
775-
let forward_time = time_cost(false, total_exp, dh_exp);
776-
let bare_time = time_cost_bare(total_exp, dh_exp);
777-
println!("{dh_exp:<7} {dh_size:<7} {forward_time:<16?} {inverse_time:<16?} {bare_time:<16?}");
778-
}
779-
println!("----------------------------------------");
780-
}
781-
}
782-
783-
#[test]
784-
fn test_time_one() {
785-
time_cost(true, 26, 4);
786-
time_cost(false, 26, 4);
787-
time_cost_bare(26, 8);
788-
}
789572
}

0 commit comments

Comments
 (0)