diff --git a/sycl/include/CL/sycl/accessor.hpp b/sycl/include/CL/sycl/accessor.hpp index 4584a8c3c955c..9bdc641371b14 100644 --- a/sycl/include/CL/sycl/accessor.hpp +++ b/sycl/include/CL/sycl/accessor.hpp @@ -150,9 +150,9 @@ SYCL_ACCESSOR_IMPL(!isTargetHostAccess(accessTarget) && // reinterpret casting while setting kernel arguments in order to get cl_mem // value from the buffer regardless of the accessor's dimensionality. #ifndef __SYCL_DEVICE_ONLY__ - detail::buffer_impl> *m_Buf = nullptr; + detail::buffer_impl *m_Buf = nullptr; #else - char padding[sizeof(detail::buffer_impl> *)]; + char padding[sizeof(detail::buffer_impl *)]; #endif // __SYCL_DEVICE_ONLY__ dataT *Data; @@ -185,9 +185,9 @@ SYCL_ACCESSOR_IMPL(!isTargetHostAccess(accessTarget) && // reinterpret casting while setting kernel arguments in order to get cl_mem // value from the buffer regardless of the accessor's dimensionality. #ifndef __SYCL_DEVICE_ONLY__ - detail::buffer_impl> *m_Buf = nullptr; + detail::buffer_impl *m_Buf = nullptr; #else - char padding[sizeof(detail::buffer_impl> *)]; + char padding[sizeof(detail::buffer_impl *)]; #endif // __SYCL_DEVICE_ONLY__ dataT *Data; @@ -763,7 +763,7 @@ class accessor #ifdef __SYCL_DEVICE_ONLY__ ; // This ctor can't be used in device code, so no need to define it. #else // !__SYCL_DEVICE_ONLY__ - : __impl(detail::getSyclObjImpl(bufferRef)->BufPtr, Range, + : __impl((dataT *)detail::getSyclObjImpl(bufferRef)->BufPtr, Range, bufferRef.get_range(), Offset) { auto BufImpl = detail::getSyclObjImpl(bufferRef); if (AccessTarget == access::target::host_buffer) { diff --git a/sycl/include/CL/sycl/buffer.hpp b/sycl/include/CL/sycl/buffer.hpp index eac3d1fee317e..9a43504d5de72 100644 --- a/sycl/include/CL/sycl/buffer.hpp +++ b/sycl/include/CL/sycl/buffer.hpp @@ -21,7 +21,7 @@ class queue; template class range; template > + typename AllocatorT = cl::sycl::buffer_allocator> class buffer { public: using value_type = T; @@ -36,11 +36,11 @@ class buffer { get_count() * sizeof(T), propList); } - // buffer(const range &bufferRange, AllocatorT allocator, - // const property_list &propList = {}) { - // impl = std::make_shared(bufferRange, allocator, - // propList); - // } + buffer(const range &bufferRange, AllocatorT allocator, + const property_list &propList = {}) { + impl = std::make_shared>( + get_count() * sizeof(T), propList, allocator); + } buffer(T *hostData, const range &bufferRange, const property_list &propList = {}) @@ -49,11 +49,11 @@ class buffer { hostData, get_count() * sizeof(T), propList); } - // buffer(T *hostData, const range &bufferRange, - // AllocatorT allocator, const property_list &propList = {}) { - // impl = std::make_shared(hostData, bufferRange, - // allocator, propList); - // } + buffer(T *hostData, const range &bufferRange, + AllocatorT allocator, const property_list &propList = {}) { + impl = std::make_shared>( + hostData, get_count() * sizeof(T), propList, allocator); + } buffer(const T *hostData, const range &bufferRange, const property_list &propList = {}) @@ -62,18 +62,18 @@ class buffer { hostData, get_count() * sizeof(T), propList); } - // buffer(const T *hostData, const range &bufferRange, - // AllocatorT allocator, const property_list &propList = {}) { - // impl = std::make_shared(hostData, bufferRange, - // allocator, propList); - // } + buffer(const T *hostData, const range &bufferRange, + AllocatorT allocator, const property_list &propList = {}) { + impl = std::make_shared>( + hostData, get_count() * sizeof(T), propList, allocator); + } - // buffer(const shared_ptr_class &hostData, - // const range &bufferRange, AllocatorT allocator, - // const property_list &propList = {}) { - // impl = std::make_shared(hostData, bufferRange, - // allocator, propList); - // } + buffer(const shared_ptr_class &hostData, + const range &bufferRange, AllocatorT allocator, + const property_list &propList = {}) { + impl = std::make_shared>( + hostData, get_count() * sizeof(T), propList, allocator); + } buffer(const shared_ptr_class &hostData, const range &bufferRange, @@ -83,12 +83,13 @@ class buffer { hostData, get_count() * sizeof(T), propList); } - // template - // buffer(InputIterator first, InputIterator last, AllocatorT allocator, - // const property_list &propList = {}) { - // impl = std::make_shared(first, last, allocator, - // propList); - // } + template + buffer(InputIterator first, InputIterator last, AllocatorT allocator, + const property_list &propList = {}) + : Range(range<1>(std::distance(first, last))) { + impl = std::make_shared>( + first, last, get_count() * sizeof(T), propList, allocator); + } template > @@ -135,7 +136,7 @@ class buffer { size_t get_size() const { return impl->get_size(); } - // AllocatorT get_allocator() const { return impl->get_allocator(); } + AllocatorT get_allocator() const { return impl->get_allocator(); } template @@ -152,28 +153,29 @@ class buffer { return impl->template get_access(*this); } - // template accessor get_access( handler &commandGroupHandler, - // range accessRange, id accessOffset = {}) { - // return impl->get_access(commandGroupHandler, accessRange, - // accessOffset); - // } + template + accessor + get_access(handler &commandGroupHandler, range accessRange, + id accessOffset = {}) { + return impl->template get_access( + *this, commandGroupHandler, accessRange, accessOffset); + } - // template - // accessor get_access( range accessRange, - // id accessOffset = {}) { - // return impl->get_access(accessRange, accessOffset); - // } + template + accessor + get_access(range accessRange, id accessOffset = {}) { + return impl->template get_access(*this, accessRange, + accessOffset); + } template void set_final_data(Destination finalData = nullptr) { impl->set_final_data(finalData); } - // void set_write_back(bool flag = true) { return impl->set_write_back(flag); - // } + void set_write_back(bool flag = true) { return impl->set_write_back(flag); } // bool is_sub_buffer() const { return impl->is_sub_buffer(); } @@ -189,6 +191,14 @@ class buffer { reinterpretRange); } + template bool has_property() const { + return impl->template has_property(); + } + + template propertyT get_property() const { + return impl->template get_property(); + } + private: shared_ptr_class> impl; template diff --git a/sycl/include/CL/sycl/detail/buffer_impl.hpp b/sycl/include/CL/sycl/detail/buffer_impl.hpp index 4921c203f533f..321bacb2f5aae 100644 --- a/sycl/include/CL/sycl/detail/buffer_impl.hpp +++ b/sycl/include/CL/sycl/detail/buffer_impl.hpp @@ -37,16 +37,18 @@ class handler; class queue; template class id; template class range; -template using buffer_allocator = std::allocator; +using buffer_allocator = std::allocator; namespace detail { template class buffer_impl { public: - buffer_impl(const size_t sizeInBytes, const property_list &propList) - : buffer_impl((void *)nullptr, sizeInBytes, propList) {} + buffer_impl(const size_t sizeInBytes, const property_list &propList, + AllocatorT allocator = AllocatorT()) + : buffer_impl((void *)nullptr, sizeInBytes, propList, allocator) {} buffer_impl(void *hostData, const size_t sizeInBytes, - const property_list &propList) - : SizeInBytes(sizeInBytes), Props(propList) { + const property_list &propList, + AllocatorT allocator = AllocatorT()) + : SizeInBytes(sizeInBytes), Props(propList), MAllocator(allocator) { if (Props.has_property()) { BufPtr = hostData; } else { @@ -62,8 +64,9 @@ template class buffer_impl { // TODO temporary solution for allowing initialisation with const data buffer_impl(const void *hostData, const size_t sizeInBytes, - const property_list &propList) - : SizeInBytes(sizeInBytes), Props(propList) { + const property_list &propList, + AllocatorT allocator = AllocatorT()) + : SizeInBytes(sizeInBytes), Props(propList), MAllocator(allocator) { if (Props.has_property()) { // TODO make this buffer read only BufPtr = const_cast(hostData); @@ -79,8 +82,9 @@ template class buffer_impl { template buffer_impl(const shared_ptr_class &hostData, const size_t sizeInBytes, - const property_list &propList) - : SizeInBytes(sizeInBytes), Props(propList) { + const property_list &propList, + AllocatorT allocator = AllocatorT()) + : SizeInBytes(sizeInBytes), Props(propList), MAllocator(allocator) { if (Props.has_property()) { BufPtr = hostData.get(); } else { @@ -97,8 +101,9 @@ template class buffer_impl { template buffer_impl(InputIterator first, InputIterator last, const size_t sizeInBytes, - const property_list &propList) - : SizeInBytes(sizeInBytes), Props(propList) { + const property_list &propList, + AllocatorT allocator = AllocatorT()) + : SizeInBytes(sizeInBytes), Props(propList), MAllocator(allocator) { if (Props.has_property()) { // TODO next line looks unsafe BufPtr = &*first; @@ -140,7 +145,7 @@ template class buffer_impl { .copyBack( *this); - if (uploadData != nullptr) { + if (uploadData != nullptr && NeedWriteBack) { uploadData(); } @@ -170,7 +175,7 @@ template class buffer_impl { throw cl::sycl::runtime_error( "set_final_data could not be used with interoperability buffer"); static_assert(!std::is_const::value, - "Сan not write in a constant Destination. Destination should " + "Can not write in a constant Destination. Destination should " "not be const."); uploadData = [this, final_data]() mutable { auto *Ptr = @@ -182,6 +187,10 @@ template class buffer_impl { }; } + void set_write_back(bool flag) { NeedWriteBack = flag; } + + AllocatorT get_allocator() const { return MAllocator; } + template accessor @@ -199,6 +208,34 @@ template class buffer_impl { access::placeholder::false_t>(Buffer); } + template + accessor + get_access(buffer &Buffer, + handler &commandGroupHandler, range accessRange, + id accessOffset) { + return accessor( + Buffer, commandGroupHandler, accessRange, accessOffset); + } + + template + accessor + get_access(buffer &Buffer, + range accessRange, id accessOffset) { + return accessor(Buffer, accessRange, + accessOffset); + } + + template bool has_property() const { + return Props.has_property(); + } + + template propertyT get_property() const { + return Props.get_property(); + } + public: void moveMemoryTo(QueueImplPtr Queue, std::vector DepEvents, EventImplPtr Event); @@ -243,8 +280,10 @@ template class buffer_impl { // This field must be the first to guarantee that it's safe to use // reinterpret casting while setting kernel arguments in order to get cl_mem // value from the buffer regardless of its dimensionality. + AllocatorT MAllocator; OpenCLMemState OCLState; bool OpenCLInterop = false; + bool NeedWriteBack = true; event AvailableEvent; cl_context OpenCLContext = nullptr; void *BufPtr = nullptr; diff --git a/sycl/include/CL/sycl/detail/scheduler/scheduler.cpp b/sycl/include/CL/sycl/detail/scheduler/scheduler.cpp index 19c49a1c1e6e4..b724fac653cd8 100644 --- a/sycl/include/CL/sycl/detail/scheduler/scheduler.cpp +++ b/sycl/include/CL/sycl/detail/scheduler/scheduler.cpp @@ -61,8 +61,7 @@ template &&Acc, int argIndex) { - detail::buffer_impl> *buf = - Acc.__get_impl()->m_Buf; + detail::buffer_impl *buf = Acc.__get_impl()->m_Buf; addBufRequirement(*buf); addInteropArg(nullptr, buf->get_size(), argIndex, getReqForBuffer(m_Bufs, *buf)); @@ -134,7 +133,7 @@ void Node::addExplicitMemOp( auto *DestBase = Dest.__get_impl(); assert(DestBase != nullptr && "Accessor should have an initialized accessor_base"); - detail::buffer_impl> *Buf = DestBase->m_Buf; + detail::buffer_impl *Buf = DestBase->m_Buf; range Range = DestBase->AccessRange; id Offset = DestBase->Offset; @@ -162,10 +161,10 @@ void Node::addExplicitMemOp( assert(DestBase != nullptr && "Accessor should have an initialized accessor_base"); - detail::buffer_impl> *SrcBuf = SrcBase->m_Buf; + detail::buffer_impl *SrcBuf = SrcBase->m_Buf; assert(SrcBuf != nullptr && "Accessor should have an initialized buffer_impl"); - detail::buffer_impl> *DestBuf = DestBase->m_Buf; + detail::buffer_impl *DestBuf = DestBase->m_Buf; assert(DestBuf != nullptr && "Accessor should have an initialized buffer_impl"); @@ -195,7 +194,7 @@ void Scheduler::updateHost( auto *AccBase = Acc.__get_impl(); assert(AccBase != nullptr && "Accessor should have an initialized accessor_base"); - detail::buffer_impl> *Buf = AccBase->m_Buf; + detail::buffer_impl *Buf = AccBase->m_Buf; updateHost(*Buf, Event); } diff --git a/sycl/test/basic_tests/buffer/buffer.cpp b/sycl/test/basic_tests/buffer/buffer.cpp index 02450b91abbfc..574d938fee1d5 100644 --- a/sycl/test/basic_tests/buffer/buffer.cpp +++ b/sycl/test/basic_tests/buffer/buffer.cpp @@ -19,6 +19,7 @@ using namespace cl::sycl; int main() { int data = 5; + bool failed = false; buffer buf(&data, range<1>(1)); { int data1[10] = {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1}; @@ -453,5 +454,55 @@ int main() { for (int i = 5; i < 10; i++) assert(data1[i] == -1); } + + // Check that data is copied back after forcing write-back using + // set_write_back + { + std::vector data1(10, -1); + { + buffer b(range<1>(10)); + b.set_final_data(data1.data()); + b.set_write_back(true); + queue myQueue; + myQueue.submit([&](handler &cgh) { + auto B = b.get_access(cgh); + cgh.parallel_for(range<1>{10}, + [=](id<1> index) { B[index] = 0; }); + }); + + } + // Data is copied back because there is a user side ptr and write-back is + // enabled + for (int i = 0; i < 10; i++) + if (data1[i] != 0) { + assert(false); + failed = true; + } + } + + // Check that data is not copied back after canceling write-back using + // set_write_back + { + std::vector data1(10, -1); + { + buffer b(range<1>(10)); + b.set_final_data(data1.data()); + b.set_write_back(false); + queue myQueue; + myQueue.submit([&](handler &cgh) { + auto B = b.get_access(cgh); + cgh.parallel_for(range<1>{10}, + [=](id<1> index) { B[index] = 0; }); + }); + + } + // Data is not copied back because write-back is canceled + for (int i = 0; i < 10; i++) + if (data1[i] != -1) { + assert(false); + failed = true; + } + } // TODO tests with mutex property + return failed; }