Skip to content

[SYCL] Fix is_device_copyable with range rounding #4478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions sycl/include/CL/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,42 @@ checkValueRange(const T &V) {
#endif
}

template <typename TransformedArgType, int Dims, typename KernelType>
class RoundedRangeKernel {
public:
RoundedRangeKernel(range<Dims> NumWorkItems, KernelType KernelFunc)
: NumWorkItems(NumWorkItems), KernelFunc(KernelFunc) {}

void operator()(TransformedArgType Arg) const {
if (Arg[0] >= NumWorkItems[0])
return;
Arg.set_allowed_range(NumWorkItems);
KernelFunc(Arg);
}

private:
range<Dims> NumWorkItems;
KernelType KernelFunc;
};

template <typename TransformedArgType, int Dims, typename KernelType>
class RoundedRangeKernelWithKH {
public:
RoundedRangeKernelWithKH(range<Dims> NumWorkItems, KernelType KernelFunc)
: NumWorkItems(NumWorkItems), KernelFunc(KernelFunc) {}

void operator()(TransformedArgType Arg, kernel_handler KH) const {
if (Arg[0] >= NumWorkItems[0])
return;
Arg.set_allowed_range(NumWorkItems);
KernelFunc(Arg, KH);
}

private:
range<Dims> NumWorkItems;
KernelType KernelFunc;
};

} // namespace detail

namespace ext {
Expand Down Expand Up @@ -2447,19 +2483,12 @@ class __SYCL_EXPORT handler {
range<Dims> NumWorkItems) {
if constexpr (detail::isKernelLambdaCallableWithKernelHandler<
KernelType, TransformedArgType>()) {
return [=](TransformedArgType Arg, kernel_handler KH) {
if (Arg[0] >= NumWorkItems[0])
return;
Arg.set_allowed_range(NumWorkItems);
KernelFunc(Arg, KH);
};
return detail::RoundedRangeKernelWithKH<TransformedArgType, Dims,
KernelType>(NumWorkItems,
KernelFunc);
} else {
return [=](TransformedArgType Arg) {
if (Arg[0] >= NumWorkItems[0])
return;
Arg.set_allowed_range(NumWorkItems);
KernelFunc(Arg);
};
return detail::RoundedRangeKernel<TransformedArgType, Dims, KernelType>(
NumWorkItems, KernelFunc);
}
}
};
Expand Down
12 changes: 11 additions & 1 deletion sycl/include/CL/sycl/id.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
// Forward declarations
namespace detail {
template <typename TransformedArgType, int Dims, typename KernelType>
class RoundedRangeKernel;
template <typename TransformedArgType, int Dims, typename KernelType>
class RoundedRangeKernelWithKH;
} // namespace detail
template <int dimensions> class range;
template <int dimensions, bool with_offset> class item;

Expand Down Expand Up @@ -241,7 +248,10 @@ template <int dimensions = 1> class id : public detail::array<dimensions> {
#undef __SYCL_GEN_OPT

private:
friend class handler;
// Friend to get access to private method set_allowed_range().
template <typename, int, typename> friend class detail::RoundedRangeKernel;
template <typename, int, typename>
friend class detail::RoundedRangeKernelWithKH;
void set_allowed_range(range<dimensions> rnwi) { (void)rnwi[0]; }
};

Expand Down
9 changes: 8 additions & 1 deletion sycl/include/CL/sycl/item.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ __SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace detail {
class Builder;
template <typename TransformedArgType, int Dims, typename KernelType>
class RoundedRangeKernel;
template <typename TransformedArgType, int Dims, typename KernelType>
class RoundedRangeKernelWithKH;
}
template <int dimensions> class id;
template <int dimensions> class range;
Expand Down Expand Up @@ -120,7 +124,10 @@ template <int dimensions = 1, bool with_offset = true> class item {
friend class detail::Builder;

private:
friend class handler;
// Friend to get access to private method set_allowed_range().
template <typename, int, typename> friend class detail::RoundedRangeKernel;
template <typename, int, typename>
friend class detail::RoundedRangeKernelWithKH;
void set_allowed_range(const range<dimensions> rnwi) { MImpl.MExtent = rnwi; }

detail::ItemBase<dimensions, with_offset> MImpl;
Expand Down
19 changes: 19 additions & 0 deletions sycl/include/CL/sycl/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,12 @@ convertImpl(T Value) {

#endif // __SYCL_DEVICE_ONLY__

// Forward declarations
template <typename TransformedArgType, int Dims, typename KernelType>
class RoundedRangeKernel;
template <typename TransformedArgType, int Dims, typename KernelType>
class RoundedRangeKernelWithKH;

} // namespace detail

#if defined(_WIN32) && (_MSC_VER)
Expand Down Expand Up @@ -2388,6 +2394,19 @@ template <typename FuncT>
struct CheckDeviceCopyable
: CheckFieldsAreDeviceCopyable<FuncT, __builtin_num_fields(FuncT)>,
CheckBasesAreDeviceCopyable<FuncT, __builtin_num_bases(FuncT)> {};

// Below are two specializations for CheckDeviceCopyable when a kernel lambda
// is wrapped after range rounding optimization.
template <typename TransformedArgType, int Dims, typename KernelType>
struct CheckDeviceCopyable<
RoundedRangeKernel<TransformedArgType, Dims, KernelType>>
: CheckDeviceCopyable<KernelType> {};

template <typename TransformedArgType, int Dims, typename KernelType>
struct CheckDeviceCopyable<
RoundedRangeKernelWithKH<TransformedArgType, Dims, KernelType>>
: CheckDeviceCopyable<KernelType> {};

#endif // __SYCL_DEVICE_ONLY__
} // namespace detail

Expand Down
6 changes: 6 additions & 0 deletions sycl/test/basic_tests/is_device_copyable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,10 @@ void test() {

Q.single_task<class TestB>(FunctorA{});
Q.single_task<class TestC>(FunctorB{});

Q.submit([=](sycl::handler &cgh) {
const sycl::range<2> range(1026, 1026);
cgh.parallel_for(range,
[=](sycl::item<2> item) { int A = IamBadButCopyable.i; });
});
}