Skip to content

Commit b6d2b42

Browse files
committed
test: 修复release 运行报错的问题
1 parent 3d6c463 commit b6d2b42

File tree

1 file changed

+12
-6
lines changed
  • operators/src/rearrange/cuda

1 file changed

+12
-6
lines changed

operators/src/rearrange/cuda/mod.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,9 @@ impl crate::Operator for Operator {
262262
src_cs,
263263
layout.r,
264264
layout.c,
265-
32u32, // sub_size_x
266-
32u32, // sub_size_y
267-
(unit as u32) // bytes_per_thread
265+
32u32, // sub_size_x
266+
32u32, // sub_size_y
267+
unit // bytes_per_thread
268268
];
269269

270270
let shared_memory_size = if use_shared_memory {
@@ -491,12 +491,13 @@ mod test {
491491
r#"
492492
493493
extern "C" __global__ void fill_src(
494-
unsigned char *src,
494+
void *__restrict__ src,
495495
unsigned int n
496496
){{
497497
int idx = threadIdx.x + blockIdx.x * blockDim.x;
498+
498499
if (idx < n) {{
499-
((unsigned char*)src)[idx] = threadIdx.x;
500+
reinterpret_cast<char *>(src)[idx] = 11;
500501
}}
501502
}}
502503
"#
@@ -515,10 +516,15 @@ extern "C" __global__ void fill_src(
515516
let block = block_size;
516517
let grid = grid_size;
517518

518-
let params = cuda::params![src.as_mut_ptr(), src.len() as u32];
519+
let src_ptr = src.as_mut_ptr();
520+
let src_len = src.len() as i32;
521+
522+
let params = cuda::params![src_ptr, src_len];
523+
519524
module
520525
.get_kernel(&name)
521526
.launch(grid as u32, block as u32, params.as_ptr(), 0, Some(queue));
527+
let _keep_alive = (src_ptr, src_len);
522528
}
523529

524530
use std::time::Duration;

0 commit comments

Comments
 (0)