Skip to content

Commit

Permalink
re-introduce packet pool
Browse files Browse the repository at this point in the history
  • Loading branch information
shogo82148 committed Oct 9, 2023
1 parent b6c7b5d commit 6eda7f8
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 40 deletions.
3 changes: 2 additions & 1 deletion auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,11 @@ func (mc *mysqlConn) handleAuthResult(ctx context.Context, oldAuthData []byte, p
return err
}

data, err = mc.readPacket(ctx)
packet, err := mc.readPacket(ctx)
if err != nil {
return err
}
data = packet.data

if data[0] != iAuthMoreData {
return fmt.Errorf("unexpected resp from server for caching_sha2_password, perform full authentication")
Expand Down
2 changes: 1 addition & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type mysqlConn struct {
// for context support (Go 1.8+)
closech chan struct{}
closed atomicBool // set when conn is closed, before closech is closed
readRes chan packet // channel for read result
readRes chan *packet // channel for read result
writeReq chan []byte // buffered channel for write packets
writeRes chan writeResult // channel for write result
}
Expand Down
2 changes: 2 additions & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ func TestPingMarkBadConnection(t *testing.T) {
netConn: nc,
rbuf: newReadBuffer(nc),
maxAllowedPacket: defaultMaxAllowedPacket,
connector: &connector{},
}
ms.startGoroutines()
defer ms.cleanup()
Expand All @@ -180,6 +181,7 @@ func TestPingErrInvalidConn(t *testing.T) {
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
cfg: NewConfig(),
connector: &connector{},
}
ms.startGoroutines()
defer ms.cleanup()
Expand Down
16 changes: 16 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ import (
"os"
"strconv"
"strings"
"sync"
)

type connector struct {
cfg *Config // immutable private copy.
encodedAttributes string // Encoded connection attributes.
packetPool sync.Pool
}

func encodeConnectionAttributes(textAttributes string) string {
Expand Down Expand Up @@ -169,6 +171,20 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
return mc, nil
}

func (c *connector) getPacket() *packet {
p := c.packetPool.Get()
if p == nil {
return &packet{}
}
return p.(*packet)
}

func (c *connector) putPacket(p *packet) {
if p != nil && len(p.data) < maxPacketSize {
c.packetPool.Put(p)
}
}

// Driver implements driver.Connector interface.
// Driver returns &MySQLDriver{}.
func (c *connector) Driver() driver.Driver {
Expand Down
64 changes: 47 additions & 17 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ func (p *packet) readFrom(r *readBuffer) {
return
}

p.data = append([]byte(nil), data...) // TODO: reduce allocations
p.data = append(p.data[:0], data...)
}

// Read packet to buffer 'data'
func (mc *mysqlConn) readPacket(ctx context.Context) ([]byte, error) {
var prevData []byte
func (mc *mysqlConn) readPacket(ctx context.Context) (*packet, error) {
var prevData *packet
for {
var pkt packet
var pkt *packet
select {
case pkt = <-mc.readRes:
case <-mc.closech:
Expand Down Expand Up @@ -99,12 +99,24 @@ func (mc *mysqlConn) readPacket(ctx context.Context) ([]byte, error) {
return prevData, nil
}

prevData = append(prevData, pkt.data...)

// return data if this was the last packet
if pktLen < maxPacketSize {
// zero allocations for non-split packets
if prevData == nil {
return pkt, nil
}

prevData.data = append(prevData.data, pkt.data...)
mc.connector.putPacket(pkt)
return prevData, nil
}

if prevData != nil {
prevData.data = append(prevData.data, pkt.data...)
mc.connector.putPacket(pkt)
} else {
prevData = pkt
}
}
}

Expand Down Expand Up @@ -209,7 +221,7 @@ func (mc *mysqlConn) writePacket(ctx context.Context, data []byte) error {
// Handshake Initialization Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
func (mc *mysqlConn) readHandshakePacket(ctx context.Context) (data []byte, plugin string, err error) {
data, err = mc.readPacket(ctx)
packet, err := mc.readPacket(ctx)
if err != nil {
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
// in connection initialization we don't risk retrying non-idempotent actions.
Expand All @@ -218,6 +230,8 @@ func (mc *mysqlConn) readHandshakePacket(ctx context.Context) (data []byte, plug
}
return
}
defer mc.connector.putPacket(packet)
data = packet.data

if data[0] == iERR {
return nil, "", mc.handleErrorPacket(data)
Expand Down Expand Up @@ -504,10 +518,11 @@ func (mc *mysqlConn) writeCommandPacketUint32(ctx context.Context, command byte,
******************************************************************************/

func (mc *mysqlConn) readAuthResult(ctx context.Context) ([]byte, string, error) {
data, err := mc.readPacket(ctx)
packet, err := mc.readPacket(ctx)
if err != nil {
return nil, "", err
}
data := packet.data

// packet indicator
switch data[0] {
Expand Down Expand Up @@ -540,10 +555,11 @@ func (mc *mysqlConn) readAuthResult(ctx context.Context) ([]byte, string, error)

// Returns error if Packet is not a 'Result OK'-Packet
func (mc *okHandler) readResultOK(ctx context.Context) error {
data, err := mc.conn().readPacket(ctx)
packet, err := mc.conn().readPacket(ctx)
if err != nil {
return err
}
data := packet.data

if data[0] == iOK {
return mc.handleOkPacket(data)
Expand All @@ -558,10 +574,12 @@ func (mc *okHandler) readResultSetHeaderPacket(ctx context.Context) (int, error)
mc.result.affectedRows = append(mc.result.affectedRows, 0)
mc.result.insertIds = append(mc.result.insertIds, 0)

data, err := mc.conn().readPacket(ctx)
packet, err := mc.conn().readPacket(ctx)
if err != nil {
return 0, err
}
defer mc.conn().connector.putPacket(packet)
data := packet.data
if err == nil {
switch data[0] {

Expand Down Expand Up @@ -704,10 +722,11 @@ func (mc *mysqlConn) readColumns(ctx context.Context, count int) ([]mysqlField,
columns := make([]mysqlField, count)

for i := 0; ; i++ {
data, err := mc.readPacket(ctx)
packet, err := mc.readPacket(ctx)
if err != nil {
return nil, err
}
data := packet.data

// EOF Packet
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
Expand Down Expand Up @@ -808,10 +827,13 @@ func (rows *textRows) readRow(dest []driver.Value) error {
return io.EOF
}

data, err := mc.readPacket(ctx)
rows.mc.connector.putPacket(rows.pkt)
packet, err := mc.readPacket(ctx)
rows.pkt = packet
if err != nil {
return err
}
data := packet.data

// EOF Packet
if data[0] == iEOF && len(data) == 5 {
Expand Down Expand Up @@ -893,10 +915,11 @@ func (rows *textRows) readRow(dest []driver.Value) error {
// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
func (mc *mysqlConn) readUntilEOF(ctx context.Context) error {
for {
data, err := mc.readPacket(ctx)
packet, err := mc.readPacket(ctx)
if err != nil {
return err
}
data := packet.data

switch data[0] {
case iERR:
Expand All @@ -907,6 +930,7 @@ func (mc *mysqlConn) readUntilEOF(ctx context.Context) error {
}
return nil
}
mc.connector.putPacket(packet)
}
}

Expand All @@ -917,10 +941,12 @@ func (mc *mysqlConn) readUntilEOF(ctx context.Context) error {
// Prepare Result Packets
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
func (stmt *mysqlStmt) readPrepareResultPacket(ctx context.Context) (uint16, error) {
data, err := stmt.mc.readPacket(ctx)
packet, err := stmt.mc.readPacket(ctx)
if err != nil {
return 0, err
}
defer stmt.mc.connector.putPacket(packet)
data := packet.data
if err == nil {
// packet indicator [1 byte]
if data[0] != iOK {
Expand Down Expand Up @@ -1253,10 +1279,14 @@ func (mc *okHandler) discardResults(ctx context.Context) error {
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
func (rows *binaryRows) readRow(dest []driver.Value) error {
ctx := rows.ctx
data, err := rows.mc.readPacket(ctx)

rows.mc.connector.putPacket(rows.pkt)
packet, err := rows.mc.readPacket(ctx)
rows.pkt = packet
if err != nil {
return err
}
data := packet.data

// packet indicator [1 byte]
if data[0] != iOK {
Expand Down Expand Up @@ -1432,7 +1462,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {

func (mc *mysqlConn) startGoroutines() {
mc.closech = make(chan struct{})
mc.readRes = make(chan packet)
mc.readRes = make(chan *packet)
mc.writeReq = make(chan []byte, 1)
mc.writeRes = make(chan writeResult)

Expand All @@ -1442,7 +1472,7 @@ func (mc *mysqlConn) startGoroutines() {

func (mc *mysqlConn) readLoop() {
for {
var pkt packet
pkt := mc.connector.getPacket()
mc.muRead.Lock()
pkt.readFrom(&mc.rbuf)
mc.muRead.Unlock()
Expand Down
42 changes: 21 additions & 21 deletions packets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ func TestReadPacketSingleByte(t *testing.T) {
conn.Write([]byte{0x01, 0x00, 0x00, 0x00, 0xff})
}()

data, err := mc.readPacket(context.Background())
packet, err := mc.readPacket(context.Background())
if err != nil {
t.Fatal(err)
}
if len(data) != 1 {
t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(data))
if len(packet.data) != 1 {
t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(packet.data))
}
if data[0] != 0xff {
t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, data[0])
if packet.data[0] != 0xff {
t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet.data[0])
}
}

Expand Down Expand Up @@ -124,11 +124,11 @@ func TestReadPacketSplit(t *testing.T) {
}()
// TODO: check read operation count

data, err := mc.readPacket(context.Background())
packet, err := mc.readPacket(context.Background())
if err != nil {
t.Fatal(err)
}
if len(data) != maxPacketSize {
if len(packet.data) != maxPacketSize {
t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize, len(data))
}
if data[0] != 0x11 {
Expand Down Expand Up @@ -173,18 +173,18 @@ func TestReadPacketSplit(t *testing.T) {
}()
// TODO: check read operation count

data, err := mc.readPacket(context.Background())
packet, err := mc.readPacket(context.Background())
if err != nil {
t.Fatal(err)
}
if len(data) != 2*maxPacketSize {
t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(data))
if len(packet.data) != 2*maxPacketSize {
t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(packet.data))
}
if data[0] != 0x11 {
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, data[0])
if packet.data[0] != 0x11 {
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet.data[0])
}
if data[2*maxPacketSize-1] != 0x44 {
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, data[2*maxPacketSize-1])
if packet.data[2*maxPacketSize-1] != 0x44 {
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet.data[2*maxPacketSize-1])
}
})

Expand Down Expand Up @@ -215,18 +215,18 @@ func TestReadPacketSplit(t *testing.T) {
}()
// TODO: check read operation count

data, err := mc.readPacket(context.Background())
packet, err := mc.readPacket(context.Background())
if err != nil {
t.Fatal(err)
}
if len(data) != maxPacketSize+42 {
t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(data))
if len(packet.data) != maxPacketSize+42 {
t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(packet.data))
}
if data[0] != 0x11 {
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, data[0])
if packet.data[0] != 0x11 {
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet.data[0])
}
if data[maxPacketSize+41] != 0x44 {
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, data[maxPacketSize+41])
if packet.data[maxPacketSize+41] != 0x44 {
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet.data[maxPacketSize+41])
}
})
}
Expand Down
4 changes: 4 additions & 0 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type mysqlRows struct {
mc *mysqlConn
ctx context.Context
rs resultSet
pkt *packet // current read packet
}

type binaryRows struct {
Expand Down Expand Up @@ -108,6 +109,9 @@ func (rows *mysqlRows) Close() (err error) {
return err
}

rows.mc.connector.putPacket(rows.pkt)
rows.pkt = nil

// Remove unread packets from stream
if !rows.rs.done {
err = mc.readUntilEOF(ctx)
Expand Down

0 comments on commit 6eda7f8

Please sign in to comment.