diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index b7d3c93ad90d2..da49b8ea61a0a 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -783,12 +783,14 @@ static void VisitField(CXXRecordDecl *Owner, RangeTy &&Item, QualType ItemTy, Handlers &... handlers) { if (Util::isSyclAccessorType(ItemTy)) KF_FOR_EACH(handleSyclAccessorType, Item, ItemTy); - if (Util::isSyclStreamType(ItemTy)) + else if (Util::isSyclStreamType(ItemTy)) KF_FOR_EACH(handleSyclStreamType, Item, ItemTy); - if (ItemTy->isStructureOrClassType()) + else if (Util::isSyclSamplerType(ItemTy)) + KF_FOR_EACH(handleSyclSamplerType, Item, ItemTy); + else if (ItemTy->isStructureOrClassType()) VisitAccessorWrapper(Owner, Item, ItemTy->getAsCXXRecordDecl(), handlers...); - if (ItemTy->isArrayType()) + else if (ItemTy->isArrayType()) VisitArrayElements(Item, ItemTy, handlers...); } @@ -891,6 +893,9 @@ template class SyclKernelFieldHandler { return true; } virtual bool handleSyclAccessorType(FieldDecl *, QualType) { return true; } + virtual bool handleSyclSamplerType(const CXXBaseSpecifier &, QualType) { + return true; + } virtual bool handleSyclSamplerType(FieldDecl *, QualType) { return true; } virtual bool handleSyclSpecConstantType(FieldDecl *, QualType) { return true; @@ -1203,6 +1208,7 @@ class SyclKernelDeclCreator return ArrayRef(std::begin(Params) + LastParamIndex, std::end(Params)); } + using SyclKernelFieldHandler::handleSyclSamplerType; }; class SyclKernelBodyCreator @@ -1457,6 +1463,7 @@ class SyclKernelBodyCreator } using SyclKernelFieldHandler::enterStruct; + using SyclKernelFieldHandler::handleSyclSamplerType; using SyclKernelFieldHandler::leaveStruct; }; @@ -1606,6 +1613,7 @@ class SyclKernelIntHeaderCreator CurOffset -= Layout.getBaseClassOffset(BS.getType()->getAsCXXRecordDecl()) .getQuantity(); } + using SyclKernelFieldHandler::handleSyclSamplerType; }; } // namespace diff --git a/clang/test/CodeGenSYCL/sampler.cpp b/clang/test/CodeGenSYCL/sampler.cpp index 947a650afea12..749b5a0bfdaa9 100644 --- a/clang/test/CodeGenSYCL/sampler.cpp +++ b/clang/test/CodeGenSYCL/sampler.cpp @@ -11,8 +11,20 @@ // CHECK-NEXT: [[GEPCAST:%[0-9]+]] = addrspacecast %"class{{.*}}.cl::sycl::sampler"* [[GEP]] to %"class{{.*}}.cl::sycl::sampler" addrspace(4)* // CHECK-NEXT: call spir_func void @{{[a-zA-Z0-9_]+}}(%"class.{{.*}}.cl::sycl::sampler" addrspace(4)* [[GEPCAST]], %opencl.sampler_t addrspace(2)* [[LOAD_SAMPLER_ARG]]) // + +// CHECK: define spir_kernel void @{{[a-zA-Z0-9_]+}}(%struct{{.*}}sampler_wrapper{{.*}} %opencl.sampler_t addrspace(2)* [[SAMPLER_ARG_WRAPPED:%[a-zA-Z0-9_]+]]) +// CHECK: [[SAMPLER_ARG_WRAPPED]].addr = alloca %opencl.sampler_t addrspace(2)*, align 8 +// CHECK: store %opencl.sampler_t addrspace(2)* [[SAMPLER_ARG_WRAPPED]], %opencl.sampler_t addrspace(2)** [[SAMPLER_ARG_WRAPPED]].addr, align 8 +// CHECK: [[LOAD_SAMPLER_ARG_WRAPPED:%[0-9]+]] = load %opencl.sampler_t addrspace(2)*, %opencl.sampler_t addrspace(2)** [[SAMPLER_ARG_WRAPPED]].addr, align 8 +// CHECK: call spir_func void @{{[a-zA-Z0-9_]+}}(%"class.{{.*}}.cl::sycl::sampler" addrspace(4)* {{.*}}, %opencl.sampler_t addrspace(2)* [[LOAD_SAMPLER_ARG_WRAPPED]]) +// #include "sycl.hpp" +struct sampler_wrapper { + cl::sycl::sampler smpl; + int a; +}; + template __attribute__((sycl_kernel)) void kernel_single_task(KernelType kernelFunc) { kernelFunc(); @@ -24,5 +36,10 @@ int main() { smplr.use(); }); + sampler_wrapper wrappedSampler = {smplr, 1}; + kernel_single_task([=]() { + wrappedSampler.smpl.use(); + }); + return 0; } diff --git a/sycl/test/basic_tests/sampler/sampler.cpp b/sycl/test/basic_tests/sampler/sampler.cpp index 352235fef391f..4ea4c287790d2 100644 --- a/sycl/test/basic_tests/sampler/sampler.cpp +++ b/sycl/test/basic_tests/sampler/sampler.cpp @@ -22,6 +22,15 @@ namespace sycl { using namespace cl::sycl; } +struct SamplerWrapper { + SamplerWrapper(sycl::coordinate_normalization_mode Norm, + sycl::addressing_mode Addr, sycl::filtering_mode Filter) + : Smpl(Norm, Addr, Filter), A(0) {} + + sycl::sampler Smpl; + int A; +}; + int main() { // Check constructor from enums sycl::sampler A(sycl::coordinate_normalization_mode::unnormalized, @@ -88,6 +97,10 @@ int main() { assert(C == A); assert(Hasher(C) != Hasher(B)); + SamplerWrapper WrappedSmplr( + sycl::coordinate_normalization_mode::normalized, + sycl::addressing_mode::repeat, sycl::filtering_mode::linear); + // Device sampler. { sycl::queue Queue; @@ -95,6 +108,7 @@ int main() { cgh.single_task([=]() { sycl::sampler C = A; sycl::sampler D(C); + sycl::sampler E(WrappedSmplr.Smpl); }); }); }