diff --git a/Release Notes/511.md b/Release Notes/511.md index 7f751f7f8ed..7bf1ed9f395 100644 --- a/Release Notes/511.md +++ b/Release Notes/511.md @@ -35,6 +35,11 @@ - `String.isValidIdentifier(for:)` - Description: `SwiftParser` adds an extension on `String` to check if it can be used as an identifier in a given context. - Pull Request: https://github.com/apple/swift-syntax/pull/2434 + +- `MacroDeclSyntax.expand` + - the `expand(argumentList:definition:replacements:)` method gains a new parameter 'genericReplacements:' that is defaulted to an empty array. + - The method's signature is now `expand(argumentList:definition:replacements:genericReplacements:)` + - Pull Request: https://github.com/apple/swift-syntax/pull/2450 - `SyntaxProtocol.asMacroLexicalContext()` and `allMacroLexicalContexts(enclosingSyntax:)` - Description: Produce the lexical context for a given syntax node (if it has one), or the entire stack of lexical contexts enclosing a syntax node, for use in macro expansion. @@ -67,6 +72,11 @@ ## API-Incompatible Changes +- `MacroDefinition` used for expanding macros: + - Description: The `MacroDefinition/expansion` enum case used to have two values (`(MacroExpansionExprSyntax, replacements: [Replacement])`), has now gained another value in order to support generic argument replacements in macro expansions: `(MacroExpansionExprSyntax, replacements: [Replacement], genericReplacements: [GenericArgumentReplacement])` + - Pull request: https://github.com/apple/swift-syntax/pull/2450 + - Migration steps: Code which exhaustively checked over the enum should be changed to `case .expansion(let node, let replacements, let genericReplacements):`. Creating the `.extension` gained a compatibility shim, retaining the previous syntax source compatible (`return .expansion(node, replacements: [])`). + - Effect specifiers: - Description: The `unexpectedAfterThrowsSpecifier` node of the various effect specifiers has been removed. - Pull request: https://github.com/apple/swift-syntax/pull/2219 diff --git a/Sources/SwiftSyntaxMacroExpansion/MacroReplacement.swift b/Sources/SwiftSyntaxMacroExpansion/MacroReplacement.swift index 5f89e7ece25..1aa33c632e9 100644 --- a/Sources/SwiftSyntaxMacroExpansion/MacroReplacement.swift +++ b/Sources/SwiftSyntaxMacroExpansion/MacroReplacement.swift @@ -18,6 +18,7 @@ enum MacroExpanderError: DiagnosticMessage { case undefined case definitionNotMacroExpansion case nonParameterReference(TokenSyntax) + case nonTypeReference(TokenSyntax) case nonLiteralOrParameter(ExprSyntax) var message: String { @@ -31,6 +32,9 @@ enum MacroExpanderError: DiagnosticMessage { case .nonParameterReference(let name): return "reference to value '\(name.text)' that is not a macro parameter in expansion" + case .nonTypeReference(let name): + return "reference to type '\(name)' that is not a macro type parameter in expansion" + case .nonLiteralOrParameter: return "only literals and macro parameters are permitted in expansion" } @@ -58,7 +62,15 @@ public enum MacroDefinition { /// defining macro. These subtrees will need to be replaced with the text of /// the corresponding argument to the macro, which can be accomplished with /// `MacroDeclSyntax.expandDefinition`. - case expansion(MacroExpansionExprSyntax, replacements: [Replacement]) + case expansion(MacroExpansionExprSyntax, replacements: [Replacement], genericReplacements: [GenericArgumentReplacement]) +} + +extension MacroDefinition { + /// Best effort compatibility shim, the case has gained additional parameters. + @available(*, deprecated, message: "Use the expansion case with three associated values instead") + public func expansion(_ node: MacroExpansionExprSyntax, replacements: [Replacement]) -> Self { + .expansion(node, replacements: replacements, genericReplacements: []) + } } extension MacroDefinition { @@ -70,11 +82,21 @@ extension MacroDefinition { /// The index of the parameter in the defining macro. public let parameterIndex: Int } + + /// A replacement that occurs as part of an expanded macro definition. + public struct GenericArgumentReplacement { + /// A reference to a parameter as it occurs in the macro expansion expression. + public let reference: GenericArgumentSyntax + + /// The index of the parameter in the defining macro. + public let parameterIndex: Int + } } fileprivate class ParameterReplacementVisitor: SyntaxAnyVisitor { let macro: MacroDeclSyntax var replacements: [MacroDefinition.Replacement] = [] + var genericReplacements: [MacroDefinition.GenericArgumentReplacement] = [] var diagnostics: [Diagnostic] = [] init(macro: MacroDeclSyntax) { @@ -156,6 +178,44 @@ fileprivate class ParameterReplacementVisitor: SyntaxAnyVisitor { return .visitChildren } + override func visit(_ node: GenericArgumentClauseSyntax) -> SyntaxVisitorContinueKind { + return .visitChildren + } + + override func visit(_ node: GenericArgumentListSyntax) -> SyntaxVisitorContinueKind { + return .visitChildren + } + + override func visit(_ node: GenericArgumentSyntax) -> SyntaxVisitorContinueKind { + guard let baseName = node.argument.as(IdentifierTypeSyntax.self)?.name else { + return .skipChildren + } + + guard let genericParameterClause = macro.genericParameterClause else { + return .skipChildren + } + + let matchedParameter = genericParameterClause.parameters.enumerated().first { (index, parameter) in + return parameter.name.text == baseName.text + } + + guard let (parameterIndex, _) = matchedParameter else { + // We have a reference to something that isn't a parameter of the macro. + diagnostics.append( + Diagnostic( + node: Syntax(baseName), + message: MacroExpanderError.nonTypeReference(baseName) + ) + ) + + return .visitChildren + } + + genericReplacements.append(.init(reference: node, parameterIndex: parameterIndex)) + + return .visitChildren + } + override func visitAny(_ node: Syntax) -> SyntaxVisitorContinueKind { if let expr = node.as(ExprSyntax.self) { // We have an expression that is not one of the allowed forms, so @@ -230,7 +290,7 @@ extension MacroDeclSyntax { throw DiagnosticsError(diagnostics: visitor.diagnostics) } - return .expansion(definition, replacements: visitor.replacements) + return .expansion(definition, replacements: visitor.replacements, genericReplacements: visitor.genericReplacements) } } @@ -239,10 +299,19 @@ extension MacroDeclSyntax { private final class MacroExpansionRewriter: SyntaxRewriter { let parameterReplacements: [DeclReferenceExprSyntax: Int] let arguments: [ExprSyntax] - - init(parameterReplacements: [DeclReferenceExprSyntax: Int], arguments: [ExprSyntax]) { + let genericParameterReplacements: [GenericArgumentSyntax: Int] + let genericArguments: [TypeSyntax] + + init( + parameterReplacements: [DeclReferenceExprSyntax: Int], + arguments: [ExprSyntax], + genericReplacements: [GenericArgumentSyntax: Int], + genericArguments: [TypeSyntax] + ) { self.parameterReplacements = parameterReplacements self.arguments = arguments + self.genericParameterReplacements = genericReplacements + self.genericArguments = genericArguments super.init(viewMode: .sourceAccurate) } @@ -254,6 +323,21 @@ private final class MacroExpansionRewriter: SyntaxRewriter { // Swap in the argument for this parameter return arguments[parameterIndex].trimmed } + + override func visit(_ node: GenericArgumentSyntax) -> GenericArgumentSyntax { + guard let parameterIndex = genericParameterReplacements[node] else { + return super.visit(node) + } + + guard parameterIndex < genericArguments.count else { + return super.visit(node) + } + + // Swap in the argument for type parameter + var node = node + node.argument = genericArguments[parameterIndex].trimmed + return node + } } extension MacroDeclSyntax { @@ -261,24 +345,40 @@ extension MacroDeclSyntax { /// argument list. private func expand( argumentList: LabeledExprListSyntax?, + genericArgumentList: GenericArgumentClauseSyntax?, definition: MacroExpansionExprSyntax, - replacements: [MacroDefinition.Replacement] + replacements: [MacroDefinition.Replacement], + genericReplacements: [MacroDefinition.GenericArgumentReplacement] = [] ) -> ExprSyntax { // FIXME: Do real call-argument matching between the argument list and the // macro parameter list, porting over from the compiler. + let parameterReplacements = Dictionary( + replacements.map { replacement in + (replacement.reference, replacement.parameterIndex) + }, + uniquingKeysWith: { l, r in l } + ) let arguments: [ExprSyntax] = argumentList?.map { element in element.expression } ?? [] - return MacroExpansionRewriter( - parameterReplacements: Dictionary( - uniqueKeysWithValues: replacements.map { replacement in - (replacement.reference, replacement.parameterIndex) - } - ), - arguments: arguments - ).visit(definition) + let genericReplacements = Dictionary( + genericReplacements.map { replacement in + (replacement.reference, replacement.parameterIndex) + }, + uniquingKeysWith: { l, r in l } + ) + let genericArguments: [TypeSyntax] = + genericArgumentList?.arguments.map { $0.argument } ?? [] + + let rewriter = MacroExpansionRewriter( + parameterReplacements: parameterReplacements, + arguments: arguments, + genericReplacements: genericReplacements, + genericArguments: genericArguments + ) + return rewriter.visit(definition) } /// Given a freestanding macro expansion syntax node that references this @@ -287,12 +387,15 @@ extension MacroDeclSyntax { public func expand( _ node: some FreestandingMacroExpansionSyntax, definition: MacroExpansionExprSyntax, - replacements: [MacroDefinition.Replacement] + replacements: [MacroDefinition.Replacement], + genericReplacements: [MacroDefinition.GenericArgumentReplacement] = [] ) -> ExprSyntax { return expand( argumentList: node.arguments, + genericArgumentList: node.genericArgumentClause, definition: definition, - replacements: replacements + replacements: replacements, + genericReplacements: genericReplacements ) } @@ -302,7 +405,8 @@ extension MacroDeclSyntax { public func expand( _ node: AttributeSyntax, definition: MacroExpansionExprSyntax, - replacements: [MacroDefinition.Replacement] + replacements: [MacroDefinition.Replacement], + genericReplacements: [MacroDefinition.GenericArgumentReplacement] = [] ) -> ExprSyntax { // Dig out the argument list. let argumentList: LabeledExprListSyntax? @@ -314,8 +418,10 @@ extension MacroDeclSyntax { return expand( argumentList: argumentList, + genericArgumentList: .init(arguments: []), definition: definition, - replacements: replacements + replacements: replacements, + genericReplacements: genericReplacements ) } } diff --git a/Tests/SwiftSyntaxMacroExpansionTest/MacroReplacementTests.swift b/Tests/SwiftSyntaxMacroExpansionTest/MacroReplacementTests.swift index d6a4d6e45a7..a564a2c3284 100644 --- a/Tests/SwiftSyntaxMacroExpansionTest/MacroReplacementTests.swift +++ b/Tests/SwiftSyntaxMacroExpansionTest/MacroReplacementTests.swift @@ -24,8 +24,8 @@ final class MacroReplacementTests: XCTestCase { macro expand1(a: Int, b: Int) = #otherMacro(first: b, second: ["a": a], third: [3.14159, 2.71828], fourth: 4) """ - let definition = try macro.as(MacroDeclSyntax.self)!.checkDefinition() - guard case let .expansion(_, replacements) = definition else { + let definition = try macro.cast(MacroDeclSyntax.self).checkDefinition() + guard case let .expansion(_, replacements, _) = definition else { XCTFail("not an expansion definition") fatalError() } @@ -43,7 +43,7 @@ final class MacroReplacementTests: XCTestCase { let diags: [Diagnostic] do { - _ = try macro.as(MacroDeclSyntax.self)!.checkDefinition() + _ = try macro.cast(MacroDeclSyntax.self).checkDefinition() XCTFail("should have failed with an error") fatalError() } catch let diagError as DiagnosticsError { @@ -69,7 +69,7 @@ final class MacroReplacementTests: XCTestCase { let diags: [Diagnostic] do { - _ = try macro.as(MacroDeclSyntax.self)!.checkDefinition() + _ = try macro.cast(MacroDeclSyntax.self).checkDefinition() XCTFail("should have failed with an error") fatalError() } catch let diagError as DiagnosticsError { @@ -94,9 +94,9 @@ final class MacroReplacementTests: XCTestCase { #expand1(a: 5, b: 17) """ - let macroDecl = macro.as(MacroDeclSyntax.self)! + let macroDecl = macro.cast(MacroDeclSyntax.self) let definition = try macroDecl.checkDefinition() - guard case let .expansion(expansion, replacements) = definition else { + guard case let .expansion(expansion, replacements, genericReplacements) = definition else { XCTFail("not a normal expansion") fatalError() } @@ -104,7 +104,8 @@ final class MacroReplacementTests: XCTestCase { let expandedSyntax = macroDecl.expand( use.as(MacroExpansionExprSyntax.self)!, definition: expansion, - replacements: replacements + replacements: replacements, + genericReplacements: genericReplacements ) assertStringsEqualWithDiff( expandedSyntax.description, @@ -113,4 +114,230 @@ final class MacroReplacementTests: XCTestCase { """ ) } + + func testMacroGenericArgumentExpansionBase() throws { + let macro: DeclSyntax = + """ + macro gen(a: A, b: B) = #otherMacro(first: a, second: b) + """ + + let use: ExprSyntax = + """ + #gen(a: 5, b: "Hello") + """ + + let macroDecl = macro.cast(MacroDeclSyntax.self) + let definition = try macroDecl.checkDefinition() + guard case let .expansion(expansion, replacements, genericReplacements) = definition else { + XCTFail("not a normal expansion") + return + } + + let replacementA = try XCTUnwrap(genericReplacements.first) + let replacementB = try XCTUnwrap(genericReplacements.dropFirst().first) + + XCTAssertEqual(genericReplacements.count, 2) + + XCTAssertEqual(replacementA.parameterIndex, 0) + XCTAssertEqual("\(replacementA.reference.argument)", "A") + + XCTAssertEqual(replacementB.parameterIndex, 1) + XCTAssertEqual("\(replacementB.reference.argument)", "B") + + let expandedSyntax = macroDecl.expand( + use.as(MacroExpansionExprSyntax.self)!, + definition: expansion, + replacements: replacements, + genericReplacements: genericReplacements + ) + assertStringsEqualWithDiff( + expandedSyntax.description, + """ + #otherMacro(first: 5, second: "Hello") + """ + ) + } + + func testMacroGenericArgumentExpansionIgnoreTrivia() throws { + let macro: DeclSyntax = + """ + macro gen(a: A, b: B) = #otherMacro(first: a, second: b) + """ + + let use: ExprSyntax = + """ + #gen(a: 5, b: "Hello") + """ + + let macroDecl = macro.cast(MacroDeclSyntax.self) + let definition = try macroDecl.checkDefinition() + guard case let .expansion(expansion, replacements, genericReplacements) = definition else { + XCTFail("not a normal expansion") + return + } + + let replacementA = try XCTUnwrap(genericReplacements.first) + let replacementB = try XCTUnwrap(genericReplacements.dropFirst().first) + XCTAssertEqual(genericReplacements.count, 2) + + XCTAssertEqual(replacementA.parameterIndex, 0) + XCTAssertEqual("\(replacementA.reference.argument)", "A") + + XCTAssertEqual(replacementB.parameterIndex, 1) + XCTAssertEqual("\(replacementB.reference.argument)", "B") + + let expandedSyntax = macroDecl.expand( + use.as(MacroExpansionExprSyntax.self)!, + definition: expansion, + replacements: replacements, + genericReplacements: genericReplacements + ) + assertStringsEqualWithDiff( + expandedSyntax.description, + """ + #otherMacro(first: 5, second: "Hello") + """ + ) + } + + func testMacroGenericArgumentExpansionNotVisitGenericParameterArguments() throws { + let macro: DeclSyntax = + """ + macro gen(a: Array) = #otherMacro(first: a) + """ + + let use: ExprSyntax = + """ + #gen(a: [1, 2, 3]) + """ + + let macroDecl = macro.cast(MacroDeclSyntax.self) + let definition = try macroDecl.checkDefinition() + guard case let .expansion(expansion, replacements, genericReplacements) = definition else { + XCTFail("not a normal expansion") + return + } + + XCTAssertEqual(genericReplacements.count, 0) + + let expandedSyntax = macroDecl.expand( + use.as(MacroExpansionExprSyntax.self)!, + definition: expansion, + replacements: replacements, + genericReplacements: genericReplacements + ) + assertStringsEqualWithDiff( + expandedSyntax.description, + """ + #otherMacro(first: [1, 2, 3]) + """ + ) + } + + func testMacroGenericArgumentExpansionReplaceInner() throws { + let macro: DeclSyntax = + """ + macro gen(a: Array) = #reduce(first: a) + """ + + let use: ExprSyntax = + """ + #gen(a: [1, 2, 3]) + """ + + let macroDecl = macro.cast(MacroDeclSyntax.self) + let definition = try macroDecl.checkDefinition() + guard case let .expansion(expansion, replacements, genericReplacements) = definition else { + XCTFail("not a normal expansion") + return + } + + XCTAssertEqual(genericReplacements.count, 1) + + let expandedSyntax = macroDecl.expand( + use.as(MacroExpansionExprSyntax.self)!, + definition: expansion, + replacements: replacements, + genericReplacements: genericReplacements + ) + assertStringsEqualWithDiff( + expandedSyntax.description, + """ + #reduce(first: [1, 2, 3]) + """ + ) + } + + func testMacroGenericArgumentExpansionArray() throws { + let macro: DeclSyntax = + """ + macro gen(a: Array) = #other(first: a) + """ + + let use: ExprSyntax = + """ + #gen(a: [1, 2, 3]) + """ + + let macroDecl = macro.cast(MacroDeclSyntax.self) + let definition = try macroDecl.checkDefinition() + guard case let .expansion(expansion, replacements, genericReplacements) = definition else { + XCTFail("not a normal expansion") + return + } + + XCTAssertEqual(genericReplacements.count, 0) + + let expandedSyntax = macroDecl.expand( + use.as(MacroExpansionExprSyntax.self)!, + definition: expansion, + replacements: replacements, + genericReplacements: genericReplacements + ) + assertStringsEqualWithDiff( + expandedSyntax.description, + """ + #other(first: [1, 2, 3]) + """ + ) + } + + func testMacroExpansionDontCrashOnDuplicates() throws { + let macro: DeclSyntax = + """ + macro gen(a: Array) = #other(first: a) + """ + + let use: ExprSyntax = + """ + #gen(a: [1, 2, 3]) + """ + + let macroDecl = macro.cast(MacroDeclSyntax.self) + let definition = try macroDecl.checkDefinition() + guard case let .expansion(expansion, replacements, genericReplacements) = definition else { + XCTFail("not a normal expansion") + return + } + + var replacementsWithDupes = replacements + replacementsWithDupes.append(contentsOf: replacements) + var genericReplacementsWithDupes = genericReplacements + genericReplacementsWithDupes.append(contentsOf: genericReplacements) + + XCTAssertEqual(genericReplacements.count, 0) + + let expandedSyntax = macroDecl.expand( + use.as(MacroExpansionExprSyntax.self)!, + definition: expansion, + replacements: replacementsWithDupes, + genericReplacements: genericReplacementsWithDupes + ) + assertStringsEqualWithDiff( + expandedSyntax.description, + """ + #other(first: [1, 2, 3]) + """ + ) + } }