Skip to content

Add AST representation for coroutines #78508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 2 additions & 20 deletions include/swift/AST/AnyFunctionRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class AnyFunctionRef {
Type getBodyResultType() const {
if (auto *AFD = TheFunction.dyn_cast<AbstractFunctionDecl *>()) {
if (auto *FD = dyn_cast<FuncDecl>(AFD))
return FD->mapTypeIntoContext(FD->getResultInterfaceType());
return FD->mapTypeIntoContext(FD->getResultInterfaceTypeWithoutYields());
return TupleType::getEmpty(AFD->getASTContext());
}
return TheFunction.get<AbstractClosureExpr *>()->getResultType();
Expand Down Expand Up @@ -263,25 +263,7 @@ class AnyFunctionRef {
private:
ArrayRef<AnyFunctionType::Yield>
getYieldResultsImpl(SmallVectorImpl<AnyFunctionType::Yield> &buffer,
bool mapIntoContext) const {
assert(buffer.empty());
if (auto *AFD = TheFunction.dyn_cast<AbstractFunctionDecl *>()) {
if (auto *AD = dyn_cast<AccessorDecl>(AFD)) {
if (AD->isCoroutine()) {
auto valueTy = AD->getStorage()->getValueInterfaceType()
->getReferenceStorageReferent();
if (mapIntoContext)
valueTy = AD->mapTypeIntoContext(valueTy);
YieldTypeFlags flags(isYieldingMutableAccessor(AD->getAccessorKind())
? ParamSpecifier::InOut
: ParamSpecifier::LegacyShared);
buffer.push_back(AnyFunctionType::Yield(valueTy, flags));
return buffer;
}
}
}
return {};
}
bool mapIntoContext) const;
};
#if SWIFT_COMPILER_IS_MSVC
#pragma warning(pop)
Expand Down
10 changes: 9 additions & 1 deletion include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7713,6 +7713,8 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
/// attribute.
bool isTransparent() const;

bool isCoroutine() const;

// Expose our import as member status
ImportAsMemberStatus getImportAsMemberStatus() const {
return ImportAsMemberStatus(Bits.AbstractFunctionDecl.IAMStatus);
Expand Down Expand Up @@ -8331,9 +8333,15 @@ class FuncDecl : public AbstractFunctionDecl {
return FnRetType.getSourceRange();
}

/// Retrieve the result interface type of this function.
/// Retrieve the full result interface type of this function, including yields
Type getResultInterfaceType() const;

/// Same as above, but without @yields
Type getResultInterfaceTypeWithoutYields() const;

/// Same as above, but only yields
Type getYieldsInterfaceType() const;

/// isUnaryOperator - Determine whether this is a unary operator
/// implementation. This check is a syntactic rather than type-based check,
/// which looks at the number of parameters specified, in order to allow
Expand Down
6 changes: 5 additions & 1 deletion include/swift/AST/DeclAttr.def
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,11 @@ DECL_ATTR(abi, ABI,
165)
DECL_ATTR_FEATURE_REQUIREMENT(ABI, ABIAttribute)

LAST_DECL_ATTR(ABI)
SIMPLE_DECL_ATTR(yield_once, Coroutine,
OnFunc | UserInaccessible | ABIBreakingToAdd | ABIBreakingToRemove | APIBreakingToAdd | APIBreakingToRemove,
166)

LAST_DECL_ATTR(Coroutine)

#undef DECL_ATTR_ALIAS
#undef CONTEXTUAL_DECL_ATTR_ALIAS
Expand Down
28 changes: 24 additions & 4 deletions include/swift/AST/ExtInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ enum class SILFunctionTypeRepresentation : uint8_t {
CFunctionPointer = uint8_t(FunctionTypeRepresentation::CFunctionPointer),

/// The value of the greatest AST function representation.
LastAST = CFunctionPointer,
LastAST = uint8_t(FunctionTypeRepresentation::Last),

/// The value of the least SIL-only function representation.
FirstSIL = 8,
Expand Down Expand Up @@ -438,8 +438,8 @@ class ASTExtInfoBuilder {
// If bits are added or removed, then TypeBase::NumAFTExtInfoBits
// and NumMaskBits must be updated, and they must match.
//
// |representation|noEscape|concurrent|async|throws|isolation|differentiability| SendingResult |
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 10 | 11 .. 13 | 14 |
// |representation|noEscape|concurrent|async|throws|isolation|differentiability| SendingResult | coroutine |
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 10 | 11 .. 13 | 14 | 15 |
//
enum : unsigned {
RepresentationMask = 0xF << 0,
Expand All @@ -452,7 +452,8 @@ class ASTExtInfoBuilder {
DifferentiabilityMaskOffset = 11,
DifferentiabilityMask = 0x7 << DifferentiabilityMaskOffset,
SendingResultMask = 1 << 14,
NumMaskBits = 15
CoroutineMask = 1 << 15,
NumMaskBits = 16
};

static_assert(FunctionTypeIsolation::Mask == 0x7, "update mask manually");
Expand Down Expand Up @@ -531,6 +532,8 @@ class ASTExtInfoBuilder {

constexpr bool hasSendingResult() const { return bits & SendingResultMask; }

constexpr bool isCoroutine() const { return bits & CoroutineMask; }

constexpr DifferentiabilityKind getDifferentiabilityKind() const {
return DifferentiabilityKind((bits & DifferentiabilityMask) >>
DifferentiabilityMaskOffset);
Expand Down Expand Up @@ -647,6 +650,13 @@ class ASTExtInfoBuilder {
clangTypeInfo, globalActor, thrownError, lifetimeDependencies);
}

[[nodiscard]]
ASTExtInfoBuilder withCoroutine(bool coroutine = true) const {
return ASTExtInfoBuilder(
coroutine ? (bits | CoroutineMask) : (bits & ~CoroutineMask),
clangTypeInfo, globalActor, thrownError, lifetimeDependencies);
}

[[nodiscard]]
ASTExtInfoBuilder
withDifferentiabilityKind(DifferentiabilityKind differentiability) const {
Expand Down Expand Up @@ -762,6 +772,8 @@ class ASTExtInfo {

constexpr bool isThrowing() const { return builder.isThrowing(); }

constexpr bool isCoroutine() const { return builder.isCoroutine(); }

constexpr bool hasSendingResult() const { return builder.hasSendingResult(); }

constexpr DifferentiabilityKind getDifferentiabilityKind() const {
Expand Down Expand Up @@ -825,6 +837,14 @@ class ASTExtInfo {
return builder.withThrows(true, Type()).build();
}

/// Helper method for changing only the coroutine field.
///
/// Prefer using \c ASTExtInfoBuilder::withCoroutine for chaining.
[[nodiscard]]
ASTExtInfo withCoroutine(bool coroutine = true) const {
return builder.withCoroutine(coroutine).build();
}

/// Helper method for changing only the async field.
///
/// Prefer using \c ASTExtInfoBuilder::withAsync for chaining.
Expand Down
6 changes: 3 additions & 3 deletions include/swift/AST/TypeAttr.def
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ SIMPLE_TYPE_ATTR(_noMetadata, NoMetadata)
TYPE_ATTR(_opaqueReturnTypeOf, OpaqueReturnTypeOf)
TYPE_ATTR(isolated, Isolated)
SIMPLE_TYPE_ATTR(_addressable, Addressable)
SIMPLE_TYPE_ATTR(yields, Yields)
SIMPLE_TYPE_ATTR(yield_once, YieldOnce)
SIMPLE_TYPE_ATTR(yield_once_2, YieldOnce2)

// SIL-specific attributes
SIMPLE_SIL_TYPE_ATTR(async, Async)
Expand Down Expand Up @@ -102,9 +105,6 @@ SIL_TYPE_ATTR(opened, Opened)
SIL_TYPE_ATTR(pack_element, PackElement)
SIMPLE_SIL_TYPE_ATTR(pseudogeneric, Pseudogeneric)
SIMPLE_SIL_TYPE_ATTR(unimplementable, Unimplementable)
SIMPLE_SIL_TYPE_ATTR(yields, Yields)
SIMPLE_SIL_TYPE_ATTR(yield_once, YieldOnce)
SIMPLE_SIL_TYPE_ATTR(yield_once_2, YieldOnce2)
SIMPLE_SIL_TYPE_ATTR(yield_many, YieldMany)
SIMPLE_SIL_TYPE_ATTR(captures_generics, CapturesGenerics)
// Used at the SIL level to mark a type as moveOnly.
Expand Down
4 changes: 4 additions & 0 deletions include/swift/AST/TypeDifferenceVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ class CanTypeDifferenceVisitor : public CanTypePairVisitor<Impl, bool> {
type1->getElements(), type2->getElements());
}

bool visitYieldResultType(CanYieldResultType type1, CanYieldResultType type2) {
return asImpl().visit(type1.getResultType(), type2.getResultType());
}

bool visitComponent(CanType type1, CanType type2,
const TupleTypeElt &elt1, const TupleTypeElt &elt2) {
if (elt1.getName() != elt2.getName())
Expand Down
11 changes: 11 additions & 0 deletions include/swift/AST/TypeMatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ class TypeMatcher {
return mismatch(firstTuple.getPointer(), secondType, sugaredFirstType);
}

bool visitYieldResultType(CanYieldResultType firstType, Type secondType,
Type sugaredFirstType) {
if (auto secondYieldType = secondType->getAs<YieldResultType>())
if (!this->visit(firstType.getResultType(),
secondYieldType->getResultType(),
sugaredFirstType->getAs<YieldResultType>()->getResultType()))
return false;

return mismatch(firstType.getPointer(), secondType, sugaredFirstType);
}

bool visitSILPackType(CanSILPackType firstPack, Type secondType,
Type sugaredFirstType) {
if (auto secondPack = secondType->getAs<SILPackType>()) {
Expand Down
1 change: 1 addition & 0 deletions include/swift/AST/TypeNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ TYPE(InOut, Type)
TYPE(Pack, Type)
TYPE(PackExpansion, Type)
TYPE(PackElement, Type)
TYPE(YieldResult, Type)
UNCHECKED_TYPE(TypeVariable, Type)
UNCHECKED_TYPE(ErrorUnion, Type)
ALWAYS_CANONICAL_TYPE(Integer, Type)
Expand Down
11 changes: 11 additions & 0 deletions include/swift/AST/TypeTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/SILLayout.h"
#include "swift/AST/Types.h"

namespace swift {

Expand Down Expand Up @@ -955,6 +956,16 @@ case TypeKind::Id:
t : InOutType::get(objectTy);
}

case TypeKind::YieldResult: {
auto yield = cast<YieldResultType>(base);
auto objectTy = doIt(yield->getResultType(), TypePosition::Invariant);
if (!objectTy || objectTy->hasError())
return objectTy;

return objectTy.getPointer() == yield->getResultType().getPointer() ?
t : YieldResultType::get(objectTy, yield->isInOut());
}

case TypeKind::Existential: {
auto *existential = cast<ExistentialType>(base);
auto constraint = doIt(existential->getConstraintType(), pos);
Expand Down
41 changes: 38 additions & 3 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ class alignas(1 << TypeAlignInBits) TypeBase
}

protected:
enum { NumAFTExtInfoBits = 15 };
enum { NumAFTExtInfoBits = 16 };
enum { NumSILExtInfoBits = 14 };

// clang-format off
Expand All @@ -428,7 +428,7 @@ class alignas(1 << TypeAlignInBits) TypeBase
HasCachedType : 1
);

SWIFT_INLINE_BITFIELD_FULL(AnyFunctionType, TypeBase, NumAFTExtInfoBits+1+1+1+1+16,
SWIFT_INLINE_BITFIELD_FULL(AnyFunctionType, TypeBase, NumAFTExtInfoBits+1+1+1+1+15,
/// Extra information which affects how the function is called, like
/// regparm and the calling convention.
ExtInfoBits : NumAFTExtInfoBits,
Expand All @@ -437,7 +437,7 @@ class alignas(1 << TypeAlignInBits) TypeBase
HasThrownError : 1,
HasLifetimeDependencies : 1,
: NumPadBits,
NumParams : 16
NumParams : 15
);

SWIFT_INLINE_BITFIELD_FULL(ArchetypeType, TypeBase, 1+1+16,
Expand Down Expand Up @@ -1651,6 +1651,36 @@ class UnresolvedType : public TypeBase {
};
DEFINE_EMPTY_CAN_TYPE_WRAPPER(UnresolvedType, Type)

class YieldResultType : public TypeBase {
Type ResultType;
bool InOut = false;

YieldResultType(Type objectTy, bool InOut, const ASTContext *canonicalContext,
RecursiveTypeProperties properties)
: TypeBase(TypeKind::YieldResult, canonicalContext, properties),
ResultType(objectTy), InOut(InOut) {}

public:
static YieldResultType *get(Type originalType, bool InOut);

Type getResultType() const { return ResultType; }
bool isInOut() const { return InOut; }

// Implement isa/cast/dyncast/etc.
static bool classof(const TypeBase *T) {
return T->getKind() == TypeKind::YieldResult;
}
};

BEGIN_CAN_TYPE_WRAPPER(YieldResultType, Type)
PROXY_CAN_TYPE_SIMPLE_GETTER(getResultType)
bool isInOut() const {
return getPointer()->isInOut();
}
static CanYieldResultType get(CanType type, bool InOut) {
return CanYieldResultType(YieldResultType::get(type, InOut));
}
END_CAN_TYPE_WRAPPER(YieldResultType, Type)

/// BuiltinType - An abstract class for all the builtin types.
class BuiltinType : public TypeBase {
Expand Down Expand Up @@ -3798,6 +3828,9 @@ class AnyFunctionType : public TypeBase {
/// Return the function type without the throwing.
AnyFunctionType *getWithoutThrowing() const;

/// Return the function type without yields (and coroutine flag)
AnyFunctionType *getWithoutYields() const;

/// True if the parameter declaration it is attached to is guaranteed
/// to not persist the closure for longer than the duration of the call.
bool isNoEscape() const {
Expand All @@ -3810,6 +3843,8 @@ class AnyFunctionType : public TypeBase {

bool isThrowing() const { return getExtInfo().isThrowing(); }

bool isCoroutine() const { return getExtInfo().isCoroutine(); }

bool hasSendingResult() const { return getExtInfo().hasSendingResult(); }

bool hasEffect(EffectKind kind) const;
Expand Down
2 changes: 2 additions & 0 deletions include/swift/Demangling/DemangleNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,8 @@ NODE(DependentGenericInverseConformanceRequirement)
NODE(Integer)
NODE(NegativeInteger)
NODE(DependentGenericParamValueMarker)
NODE(YieldResult)
NODE(Coroutine)

#undef CONTEXT_NODE
#undef NODE
4 changes: 2 additions & 2 deletions include/swift/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,8 @@ class Parser {

bool isContextualYieldKeyword() {
return (Tok.isContextualKeyword("yield") &&
isa<AccessorDecl>(CurDeclContext) &&
cast<AccessorDecl>(CurDeclContext)->isCoroutine());
(isa<AbstractFunctionDecl>(CurDeclContext) &&
cast<AbstractFunctionDecl>(CurDeclContext)->isCoroutine()));
}

/// Whether the current token is the contextual keyword for a \c then
Expand Down
2 changes: 1 addition & 1 deletion include/swift/SIL/AbstractionPattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -1526,7 +1526,7 @@ class AbstractionPattern {

/// Given that the value being abstracted is a function, return the
/// abstraction pattern for its result type.
AbstractionPattern getFunctionResultType() const;
AbstractionPattern getFunctionResultType(bool withoutYields = false) const;

/// Given that the value being abstracted is a function, return the
/// abstraction pattern for its thrown error type.
Expand Down
26 changes: 26 additions & 0 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,8 @@ struct ASTContext::Implementation {
llvm::DenseMap<uintptr_t, ReferenceStorageType*> ReferenceStorageTypes;
llvm::DenseMap<Type, LValueType*> LValueTypes;
llvm::DenseMap<Type, InOutType*> InOutTypes;
llvm::DenseMap<llvm::PointerIntPair<TypeBase*, 1, bool>,
YieldResultType*> YieldResultTypes;
llvm::DenseMap<std::pair<Type, void*>, DependentMemberType *>
DependentMemberTypes;
llvm::FoldingSet<ErrorUnionType> ErrorUnionTypes;
Expand Down Expand Up @@ -3247,6 +3249,7 @@ size_t ASTContext::Implementation::Arena::getTotalMemory() const {
llvm::capacity_in_bytes(ReferenceStorageTypes) +
llvm::capacity_in_bytes(LValueTypes) +
llvm::capacity_in_bytes(InOutTypes) +
llvm::capacity_in_bytes(YieldResultTypes) +
llvm::capacity_in_bytes(DependentMemberTypes) +
llvm::capacity_in_bytes(EnumTypes) +
llvm::capacity_in_bytes(StructTypes) +
Expand Down Expand Up @@ -3286,6 +3289,7 @@ void ASTContext::Implementation::Arena::dump(llvm::raw_ostream &os) const {
SIZE_AND_BYTES(ReferenceStorageTypes);
SIZE_AND_BYTES(LValueTypes);
SIZE_AND_BYTES(InOutTypes);
SIZE_AND_BYTES(YieldResultTypes);
SIZE_AND_BYTES(DependentMemberTypes);
SIZE(ErrorUnionTypes);
SIZE_AND_BYTES(PlaceholderTypes);
Expand Down Expand Up @@ -5477,6 +5481,28 @@ InOutType *InOutType::get(Type objectTy) {
properties);
}

YieldResultType *YieldResultType::get(Type objectTy, bool InOut) {
auto properties = objectTy->getRecursiveProperties();
if (InOut) {
assert(!objectTy->is<LValueType>() && !objectTy->is<InOutType>() &&
"cannot have 'inout' or @lvalue wrapped inside an 'inout yield'");
properties &= ~RecursiveTypeProperties::IsLValue;
}

auto arena = getArena(properties);

auto &C = objectTy->getASTContext();
auto pair = llvm::PointerIntPair<TypeBase*, 1, bool>(objectTy.getPointer(),
InOut);
auto &entry = C.getImpl().getArena(arena).YieldResultTypes[pair];
if (entry)
return entry;

const ASTContext *canonicalContext = objectTy->isCanonical() ? &C : nullptr;
return entry = new (C, arena) YieldResultType(objectTy, InOut, canonicalContext,
properties);
}

DependentMemberType *DependentMemberType::get(Type base, Identifier name) {
auto properties = base->getRecursiveProperties();
properties |= RecursiveTypeProperties::HasDependentMember;
Expand Down
Loading