diff --git a/include/swift/AST/Stmt.h b/include/swift/AST/Stmt.h index 513be8e40d774..7c8e535efbb2d 100644 --- a/include/swift/AST/Stmt.h +++ b/include/swift/AST/Stmt.h @@ -39,6 +39,8 @@ class Pattern; class PatternBindingDecl; class VarDecl; class CaseStmt; +class DoCatchStmt; +class SwitchStmt; enum class StmtKind { #define STMT(ID, PARENT) ID, @@ -927,6 +929,7 @@ class CaseStmt final CaseLabelItem> { friend TrailingObjects; + Stmt *ParentStmt = nullptr; SourceLoc UnknownAttrLoc; SourceLoc ItemIntroducerLoc; SourceLoc ItemTerminatorLoc; @@ -954,6 +957,14 @@ class CaseStmt final CaseParentKind getParentKind() const { return ParentKind; } + Stmt *getParentStmt() const { return ParentStmt; } + void setParentStmt(Stmt *S) { + assert(S && "Parent statement must be SwitchStmt or DoCatchStmt"); + assert((ParentKind == CaseParentKind::Switch && isa(S)) || + (ParentKind == CaseParentKind::DoCatch && isa(S))); + ParentStmt = S; + } + ArrayRef getCaseLabelItems() const { return {getTrailingObjects(), Bits.CaseStmt.NumPatterns}; } @@ -1161,6 +1172,8 @@ class DoCatchStmt final Bits.DoCatchStmt.NumCatches = catches.size(); std::uninitialized_copy(catches.begin(), catches.end(), getTrailingObjects()); + for (auto *catchStmt : getCatches()) + catchStmt->setParentStmt(this); } public: diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp index d2ee30cedbcdb..c070a55d3cd8a 100644 --- a/lib/AST/Stmt.cpp +++ b/lib/AST/Stmt.cpp @@ -479,6 +479,9 @@ SwitchStmt *SwitchStmt::create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc, std::uninitialized_copy(Cases.begin(), Cases.end(), theSwitch->getTrailingObjects()); + for (auto *caseStmt : theSwitch->getCases()) + caseStmt->setParentStmt(theSwitch); + return theSwitch; }