diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index fac53fcfbf94a..971b8eadec6f6 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -914,42 +914,6 @@ class KernelBodyTransform : public TreeTransform { Sema &SemaRef; }; -// Searches for a call to PFWG lambda function and captures it. -class FindPFWGLambdaFnVisitor - : public RecursiveASTVisitor { -public: - // LambdaObjTy - lambda type of the PFWG lambda object - FindPFWGLambdaFnVisitor(const CXXRecordDecl *LambdaObjTy) - : LambdaFn(nullptr), LambdaObjTy(LambdaObjTy) {} - - bool VisitCallExpr(CallExpr *Call) { - auto *M = dyn_cast(Call->getDirectCallee()); - if (!M || (M->getOverloadedOperator() != OO_Call)) - return true; - - unsigned int NumPFWGLambdaArgs = - M->getNumParams() + 1; // group, optional kernel_handler and lambda obj - if (Call->getNumArgs() != NumPFWGLambdaArgs) - return true; - if (!Util::isSyclType(Call->getArg(1)->getType(), "group", true /*Tmpl*/)) - return true; - if ((Call->getNumArgs() > 2) && - !Util::isSyclKernelHandlerType(Call->getArg(2)->getType())) - return true; - if (Call->getArg(0)->getType()->getAsCXXRecordDecl() != LambdaObjTy) - return true; - LambdaFn = M; // call to PFWG lambda found - record the lambda - return false; // ... and stop searching - } - - // Returns the captured lambda function or nullptr; - CXXMethodDecl *getLambdaFn() const { return LambdaFn; } - -private: - CXXMethodDecl *LambdaFn; - const CXXRecordDecl *LambdaObjTy; -}; - class MarkWIScopeFnVisitor : public RecursiveASTVisitor { public: MarkWIScopeFnVisitor(ASTContext &Ctx) : Ctx(Ctx) {} @@ -2541,10 +2505,16 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { void markParallelWorkItemCalls() { if (getKernelInvocationKind(KernelCallerFunc) == InvokeParallelForWorkGroup) { - FindPFWGLambdaFnVisitor V(KernelObj); - V.TraverseStmt(KernelCallerFunc->getBody()); - CXXMethodDecl *WGLambdaFn = V.getLambdaFn(); - assert(WGLambdaFn && "PFWG lambda not found"); + // Fetch the kernel object and the associated call operator + // (of either the lambda or the function object). + CXXRecordDecl *KernelObj = + GetSYCLKernelObjectType(KernelCallerFunc)->getAsCXXRecordDecl(); + CXXMethodDecl *WGLambdaFn = nullptr; + if (KernelObj->isLambda()) + WGLambdaFn = KernelObj->getLambdaCallOperator(); + else + WGLambdaFn = getOperatorParens(KernelObj); + assert(WGLambdaFn && "non callable object is passed as kernel obj"); // Mark the function that it "works" in a work group scope: // NOTE: In case of parallel_for_work_item the marker call itself is // marked with work item scope attribute, here the '()' operator of the diff --git a/clang/test/SemaSYCL/sycl-pfwg-invalid-code.cpp b/clang/test/SemaSYCL/sycl-pfwg-invalid-code.cpp new file mode 100644 index 0000000000000..db8e77a678998 --- /dev/null +++ b/clang/test/SemaSYCL/sycl-pfwg-invalid-code.cpp @@ -0,0 +1,12 @@ +// RUN: %clang_cc1 -fsycl-is-device %s -verify + +// Tests that the compiler does not crash (due to a triggered assertion) +// if definition of kernel_parallel_for_work_group is invalid. +template +__attribute__((sycl_kernel)) void kernel_parallel_for_work_group(const K &) { + unknown(); // expected-error{{use of undeclared identifier 'unknown'}} +} +void foo() { + auto lambda = [] {}; + kernel_parallel_for_work_group(lambda); +}