diff --git a/Package.swift b/Package.swift index 12b28d0bd..9a6fcbd5f 100644 --- a/Package.swift +++ b/Package.swift @@ -65,7 +65,7 @@ var dependencies: [Package.Dependency] { branch: "main"), .package( url: "https://github.com/swiftlang/swift-syntax", - from: "600.0.0") + branch: "main") ] } } diff --git a/Sources/FoundationMacros/PredicateMacro.swift b/Sources/FoundationMacros/PredicateMacro.swift index bd48a6aec..a4b8bb4ab 100644 --- a/Sources/FoundationMacros/PredicateMacro.swift +++ b/Sources/FoundationMacros/PredicateMacro.swift @@ -229,20 +229,25 @@ extension SyntaxProtocol { private protocol PredicateSyntaxRewriter : SyntaxRewriter { var success: Bool { get } + var ignorable: Bool { get } var diagnostics: [Diagnostic] { get } } extension PredicateSyntaxRewriter { var success: Bool { true } + var ignorable: Bool { false } var diagnostics: [Diagnostic] { [] } } extension SyntaxProtocol { fileprivate func rewrite(with rewriter: some PredicateSyntaxRewriter) throws -> Syntax { - let translated = rewriter.rewrite(Syntax(self)) + let translated = rewriter.rewrite(self) guard rewriter.success else { throw DiagnosticsError(diagnostics: rewriter.diagnostics) } + guard !rewriter.ignorable else { + return Syntax(self) + } return translated } } @@ -251,6 +256,7 @@ private class OptionalChainRewriter: SyntaxRewriter, PredicateSyntaxRewriter { var withinValidChainingTreeStart = true var withinChainingTree = false var optionalInput: ExprSyntax? = nil + var ignorable = true private func _prePossibleTopOfTree() -> Bool { if !withinChainingTree && withinValidChainingTreeStart { @@ -265,6 +271,7 @@ private class OptionalChainRewriter: SyntaxRewriter, PredicateSyntaxRewriter { withinChainingTree = false if let input = optionalInput { optionalInput = nil + ignorable = false let visited = self.visit(input) let closure = ClosureExprSyntax(statements: [CodeBlockItemSyntax(item: CodeBlockItemSyntax.Item(node))]) let functionMember = MemberAccessExprSyntax(base: visited, name: "flatMap") @@ -282,10 +289,14 @@ private class OptionalChainRewriter: SyntaxRewriter, PredicateSyntaxRewriter { // We're in the middle of a potential tree, so rewrite the closure with a fresh state // This ensures potential chaining in the closure isn't rewritten outside of the closure - guard let rewritten = (try? node.rewrite(with: OptionalChainRewriter()))?.as(ExprSyntax.self) else { + let nestedRewriter = OptionalChainRewriter() + guard let rewritten = (try? node.rewrite(with: nestedRewriter))?.as(ExprSyntax.self) else { // If rewriting the closure failed, just leave the closure as-is return ExprSyntax(node) } + if ignorable { + ignorable = nestedRewriter.ignorable + } return rewritten } diff --git a/Tests/FoundationMacrosTests/MacroTestUtilities.swift b/Tests/FoundationMacrosTests/MacroTestUtilities.swift index 285e70ce6..275e255c8 100644 --- a/Tests/FoundationMacrosTests/MacroTestUtilities.swift +++ b/Tests/FoundationMacrosTests/MacroTestUtilities.swift @@ -121,12 +121,14 @@ extension DiagnosticTest { } func AssertMacroExpansion(macros: [String : Macro.Type], testModuleName: String = "TestModule", testFileName: String = "test.swift", _ source: String, _ result: String = "", diagnostics: Set = [], file: StaticString = #filePath, line: UInt = #line) { - let context = BasicMacroExpansionContext() let origSourceFile = Parser.parse(source: source) let expandedSourceFile: Syntax + let context: BasicMacroExpansionContext do { - expandedSourceFile = try OperatorTable.standardOperators.foldAll(origSourceFile).expand(macros: macros) { syntax in - BasicMacroExpansionContext(sharingWith: context, lexicalContext: [syntax]) + let foldedSourceFile = try OperatorTable.standardOperators.foldAll(origSourceFile).cast(SourceFileSyntax.self) + context = BasicMacroExpansionContext(sourceFiles: [foldedSourceFile : .init(moduleName: testModuleName, fullFilePath: testFileName)]) + expandedSourceFile = foldedSourceFile.expand(macros: macros) { + BasicMacroExpansionContext(sharingWith: context, lexicalContext: [$0]) } } catch { XCTFail("Operator folding on input source failed with error \(error)") @@ -139,6 +141,9 @@ func AssertMacroExpansion(macros: [String : Macro.Type], testModuleName: String for diagnostic in context.diagnostics { if !diagnostics.contains(where: { $0.matches(diagnostic) }) { XCTFail("Produced extra diagnostic:\n\(diagnostic._assertionDescription)", file: file, line: line) + } else { + let location = context.location(of: diagnostic.node, at: .afterLeadingTrivia, filePathMode: .fileID) + XCTAssertNotNil(location, "Produced diagnostic without attached source information:\n\(diagnostic._assertionDescription)", file: file, line: line) } } for diagnostic in diagnostics {