diff --git a/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/MessageEncoder.scala b/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/MessageEncoder.scala index 5cf5d480..b00ba852 100644 --- a/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/MessageEncoder.scala +++ b/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/MessageEncoder.scala @@ -16,16 +16,15 @@ package com.github.mauricio.async.db.postgresql.codec +import java.nio.charset.Charset + import com.github.mauricio.async.db.column.ColumnEncoderRegistry import com.github.mauricio.async.db.exceptions.EncoderNotAvailableException import com.github.mauricio.async.db.postgresql.encoders._ -import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage import com.github.mauricio.async.db.postgresql.messages.frontend._ import com.github.mauricio.async.db.util.{BufferDumper, Log} -import java.nio.charset.Charset -import scala.annotation.switch -import io.netty.handler.codec.MessageToMessageEncoder import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.MessageToMessageEncoder object MessageEncoder { val log = Log.get[MessageEncoder] @@ -44,22 +43,19 @@ class MessageEncoder(charset: Charset, encoderRegistry: ColumnEncoderRegistry) e override def encode(ctx: ChannelHandlerContext, msg: AnyRef, out: java.util.List[Object]) = { val buffer = msg match { - case message: ClientMessage => { - val encoder = (message.kind: @switch) match { - case ServerMessage.Close => CloseMessageEncoder - case ServerMessage.Execute => this.executeEncoder - case ServerMessage.Parse => this.openEncoder - case ServerMessage.Startup => this.startupEncoder - case ServerMessage.Query => this.queryEncoder - case ServerMessage.PasswordMessage => this.credentialEncoder + case message: ClientMessage => + val encoder = message match { + case CloseMessage => CloseMessageEncoder + case _ : PreparedStatementOpeningMessage => this.openEncoder + case _ : StartupMessage => this.startupEncoder + case _ : QueryMessage => this.queryEncoder + case _ : CredentialMessage => this.credentialEncoder + case _ : PreparedStatementExecuteMessage => this.executeEncoder case _ => throw new EncoderNotAvailableException(message) } - encoder.encode(message) - } - case _ => { + case _ => throw new IllegalArgumentException("Can not encode message %s".format(msg)) - } } if (log.isTraceEnabled) { diff --git a/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/encoders/PreparedStatementEncoderHelper.scala b/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/encoders/PreparedStatementEncoderHelper.scala index 4f0716b9..93d91884 100644 --- a/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/encoders/PreparedStatementEncoderHelper.scala +++ b/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/encoders/PreparedStatementEncoderHelper.scala @@ -40,8 +40,17 @@ trait PreparedStatementEncoderHelper { writeDescribe: Boolean = false ): ByteBuf = { + val bindBuffer: ByteBuf = bind(statementIdBytes, query, values, encoder, charset, writeDescribe) + val executeBuffer: ByteBuf = execute(statementIdBytes, 0) + val closeBuffer: ByteBuf = closePortal(statementIdBytes) + val syncBuffer: ByteBuf = sync + + Unpooled.wrappedBuffer(bindBuffer, executeBuffer, syncBuffer, closeBuffer) + } + + def bind(statementIdBytes: Array[Byte], query: String, values: Seq[Any], encoder: ColumnEncoderRegistry, charset: Charset, writeDescribe: Boolean): ByteBuf = { if (log.isDebugEnabled) { - log.debug(s"Preparing execute portal to statement ($query) - values (${values.mkString(", ")}) - ${charset}") + log.debug(s"Preparing execute portal to statement ($query) - values (${values.mkString(", ")}) - $charset") } val bindBuffer = Unpooled.buffer(1024) @@ -106,15 +115,28 @@ trait PreparedStatementEncoderHelper { describeBuffer.writeBytes(statementIdBytes) describeBuffer.writeByte(0) } + bindBuffer + } + def execute(statementIdBytes: Array[Byte], fetchSize: Int): ByteBuf = { val executeLength = 1 + 4 + statementIdBytes.length + 1 + 4 val executeBuffer = Unpooled.buffer(executeLength) executeBuffer.writeByte(ServerMessage.Execute) executeBuffer.writeInt(executeLength - 1) executeBuffer.writeBytes(statementIdBytes) executeBuffer.writeByte(0) - executeBuffer.writeInt(0) + executeBuffer.writeInt(fetchSize) + executeBuffer + } + def sync: ByteBuf = { + val syncBuffer = Unpooled.buffer(5) + syncBuffer.writeByte(ServerMessage.Sync) + syncBuffer.writeInt(4) + syncBuffer + } + + def closePortal(statementIdBytes: Array[Byte]): ByteBuf = { val closeLength = 1 + 4 + 1 + statementIdBytes.length + 1 val closeBuffer = Unpooled.buffer(closeLength) closeBuffer.writeByte(ServerMessage.CloseStatementOrPortal) @@ -122,15 +144,34 @@ trait PreparedStatementEncoderHelper { closeBuffer.writeByte('P') closeBuffer.writeBytes(statementIdBytes) closeBuffer.writeByte(0) + closeBuffer + } - val syncBuffer = Unpooled.buffer(5) - syncBuffer.writeByte(ServerMessage.Sync) - syncBuffer.writeInt(4) + def isNull(value: Any): Boolean = value == null || value == None - Unpooled.wrappedBuffer(bindBuffer, executeBuffer, syncBuffer, closeBuffer) + def parse(statementIdBytes: Array[Byte], query: String, valueTypes: Seq[Int], charset: Charset): ByteBuf = { + val columnCount = valueTypes.size - } + val parseBuffer = Unpooled.buffer(1024) + parseBuffer.writeByte(ServerMessage.Parse) + parseBuffer.writeInt(0) - def isNull(value: Any): Boolean = value == null || value == None + parseBuffer.writeBytes(statementIdBytes) + parseBuffer.writeByte(0) + parseBuffer.writeBytes(query.getBytes(charset)) + parseBuffer.writeByte(0) + + parseBuffer.writeShort(columnCount) + if (log.isDebugEnabled) { + log.debug(s"Opening query ($query) - statement id (${statementIdBytes.mkString("-")}) - selected types (${valueTypes.mkString(", ")}))") + } + + for (kind <- valueTypes) { + parseBuffer.writeInt(kind) + } + + ByteBufferUtils.writeLength(parseBuffer) + parseBuffer + } } diff --git a/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/encoders/PreparedStatementOpeningEncoder.scala b/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/encoders/PreparedStatementOpeningEncoder.scala index 41263bb1..d6515219 100644 --- a/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/encoders/PreparedStatementOpeningEncoder.scala +++ b/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/encoders/PreparedStatementOpeningEncoder.scala @@ -16,12 +16,12 @@ package com.github.mauricio.async.db.postgresql.encoders +import java.nio.charset.Charset + import com.github.mauricio.async.db.column.ColumnEncoderRegistry -import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage import com.github.mauricio.async.db.postgresql.messages.frontend.{ClientMessage, PreparedStatementOpeningMessage} -import com.github.mauricio.async.db.util.{Log, ByteBufferUtils} -import java.nio.charset.Charset -import io.netty.buffer.{Unpooled, ByteBuf} +import com.github.mauricio.async.db.util.Log +import io.netty.buffer.{ByteBuf, Unpooled} object PreparedStatementOpeningEncoder { val log = Log.get[PreparedStatementOpeningEncoder] @@ -32,40 +32,14 @@ class PreparedStatementOpeningEncoder(charset: Charset, encoder : ColumnEncoderR with PreparedStatementEncoderHelper { - import PreparedStatementOpeningEncoder.log - override def encode(message: ClientMessage): ByteBuf = { val m = message.asInstanceOf[PreparedStatementOpeningMessage] val statementIdBytes = m.statementId.toString.getBytes(charset) - val columnCount = m.valueTypes.size - - val parseBuffer = Unpooled.buffer(1024) - - parseBuffer.writeByte(ServerMessage.Parse) - parseBuffer.writeInt(0) - - parseBuffer.writeBytes(statementIdBytes) - parseBuffer.writeByte(0) - parseBuffer.writeBytes(m.query.getBytes(charset)) - parseBuffer.writeByte(0) - - parseBuffer.writeShort(columnCount) - - if ( log.isDebugEnabled ) { - log.debug(s"Opening query (${m.query}) - statement id (${statementIdBytes.mkString("-")}) - selected types (${m.valueTypes.mkString(", ")}) - values (${m.values.mkString(", ")})") - } - - for (kind <- m.valueTypes) { - parseBuffer.writeInt(kind) - } - - ByteBufferUtils.writeLength(parseBuffer) - - val executeBuffer = writeExecutePortal(statementIdBytes, m.query, m.values, encoder, charset, true) + val parseBuffer: ByteBuf = parse(statementIdBytes, m.query, m.valueTypes, charset) + val executeBuffer = writeExecutePortal(statementIdBytes, m.query, m.values, encoder, charset, writeDescribe = true) Unpooled.wrappedBuffer(parseBuffer, executeBuffer) } - } \ No newline at end of file