diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index 571c46daf796b..f4b71cb554bc2 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -10999,11 +10999,9 @@ def err_sycl_restrict : Error< "|use a const static or global variable that is neither zero-initialized " "nor constant-initialized" "}0">; -def warn_sycl_kernel_too_many_args : Warning< - "kernel argument count (%0) exceeds supported maximum of %1 on GPU">, - InGroup; -def note_sycl_kernel_args_count : Note<"array elements and fields of a " - "class/struct may be counted separately">; +def warn_sycl_kernel_too_big_args : Warning< + "size of kernel arguments (%0 bytes) exceeds supported maximum of %1 bytes " + "on GPU">, InGroup; def err_sycl_virtual_types : Error< "No class with a vtable can be used in a SYCL kernel or any code included in the kernel">; def note_sycl_recursive_function_declared_here: Note<"function implemented using recursion declared here">; diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index b2d1c362616ca..1a701d3775afb 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -56,7 +56,7 @@ enum KernelInvocationKind { const static std::string InitMethodName = "__init"; const static std::string FinalizeMethodName = "__finalize"; -constexpr unsigned GPUMaxKernelArgsNum = 2000; +constexpr unsigned GPUMaxKernelArgsSize = 2048; namespace { @@ -1656,32 +1656,35 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { using SyclKernelFieldHandler::leaveStruct; }; -class SyclKernelNumArgsChecker : public SyclKernelFieldHandler { +class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler { SourceLocation KernelLoc; - unsigned NumOfParams = 0; + unsigned SizeOfParams = 0; + + void addParam(QualType ArgTy) { + SizeOfParams += + SemaRef.getASTContext().getTypeSizeInChars(ArgTy).getQuantity(); + } bool handleSpecialType(QualType FieldTy) { const CXXRecordDecl *RecordDecl = FieldTy->getAsCXXRecordDecl(); assert(RecordDecl && "The accessor/sampler must be a RecordDecl"); CXXMethodDecl *InitMethod = getMethodByName(RecordDecl, InitMethodName); assert(InitMethod && "The accessor/sampler must have the __init method"); - NumOfParams += InitMethod->getNumParams(); + for (const ParmVarDecl *Param : InitMethod->parameters()) + addParam(Param->getType()); return true; } public: - SyclKernelNumArgsChecker(Sema &S, SourceLocation Loc) + SyclKernelArgsSizeChecker(Sema &S, SourceLocation Loc) : SyclKernelFieldHandler(S), KernelLoc(Loc) {} - ~SyclKernelNumArgsChecker() { + ~SyclKernelArgsSizeChecker() { if (SemaRef.Context.getTargetInfo().getTriple().getSubArch() == - llvm::Triple::SPIRSubArch_gen) { - if (NumOfParams > GPUMaxKernelArgsNum) { - SemaRef.Diag(KernelLoc, diag::warn_sycl_kernel_too_many_args) - << NumOfParams << GPUMaxKernelArgsNum; - SemaRef.Diag(KernelLoc, diag::note_sycl_kernel_args_count); - } - } + llvm::Triple::SPIRSubArch_gen) + if (SizeOfParams > GPUMaxKernelArgsSize) + SemaRef.Diag(KernelLoc, diag::warn_sycl_kernel_too_big_args) + << SizeOfParams << GPUMaxKernelArgsSize; } bool handleSyclAccessorType(FieldDecl *FD, QualType FieldTy) final { @@ -1703,12 +1706,12 @@ class SyclKernelNumArgsChecker : public SyclKernelFieldHandler { } bool handlePointerType(FieldDecl *FD, QualType FieldTy) final { - NumOfParams++; + addParam(FieldTy); return true; } bool handleScalarType(FieldDecl *FD, QualType FieldTy) final { - NumOfParams++; + addParam(FieldTy); return true; } @@ -1717,17 +1720,17 @@ class SyclKernelNumArgsChecker : public SyclKernelFieldHandler { } bool handleSyclHalfType(FieldDecl *FD, QualType FieldTy) final { - NumOfParams++; + addParam(FieldTy); return true; } bool handleSyclStreamType(FieldDecl *FD, QualType FieldTy) final { - NumOfParams++; + addParam(FieldTy); return true; } bool handleSyclStreamType(const CXXRecordDecl *, const CXXBaseSpecifier &, QualType FieldTy) final { - NumOfParams++; + addParam(FieldTy); return true; } using SyclKernelFieldHandler::handleSyclHalfType; @@ -2468,7 +2471,7 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc, SyclKernelFieldChecker FieldChecker(*this); SyclKernelUnionChecker UnionChecker(*this); - SyclKernelNumArgsChecker NumArgsChecker(*this, Args[0]->getExprLoc()); + SyclKernelArgsSizeChecker ArgsSizeChecker(*this, Args[0]->getExprLoc()); // check that calling kernel conforms to spec QualType KernelParamTy = KernelFunc->getParamDecl(0)->getType(); if (KernelParamTy->isReferenceType()) { @@ -2488,9 +2491,9 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc, KernelObjVisitor Visitor{*this}; DiagnosingSYCLKernel = true; Visitor.VisitRecordBases(KernelObj, FieldChecker, UnionChecker, - NumArgsChecker); + ArgsSizeChecker); Visitor.VisitRecordFields(KernelObj, FieldChecker, UnionChecker, - NumArgsChecker); + ArgsSizeChecker); DiagnosingSYCLKernel = false; if (!FieldChecker.isValid() || !UnionChecker.isValid()) KernelFunc->setInvalidDecl(); diff --git a/clang/test/SemaSYCL/num-args-overflow.cpp b/clang/test/SemaSYCL/args-size-overflow.cpp similarity index 75% rename from clang/test/SemaSYCL/num-args-overflow.cpp rename to clang/test/SemaSYCL/args-size-overflow.cpp index e4fcc1c2f2d71..27fa7fdb01b20 100644 --- a/clang/test/SemaSYCL/num-args-overflow.cpp +++ b/clang/test/SemaSYCL/args-size-overflow.cpp @@ -13,11 +13,9 @@ __attribute__((sycl_kernel)) void kernel(F KernelFunc) { template void parallel_for(F KernelFunc) { #ifdef GPU - // expected-warning@+8 {{kernel argument count (2001) exceeds supported maximum of 2000 on GPU}} - // expected-note@+7 {{array elements and fields of a class/struct may be counted separately}} + // expected-warning@+6 {{size of kernel arguments (7994 bytes) exceeds supported maximum of 2048 bytes on GPU}} #elif ERROR - // expected-error@+5 {{kernel argument count (2001) exceeds supported maximum of 2000 on GPU}} - // expected-note@+4 {{array elements and fields of a class/struct may be counted separately}} + // expected-error@+4 {{size of kernel arguments (7994 bytes) exceeds supported maximum of 2048 bytes on GPU}} #else // expected-no-diagnostics #endif