Skip to content

Commit b280e10

Browse files
Clean up uses of Strided1DIndexer
Provide overloaded constructors to avoid uses of static_cast at constructor sites. Provide shortcut constructors (for zero offset). Constructo call site use comments to specify meaning of constructor parameters.
1 parent 63ffaba commit b280e10

File tree

8 files changed

+213
-207
lines changed

8 files changed

+213
-207
lines changed

dpctl/tensor/libtensor/include/kernels/accumulators.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -592,9 +592,8 @@ sycl::event inclusive_scan_iter(sycl::queue &exec_q,
592592
size_t src_size = acc_groups - 1;
593593
using LocalScanIndexerT =
594594
dpctl::tensor::offset_utils::Strided1DIndexer;
595-
const LocalScanIndexerT scan_iter_indexer{
596-
0, static_cast<ssize_t>(iter_nelems),
597-
static_cast<ssize_t>(src_size)};
595+
const LocalScanIndexerT scan_iter_indexer{/* size */ iter_nelems,
596+
/* step */ src_size};
598597

599598
using IterIndexerT =
600599
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
@@ -623,11 +622,10 @@ sycl::event inclusive_scan_iter(sycl::queue &exec_q,
623622
using LocalScanIndexerT =
624623
dpctl::tensor::offset_utils::Strided1DIndexer;
625624
const LocalScanIndexerT scan1_iter_indexer{
626-
0, static_cast<ssize_t>(iter_nelems),
627-
static_cast<ssize_t>(size_to_update)};
628-
const LocalScanIndexerT scan2_iter_indexer{
629-
0, static_cast<ssize_t>(iter_nelems),
630-
static_cast<ssize_t>(src_size)};
625+
/* size */ iter_nelems,
626+
/* step */ size_to_update};
627+
const LocalScanIndexerT scan2_iter_indexer{/* size */ iter_nelems,
628+
/* step */ src_size};
631629

632630
using IterIndexerT =
633631
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<

dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,8 @@ sycl::event masked_extract_all_slices_strided_impl(
233233
/* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
234234
* *_packed_shape_strides) */
235235
const StridedIndexer masked_src_indexer(nd, 0, packed_src_shape_strides);
236-
const Strided1DIndexer masked_dst_indexer(0, dst_size, dst_stride);
236+
const Strided1DIndexer masked_dst_indexer(/* size */ dst_size,
237+
/* step */ dst_stride);
237238

238239
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
239240
cgh.depends_on(depends);
@@ -309,8 +310,8 @@ sycl::event masked_extract_some_slices_strided_impl(
309310

310311
const StridedIndexer masked_src_indexer{masked_nd, 0,
311312
packed_masked_src_shape_strides};
312-
const Strided1DIndexer masked_dst_indexer{0, masked_dst_size,
313-
masked_dst_stride};
313+
const Strided1DIndexer masked_dst_indexer{/* size */ masked_dst_size,
314+
/* step */ masked_dst_stride};
314315

315316
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
316317
cgh.depends_on(depends);

dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -576,9 +576,8 @@ dot_product_contig_impl(sycl::queue &exec_q,
576576
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
577577
NoOpIndexerT, NoOpIndexerT>;
578578

579-
const InputBatchIndexerT inp_batch_indexer{
580-
0, static_cast<ssize_t>(batches),
581-
static_cast<ssize_t>(reduction_nelems)};
579+
const InputBatchIndexerT inp_batch_indexer{/* size */ batches,
580+
/* step */ reduction_nelems};
582581
const InputOutputBatchIndexerT inp_out_batch_indexer{
583582
inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}};
584583
constexpr ReductionIndexerT reduction_indexer{NoOpIndexerT{},
@@ -612,9 +611,8 @@ dot_product_contig_impl(sycl::queue &exec_q,
612611
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
613612
NoOpIndexerT, NoOpIndexerT>;
614613

615-
const InputBatchIndexerT inp_batch_indexer{
616-
0, static_cast<ssize_t>(batches),
617-
static_cast<ssize_t>(reduction_nelems)};
614+
const InputBatchIndexerT inp_batch_indexer{/* size */ batches,
615+
/* step */ reduction_nelems};
618616
const InputOutputBatchIndexerT inp_out_batch_indexer{
619617
inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}};
620618
constexpr ReductionIndexerT reduction_indexer{NoOpIndexerT{},
@@ -1089,9 +1087,8 @@ sycl::event dot_product_tree_impl(sycl::queue &exec_q,
10891087
InputIndexerT, ResIndexerT>;
10901088
using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
10911089

1092-
const InputIndexerT inp_indexer{
1093-
0, static_cast<ssize_t>(batches),
1094-
static_cast<ssize_t>(reduction_groups_)};
1090+
const InputIndexerT inp_indexer{/* size */ batches,
1091+
/* step */ reduction_groups_};
10951092
constexpr ResIndexerT res_iter_indexer{};
10961093

10971094
const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
@@ -1120,9 +1117,8 @@ sycl::event dot_product_tree_impl(sycl::queue &exec_q,
11201117
InputIndexerT, ResIndexerT>;
11211118
using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
11221119

1123-
const InputIndexerT inp_indexer{
1124-
0, static_cast<ssize_t>(batches),
1125-
static_cast<ssize_t>(remaining_reduction_nelems)};
1120+
const InputIndexerT inp_indexer{/* size */ batches,
1121+
/* step */ remaining_reduction_nelems};
11261122
const ResIndexerT res_iter_indexer{
11271123
batch_nd, batch_res_offset,
11281124
/* shape */ batch_shape_and_strides,
@@ -1200,9 +1196,8 @@ dot_product_contig_tree_impl(sycl::queue &exec_q,
12001196
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
12011197
NoOpIndexerT, NoOpIndexerT>;
12021198

1203-
const InputBatchIndexerT inp_batch_indexer{
1204-
0, static_cast<ssize_t>(batches),
1205-
static_cast<ssize_t>(reduction_nelems)};
1199+
const InputBatchIndexerT inp_batch_indexer{/* size */ batches,
1200+
/* step */ reduction_nelems};
12061201
const InputOutputBatchIndexerT inp_out_batch_indexer{
12071202
inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}};
12081203
constexpr ReductionIndexerT reduction_indexer{NoOpIndexerT{},
@@ -1238,9 +1233,8 @@ dot_product_contig_tree_impl(sycl::queue &exec_q,
12381233
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
12391234
NoOpIndexerT, NoOpIndexerT>;
12401235

1241-
const InputBatchIndexerT inp_batch_indexer{
1242-
0, static_cast<ssize_t>(batches),
1243-
static_cast<ssize_t>(reduction_nelems)};
1236+
const InputBatchIndexerT inp_batch_indexer{/* size */ batches,
1237+
/* step */ reduction_nelems};
12441238
const InputOutputBatchIndexerT inp_out_batch_indexer{
12451239
inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}};
12461240
constexpr ReductionIndexerT reduction_indexer{NoOpIndexerT{},
@@ -1307,8 +1301,8 @@ dot_product_contig_tree_impl(sycl::queue &exec_q,
13071301
NoOpIndexerT, NoOpIndexerT>;
13081302

13091303
const InputBatchIndexerT inp_batch_indexer{
1310-
0, static_cast<ssize_t>(batches),
1311-
static_cast<ssize_t>(reduction_nelems)};
1304+
/* size */ batches,
1305+
/* step */ reduction_nelems};
13121306
const InputOutputBatchIndexerT inp_out_batch_indexer{
13131307
inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}};
13141308
constexpr ReductionIndexerT reduction_indexer{NoOpIndexerT{},
@@ -1343,9 +1337,8 @@ dot_product_contig_tree_impl(sycl::queue &exec_q,
13431337
InputIndexerT, ResIndexerT>;
13441338
using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
13451339

1346-
const InputIndexerT inp_indexer{
1347-
0, static_cast<ssize_t>(batches),
1348-
static_cast<ssize_t>(reduction_groups_)};
1340+
const InputIndexerT inp_indexer{/* size */ batches,
1341+
/* step */ reduction_groups_};
13491342
constexpr ResIndexerT res_iter_indexer{};
13501343

13511344
const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
@@ -1374,9 +1367,8 @@ dot_product_contig_tree_impl(sycl::queue &exec_q,
13741367
InputIndexerT, ResIndexerT>;
13751368
using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
13761369

1377-
const InputIndexerT inp_indexer{
1378-
0, static_cast<ssize_t>(batches),
1379-
static_cast<ssize_t>(remaining_reduction_nelems)};
1370+
const InputIndexerT inp_indexer{/* size */ batches,
1371+
/* step */ remaining_reduction_nelems};
13801372
constexpr ResIndexerT res_iter_indexer{};
13811373

13821374
const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,

0 commit comments

Comments
 (0)