Skip to content

Ensure predicate diagnostics contain source information when possible #1025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ var dependencies: [Package.Dependency] {
branch: "main"),
.package(
url: "https://github.com/swiftlang/swift-syntax",
from: "600.0.0")
branch: "main")
]
}
}
Expand Down
15 changes: 13 additions & 2 deletions Sources/FoundationMacros/PredicateMacro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -229,20 +229,25 @@ extension SyntaxProtocol {

private protocol PredicateSyntaxRewriter : SyntaxRewriter {
var success: Bool { get }
var ignorable: Bool { get }
var diagnostics: [Diagnostic] { get }
}

extension PredicateSyntaxRewriter {
var success: Bool { true }
var ignorable: Bool { false }
var diagnostics: [Diagnostic] { [] }
}

extension SyntaxProtocol {
fileprivate func rewrite(with rewriter: some PredicateSyntaxRewriter) throws -> Syntax {
let translated = rewriter.rewrite(Syntax(self))
let translated = rewriter.rewrite(self)
guard rewriter.success else {
throw DiagnosticsError(diagnostics: rewriter.diagnostics)
}
guard !rewriter.ignorable else {
return Syntax(self)
}
return translated
}
}
Expand All @@ -251,6 +256,7 @@ private class OptionalChainRewriter: SyntaxRewriter, PredicateSyntaxRewriter {
var withinValidChainingTreeStart = true
var withinChainingTree = false
var optionalInput: ExprSyntax? = nil
var ignorable = true

private func _prePossibleTopOfTree() -> Bool {
if !withinChainingTree && withinValidChainingTreeStart {
Expand All @@ -265,6 +271,7 @@ private class OptionalChainRewriter: SyntaxRewriter, PredicateSyntaxRewriter {
withinChainingTree = false
if let input = optionalInput {
optionalInput = nil
ignorable = false
let visited = self.visit(input)
let closure = ClosureExprSyntax(statements: [CodeBlockItemSyntax(item: CodeBlockItemSyntax.Item(node))])
let functionMember = MemberAccessExprSyntax(base: visited, name: "flatMap")
Expand All @@ -282,10 +289,14 @@ private class OptionalChainRewriter: SyntaxRewriter, PredicateSyntaxRewriter {

// We're in the middle of a potential tree, so rewrite the closure with a fresh state
// This ensures potential chaining in the closure isn't rewritten outside of the closure
guard let rewritten = (try? node.rewrite(with: OptionalChainRewriter()))?.as(ExprSyntax.self) else {
let nestedRewriter = OptionalChainRewriter()
guard let rewritten = (try? node.rewrite(with: nestedRewriter))?.as(ExprSyntax.self) else {
// If rewriting the closure failed, just leave the closure as-is
return ExprSyntax(node)
}
if ignorable {
ignorable = nestedRewriter.ignorable
}
return rewritten
}

Expand Down
11 changes: 8 additions & 3 deletions Tests/FoundationMacrosTests/MacroTestUtilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,14 @@ extension DiagnosticTest {
}

func AssertMacroExpansion(macros: [String : Macro.Type], testModuleName: String = "TestModule", testFileName: String = "test.swift", _ source: String, _ result: String = "", diagnostics: Set<DiagnosticTest> = [], file: StaticString = #filePath, line: UInt = #line) {
let context = BasicMacroExpansionContext()
let origSourceFile = Parser.parse(source: source)
let expandedSourceFile: Syntax
let context: BasicMacroExpansionContext
do {
expandedSourceFile = try OperatorTable.standardOperators.foldAll(origSourceFile).expand(macros: macros) { syntax in
BasicMacroExpansionContext(sharingWith: context, lexicalContext: [syntax])
let foldedSourceFile = try OperatorTable.standardOperators.foldAll(origSourceFile).cast(SourceFileSyntax.self)
context = BasicMacroExpansionContext(sourceFiles: [foldedSourceFile : .init(moduleName: testModuleName, fullFilePath: testFileName)])
expandedSourceFile = foldedSourceFile.expand(macros: macros) {
BasicMacroExpansionContext(sharingWith: context, lexicalContext: [$0])
}
} catch {
XCTFail("Operator folding on input source failed with error \(error)")
Expand All @@ -139,6 +141,9 @@ func AssertMacroExpansion(macros: [String : Macro.Type], testModuleName: String
for diagnostic in context.diagnostics {
if !diagnostics.contains(where: { $0.matches(diagnostic) }) {
XCTFail("Produced extra diagnostic:\n\(diagnostic._assertionDescription)", file: file, line: line)
} else {
let location = context.location(of: diagnostic.node, at: .afterLeadingTrivia, filePathMode: .fileID)
XCTAssertNotNil(location, "Produced diagnostic without attached source information:\n\(diagnostic._assertionDescription)", file: file, line: line)
}
}
for diagnostic in diagnostics {
Expand Down