diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index b3ecda585cf3e..2aa65f5f60044 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -677,73 +677,92 @@ constructKernelName(Sema &S, FunctionDecl *KernelCallerFunc, // anonymous namespace so these don't get linkage. namespace { -QualType getItemType(const FieldDecl *FD) { return FD->getType(); } -QualType getItemType(const CXXBaseSpecifier &BS) { return BS.getType(); } - // Implements the 'for-each-visitor' pattern. template -static void VisitAccessorWrapper(CXXRecordDecl *Owner, ParentTy &Parent, - CXXRecordDecl *Wrapper, - Handlers &... handlers); - -template -static void VisitAccessorWrapperHelper(CXXRecordDecl *Owner, RangeTy Range, - Handlers &... handlers) { - for (const auto &Item : Range) { - QualType ItemTy = getItemType(Item); - if (Util::isSyclAccessorType(ItemTy)) +static void VisitRecord(CXXRecordDecl *Owner, ParentTy &Parent, + CXXRecordDecl *Wrapper, Handlers &... handlers); + +template +static void VisitRecordHelper(CXXRecordDecl *Owner, + clang::CXXRecordDecl::base_class_range Range, + Handlers &... handlers) { + for (const auto &Base : Range) { + QualType BaseTy = Base.getType(); + if (Util::isSyclAccessorType(BaseTy)) (void)std::initializer_list{ - (handlers.handleSyclAccessorType(Item, ItemTy), 0)...}; - else if (ItemTy->isStructureOrClassType()) { - VisitAccessorWrapper(Owner, Item, ItemTy->getAsCXXRecordDecl(), - handlers...); - if (Util::isSyclStreamType(ItemTy)) - (void)std::initializer_list{ - (handlers.handleSyclStreamType(Item, ItemTy), 0)...}; - } + (handlers.handleSyclAccessorType(Base, BaseTy), 0)...}; + else if (Util::isSyclStreamType(BaseTy)) + (void)std::initializer_list{ + (handlers.handleSyclStreamType(Base, BaseTy), 0)...}; + else + VisitRecord(Owner, Base, BaseTy->getAsCXXRecordDecl(), handlers...); } } +template +static void VisitRecordHelper(CXXRecordDecl *Owner, + clang::RecordDecl::field_range Range, + Handlers &... handlers) { + VisitRecordFields(Owner, handlers...); +} + // Parent contains the FieldDecl or CXXBaseSpecifier that was used to enter // the Wrapper structure that we're currently visiting. Owner is the parent type // (which doesn't exist in cases where it is a FieldDecl in the 'root'), and // Wrapper is the current struct being unwrapped. template -static void VisitAccessorWrapper(CXXRecordDecl *Owner, ParentTy &Parent, - CXXRecordDecl *Wrapper, - Handlers &... handlers) { +static void VisitRecord(CXXRecordDecl *Owner, ParentTy &Parent, + CXXRecordDecl *Wrapper, Handlers &... handlers) { (void)std::initializer_list{(handlers.enterStruct(Owner, Parent), 0)...}; - VisitAccessorWrapperHelper(Wrapper, Wrapper->bases(), handlers...); - VisitAccessorWrapperHelper(Wrapper, Wrapper->fields(), handlers...); + VisitRecordHelper(Wrapper, Wrapper->bases(), handlers...); + VisitRecordHelper(Wrapper, Wrapper->fields(), handlers...); (void)std::initializer_list{(handlers.leaveStruct(Owner, Parent), 0)...}; } +int getFieldNumber(const CXXRecordDecl *BaseDecl) { + int Members = 0; + for (const auto *Field : BaseDecl->fields()) + ++Members; + + return Members; +} + +template +static void VisitFunctorBases(CXXRecordDecl *KernelFunctor, + Handlers &... handlers) { + VisitRecordHelper(KernelFunctor, KernelFunctor->bases(), handlers...); +} + // A visitor function that dispatches to functions as defined in // SyclKernelFieldHandler for the purposes of kernel generation. template -static void VisitRecordFields(RecordDecl::field_range Fields, - Handlers &... handlers) { +static void VisitRecordFields(CXXRecordDecl *Owner, Handlers &... handlers) { #define KF_FOR_EACH(FUNC) \ (void)std::initializer_list { (handlers.FUNC(Field, FieldTy), 0)... } - for (const auto &Field : Fields) { + for (const auto &Field : Owner->fields()) { QualType FieldTy = Field->getType(); if (Util::isSyclAccessorType(FieldTy)) + // FIXME: Does this still work? Check KF_FOR_EACH(handleSyclAccessorType); else if (Util::isSyclSamplerType(FieldTy)) + // FIXME: Does this still work? Check KF_FOR_EACH(handleSyclSamplerType); else if (Util::isSyclSpecConstantType(FieldTy)) + // FIXME: Does this still work? Check KF_FOR_EACH(handleSyclSpecConstantType); else if (Util::isSyclStreamType(FieldTy)) { // Stream actually wraps accessors, so do recursion + // FIXME: Does this still work? Check CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl(); - VisitAccessorWrapper(nullptr, Field, RD, handlers...); + VisitRecord(nullptr, Field, RD, handlers...); KF_FOR_EACH(handleSyclStreamType); } else if (FieldTy->isStructureOrClassType()) { + // Handle diagnostics for non-standard layout KF_FOR_EACH(handleStructType); CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl(); - VisitAccessorWrapper(nullptr, Field, RD, handlers...); + VisitRecord(Owner, Field, RD, handlers...); } else if (FieldTy->isReferenceType()) KF_FOR_EACH(handleReferenceType); else if (FieldTy->isPointerType()) @@ -990,7 +1009,7 @@ class SyclKernelDeclCreator } void handleStructType(FieldDecl *FD, QualType FieldTy) final { - addParam(FD, FieldTy); + // addParam(FD, FieldTy); } void handleSyclStreamType(FieldDecl *FD, QualType FieldTy) final { @@ -1169,16 +1188,16 @@ class SyclKernelBodyCreator void handleSpecialType(FieldDecl *FD, QualType Ty) { const auto *RecordDecl = Ty->getAsCXXRecordDecl(); // Perform initialization only if it is field of kernel object - if (MemberExprBases.size() == 1) { - InitializedEntity Entity = - InitializedEntity::InitializeMember(FD, &VarEntity); - // Initialize with the default constructor. - InitializationKind InitKind = - InitializationKind::CreateDefault(SourceLocation()); - InitializationSequence InitSeq(SemaRef, Entity, InitKind, None); - ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, None); - InitExprs.push_back(MemberInit.get()); - } + // if (MemberExprBases.size() == 1) { + InitializedEntity Entity = + InitializedEntity::InitializeMember(FD, &VarEntity); + // Initialize with the default constructor. + InitializationKind InitKind = + InitializationKind::CreateDefault(SourceLocation()); + InitializationSequence InitSeq(SemaRef, Entity, InitKind, None); + ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, None); + InitExprs.push_back(MemberInit.get()); + // } createSpecialMethodCall(RecordDecl, MemberExprBases.back(), InitMethodName, FD); } @@ -1241,7 +1260,7 @@ class SyclKernelBodyCreator } void handleStructType(FieldDecl *FD, QualType FieldTy) final { - createExprForStructOrScalar(FD); + // createExprForStructOrScalar(FD); } void handleScalarType(FieldDecl *FD, QualType FieldTy) final { @@ -1252,8 +1271,41 @@ class SyclKernelBodyCreator MemberExprBases.push_back(BuildMemberExpr(MemberExprBases.back(), FD)); } + void addStructInit(const CXXRecordDecl *RD) { + if (!RD) + return; + + int NumberOfFields = getFieldNumber(RD); + int popOut = NumberOfFields + RD->getNumBases(); + llvm::SmallVector BaseInitExprs; + for (int I = 0; I < popOut; I++) { + BaseInitExprs.push_back(InitExprs.back()); + InitExprs.pop_back(); + } + std::reverse(BaseInitExprs.begin(), BaseInitExprs.end()); + + Expr *ILE = new (SemaRef.getASTContext()) + InitListExpr(SemaRef.getASTContext(), SourceLocation(), BaseInitExprs, + SourceLocation()); + ILE->setType(QualType(RD->getTypeForDecl(), 0)); + InitExprs.push_back(ILE); + + //MemberExprBases.pop_back(); + + } + void leaveStruct(const CXXRecordDecl *, FieldDecl *FD) final { - MemberExprBases.pop_back(); + + const CXXRecordDecl *RD = FD->getType()->getAsCXXRecordDecl(); + + addStructInit(RD); + + } + + void leaveStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final { + + const CXXRecordDecl *BaseClass = BS.getType()->getAsCXXRecordDecl(); + addStructInit(BaseClass); } using SyclKernelFieldHandler::enterStruct; @@ -1356,7 +1408,7 @@ class SyclKernelIntHeaderCreator addParam(FD, FieldTy, SYCLIntegrationHeader::kind_pointer); } void handleStructType(FieldDecl *FD, QualType FieldTy) final { - addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout); + // addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout); } void handleScalarType(FieldDecl *FD, QualType FieldTy) final { addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout); @@ -1447,7 +1499,9 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc, StableName); ConstructingOpenCLKernel = true; - VisitRecordFields(KernelLambda->fields(), checker, kernel_decl, kernel_body, + VisitFunctorBases(KernelLambda, checker, kernel_decl, kernel_body, + int_header); + VisitRecordFields(KernelLambda, checker, kernel_decl, kernel_body, int_header); ConstructingOpenCLKernel = false; }