From 978f9f3a0b41c6214205d130da172f61634f2bd1 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Mon, 9 Oct 2023 09:37:55 +0900 Subject: [PATCH] introduce writeBuffer --- auth.go | 9 ++++----- buffer.go | 28 ++++++++++++++++++++++++++++ connection.go | 11 ++++------- connector.go | 32 -------------------------------- packets.go | 50 ++++++++++++++++---------------------------------- 5 files changed, 52 insertions(+), 78 deletions(-) create mode 100644 buffer.go diff --git a/auth.go b/auth.go index 09fbbbc1f..ef5394c62 100644 --- a/auth.go +++ b/auth.go @@ -361,8 +361,9 @@ func (mc *mysqlConn) handleAuthResult(ctx context.Context, oldAuthData []byte, p } else { pubKey := mc.cfg.pubKey if pubKey == nil { - mc.data[4] = cachingSha2PasswordRequestPublicKey - err = mc.writePacket(ctx, mc.data[:5]) + data := mc.wbuf.takeBuffer(5) + data[4] = cachingSha2PasswordRequestPublicKey + err = mc.writePacket(ctx, data) if err != nil { return err } @@ -372,7 +373,7 @@ func (mc *mysqlConn) handleAuthResult(ctx context.Context, oldAuthData []byte, p return err } - data := packet.data + data = packet.data if data[0] != iAuthMoreData { return fmt.Errorf("unexpected resp from server for caching_sha2_password, perform full authentication") } @@ -387,8 +388,6 @@ func (mc *mysqlConn) handleAuthResult(ctx context.Context, oldAuthData []byte, p return err } pubKey = pkix.(*rsa.PublicKey) - - mc.connector.putPacket(packet) } // send encrypted password diff --git a/buffer.go b/buffer.go new file mode 100644 index 000000000..30848c99c --- /dev/null +++ b/buffer.go @@ -0,0 +1,28 @@ +package mysql + +const defaultBufSize = 4096 + +// const maxCachedBufSize = 256 * 1024 + +type writeBuffer struct { + buf []byte +} + +// takeBuffer returns a buffer with the requested size. +// If possible, a slice from the existing buffer is returned. +// Otherwise a bigger buffer is made. +// Only one buffer (total) can be used at a time. +func (wb *writeBuffer) takeBuffer(length int) []byte { + if length <= cap(wb.buf) { + return wb.buf[:length] + } + if length <= defaultBufSize { + wb.buf = make([]byte, length, defaultBufSize) + return wb.buf + } + if length <= maxPacketSize { + wb.buf = make([]byte, length) + return wb.buf + } + return make([]byte, length) +} diff --git a/connection.go b/connection.go index d1f9226fe..1c238d821 100644 --- a/connection.go +++ b/connection.go @@ -46,12 +46,11 @@ type mysqlConn struct { sequence uint8 parseTime bool reset bool // set when the Go SQL package calls ResetSession + wbuf writeBuffer // for context support (Go 1.8+) - closech chan struct{} - closed atomicBool // set when conn is closed, before closech is closed - - data [16]byte // buffer for small writes + closech chan struct{} + closed atomicBool // set when conn is closed, before closech is closed readRes chan *packet // channel for read result writeReq chan []byte // buffered channel for write packets writeRes chan writeResult // channel for write result @@ -194,9 +193,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } var err error - packet := mc.connector.getPacket() - defer mc.connector.putPacket(packet) - buf := packet.data[:0] + buf := mc.wbuf.takeBuffer(0) argPos := 0 for i := 0; i < len(query); i++ { diff --git a/connector.go b/connector.go index 1fe65a4f4..62dc08376 100644 --- a/connector.go +++ b/connector.go @@ -16,16 +16,11 @@ import ( "os" "strconv" "strings" - "sync" ) -const defaultBufSize = 4096 -const maxCachedBufSize = 256 * 1024 - type connector struct { cfg *Config // immutable private copy. encodedAttributes string // Encoded connection attributes. - packetPool sync.Pool } func encodeConnectionAttributes(textAttributes string) string { @@ -65,33 +60,6 @@ func newConnector(cfg *Config) (*connector, error) { }, nil } -func (c *connector) getPacket() *packet { - if c == nil { - return &packet{data: make([]byte, defaultBufSize)} - } - pkt := c.packetPool.Get() - if pkt == nil { - return &packet{data: make([]byte, defaultBufSize)} - } - return pkt.(*packet) -} - -func (c *connector) getPacketWithSize(n int) *packet { - pkt := c.getPacket() - if cap(pkt.data) < n { - pkt.data = make([]byte, n) - } else { - pkt.data = pkt.data[:n] - } - return pkt -} - -func (c *connector) putPacket(pkt *packet) { - if c != nil && cap(pkt.data) <= maxCachedBufSize { - c.packetPool.Put(pkt) - } -} - // Connect implements driver.Connector interface. // Connect returns a connection to the database. func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { diff --git a/packets.go b/packets.go index 9fbd66c18..4ebd6d976 100644 --- a/packets.go +++ b/packets.go @@ -117,7 +117,6 @@ func (mc *mysqlConn) readPacket(ctx context.Context) (*packet, error) { prevData = pkt } else { prevData.data = append(prevData.data, pkt.data...) - mc.connector.putPacket(pkt) } } } @@ -454,9 +453,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(ctx context.Context, authResp // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse func (mc *mysqlConn) writeAuthSwitchPacket(ctx context.Context, authData []byte) error { pktLen := 4 + len(authData) - packet := mc.connector.getPacketWithSize(pktLen) - defer mc.connector.putPacket(packet) - data := packet.data + data := mc.wbuf.takeBuffer(pktLen) // Add the auth data [EOF] copy(data[4:], authData) @@ -472,10 +469,11 @@ func (mc *mysqlConn) writeCommandPacket(ctx context.Context, command byte) error mc.sequence = 0 // Add command byte - mc.data[4] = command + data := mc.wbuf.takeBuffer(4 + 1) + data[4] = command // Send CMD packet - return mc.writePacket(ctx, mc.data[:4+1]) + return mc.writePacket(ctx, data) } func (mc *mysqlConn) writeCommandPacketStr(ctx context.Context, command byte, arg string) error { @@ -484,9 +482,7 @@ func (mc *mysqlConn) writeCommandPacketStr(ctx context.Context, command byte, ar pktLen := 1 + len(arg) - packet := mc.connector.getPacketWithSize(4 + pktLen) - defer mc.connector.putPacket(packet) - data := packet.data + data := mc.wbuf.takeBuffer(4 + pktLen) // Add command byte data[4] = command @@ -503,16 +499,17 @@ func (mc *mysqlConn) writeCommandPacketUint32(ctx context.Context, command byte, mc.sequence = 0 // Add command byte - mc.data[4] = command + data := mc.wbuf.takeBuffer(4 + 1 + 4) + data[4] = command // Add arg [32 bit] - mc.data[5] = byte(arg) - mc.data[6] = byte(arg >> 8) - mc.data[7] = byte(arg >> 16) - mc.data[8] = byte(arg >> 24) + data[5] = byte(arg) + data[6] = byte(arg >> 8) + data[7] = byte(arg >> 16) + data[8] = byte(arg >> 24) // Send CMD packet - return mc.writePacket(ctx, mc.data[:4+1+4]) + return mc.writePacket(ctx, data) } /****************************************************************************** @@ -561,7 +558,6 @@ func (mc *okHandler) readResultOK(ctx context.Context) error { if err != nil { return err } - defer mc.connector.putPacket(packet) data := packet.data if data[0] == iOK { @@ -829,9 +825,6 @@ func (rows *textRows) readRow(dest []driver.Value) error { return io.EOF } - if pkt := rows.pkt; pkt != nil { - rows.mc.connector.putPacket(pkt) - } packet, err := mc.readPacket(ctx) if err != nil { return err @@ -934,7 +927,6 @@ func (mc *mysqlConn) readUntilEOF(ctx context.Context) error { } return nil } - mc.connector.putPacket(packet) } } @@ -971,7 +963,6 @@ func (stmt *mysqlStmt) readPrepareResultPacket(ctx context.Context) (uint16, err return columnCount, nil } - stmt.mc.connector.putPacket(packet) return 0, err } @@ -989,10 +980,7 @@ func (stmt *mysqlStmt) writeCommandLongData(ctx context.Context, paramID int, ar // Cannot use the write buffer since // a) the buffer is too small // b) it is in use - bufLen := 4 + 1 + 4 + 2 + len(arg) - packet := stmt.mc.connector.getPacketWithSize(bufLen) - defer stmt.mc.connector.putPacket(packet) - data := packet.data + data := make([]byte, 4+1+4+2+len(arg)) copy(data[4+dataOffset:], arg) @@ -1055,9 +1043,8 @@ func (stmt *mysqlStmt) writeExecutePacket(ctx context.Context, args []driver.Val var err error - packet := mc.connector.getPacket() - defer mc.connector.putPacket(packet) - data := packet.data[:cap(packet.data)] + data := mc.wbuf.takeBuffer(minPktLen) + data = data[:cap(data)] // command [1 byte] data[4] = comStmtExecute @@ -1259,7 +1246,6 @@ func (stmt *mysqlStmt) writeExecutePacket(ctx context.Context, args []driver.Val data = data[:pos] } - packet.data = data return mc.writePacket(ctx, data) } @@ -1288,9 +1274,6 @@ func (mc *okHandler) discardResults(ctx context.Context) error { // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html func (rows *binaryRows) readRow(dest []driver.Value) error { ctx := rows.ctx - if pkt := rows.pkt; pkt != nil { - rows.mc.connector.putPacket(pkt) - } packet, err := rows.mc.readPacket(ctx) if err != nil { return err @@ -1482,8 +1465,7 @@ func (mc *mysqlConn) startGoroutines() { func (mc *mysqlConn) readLoop() { for { - pkt := mc.connector.getPacket() - + pkt := new(packet) mc.muRead.Lock() pkt.readFrom(mc.netConn) mc.muRead.Unlock()