Skip to content

Commit f56bd7a

Browse files
Merge pull request #912 from IntelPython/use-usm-host-allocator
2 parents 0f0c926 + eef632e commit f56bd7a

File tree

1 file changed

+56
-29
lines changed

1 file changed

+56
-29
lines changed

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -487,13 +487,20 @@ sycl::event _populate_packed_shape_strides_for_copycast_kernel(
487487
const std::vector<py::ssize_t> &src_strides,
488488
const std::vector<py::ssize_t> &dst_strides)
489489
{
490-
using shT = std::vector<py::ssize_t>;
490+
// memory transfer optimization, use USM-host for temporary speeds up
491+
// tranfer to device, especially on dGPUs
492+
using usm_host_allocatorT =
493+
sycl::usm_allocator<py::ssize_t, sycl::usm::alloc::host>;
494+
using shT = std::vector<py::ssize_t, usm_host_allocatorT>;
491495
size_t nd = common_shape.size();
492496

497+
usm_host_allocatorT allocator(exec_q);
498+
493499
// create host temporary for packed shape and strides managed by shared
494500
// pointer. Packed vector is concatenation of common_shape, src_stride and
495501
// std_strides
496-
std::shared_ptr<shT> shp_host_shape_strides = std::make_shared<shT>(3 * nd);
502+
std::shared_ptr<shT> shp_host_shape_strides =
503+
std::make_shared<shT>(3 * nd, allocator);
497504
std::copy(common_shape.begin(), common_shape.end(),
498505
shp_host_shape_strides->begin());
499506

@@ -943,9 +950,12 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
943950
throw std::runtime_error("Unabled to allocate device memory");
944951
}
945952

946-
using shT = std::vector<py::ssize_t>;
953+
using usm_host_allocatorT =
954+
sycl::usm_allocator<py::ssize_t, sycl::usm::alloc::host>;
955+
using shT = std::vector<py::ssize_t, usm_host_allocatorT>;
956+
usm_host_allocatorT allocator(exec_q);
947957
std::shared_ptr<shT> packed_host_shapes_strides_shp =
948-
std::make_shared<shT>(2 * (src_nd + dst_nd));
958+
std::make_shared<shT>(2 * (src_nd + dst_nd), allocator);
949959

950960
std::copy(src_shape, src_shape + src_nd,
951961
packed_host_shapes_strides_shp->begin());
@@ -956,13 +966,13 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
956966
if (src_strides == nullptr) {
957967
int src_flags = src.get_flags();
958968
if (src_flags & USM_ARRAY_C_CONTIGUOUS) {
959-
const shT &src_contig_strides =
969+
const auto &src_contig_strides =
960970
c_contiguous_strides(src_nd, src_shape);
961971
std::copy(src_contig_strides.begin(), src_contig_strides.end(),
962972
packed_host_shapes_strides_shp->begin() + src_nd);
963973
}
964974
else if (src_flags & USM_ARRAY_F_CONTIGUOUS) {
965-
const shT &src_contig_strides =
975+
const auto &src_contig_strides =
966976
c_contiguous_strides(src_nd, src_shape);
967977
std::copy(src_contig_strides.begin(), src_contig_strides.end(),
968978
packed_host_shapes_strides_shp->begin() + src_nd);
@@ -982,14 +992,14 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
982992
if (dst_strides == nullptr) {
983993
int dst_flags = dst.get_flags();
984994
if (dst_flags & USM_ARRAY_C_CONTIGUOUS) {
985-
const shT &dst_contig_strides =
995+
const auto &dst_contig_strides =
986996
c_contiguous_strides(dst_nd, dst_shape);
987997
std::copy(dst_contig_strides.begin(), dst_contig_strides.end(),
988998
packed_host_shapes_strides_shp->begin() + 2 * src_nd +
989999
dst_nd);
9901000
}
9911001
else if (dst_flags & USM_ARRAY_F_CONTIGUOUS) {
992-
const shT &dst_contig_strides =
1002+
const auto &dst_contig_strides =
9931003
f_contiguous_strides(dst_nd, dst_shape);
9941004
std::copy(dst_contig_strides.begin(), dst_contig_strides.end(),
9951005
packed_host_shapes_strides_shp->begin() + 2 * src_nd +
@@ -1349,7 +1359,12 @@ void copy_numpy_ndarray_into_usm_ndarray(
13491359
throw std::runtime_error("Unabled to allocate device memory");
13501360
}
13511361

1352-
std::shared_ptr<shT> host_shape_strides_shp = std::make_shared<shT>(3 * nd);
1362+
using usm_host_allocatorT =
1363+
sycl::usm_allocator<py::ssize_t, sycl::usm::alloc::host>;
1364+
using usmshT = std::vector<py::ssize_t, usm_host_allocatorT>;
1365+
usm_host_allocatorT alloc(exec_q);
1366+
1367+
auto host_shape_strides_shp = std::make_shared<usmshT>(3 * nd, alloc);
13531368
std::copy(simplified_shape.begin(), simplified_shape.end(),
13541369
host_shape_strides_shp->begin());
13551370
std::copy(simplified_src_strides.begin(), simplified_src_strides.end(),
@@ -2023,9 +2038,10 @@ tri(sycl::queue &exec_q,
20232038
return std::make_pair(sycl::event(), sycl::event());
20242039
}
20252040

2026-
// check that arrays do not overlap, and concurrent copying is safe.
20272041
char *src_data = src.get_data();
20282042
char *dst_data = dst.get_data();
2043+
2044+
// check that arrays do not overlap, and concurrent copying is safe.
20292045
auto src_offsets = src.get_minmax_offsets();
20302046
auto dst_offsets = dst.get_minmax_offsets();
20312047
int src_elem_size = src.get_elemsize();
@@ -2045,6 +2061,7 @@ tri(sycl::queue &exec_q,
20452061
int dst_typenum = dst.get_typenum();
20462062
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
20472063
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
2064+
20482065
if (dst_typeid != src_typeid) {
20492066
throw py::value_error("Array dtype are not the same.");
20502067
}
@@ -2059,11 +2076,13 @@ tri(sycl::queue &exec_q,
20592076
}
20602077

20612078
using shT = std::vector<py::ssize_t>;
2062-
int src_flags = src.get_flags();
2063-
const py::ssize_t *src_strides_raw = src.get_strides_raw();
20642079
shT src_strides(src_nd);
2080+
2081+
int src_flags = src.get_flags();
20652082
bool is_src_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) != 0);
20662083
bool is_src_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) != 0);
2084+
2085+
const py::ssize_t *src_strides_raw = src.get_strides_raw();
20672086
if (src_strides_raw == nullptr) {
20682087
if (is_src_c_contig) {
20692088
src_strides = c_contiguous_strides(src_nd, src_shape);
@@ -2081,11 +2100,13 @@ tri(sycl::queue &exec_q,
20812100
src_strides.begin());
20822101
}
20832102

2084-
int dst_flags = dst.get_flags();
2085-
const py::ssize_t *dst_strides_raw = dst.get_strides_raw();
20862103
shT dst_strides(src_nd);
2104+
2105+
int dst_flags = dst.get_flags();
20872106
bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0);
20882107
bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0);
2108+
2109+
const py::ssize_t *dst_strides_raw = dst.get_strides_raw();
20892110
if (dst_strides_raw == nullptr) {
20902111
if (is_dst_c_contig) {
20912112
dst_strides = c_contiguous_strides(src_nd, src_shape);
@@ -2128,23 +2149,29 @@ tri(sycl::queue &exec_q,
21282149
}
21292150

21302151
nd += 2;
2131-
std::vector<py::ssize_t> shape_and_strides(3 * nd);
2152+
2153+
using usm_host_allocatorT =
2154+
sycl::usm_allocator<py::ssize_t, sycl::usm::alloc::host>;
2155+
using usmshT = std::vector<py::ssize_t, usm_host_allocatorT>;
2156+
2157+
usm_host_allocatorT allocator(exec_q);
2158+
auto shp_host_shape_and_strides =
2159+
std::make_shared<usmshT>(3 * nd, allocator);
21322160

21332161
std::copy(simplified_shape.begin(), simplified_shape.end(),
2134-
shape_and_strides.begin());
2135-
shape_and_strides[nd - 2] = src_shape[src_nd - 2];
2136-
shape_and_strides[nd - 1] = src_shape[src_nd - 1];
2162+
shp_host_shape_and_strides->begin());
2163+
(*shp_host_shape_and_strides)[nd - 2] = src_shape[src_nd - 2];
2164+
(*shp_host_shape_and_strides)[nd - 1] = src_shape[src_nd - 1];
2165+
21372166
std::copy(simplified_src_strides.begin(), simplified_src_strides.end(),
2138-
shape_and_strides.begin() + nd);
2139-
shape_and_strides[2 * nd - 2] = src_strides[src_nd - 2];
2140-
shape_and_strides[2 * nd - 1] = src_strides[src_nd - 1];
2141-
std::copy(simplified_dst_strides.begin(), simplified_dst_strides.end(),
2142-
shape_and_strides.begin() + 2 * nd);
2143-
shape_and_strides[3 * nd - 2] = dst_strides[src_nd - 2];
2144-
shape_and_strides[3 * nd - 1] = dst_strides[src_nd - 1];
2167+
shp_host_shape_and_strides->begin() + nd);
2168+
(*shp_host_shape_and_strides)[2 * nd - 2] = src_strides[src_nd - 2];
2169+
(*shp_host_shape_and_strides)[2 * nd - 1] = src_strides[src_nd - 1];
21452170

2146-
std::shared_ptr<shT> shp_host_shape_and_strides =
2147-
std::make_shared<shT>(shape_and_strides);
2171+
std::copy(simplified_dst_strides.begin(), simplified_dst_strides.end(),
2172+
shp_host_shape_and_strides->begin() + 2 * nd);
2173+
(*shp_host_shape_and_strides)[3 * nd - 2] = dst_strides[src_nd - 2];
2174+
(*shp_host_shape_and_strides)[3 * nd - 1] = dst_strides[src_nd - 1];
21482175

21492176
py::ssize_t *dev_shape_and_strides =
21502177
sycl::malloc_device<ssize_t>(3 * nd, exec_q);
@@ -2154,8 +2181,7 @@ tri(sycl::queue &exec_q,
21542181
sycl::event copy_shape_and_strides = exec_q.copy<ssize_t>(
21552182
shp_host_shape_and_strides->data(), dev_shape_and_strides, 3 * nd);
21562183

2157-
py::ssize_t inner_range =
2158-
shape_and_strides[nd - 1] * shape_and_strides[nd - 2];
2184+
py::ssize_t inner_range = src_shape[src_nd - 1] * src_shape[src_nd - 2];
21592185
py::ssize_t outer_range = src_nelems / inner_range;
21602186

21612187
sycl::event tri_ev;
@@ -2182,6 +2208,7 @@ tri(sycl::queue &exec_q,
21822208
sycl::free(dev_shape_and_strides, ctx);
21832209
});
21842210
});
2211+
21852212
return std::make_pair(keep_args_alive(exec_q, {src, dst}, {tri_ev}),
21862213
tri_ev);
21872214
}

0 commit comments

Comments
 (0)