From 62b61ab0b2797707f2564e063a0afd97065282f8 Mon Sep 17 00:00:00 2001 From: Mario Sangiorgio Date: Wed, 9 Aug 2023 10:55:45 +0200 Subject: [PATCH 01/13] async/await prepared statements --- .../Connection/PostgresConnection.swift | 21 ++++ .../PreparedStatementStateMachine.swift | 68 +++++++++++++ Sources/PostgresNIO/New/PSQLTask.swift | 23 +++++ .../New/PostgresChannelHandler.swift | 96 +++++++++++++++++++ .../PostgresNIO/New/PreparedStatement.swift | 40 ++++++++ Tests/IntegrationTests/AsyncTests.swift | 42 ++++++++ 6 files changed, 290 insertions(+) create mode 100644 Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift create mode 100644 Sources/PostgresNIO/New/PreparedStatement.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 7ac8ec57..79f558b1 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -460,6 +460,27 @@ extension PostgresConnection { self.channel.write(task, promise: nil) } } + + /// Execute a prepared statement, taking care of the preparation when necessary + public func execute( + _ preparedStatement: P, + logger: Logger + ) async throws -> AsyncThrowingMapSequence + { + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let task = HandlerTask.executePreparedStatement(.init( + name: String(reflecting: P.self), + sql: P.sql, + bindings: preparedStatement.makeBindings(), + logger: logger, + promise: promise + )) + self.channel.write(task, promise: nil) + return try await promise.futureResult + .map { $0.asyncSequence() } + .get() + .map { try preparedStatement.decodeRow($0) } + } } // MARK: EventLoopFuture interface diff --git a/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift new file mode 100644 index 00000000..62c36ebf --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift @@ -0,0 +1,68 @@ +import NIOCore + +struct PreparedStatementStateMachine { + enum State { + case preparing([PreparedStatementContext]) + case prepared(RowDescription?) + case error(PSQLError) + } + + enum Action { + case prepareStatement + case waitForAlreadyInFlightPreparation + case executePendingStatements([PreparedStatementContext], RowDescription?) + case returnError([PreparedStatementContext], PSQLError) + } + + var preparedStatements: [String: State] + + init() { + self.preparedStatements = [:] + } + + mutating func lookup(name: String, context: PreparedStatementContext) -> Action { + if let state = self.preparedStatements[name] { + switch state { + case .preparing(var statements): + statements.append(context) + self.preparedStatements[name] = .preparing(statements) + return .waitForAlreadyInFlightPreparation + case .prepared(let rowDescription): + return .executePendingStatements([context], rowDescription) + case .error(let error): + return .returnError([context], error) + } + } else { + self.preparedStatements[name] = .preparing([context]) + return .prepareStatement + } + } + + mutating func preparationComplete( + name: String, + rowDescription: RowDescription? + ) -> Action { + guard case .preparing(let statements) = self.preparedStatements[name] else { + preconditionFailure("Preparation completed for an unexpected statement") + } + // When sending the bindings we are going to ask for binary data. + if var rowDescription { + for i in 0.. Action { + guard case .preparing(let statements) = self.preparedStatements[name] else { + preconditionFailure("Preparation completed for an unexpected statement") + } + self.preparedStatements[name] = .error(error) + return .returnError(statements, error) + } +} diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index f5de6561..9425c12b 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -6,6 +6,7 @@ enum HandlerTask { case closeCommand(CloseCommandContext) case startListening(NotificationListener) case cancelListening(String, Int) + case executePreparedStatement(PreparedStatementContext) } enum PSQLTask { @@ -69,6 +70,28 @@ final class ExtendedQueryContext { } } +final class PreparedStatementContext{ + let name: String + let sql: String + let bindings: PostgresBindings + let logger: Logger + let promise: EventLoopPromise + + init( + name: String, + sql: String, + bindings: PostgresBindings, + logger: Logger, + promise: EventLoopPromise + ) { + self.name = name + self.sql = sql + self.bindings = bindings + self.logger = logger + self.promise = promise + } +} + final class CloseCommandContext { let target: CloseTarget let logger: Logger diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 7801d4d6..dbe4d2f9 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -23,6 +23,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private let configureSSLCallback: ((Channel) throws -> Void)? private var listenState: ListenStateMachine + private var preparedStatementState: PreparedStatementStateMachine init( configuration: PostgresConnection.InternalConfiguration, @@ -33,6 +34,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData) self.eventLoop = eventLoop self.listenState = ListenStateMachine() + self.preparedStatementState = PreparedStatementStateMachine() self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger @@ -51,6 +53,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.state = state self.eventLoop = eventLoop self.listenState = ListenStateMachine() + self.preparedStatementState = PreparedStatementStateMachine() self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger @@ -233,6 +236,56 @@ final class PostgresChannelHandler: ChannelDuplexHandler { listener.failed(CancellationError()) return } + case .executePreparedStatement(let preparedStatement): + switch self.preparedStatementState.lookup( + name: preparedStatement.name, + context: preparedStatement + ) { + case .prepareStatement: + let promise = self.eventLoop.makePromise(of: RowDescription?.self) + promise.futureResult.whenSuccess { rowDescription in + self.prepareStatementComplete( + name: preparedStatement.name, + rowDescription: rowDescription, + context: context + ) + } + promise.futureResult.whenFailure { error in + self.prepareStatementFailed( + name: preparedStatement.name, + error: error as! PSQLError, + context: context + ) + } + psqlTask = .extendedQuery(.init( + name: preparedStatement.name, + query: preparedStatement.sql, + logger: preparedStatement.logger, + promise: promise + )) + case .waitForAlreadyInFlightPreparation: + // The state machine already keeps track of this + // and will execute the statement as soon as it's prepared + return + case .executePendingStatements(let pendingStatements, let rowDescription): + for statement in pendingStatements { + let action = self.state.enqueue(task: .extendedQuery(.init( + executeStatement: .init( + name: statement.name, + binds: statement.bindings, + rowDescription: rowDescription), + logger: statement.logger, + promise: statement.promise + ))) + self.run(action, with: context) + } + return + case .returnError(let pendingStatements, let error): + for statement in pendingStatements { + statement.promise.fail(error) + } + return + } } let action = self.state.enqueue(task: psqlTask) @@ -664,6 +717,49 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } } + private func prepareStatementComplete( + name: String, + rowDescription: RowDescription?, + context: ChannelHandlerContext + ) { + let action = self.preparedStatementState.preparationComplete( + name: name, + rowDescription: rowDescription + ) + guard case .executePendingStatements(let statements, let rowDescription) = action else { + preconditionFailure("Expected to have pending statements to execute") + } + for preparedStatement in statements { + let action = self.state.enqueue(task: .extendedQuery(.init( + executeStatement: .init( + name: preparedStatement.name, + binds: preparedStatement.bindings, + rowDescription: rowDescription + ), + logger: preparedStatement.logger, + promise: preparedStatement.promise + )) + ) + self.run(action, with: context) + } + } + + private func prepareStatementFailed( + name: String, + error: PSQLError, + context: ChannelHandlerContext + ) { + let action = self.preparedStatementState.errorHappened( + name: name, + error: error + ) + guard case .returnError(let statements, let error) = action else { + preconditionFailure("Expected to have pending statements to execute") + } + for statement in statements { + statement.promise.fail(error) + } + } } extension PostgresChannelHandler: PSQLRowsDataSource { diff --git a/Sources/PostgresNIO/New/PreparedStatement.swift b/Sources/PostgresNIO/New/PreparedStatement.swift new file mode 100644 index 00000000..78ffdf4b --- /dev/null +++ b/Sources/PostgresNIO/New/PreparedStatement.swift @@ -0,0 +1,40 @@ +/// A prepared statement. +/// +/// Structs conforming to this protocol will need to provide the SQL statement to +/// send to the server and a way of creating bindings are decoding the result. +/// +/// As an example, consider this struct: +/// ```swift +/// struct Example: PreparedStatement { +/// static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" +/// typealias Row = (Int, String) +/// +/// var state: String +/// +/// func makeBindings() -> PostgresBindings { +/// var bindings = PostgresBindings() +/// bindings.append(.init(string: self.state)) +/// return bindings +/// } +/// +/// func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { +/// try row.decode(Row.self) +/// } +/// } +/// ``` +/// +/// Structs conforming to this protocol can then be used with `PostgresConnection.execute(_ preparedStatement:, logger:)`, +/// which will take care of preparing the statement on the server side and executing it. +public protocol PreparedStatement { + /// The type rows returned by the statement will be decoded into + associatedtype Row + + /// The SQL statement to prepare on the database server. + static var sql: String { get } + + /// Make the bindings to provided concrete values to use when executing the prepared SQL statement + func makeBindings() -> PostgresBindings + + /// Decode a row returned by the database into an instance of `Row` + func decodeRow(_ row: PostgresRow) throws -> Row +} diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index f68ef1f3..6b374e07 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -315,6 +315,48 @@ final class AsyncPostgresConnectionTests: XCTestCase { try await connection.query("SELECT 1;", logger: .psqlTest) } } + + func testPreparedStatement() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + struct TestPreparedStatement: PreparedStatement { + static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" + typealias Row = (Int, String) + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(.init(string: self.state)) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + let preparedStatement = TestPreparedStatement(state: "active") + try await withTestConnection(on: eventLoop) { connection in + var results = try await connection.execute(preparedStatement, logger: .psqlTest) + var counter = 0 + + for try await element in results { + XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database") + counter += 1 + } + + XCTAssertGreaterThanOrEqual(counter, 1) + + // Second execution, which reuses the existing prepared statement + results = try await connection.execute(preparedStatement, logger: .psqlTest) + for try await element in results { + XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database") + counter += 1 + } + } + } } extension XCTestCase { From 8283d8b774ad5aea7a30873f917bf0decedd2d49 Mon Sep 17 00:00:00 2001 From: Mario Sangiorgio Date: Thu, 10 Aug 2023 15:37:44 +0200 Subject: [PATCH 02/13] Address PR feedbacks --- .../Connection/PostgresConnection.swift | 31 +++++++++++++--- .../PreparedStatementStateMachine.swift | 35 ++++++++++++------- .../New/PostgresChannelHandler.swift | 14 ++++++-- .../PostgresNIO/New/PreparedStatement.swift | 4 +-- Tests/IntegrationTests/AsyncTests.swift | 2 +- 5 files changed, 63 insertions(+), 23 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 79f558b1..2e19c0d7 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -462,15 +462,15 @@ extension PostgresConnection { } /// Execute a prepared statement, taking care of the preparation when necessary - public func execute( - _ preparedStatement: P, + public func execute( + _ preparedStatement: Statement, logger: Logger - ) async throws -> AsyncThrowingMapSequence + ) async throws -> AsyncThrowingMapSequence { let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( - name: String(reflecting: P.self), - sql: P.sql, + name: String(reflecting: Statement.self), + sql: Statement.sql, bindings: preparedStatement.makeBindings(), logger: logger, promise: promise @@ -481,6 +481,27 @@ extension PostgresConnection { .get() .map { try preparedStatement.decodeRow($0) } } + + /// Execute a prepared statement, taking care of the preparation when necessary + public func execute( + _ preparedStatement: Statement, + logger: Logger + ) async throws -> String + where Statement.Row == () + { + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let task = HandlerTask.executePreparedStatement(.init( + name: String(reflecting: Statement.self), + sql: Statement.sql, + bindings: preparedStatement.makeBindings(), + logger: logger, + promise: promise + )) + self.channel.write(task, promise: nil) + return try await promise.futureResult + .map { $0.commandTag } + .get() + } } // MARK: EventLoopFuture interface diff --git a/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift index 62c36ebf..4898a9c1 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift @@ -6,21 +6,22 @@ struct PreparedStatementStateMachine { case prepared(RowDescription?) case error(PSQLError) } + + var preparedStatements: [String: State] - enum Action { + init() { + self.preparedStatements = [:] + } + + enum LookupAction { case prepareStatement case waitForAlreadyInFlightPreparation + case executeStatement(RowDescription?) case executePendingStatements([PreparedStatementContext], RowDescription?) case returnError([PreparedStatementContext], PSQLError) } - - var preparedStatements: [String: State] - - init() { - self.preparedStatements = [:] - } - - mutating func lookup(name: String, context: PreparedStatementContext) -> Action { + + mutating func lookup(name: String, context: PreparedStatementContext) -> LookupAction { if let state = self.preparedStatements[name] { switch state { case .preparing(var statements): @@ -28,7 +29,7 @@ struct PreparedStatementStateMachine { self.preparedStatements[name] = .preparing(statements) return .waitForAlreadyInFlightPreparation case .prepared(let rowDescription): - return .executePendingStatements([context], rowDescription) + return .executeStatement(rowDescription) case .error(let error): return .returnError([context], error) } @@ -37,11 +38,15 @@ struct PreparedStatementStateMachine { return .prepareStatement } } - + + enum PreparationCompleteAction { + case executePendingStatements([PreparedStatementContext], RowDescription?) + } + mutating func preparationComplete( name: String, rowDescription: RowDescription? - ) -> Action { + ) -> PreparationCompleteAction { guard case .preparing(let statements) = self.preparedStatements[name] else { preconditionFailure("Preparation completed for an unexpected statement") } @@ -58,7 +63,11 @@ struct PreparedStatementStateMachine { } } - mutating func errorHappened(name: String, error: PSQLError) -> Action { + enum ErrorHappenedAction { + case returnError([PreparedStatementContext], PSQLError) + } + + mutating func errorHappened(name: String, error: PSQLError) -> ErrorHappenedAction { guard case .preparing(let statements) = self.preparedStatements[name] else { preconditionFailure("Preparation completed for an unexpected statement") } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index dbe4d2f9..aa6ad3ab 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -237,10 +237,11 @@ final class PostgresChannelHandler: ChannelDuplexHandler { return } case .executePreparedStatement(let preparedStatement): - switch self.preparedStatementState.lookup( + let action = self.preparedStatementState.lookup( name: preparedStatement.name, context: preparedStatement - ) { + ) + switch action { case .prepareStatement: let promise = self.eventLoop.makePromise(of: RowDescription?.self) promise.futureResult.whenSuccess { rowDescription in @@ -267,6 +268,15 @@ final class PostgresChannelHandler: ChannelDuplexHandler { // The state machine already keeps track of this // and will execute the statement as soon as it's prepared return + case .executeStatement(let rowDescription): + psqlTask = .extendedQuery(.init( + executeStatement: .init( + name: preparedStatement.name, + binds: preparedStatement.bindings, + rowDescription: rowDescription), + logger: preparedStatement.logger, + promise: preparedStatement.promise + )) case .executePendingStatements(let pendingStatements, let rowDescription): for statement in pendingStatements { let action = self.state.enqueue(task: .extendedQuery(.init( diff --git a/Sources/PostgresNIO/New/PreparedStatement.swift b/Sources/PostgresNIO/New/PreparedStatement.swift index 78ffdf4b..2d80b86d 100644 --- a/Sources/PostgresNIO/New/PreparedStatement.swift +++ b/Sources/PostgresNIO/New/PreparedStatement.swift @@ -5,7 +5,7 @@ /// /// As an example, consider this struct: /// ```swift -/// struct Example: PreparedStatement { +/// struct Example: PostgresPreparedStatement { /// static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" /// typealias Row = (Int, String) /// @@ -25,7 +25,7 @@ /// /// Structs conforming to this protocol can then be used with `PostgresConnection.execute(_ preparedStatement:, logger:)`, /// which will take care of preparing the statement on the server side and executing it. -public protocol PreparedStatement { +public protocol PostgresPreparedStatement: Sendable { /// The type rows returned by the statement will be decoded into associatedtype Row diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 6b374e07..e5de5bc9 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -321,7 +321,7 @@ final class AsyncPostgresConnectionTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - struct TestPreparedStatement: PreparedStatement { + struct TestPreparedStatement: PostgresPreparedStatement { static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" typealias Row = (Int, String) From 95747250519c12981bfd7878890852c457de0730 Mon Sep 17 00:00:00 2001 From: Mario Sangiorgio Date: Fri, 11 Aug 2023 09:49:31 +0200 Subject: [PATCH 03/13] Apply suggestions from code review Co-authored-by: Fabian Fett --- .../Connection/PostgresConnection.swift | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 2e19c0d7..ee6352b1 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -464,9 +464,10 @@ extension PostgresConnection { /// Execute a prepared statement, taking care of the preparation when necessary public func execute( _ preparedStatement: Statement, - logger: Logger - ) async throws -> AsyncThrowingMapSequence - { + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> AsyncThrowingMapSequence { let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( name: String(reflecting: Statement.self), @@ -485,10 +486,10 @@ extension PostgresConnection { /// Execute a prepared statement, taking care of the preparation when necessary public func execute( _ preparedStatement: Statement, - logger: Logger - ) async throws -> String - where Statement.Row == () - { + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> String where Statement.Row == () { let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( name: String(reflecting: Statement.self), From b03c2512de882a8cb75ca402403bb95afd406bc2 Mon Sep 17 00:00:00 2001 From: Mario Sangiorgio Date: Fri, 11 Aug 2023 12:02:33 +0200 Subject: [PATCH 04/13] PR feedback and state machine tests --- .../Connection/PostgresConnection.swift | 42 ++++- .../PreparedStatementStateMachine.swift | 46 ++--- .../New/PostgresChannelHandler.swift | 128 +++++++------- .../PreparedStatementStateMachineTests.swift | 159 ++++++++++++++++++ 4 files changed, 284 insertions(+), 91 deletions(-) create mode 100644 Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index ee6352b1..2111ec31 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -468,19 +468,31 @@ extension PostgresConnection { file: String = #fileID, line: Int = #line ) async throws -> AsyncThrowingMapSequence { + let bindings = preparedStatement.makeBindings() let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( name: String(reflecting: Statement.self), sql: Statement.sql, - bindings: preparedStatement.makeBindings(), + bindings: bindings, logger: logger, promise: promise )) self.channel.write(task, promise: nil) - return try await promise.futureResult - .map { $0.asyncSequence() } - .get() - .map { try preparedStatement.decodeRow($0) } + do { + return try await promise.futureResult + .map { $0.asyncSequence() } + .get() + .map { try preparedStatement.decodeRow($0) } + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = .init( + unsafeSQL: Statement.sql, + binds: bindings + ) + throw error // rethrow with more metadata + } + } /// Execute a prepared statement, taking care of the preparation when necessary @@ -490,18 +502,30 @@ extension PostgresConnection { file: String = #fileID, line: Int = #line ) async throws -> String where Statement.Row == () { + let bindings = preparedStatement.makeBindings() let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( name: String(reflecting: Statement.self), sql: Statement.sql, - bindings: preparedStatement.makeBindings(), + bindings: bindings, logger: logger, promise: promise )) self.channel.write(task, promise: nil) - return try await promise.futureResult - .map { $0.commandTag } - .get() + do { + return try await promise.futureResult + .map { $0.commandTag } + .get() + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = .init( + unsafeSQL: Statement.sql, + binds: bindings + ) + throw error // rethrow with more metadata + } + } } diff --git a/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift index 4898a9c1..9068ebbb 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift @@ -7,40 +7,36 @@ struct PreparedStatementStateMachine { case error(PSQLError) } - var preparedStatements: [String: State] + var preparedStatements: [String: State] = [:] - init() { - self.preparedStatements = [:] - } - enum LookupAction { case prepareStatement case waitForAlreadyInFlightPreparation case executeStatement(RowDescription?) - case executePendingStatements([PreparedStatementContext], RowDescription?) - case returnError([PreparedStatementContext], PSQLError) + case returnError(PSQLError) } - mutating func lookup(name: String, context: PreparedStatementContext) -> LookupAction { - if let state = self.preparedStatements[name] { + mutating func lookup(preparedStatement: PreparedStatementContext) -> LookupAction { + if let state = self.preparedStatements[preparedStatement.name] { switch state { case .preparing(var statements): - statements.append(context) - self.preparedStatements[name] = .preparing(statements) + statements.append(preparedStatement) + self.preparedStatements[preparedStatement.name] = .preparing(statements) return .waitForAlreadyInFlightPreparation case .prepared(let rowDescription): return .executeStatement(rowDescription) case .error(let error): - return .returnError([context], error) + return .returnError(error) } } else { - self.preparedStatements[name] = .preparing([context]) + self.preparedStatements[preparedStatement.name] = .preparing([preparedStatement]) return .prepareStatement } } - enum PreparationCompleteAction { - case executePendingStatements([PreparedStatementContext], RowDescription?) + struct PreparationCompleteAction { + var statements: [PreparedStatementContext] + var rowDescription: RowDescription? } mutating func preparationComplete( @@ -56,15 +52,22 @@ struct PreparedStatementStateMachine { rowDescription.columns[i].format = .binary } self.preparedStatements[name] = .prepared(rowDescription) - return .executePendingStatements(statements, rowDescription) + return PreparationCompleteAction( + statements: statements, + rowDescription: rowDescription + ) } else { self.preparedStatements[name] = .prepared(nil) - return .executePendingStatements(statements, nil) + return PreparationCompleteAction( + statements: statements, + rowDescription: nil + ) } } - enum ErrorHappenedAction { - case returnError([PreparedStatementContext], PSQLError) + struct ErrorHappenedAction { + var statements: [PreparedStatementContext] + var error: PSQLError } mutating func errorHappened(name: String, error: PSQLError) -> ErrorHappenedAction { @@ -72,6 +75,9 @@ struct PreparedStatementStateMachine { preconditionFailure("Preparation completed for an unexpected statement") } self.preparedStatements[name] = .error(error) - return .returnError(statements, error) + return ErrorHappenedAction( + statements: statements, + error: error + ) } } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index aa6ad3ab..c7a219c0 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -22,8 +22,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private let configuration: PostgresConnection.InternalConfiguration private let configureSSLCallback: ((Channel) throws -> Void)? - private var listenState: ListenStateMachine - private var preparedStatementState: PreparedStatementStateMachine + private var listenState = ListenStateMachine() + private var preparedStatementState = PreparedStatementStateMachine() init( configuration: PostgresConnection.InternalConfiguration, @@ -33,8 +33,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) { self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData) self.eventLoop = eventLoop - self.listenState = ListenStateMachine() - self.preparedStatementState = PreparedStatementStateMachine() self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger @@ -238,62 +236,25 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } case .executePreparedStatement(let preparedStatement): let action = self.preparedStatementState.lookup( - name: preparedStatement.name, - context: preparedStatement + preparedStatement: preparedStatement ) switch action { case .prepareStatement: - let promise = self.eventLoop.makePromise(of: RowDescription?.self) - promise.futureResult.whenSuccess { rowDescription in - self.prepareStatementComplete( - name: preparedStatement.name, - rowDescription: rowDescription, - context: context - ) - } - promise.futureResult.whenFailure { error in - self.prepareStatementFailed( - name: preparedStatement.name, - error: error as! PSQLError, - context: context - ) - } - psqlTask = .extendedQuery(.init( - name: preparedStatement.name, - query: preparedStatement.sql, - logger: preparedStatement.logger, - promise: promise - )) + psqlTask = self.makePrepareStatementAction( + preparedStatement: preparedStatement, + context: context + ) case .waitForAlreadyInFlightPreparation: // The state machine already keeps track of this // and will execute the statement as soon as it's prepared return case .executeStatement(let rowDescription): - psqlTask = .extendedQuery(.init( - executeStatement: .init( - name: preparedStatement.name, - binds: preparedStatement.bindings, - rowDescription: rowDescription), - logger: preparedStatement.logger, - promise: preparedStatement.promise - )) - case .executePendingStatements(let pendingStatements, let rowDescription): - for statement in pendingStatements { - let action = self.state.enqueue(task: .extendedQuery(.init( - executeStatement: .init( - name: statement.name, - binds: statement.bindings, - rowDescription: rowDescription), - logger: statement.logger, - promise: statement.promise - ))) - self.run(action, with: context) - } - return - case .returnError(let pendingStatements, let error): - for statement in pendingStatements { - statement.promise.fail(error) - } + psqlTask = self.makeExecutPreparedStatementAction( + preparedStatement: preparedStatement, + rowDescription: rowDescription + ) + case .returnError(let error): + preparedStatement.promise.fail(error) return } } @@ -727,6 +688,55 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } } + private func makePrepareStatementAction( + preparedStatement: PreparedStatementContext, + context: ChannelHandlerContext + ) -> PSQLTask { + let promise = self.eventLoop.makePromise(of: RowDescription?.self) + promise.futureResult.whenComplete { result in + switch result { + case .success(let rowDescription): + self.prepareStatementComplete( + name: preparedStatement.name, + rowDescription: rowDescription, + context: context + ) + case .failure(let error): + let psqlError: PSQLError + if let error = error as? PSQLError { + psqlError = error + } else { + psqlError = .connectionError(underlying: error) + } + self.prepareStatementFailed( + name: preparedStatement.name, + error: psqlError, + context: context + ) + } + } + return .extendedQuery(.init( + name: preparedStatement.name, + query: preparedStatement.sql, + logger: preparedStatement.logger, + promise: promise + )) + } + + private func makeExecutPreparedStatementAction( + preparedStatement: PreparedStatementContext, + rowDescription: RowDescription? + ) -> PSQLTask { + return .extendedQuery(.init( + executeStatement: .init( + name: preparedStatement.name, + binds: preparedStatement.bindings, + rowDescription: rowDescription), + logger: preparedStatement.logger, + promise: preparedStatement.promise + )) + } + private func prepareStatementComplete( name: String, rowDescription: RowDescription?, @@ -736,15 +746,12 @@ final class PostgresChannelHandler: ChannelDuplexHandler { name: name, rowDescription: rowDescription ) - guard case .executePendingStatements(let statements, let rowDescription) = action else { - preconditionFailure("Expected to have pending statements to execute") - } - for preparedStatement in statements { + for preparedStatement in action.statements { let action = self.state.enqueue(task: .extendedQuery(.init( executeStatement: .init( name: preparedStatement.name, binds: preparedStatement.bindings, - rowDescription: rowDescription + rowDescription: action.rowDescription ), logger: preparedStatement.logger, promise: preparedStatement.promise @@ -763,11 +770,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { name: name, error: error ) - guard case .returnError(let statements, let error) = action else { - preconditionFailure("Expected to have pending statements to execute") - } - for statement in statements { - statement.promise.fail(error) + for statement in action.statements { + statement.promise.fail(action.error) } } } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift new file mode 100644 index 00000000..ab77a57c --- /dev/null +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift @@ -0,0 +1,159 @@ +import XCTest +import NIOEmbedded +@testable import PostgresNIO + +class PreparedStatementStateMachineTests: XCTestCase { + func testPrepareAndExecuteStatement() { + let eventLoop = EmbeddedEventLoop() + var stateMachine = PreparedStatementStateMachine() + + let firstPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Initial lookup, the statement hasn't been prepared yet + let lookupAction = stateMachine.lookup(preparedStatement: firstPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .prepareStatement = lookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // Once preparation is complete we transition to a prepared state + let preparationCompleteAction = stateMachine.preparationComplete(name: "test", rowDescription: nil) + guard case .prepared(nil) = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + XCTAssertEqual(preparationCompleteAction.statements.count, 1) + XCTAssertNil(preparationCompleteAction.rowDescription) + firstPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + + // Create a new prepared statement + let secondPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // The statement is already preparead, lookups tell us to execute it + let secondLookupAction = stateMachine.lookup(preparedStatement: secondPreparedStatement) + guard case .prepared(nil) = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .executeStatement(nil) = secondLookupAction else { + XCTFail("State machine returned the wrong action") + return + } + secondPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + } + + func testPrepareAndExecuteStatementWithError() { + let eventLoop = EmbeddedEventLoop() + var stateMachine = PreparedStatementStateMachine() + + let firstPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Initial lookup, the statement hasn't been prepared yet + let lookupAction = stateMachine.lookup(preparedStatement: firstPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .prepareStatement = lookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // Simulate an error occurring during preparation + let error = PSQLError(code: .server) + let preparationCompleteAction = stateMachine.errorHappened( + name: "test", + error: error + ) + guard case .error = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + XCTAssertEqual(preparationCompleteAction.statements.count, 1) + firstPreparedStatement.promise.fail(error) + + // Create a new prepared statement + let secondPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Ensure that we don't try again to prepare a statement we know will fail + let secondLookupAction = stateMachine.lookup(preparedStatement: secondPreparedStatement) + guard case .error = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .returnError = secondLookupAction else { + XCTFail("State machine returned the wrong action") + return + } + secondPreparedStatement.promise.fail(error) + } + + func testBatchStatementPreparation() { + let eventLoop = EmbeddedEventLoop() + var stateMachine = PreparedStatementStateMachine() + + let firstPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Initial lookup, the statement hasn't been prepared yet + let lookupAction = stateMachine.lookup(preparedStatement: firstPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .prepareStatement = lookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // A new request comes in before the statement completes + let secondPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + let secondLookupAction = stateMachine.lookup(preparedStatement: secondPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .waitForAlreadyInFlightPreparation = secondLookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // Once preparation is complete we transition to a prepared state. + // The action tells us to execute both the pending statements. + let preparationCompleteAction = stateMachine.preparationComplete(name: "test", rowDescription: nil) + guard case .prepared(nil) = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + XCTAssertEqual(preparationCompleteAction.statements.count, 2) + XCTAssertNil(preparationCompleteAction.rowDescription) + + firstPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + secondPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + } + + private func makePreparedStatementContext(eventLoop: EmbeddedEventLoop) -> PreparedStatementContext { + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + return PreparedStatementContext( + name: "test", + sql: "INSERT INTO test_table (column1) VALUES (1)", + bindings: PostgresBindings(), + logger: .psqlTest, + promise: promise + ) + } +} From 800ceaf0e0657506900abd8d86e379ce36029095 Mon Sep 17 00:00:00 2001 From: Mario Sangiorgio Date: Fri, 11 Aug 2023 17:09:18 +0200 Subject: [PATCH 05/13] Add test to `PostgresConnectionTests` to test prepared statements --- .../PSQLFrontendMessageDecoder.swift | 2 +- .../New/PostgresConnectionTests.swift | 91 +++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index b9677000..46c043b1 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -142,7 +142,7 @@ extension PostgresFrontendMessage { } let parameters = (0.. ByteBuffer? in - let length = buffer.readInteger(as: UInt16.self) + let length = buffer.readInteger(as: UInt32.self) switch length { case .some(..<0): return nil diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 46f864ce..500fee5b 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -275,6 +275,97 @@ class PostgresConnectionTests: XCTestCase { } } + struct TestPrepareStatement: PostgresPreparedStatement { + static var sql = "SELECT datname FROM pg_stat_activity WHERE state = $1" + typealias Row = String + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(.init(string: self.state)) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + + func testPreparedStatement() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database", database) + } + XCTAssertEqual(rows, 1) + } + // Wait for the PREPARE request from the client + guard case .parse(let parse) = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else { + fatalError("Unexpected message") + } + XCTAssertEqual(parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(parse.parameters.count, 0) + guard case .describe(.preparedStatement(let name)) = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else { + fatalError("Unexpected message") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + guard case .sync = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else { + fatalError("Unexpected message") + } + + // Respond to the PREPARE request + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: [ + PostgresDataType.text + ]))) + try await channel.testingEventLoop.executeInContext { channel.read() } + let rowDescription = RowDescription(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + try await channel.writeInbound(PostgresBackendMessage.rowDescription(rowDescription)) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await channel.testingEventLoop.executeInContext { channel.read() } + + // Wait for the EXECUTE request + guard case .bind = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else { + fatalError("Unexpected message") + } + guard case .execute = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else { + fatalError("Unexpected message") + } + guard case .sync = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else { + fatalError("Unexpected message") + } + // Respond to the EXECUTE request + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.testingEventLoop.executeInContext { channel.read() } + let dataRow = DataRow(arrayLiteral: "test_database") + try await channel.writeInbound(PostgresBackendMessage.dataRow(dataRow)) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.commandComplete(TestPrepareStatement.sql)) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await channel.testingEventLoop.executeInContext { channel.read() } + } + } + func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() let channel = await NIOAsyncTestingChannel(handlers: [ From a426205dca70ac2645085094fe553a2bb76e7f56 Mon Sep 17 00:00:00 2001 From: Mario Sangiorgio Date: Wed, 16 Aug 2023 13:35:53 +0100 Subject: [PATCH 06/13] Explicit switch over all the state machine states --- .../PreparedStatementStateMachine.swift | 58 +++++++++++-------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift index 9068ebbb..0a1ba7b4 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift @@ -43,25 +43,30 @@ struct PreparedStatementStateMachine { name: String, rowDescription: RowDescription? ) -> PreparationCompleteAction { - guard case .preparing(let statements) = self.preparedStatements[name] else { - preconditionFailure("Preparation completed for an unexpected statement") + guard let state = self.preparedStatements[name] else { + fatalError("Unknown prepared statement \(name)") } - // When sending the bindings we are going to ask for binary data. - if var rowDescription { - for i in 0.. ErrorHappenedAction { - guard case .preparing(let statements) = self.preparedStatements[name] else { - preconditionFailure("Preparation completed for an unexpected statement") + guard let state = self.preparedStatements[name] else { + fatalError("Unknown prepared statement \(name)") + } + switch state { + case .preparing(let statements): + self.preparedStatements[name] = .error(error) + return ErrorHappenedAction( + statements: statements, + error: error + ) + case .prepared, .error: + preconditionFailure("Error happened in an unexpected state \(state)") } - self.preparedStatements[name] = .error(error) - return ErrorHappenedAction( - statements: statements, - error: error - ) } } From abe1cd15d4b100320a3a48b5d38c0e420e8664b7 Mon Sep 17 00:00:00 2001 From: Mario Sangiorgio Date: Wed, 16 Aug 2023 14:33:29 +0100 Subject: [PATCH 07/13] PostgresBindings.append that use the default context --- .../PostgresNIO/Connection/PostgresConnection.swift | 4 ++-- Sources/PostgresNIO/New/PostgresQuery.swift | 10 ++++++++++ Sources/PostgresNIO/New/PreparedStatement.swift | 4 ++-- Tests/IntegrationTests/AsyncTests.swift | 2 +- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 2111ec31..ebf7ac52 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -468,7 +468,7 @@ extension PostgresConnection { file: String = #fileID, line: Int = #line ) async throws -> AsyncThrowingMapSequence { - let bindings = preparedStatement.makeBindings() + let bindings = try preparedStatement.makeBindings() let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( name: String(reflecting: Statement.self), @@ -502,7 +502,7 @@ extension PostgresConnection { file: String = #fileID, line: Int = #line ) async throws -> String where Statement.Row == () { - let bindings = preparedStatement.makeBindings() + let bindings = try preparedStatement.makeBindings() let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( name: String(reflecting: Statement.self), diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 2e06e1d9..4ca1e454 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -167,6 +167,11 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(dataType: .null, format: .binary, protected: true)) } + @inlinable + public mutating func append(_ value: Value) throws { + try self.append(value, context: .default) + } + @inlinable public mutating func append( _ value: Value, @@ -176,6 +181,11 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(value: value, protected: true)) } + @inlinable + public mutating func append(_ value: Value) { + self.append(value, context: .default) + } + @inlinable public mutating func append( _ value: Value, diff --git a/Sources/PostgresNIO/New/PreparedStatement.swift b/Sources/PostgresNIO/New/PreparedStatement.swift index 2d80b86d..28616921 100644 --- a/Sources/PostgresNIO/New/PreparedStatement.swift +++ b/Sources/PostgresNIO/New/PreparedStatement.swift @@ -13,7 +13,7 @@ /// /// func makeBindings() -> PostgresBindings { /// var bindings = PostgresBindings() -/// bindings.append(.init(string: self.state)) +/// bindings.append(self.state) /// return bindings /// } /// @@ -33,7 +33,7 @@ public protocol PostgresPreparedStatement: Sendable { static var sql: String { get } /// Make the bindings to provided concrete values to use when executing the prepared SQL statement - func makeBindings() -> PostgresBindings + func makeBindings() throws -> PostgresBindings /// Decode a row returned by the database into an instance of `Row` func decodeRow(_ row: PostgresRow) throws -> Row diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index e5de5bc9..bf945a67 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -329,7 +329,7 @@ final class AsyncPostgresConnectionTests: XCTestCase { func makeBindings() -> PostgresBindings { var bindings = PostgresBindings() - bindings.append(.init(string: self.state)) + bindings.append(self.state) return bindings } From ff0e139eb9c11628d1c90dc8fbe032033d886b83 Mon Sep 17 00:00:00 2001 From: Mario Sangiorgio Date: Wed, 16 Aug 2023 14:40:42 +0100 Subject: [PATCH 08/13] More assertions in `testPreparedStatement` --- Tests/PostgresNIOTests/New/PostgresConnectionTests.swift | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 500fee5b..95e61a9e 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -344,9 +344,12 @@ class PostgresConnectionTests: XCTestCase { try await channel.testingEventLoop.executeInContext { channel.read() } // Wait for the EXECUTE request - guard case .bind = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else { + guard case .bind(let bind) = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else { fatalError("Unexpected message") } + XCTAssertEqual(bind.preparedStatementName, String(reflecting: TestPrepareStatement.self)) + XCTAssertEqual(bind.parameters.count, 1) + XCTAssertEqual(bind.resultColumnFormats, [.binary]) guard case .execute = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else { fatalError("Unexpected message") } From 725e778c034ffdfd3b8c06c5243ec17ef4b5af3f Mon Sep 17 00:00:00 2001 From: Mario Sangiorgio Date: Wed, 16 Aug 2023 16:19:43 +0100 Subject: [PATCH 09/13] Fix a build issue with Swift 5.6 --- .../PreparedStatementStateMachine.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift index 0a1ba7b4..5afa4d0b 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift @@ -49,7 +49,7 @@ struct PreparedStatementStateMachine { switch state { case .preparing(let statements): // When sending the bindings we are going to ask for binary data. - if var rowDescription { + if var rowDescription = rowDescription { for i in 0.. Date: Thu, 17 Aug 2023 11:34:43 +0100 Subject: [PATCH 10/13] Apply suggestions from code review Co-authored-by: Fabian Fett --- .../PostgresNIO/Connection/PostgresConnection.swift | 7 ++++--- .../PostgresNIO/New/PostgresChannelHandler.swift | 13 ++++++------- Sources/PostgresNIO/New/PreparedStatement.swift | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index ebf7ac52..d758c808 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -462,12 +462,13 @@ extension PostgresConnection { } /// Execute a prepared statement, taking care of the preparation when necessary - public func execute( + @inlinable + public func execute( _ preparedStatement: Statement, logger: Logger, file: String = #fileID, line: Int = #line - ) async throws -> AsyncThrowingMapSequence { + ) async throws -> AsyncThrowingMapSequence where Row == Statement.Row { let bindings = try preparedStatement.makeBindings() let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( @@ -496,6 +497,7 @@ extension PostgresConnection { } /// Execute a prepared statement, taking care of the preparation when necessary + @inlinable public func execute( _ preparedStatement: Statement, logger: Logger, @@ -525,7 +527,6 @@ extension PostgresConnection { ) throw error // rethrow with more metadata } - } } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index c7a219c0..4e63c864 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -50,8 +50,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) { self.state = state self.eventLoop = eventLoop - self.listenState = ListenStateMachine() - self.preparedStatementState = PreparedStatementStateMachine() self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger @@ -240,7 +238,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) switch action { case .prepareStatement: - psqlTask = self.makePrepareStatementAction( + psqlTask = self.makePrepareStatementTask( preparedStatement: preparedStatement, context: context ) @@ -249,7 +247,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { // and will execute the statement as soon as it's prepared return case .executeStatement(let rowDescription): - psqlTask = self.makeExecutPreparedStatementAction( + psqlTask = self. makeExecutePreparedStatementTask( preparedStatement: preparedStatement, rowDescription: rowDescription ) @@ -688,7 +686,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } } - private func makePrepareStatementAction( + private func makePrepareStatementTask( preparedStatement: PreparedStatementContext, context: ChannelHandlerContext ) -> PSQLTask { @@ -723,7 +721,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { )) } - private func makeExecutPreparedStatementAction( + private func makeExecutePreparedStatementTask( preparedStatement: PreparedStatementContext, rowDescription: RowDescription? ) -> PSQLTask { @@ -731,7 +729,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { executeStatement: .init( name: preparedStatement.name, binds: preparedStatement.bindings, - rowDescription: rowDescription), + rowDescription: rowDescription + ), logger: preparedStatement.logger, promise: preparedStatement.promise )) diff --git a/Sources/PostgresNIO/New/PreparedStatement.swift b/Sources/PostgresNIO/New/PreparedStatement.swift index 28616921..1e0b5d5a 100644 --- a/Sources/PostgresNIO/New/PreparedStatement.swift +++ b/Sources/PostgresNIO/New/PreparedStatement.swift @@ -6,7 +6,7 @@ /// As an example, consider this struct: /// ```swift /// struct Example: PostgresPreparedStatement { -/// static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" +/// static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" /// typealias Row = (Int, String) /// /// var state: String From 0740293b832fbeeb1bfa1b363e057588d70c493d Mon Sep 17 00:00:00 2001 From: Mario Sangiorgio Date: Thu, 17 Aug 2023 14:53:28 +0100 Subject: [PATCH 11/13] Fix build issues --- Sources/PostgresNIO/Connection/PostgresConnection.swift | 2 -- Sources/PostgresNIO/New/PostgresChannelHandler.swift | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index d758c808..d3f51ca9 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -462,7 +462,6 @@ extension PostgresConnection { } /// Execute a prepared statement, taking care of the preparation when necessary - @inlinable public func execute( _ preparedStatement: Statement, logger: Logger, @@ -497,7 +496,6 @@ extension PostgresConnection { } /// Execute a prepared statement, taking care of the preparation when necessary - @inlinable public func execute( _ preparedStatement: Statement, logger: Logger, diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 4e63c864..bf56d6d1 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -247,7 +247,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { // and will execute the statement as soon as it's prepared return case .executeStatement(let rowDescription): - psqlTask = self. makeExecutePreparedStatementTask( + psqlTask = self.makeExecutePreparedStatementTask( preparedStatement: preparedStatement, rowDescription: rowDescription ) From 660e78d7843cd10ee8c443106a1c2eb75e40a10a Mon Sep 17 00:00:00 2001 From: Mario Sangiorgio Date: Thu, 17 Aug 2023 14:53:55 +0100 Subject: [PATCH 12/13] Add more `PostgresConnectionTests` cases --- .../New/PostgresConnectionTests.swift | 354 +++++++++++++++--- 1 file changed, 306 insertions(+), 48 deletions(-) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 95e61a9e..10c2b1de 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -306,66 +306,254 @@ class PostgresConnectionTests: XCTestCase { } XCTAssertEqual(rows, 1) } - // Wait for the PREPARE request from the client - guard case .parse(let parse) = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else { - fatalError("Unexpected message") + + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") } - XCTAssertEqual(parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(parse.parameters.count, 0) - guard case .describe(.preparedStatement(let name)) = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else { - fatalError("Unexpected message") + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) + + let preparedRequest = try await channel.waitForPreparedRequest() + XCTAssertEqual(preparedRequest.bind.preparedStatementName, String(reflecting: TestPrepareStatement.self)) + XCTAssertEqual(preparedRequest.bind.parameters.count, 1) + XCTAssertEqual(preparedRequest.bind.resultColumnFormats, [.binary]) + + try await channel.sendPreparedResponse( + dataRows: [ + DataRow(arrayLiteral: "test_database") + ], + commandTag: TestPrepareStatement.sql + ) + } + } + + func testSerialExecutionOfSamePreparedStatement() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database", database) + } + XCTAssertEqual(rows, 1) + } + + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") } XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) - guard case .sync = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else { - fatalError("Unexpected message") + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) + + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + XCTAssertEqual(parameter1, "active") + try await channel.sendPreparedResponse( + dataRows: [ + DataRow(arrayLiteral: "test_database") + ], + commandTag: TestPrepareStatement.sql + ) + + // Now that the statement has been prepared and executed, send another request that will only get executed + // without preparation + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database", database) + } + XCTAssertEqual(rows, 1) } - // Respond to the PREPARE request - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.testingEventLoop.executeInContext { channel.read() } - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: [ - PostgresDataType.text - ]))) - try await channel.testingEventLoop.executeInContext { channel.read() } - let rowDescription = RowDescription(columns: [ - .init( - name: "datname", - tableOID: 12222, - columnAttributeNumber: 2, - dataType: .name, - dataTypeSize: 64, - dataTypeModifier: -1, - format: .text - ) - ]) - try await channel.writeInbound(PostgresBackendMessage.rowDescription(rowDescription)) - try await channel.testingEventLoop.executeInContext { channel.read() } - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - try await channel.testingEventLoop.executeInContext { channel.read() } + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + XCTAssertEqual(parameter2, "idle") + try await channel.sendPreparedResponse( + dataRows: [ + DataRow(arrayLiteral: "test_database") + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + XCTAssert(parameters.contains("active")) + XCTAssert(parameters.contains("idle")) + } + } - // Wait for the EXECUTE request - guard case .bind(let bind) = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else { - fatalError("Unexpected message") + func testStatementPreparationOnlyHappensOnceWithConcurrentRequests() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Let them race to tests that requests and responses aren't mixed up + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database_active", database) + } + XCTAssertEqual(rows, 1) } - XCTAssertEqual(bind.preparedStatementName, String(reflecting: TestPrepareStatement.self)) - XCTAssertEqual(bind.parameters.count, 1) - XCTAssertEqual(bind.resultColumnFormats, [.binary]) - guard case .execute = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else { - fatalError("Unexpected message") + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database_idle", database) + } + XCTAssertEqual(rows, 1) } - guard case .sync = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else { - fatalError("Unexpected message") + + // The channel deduplicates prepare requests, we're going to see only one of them + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") } - // Respond to the EXECUTE request - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.testingEventLoop.executeInContext { channel.read() } - let dataRow = DataRow(arrayLiteral: "test_database") - try await channel.writeInbound(PostgresBackendMessage.dataRow(dataRow)) - try await channel.testingEventLoop.executeInContext { channel.read() } - try await channel.writeInbound(PostgresBackendMessage.commandComplete(TestPrepareStatement.sql)) + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) + + // Now both the tasks have their statements prepared. + // We should see both of their execute requests coming in, the order is nondeterministic + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + DataRow(arrayLiteral: "test_database_\(parameter1)") + ], + commandTag: TestPrepareStatement.sql + ) + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + DataRow(arrayLiteral: "test_database_\(parameter2)") + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + XCTAssert(parameters.contains("active")) + XCTAssert(parameters.contains("idle")) + } + } + + func testStatementPreparationFailure() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Was supposed to fail") + } catch { + XCTAssert(error is PSQLError) + } + } + + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + // Respond with an error taking care to return a SQLSTATE that isn't + // going to get the connection closed. + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .sqlState : "26000" // invalid_sql_statement_name + ]))) try await channel.testingEventLoop.executeInContext { channel.read() } try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) try await channel.testingEventLoop.executeInContext { channel.read() } + + + // Send another requests with the same prepared statement, which should fail straight + // away without any interaction with the server + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Was supposed to fail") + } catch { + XCTAssert(error is PSQLError) + } + } } } @@ -421,6 +609,66 @@ extension NIOAsyncTestingChannel { return UnpreparedRequest(parse: parse, describe: describe, bind: bind, execute: execute) } + + func waitForPrepareRequest() async throws -> PrepareRequest { + let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + + guard case .parse(let parse) = parse, + case .describe(let describe) = describe, + case .sync = sync + else { + fatalError("Unexpected message") + } + + return PrepareRequest(parse: parse, describe: describe) + } + + func sendPrepareResponse( + parameterDescription: PostgresBackendMessage.ParameterDescription, + rowDescription: RowDescription + ) async throws { + try await self.writeInbound(PostgresBackendMessage.parseComplete) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.parameterDescription(parameterDescription)) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.rowDescription(rowDescription)) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await self.testingEventLoop.executeInContext { self.read() } + } + + func waitForPreparedRequest() async throws -> PreparedRequest { + let bind = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let execute = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + + guard case .bind(let bind) = bind, + case .execute(let execute) = execute, + case .sync = sync + else { + fatalError() + } + + return PreparedRequest(bind: bind, execute: execute) + } + + func sendPreparedResponse( + dataRows: [DataRow], + commandTag: String + ) async throws { + try await self.writeInbound(PostgresBackendMessage.bindComplete) + try await self.testingEventLoop.executeInContext { self.read() } + for dataRow in dataRows { + try await self.writeInbound(PostgresBackendMessage.dataRow(dataRow)) + } + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.commandComplete(commandTag)) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await self.testingEventLoop.executeInContext { self.read() } + } } struct UnpreparedRequest { @@ -429,3 +677,13 @@ struct UnpreparedRequest { var bind: PostgresFrontendMessage.Bind var execute: PostgresFrontendMessage.Execute } + +struct PrepareRequest { + var parse: PostgresFrontendMessage.Parse + var describe: PostgresFrontendMessage.Describe +} + +struct PreparedRequest { + var bind: PostgresFrontendMessage.Bind + var execute: PostgresFrontendMessage.Execute +} From da87ce0a7d21647413fbffc64d08e768cda7b6d2 Mon Sep 17 00:00:00 2001 From: Mario Sangiorgio Date: Fri, 18 Aug 2023 09:39:49 +0100 Subject: [PATCH 13/13] Apply suggestions from code review Co-authored-by: Fabian Fett --- .../PostgresNIOTests/New/PostgresConnectionTests.swift | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 10c2b1de..9c4dc5cb 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -339,7 +339,7 @@ class PostgresConnectionTests: XCTestCase { try await channel.sendPreparedResponse( dataRows: [ - DataRow(arrayLiteral: "test_database") + ["test_database"] ], commandTag: TestPrepareStatement.sql ) @@ -394,7 +394,7 @@ class PostgresConnectionTests: XCTestCase { XCTAssertEqual(parameter1, "active") try await channel.sendPreparedResponse( dataRows: [ - DataRow(arrayLiteral: "test_database") + ["test_database"] ], commandTag: TestPrepareStatement.sql ) @@ -418,7 +418,7 @@ class PostgresConnectionTests: XCTestCase { XCTAssertEqual(parameter2, "idle") try await channel.sendPreparedResponse( dataRows: [ - DataRow(arrayLiteral: "test_database") + ["test_database"] ], commandTag: TestPrepareStatement.sql ) @@ -489,7 +489,7 @@ class PostgresConnectionTests: XCTestCase { let parameter1 = buffer.readString(length: buffer.readableBytes)! try await channel.sendPreparedResponse( dataRows: [ - DataRow(arrayLiteral: "test_database_\(parameter1)") + ["test_database_\(parameter1)"] ], commandTag: TestPrepareStatement.sql ) @@ -498,7 +498,7 @@ class PostgresConnectionTests: XCTestCase { let parameter2 = buffer.readString(length: buffer.readableBytes)! try await channel.sendPreparedResponse( dataRows: [ - DataRow(arrayLiteral: "test_database_\(parameter2)") + ["test_database_\(parameter2)"] ], commandTag: TestPrepareStatement.sql )