Skip to content

Commit 573573b

Browse files
committed
Fix tests
1 parent 4365c6b commit 573573b

File tree

5 files changed

+40
-62
lines changed

5 files changed

+40
-62
lines changed

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
306306
case .wait:
307307
break
308308
case .sendStartupMessage(let authContext):
309-
self.encoder.startup(authContext.toStartupParameters())
309+
self.encoder.startup(user: authContext.username, database: authContext.database, options: authContext.additionalParameters)
310310
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil)
311311
case .sendSSLRequest:
312312
self.encoder.ssl()
@@ -684,17 +684,6 @@ extension PostgresChannelHandler: PSQLRowsDataSource {
684684
}
685685
}
686686

687-
extension AuthContext {
688-
func toStartupParameters() -> PostgresFrontendMessage.Startup.Parameters {
689-
PostgresFrontendMessage.Startup.Parameters(
690-
user: self.username,
691-
database: self.database,
692-
replication: .false,
693-
options: self.additionalParameters
694-
)
695-
}
696-
}
697-
698687
private extension Insecure.MD5.Digest {
699688

700689
private static let lowercaseLookup: [UInt8] = [

Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,22 @@ struct PostgresFrontendMessageEncoder {
1313
self.buffer = buffer
1414
}
1515

16-
mutating func startup(_ parameters: PostgresFrontendMessage.Startup.Parameters) {
16+
mutating func startup(user: String, database: String?, options: [(String, String)]) {
1717
self.clearIfNeeded()
1818
self.encodeLengthPrefixed { buffer in
1919
buffer.writeInteger(PostgresFrontendMessage.Startup.versionThree)
2020
buffer.writeNullTerminatedString("user")
21-
buffer.writeNullTerminatedString(parameters.user)
21+
buffer.writeNullTerminatedString(user)
2222

23-
if let database = parameters.database {
23+
if let database = database {
2424
buffer.writeNullTerminatedString("database")
2525
buffer.writeNullTerminatedString(database)
2626
}
2727

28-
switch parameters.replication {
29-
case .database:
30-
buffer.writeNullTerminatedString("replication")
31-
buffer.writeNullTerminatedString("replication")
32-
case .true:
33-
buffer.writeNullTerminatedString("replication")
34-
buffer.writeNullTerminatedString("true")
35-
case .false:
36-
break
37-
}
28+
// we don't send replication parameters, as the default is false and this is what we
29+
// need for a client
3830

39-
for (key, value) in parameters.options {
31+
for (key, value) in options {
4032
buffer.writeNullTerminatedString(key)
4133
buffer.writeNullTerminatedString(value)
4234
}

Tests/PostgresNIOTests/New/Messages/StartupTests.swift

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,41 +7,26 @@ class StartupTests: XCTestCase {
77
func testStartupMessage() {
88
var encoder = PostgresFrontendMessageEncoder(buffer: .init())
99
var byteBuffer = ByteBuffer()
10-
11-
let replicationValues: [PostgresFrontendMessage.Startup.Parameters.Replication] = [
12-
.`true`,
13-
.`false`,
14-
.database
15-
]
16-
17-
for replication in replicationValues {
18-
let parameters = PostgresFrontendMessage.Startup.Parameters(
19-
user: "test",
20-
database: "abc123",
21-
replication: replication,
22-
options: [("some", "options")]
23-
)
24-
25-
encoder.startup(parameters)
26-
byteBuffer = encoder.flushBuffer()
2710

28-
let byteBufferLength = Int32(byteBuffer.readableBytes)
29-
XCTAssertEqual(byteBufferLength, byteBuffer.readInteger())
30-
XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger())
31-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user")
32-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test")
33-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database")
34-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123")
35-
if replication != .false {
36-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "replication")
37-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), replication.stringValue)
38-
}
39-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some")
40-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options")
41-
XCTAssertEqual(byteBuffer.readInteger(), UInt8(0))
42-
43-
XCTAssertEqual(byteBuffer.readableBytes, 0)
44-
}
11+
let user = "test"
12+
let database = "abc123"
13+
let options = [("some", "options")]
14+
15+
encoder.startup(user: user, database: database, options: options)
16+
byteBuffer = encoder.flushBuffer()
17+
18+
let byteBufferLength = Int32(byteBuffer.readableBytes)
19+
XCTAssertEqual(byteBufferLength, byteBuffer.readInteger())
20+
XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger())
21+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user")
22+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test")
23+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database")
24+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123")
25+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some")
26+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options")
27+
XCTAssertEqual(byteBuffer.readInteger(), UInt8(0))
28+
29+
XCTAssertEqual(byteBuffer.readableBytes, 0)
4530
}
4631
}
4732

Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import NIOSSL
55
import NIOEmbedded
66
@testable import PostgresNIO
77

8+
89
class PostgresChannelHandlerTests: XCTestCase {
910

1011
var eventLoop: EmbeddedEventLoop!
@@ -207,7 +208,7 @@ class PostgresChannelHandlerTests: XCTestCase {
207208

208209
XCTAssertEqual(startup.parameters.user, config.username)
209210
XCTAssertEqual(startup.parameters.database, config.database)
210-
XCTAssertEqual(startup.parameters.options, nil)
211+
XCTAssert(startup.parameters.options.isEmpty)
211212
XCTAssertEqual(startup.parameters.replication, .false)
212213

213214
var buffer = ByteBuffer()
@@ -274,3 +275,14 @@ class TestEventHandler: ChannelInboundHandler {
274275
self.events.append(psqlEvent)
275276
}
276277
}
278+
279+
extension AuthContext {
280+
func toStartupParameters() -> PostgresFrontendMessage.Startup.Parameters {
281+
PostgresFrontendMessage.Startup.Parameters(
282+
user: self.username,
283+
database: self.database,
284+
replication: .false,
285+
options: self.additionalParameters
286+
)
287+
}
288+
}

Tests/PostgresNIOTests/New/PostgresConnectionTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ class PostgresConnectionTests: XCTestCase {
292292

293293
async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: self.logger)
294294
let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self)
295-
XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", replication: .false))))
295+
XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", replication: .false, options: []))))
296296
try await channel.writeInbound(PostgresBackendMessage.authentication(.ok))
297297
try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678)))
298298
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))

0 commit comments

Comments
 (0)