diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst index f7c3194c91fa3..a300829cf0e32 100644 --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -246,6 +246,9 @@ Attribute Changes in Clang instantiation by accidentally allowing it in C++ in some circumstances. (#GH106864) +- Introduced a new attribute ``[[clang::coro_await_elidable]]`` on coroutine return types + to express elideability at call sites where the coroutine is co_awaited as a prvalue. + Improvements to Clang's diagnostics ----------------------------------- diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h index 65104acda9382..66c746cc25040 100644 --- a/clang/include/clang/AST/Expr.h +++ b/clang/include/clang/AST/Expr.h @@ -2991,6 +2991,9 @@ class CallExpr : public Expr { bool hasStoredFPFeatures() const { return CallExprBits.HasFPFeatures; } + bool isCoroElideSafe() const { return CallExprBits.IsCoroElideSafe; } + void setCoroElideSafe(bool V = true) { CallExprBits.IsCoroElideSafe = V; } + Decl *getCalleeDecl() { return getCallee()->getReferencedDeclOfCallee(); } const Decl *getCalleeDecl() const { return getCallee()->getReferencedDeclOfCallee(); diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h index f1a2aac0a8b2f..7aed83e9c68bb 100644 --- a/clang/include/clang/AST/Stmt.h +++ b/clang/include/clang/AST/Stmt.h @@ -561,8 +561,11 @@ class alignas(void *) Stmt { LLVM_PREFERRED_TYPE(bool) unsigned HasFPFeatures : 1; + /// True if the call expression is a must-elide call to a coroutine. + unsigned IsCoroElideSafe : 1; + /// Padding used to align OffsetToTrailingObjects to a byte multiple. - unsigned : 24 - 3 - NumExprBits; + unsigned : 24 - 4 - NumExprBits; /// The offset in bytes from the this pointer to the start of the /// trailing objects belonging to CallExpr. Intentionally byte sized diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 0c98f8e25a6fb..9a7b163b2c6da 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -1250,6 +1250,14 @@ def CoroDisableLifetimeBound : InheritableAttr { let SimpleHandler = 1; } +def CoroAwaitElidable : InheritableAttr { + let Spellings = [Clang<"coro_await_elidable">]; + let Subjects = SubjectList<[CXXRecord]>; + let LangOpts = [CPlusPlus]; + let Documentation = [CoroAwaitElidableDoc]; + let SimpleHandler = 1; +} + // OSObject-based attributes. def OSConsumed : InheritableParamAttr { let Spellings = [Clang<"os_consumed">]; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td index ef077db298831..546e5100b79dd 100644 --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -8255,6 +8255,38 @@ but do not pass them to the underlying coroutine or pass them by value. }]; } +def CoroAwaitElidableDoc : Documentation { + let Category = DocCatDecl; + let Content = [{ +The ``[[clang::coro_await_elidable]]`` is a class attribute which can be applied +to a coroutine return type. + +When a coroutine function that returns such a type calls another coroutine function, +the compiler performs heap allocation elision when the call to the coroutine function +is immediately co_awaited as a prvalue. In this case, the coroutine frame for the +callee will be a local variable within the enclosing braces in the caller's stack +frame. And the local variable, like other variables in coroutines, may be collected +into the coroutine frame, which may be allocated on the heap. + +Example: + +.. code-block:: c++ + + class [[clang::coro_await_elidable]] Task { ... }; + + Task foo(); + Task bar() { + co_await foo(); // foo()'s coroutine frame on this line is elidable + auto t = foo(); // foo()'s coroutine frame on this line is NOT elidable + co_await t; + } + +The behavior is undefined if the caller coroutine is destroyed earlier than the +callee coroutine. + +}]; +} + def CountedByDocs : Documentation { let Category = DocCatField; let Content = [{ @@ -8414,4 +8446,3 @@ Declares that a function potentially allocates heap memory, and prevents any pot of ``nonallocating`` by the compiler. }]; } - diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp index 27930db019a17..6545912ed160d 100644 --- a/clang/lib/AST/Expr.cpp +++ b/clang/lib/AST/Expr.cpp @@ -1475,6 +1475,7 @@ CallExpr::CallExpr(StmtClass SC, Expr *Fn, ArrayRef PreArgs, this->computeDependence(); CallExprBits.HasFPFeatures = FPFeatures.requiresTrailingStorage(); + CallExprBits.IsCoroElideSafe = false; if (hasStoredFPFeatures()) setStoredFPFeatures(FPFeatures); } @@ -1490,6 +1491,7 @@ CallExpr::CallExpr(StmtClass SC, unsigned NumPreArgs, unsigned NumArgs, assert((CallExprBits.OffsetToTrailingObjects == OffsetToTrailingObjects) && "OffsetToTrailingObjects overflow!"); CallExprBits.HasFPFeatures = HasFPFeatures; + CallExprBits.IsCoroElideSafe = false; } CallExpr *CallExpr::Create(const ASTContext &Ctx, Expr *Fn, diff --git a/clang/lib/CodeGen/CGBlocks.cpp b/clang/lib/CodeGen/CGBlocks.cpp index 066139b1c78c7..684fda7440731 100644 --- a/clang/lib/CodeGen/CGBlocks.cpp +++ b/clang/lib/CodeGen/CGBlocks.cpp @@ -1163,7 +1163,8 @@ llvm::Type *CodeGenModule::getGenericBlockLiteralType() { } RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { const auto *BPT = E->getCallee()->getType()->castAs(); llvm::Value *BlockPtr = EmitScalarExpr(E->getCallee()); llvm::Type *GenBlockTy = CGM.getGenericBlockLiteralType(); @@ -1220,7 +1221,7 @@ RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E, CGCallee Callee(CGCalleeInfo(), Func); // And call the block. - return EmitCall(FnInfo, Callee, ReturnValue, Args); + return EmitCall(FnInfo, Callee, ReturnValue, Args, CallOrInvoke); } Address CodeGenFunction::GetAddrOfBlockDecl(const VarDecl *variable) { diff --git a/clang/lib/CodeGen/CGCUDARuntime.cpp b/clang/lib/CodeGen/CGCUDARuntime.cpp index c14a9d3f2bbbc..1e1da1e2411a7 100644 --- a/clang/lib/CodeGen/CGCUDARuntime.cpp +++ b/clang/lib/CodeGen/CGCUDARuntime.cpp @@ -25,7 +25,8 @@ CGCUDARuntime::~CGCUDARuntime() {} RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF, const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { llvm::BasicBlock *ConfigOKBlock = CGF.createBasicBlock("kcall.configok"); llvm::BasicBlock *ContBlock = CGF.createBasicBlock("kcall.end"); @@ -35,7 +36,7 @@ RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF, eval.begin(CGF); CGF.EmitBlock(ConfigOKBlock); - CGF.EmitSimpleCallExpr(E, ReturnValue); + CGF.EmitSimpleCallExpr(E, ReturnValue, CallOrInvoke); CGF.EmitBranch(ContBlock); CGF.EmitBlock(ContBlock); diff --git a/clang/lib/CodeGen/CGCUDARuntime.h b/clang/lib/CodeGen/CGCUDARuntime.h index 8030d632cc3d2..86f776004ee7c 100644 --- a/clang/lib/CodeGen/CGCUDARuntime.h +++ b/clang/lib/CodeGen/CGCUDARuntime.h @@ -21,6 +21,7 @@ #include "llvm/IR/GlobalValue.h" namespace llvm { +class CallBase; class Function; class GlobalVariable; } @@ -82,9 +83,10 @@ class CGCUDARuntime { CGCUDARuntime(CodeGenModule &CGM) : CGM(CGM) {} virtual ~CGCUDARuntime(); - virtual RValue EmitCUDAKernelCallExpr(CodeGenFunction &CGF, - const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue); + virtual RValue + EmitCUDAKernelCallExpr(CodeGenFunction &CGF, const CUDAKernelCallExpr *E, + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke = nullptr); /// Emits a kernel launch stub. virtual void emitDeviceStub(CodeGenFunction &CGF, FunctionArgList &Args) = 0; diff --git a/clang/lib/CodeGen/CGCXXABI.h b/clang/lib/CodeGen/CGCXXABI.h index 7dcc539111996..687ff7fb84444 100644 --- a/clang/lib/CodeGen/CGCXXABI.h +++ b/clang/lib/CodeGen/CGCXXABI.h @@ -485,11 +485,11 @@ class CGCXXABI { llvm::PointerUnion; /// Emit the ABI-specific virtual destructor call. - virtual llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF, - const CXXDestructorDecl *Dtor, - CXXDtorType DtorType, - Address This, - DeleteOrMemberCallExpr E) = 0; + virtual llvm::Value * + EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, + CXXDtorType DtorType, Address This, + DeleteOrMemberCallExpr E, + llvm::CallBase **CallOrInvoke) = 0; virtual void adjustCallArgsForDestructorThunk(CodeGenFunction &CGF, GlobalDecl GD, diff --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp index e5ba50de3462d..352955749a633 100644 --- a/clang/lib/CodeGen/CGClass.cpp +++ b/clang/lib/CodeGen/CGClass.cpp @@ -2192,15 +2192,11 @@ static bool canEmitDelegateCallArgs(CodeGenFunction &CGF, return true; } -void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D, - CXXCtorType Type, - bool ForVirtualBase, - bool Delegating, - Address This, - CallArgList &Args, - AggValueSlot::Overlap_t Overlap, - SourceLocation Loc, - bool NewPointerIsChecked) { +void CodeGenFunction::EmitCXXConstructorCall( + const CXXConstructorDecl *D, CXXCtorType Type, bool ForVirtualBase, + bool Delegating, Address This, CallArgList &Args, + AggValueSlot::Overlap_t Overlap, SourceLocation Loc, + bool NewPointerIsChecked, llvm::CallBase **CallOrInvoke) { const CXXRecordDecl *ClassDecl = D->getParent(); if (!NewPointerIsChecked) @@ -2248,7 +2244,7 @@ void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D, const CGFunctionInfo &Info = CGM.getTypes().arrangeCXXConstructorCall( Args, D, Type, ExtraArgs.Prefix, ExtraArgs.Suffix, PassPrototypeArgs); CGCallee Callee = CGCallee::forDirect(CalleePtr, GlobalDecl(D, Type)); - EmitCall(Info, Callee, ReturnValueSlot(), Args, nullptr, false, Loc); + EmitCall(Info, Callee, ReturnValueSlot(), Args, CallOrInvoke, false, Loc); // Generate vtable assumptions if we're constructing a complete object // with a vtable. We don't do this for base subobjects for two reasons: diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 99cd61b9e7895..35b5daaf6d4b5 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -33,6 +33,7 @@ #include "clang/Basic/SourceManager.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringExtras.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Intrinsics.h" @@ -5544,16 +5545,30 @@ RValue CodeGenFunction::EmitRValueForField(LValue LV, //===--------------------------------------------------------------------===// RValue CodeGenFunction::EmitCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { + llvm::CallBase *CallOrInvokeStorage; + if (!CallOrInvoke) { + CallOrInvoke = &CallOrInvokeStorage; + } + + auto AddCoroElideSafeOnExit = llvm::make_scope_exit([&] { + if (E->isCoroElideSafe()) { + auto *I = *CallOrInvoke; + if (I) + I->addFnAttr(llvm::Attribute::CoroElideSafe); + } + }); + // Builtins never have block type. if (E->getCallee()->getType()->isBlockPointerType()) - return EmitBlockCallExpr(E, ReturnValue); + return EmitBlockCallExpr(E, ReturnValue, CallOrInvoke); if (const auto *CE = dyn_cast(E)) - return EmitCXXMemberCallExpr(CE, ReturnValue); + return EmitCXXMemberCallExpr(CE, ReturnValue, CallOrInvoke); if (const auto *CE = dyn_cast(E)) - return EmitCUDAKernelCallExpr(CE, ReturnValue); + return EmitCUDAKernelCallExpr(CE, ReturnValue, CallOrInvoke); // A CXXOperatorCallExpr is created even for explicit object methods, but // these should be treated like static function call. @@ -5561,7 +5576,7 @@ RValue CodeGenFunction::EmitCallExpr(const CallExpr *E, if (const auto *MD = dyn_cast_if_present(CE->getCalleeDecl()); MD && MD->isImplicitObjectMemberFunction()) - return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue); + return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue, CallOrInvoke); CGCallee callee = EmitCallee(E->getCallee()); @@ -5574,14 +5589,17 @@ RValue CodeGenFunction::EmitCallExpr(const CallExpr *E, return EmitCXXPseudoDestructorExpr(callee.getPseudoDestructorExpr()); } - return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue); + return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue, + /*Chain=*/nullptr, CallOrInvoke); } /// Emit a CallExpr without considering whether it might be a subclass. RValue CodeGenFunction::EmitSimpleCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { CGCallee Callee = EmitCallee(E->getCallee()); - return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue); + return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue, + /*Chain=*/nullptr, CallOrInvoke); } // Detect the unusual situation where an inline version is shadowed by a @@ -5785,8 +5803,9 @@ LValue CodeGenFunction::EmitBinaryOperatorLValue(const BinaryOperator *E) { llvm_unreachable("bad evaluation kind"); } -LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E) { - RValue RV = EmitCallExpr(E); +LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E, + llvm::CallBase **CallOrInvoke) { + RValue RV = EmitCallExpr(E, ReturnValueSlot(), CallOrInvoke); if (!RV.isScalar()) return MakeAddrLValue(RV.getAggregateAddress(), E->getType(), @@ -5909,9 +5928,11 @@ LValue CodeGenFunction::EmitStmtExprLValue(const StmtExpr *E) { AlignmentSource::Decl); } -RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee, - const CallExpr *E, ReturnValueSlot ReturnValue, - llvm::Value *Chain) { +RValue CodeGenFunction::EmitCall(QualType CalleeType, + const CGCallee &OrigCallee, const CallExpr *E, + ReturnValueSlot ReturnValue, + llvm::Value *Chain, + llvm::CallBase **CallOrInvoke) { // Get the actual function type. The callee type will always be a pointer to // function type or a block pointer type. assert(CalleeType->isFunctionPointerType() && @@ -6131,8 +6152,8 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee Address(Handle, Handle->getType(), CGM.getPointerAlign())); Callee.setFunctionPointer(Stub); } - llvm::CallBase *CallOrInvoke = nullptr; - RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &CallOrInvoke, + llvm::CallBase *LocalCallOrInvoke = nullptr; + RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &LocalCallOrInvoke, E == MustTailCall, E->getExprLoc()); // Generate function declaration DISuprogram in order to be used @@ -6141,11 +6162,13 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee if (auto *CalleeDecl = dyn_cast_or_null(TargetDecl)) { FunctionArgList Args; QualType ResTy = BuildFunctionArgList(CalleeDecl, Args); - DI->EmitFuncDeclForCallSite(CallOrInvoke, + DI->EmitFuncDeclForCallSite(LocalCallOrInvoke, DI->getFunctionType(CalleeDecl, ResTy, Args), CalleeDecl); } } + if (CallOrInvoke) + *CallOrInvoke = LocalCallOrInvoke; return Call; } diff --git a/clang/lib/CodeGen/CGExprCXX.cpp b/clang/lib/CodeGen/CGExprCXX.cpp index 8eb6ab7381acb..1214bb054fb8d 100644 --- a/clang/lib/CodeGen/CGExprCXX.cpp +++ b/clang/lib/CodeGen/CGExprCXX.cpp @@ -84,23 +84,24 @@ commonEmitCXXMemberOrOperatorCall(CodeGenFunction &CGF, GlobalDecl GD, RValue CodeGenFunction::EmitCXXMemberOrOperatorCall( const CXXMethodDecl *MD, const CGCallee &Callee, - ReturnValueSlot ReturnValue, - llvm::Value *This, llvm::Value *ImplicitParam, QualType ImplicitParamTy, - const CallExpr *CE, CallArgList *RtlArgs) { + ReturnValueSlot ReturnValue, llvm::Value *This, llvm::Value *ImplicitParam, + QualType ImplicitParamTy, const CallExpr *CE, CallArgList *RtlArgs, + llvm::CallBase **CallOrInvoke) { const FunctionProtoType *FPT = MD->getType()->castAs(); CallArgList Args; MemberCallInfo CallInfo = commonEmitCXXMemberOrOperatorCall( *this, MD, This, ImplicitParam, ImplicitParamTy, CE, Args, RtlArgs); auto &FnInfo = CGM.getTypes().arrangeCXXMethodCall( Args, FPT, CallInfo.ReqArgs, CallInfo.PrefixSize); - return EmitCall(FnInfo, Callee, ReturnValue, Args, nullptr, + return EmitCall(FnInfo, Callee, ReturnValue, Args, CallOrInvoke, CE && CE == MustTailCall, CE ? CE->getExprLoc() : SourceLocation()); } RValue CodeGenFunction::EmitCXXDestructorCall( GlobalDecl Dtor, const CGCallee &Callee, llvm::Value *This, QualType ThisTy, - llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *CE) { + llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *CE, + llvm::CallBase **CallOrInvoke) { const CXXMethodDecl *DtorDecl = cast(Dtor.getDecl()); assert(!ThisTy.isNull()); @@ -120,7 +121,8 @@ RValue CodeGenFunction::EmitCXXDestructorCall( commonEmitCXXMemberOrOperatorCall(*this, Dtor, This, ImplicitParam, ImplicitParamTy, CE, Args, nullptr); return EmitCall(CGM.getTypes().arrangeCXXStructorDeclaration(Dtor), Callee, - ReturnValueSlot(), Args, nullptr, CE && CE == MustTailCall, + ReturnValueSlot(), Args, CallOrInvoke, + CE && CE == MustTailCall, CE ? CE->getExprLoc() : SourceLocation{}); } @@ -186,11 +188,12 @@ static CXXRecordDecl *getCXXRecord(const Expr *E) { // Note: This function also emit constructor calls to support a MSVC // extensions allowing explicit constructor function call. RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { const Expr *callee = CE->getCallee()->IgnoreParens(); if (isa(callee)) - return EmitCXXMemberPointerCallExpr(CE, ReturnValue); + return EmitCXXMemberPointerCallExpr(CE, ReturnValue, CallOrInvoke); const MemberExpr *ME = cast(callee); const CXXMethodDecl *MD = cast(ME->getMemberDecl()); @@ -200,7 +203,7 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE, CGCallee callee = CGCallee::forDirect(CGM.GetAddrOfFunction(MD), GlobalDecl(MD)); return EmitCall(getContext().getPointerType(MD->getType()), callee, CE, - ReturnValue); + ReturnValue, /*Chain=*/nullptr, CallOrInvoke); } bool HasQualifier = ME->hasQualifier(); @@ -208,14 +211,15 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE, bool IsArrow = ME->isArrow(); const Expr *Base = ME->getBase(); - return EmitCXXMemberOrOperatorMemberCallExpr( - CE, MD, ReturnValue, HasQualifier, Qualifier, IsArrow, Base); + return EmitCXXMemberOrOperatorMemberCallExpr(CE, MD, ReturnValue, + HasQualifier, Qualifier, IsArrow, + Base, CallOrInvoke); } RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( const CallExpr *CE, const CXXMethodDecl *MD, ReturnValueSlot ReturnValue, bool HasQualifier, NestedNameSpecifier *Qualifier, bool IsArrow, - const Expr *Base) { + const Expr *Base, llvm::CallBase **CallOrInvoke) { assert(isa(CE) || isa(CE)); // Compute the object pointer. @@ -300,7 +304,7 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( EmitCXXConstructorCall(Ctor, Ctor_Complete, /*ForVirtualBase=*/false, /*Delegating=*/false, This.getAddress(), Args, AggValueSlot::DoesNotOverlap, CE->getExprLoc(), - /*NewPointerIsChecked=*/false); + /*NewPointerIsChecked=*/false, CallOrInvoke); return RValue::get(nullptr); } @@ -374,9 +378,9 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( "Destructor shouldn't have explicit parameters"); assert(ReturnValue.isNull() && "Destructor shouldn't have return value"); if (UseVirtualCall) { - CGM.getCXXABI().EmitVirtualDestructorCall(*this, Dtor, Dtor_Complete, - This.getAddress(), - cast(CE)); + CGM.getCXXABI().EmitVirtualDestructorCall( + *this, Dtor, Dtor_Complete, This.getAddress(), + cast(CE), CallOrInvoke); } else { GlobalDecl GD(Dtor, Dtor_Complete); CGCallee Callee; @@ -393,7 +397,7 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( IsArrow ? Base->getType()->getPointeeType() : Base->getType(); EmitCXXDestructorCall(GD, Callee, This.getPointer(*this), ThisTy, /*ImplicitParam=*/nullptr, - /*ImplicitParamTy=*/QualType(), CE); + /*ImplicitParamTy=*/QualType(), CE, CallOrInvoke); } return RValue::get(nullptr); } @@ -435,12 +439,13 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( return EmitCXXMemberOrOperatorCall( CalleeDecl, Callee, ReturnValue, This.getPointer(*this), - /*ImplicitParam=*/nullptr, QualType(), CE, RtlArgs); + /*ImplicitParam=*/nullptr, QualType(), CE, RtlArgs, CallOrInvoke); } RValue CodeGenFunction::EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { const BinaryOperator *BO = cast(E->getCallee()->IgnoreParens()); const Expr *BaseExpr = BO->getLHS(); @@ -484,24 +489,25 @@ CodeGenFunction::EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E, EmitCallArgs(Args, FPT, E->arguments()); return EmitCall(CGM.getTypes().arrangeCXXMethodCall(Args, FPT, required, /*PrefixSize=*/0), - Callee, ReturnValue, Args, nullptr, E == MustTailCall, + Callee, ReturnValue, Args, CallOrInvoke, E == MustTailCall, E->getExprLoc()); } -RValue -CodeGenFunction::EmitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *E, - const CXXMethodDecl *MD, - ReturnValueSlot ReturnValue) { +RValue CodeGenFunction::EmitCXXOperatorMemberCallExpr( + const CXXOperatorCallExpr *E, const CXXMethodDecl *MD, + ReturnValueSlot ReturnValue, llvm::CallBase **CallOrInvoke) { assert(MD->isImplicitObjectMemberFunction() && "Trying to emit a member call expr on a static method!"); return EmitCXXMemberOrOperatorMemberCallExpr( E, MD, ReturnValue, /*HasQualifier=*/false, /*Qualifier=*/nullptr, - /*IsArrow=*/false, E->getArg(0)); + /*IsArrow=*/false, E->getArg(0), CallOrInvoke); } RValue CodeGenFunction::EmitCUDAKernelCallExpr(const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue) { - return CGM.getCUDARuntime().EmitCUDAKernelCallExpr(*this, E, ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { + return CGM.getCUDARuntime().EmitCUDAKernelCallExpr(*this, E, ReturnValue, + CallOrInvoke); } static void EmitNullBaseClassInitialization(CodeGenFunction &CGF, diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 9b93e9673ec5f..5892d6ac6f88a 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -3149,7 +3149,8 @@ class CodeGenFunction : public CodeGenTypeCache { bool ForVirtualBase, bool Delegating, Address This, CallArgList &Args, AggValueSlot::Overlap_t Overlap, - SourceLocation Loc, bool NewPointerIsChecked); + SourceLocation Loc, bool NewPointerIsChecked, + llvm::CallBase **CallOrInvoke = nullptr); /// Emit assumption load for all bases. Requires to be called only on /// most-derived class and not under construction of the object. @@ -4269,7 +4270,8 @@ class CodeGenFunction : public CodeGenTypeCache { LValue EmitBinaryOperatorLValue(const BinaryOperator *E); LValue EmitCompoundAssignmentLValue(const CompoundAssignOperator *E); // Note: only available for agg return types - LValue EmitCallExprLValue(const CallExpr *E); + LValue EmitCallExprLValue(const CallExpr *E, + llvm::CallBase **CallOrInvoke = nullptr); // Note: only available for agg return types LValue EmitVAArgExprLValue(const VAArgExpr *E); LValue EmitDeclRefLValue(const DeclRefExpr *E); @@ -4382,21 +4384,27 @@ class CodeGenFunction : public CodeGenTypeCache { /// LLVM arguments and the types they were derived from. RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee, ReturnValueSlot ReturnValue, const CallArgList &Args, - llvm::CallBase **callOrInvoke, bool IsMustTail, + llvm::CallBase **CallOrInvoke, bool IsMustTail, SourceLocation Loc, bool IsVirtualFunctionPointerThunk = false); RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee, ReturnValueSlot ReturnValue, const CallArgList &Args, - llvm::CallBase **callOrInvoke = nullptr, + llvm::CallBase **CallOrInvoke = nullptr, bool IsMustTail = false) { - return EmitCall(CallInfo, Callee, ReturnValue, Args, callOrInvoke, + return EmitCall(CallInfo, Callee, ReturnValue, Args, CallOrInvoke, IsMustTail, SourceLocation()); } RValue EmitCall(QualType FnType, const CGCallee &Callee, const CallExpr *E, - ReturnValueSlot ReturnValue, llvm::Value *Chain = nullptr); + ReturnValueSlot ReturnValue, llvm::Value *Chain = nullptr, + llvm::CallBase **CallOrInvoke = nullptr); + + // If a Call or Invoke instruction was emitted for this CallExpr, this method + // writes the pointer to `CallOrInvoke` if it's not null. RValue EmitCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue = ReturnValueSlot()); - RValue EmitSimpleCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue = ReturnValueSlot(), + llvm::CallBase **CallOrInvoke = nullptr); + RValue EmitSimpleCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke = nullptr); CGCallee EmitCallee(const Expr *E); void checkTargetFeatures(const CallExpr *E, const FunctionDecl *TargetDecl); @@ -4500,25 +4508,23 @@ class CodeGenFunction : public CodeGenTypeCache { void callCStructCopyAssignmentOperator(LValue Dst, LValue Src); void callCStructMoveAssignmentOperator(LValue Dst, LValue Src); - RValue - EmitCXXMemberOrOperatorCall(const CXXMethodDecl *Method, - const CGCallee &Callee, - ReturnValueSlot ReturnValue, llvm::Value *This, - llvm::Value *ImplicitParam, - QualType ImplicitParamTy, const CallExpr *E, - CallArgList *RtlArgs); + RValue EmitCXXMemberOrOperatorCall( + const CXXMethodDecl *Method, const CGCallee &Callee, + ReturnValueSlot ReturnValue, llvm::Value *This, + llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *E, + CallArgList *RtlArgs, llvm::CallBase **CallOrInvoke); RValue EmitCXXDestructorCall(GlobalDecl Dtor, const CGCallee &Callee, llvm::Value *This, QualType ThisTy, llvm::Value *ImplicitParam, - QualType ImplicitParamTy, const CallExpr *E); + QualType ImplicitParamTy, const CallExpr *E, + llvm::CallBase **CallOrInvoke = nullptr); RValue EmitCXXMemberCallExpr(const CXXMemberCallExpr *E, - ReturnValueSlot ReturnValue); - RValue EmitCXXMemberOrOperatorMemberCallExpr(const CallExpr *CE, - const CXXMethodDecl *MD, - ReturnValueSlot ReturnValue, - bool HasQualifier, - NestedNameSpecifier *Qualifier, - bool IsArrow, const Expr *Base); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke = nullptr); + RValue EmitCXXMemberOrOperatorMemberCallExpr( + const CallExpr *CE, const CXXMethodDecl *MD, ReturnValueSlot ReturnValue, + bool HasQualifier, NestedNameSpecifier *Qualifier, bool IsArrow, + const Expr *Base, llvm::CallBase **CallOrInvoke); // Compute the object pointer. Address EmitCXXMemberDataPointerAddress(const Expr *E, Address base, llvm::Value *memberPtr, @@ -4526,15 +4532,18 @@ class CodeGenFunction : public CodeGenTypeCache { LValueBaseInfo *BaseInfo = nullptr, TBAAAccessInfo *TBAAInfo = nullptr); RValue EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E, - ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); RValue EmitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *E, const CXXMethodDecl *MD, - ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); RValue EmitCXXPseudoDestructorExpr(const CXXPseudoDestructorExpr *E); RValue EmitCUDAKernelCallExpr(const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); RValue EmitNVPTXDevicePrintfCallExpr(const CallExpr *E); RValue EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E); @@ -4556,7 +4565,8 @@ class CodeGenFunction : public CodeGenTypeCache { const analyze_os_log::OSLogBufferLayout &Layout, CharUnits BufferAlignment); - RValue EmitBlockCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue); + RValue EmitBlockCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); /// EmitTargetBuiltinExpr - Emit the given builtin call. Returns 0 if the call /// is unhandled by the current target. diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp index fb1eb72d9f340..dcc35d5689831 100644 --- a/clang/lib/CodeGen/ItaniumCXXABI.cpp +++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -315,10 +315,11 @@ class ItaniumCXXABI : public CodeGen::CGCXXABI { Address This, llvm::Type *Ty, SourceLocation Loc) override; - llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF, - const CXXDestructorDecl *Dtor, - CXXDtorType DtorType, Address This, - DeleteOrMemberCallExpr E) override; + llvm::Value * + EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, + CXXDtorType DtorType, Address This, + DeleteOrMemberCallExpr E, + llvm::CallBase **CallOrInvoke) override; void emitVirtualInheritanceTables(const CXXRecordDecl *RD) override; @@ -1399,7 +1400,8 @@ void ItaniumCXXABI::emitVirtualObjectDelete(CodeGenFunction &CGF, // FIXME: Provide a source location here even though there's no // CXXMemberCallExpr for dtor call. CXXDtorType DtorType = UseGlobalDelete ? Dtor_Complete : Dtor_Deleting; - EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE); + EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE, + /*CallOrInvoke=*/nullptr); if (UseGlobalDelete) CGF.PopCleanupBlock(); @@ -2236,7 +2238,7 @@ CGCallee ItaniumCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF, llvm::Value *ItaniumCXXABI::EmitVirtualDestructorCall( CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, CXXDtorType DtorType, - Address This, DeleteOrMemberCallExpr E) { + Address This, DeleteOrMemberCallExpr E, llvm::CallBase **CallOrInvoke) { auto *CE = E.dyn_cast(); auto *D = E.dyn_cast(); assert((CE != nullptr) ^ (D != nullptr)); @@ -2257,7 +2259,7 @@ llvm::Value *ItaniumCXXABI::EmitVirtualDestructorCall( } CGF.EmitCXXDestructorCall(GD, Callee, This.emitRawPointer(CGF), ThisTy, - nullptr, QualType(), nullptr); + nullptr, QualType(), nullptr, CallOrInvoke); return nullptr; } diff --git a/clang/lib/CodeGen/MicrosoftCXXABI.cpp b/clang/lib/CodeGen/MicrosoftCXXABI.cpp index 76d0191a7e63a..79dcdc04b0996 100644 --- a/clang/lib/CodeGen/MicrosoftCXXABI.cpp +++ b/clang/lib/CodeGen/MicrosoftCXXABI.cpp @@ -334,10 +334,11 @@ class MicrosoftCXXABI : public CGCXXABI { Address This, llvm::Type *Ty, SourceLocation Loc) override; - llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF, - const CXXDestructorDecl *Dtor, - CXXDtorType DtorType, Address This, - DeleteOrMemberCallExpr E) override; + llvm::Value * + EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, + CXXDtorType DtorType, Address This, + DeleteOrMemberCallExpr E, + llvm::CallBase **CallOrInvoke) override; void adjustCallArgsForDestructorThunk(CodeGenFunction &CGF, GlobalDecl GD, CallArgList &CallArgs) override { @@ -901,7 +902,8 @@ void MicrosoftCXXABI::emitVirtualObjectDelete(CodeGenFunction &CGF, // CXXMemberCallExpr for dtor call. bool UseGlobalDelete = DE->isGlobalDelete(); CXXDtorType DtorType = UseGlobalDelete ? Dtor_Complete : Dtor_Deleting; - llvm::Value *MDThis = EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE); + llvm::Value *MDThis = EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE, + /*CallOrInvoke=*/nullptr); if (UseGlobalDelete) CGF.EmitDeleteCall(DE->getOperatorDelete(), MDThis, ElementType); } @@ -1685,7 +1687,7 @@ void MicrosoftCXXABI::EmitDestructorCall(CodeGenFunction &CGF, CGF.EmitCXXDestructorCall(GD, Callee, CGF.getAsNaturalPointerTo(This, ThisTy), ThisTy, /*ImplicitParam=*/Implicit, - /*ImplicitParamTy=*/QualType(), nullptr); + /*ImplicitParamTy=*/QualType(), /*E=*/nullptr); if (BaseDtorEndBB) { // Complete object handler should continue to be the remaining CGF.Builder.CreateBr(BaseDtorEndBB); @@ -2001,7 +2003,7 @@ CGCallee MicrosoftCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF, llvm::Value *MicrosoftCXXABI::EmitVirtualDestructorCall( CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, CXXDtorType DtorType, - Address This, DeleteOrMemberCallExpr E) { + Address This, DeleteOrMemberCallExpr E, llvm::CallBase **CallOrInvoke) { auto *CE = E.dyn_cast(); auto *D = E.dyn_cast(); assert((CE != nullptr) ^ (D != nullptr)); @@ -2031,7 +2033,7 @@ llvm::Value *MicrosoftCXXABI::EmitVirtualDestructorCall( This = adjustThisArgumentForVirtualFunctionCall(CGF, GD, This, true); RValue RV = CGF.EmitCXXDestructorCall(GD, Callee, This.emitRawPointer(CGF), ThisTy, - ImplicitParam, Context.IntTy, CE); + ImplicitParam, Context.IntTy, CE, CallOrInvoke); return RV.getScalarVal(); } diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp index 1bb8955f6f879..a574d56646f3a 100644 --- a/clang/lib/Sema/SemaCoroutine.cpp +++ b/clang/lib/Sema/SemaCoroutine.cpp @@ -844,6 +844,19 @@ ExprResult Sema::BuildOperatorCoawaitLookupExpr(Scope *S, SourceLocation Loc) { return CoawaitOp; } +static bool isAttributedCoroAwaitElidable(const QualType &QT) { + auto *Record = QT->getAsCXXRecordDecl(); + return Record && Record->hasAttr(); +} + +static bool isCoroAwaitElidableCall(Expr *Operand) { + if (!Operand->isPRValue()) { + return false; + } + + return isAttributedCoroAwaitElidable(Operand->getType()); +} + // Attempts to resolve and build a CoawaitExpr from "raw" inputs, bailing out to // DependentCoawaitExpr if needed. ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand, @@ -867,7 +880,16 @@ ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand, } auto *RD = Promise->getType()->getAsCXXRecordDecl(); - auto *Transformed = Operand; + bool AwaitElidable = + isCoroAwaitElidableCall(Operand) && + isAttributedCoroAwaitElidable( + getCurFunctionDecl(/*AllowLambda=*/true)->getReturnType()); + + if (AwaitElidable) + if (auto *Call = dyn_cast(Operand->IgnoreImplicit())) + Call->setCoroElideSafe(); + + Expr *Transformed = Operand; if (lookupMember(*this, "await_transform", RD, Loc)) { ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", Operand); diff --git a/clang/test/CodeGenCoroutines/Inputs/utility.h b/clang/test/CodeGenCoroutines/Inputs/utility.h new file mode 100644 index 0000000000000..43c6d27823bd4 --- /dev/null +++ b/clang/test/CodeGenCoroutines/Inputs/utility.h @@ -0,0 +1,13 @@ +// This is a mock file for + +namespace std { + +template struct remove_reference { using type = T; }; +template struct remove_reference { using type = T; }; +template struct remove_reference { using type = T; }; + +template +constexpr typename std::remove_reference::type&& move(T &&t) noexcept { + return static_cast::type &&>(t); +} +} diff --git a/clang/test/CodeGenCoroutines/coro-await-elidable.cpp b/clang/test/CodeGenCoroutines/coro-await-elidable.cpp new file mode 100644 index 0000000000000..8512995dfad45 --- /dev/null +++ b/clang/test/CodeGenCoroutines/coro-await-elidable.cpp @@ -0,0 +1,87 @@ +// This file tests the coro_await_elidable attribute semantics. +// RUN: %clang_cc1 -triple=x86_64-unknown-linux-gnu -std=c++20 -disable-llvm-passes -emit-llvm %s -o - | FileCheck %s + +#include "Inputs/coroutine.h" +#include "Inputs/utility.h" + +template +struct [[clang::coro_await_elidable]] Task { + struct promise_type { + struct FinalAwaiter { + bool await_ready() const noexcept { return false; } + + template + std::coroutine_handle<> await_suspend(std::coroutine_handle

coro) noexcept { + if (!coro) + return std::noop_coroutine(); + return coro.promise().continuation; + } + void await_resume() noexcept {} + }; + + Task get_return_object() noexcept { + return std::coroutine_handle::from_promise(*this); + } + + std::suspend_always initial_suspend() noexcept { return {}; } + FinalAwaiter final_suspend() noexcept { return {}; } + void unhandled_exception() noexcept {} + void return_value(T x) noexcept { + value = x; + } + + std::coroutine_handle<> continuation; + T value; + }; + + Task(std::coroutine_handle handle) : handle(handle) {} + ~Task() { + if (handle) + handle.destroy(); + } + + struct Awaiter { + Awaiter(Task *t) : task(t) {} + bool await_ready() const noexcept { return false; } + void await_suspend(std::coroutine_handle continuation) noexcept {} + T await_resume() noexcept { + return task->handle.promise().value; + } + + Task *task; + }; + + auto operator co_await() { + return Awaiter{this}; + } + +private: + std::coroutine_handle handle; +}; + +// CHECK-LABEL: define{{.*}} @_Z6calleev{{.*}} { +Task callee() { + co_return 1; +} + +// CHECK-LABEL: define{{.*}} @_Z8elidablev{{.*}} { +Task elidable() { + // CHECK: %[[TASK_OBJ:.+]] = alloca %struct.Task + // CHECK: call void @_Z6calleev(ptr dead_on_unwind writable sret(%struct.Task) align 8 %[[TASK_OBJ]]) #[[ELIDE_SAFE:.+]] + co_return co_await callee(); +} + +// CHECK-LABEL: define{{.*}} @_Z11nonelidablev{{.*}} { +Task nonelidable() { + // CHECK: %[[TASK_OBJ:.+]] = alloca %struct.Task + auto t = callee(); + // Because we aren't co_awaiting a prvalue, we cannot elide here. + // CHECK: call void @_Z6calleev(ptr dead_on_unwind writable sret(%struct.Task) align 8 %[[TASK_OBJ]]) + // CHECK-NOT: #[[ELIDE_SAFE]] + co_await t; + co_await std::move(t); + + co_return 1; +} + +// CHECK: attributes #[[ELIDE_SAFE]] = { coro_elide_safe } diff --git a/clang/test/Misc/pragma-attribute-supported-attributes-list.test b/clang/test/Misc/pragma-attribute-supported-attributes-list.test index eca8633114902..baa1816358b15 100644 --- a/clang/test/Misc/pragma-attribute-supported-attributes-list.test +++ b/clang/test/Misc/pragma-attribute-supported-attributes-list.test @@ -59,6 +59,7 @@ // CHECK-NEXT: ConsumableAutoCast (SubjectMatchRule_record) // CHECK-NEXT: ConsumableSetOnRead (SubjectMatchRule_record) // CHECK-NEXT: Convergent (SubjectMatchRule_function) +// CHECK-NEXT: CoroAwaitElidable (SubjectMatchRule_record) // CHECK-NEXT: CoroDisableLifetimeBound (SubjectMatchRule_function) // CHECK-NEXT: CoroLifetimeBound (SubjectMatchRule_record) // CHECK-NEXT: CoroOnlyDestroyWhenComplete (SubjectMatchRule_record) diff --git a/llvm/docs/Coroutines.rst b/llvm/docs/Coroutines.rst index 36092325e536f..5679aefcb421d 100644 --- a/llvm/docs/Coroutines.rst +++ b/llvm/docs/Coroutines.rst @@ -2022,6 +2022,12 @@ The pass CoroSplit builds coroutine frame and outlines resume and destroy parts into separate functions. This pass also lowers `coro.await.suspend.void`_, `coro.await.suspend.bool`_ and `coro.await.suspend.handle`_ intrinsics. +CoroAnnotationElide +------------------- +This pass finds all usages of coroutines that are "must elide" and replaces +`coro.begin` intrinsic with an address of a coroutine frame placed on its caller +and replaces `coro.alloc` and `coro.free` intrinsics with `false` and `null` +respectively to remove the deallocation code. CoroElide --------- @@ -2049,6 +2055,18 @@ the coroutine must reach the final suspend point when it get destroyed. This attribute only works for switched-resume coroutines now. +coro_elide_safe +--------------- + +When a Call or Invoke instruction to switch ABI coroutine `f` is marked with +`coro_elide_safe`, CoroSplitPass generates a `f.noalloc` ramp function. +`f.noalloc` has one more argument than its original ramp function `f`, which is +the pointer to the allocated frame. `f.noalloc` also suppressed any allocations +or deallocations that may be guarded by `@llvm.coro.alloc` and `@llvm.coro.free`. + +CoroAnnotationElidePass performs the heap elision when possible. Note that for +recursive or mutually recursive functions this elision is usually not possible. + Metadata ======== diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h index 49a48f1c1510c..05ed148148d7c 100644 --- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h +++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h @@ -761,6 +761,8 @@ enum AttributeKindCodes { ATTR_KIND_INITIALIZES = 94, ATTR_KIND_HYBRID_PATCHABLE = 95, ATTR_KIND_SANITIZE_REALTIME = 96, + ATTR_KIND_NO_SANITIZE_REALTIME = 97, + ATTR_KIND_CORO_ELIDE_SAFE = 98, }; enum ComdatSelectionKindCodes { diff --git a/llvm/include/llvm/IR/Attributes.td b/llvm/include/llvm/IR/Attributes.td index 891e34fec0c79..f3ef1e707675e 100644 --- a/llvm/include/llvm/IR/Attributes.td +++ b/llvm/include/llvm/IR/Attributes.td @@ -345,6 +345,10 @@ def PresplitCoroutine : EnumAttr<"presplitcoroutine", [FnAttr]>; /// The coroutine would only be destroyed when it is complete. def CoroDestroyOnlyWhenComplete : EnumAttr<"coro_only_destroy_when_complete", [FnAttr]>; +/// The coroutine call meets the elide requirement. Hint the optimization +/// pipeline to perform elide on the call or invoke instruction. +def CoroElideSafe : EnumAttr<"coro_elide_safe", [FnAttr]>; + /// Target-independent string attributes. def LessPreciseFPMAD : StrBoolAttr<"less-precise-fpmad">; def NoInfsFPMath : StrBoolAttr<"no-infs-fp-math">; diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h b/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h new file mode 100644 index 0000000000000..352c9e1452669 --- /dev/null +++ b/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h @@ -0,0 +1,36 @@ +//===- CoroAnnotationElide.h - Elide attributed safe coroutine calls ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file +// This pass transforms all Call or Invoke instructions that are annotated +// "coro_elide_safe" to call the `.noalloc` variant of coroutine instead. +// The frame of the callee coroutine is allocated inside the caller. A pointer +// to the allocated frame will be passed into the `.noalloc` ramp function. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H +#define LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H + +#include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/IR/PassManager.h" + +namespace llvm { + +struct CoroAnnotationElidePass : PassInfoMixin { + CoroAnnotationElidePass() {} + + PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, + LazyCallGraph &CG, CGSCCUpdateResult &UR); + + static bool isRequired() { return false; } +}; +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp index 4faee1fb9ada2..6b8edbca19a01 100644 --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -2190,6 +2190,8 @@ static Attribute::AttrKind getAttrFromCode(uint64_t Code) { return Attribute::Range; case bitc::ATTR_KIND_INITIALIZES: return Attribute::Initializes; + case bitc::ATTR_KIND_CORO_ELIDE_SAFE: + return Attribute::CoroElideSafe; } } diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp index 1aeaf0955fbbf..a5942153dc2d6 100644 --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -885,6 +885,8 @@ static uint64_t getAttrKindEncoding(Attribute::AttrKind Kind) { return bitc::ATTR_KIND_WRITABLE; case Attribute::CoroDestroyOnlyWhenComplete: return bitc::ATTR_KIND_CORO_ONLY_DESTROY_WHEN_COMPLETE; + case Attribute::CoroElideSafe: + return bitc::ATTR_KIND_CORO_ELIDE_SAFE; case Attribute::DeadOnUnwind: return bitc::ATTR_KIND_DEAD_ON_UNWIND; case Attribute::Range: diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index 83c1a6712bf4d..c34f9148cce58 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -139,6 +139,7 @@ #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" #include "llvm/Transforms/CFGuard.h" +#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h" #include "llvm/Transforms/Coroutines/CoroCleanup.h" #include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h" #include "llvm/Transforms/Coroutines/CoroEarly.h" diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp index 7f9e1362e7ef2..4e8e3dcdff442 100644 --- a/llvm/lib/Passes/PassBuilderPipelines.cpp +++ b/llvm/lib/Passes/PassBuilderPipelines.cpp @@ -33,6 +33,7 @@ #include "llvm/Support/VirtualFileSystem.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" +#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h" #include "llvm/Transforms/Coroutines/CoroCleanup.h" #include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h" #include "llvm/Transforms/Coroutines/CoroEarly.h" @@ -973,8 +974,10 @@ PassBuilder::buildInlinerPipeline(OptimizationLevel Level, MainCGPipeline.addPass(createCGSCCToFunctionPassAdaptor( RequireAnalysisPass())); - if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) + if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) { MainCGPipeline.addPass(CoroSplitPass(Level != OptimizationLevel::O0)); + MainCGPipeline.addPass(CoroAnnotationElidePass()); + } // Make sure we don't affect potential future NoRerun CGSCC adaptors. MIWP.addLateModulePass(createModuleToFunctionPassAdaptor( @@ -1016,9 +1019,12 @@ PassBuilder::buildModuleInlinerPipeline(OptimizationLevel Level, buildFunctionSimplificationPipeline(Level, Phase), PTO.EagerlyInvalidateAnalyses)); - if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) + if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) { MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor( CoroSplitPass(Level != OptimizationLevel::O0))); + MPM.addPass( + createModuleToPostOrderCGSCCPassAdaptor(CoroAnnotationElidePass())); + } return MPM; } diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index 1769c2496a7b8..4f5f680a6e953 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -243,6 +243,7 @@ CGSCC_PASS("attributor-light-cgscc", AttributorLightCGSCCPass()) CGSCC_PASS("invalidate", InvalidateAllAnalysesPass()) CGSCC_PASS("no-op-cgscc", NoOpCGSCCPass()) CGSCC_PASS("openmp-opt-cgscc", OpenMPOptCGSCCPass()) +CGSCC_PASS("coro-annotation-elide", CoroAnnotationElidePass()) #undef CGSCC_PASS #ifndef CGSCC_PASS_WITH_PARAMS diff --git a/llvm/lib/Transforms/Coroutines/CMakeLists.txt b/llvm/lib/Transforms/Coroutines/CMakeLists.txt index 2139446e5ff95..b4b5812d97d89 100644 --- a/llvm/lib/Transforms/Coroutines/CMakeLists.txt +++ b/llvm/lib/Transforms/Coroutines/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_component_library(LLVMCoroutines Coroutines.cpp + CoroAnnotationElide.cpp CoroCleanup.cpp CoroConditionalWrapper.cpp CoroEarly.cpp diff --git a/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp b/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp new file mode 100644 index 0000000000000..27a370a5d8fbf --- /dev/null +++ b/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp @@ -0,0 +1,155 @@ +//===- CoroAnnotationElide.cpp - Elide attributed safe coroutine calls ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file +// This pass transforms all Call or Invoke instructions that are annotated +// "coro_elide_safe" to call the `.noalloc` variant of coroutine instead. +// The frame of the callee coroutine is allocated inside the caller. A pointer +// to the allocated frame will be passed into the `.noalloc` ramp function. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h" + +#include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/IR/Analysis.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Transforms/Utils/CallGraphUpdater.h" + +#include + +using namespace llvm; + +#define DEBUG_TYPE "coro-annotation-elide" + +static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) { + for (Instruction &I : F->getEntryBlock()) + if (!isa(&I)) + return &I; + llvm_unreachable("no terminator in the entry block"); +} + +// Create an alloca in the caller, using FrameSize and FrameAlign as the callee +// coroutine's activation frame. +static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize, + Align FrameAlign) { + LLVMContext &C = Caller->getContext(); + BasicBlock::iterator InsertPt = + getFirstNonAllocaInTheEntryBlock(Caller)->getIterator(); + const DataLayout &DL = Caller->getDataLayout(); + auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize); + auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt); + Frame->setAlignment(FrameAlign); + return Frame; +} + +// Given a call or invoke instruction to the elide safe coroutine, this function +// does the following: +// - Allocate a frame for the callee coroutine in the caller using alloca. +// - Replace the old CB with a new Call or Invoke to `NewCallee`, with the +// pointer to the frame as an additional argument to NewCallee. +static void processCall(CallBase *CB, Function *Caller, Function *NewCallee, + uint64_t FrameSize, Align FrameAlign) { + // TODO: generate the lifetime intrinsics for the new frame. This will require + // introduction of two pesudo lifetime intrinsics in the frontend around the + // `co_await` expression and convert them to real lifetime intrinsics here. + auto *FramePtr = allocateFrameInCaller(Caller, FrameSize, FrameAlign); + auto NewCBInsertPt = CB->getIterator(); + llvm::CallBase *NewCB = nullptr; + SmallVector NewArgs; + NewArgs.append(CB->arg_begin(), CB->arg_end()); + NewArgs.push_back(FramePtr); + + if (auto *CI = dyn_cast(CB)) { + auto *NewCI = CallInst::Create(NewCallee->getFunctionType(), NewCallee, + NewArgs, "", NewCBInsertPt); + NewCI->setTailCallKind(CI->getTailCallKind()); + NewCB = NewCI; + } else if (auto *II = dyn_cast(CB)) { + NewCB = InvokeInst::Create(NewCallee->getFunctionType(), NewCallee, + II->getNormalDest(), II->getUnwindDest(), + NewArgs, std::nullopt, "", NewCBInsertPt); + } else { + llvm_unreachable("CallBase should either be Call or Invoke!"); + } + + NewCB->setCalledFunction(NewCallee->getFunctionType(), NewCallee); + NewCB->setCallingConv(CB->getCallingConv()); + NewCB->setAttributes(CB->getAttributes()); + NewCB->setDebugLoc(CB->getDebugLoc()); + std::copy(CB->bundle_op_info_begin(), CB->bundle_op_info_end(), + NewCB->bundle_op_info_begin()); + + NewCB->removeFnAttr(llvm::Attribute::CoroElideSafe); + CB->replaceAllUsesWith(NewCB); + CB->eraseFromParent(); +} + +PreservedAnalyses CoroAnnotationElidePass::run(LazyCallGraph::SCC &C, + CGSCCAnalysisManager &AM, + LazyCallGraph &CG, + CGSCCUpdateResult &UR) { + bool Changed = false; + CallGraphUpdater CGUpdater; + CGUpdater.initialize(CG, C, AM, UR); + + auto &FAM = + AM.getResult(C, CG).getManager(); + + for (LazyCallGraph::Node &N : C) { + Function *Callee = &N.getFunction(); + Function *NewCallee = Callee->getParent()->getFunction( + (Callee->getName() + ".noalloc").str()); + if (!NewCallee) { + continue; + } + + auto FramePtrArgPosition = NewCallee->arg_size() - 1; + auto FrameSize = + NewCallee->getParamDereferenceableBytes(FramePtrArgPosition); + auto FrameAlign = + NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne(); + + SmallVector Users; + for (auto *U : Callee->users()) { + if (auto *CB = dyn_cast(U)) { + if (CB->getCalledFunction() == Callee) + Users.push_back(CB); + } + } + + auto &ORE = FAM.getResult(*Callee); + + for (auto *CB : Users) { + auto *Caller = CB->getFunction(); + if (Caller && Caller->isPresplitCoroutine() && + CB->hasFnAttr(llvm::Attribute::CoroElideSafe)) { + + auto *CallerN = CG.lookup(*Caller); + auto *CallerC = CG.lookupSCC(*CallerN); + processCall(CB, Caller, NewCallee, FrameSize, FrameAlign); + + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller) + << "'" << ore::NV("callee", Callee->getName()) + << "' elided in '" << ore::NV("caller", Caller->getName()); + }); + Changed = true; + updateCGAndAnalysisManagerForCGSCCPass(CG, *CallerC, *CallerN, AM, UR, + FAM); + } + } + } + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h index d535ad7f85d74..be86f96525b67 100644 --- a/llvm/lib/Transforms/Coroutines/CoroInternal.h +++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h @@ -26,6 +26,13 @@ bool declaresIntrinsics(const Module &M, const std::initializer_list); void replaceCoroFree(CoroIdInst *CoroId, bool Elide); +/// Replaces all @llvm.coro.alloc intrinsics calls associated with a given +/// call @llvm.coro.id instruction with boolean value false. +void suppressCoroAllocs(CoroIdInst *CoroId); +/// Replaces CoroAllocs with boolean value false. +void suppressCoroAllocs(LLVMContext &Context, + ArrayRef CoroAllocs); + /// Attempts to rewrite the location operand of debug intrinsics in terms of /// the coroutine frame pointer, folding pointer offsets into the DIExpression /// of the intrinsic. diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp index 6bf3c75b95113..494c4d632de95 100644 --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -25,6 +25,7 @@ #include "llvm/ADT/PriorityWorklist.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/CFG.h" @@ -1177,6 +1178,14 @@ static void updateAsyncFuncPointerContextSize(coro::Shape &Shape) { Shape.AsyncLowering.AsyncFuncPointer->setInitializer(NewFuncPtrStruct); } +static TypeSize getFrameSizeForShape(coro::Shape &Shape) { + // In the same function all coro.sizes should have the same result type. + auto *SizeIntrin = Shape.CoroSizes.back(); + Module *M = SizeIntrin->getModule(); + const DataLayout &DL = M->getDataLayout(); + return DL.getTypeAllocSize(Shape.FrameTy); +} + static void replaceFrameSizeAndAlignment(coro::Shape &Shape) { if (Shape.ABI == coro::ABI::Async) updateAsyncFuncPointerContextSize(Shape); @@ -1192,10 +1201,8 @@ static void replaceFrameSizeAndAlignment(coro::Shape &Shape) { // In the same function all coro.sizes should have the same result type. auto *SizeIntrin = Shape.CoroSizes.back(); - Module *M = SizeIntrin->getModule(); - const DataLayout &DL = M->getDataLayout(); - auto Size = DL.getTypeAllocSize(Shape.FrameTy); - auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size); + auto *SizeConstant = + ConstantInt::get(SizeIntrin->getType(), getFrameSizeForShape(Shape)); for (CoroSizeInst *CS : Shape.CoroSizes) { CS->replaceAllUsesWith(SizeConstant); @@ -1452,6 +1459,75 @@ struct SwitchCoroutineSplitter { setCoroInfo(F, Shape, Clones); } + // Create a variant of ramp function that does not perform heap allocation + // for a switch ABI coroutine. + // + // The newly split `.noalloc` ramp function has the following differences: + // - Has one additional frame pointer parameter in lieu of dynamic + // allocation. + // - Suppressed allocations by replacing coro.alloc and coro.free. + static Function *createNoAllocVariant(Function &F, coro::Shape &Shape, + SmallVectorImpl &Clones) { + assert(Shape.ABI == coro::ABI::Switch); + auto *OrigFnTy = F.getFunctionType(); + auto OldParams = OrigFnTy->params(); + + SmallVector NewParams; + NewParams.reserve(OldParams.size() + 1); + NewParams.append(OldParams.begin(), OldParams.end()); + NewParams.push_back(PointerType::getUnqual(Shape.FrameTy)); + + auto *NewFnTy = FunctionType::get(OrigFnTy->getReturnType(), NewParams, + OrigFnTy->isVarArg()); + Function *NoAllocF = + Function::Create(NewFnTy, F.getLinkage(), F.getName() + ".noalloc"); + + ValueToValueMapTy VMap; + unsigned int Idx = 0; + for (const auto &I : F.args()) { + VMap[&I] = NoAllocF->getArg(Idx++); + } + // We just appended the frame pointer as the last argument of the new + // function. + auto FrameIdx = NoAllocF->arg_size() - 1; + SmallVector Returns; + CloneFunctionInto(NoAllocF, &F, VMap, + CloneFunctionChangeType::LocalChangesOnly, Returns); + + if (Shape.CoroBegin) { + auto *NewCoroBegin = + cast_if_present(VMap[Shape.CoroBegin]); + auto *NewCoroId = cast(NewCoroBegin->getId()); + coro::replaceCoroFree(NewCoroId, /*Elide=*/true); + coro::suppressCoroAllocs(NewCoroId); + NewCoroBegin->replaceAllUsesWith(NoAllocF->getArg(FrameIdx)); + NewCoroBegin->eraseFromParent(); + } + + Module *M = F.getParent(); + M->getFunctionList().insert(M->end(), NoAllocF); + + removeUnreachableBlocks(*NoAllocF); + auto NewAttrs = NoAllocF->getAttributes(); + // When we elide allocation, we read these attributes to determine the + // frame size and alignment. + addFramePointerAttrs(NewAttrs, NoAllocF->getContext(), FrameIdx, + Shape.FrameSize, Shape.FrameAlign, + /*NoAlias=*/false); + + NoAllocF->setAttributes(NewAttrs); + + Clones.push_back(NoAllocF); + // Reset the original function's coro info, make the new noalloc variant + // connected to the original ramp function. + setCoroInfo(F, Shape, Clones); + // After copying, set the linkage to internal linkage. Original function + // may have different linkage, but optimization dependent on this function + // generally relies on LTO. + NoAllocF->setLinkage(llvm::GlobalValue::InternalLinkage); + return NoAllocF; + } + private: // Create a resume clone by cloning the body of the original function, setting // new entry block and replacing coro.suspend an appropriate value to force @@ -1910,6 +1986,33 @@ class PrettyStackTraceFunction : public PrettyStackTraceEntry { }; } // namespace +/// Remove calls to llvm.coro.end in the original function. +static void removeCoroEndsFromRampFunction(const coro::Shape &Shape) { + if (Shape.ABI != coro::ABI::Switch) { + for (auto *End : Shape.CoroEnds) { + replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, nullptr); + } + } else { + for (llvm::AnyCoroEndInst *End : Shape.CoroEnds) { + auto &Context = End->getContext(); + End->replaceAllUsesWith(ConstantInt::getFalse(Context)); + End->eraseFromParent(); + } + } +} + +static bool hasSafeElideCaller(Function &F) { + for (auto *U : F.users()) { + if (auto *CB = dyn_cast(U)) { + auto *Caller = CB->getFunction(); + if (Caller && Caller->isPresplitCoroutine() && + CB->hasFnAttr(llvm::Attribute::CoroElideSafe)) + return true; + } + } + return false; +} + static coro::Shape splitCoroutine(Function &F, SmallVectorImpl &Clones, TargetTransformInfo &TTI, bool OptimizeFrame, @@ -1929,10 +2032,15 @@ splitCoroutine(Function &F, SmallVectorImpl &Clones, simplifySuspendPoints(Shape); buildCoroutineFrame(F, Shape, TTI, MaterializableCallback); replaceFrameSizeAndAlignment(Shape); + bool isNoSuspendCoroutine = Shape.CoroSuspends.empty(); + + bool shouldCreateNoAllocVariant = !isNoSuspendCoroutine && + Shape.ABI == coro::ABI::Switch && + hasSafeElideCaller(F); // If there are no suspend points, no split required, just remove // the allocation and deallocation blocks, they are not needed. - if (Shape.CoroSuspends.empty()) { + if (isNoSuspendCoroutine) { handleNoSuspendCoroutine(Shape); } else { switch (Shape.ABI) { @@ -1962,22 +2070,13 @@ splitCoroutine(Function &F, SmallVectorImpl &Clones, coro::salvageDebugInfo(ArgToAllocaMap, *DDI, false /*UseEntryValue*/); for (DbgVariableRecord *DVR : DbgVariableRecords) coro::salvageDebugInfo(ArgToAllocaMap, *DVR, false /*UseEntryValue*/); - return Shape; -} -/// Remove calls to llvm.coro.end in the original function. -static void removeCoroEndsFromRampFunction(const coro::Shape &Shape) { - if (Shape.ABI != coro::ABI::Switch) { - for (auto *End : Shape.CoroEnds) { - replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, nullptr); - } - } else { - for (llvm::AnyCoroEndInst *End : Shape.CoroEnds) { - auto &Context = End->getContext(); - End->replaceAllUsesWith(ConstantInt::getFalse(Context)); - End->eraseFromParent(); - } - } + removeCoroEndsFromRampFunction(Shape); + + if (shouldCreateNoAllocVariant) + SwitchCoroutineSplitter::createNoAllocVariant(F, Shape, Clones); + + return Shape; } static void updateCallGraphAfterCoroutineSplit( @@ -2108,13 +2207,12 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, F.setSplittedCoroutine(); SmallVector Clones; - auto &ORE = FAM.getResult(F); - const coro::Shape Shape = + coro::Shape Shape = splitCoroutine(F, Clones, FAM.getResult(F), OptimizeFrame, MaterializableCallback); - removeCoroEndsFromRampFunction(Shape); updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM); + auto &ORE = FAM.getResult(F); ORE.emit([&]() { return OptimizationRemark(DEBUG_TYPE, "CoroSplit", &F) << "Split '" << ore::NV("function", F.getName()) @@ -2130,9 +2228,9 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, } } - for (auto *PrepareFn : PrepareFns) { - replaceAllPrepares(PrepareFn, CG, C); - } + for (auto *PrepareFn : PrepareFns) { + replaceAllPrepares(PrepareFn, CG, C); + } return PreservedAnalyses::none(); } diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp index 1a92bc1636257..be257339e0ac4 100644 --- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -145,6 +145,33 @@ void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) { } } +void coro::suppressCoroAllocs(CoroIdInst *CoroId) { + SmallVector CoroAllocs; + for (User *U : CoroId->users()) + if (auto *CA = dyn_cast(U)) + CoroAllocs.push_back(CA); + + if (CoroAllocs.empty()) + return; + + coro::suppressCoroAllocs(CoroId->getContext(), CoroAllocs); +} + +// Replacing llvm.coro.alloc with false will suppress dynamic +// allocation as it is expected for the frontend to generate the code that +// looks like: +// id = coro.id(...) +// mem = coro.alloc(id) ? malloc(coro.size()) : 0; +// coro.begin(id, mem) +void coro::suppressCoroAllocs(LLVMContext &Context, + ArrayRef CoroAllocs) { + auto *False = ConstantInt::getFalse(Context); + for (auto *CA : CoroAllocs) { + CA->replaceAllUsesWith(False); + CA->eraseFromParent(); + } +} + static void clear(coro::Shape &Shape) { Shape.CoroBegin = nullptr; Shape.CoroEnds.clear(); diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index d378c6c3a4b01..895b588a9e5ac 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -916,6 +916,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::Memory: case Attribute::NoFPClass: case Attribute::CoroDestroyOnlyWhenComplete: + case Attribute::CoroElideSafe: continue; // Those attributes should be safe to propagate to the extracted function. case Attribute::AlwaysInline: diff --git a/llvm/test/Other/new-pm-defaults.ll b/llvm/test/Other/new-pm-defaults.ll index 588337c15625e..55dbdb1b8366d 100644 --- a/llvm/test/Other/new-pm-defaults.ll +++ b/llvm/test/Other/new-pm-defaults.ll @@ -226,6 +226,7 @@ ; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Running pass: CoroSplitPass +; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass ; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis diff --git a/llvm/test/Other/new-pm-thinlto-postlink-defaults.ll b/llvm/test/Other/new-pm-thinlto-postlink-defaults.ll index 064362eabbf83..fcf84dc5e1105 100644 --- a/llvm/test/Other/new-pm-thinlto-postlink-defaults.ll +++ b/llvm/test/Other/new-pm-thinlto-postlink-defaults.ll @@ -153,6 +153,7 @@ ; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Running pass: CoroSplitPass +; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass ; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis diff --git a/llvm/test/Other/new-pm-thinlto-postlink-pgo-defaults.ll b/llvm/test/Other/new-pm-thinlto-postlink-pgo-defaults.ll index 19a44867e434a..4d5b5e733a87c 100644 --- a/llvm/test/Other/new-pm-thinlto-postlink-pgo-defaults.ll +++ b/llvm/test/Other/new-pm-thinlto-postlink-pgo-defaults.ll @@ -137,6 +137,7 @@ ; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Running pass: CoroSplitPass +; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass ; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis diff --git a/llvm/test/Other/new-pm-thinlto-postlink-samplepgo-defaults.ll b/llvm/test/Other/new-pm-thinlto-postlink-samplepgo-defaults.ll index 9c2025f7d1ec3..62b81ac7cad03 100644 --- a/llvm/test/Other/new-pm-thinlto-postlink-samplepgo-defaults.ll +++ b/llvm/test/Other/new-pm-thinlto-postlink-samplepgo-defaults.ll @@ -146,6 +146,7 @@ ; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Running pass: CoroSplitPass +; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass ; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis ; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis diff --git a/llvm/test/Transforms/Coroutines/coro-split-00.ll b/llvm/test/Transforms/Coroutines/coro-split-00.ll index b35bd720b86f9..9909627e60597 100644 --- a/llvm/test/Transforms/Coroutines/coro-split-00.ll +++ b/llvm/test/Transforms/Coroutines/coro-split-00.ll @@ -32,6 +32,13 @@ suspend: ret ptr %hdl } +; Make a safe_elide call to f and CoroSplit should generate the .noalloc variant +define void @caller() presplitcoroutine { +entry: + %ptr = call ptr @f() #1 + ret void +} + ; CHECK-LABEL: @f() !func_sanitize !0 { ; CHECK: call ptr @malloc ; CHECK: @llvm.coro.begin(token %id, ptr %phi) @@ -63,6 +70,13 @@ suspend: ; CHECK-NOT: call void @free( ; CHECK: ret void +; CHECK-LABEL: @f.noalloc(ptr noundef nonnull align 8 dereferenceable(24) %{{.*}}) +; CHECK-NOT: call ptr @malloc +; CHECK: call void @print(i32 0) +; CHECK-NOT: call void @print(i32 1) +; CHECK-NOT: call void @free( +; CHECK: ret ptr %{{.*}} + declare ptr @llvm.coro.free(token, ptr) declare i32 @llvm.coro.size.i32() declare i8 @llvm.coro.suspend(token, i1) @@ -79,3 +93,4 @@ declare void @print(i32) declare void @free(ptr) willreturn allockind("free") "alloc-family"="malloc" !0 = !{i32 846595819, ptr null} +attributes #1 = { coro_elide_safe } diff --git a/llvm/test/Transforms/Coroutines/coro-transform-must-elide.ll b/llvm/test/Transforms/Coroutines/coro-transform-must-elide.ll new file mode 100644 index 0000000000000..a4e575f6c0381 --- /dev/null +++ b/llvm/test/Transforms/Coroutines/coro-transform-must-elide.ll @@ -0,0 +1,75 @@ +; Testing elide performed its job for calls to coroutines marked safe. +; RUN: opt < %s -S -passes='cgscc(coro-annotation-elide)' | FileCheck %s + +%struct.Task = type { ptr } + +declare void @print(i32) nounwind + +; resume part of the coroutine +define fastcc void @callee.resume(ptr dereferenceable(1)) { + tail call void @print(i32 0) + ret void +} + +; destroy part of the coroutine +define fastcc void @callee.destroy(ptr) { + tail call void @print(i32 1) + ret void +} + +; cleanup part of the coroutine +define fastcc void @callee.cleanup(ptr) { + tail call void @print(i32 2) + ret void +} + +@callee.resumers = internal constant [3 x ptr] [ + ptr @callee.resume, ptr @callee.destroy, ptr @callee.cleanup] + +declare void @alloc(i1) nounwind + +; CHECK-LABEL: define ptr @callee +define ptr @callee(i8 %arg) { +entry: + %task = alloca %struct.Task, align 8 + %id = call token @llvm.coro.id(i32 0, ptr null, + ptr @callee, + ptr @callee.resumers) + %alloc = call i1 @llvm.coro.alloc(token %id) + %hdl = call ptr @llvm.coro.begin(token %id, ptr null) + store ptr %hdl, ptr %task + ret ptr %task +} + +; CHECK-LABEL: define ptr @callee.noalloc +define ptr @callee.noalloc(i8 %arg, ptr dereferenceable(32) align(8) %frame) { + entry: + %task = alloca %struct.Task, align 8 + %id = call token @llvm.coro.id(i32 0, ptr null, + ptr @callee, + ptr @callee.resumers) + %hdl = call ptr @llvm.coro.begin(token %id, ptr null) + store ptr %hdl, ptr %task + ret ptr %task +} + +; CHECK-LABEL: define ptr @caller() +; Function Attrs: presplitcoroutine +define ptr @caller() #0 { +entry: + %task = call ptr @callee(i8 0) #1 + ret ptr %task + + ; CHECK: %[[FRAME:.+]] = alloca [32 x i8], align 8 + ; CHECK-NEXT: %[[TASK:.+]] = call ptr @callee.noalloc(i8 0, ptr %[[FRAME]]) + ; CHECK-NEXT: ret ptr %[[TASK]] +} + +declare token @llvm.coro.id(i32, ptr, ptr, ptr) +declare ptr @llvm.coro.begin(token, ptr) +declare ptr @llvm.coro.frame() +declare ptr @llvm.coro.subfn.addr(ptr, i8) +declare i1 @llvm.coro.alloc(token) + +attributes #0 = { presplitcoroutine } +attributes #1 = { coro_elide_safe }