Skip to content

Commit 919053f

Browse files
authored
Ensure predicate diagnostics contain source information when possible (#1025)
1 parent 895bc7e commit 919053f

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ var dependencies: [Package.Dependency] {
6565
branch: "main"),
6666
.package(
6767
url: "https://github.com/swiftlang/swift-syntax",
68-
from: "600.0.0")
68+
branch: "main")
6969
]
7070
}
7171
}

Sources/FoundationMacros/PredicateMacro.swift

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,20 +229,25 @@ extension SyntaxProtocol {
229229

230230
private protocol PredicateSyntaxRewriter : SyntaxRewriter {
231231
var success: Bool { get }
232+
var ignorable: Bool { get }
232233
var diagnostics: [Diagnostic] { get }
233234
}
234235

235236
extension PredicateSyntaxRewriter {
236237
var success: Bool { true }
238+
var ignorable: Bool { false }
237239
var diagnostics: [Diagnostic] { [] }
238240
}
239241

240242
extension SyntaxProtocol {
241243
fileprivate func rewrite(with rewriter: some PredicateSyntaxRewriter) throws -> Syntax {
242-
let translated = rewriter.rewrite(Syntax(self))
244+
let translated = rewriter.rewrite(self)
243245
guard rewriter.success else {
244246
throw DiagnosticsError(diagnostics: rewriter.diagnostics)
245247
}
248+
guard !rewriter.ignorable else {
249+
return Syntax(self)
250+
}
246251
return translated
247252
}
248253
}
@@ -251,6 +256,7 @@ private class OptionalChainRewriter: SyntaxRewriter, PredicateSyntaxRewriter {
251256
var withinValidChainingTreeStart = true
252257
var withinChainingTree = false
253258
var optionalInput: ExprSyntax? = nil
259+
var ignorable = true
254260

255261
private func _prePossibleTopOfTree() -> Bool {
256262
if !withinChainingTree && withinValidChainingTreeStart {
@@ -265,6 +271,7 @@ private class OptionalChainRewriter: SyntaxRewriter, PredicateSyntaxRewriter {
265271
withinChainingTree = false
266272
if let input = optionalInput {
267273
optionalInput = nil
274+
ignorable = false
268275
let visited = self.visit(input)
269276
let closure = ClosureExprSyntax(statements: [CodeBlockItemSyntax(item: CodeBlockItemSyntax.Item(node))])
270277
let functionMember = MemberAccessExprSyntax(base: visited, name: "flatMap")
@@ -282,10 +289,14 @@ private class OptionalChainRewriter: SyntaxRewriter, PredicateSyntaxRewriter {
282289

283290
// We're in the middle of a potential tree, so rewrite the closure with a fresh state
284291
// This ensures potential chaining in the closure isn't rewritten outside of the closure
285-
guard let rewritten = (try? node.rewrite(with: OptionalChainRewriter()))?.as(ExprSyntax.self) else {
292+
let nestedRewriter = OptionalChainRewriter()
293+
guard let rewritten = (try? node.rewrite(with: nestedRewriter))?.as(ExprSyntax.self) else {
286294
// If rewriting the closure failed, just leave the closure as-is
287295
return ExprSyntax(node)
288296
}
297+
if ignorable {
298+
ignorable = nestedRewriter.ignorable
299+
}
289300
return rewritten
290301
}
291302

Tests/FoundationMacrosTests/MacroTestUtilities.swift

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,14 @@ extension DiagnosticTest {
121121
}
122122

123123
func AssertMacroExpansion(macros: [String : Macro.Type], testModuleName: String = "TestModule", testFileName: String = "test.swift", _ source: String, _ result: String = "", diagnostics: Set<DiagnosticTest> = [], file: StaticString = #filePath, line: UInt = #line) {
124-
let context = BasicMacroExpansionContext()
125124
let origSourceFile = Parser.parse(source: source)
126125
let expandedSourceFile: Syntax
126+
let context: BasicMacroExpansionContext
127127
do {
128-
expandedSourceFile = try OperatorTable.standardOperators.foldAll(origSourceFile).expand(macros: macros) { syntax in
129-
BasicMacroExpansionContext(sharingWith: context, lexicalContext: [syntax])
128+
let foldedSourceFile = try OperatorTable.standardOperators.foldAll(origSourceFile).cast(SourceFileSyntax.self)
129+
context = BasicMacroExpansionContext(sourceFiles: [foldedSourceFile : .init(moduleName: testModuleName, fullFilePath: testFileName)])
130+
expandedSourceFile = foldedSourceFile.expand(macros: macros) {
131+
BasicMacroExpansionContext(sharingWith: context, lexicalContext: [$0])
130132
}
131133
} catch {
132134
XCTFail("Operator folding on input source failed with error \(error)")
@@ -139,6 +141,9 @@ func AssertMacroExpansion(macros: [String : Macro.Type], testModuleName: String
139141
for diagnostic in context.diagnostics {
140142
if !diagnostics.contains(where: { $0.matches(diagnostic) }) {
141143
XCTFail("Produced extra diagnostic:\n\(diagnostic._assertionDescription)", file: file, line: line)
144+
} else {
145+
let location = context.location(of: diagnostic.node, at: .afterLeadingTrivia, filePathMode: .fileID)
146+
XCTAssertNotNil(location, "Produced diagnostic without attached source information:\n\(diagnostic._assertionDescription)", file: file, line: line)
142147
}
143148
}
144149
for diagnostic in diagnostics {

0 commit comments

Comments
 (0)