Skip to content

Commit 20a0f2a

Browse files
authored
Add message types to support COPY operations (#569)
This adds the infrastrucutre to decode messages needed for COPY operations. It does not implement the handling support itself yet. That will be added in a follow-up PR.
1 parent d50aade commit 20a0f2a

12 files changed

+305
-5
lines changed

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,12 @@ struct ConnectionStateMachine {
752752
return self.modify(with: action)
753753
}
754754

755+
mutating func copyInResponseReceived(
756+
_ copyInResponse: PostgresBackendMessage.CopyInResponse
757+
) -> ConnectionAction {
758+
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
759+
}
760+
755761
mutating func emptyQueryResponseReceived() -> ConnectionAction {
756762
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
757763
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse))

Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ struct ExtendedQueryStateMachine {
9191
mutating func cancel() -> Action {
9292
switch self.state {
9393
case .initialized:
94-
preconditionFailure("Start must be called immediatly after the query was created")
94+
preconditionFailure("Start must be called immediately after the query was created")
9595

9696
case .messagesSent(let queryContext),
9797
.parseCompleteReceived(let queryContext),
@@ -322,6 +322,12 @@ struct ExtendedQueryStateMachine {
322322
}
323323
}
324324

325+
mutating func copyInResponseReceived(
326+
_ copyInResponse: PostgresBackendMessage.CopyInResponse
327+
) -> Action {
328+
return self.setAndFireError(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
329+
}
330+
325331
mutating func emptyQueryResponseReceived() -> Action {
326332
guard case .bindCompleteReceived(let queryContext) = self.state else {
327333
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
extension PostgresBackendMessage {
2+
struct CopyInResponse: Hashable {
3+
enum Format: Int8 {
4+
case textual = 0
5+
case binary = 1
6+
}
7+
8+
let format: Format
9+
let columnFormats: [Format]
10+
11+
static func decode(from buffer: inout ByteBuffer) throws -> Self {
12+
guard let rawFormat = buffer.readInteger(endianness: .big, as: Int8.self) else {
13+
throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(1, actual: buffer.readableBytes)
14+
}
15+
guard let format = Format(rawValue: rawFormat) else {
16+
throw PSQLPartialDecodingError.unexpectedValue(value: rawFormat)
17+
}
18+
19+
guard let numColumns = buffer.readInteger(endianness: .big, as: Int16.self) else {
20+
throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: buffer.readableBytes)
21+
}
22+
var columnFormatCodes: [Format] = []
23+
columnFormatCodes.reserveCapacity(Int(numColumns))
24+
25+
for _ in 0..<numColumns {
26+
guard let rawColumnFormat = buffer.readInteger(endianness: .big, as: Int16.self) else {
27+
throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: buffer.readableBytes)
28+
}
29+
guard Int8.min <= rawColumnFormat, rawColumnFormat <= Int8.max, let columnFormat = Format(rawValue: Int8(rawColumnFormat)) else {
30+
throw PSQLPartialDecodingError.unexpectedValue(value: rawColumnFormat)
31+
}
32+
columnFormatCodes.append(columnFormat)
33+
}
34+
35+
return CopyInResponse(format: format, columnFormats: columnFormatCodes)
36+
}
37+
}
38+
}
39+
40+
extension PostgresBackendMessage.CopyInResponse: CustomDebugStringConvertible {
41+
var debugDescription: String {
42+
"format: \(format), columnFormats: \(columnFormats)"
43+
}
44+
}

Sources/PostgresNIO/New/PostgresBackendMessage.swift

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ enum PostgresBackendMessage: Hashable {
2929
case bindComplete
3030
case closeComplete
3131
case commandComplete(String)
32+
case copyInResponse(CopyInResponse)
3233
case dataRow(DataRow)
3334
case emptyQueryResponse
3435
case error(ErrorResponse)
@@ -96,6 +97,9 @@ extension PostgresBackendMessage {
9697
}
9798
return .commandComplete(commandTag)
9899

100+
case .copyInResponse:
101+
return try .copyInResponse(.decode(from: &buffer))
102+
99103
case .dataRow:
100104
return try .dataRow(.decode(from: &buffer))
101105

@@ -131,9 +135,9 @@ extension PostgresBackendMessage {
131135

132136
case .rowDescription:
133137
return try .rowDescription(.decode(from: &buffer))
134-
135-
case .copyData, .copyDone, .copyInResponse, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion:
136-
preconditionFailure()
138+
139+
case .copyData, .copyDone, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion:
140+
throw PSQLPartialDecodingError.unknownMessageKind(messageID)
137141
}
138142
}
139143
}
@@ -151,6 +155,8 @@ extension PostgresBackendMessage: CustomDebugStringConvertible {
151155
return ".closeComplete"
152156
case .commandComplete(let commandTag):
153157
return ".commandComplete(\(String(reflecting: commandTag)))"
158+
case .copyInResponse(let copyInResponse):
159+
return ".copyInResponse(\(String(reflecting: copyInResponse)))"
154160
case .dataRow(let dataRow):
155161
return ".dataRow(\(String(reflecting: dataRow)))"
156162
case .emptyQueryResponse:

Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ struct PSQLPartialDecodingError: Error {
189189
description: "Expected the integer to be positive or null, but got \(actual).",
190190
file: file, line: line)
191191
}
192+
193+
static func unknownMessageKind(_ messageID: PostgresBackendMessage.ID, file: String = #fileID, line: Int = #line) -> Self {
194+
return PSQLPartialDecodingError(
195+
description: "Unknown message kind: \(messageID)",
196+
file: file, line: line)
197+
}
192198
}
193199

194200
extension ByteBuffer {

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
136136
action = self.state.closeCompletedReceived()
137137
case .commandComplete(let commandTag):
138138
action = self.state.commandCompletedReceived(commandTag)
139+
case .copyInResponse(let copyInResponse):
140+
action = self.state.copyInResponseReceived(copyInResponse)
139141
case .dataRow(let dataRow):
140142
action = self.state.dataRowReceived(dataRow)
141143
case .emptyQueryResponse:

Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,28 @@ struct PostgresFrontendMessageEncoder {
167167
self.buffer.writeMultipleIntegers(UInt32(8), Self.sslRequestCode)
168168
}
169169

170+
/// Adds the `CopyData` message ID and `dataLength` to the message buffer but not the actual data.
171+
///
172+
/// The caller of this function is expected to write the encoder's message buffer to the backend after calling this
173+
/// function, followed by sending the actual data to the backend.
174+
mutating func copyDataHeader(dataLength: UInt32) {
175+
self.clearIfNeeded()
176+
self.buffer.psqlWriteMultipleIntegers(id: .copyData, length: dataLength)
177+
}
178+
179+
mutating func copyDone() {
180+
self.clearIfNeeded()
181+
self.buffer.psqlWriteMultipleIntegers(id: .copyDone, length: 0)
182+
}
183+
184+
mutating func copyFail(message: String) {
185+
self.clearIfNeeded()
186+
var messageBuffer = ByteBuffer()
187+
messageBuffer.writeNullTerminatedString(message)
188+
self.buffer.psqlWriteMultipleIntegers(id: .copyFail, length: UInt32(messageBuffer.readableBytes))
189+
self.buffer.writeImmutableBuffer(messageBuffer)
190+
}
191+
170192
mutating func sync() {
171193
self.clearIfNeeded()
172194
self.buffer.psqlWriteMultipleIntegers(id: .sync, length: 0)
@@ -197,6 +219,9 @@ struct PostgresFrontendMessageEncoder {
197219
private enum FrontendMessageID: UInt8, Hashable, Sendable {
198220
case bind = 66 // B
199221
case close = 67 // C
222+
case copyData = 100 // d
223+
case copyDone = 99 // c
224+
case copyFail = 102 // f
200225
case describe = 68 // D
201226
case execute = 69 // E
202227
case flush = 72 // H

Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class ExtendedQueryStateMachineTests: XCTestCase {
114114
.failQuery(promise, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil)))
115115
}
116116

117-
func testExtendedQueryIsCancelledImmediatly() {
117+
func testExtendedQueryIsCancelledImmediately() {
118118
var state = ConnectionStateMachine.readyForQuery()
119119

120120
let logger = Logger.psqlTest

Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder {
2828
case .commandComplete(let string):
2929
self.encode(messageID: message.id, payload: StringPayload(string), into: &buffer)
3030

31+
case .copyInResponse(let copyInResponse):
32+
self.encode(messageID: message.id, payload: copyInResponse, into: &buffer)
3133
case .dataRow(let row):
3234
self.encode(messageID: message.id, payload: row, into: &buffer)
3335

@@ -99,6 +101,8 @@ extension PostgresBackendMessage {
99101
return .closeComplete
100102
case .commandComplete:
101103
return .commandComplete
104+
case .copyInResponse:
105+
return .copyInResponse
102106
case .dataRow:
103107
return .dataRow
104108
case .emptyQueryResponse:
@@ -184,6 +188,16 @@ extension PostgresBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable {
184188
}
185189
}
186190

191+
extension PostgresBackendMessage.CopyInResponse: PSQLMessagePayloadEncodable {
192+
public func encode(into buffer: inout ByteBuffer) {
193+
buffer.writeInteger(Int8(self.format.rawValue))
194+
buffer.writeInteger(Int16(self.columnFormats.count))
195+
for columnFormat in columnFormats {
196+
buffer.writeInteger(Int16(columnFormat.rawValue))
197+
}
198+
}
199+
}
200+
187201
extension DataRow: PSQLMessagePayloadEncodable {
188202
public func encode(into buffer: inout ByteBuffer) {
189203
buffer.writeInteger(self.columnCount, as: Int16.self)

Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ extension PostgresFrontendMessage {
168168
)
169169
)
170170

171+
case .copyData:
172+
return .copyData(CopyData(data: buffer))
173+
174+
case .copyDone:
175+
return .copyDone
176+
177+
case .copyFail:
178+
guard let message = buffer.readNullTerminatedString() else {
179+
throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self)
180+
}
181+
return .copyFail(CopyFail(message: message))
182+
171183
case .close:
172184
preconditionFailure("TODO: Unimplemented")
173185

0 commit comments

Comments
 (0)