diff --git a/Release Notes/511.md b/Release Notes/511.md
index 7f751f7f8ed..7bf1ed9f395 100644
--- a/Release Notes/511.md
+++ b/Release Notes/511.md
@@ -35,6 +35,11 @@
- `String.isValidIdentifier(for:)`
- Description: `SwiftParser` adds an extension on `String` to check if it can be used as an identifier in a given context.
- Pull Request: https://github.com/apple/swift-syntax/pull/2434
+
+- `MacroDeclSyntax.expand`
+ - the `expand(argumentList:definition:replacements:)` method gains a new parameter 'genericReplacements:' that is defaulted to an empty array.
+ - The method's signature is now `expand(argumentList:definition:replacements:genericReplacements:)`
+ - Pull Request: https://github.com/apple/swift-syntax/pull/2450
- `SyntaxProtocol.asMacroLexicalContext()` and `allMacroLexicalContexts(enclosingSyntax:)`
- Description: Produce the lexical context for a given syntax node (if it has one), or the entire stack of lexical contexts enclosing a syntax node, for use in macro expansion.
@@ -67,6 +72,11 @@
## API-Incompatible Changes
+- `MacroDefinition` used for expanding macros:
+ - Description: The `MacroDefinition/expansion` enum case used to have two values (`(MacroExpansionExprSyntax, replacements: [Replacement])`), has now gained another value in order to support generic argument replacements in macro expansions: `(MacroExpansionExprSyntax, replacements: [Replacement], genericReplacements: [GenericArgumentReplacement])`
+ - Pull request: https://github.com/apple/swift-syntax/pull/2450
+ - Migration steps: Code which exhaustively checked over the enum should be changed to `case .expansion(let node, let replacements, let genericReplacements):`. Creating the `.extension` gained a compatibility shim, retaining the previous syntax source compatible (`return .expansion(node, replacements: [])`).
+
- Effect specifiers:
- Description: The `unexpectedAfterThrowsSpecifier` node of the various effect specifiers has been removed.
- Pull request: https://github.com/apple/swift-syntax/pull/2219
diff --git a/Sources/SwiftSyntaxMacroExpansion/MacroReplacement.swift b/Sources/SwiftSyntaxMacroExpansion/MacroReplacement.swift
index 5f89e7ece25..1aa33c632e9 100644
--- a/Sources/SwiftSyntaxMacroExpansion/MacroReplacement.swift
+++ b/Sources/SwiftSyntaxMacroExpansion/MacroReplacement.swift
@@ -18,6 +18,7 @@ enum MacroExpanderError: DiagnosticMessage {
case undefined
case definitionNotMacroExpansion
case nonParameterReference(TokenSyntax)
+ case nonTypeReference(TokenSyntax)
case nonLiteralOrParameter(ExprSyntax)
var message: String {
@@ -31,6 +32,9 @@ enum MacroExpanderError: DiagnosticMessage {
case .nonParameterReference(let name):
return "reference to value '\(name.text)' that is not a macro parameter in expansion"
+ case .nonTypeReference(let name):
+ return "reference to type '\(name)' that is not a macro type parameter in expansion"
+
case .nonLiteralOrParameter:
return "only literals and macro parameters are permitted in expansion"
}
@@ -58,7 +62,15 @@ public enum MacroDefinition {
/// defining macro. These subtrees will need to be replaced with the text of
/// the corresponding argument to the macro, which can be accomplished with
/// `MacroDeclSyntax.expandDefinition`.
- case expansion(MacroExpansionExprSyntax, replacements: [Replacement])
+ case expansion(MacroExpansionExprSyntax, replacements: [Replacement], genericReplacements: [GenericArgumentReplacement])
+}
+
+extension MacroDefinition {
+ /// Best effort compatibility shim, the case has gained additional parameters.
+ @available(*, deprecated, message: "Use the expansion case with three associated values instead")
+ public func expansion(_ node: MacroExpansionExprSyntax, replacements: [Replacement]) -> Self {
+ .expansion(node, replacements: replacements, genericReplacements: [])
+ }
}
extension MacroDefinition {
@@ -70,11 +82,21 @@ extension MacroDefinition {
/// The index of the parameter in the defining macro.
public let parameterIndex: Int
}
+
+ /// A replacement that occurs as part of an expanded macro definition.
+ public struct GenericArgumentReplacement {
+ /// A reference to a parameter as it occurs in the macro expansion expression.
+ public let reference: GenericArgumentSyntax
+
+ /// The index of the parameter in the defining macro.
+ public let parameterIndex: Int
+ }
}
fileprivate class ParameterReplacementVisitor: SyntaxAnyVisitor {
let macro: MacroDeclSyntax
var replacements: [MacroDefinition.Replacement] = []
+ var genericReplacements: [MacroDefinition.GenericArgumentReplacement] = []
var diagnostics: [Diagnostic] = []
init(macro: MacroDeclSyntax) {
@@ -156,6 +178,44 @@ fileprivate class ParameterReplacementVisitor: SyntaxAnyVisitor {
return .visitChildren
}
+ override func visit(_ node: GenericArgumentClauseSyntax) -> SyntaxVisitorContinueKind {
+ return .visitChildren
+ }
+
+ override func visit(_ node: GenericArgumentListSyntax) -> SyntaxVisitorContinueKind {
+ return .visitChildren
+ }
+
+ override func visit(_ node: GenericArgumentSyntax) -> SyntaxVisitorContinueKind {
+ guard let baseName = node.argument.as(IdentifierTypeSyntax.self)?.name else {
+ return .skipChildren
+ }
+
+ guard let genericParameterClause = macro.genericParameterClause else {
+ return .skipChildren
+ }
+
+ let matchedParameter = genericParameterClause.parameters.enumerated().first { (index, parameter) in
+ return parameter.name.text == baseName.text
+ }
+
+ guard let (parameterIndex, _) = matchedParameter else {
+ // We have a reference to something that isn't a parameter of the macro.
+ diagnostics.append(
+ Diagnostic(
+ node: Syntax(baseName),
+ message: MacroExpanderError.nonTypeReference(baseName)
+ )
+ )
+
+ return .visitChildren
+ }
+
+ genericReplacements.append(.init(reference: node, parameterIndex: parameterIndex))
+
+ return .visitChildren
+ }
+
override func visitAny(_ node: Syntax) -> SyntaxVisitorContinueKind {
if let expr = node.as(ExprSyntax.self) {
// We have an expression that is not one of the allowed forms, so
@@ -230,7 +290,7 @@ extension MacroDeclSyntax {
throw DiagnosticsError(diagnostics: visitor.diagnostics)
}
- return .expansion(definition, replacements: visitor.replacements)
+ return .expansion(definition, replacements: visitor.replacements, genericReplacements: visitor.genericReplacements)
}
}
@@ -239,10 +299,19 @@ extension MacroDeclSyntax {
private final class MacroExpansionRewriter: SyntaxRewriter {
let parameterReplacements: [DeclReferenceExprSyntax: Int]
let arguments: [ExprSyntax]
-
- init(parameterReplacements: [DeclReferenceExprSyntax: Int], arguments: [ExprSyntax]) {
+ let genericParameterReplacements: [GenericArgumentSyntax: Int]
+ let genericArguments: [TypeSyntax]
+
+ init(
+ parameterReplacements: [DeclReferenceExprSyntax: Int],
+ arguments: [ExprSyntax],
+ genericReplacements: [GenericArgumentSyntax: Int],
+ genericArguments: [TypeSyntax]
+ ) {
self.parameterReplacements = parameterReplacements
self.arguments = arguments
+ self.genericParameterReplacements = genericReplacements
+ self.genericArguments = genericArguments
super.init(viewMode: .sourceAccurate)
}
@@ -254,6 +323,21 @@ private final class MacroExpansionRewriter: SyntaxRewriter {
// Swap in the argument for this parameter
return arguments[parameterIndex].trimmed
}
+
+ override func visit(_ node: GenericArgumentSyntax) -> GenericArgumentSyntax {
+ guard let parameterIndex = genericParameterReplacements[node] else {
+ return super.visit(node)
+ }
+
+ guard parameterIndex < genericArguments.count else {
+ return super.visit(node)
+ }
+
+ // Swap in the argument for type parameter
+ var node = node
+ node.argument = genericArguments[parameterIndex].trimmed
+ return node
+ }
}
extension MacroDeclSyntax {
@@ -261,24 +345,40 @@ extension MacroDeclSyntax {
/// argument list.
private func expand(
argumentList: LabeledExprListSyntax?,
+ genericArgumentList: GenericArgumentClauseSyntax?,
definition: MacroExpansionExprSyntax,
- replacements: [MacroDefinition.Replacement]
+ replacements: [MacroDefinition.Replacement],
+ genericReplacements: [MacroDefinition.GenericArgumentReplacement] = []
) -> ExprSyntax {
// FIXME: Do real call-argument matching between the argument list and the
// macro parameter list, porting over from the compiler.
+ let parameterReplacements = Dictionary(
+ replacements.map { replacement in
+ (replacement.reference, replacement.parameterIndex)
+ },
+ uniquingKeysWith: { l, r in l }
+ )
let arguments: [ExprSyntax] =
argumentList?.map { element in
element.expression
} ?? []
- return MacroExpansionRewriter(
- parameterReplacements: Dictionary(
- uniqueKeysWithValues: replacements.map { replacement in
- (replacement.reference, replacement.parameterIndex)
- }
- ),
- arguments: arguments
- ).visit(definition)
+ let genericReplacements = Dictionary(
+ genericReplacements.map { replacement in
+ (replacement.reference, replacement.parameterIndex)
+ },
+ uniquingKeysWith: { l, r in l }
+ )
+ let genericArguments: [TypeSyntax] =
+ genericArgumentList?.arguments.map { $0.argument } ?? []
+
+ let rewriter = MacroExpansionRewriter(
+ parameterReplacements: parameterReplacements,
+ arguments: arguments,
+ genericReplacements: genericReplacements,
+ genericArguments: genericArguments
+ )
+ return rewriter.visit(definition)
}
/// Given a freestanding macro expansion syntax node that references this
@@ -287,12 +387,15 @@ extension MacroDeclSyntax {
public func expand(
_ node: some FreestandingMacroExpansionSyntax,
definition: MacroExpansionExprSyntax,
- replacements: [MacroDefinition.Replacement]
+ replacements: [MacroDefinition.Replacement],
+ genericReplacements: [MacroDefinition.GenericArgumentReplacement] = []
) -> ExprSyntax {
return expand(
argumentList: node.arguments,
+ genericArgumentList: node.genericArgumentClause,
definition: definition,
- replacements: replacements
+ replacements: replacements,
+ genericReplacements: genericReplacements
)
}
@@ -302,7 +405,8 @@ extension MacroDeclSyntax {
public func expand(
_ node: AttributeSyntax,
definition: MacroExpansionExprSyntax,
- replacements: [MacroDefinition.Replacement]
+ replacements: [MacroDefinition.Replacement],
+ genericReplacements: [MacroDefinition.GenericArgumentReplacement] = []
) -> ExprSyntax {
// Dig out the argument list.
let argumentList: LabeledExprListSyntax?
@@ -314,8 +418,10 @@ extension MacroDeclSyntax {
return expand(
argumentList: argumentList,
+ genericArgumentList: .init(arguments: []),
definition: definition,
- replacements: replacements
+ replacements: replacements,
+ genericReplacements: genericReplacements
)
}
}
diff --git a/Tests/SwiftSyntaxMacroExpansionTest/MacroReplacementTests.swift b/Tests/SwiftSyntaxMacroExpansionTest/MacroReplacementTests.swift
index d6a4d6e45a7..a564a2c3284 100644
--- a/Tests/SwiftSyntaxMacroExpansionTest/MacroReplacementTests.swift
+++ b/Tests/SwiftSyntaxMacroExpansionTest/MacroReplacementTests.swift
@@ -24,8 +24,8 @@ final class MacroReplacementTests: XCTestCase {
macro expand1(a: Int, b: Int) = #otherMacro(first: b, second: ["a": a], third: [3.14159, 2.71828], fourth: 4)
"""
- let definition = try macro.as(MacroDeclSyntax.self)!.checkDefinition()
- guard case let .expansion(_, replacements) = definition else {
+ let definition = try macro.cast(MacroDeclSyntax.self).checkDefinition()
+ guard case let .expansion(_, replacements, _) = definition else {
XCTFail("not an expansion definition")
fatalError()
}
@@ -43,7 +43,7 @@ final class MacroReplacementTests: XCTestCase {
let diags: [Diagnostic]
do {
- _ = try macro.as(MacroDeclSyntax.self)!.checkDefinition()
+ _ = try macro.cast(MacroDeclSyntax.self).checkDefinition()
XCTFail("should have failed with an error")
fatalError()
} catch let diagError as DiagnosticsError {
@@ -69,7 +69,7 @@ final class MacroReplacementTests: XCTestCase {
let diags: [Diagnostic]
do {
- _ = try macro.as(MacroDeclSyntax.self)!.checkDefinition()
+ _ = try macro.cast(MacroDeclSyntax.self).checkDefinition()
XCTFail("should have failed with an error")
fatalError()
} catch let diagError as DiagnosticsError {
@@ -94,9 +94,9 @@ final class MacroReplacementTests: XCTestCase {
#expand1(a: 5, b: 17)
"""
- let macroDecl = macro.as(MacroDeclSyntax.self)!
+ let macroDecl = macro.cast(MacroDeclSyntax.self)
let definition = try macroDecl.checkDefinition()
- guard case let .expansion(expansion, replacements) = definition else {
+ guard case let .expansion(expansion, replacements, genericReplacements) = definition else {
XCTFail("not a normal expansion")
fatalError()
}
@@ -104,7 +104,8 @@ final class MacroReplacementTests: XCTestCase {
let expandedSyntax = macroDecl.expand(
use.as(MacroExpansionExprSyntax.self)!,
definition: expansion,
- replacements: replacements
+ replacements: replacements,
+ genericReplacements: genericReplacements
)
assertStringsEqualWithDiff(
expandedSyntax.description,
@@ -113,4 +114,230 @@ final class MacroReplacementTests: XCTestCase {
"""
)
}
+
+ func testMacroGenericArgumentExpansionBase() throws {
+ let macro: DeclSyntax =
+ """
+ macro gen(a: A, b: B) = #otherMacro(first: a, second: b)
+ """
+
+ let use: ExprSyntax =
+ """
+ #gen(a: 5, b: "Hello")
+ """
+
+ let macroDecl = macro.cast(MacroDeclSyntax.self)
+ let definition = try macroDecl.checkDefinition()
+ guard case let .expansion(expansion, replacements, genericReplacements) = definition else {
+ XCTFail("not a normal expansion")
+ return
+ }
+
+ let replacementA = try XCTUnwrap(genericReplacements.first)
+ let replacementB = try XCTUnwrap(genericReplacements.dropFirst().first)
+
+ XCTAssertEqual(genericReplacements.count, 2)
+
+ XCTAssertEqual(replacementA.parameterIndex, 0)
+ XCTAssertEqual("\(replacementA.reference.argument)", "A")
+
+ XCTAssertEqual(replacementB.parameterIndex, 1)
+ XCTAssertEqual("\(replacementB.reference.argument)", "B")
+
+ let expandedSyntax = macroDecl.expand(
+ use.as(MacroExpansionExprSyntax.self)!,
+ definition: expansion,
+ replacements: replacements,
+ genericReplacements: genericReplacements
+ )
+ assertStringsEqualWithDiff(
+ expandedSyntax.description,
+ """
+ #otherMacro(first: 5, second: "Hello")
+ """
+ )
+ }
+
+ func testMacroGenericArgumentExpansionIgnoreTrivia() throws {
+ let macro: DeclSyntax =
+ """
+ macro gen(a: A, b: B) = #otherMacro(first: a, second: b)
+ """
+
+ let use: ExprSyntax =
+ """
+ #gen(a: 5, b: "Hello")
+ """
+
+ let macroDecl = macro.cast(MacroDeclSyntax.self)
+ let definition = try macroDecl.checkDefinition()
+ guard case let .expansion(expansion, replacements, genericReplacements) = definition else {
+ XCTFail("not a normal expansion")
+ return
+ }
+
+ let replacementA = try XCTUnwrap(genericReplacements.first)
+ let replacementB = try XCTUnwrap(genericReplacements.dropFirst().first)
+ XCTAssertEqual(genericReplacements.count, 2)
+
+ XCTAssertEqual(replacementA.parameterIndex, 0)
+ XCTAssertEqual("\(replacementA.reference.argument)", "A")
+
+ XCTAssertEqual(replacementB.parameterIndex, 1)
+ XCTAssertEqual("\(replacementB.reference.argument)", "B")
+
+ let expandedSyntax = macroDecl.expand(
+ use.as(MacroExpansionExprSyntax.self)!,
+ definition: expansion,
+ replacements: replacements,
+ genericReplacements: genericReplacements
+ )
+ assertStringsEqualWithDiff(
+ expandedSyntax.description,
+ """
+ #otherMacro(first: 5, second: "Hello")
+ """
+ )
+ }
+
+ func testMacroGenericArgumentExpansionNotVisitGenericParameterArguments() throws {
+ let macro: DeclSyntax =
+ """
+ macro gen(a: Array) = #otherMacro(first: a)
+ """
+
+ let use: ExprSyntax =
+ """
+ #gen(a: [1, 2, 3])
+ """
+
+ let macroDecl = macro.cast(MacroDeclSyntax.self)
+ let definition = try macroDecl.checkDefinition()
+ guard case let .expansion(expansion, replacements, genericReplacements) = definition else {
+ XCTFail("not a normal expansion")
+ return
+ }
+
+ XCTAssertEqual(genericReplacements.count, 0)
+
+ let expandedSyntax = macroDecl.expand(
+ use.as(MacroExpansionExprSyntax.self)!,
+ definition: expansion,
+ replacements: replacements,
+ genericReplacements: genericReplacements
+ )
+ assertStringsEqualWithDiff(
+ expandedSyntax.description,
+ """
+ #otherMacro(first: [1, 2, 3])
+ """
+ )
+ }
+
+ func testMacroGenericArgumentExpansionReplaceInner() throws {
+ let macro: DeclSyntax =
+ """
+ macro gen(a: Array) = #reduce(first: a)
+ """
+
+ let use: ExprSyntax =
+ """
+ #gen(a: [1, 2, 3])
+ """
+
+ let macroDecl = macro.cast(MacroDeclSyntax.self)
+ let definition = try macroDecl.checkDefinition()
+ guard case let .expansion(expansion, replacements, genericReplacements) = definition else {
+ XCTFail("not a normal expansion")
+ return
+ }
+
+ XCTAssertEqual(genericReplacements.count, 1)
+
+ let expandedSyntax = macroDecl.expand(
+ use.as(MacroExpansionExprSyntax.self)!,
+ definition: expansion,
+ replacements: replacements,
+ genericReplacements: genericReplacements
+ )
+ assertStringsEqualWithDiff(
+ expandedSyntax.description,
+ """
+ #reduce(first: [1, 2, 3])
+ """
+ )
+ }
+
+ func testMacroGenericArgumentExpansionArray() throws {
+ let macro: DeclSyntax =
+ """
+ macro gen(a: Array) = #other(first: a)
+ """
+
+ let use: ExprSyntax =
+ """
+ #gen(a: [1, 2, 3])
+ """
+
+ let macroDecl = macro.cast(MacroDeclSyntax.self)
+ let definition = try macroDecl.checkDefinition()
+ guard case let .expansion(expansion, replacements, genericReplacements) = definition else {
+ XCTFail("not a normal expansion")
+ return
+ }
+
+ XCTAssertEqual(genericReplacements.count, 0)
+
+ let expandedSyntax = macroDecl.expand(
+ use.as(MacroExpansionExprSyntax.self)!,
+ definition: expansion,
+ replacements: replacements,
+ genericReplacements: genericReplacements
+ )
+ assertStringsEqualWithDiff(
+ expandedSyntax.description,
+ """
+ #other(first: [1, 2, 3])
+ """
+ )
+ }
+
+ func testMacroExpansionDontCrashOnDuplicates() throws {
+ let macro: DeclSyntax =
+ """
+ macro gen(a: Array) = #other(first: a)
+ """
+
+ let use: ExprSyntax =
+ """
+ #gen(a: [1, 2, 3])
+ """
+
+ let macroDecl = macro.cast(MacroDeclSyntax.self)
+ let definition = try macroDecl.checkDefinition()
+ guard case let .expansion(expansion, replacements, genericReplacements) = definition else {
+ XCTFail("not a normal expansion")
+ return
+ }
+
+ var replacementsWithDupes = replacements
+ replacementsWithDupes.append(contentsOf: replacements)
+ var genericReplacementsWithDupes = genericReplacements
+ genericReplacementsWithDupes.append(contentsOf: genericReplacements)
+
+ XCTAssertEqual(genericReplacements.count, 0)
+
+ let expandedSyntax = macroDecl.expand(
+ use.as(MacroExpansionExprSyntax.self)!,
+ definition: expansion,
+ replacements: replacementsWithDupes,
+ genericReplacements: genericReplacementsWithDupes
+ )
+ assertStringsEqualWithDiff(
+ expandedSyntax.description,
+ """
+ #other(first: [1, 2, 3])
+ """
+ )
+ }
}