Skip to content

Optimized in-place operators for rows and matrices #1244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ template <typename fnT, typename T1, typename T2> struct AddTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class add_strided_strided_kernel;
class add_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event add_strided_impl(sycl::queue exec_q,
Expand All @@ -235,8 +235,7 @@ sycl::event add_strided_impl(sycl::queue exec_q,
const std::vector<sycl::event> &additional_depends)
{
return elementwise_common::binary_strided_impl<
argTy1, argTy2, AddOutputType, AddStridedFunctor,
add_strided_strided_kernel>(
argTy1, argTy2, AddOutputType, AddStridedFunctor, add_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
arg2_offset, res_p, res_offset, depends, additional_depends);
}
Expand Down Expand Up @@ -480,6 +479,60 @@ struct AddInplaceStridedFactory
}
};

template <typename argT, typename resT>
class add_inplace_row_matrix_broadcast_sg_krn;

template <typename argT, typename resT>
using AddInplaceRowMatrixBroadcastingFunctor =
elementwise_common::BinaryInplaceRowMatrixBroadcastingFunctor<
argT,
resT,
AddInplaceFunctor<argT, resT>>;

template <typename argT, typename resT>
sycl::event add_inplace_row_matrix_broadcast_impl(
sycl::queue exec_q,
std::vector<sycl::event> &host_tasks,
size_t n0,
size_t n1,
const char *vec_p, // typeless pointer to (n1,) contiguous row
py::ssize_t vec_offset,
char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix
py::ssize_t mat_offset,
const std::vector<sycl::event> &depends = {})
{
return elementwise_common::binary_inplace_row_matrix_broadcast_impl<
argT, resT, AddInplaceRowMatrixBroadcastingFunctor,
add_inplace_row_matrix_broadcast_sg_krn>(exec_q, host_tasks, n0, n1,
vec_p, vec_offset, mat_p,
mat_offset, depends);
}

template <typename fnT, typename T1, typename T2>
struct AddInplaceRowMatrixBroadcastFactory
{
fnT get()
{
using resT = typename AddOutputType<T1, T2>::value_type;
if constexpr (!std::is_same_v<resT, T2>) {
fnT fn = nullptr;
return fn;
}
else {
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
dpctl::tensor::type_utils::is_complex<T2>::value)
{
fnT fn = nullptr;
return fn;
}
else {
fnT fn = add_inplace_row_matrix_broadcast_impl<T1, T2>;
return fn;
}
}
}
};

} // namespace add
} // namespace kernels
} // namespace tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,60 @@ struct BinaryInplaceStridedFunctor
}
};

template <typename argT, typename resT, typename BinaryOperatorT>
struct BinaryInplaceRowMatrixBroadcastingFunctor
{
private:
const argT *padded_vec;
resT *mat;
size_t n_elems;
size_t n1;

public:
BinaryInplaceRowMatrixBroadcastingFunctor(const argT *row_tp,
resT *mat_tp,
size_t n_elems_in_mat,
size_t n_elems_in_row)
: padded_vec(row_tp), mat(mat_tp), n_elems(n_elems_in_mat),
n1(n_elems_in_row)
{
}

void operator()(sycl::nd_item<1> ndit) const
{
BinaryOperatorT op{};
static_assert(BinaryOperatorT::supports_sg_loadstore::value);

auto sg = ndit.get_sub_group();
size_t gid = ndit.get_global_linear_id();

std::uint8_t sgSize = sg.get_local_range()[0];
size_t base = gid - sg.get_local_id()[0];

if (base + sgSize < n_elems) {
using in_ptrT =
sycl::multi_ptr<const argT,
sycl::access::address_space::global_space>;
using res_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;

const argT vec_el = sg.load(in_ptrT(&padded_vec[base % n1]));
resT mat_el = sg.load(res_ptrT(&mat[base]));

op(mat_el, vec_el);

sg.store(res_ptrT(&mat[base]), mat_el);
}
else {
for (size_t k = base + sg.get_local_id()[0]; k < n_elems;
k += sgSize) {
op(mat[k], padded_vec[k % n1]);
}
}
}
};

// Typedefs for function pointers

typedef sycl::event (*binary_inplace_contig_impl_fn_ptr_t)(
Expand All @@ -214,6 +268,17 @@ typedef sycl::event (*binary_inplace_strided_impl_fn_ptr_t)(
const std::vector<sycl::event> &,
const std::vector<sycl::event> &);

typedef sycl::event (*binary_inplace_row_matrix_broadcast_impl_fn_ptr_t)(
sycl::queue,
std::vector<sycl::event> &,
size_t,
size_t,
const char *,
py::ssize_t,
char *,
py::ssize_t,
const std::vector<sycl::event> &);

template <typename argTy,
typename resTy,
template <typename T1, typename T2, unsigned int vs, unsigned int nv>
Expand Down Expand Up @@ -289,6 +354,79 @@ binary_inplace_strided_impl(sycl::queue exec_q,
return comp_ev;
}

template <typename argT,
typename resT,
template <typename T1, typename T3>
class BinaryInplaceRowMatrixBroadcastFunctorT,
template <typename T1, typename T3>
class kernel_name>
sycl::event binary_inplace_row_matrix_broadcast_impl(
sycl::queue exec_q,
std::vector<sycl::event> &host_tasks,
size_t n0,
size_t n1,
const char *vec_p, // typeless pointer to (n1,) contiguous row
py::ssize_t vec_offset,
char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix
py::ssize_t mat_offset,
const std::vector<sycl::event> &depends = {})
{
const argT *vec = reinterpret_cast<const argT *>(vec_p) + vec_offset;
resT *mat = reinterpret_cast<resT *>(mat_p) + mat_offset;

const auto &dev = exec_q.get_device();
const auto &sg_sizes = dev.get_info<sycl::info::device::sub_group_sizes>();
// Get device-specific kernel info max_sub_group_size
size_t max_sgSize =
*(std::max_element(std::begin(sg_sizes), std::end(sg_sizes)));

size_t n1_padded = n1 + max_sgSize;
argT *padded_vec = sycl::malloc_device<argT>(n1_padded, exec_q);

if (padded_vec == nullptr) {
throw std::runtime_error("Could not allocate memory on the device");
}
sycl::event make_padded_vec_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends); // ensure vec contains actual data
cgh.parallel_for({n1_padded}, [=](sycl::id<1> id) {
auto i = id[0];
padded_vec[i] = vec[i % n1];
});
});

// sub-group spans work-items [I, I + sgSize)
// base = ndit.get_global_linear_id() - sg.get_local_id()[0]
// Generically, sg.load( &mat[base]) may load arrays from
// different rows of mat. The start corresponds to row (base / n0)
// We read sg.load(&padded_vec[(base / n0)]). The vector is padded to
// ensure that reads are accessible

size_t lws = 64;

sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(make_padded_vec_ev);

auto lwsRange = sycl::range<1>(lws);
size_t n_elems = n0 * n1;
size_t n_groups = (n_elems + lws - 1) / lws;
auto gwsRange = sycl::range<1>(n_groups * lws);

cgh.parallel_for<class kernel_name<argT, resT>>(
sycl::nd_range<1>(gwsRange, lwsRange),
BinaryInplaceRowMatrixBroadcastFunctorT<argT, resT>(padded_vec, mat,
n_elems, n1));
});

sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(comp_ev);
sycl::context ctx = exec_q.get_context();
cgh.host_task([ctx, padded_vec]() { sycl::free(padded_vec, ctx); });
});
host_tasks.push_back(tmp_cleanup_ev);

return comp_ev;
}

} // namespace elementwise_common
} // namespace kernels
} // namespace tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ template <typename fnT, typename T1, typename T2> struct EqualTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class equal_strided_strided_kernel;
class equal_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand All @@ -220,9 +220,9 @@ equal_strided_impl(sycl::queue exec_q,
{
return elementwise_common::binary_strided_impl<
argTy1, argTy2, EqualOutputType, EqualStridedFunctor,
equal_strided_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
arg2_offset, res_p, res_offset, depends, additional_depends);
equal_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
arg1_offset, arg2_p, arg2_offset, res_p,
res_offset, depends, additional_depends);
}

template <typename fnT, typename T1, typename T2> struct EqualStridedFactory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ struct FloorDivideTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class floor_divide_strided_strided_kernel;
class floor_divide_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand All @@ -254,7 +254,7 @@ floor_divide_strided_impl(sycl::queue exec_q,
{
return elementwise_common::binary_strided_impl<
argTy1, argTy2, FloorDivideOutputType, FloorDivideStridedFunctor,
floor_divide_strided_strided_kernel>(
floor_divide_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
arg2_offset, res_p, res_offset, depends, additional_depends);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ template <typename fnT, typename T1, typename T2> struct GreaterTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class greater_strided_strided_kernel;
class greater_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand Down Expand Up @@ -289,7 +289,7 @@ greater_strided_impl(sycl::queue exec_q,
resTy *res_tp = reinterpret_cast<resTy *>(res_p);

cgh.parallel_for<
greater_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
greater_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
{nelems}, GreaterStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
arg1_tp, arg2_tp, res_tp, indexer));
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ struct GreaterEqualTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class greater_equal_strided_strided_kernel;
class greater_equal_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand Down Expand Up @@ -295,8 +295,8 @@ greater_equal_strided_impl(sycl::queue exec_q,
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
resTy *res_tp = reinterpret_cast<resTy *>(res_p);

cgh.parallel_for<greater_equal_strided_strided_kernel<argTy1, argTy2,
resTy, IndexerT>>(
cgh.parallel_for<
greater_equal_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
{nelems},
GreaterEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
arg1_tp, arg2_tp, res_tp, indexer));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ template <typename fnT, typename T1, typename T2> struct LessTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class less_strided_strided_kernel;
class less_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand Down Expand Up @@ -286,8 +286,7 @@ less_strided_impl(sycl::queue exec_q,
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
resTy *res_tp = reinterpret_cast<resTy *>(res_p);

cgh.parallel_for<
less_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
cgh.parallel_for<less_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
{nelems}, LessStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
arg1_tp, arg2_tp, res_tp, indexer));
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ template <typename fnT, typename T1, typename T2> struct LessEqualTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class less_equal_strided_strided_kernel;
class less_equal_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand Down Expand Up @@ -290,7 +290,7 @@ less_equal_strided_impl(sycl::queue exec_q,
resTy *res_tp = reinterpret_cast<resTy *>(res_p);

cgh.parallel_for<
less_equal_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
less_equal_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
{nelems}, LessEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
arg1_tp, arg2_tp, res_tp, indexer));
});
Expand Down
Loading