Skip to content

Commit

Permalink
introduce writeBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
shogo82148 committed Oct 9, 2023
1 parent 07a417f commit 978f9f3
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 78 deletions.
9 changes: 4 additions & 5 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,9 @@ func (mc *mysqlConn) handleAuthResult(ctx context.Context, oldAuthData []byte, p
} else {
pubKey := mc.cfg.pubKey
if pubKey == nil {
mc.data[4] = cachingSha2PasswordRequestPublicKey
err = mc.writePacket(ctx, mc.data[:5])
data := mc.wbuf.takeBuffer(5)
data[4] = cachingSha2PasswordRequestPublicKey
err = mc.writePacket(ctx, data)
if err != nil {
return err
}
Expand All @@ -372,7 +373,7 @@ func (mc *mysqlConn) handleAuthResult(ctx context.Context, oldAuthData []byte, p
return err
}

data := packet.data
data = packet.data
if data[0] != iAuthMoreData {
return fmt.Errorf("unexpected resp from server for caching_sha2_password, perform full authentication")
}
Expand All @@ -387,8 +388,6 @@ func (mc *mysqlConn) handleAuthResult(ctx context.Context, oldAuthData []byte, p
return err
}
pubKey = pkix.(*rsa.PublicKey)

mc.connector.putPacket(packet)
}

// send encrypted password
Expand Down
28 changes: 28 additions & 0 deletions buffer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package mysql

const defaultBufSize = 4096

// const maxCachedBufSize = 256 * 1024

type writeBuffer struct {
buf []byte
}

// takeBuffer returns a buffer with the requested size.
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (wb *writeBuffer) takeBuffer(length int) []byte {
if length <= cap(wb.buf) {
return wb.buf[:length]
}
if length <= defaultBufSize {
wb.buf = make([]byte, length, defaultBufSize)
return wb.buf
}
if length <= maxPacketSize {
wb.buf = make([]byte, length)
return wb.buf
}
return make([]byte, length)
}
11 changes: 4 additions & 7 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,11 @@ type mysqlConn struct {
sequence uint8
parseTime bool
reset bool // set when the Go SQL package calls ResetSession
wbuf writeBuffer

// for context support (Go 1.8+)
closech chan struct{}
closed atomicBool // set when conn is closed, before closech is closed

data [16]byte // buffer for small writes
closech chan struct{}
closed atomicBool // set when conn is closed, before closech is closed
readRes chan *packet // channel for read result
writeReq chan []byte // buffered channel for write packets
writeRes chan writeResult // channel for write result
Expand Down Expand Up @@ -194,9 +193,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
}

var err error
packet := mc.connector.getPacket()
defer mc.connector.putPacket(packet)
buf := packet.data[:0]
buf := mc.wbuf.takeBuffer(0)
argPos := 0

for i := 0; i < len(query); i++ {
Expand Down
32 changes: 0 additions & 32 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,11 @@ import (
"os"
"strconv"
"strings"
"sync"
)

const defaultBufSize = 4096
const maxCachedBufSize = 256 * 1024

type connector struct {
cfg *Config // immutable private copy.
encodedAttributes string // Encoded connection attributes.
packetPool sync.Pool
}

func encodeConnectionAttributes(textAttributes string) string {
Expand Down Expand Up @@ -65,33 +60,6 @@ func newConnector(cfg *Config) (*connector, error) {
}, nil
}

func (c *connector) getPacket() *packet {
if c == nil {
return &packet{data: make([]byte, defaultBufSize)}
}
pkt := c.packetPool.Get()
if pkt == nil {
return &packet{data: make([]byte, defaultBufSize)}
}
return pkt.(*packet)
}

func (c *connector) getPacketWithSize(n int) *packet {
pkt := c.getPacket()
if cap(pkt.data) < n {
pkt.data = make([]byte, n)
} else {
pkt.data = pkt.data[:n]
}
return pkt
}

func (c *connector) putPacket(pkt *packet) {
if c != nil && cap(pkt.data) <= maxCachedBufSize {
c.packetPool.Put(pkt)
}
}

// Connect implements driver.Connector interface.
// Connect returns a connection to the database.
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
Expand Down
50 changes: 16 additions & 34 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ func (mc *mysqlConn) readPacket(ctx context.Context) (*packet, error) {
prevData = pkt
} else {
prevData.data = append(prevData.data, pkt.data...)
mc.connector.putPacket(pkt)
}
}
}
Expand Down Expand Up @@ -454,9 +453,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(ctx context.Context, authResp
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeAuthSwitchPacket(ctx context.Context, authData []byte) error {
pktLen := 4 + len(authData)
packet := mc.connector.getPacketWithSize(pktLen)
defer mc.connector.putPacket(packet)
data := packet.data
data := mc.wbuf.takeBuffer(pktLen)

// Add the auth data [EOF]
copy(data[4:], authData)
Expand All @@ -472,10 +469,11 @@ func (mc *mysqlConn) writeCommandPacket(ctx context.Context, command byte) error
mc.sequence = 0

// Add command byte
mc.data[4] = command
data := mc.wbuf.takeBuffer(4 + 1)
data[4] = command

// Send CMD packet
return mc.writePacket(ctx, mc.data[:4+1])
return mc.writePacket(ctx, data)
}

func (mc *mysqlConn) writeCommandPacketStr(ctx context.Context, command byte, arg string) error {
Expand All @@ -484,9 +482,7 @@ func (mc *mysqlConn) writeCommandPacketStr(ctx context.Context, command byte, ar

pktLen := 1 + len(arg)

packet := mc.connector.getPacketWithSize(4 + pktLen)
defer mc.connector.putPacket(packet)
data := packet.data
data := mc.wbuf.takeBuffer(4 + pktLen)

// Add command byte
data[4] = command
Expand All @@ -503,16 +499,17 @@ func (mc *mysqlConn) writeCommandPacketUint32(ctx context.Context, command byte,
mc.sequence = 0

// Add command byte
mc.data[4] = command
data := mc.wbuf.takeBuffer(4 + 1 + 4)
data[4] = command

// Add arg [32 bit]
mc.data[5] = byte(arg)
mc.data[6] = byte(arg >> 8)
mc.data[7] = byte(arg >> 16)
mc.data[8] = byte(arg >> 24)
data[5] = byte(arg)
data[6] = byte(arg >> 8)
data[7] = byte(arg >> 16)
data[8] = byte(arg >> 24)

// Send CMD packet
return mc.writePacket(ctx, mc.data[:4+1+4])
return mc.writePacket(ctx, data)
}

/******************************************************************************
Expand Down Expand Up @@ -561,7 +558,6 @@ func (mc *okHandler) readResultOK(ctx context.Context) error {
if err != nil {
return err
}
defer mc.connector.putPacket(packet)
data := packet.data

if data[0] == iOK {
Expand Down Expand Up @@ -829,9 +825,6 @@ func (rows *textRows) readRow(dest []driver.Value) error {
return io.EOF
}

if pkt := rows.pkt; pkt != nil {
rows.mc.connector.putPacket(pkt)
}
packet, err := mc.readPacket(ctx)
if err != nil {
return err
Expand Down Expand Up @@ -934,7 +927,6 @@ func (mc *mysqlConn) readUntilEOF(ctx context.Context) error {
}
return nil
}
mc.connector.putPacket(packet)
}
}

Expand Down Expand Up @@ -971,7 +963,6 @@ func (stmt *mysqlStmt) readPrepareResultPacket(ctx context.Context) (uint16, err

return columnCount, nil
}
stmt.mc.connector.putPacket(packet)
return 0, err
}

Expand All @@ -989,10 +980,7 @@ func (stmt *mysqlStmt) writeCommandLongData(ctx context.Context, paramID int, ar
// Cannot use the write buffer since
// a) the buffer is too small
// b) it is in use
bufLen := 4 + 1 + 4 + 2 + len(arg)
packet := stmt.mc.connector.getPacketWithSize(bufLen)
defer stmt.mc.connector.putPacket(packet)
data := packet.data
data := make([]byte, 4+1+4+2+len(arg))

copy(data[4+dataOffset:], arg)

Expand Down Expand Up @@ -1055,9 +1043,8 @@ func (stmt *mysqlStmt) writeExecutePacket(ctx context.Context, args []driver.Val

var err error

packet := mc.connector.getPacket()
defer mc.connector.putPacket(packet)
data := packet.data[:cap(packet.data)]
data := mc.wbuf.takeBuffer(minPktLen)
data = data[:cap(data)]

// command [1 byte]
data[4] = comStmtExecute
Expand Down Expand Up @@ -1259,7 +1246,6 @@ func (stmt *mysqlStmt) writeExecutePacket(ctx context.Context, args []driver.Val
data = data[:pos]
}

packet.data = data
return mc.writePacket(ctx, data)
}

Expand Down Expand Up @@ -1288,9 +1274,6 @@ func (mc *okHandler) discardResults(ctx context.Context) error {
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
func (rows *binaryRows) readRow(dest []driver.Value) error {
ctx := rows.ctx
if pkt := rows.pkt; pkt != nil {
rows.mc.connector.putPacket(pkt)
}
packet, err := rows.mc.readPacket(ctx)
if err != nil {
return err
Expand Down Expand Up @@ -1482,8 +1465,7 @@ func (mc *mysqlConn) startGoroutines() {

func (mc *mysqlConn) readLoop() {
for {
pkt := mc.connector.getPacket()

pkt := new(packet)
mc.muRead.Lock()
pkt.readFrom(mc.netConn)
mc.muRead.Unlock()
Expand Down

0 comments on commit 978f9f3

Please sign in to comment.