diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index 8d021f7e00d49..fcd345c6a6721 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -301,7 +301,8 @@ class SYCLIntegrationHeader { kind_accessor = kind_first, kind_std_layout, kind_sampler, - kind_last = kind_sampler + kind_pointer, + kind_last = kind_pointer }; public: diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 1af8ea665f191..5b75385f8ccb1 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -331,6 +331,11 @@ class MarkDeviceFunction : public RecursiveASTVisitor { private: bool CheckSYCLType(QualType Ty, SourceRange Loc) { + llvm::DenseSet visited; + return CheckSYCLType(Ty, Loc, visited); + } + + bool CheckSYCLType(QualType Ty, SourceRange Loc, llvm::DenseSet &Visited) { if (Ty->isVariableArrayType()) { SemaRef.Diag(Loc.getBegin(), diag::err_vla_unsupported); return false; @@ -339,6 +344,11 @@ class MarkDeviceFunction : public RecursiveASTVisitor { while (Ty->isAnyPointerType() || Ty->isArrayType()) Ty = QualType{Ty->getPointeeOrArrayElementType(), 0}; + // Pointers complicate recursion. Add this type to Visited. + // If already there, bail out. + if (!Visited.insert(Ty).second) + return true; + if (const auto *CRD = Ty->getAsCXXRecordDecl()) { if (CRD->isPolymorphic()) { SemaRef.Diag(CRD->getLocation(), diag::err_sycl_virtual_types); @@ -347,25 +357,25 @@ class MarkDeviceFunction : public RecursiveASTVisitor { } for (const auto &Field : CRD->fields()) { - if (!CheckSYCLType(Field->getType(), Field->getSourceRange())) { + if (!CheckSYCLType(Field->getType(), Field->getSourceRange(), Visited)) { SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here); return false; } } } else if (const auto *RD = Ty->getAsRecordDecl()) { for (const auto &Field : RD->fields()) { - if (!CheckSYCLType(Field->getType(), Field->getSourceRange())) { + if (!CheckSYCLType(Field->getType(), Field->getSourceRange(), Visited)) { SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here); return false; } } } else if (const auto *FPTy = dyn_cast(Ty)) { for (const auto &ParamTy : FPTy->param_types()) - if (!CheckSYCLType(ParamTy, Loc)) + if (!CheckSYCLType(ParamTy, Loc, Visited)) return false; - return CheckSYCLType(FPTy->getReturnType(), Loc); + return CheckSYCLType(FPTy->getReturnType(), Loc, Visited); } else if (const auto *FTy = dyn_cast(Ty)) { - return CheckSYCLType(FTy->getReturnType(), Loc); + return CheckSYCLType(FTy->getReturnType(), Loc, Visited); } return true; } @@ -766,6 +776,16 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj, // Create descriptors for each accessor field in the class or struct createParamDescForWrappedAccessors(Fld, ArgTy); + } else if (ArgTy->isPointerType()) { + // Pointer Arguments need to be in the global address space + QualType PointeeTy = ArgTy->getPointeeType(); + Qualifiers Quals = PointeeTy.getQualifiers(); + Quals.setAddressSpace(LangAS::opencl_global); + PointeeTy = Context.getQualifiedType(PointeeTy.getUnqualifiedType(), + Quals); + QualType ModTy = Context.getPointerType(PointeeTy); + + CreateAndAddPrmDsc(Fld, ModTy); } else if (ArgTy->isScalarType()) { CreateAndAddPrmDsc(Fld, ArgTy); } else { @@ -853,6 +873,10 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name, uint64_t Sz = Ctx.getTypeSizeInChars(SamplerArg->getType()).getQuantity(); H.addParamDesc(SYCLIntegrationHeader::kind_sampler, static_cast(Sz), static_cast(Offset)); + } else if (ArgTy->isPointerType()) { + uint64_t Sz = Ctx.getTypeSizeInChars(Fld->getType()).getQuantity(); + H.addParamDesc(SYCLIntegrationHeader::kind_pointer, + static_cast(Sz), static_cast(Offset)); } else if (ArgTy->isStructureOrClassType() || ArgTy->isScalarType()) { // the parameter is an object of standard layout type or scalar; // the check for standard layout is done elsewhere @@ -1017,6 +1041,7 @@ static const char *paramKind2Str(KernelParamKind K) { CASE(accessor); CASE(std_layout); CASE(sampler); + CASE(pointer); default: return ""; } diff --git a/clang/test/CodeGenSYCL/usm-int-header.cpp b/clang/test/CodeGenSYCL/usm-int-header.cpp new file mode 100644 index 0000000000000..2f49b58b5ee32 --- /dev/null +++ b/clang/test/CodeGenSYCL/usm-int-header.cpp @@ -0,0 +1,34 @@ +// RUN: %clang_cc1 -std=c++11 -I %S/Inputs -fsycl-is-device -ast-dump %s | FileCheck %s +// RUN: %clang -I %S/Inputs --sycl -Xclang -fsycl-int-header=%t.h %s -c -o kernel.spv +// RUN: FileCheck -input-file=%t.h %s --check-prefix=INT-HEADER + +// INT-HEADER:{ kernel_param_kind_t::kind_pointer, 8, 0 }, +// INT-HEADER:{ kernel_param_kind_t::kind_pointer, 8, 8 }, + +//==--usm-int-header.cpp - USM kernel param aspace and int header test -----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + + +#include + +template +__attribute__((sycl_kernel)) void kernel(Func kernelFunc) { + kernelFunc(); +} + +int main() { + int* x; + float* y; + + kernel([=]() { + *x = 42; + *y = 3.14; + }); +} + +// CHECK: FunctionDecl {{.*}}usm_test 'void (__global int *, __global float *)' diff --git a/sycl/include/CL/sycl/detail/kernel_desc.hpp b/sycl/include/CL/sycl/detail/kernel_desc.hpp index f5cecb8dfaa2a..ebf71ac8561c7 100644 --- a/sycl/include/CL/sycl/detail/kernel_desc.hpp +++ b/sycl/include/CL/sycl/detail/kernel_desc.hpp @@ -33,7 +33,8 @@ class half; enum class kernel_param_kind_t { kind_accessor, kind_std_layout, // standard layout object parameters - kind_sampler + kind_sampler, + kind_pointer }; // describes a kernel parameter diff --git a/sycl/include/CL/sycl/handler.hpp b/sycl/include/CL/sycl/handler.hpp index dcc3fb4ae2cfd..d793674080d05 100644 --- a/sycl/include/CL/sycl/handler.hpp +++ b/sycl/include/CL/sycl/handler.hpp @@ -239,9 +239,11 @@ class handler { const auto kind_std_layout = detail::kernel_param_kind_t::kind_std_layout; const auto kind_accessor = detail::kernel_param_kind_t::kind_accessor; const auto kind_sampler = detail::kernel_param_kind_t::kind_sampler; + const auto kind_pointer = detail::kernel_param_kind_t::kind_pointer; switch (Kind) { - case kind_std_layout: { + case kind_std_layout: + case kind_pointer: { MArgs.emplace_back(Kind, Ptr, Size, Index + IndexShift); break; }