diff --git a/CHANGELOG.md b/CHANGELOG.md index aaabaab6..1c66e5ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,9 @@ # Overlord -## Version 1.0 +## Version 1.1.0 +1. add memcache binary protocol support. + +## Version 1.0.0 1. compitable consist hash with twemproxy. 2. reduce object alloc (by using Pool). 3. recycle using of request and response memory. diff --git a/lib/bufio/buffer.go b/lib/bufio/buffer.go index 374b9e0c..46c06c90 100644 --- a/lib/bufio/buffer.go +++ b/lib/bufio/buffer.go @@ -58,6 +58,7 @@ func (b *Buffer) grow() { nb := make([]byte, len(b.buf)*growFactor) copy(nb, b.buf[:b.w]) b.buf = nb + // NOTE: old buf cannot put into pool, maybe some slice point mem. Wait GC!!! } func (b *Buffer) len() int { diff --git a/lib/bufio/io.go b/lib/bufio/io.go index 13b4dc30..6470d5b2 100644 --- a/lib/bufio/io.go +++ b/lib/bufio/io.go @@ -245,11 +245,3 @@ func (w *Writer) Write(p []byte) (err error) { w.cursor = (w.cursor + 1) % maxBuffered return nil } - -// WriteString writes a string. -// It returns the number of bytes written. -// If the count is less than len(s), it also returns an error explaining -// why the write is short. -func (w *Writer) WriteString(s string) (err error) { - return w.Write([]byte(s)) -} diff --git a/lib/bufio/io_test.go b/lib/bufio/io_test.go index ef9f47df..f1f825f1 100644 --- a/lib/bufio/io_test.go +++ b/lib/bufio/io_test.go @@ -185,9 +185,6 @@ func TestWriterWriteOk(t *testing.T) { err = w.Write([]byte(data)) assert.NoError(t, err) - err = w.WriteString(data) - assert.NoError(t, err) - err = w.Flush() assert.NoError(t, err) diff --git a/proto/memcache/binary/node_conn.go b/proto/memcache/binary/node_conn.go new file mode 100644 index 00000000..10ad0092 --- /dev/null +++ b/proto/memcache/binary/node_conn.go @@ -0,0 +1,204 @@ +package binary + +import ( + "bytes" + "encoding/binary" + "io" + "sync/atomic" + "time" + + "overlord/lib/bufio" + libnet "overlord/lib/net" + "overlord/lib/prom" + "overlord/proto" + + "github.com/pkg/errors" +) + +const ( + handlerOpening = int32(0) + handlerClosed = int32(1) +) + +type nodeConn struct { + cluster string + addr string + conn *libnet.Conn + bw *bufio.Writer + br *bufio.Reader + closed int32 + + pinger *mcPinger +} + +// NewNodeConn returns node conn. +func NewNodeConn(cluster, addr string, dialTimeout, readTimeout, writeTimeout time.Duration) (nc proto.NodeConn) { + conn := libnet.DialWithTimeout(addr, dialTimeout, readTimeout, writeTimeout) + nc = &nodeConn{ + cluster: cluster, + addr: addr, + conn: conn, + bw: bufio.NewWriter(conn), + br: bufio.NewReader(conn, nil), + pinger: newMCPinger(conn), + } + return +} + +// Ping will send some special command by checking mc node is alive +func (n *nodeConn) Ping() (err error) { + if n.Closed() { + err = io.EOF + return + } + err = n.pinger.Ping() + return +} + +func (n *nodeConn) WriteBatch(mb *proto.MsgBatch) (err error) { + var ( + m *proto.Message + idx int + ) + for { + m = mb.Nth(idx) + if m == nil { + break + } + err = n.write(m) + if err != nil { + m.DoneWithError(err) + return err + } + m.MarkWrite() + idx++ + } + if err = n.bw.Flush(); err != nil { + err = errors.Wrap(err, "MC Writer flush message bytes") + } + return +} + +func (n *nodeConn) write(m *proto.Message) (err error) { + if n.Closed() { + err = errors.Wrap(ErrClosed, "MC Writer write") + return + } + mcr, ok := m.Request().(*MCRequest) + if !ok { + err = errors.Wrap(ErrAssertReq, "MC Writer assert request") + return + } + _ = n.bw.Write(magicReqBytes) + + cmd := mcr.rTp + if cmd == RequestTypeGetQ || cmd == RequestTypeGetKQ { + cmd = RequestTypeGetK + } + _ = n.bw.Write(cmd.Bytes()) + _ = n.bw.Write(mcr.keyLen) + _ = n.bw.Write(mcr.extraLen) + _ = n.bw.Write(zeroBytes) + _ = n.bw.Write(zeroTwoBytes) + _ = n.bw.Write(mcr.bodyLen) + _ = n.bw.Write(mcr.opaque) + _ = n.bw.Write(mcr.cas) + if !bytes.Equal(mcr.bodyLen, zeroFourBytes) { + _ = n.bw.Write(mcr.data) + } + return +} + +func (n *nodeConn) ReadBatch(mb *proto.MsgBatch) (err error) { + if n.Closed() { + err = errors.Wrap(ErrClosed, "MC Reader read batch message") + return + } + defer n.br.ResetBuffer(nil) + n.br.ResetBuffer(mb.Buffer()) + var ( + size int + cursor int + nth int + m *proto.Message + + mcr *MCRequest + ok bool + ) + m = mb.Nth(nth) + mcr, ok = m.Request().(*MCRequest) + if !ok { + err = errors.Wrap(ErrAssertReq, "MC Reader assert request") + return + } + for { + err = n.br.Read() + if err != nil { + err = errors.Wrap(err, "MC Reader while read") + return + } + for { + size, err = n.fillMCRequest(mcr, n.br.Buffer().Bytes()[cursor:]) + if err == bufio.ErrBufferFull { + err = nil + break + } else if err != nil { + return + } + m.MarkRead() + + cursor += size + nth++ + + m = mb.Nth(nth) + if m == nil { + return + } + mcr, ok = m.Request().(*MCRequest) + if !ok { + err = errors.Wrap(ErrAssertReq, "MC Reader assert request") + return + } + } + } +} + +func (n *nodeConn) fillMCRequest(mcr *MCRequest, data []byte) (size int, err error) { + if len(data) < requestHeaderLen { + return 0, bufio.ErrBufferFull + } + parseHeader(data[0:requestHeaderLen], mcr, false) + + bl := binary.BigEndian.Uint32(mcr.bodyLen) + if bl == 0 { + if mcr.rTp == RequestTypeGet || mcr.rTp == RequestTypeGetQ || mcr.rTp == RequestTypeGetK || mcr.rTp == RequestTypeGetKQ { + prom.Miss(n.cluster, n.addr) + } + size = requestHeaderLen + return + } + if len(data[requestHeaderLen:]) < int(bl) { + return 0, bufio.ErrBufferFull + } + size = requestHeaderLen + int(bl) + mcr.data = data[requestHeaderLen : requestHeaderLen+bl] + + if mcr.rTp == RequestTypeGet || mcr.rTp == RequestTypeGetQ || mcr.rTp == RequestTypeGetK || mcr.rTp == RequestTypeGetKQ { + prom.Hit(n.cluster, n.addr) + } + return +} + +func (n *nodeConn) Close() error { + if atomic.CompareAndSwapInt32(&n.closed, handlerOpening, handlerClosed) { + _ = n.pinger.Close() + n.pinger = nil + err := n.conn.Close() + return err + } + return nil +} + +func (n *nodeConn) Closed() bool { + return atomic.LoadInt32(&n.closed) == handlerClosed +} diff --git a/proto/memcache/binary/node_conn_test.go b/proto/memcache/binary/node_conn_test.go new file mode 100644 index 00000000..2eb42976 --- /dev/null +++ b/proto/memcache/binary/node_conn_test.go @@ -0,0 +1,262 @@ +package binary + +import ( + "encoding/binary" + "io" + "net" + "testing" + "time" + + "overlord/lib/bufio" + "overlord/proto" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func _createNodeConn(data []byte) *nodeConn { + conn := _createConn(data) + nc := &nodeConn{ + cluster: "clusterA", + addr: "127.0.0.1:5000", + bw: bufio.NewWriter(conn), + br: bufio.NewReader(conn, nil), + pinger: newMCPinger(conn), + conn: conn, + } + return nc +} + +func _createReqMsg(bin []byte) *proto.Message { + mc := &MCRequest{} + parseHeader(bin, mc, true) + + bl := int(binary.BigEndian.Uint32(mc.bodyLen)) + el := int(uint8(mc.extraLen[0])) + kl := int(binary.BigEndian.Uint16(mc.keyLen)) + if kl > 0 { + mc.key = bin[24+el : 24+el+kl] + } + if bl > 0 { + mc.data = bin[24:] + } + pm := proto.NewMessage() + pm.WithRequest(mc) + return pm +} + +func _causeEqual(t *testing.T, except, actual error) { + err := errors.Cause(actual) + assert.Equal(t, except, err) +} + +func TestNodeConnWriteOk(t *testing.T) { + ts := []struct { + name string + req []byte + except []byte + }{ + {name: "get", req: getTestData, except: getTestData}, + {name: "set", req: setTestData, except: setTestData}, + } + + for _, tt := range ts { + t.Run(tt.name, func(subt *testing.T) { + req := _createReqMsg(tt.req) + nc := _createNodeConn(nil) + batch := proto.NewMsgBatch() + batch.AddMsg(req) + err := nc.WriteBatch(batch) + assert.NoError(t, err) + + m, ok := nc.conn.Conn.(*mockConn) + assert.True(t, ok) + + buf := make([]byte, 1024) + size, err := m.wbuf.Read(buf) + assert.NoError(t, err) + assert.Equal(t, tt.except, buf[:size]) + }) + } +} + +func TestNodeConnBatchWriteOk(t *testing.T) { + ts := []struct { + name string + req []byte + except []byte + }{ + {name: "get", req: getTestData, except: getTestData}, + {name: "set", req: setTestData, except: setTestData}, + } + + for _, tt := range ts { + t.Run(tt.name, func(subt *testing.T) { + req := _createReqMsg(tt.req) + nc := _createNodeConn(nil) + batch := proto.NewMsgBatch() + batch.AddMsg(req) + + err := nc.WriteBatch(batch) + assert.NoError(t, err) + + m, ok := nc.conn.Conn.(*mockConn) + assert.True(t, ok) + + buf := make([]byte, 1024) + size, err := m.wbuf.Read(buf) + assert.NoError(t, err) + assert.Equal(t, tt.except, buf[:size]) + }) + } +} + +func TestNodeConnWriteClosed(t *testing.T) { + req := _createReqMsg(getTestData) + nc := _createNodeConn(nil) + err := nc.Close() + assert.NoError(t, err) + assert.True(t, nc.Closed()) + err = nc.write(req) + assert.Error(t, err) + _causeEqual(t, ErrClosed, err) + assert.NoError(t, nc.Close()) +} + +type mockReq struct { +} + +func (*mockReq) CmdString() string { + return "" +} + +func (*mockReq) Cmd() []byte { + return []byte("") +} + +func (*mockReq) Key() []byte { + return []byte{} +} + +func (*mockReq) Resp() []byte { + return nil +} + +func (*mockReq) Put() { + +} +func TestNodeConnWriteTypeAssertFail(t *testing.T) { + req := proto.NewMessage() + nc := _createNodeConn(nil) + req.WithRequest(&mockReq{}) + batch := proto.NewMsgBatch() + batch.AddMsg(req) + err := nc.WriteBatch(batch) + nc.bw.Flush() + assert.Error(t, err) + _causeEqual(t, ErrAssertReq, err) +} + +func TestNodeConnReadClosed(t *testing.T) { + req := _createReqMsg(getTestData) + nc := _createNodeConn(nil) + + err := nc.Close() + assert.NoError(t, err) + assert.True(t, nc.Closed()) + batch := proto.NewMsgBatch() + batch.AddMsg(req) + err = nc.ReadBatch(batch) + assert.Error(t, err) + _causeEqual(t, ErrClosed, err) +} + +func TestNodeConnReadOk(t *testing.T) { + ts := []struct { + name string + req []byte + cData []byte + except []byte + }{ + { + name: "getmiss", + req: getTestData, + cData: getMissRespTestData, except: getMissRespTestData, + }, + { + name: "get ok", + req: getTestData, + cData: getRespTestData, except: getRespTestData, + }, + { + name: "set ok", + req: setTestData, + cData: setRespTestData, except: setRespTestData, + }, + } + for _, tt := range ts { + + t.Run(tt.name, func(t *testing.T) { + req := _createReqMsg(tt.req) + nc := _createNodeConn([]byte(tt.cData)) + batch := proto.NewMsgBatch() + batch.AddMsg(req) + err := nc.ReadBatch(batch) + assert.NoError(t, err) + + mcr, ok := req.Request().(*MCRequest) + assert.Equal(t, true, ok) + + actual := append([]byte{mcr.magic}, mcr.rTp.Bytes()...) + actual = append(actual, mcr.keyLen...) + actual = append(actual, mcr.extraLen...) + actual = append(actual, zeroBytes...) // datatype + actual = append(actual, mcr.status...) // status + actual = append(actual, mcr.bodyLen...) + actual = append(actual, mcr.opaque...) + actual = append(actual, mcr.cas...) + bl := binary.BigEndian.Uint32(mcr.bodyLen) + if bl > 0 { + actual = append(actual, mcr.data...) + } + + assert.Equal(t, tt.except, actual) + }) + + } +} + +func TestNodeConnAssertError(t *testing.T) { + nc := _createNodeConn(nil) + req := proto.NewMessage() + req.WithRequest(&mockReq{}) + batch := proto.NewMsgBatch() + batch.AddMsg(req) + err := nc.ReadBatch(batch) + _causeEqual(t, ErrAssertReq, err) +} + +func TestNocdConnPingOk(t *testing.T) { + nc := _createNodeConn(pongBs) + err := nc.Ping() + assert.NoError(t, err) + assert.NoError(t, nc.Close()) + err = nc.Ping() + assert.Error(t, err) + _causeEqual(t, io.EOF, err) +} + +func TestNewNodeConnWithClosedBinder(t *testing.T) { + taddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") + assert.NoError(t, err) + listener, err := net.ListenTCP("tcp", taddr) + assert.NoError(t, err) + addr := listener.Addr() + go func() { + defer listener.Close() + sock, _ := listener.Accept() + defer sock.Close() + }() + nc := NewNodeConn("anyName", addr.String(), time.Second, time.Second, time.Second) + assert.NotNil(t, nc) +} diff --git a/proto/memcache/binary/pinger.go b/proto/memcache/binary/pinger.go new file mode 100644 index 00000000..a85394a8 --- /dev/null +++ b/proto/memcache/binary/pinger.go @@ -0,0 +1,84 @@ +package binary + +import ( + "bytes" + "sync/atomic" + + "overlord/lib/bufio" + libnet "overlord/lib/net" + + "github.com/pkg/errors" +) + +const ( + pingBufferSize = 32 +) + +var ( + pingBs = []byte{ + 0x80, // magic + 0x0a, // cmd: noop + 0x00, 0x00, // key len + 0x00, // extra len + 0x00, // data type + 0x00, 0x00, // vbucket + 0x00, 0x00, 0x00, 0x00, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + } + pongBs = []byte{ + 0x81, // magic + 0x0a, // cmd: noop + 0x00, 0x00, // key len + 0x00, // extra len + 0x00, // data type + 0x00, 0x00, // status + 0x00, 0x00, 0x00, 0x00, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + } +) + +type mcPinger struct { + conn *libnet.Conn + bw *bufio.Writer + br *bufio.Reader + closed int32 +} + +func newMCPinger(nc *libnet.Conn) *mcPinger { + return &mcPinger{ + conn: nc, + bw: bufio.NewWriter(nc), + br: bufio.NewReader(nc, bufio.Get(pingBufferSize)), + } +} + +func (m *mcPinger) Ping() (err error) { + if atomic.LoadInt32(&m.closed) == handlerClosed { + err = ErrPingerPong + return + } + _ = m.bw.Write(pingBs) + if err = m.bw.Flush(); err != nil { + err = errors.Wrap(err, "MC ping flush") + return + } + err = m.br.Read() + head, err := m.br.ReadExact(requestHeaderLen) + if err != nil { + err = errors.Wrap(err, "MC ping read exact") + return + } + if !bytes.Equal(head, pongBs) { + err = ErrPingerPong + } + return +} + +func (m *mcPinger) Close() error { + if atomic.CompareAndSwapInt32(&m.closed, handlerOpening, handlerClosed) { + return m.conn.Close() + } + return nil +} diff --git a/proto/memcache/binary/pinger_test.go b/proto/memcache/binary/pinger_test.go new file mode 100644 index 00000000..9b2b8027 --- /dev/null +++ b/proto/memcache/binary/pinger_test.go @@ -0,0 +1,119 @@ +package binary + +import ( + "bytes" + "net" + "testing" + "time" + + "overlord/lib/bufio" + libnet "overlord/lib/net" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +type mockAddr string + +func (m mockAddr) Network() string { + return "tcp" +} +func (m mockAddr) String() string { + return string(m) +} + +type mockConn struct { + rbuf *bytes.Buffer + wbuf *bytes.Buffer + addr mockAddr +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + return m.rbuf.Read(b) +} +func (m *mockConn) Write(b []byte) (n int, err error) { + return m.wbuf.Write(b) +} + +// writeBuffers impl the net.buffersWriter to support writev +func (m *mockConn) writeBuffers(buf *net.Buffers) (int64, error) { + return buf.WriteTo(m.wbuf) +} + +func (m *mockConn) Close() error { return nil } +func (m *mockConn) LocalAddr() net.Addr { return m.addr } +func (m *mockConn) RemoteAddr() net.Addr { return m.addr } + +func (m *mockConn) SetDeadline(t time.Time) error { return nil } +func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } + +// _createConn is useful tools for handler test +func _createConn(data []byte) *libnet.Conn { + return _createRepeatConn(data, 1) +} + +func _createRepeatConn(data []byte, r int) *libnet.Conn { + mconn := &mockConn{ + addr: "127.0.0.1:12345", + rbuf: bytes.NewBuffer(bytes.Repeat(data, r)), + wbuf: new(bytes.Buffer), + } + conn := libnet.NewConn(mconn, time.Second, time.Second) + return conn +} + +func TestPingerPingOk(t *testing.T) { + conn := _createConn(pongBs) + pinger := newMCPinger(conn) + + err := pinger.Ping() + assert.NoError(t, err) +} + +func TestPingerPingEOF(t *testing.T) { + conn := _createConn(pongBs) + pinger := newMCPinger(conn) + + err := pinger.Ping() + assert.NoError(t, err) + + err = pinger.Ping() + assert.Error(t, err) + + err = errors.Cause(err) + assert.Equal(t, bufio.ErrBufferFull, err) +} + +func TestPingerPing100Ok(t *testing.T) { + conn := _createRepeatConn(pongBs, 100) + pinger := newMCPinger(conn) + + for i := 0; i < 100; i++ { + err := pinger.Ping() + assert.NoError(t, err) + } + + err := pinger.Ping() + assert.Error(t, err) + _causeEqual(t, bufio.ErrBufferFull, err) +} + +func TestPingerClosed(t *testing.T) { + conn := _createRepeatConn(pongBs, 100) + pinger := newMCPinger(conn) + err := pinger.Close() + assert.NoError(t, err) + + err = pinger.Ping() + assert.Error(t, err) + assert.NoError(t, pinger.Close()) +} + +func TestPingerNotReturnPong(t *testing.T) { + conn := _createRepeatConn([]byte("emmmmm...."), 100) + pinger := newMCPinger(conn) + err := pinger.Ping() + assert.Error(t, err) + _causeEqual(t, ErrPingerPong, err) +} diff --git a/proto/memcache/binary/proxy_conn.go b/proto/memcache/binary/proxy_conn.go new file mode 100644 index 00000000..89fa4ad0 --- /dev/null +++ b/proto/memcache/binary/proxy_conn.go @@ -0,0 +1,171 @@ +package binary + +import ( + "bytes" + "encoding/binary" + + "overlord/lib/bufio" + libnet "overlord/lib/net" + "overlord/proto" + + "github.com/pkg/errors" +) + +// memcached binary protocol: https://github.com/memcached/memcached/wiki/BinaryProtocolRevamped +const ( + requestHeaderLen = 24 +) + +type proxyConn struct { + br *bufio.Reader + bw *bufio.Writer + completed bool +} + +// NewProxyConn new a memcache decoder and encode. +func NewProxyConn(rw *libnet.Conn) proto.ProxyConn { + p := &proxyConn{ + // TODO: optimus zero + br: bufio.NewReader(rw, bufio.Get(1024)), + bw: bufio.NewWriter(rw), + completed: true, + } + return p +} + +func (p *proxyConn) Decode(msgs []*proto.Message) ([]*proto.Message, error) { + var err error + // if completed, means that we have parsed all the buffered + // if not completed, we need only to parse the buffered message + if p.completed { + err = p.br.Read() + if err != nil { + return nil, err + } + } + for i := range msgs { + p.completed = false + // set msg type + msgs[i].Type = proto.CacheTypeMemcacheBinary + // decode + err = p.decode(msgs[i]) + if err == bufio.ErrBufferFull { + p.completed = true + msgs[i].Reset() + return msgs[:i], nil + } else if err != nil { + msgs[i].Reset() + return msgs[:i], err + } + msgs[i].MarkStart() + } + return msgs, nil +} + +func (p *proxyConn) decode(m *proto.Message) (err error) { +NEXTGET: + // bufio reset buffer + head, err := p.br.ReadExact(requestHeaderLen) + if err == bufio.ErrBufferFull { + return + } else if err != nil { + err = errors.Wrap(err, "MC decoder while reading text line") + return + } + req := p.request(m) + parseHeader(head, req, true) + if err != nil { + err = errors.Wrap(err, "MC decoder while parse header") + return + } + switch req.rTp { + case RequestTypeSet, RequestTypeAdd, RequestTypeReplace, RequestTypeGet, RequestTypeGetK, + RequestTypeDelete, RequestTypeIncr, RequestTypeDecr, RequestTypeAppend, RequestTypePrepend, RequestTypeTouch, RequestTypeGat: + if err = p.decodeCommon(m, req); err == bufio.ErrBufferFull { + p.br.Advance(-requestHeaderLen) + return + } + return + case RequestTypeGetQ, RequestTypeGetKQ: + if err = p.decodeCommon(m, req); err == bufio.ErrBufferFull { + p.br.Advance(-requestHeaderLen) + return + } + goto NEXTGET + } + err = errors.Wrap(ErrBadRequest, "MC decoder unsupport command") + return +} + +func (p *proxyConn) decodeCommon(m *proto.Message, req *MCRequest) (err error) { + bl := binary.BigEndian.Uint32(req.bodyLen) + body, err := p.br.ReadExact(int(bl)) + if err == bufio.ErrBufferFull { + return + } else if err != nil { + err = errors.Wrap(err, "MC decodeCommon read exact body") + return + } + el := uint8(req.extraLen[0]) + kl := binary.BigEndian.Uint16(req.keyLen) + req.key = body[int(el) : int(el)+int(kl)] + req.data = body + return +} + +func (p *proxyConn) request(m *proto.Message) *MCRequest { + req := m.NextReq() + if req == nil { + req = GetReq() + m.WithRequest(req) + } + return req.(*MCRequest) +} + +func parseHeader(bs []byte, req *MCRequest, isDecode bool) { + req.magic = bs[0] + if isDecode { + req.rTp = RequestType(bs[1]) + } + req.keyLen = bs[2:4] + req.extraLen = bs[4:5] + if !isDecode { + req.status = bs[6:8] + } + req.bodyLen = bs[8:12] + req.opaque = bs[12:16] + req.cas = bs[16:24] +} + +// Encode encode response and write into writer. +func (p *proxyConn) Encode(m *proto.Message) (err error) { + reqs := m.Requests() + for _, req := range reqs { + mcr, ok := req.(*MCRequest) + if !ok { + err = errors.Wrap(ErrAssertReq, "MC Encoder assert request") + return + } + _ = p.bw.Write(magicRespBytes) // NOTE: magic + _ = p.bw.Write(mcr.rTp.Bytes()) + _ = p.bw.Write(mcr.keyLen) + _ = p.bw.Write(mcr.extraLen) + _ = p.bw.Write(zeroBytes) + if err = m.Err(); err != nil { + _ = p.bw.Write(resopnseStatusInternalErrBytes) + } else { + _ = p.bw.Write(mcr.status) + } + _ = p.bw.Write(mcr.bodyLen) + _ = p.bw.Write(mcr.opaque) + _ = p.bw.Write(mcr.cas) + + if err == nil && !bytes.Equal(mcr.bodyLen, zeroFourBytes) { + _ = p.bw.Write(mcr.data) + } + } + if err = p.bw.Flush(); err != nil { + err = errors.Wrap(err, "MC Encoder encode response flush bytes") + } + return +} diff --git a/proto/memcache/binary/proxy_conn_test.go b/proto/memcache/binary/proxy_conn_test.go new file mode 100644 index 00000000..753cce45 --- /dev/null +++ b/proto/memcache/binary/proxy_conn_test.go @@ -0,0 +1,357 @@ +package binary + +import ( + "testing" + + "overlord/proto" + + "github.com/stretchr/testify/assert" +) + +var ( + getTestData = []byte{ + 0x80, // magic + 0x0c, // cmd + 0x00, 0x03, // key len + 0x00, // extra len + 0x00, // data type + 0x00, 0x00, // vbucket + 0x00, 0x00, 0x00, 0x03, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + 0x41, 0x42, 0x43, // key: ABC + } + getRespTestData = []byte{ + 0x81, // magic + 0x0c, // cmd + 0x00, 0x03, // key len + 0x04, // extra len + 0x00, // data type + 0x00, 0x00, // status + 0x00, 0x00, 0x00, 0x0c, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + 0x00, 0x00, 0x00, 0x00, // extra: flag + 0x41, 0x42, 0x43, // key: ABC + 0x41, 0x42, 0x43, 0x44, 0x45, // value: ABCDE + } + getMissRespTestData = []byte{ + 0x81, // magic + 0x0c, // cmd + 0x00, 0x00, // key len + 0x00, // extra len + 0x00, // data type + 0x00, 0x01, // status + 0x00, 0x00, 0x00, 0x00, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + } + setTestData = []byte{ + 0x80, // magic + 0x01, // cmd + 0x00, 0x03, // key len + 0x08, // extra len + 0x00, // data type + 0x00, 0x00, // vbucket + 0x00, 0x00, 0x00, 0x0f, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // extra: flags, expiration + 0x41, 0x42, 0x43, // key: ABC + 0x41, 0x42, 0x43, 0x44, 0x45, // value: ABCDE + } + setRespTestData = []byte{ + 0x81, // magic + 0x01, // cmd + 0x00, 0x00, // key len + 0x00, // extra len + 0x00, // data type + 0x00, 0x00, // status + 0x00, 0x00, 0x00, 0x00, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + } + delTestData = []byte{ + 0x80, // magic + 0x04, // cmd + 0x00, 0x03, // key len + 0x00, // extra len + 0x00, // data type + 0x00, 0x00, // vbucket + 0x00, 0x00, 0x00, 0x03, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + 0x41, 0x42, 0x43, // key: ABC + } + incrTestData = []byte{ + 0x80, // magic + 0x05, // cmd + 0x00, 0x03, // key len + 0x14, // extra len + 0x00, // data type + 0x00, 0x00, // vbucket + 0x00, 0x00, 0x00, 0x17, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // extra: Amount to add / subtract + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // extra: Initial value + 0x00, 0x00, 0x00, 0x00, // extra: Expiration + 0x41, 0x42, 0x43, // key: ABC + } + touchTestData = []byte{ + 0x80, // magic + 0x1c, // cmd + 0x00, 0x03, // key len + 0x00, 0x04, // extra len + 0x00, // data type + 0x00, 0x00, // vbucket + 0x00, 0x00, 0x00, 0x7, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + 0x00, 0x00, 0x00, 0x00, // extra: Expiration + 0x41, 0x42, 0x43, // key: ABC + } + getQTestData = []byte{ + 0x80, // magic + 0x0d, // cmd + 0x00, 0x03, // key len + 0x00, // extra len + 0x00, // data type + 0x00, 0x00, // vbucket + 0x00, 0x00, 0x00, 0x03, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + 0x41, 0x42, 0x43, // key: ABC + + 0x80, // magic + 0x0d, // cmd + 0x00, 0x03, // key len + 0x00, // extra len + 0x00, // data type + 0x00, 0x00, // vbucket + 0x00, 0x00, 0x00, 0x03, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + 0x58, 0x59, 0x5A, // key: XYZ + + 0x80, // magic + 0x0c, // cmd + 0x00, 0x03, // key len + 0x00, 0x00, // extra len + 0x00, // data type + 0x00, 0x00, // vbucket + 0x00, 0x00, 0x00, 0x03, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + 0x61, 0x62, 0x63, // key: abc + } + getQRespTestData = [][]byte{ + []byte{ + 0x81, // magic + 0x0d, // cmd + 0x00, 0x03, // key len + 0x04, // extra len + 0x00, // data type + 0x00, 0x00, // status + 0x00, 0x00, 0x00, 0x0c, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + 0x00, 0x00, 0x00, 0x00, // extra: flag + 0x41, 0x42, 0x43, // key: ABC + 0x41, 0x42, 0x43, 0x44, 0x45, // value: ABCDE + }, + []byte{ + 0x81, // magic + 0x0d, // cmd + 0x00, 0x03, // key len + 0x04, // extra len + 0x00, // data type + 0x00, 0x00, // status + 0x00, 0x00, 0x00, 0x0c, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + 0x00, 0x00, 0x00, 0x00, // extra: flag + 0x58, 0x59, 0x5A, // key: XYZ + 0x56, 0x57, 0x58, 0x59, 0x5a, // value: VWXYZ + }, + []byte{ + 0x81, // magic + 0x0c, // cmd + 0x00, 0x03, // key len + 0x04, // extra len + 0x00, // data type + 0x00, 0x00, // status + 0x00, 0x00, 0x00, 0x0c, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + 0x00, 0x00, 0x00, 0x00, // extra: flag + 0x61, 0x62, 0x63, // key: abc + 0x61, 0x62, 0x63, 0x64, 0x65, // value: abcde + }, + } + getQMissRespTestData = []byte{ + 0x81, // magic + 0x0d, // cmd + 0x00, 0x00, // key len + 0x00, // extra len + 0x00, // data type + 0x00, 0x01, // status + 0x00, 0x00, 0x00, 0x00, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + } + notTestData = []byte{ + 0x80, // magic + 0xff, // cmd + 0x00, 0x00, // key len + 0x00, // extra len + 0x00, // data type + 0x00, 0x00, // vbucket + 0x00, 0x00, 0x00, 0x0, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + } +) + +func TestParseHeader(t *testing.T) { + req := &MCRequest{} + parseHeader(getTestData, req, true) + assert.Equal(t, byte(0x80), req.magic) + assert.Equal(t, []byte{0xc}, req.rTp.Bytes()) + assert.Equal(t, []byte{0x00, 0x03}, req.keyLen) + assert.Equal(t, []byte{0x00}, req.extraLen) + assert.Equal(t, []byte{0x00, 0x00, 0x00, 0x03}, req.bodyLen) + assert.Equal(t, []byte{0x00, 0x00, 0x00, 0x00}, req.opaque) + assert.Equal(t, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, req.cas) +} + +func TestProxyConnDecodeOk(t *testing.T) { + ts := []struct { + Name string + Data []byte + Err error + Key []byte + Cmd byte + }{ + // set cases + {"SetOk", setTestData, nil, []byte("ABC"), byte(RequestTypeSet)}, + + // Get Gets + {"GetOk", setTestData, nil, []byte("ABC"), byte(RequestTypeGet)}, + + // Delete + {"DeleteOk", delTestData, nil, []byte("ABC"), byte(RequestTypeDelete)}, + + // Incr/Decr + {"IncrOk", incrTestData, nil, []byte("ABC"), byte(RequestTypeIncr)}, + + // Touch + {"TouchOk", touchTestData, nil, []byte("ABC"), byte(RequestTypeTouch)}, + + // GetQ multi get + {"MGetOk", getQTestData, nil, []byte("ABC"), byte(RequestTypeGetQ)}, + + // Not support + {"NotSupportCmd", notTestData, ErrBadRequest, []byte{}, 0xff}, + // {"NotFullLine", "baka 10", ErrBadRequest, "", ""}, + } + + for _, tt := range ts { + t.Run(tt.Name, func(t *testing.T) { + conn := _createConn(tt.Data) + p := NewProxyConn(conn) + mlist := proto.GetMsgSlice(2) + + msgs, err := p.Decode(mlist) + + if tt.Err != nil { + _causeEqual(t, tt.Err, err) + } else { + assert.NoError(t, err) + if err != nil { + m := msgs[0] + assert.NotNil(t, m) + assert.NotNil(t, m.Request()) + assert.Equal(t, tt.Key, string(m.Request().Key())) + assert.Equal(t, tt.Cmd, m.Request().Cmd()) + } + } + }) + } +} + +func _createRespMsg(t *testing.T, req []byte, resps [][]byte) *proto.Message { + conn := _createConn([]byte(req)) + p := NewProxyConn(conn) + mlist := proto.GetMsgSlice(2) + + _, err := p.Decode(mlist) + assert.NoError(t, err) + m := mlist[0] + + if !m.IsBatch() { + nc := _createNodeConn(resps[0]) + batch := proto.NewMsgBatch() + batch.AddMsg(m) + err := nc.ReadBatch(batch) + assert.NoError(t, err) + } else { + subs := m.Batch() + for idx, resp := range resps { + nc := _createNodeConn(resp) + batch := proto.NewMsgBatch() + batch.AddMsg(subs[idx]) + err := nc.ReadBatch(batch) + assert.NoError(t, err) + } + } + + return m +} + +func TestProxyConnEncodeOk(t *testing.T) { + getqResp := append(getQRespTestData[0], getQRespTestData[1]...) + getqResp = append(getqResp, getQRespTestData[2]...) + + getqMissResp := append(getQRespTestData[0], getQMissRespTestData...) + getqMissResp = append(getqMissResp, getQRespTestData[2]...) + + getAllMissResp := append(getQMissRespTestData, getQMissRespTestData...) + getAllMissResp = append(getAllMissResp, getMissRespTestData...) + + ts := []struct { + Name string + Req []byte + Resp [][]byte + Except []byte + }{ + {Name: "SetOk", Req: setTestData, Resp: [][]byte{setRespTestData}, Except: setRespTestData}, + {Name: "GetOk", Req: getTestData, Resp: [][]byte{getRespTestData}, Except: getRespTestData}, + {Name: "GetMultiOk", Req: getQTestData, + Resp: getQRespTestData, + Except: getqResp}, + + {Name: "GetMultiMissOne", Req: getQTestData, + Resp: [][]byte{getQRespTestData[0], getQMissRespTestData, getQRespTestData[2]}, + Except: getqMissResp}, + + {Name: "GetMultiAllMiss", Req: getQTestData, + Resp: [][]byte{getQMissRespTestData, getQMissRespTestData, getMissRespTestData}, + Except: getAllMissResp}, + } + + for _, tt := range ts { + t.Run(tt.Name, func(t *testing.T) { + conn := _createConn(nil) + p := NewProxyConn(conn) + msg := _createRespMsg(t, []byte(tt.Req), tt.Resp) + err := p.Encode(msg) + assert.NoError(t, err) + c := conn.Conn.(*mockConn) + buf := make([]byte, 1024) + size, err := c.wbuf.Read(buf) + assert.NoError(t, err) + assert.Equal(t, tt.Except, buf[:size]) + }) + } +} diff --git a/proto/memcache/binary/request.go b/proto/memcache/binary/request.go new file mode 100644 index 00000000..dbc34aca --- /dev/null +++ b/proto/memcache/binary/request.go @@ -0,0 +1,327 @@ +package binary + +import ( + errs "errors" + "fmt" + "sync" +) + +const ( + magicReq = 0x80 + magicResp = 0x81 +) + +var ( + magicReqBytes = []byte{0x80} + magicRespBytes = []byte{0x81} + zeroBytes = []byte{0x00} + zeroTwoBytes = []byte{0x00, 0x00} + zeroFourBytes = []byte{0x00, 0x00, 0x00, 0x00} + zeroEightBytes = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} +) + +// RequestType is the protocol-agnostic identifier for the command +type RequestType byte + +// all memcache request type +const ( + RequestTypeGet RequestType = 0x00 + RequestTypeSet RequestType = 0x01 + RequestTypeAdd RequestType = 0x02 + RequestTypeReplace RequestType = 0x03 + RequestTypeDelete RequestType = 0x04 + RequestTypeIncr RequestType = 0x05 + RequestTypeDecr RequestType = 0x06 + RequestTypeGetQ RequestType = 0x09 + RequestTypeNoop RequestType = 0x0a + RequestTypeGetK RequestType = 0x0c + RequestTypeGetKQ RequestType = 0x0d + RequestTypeAppend RequestType = 0x0e + RequestTypePrepend RequestType = 0x0f + // RequestTypeSetQ = 0x11 + // RequestTypeAddQ = 0x12 + // RequestTypeReplaceQ = 0x13 + // RequestTypeIncrQ = 0x15 + // RequestTypeDecrQ = 0x16 + // RequestTypeAppendQ = 0x19 + // RequestTypePrependQ = 0x1a + RequestTypeTouch RequestType = 0x1c + RequestTypeGat RequestType = 0x1d + // RequestTypeGatQ = 0x1e + RequestTypeUnknown RequestType = 0xff +) + +var ( + getBytes = []byte{byte(RequestTypeGet)} + setBytes = []byte{byte(RequestTypeSet)} + addBytes = []byte{byte(RequestTypeAdd)} + replaceBytes = []byte{byte(RequestTypeReplace)} + deleteBytes = []byte{byte(RequestTypeDelete)} + incrBytes = []byte{byte(RequestTypeIncr)} + decrBytes = []byte{byte(RequestTypeDecr)} + getQBytes = []byte{byte(RequestTypeGetQ)} + noopBytes = []byte{byte(RequestTypeNoop)} + getKBytes = []byte{byte(RequestTypeGetK)} + getKQBytes = []byte{byte(RequestTypeGetKQ)} + appendBytes = []byte{byte(RequestTypeAppend)} + prependBytes = []byte{byte(RequestTypePrepend)} + // setQBytes = []byte{byte(RequestTypeSetQ)} + // addQBytes = []byte{byte(RequestTypeAddQ)} + // replaceQBytes = []byte{byte(RequestTypeReplaceQ)} + // incrQBytes = []byte{byte(RequestTypeIncrQ)} + // decrQBytes = []byte{byte(RequestTypeDecrQ)} + // appendQBytes = []byte{byte(RequestTypeAppendQ)} + // prependQBytes = []byte{byte(RequestTypePrependQ)} + touchBytes = []byte{byte(RequestTypeTouch)} + gatBytes = []byte{byte(RequestTypeGat)} + // gatQBytes = []byte{byte(RequestTypeGatQ)} + unknownBytes = []byte{byte(RequestTypeUnknown)} +) + +const ( + getString = "get" + setString = "set" + addString = "add" + replaceString = "replace" + deleteString = "delete" + incrString = "incr" + decrString = "decr" + getQString = "getq" + noopString = "noop" + getKString = "getk" + getKQString = "getkq" + appendString = "append" + prependString = "prepend" + // setQString = "setq" + // addQString = "addq" + // replaceQString = "replaceq" + // incrQString = "incrq" + // decrQString = "decrq" + // appendQString = "appendq" + // prependQString = "prepend" + touchString = "touch" + gatString = "gat" + // gatQString = "gatQ" + unknownString = "unknown" +) + +// Bytes get reqtype bytes. +func (rt RequestType) Bytes() []byte { + switch rt { + case RequestTypeGet: + return getBytes + case RequestTypeSet: + return setBytes + case RequestTypeAdd: + return addBytes + case RequestTypeReplace: + return replaceBytes + case RequestTypeDelete: + return deleteBytes + case RequestTypeIncr: + return incrBytes + case RequestTypeDecr: + return decrBytes + case RequestTypeGetQ: + return getQBytes + case RequestTypeNoop: + return noopBytes + case RequestTypeGetK: + return getKBytes + case RequestTypeGetKQ: + return getKQBytes + case RequestTypeAppend: + return appendBytes + case RequestTypePrepend: + return prependBytes + // case RequestTypeSetQ: + // return setQBytes + // case RequestTypeAddQ: + // return addQBytes + // case RequestTypeReplaceQ: + // return replaceQBytes + // case RequestTypeIncrQ: + // return incrQBytes + // case RequestTypeDecrQ: + // return decrQBytes + // case RequestTypeAppendQ: + // return appendQBytes + // case RequestTypePrependQ: + // return prependQBytes + case RequestTypeTouch: + return touchBytes + case RequestTypeGat: + return gatBytes + // case RequestTypeGatQ: + // return gatQBytes + } + return unknownBytes +} + +// String get reqtype string. +func (rt RequestType) String() string { + switch rt { + case RequestTypeGet: + return getString + case RequestTypeSet: + return setString + case RequestTypeAdd: + return addString + case RequestTypeReplace: + return replaceString + case RequestTypeDelete: + return deleteString + case RequestTypeIncr: + return incrString + case RequestTypeDecr: + return decrString + case RequestTypeGetQ: + return getQString + case RequestTypeNoop: + return noopString + case RequestTypeGetK: + return getKString + case RequestTypeGetKQ: + return getKQString + case RequestTypeAppend: + return appendString + case RequestTypePrepend: + return prependString + // case RequestTypeSetQ: + // return setQString + // case RequestTypeAddQ: + // return addQString + // case RequestTypeReplaceQ: + // return replaceQString + // case RequestTypeIncrQ: + // return incrQString + // case RequestTypeDecrQ: + // return decrQString + // case RequestTypeAppendQ: + // return appendQString + // case RequestTypePrependQ: + // return prependQString + case RequestTypeTouch: + return touchString + case RequestTypeGat: + return gatString + // case RequestTypeGatQ: + // return gatQString + } + return unknownString +} + +// ResopnseStatus is the protocol-agnostic identifier for the response status +type ResopnseStatus byte + +// all memcache response status +const ( + ResopnseStatusNoErr = 0x0000 + ResopnseStatusKeyNotFound = 0x0001 + ResopnseStatusKeyExists = 0x0002 + ResopnseStatusValueTooLarge = 0x0003 + ResopnseStatusInvalidArg = 0x0004 + ResopnseStatusItemNotStored = 0x0005 + ResopnseStatusNonNumeric = 0x0006 + ResopnseStatusUnknownCmd = 0x0081 + ResopnseStatusOutOfMem = 0x0082 + ResopnseStatusNotSupported = 0x0083 + ResopnseStatusInternalErr = 0x0084 + ResopnseStatusBusy = 0x0085 + ResopnseStatusTemporary = 0x0086 +) + +var ( + resopnseStatusInternalErrBytes = []byte{0x00, 0x84} +) + +// errors +var ( + // ERROR means the client sent a nonexistent command name. + ErrError = errs.New("ERROR") + + // CLIENT_ERROR + // means some sort of client error in the input line, i.e. the input + // doesn't conform to the protocol in some way. is a + // human-readable error string. + ErrBadRequest = errs.New("CLIENT_ERROR bad request") + ErrBadLength = errs.New("CLIENT_ERROR length is not a valid integer") + + // SERVER_ERROR + // means some sort of server error prevents the server from carrying + // out the command. is a human-readable error string. In cases + // of severe server errors, which make it impossible to continue + // serving the client (this shouldn't normally happen), the server will + // close the connection after sending the error line. This is the only + // case in which the server closes a connection to a client. + ErrClosed = errs.New("SERVER_ERROR connection closed") + ErrPingerPong = errs.New("SERVER_ERROR Pinger pong unexpected") + ErrAssertReq = errs.New("SERVER_ERROR assert request not ok") + ErrBadResponse = errs.New("SERVER_ERROR bad response") +) + +// MCRequest is the mc client Msg type and data. +type MCRequest struct { + magic byte // Already known, since we're here + rTp RequestType + keyLen []byte + extraLen []byte + // dataType []byte // Always 0 + // vBucket []byte // Not used + status []byte // response status + bodyLen []byte + opaque []byte + cas []byte + + key []byte + data []byte +} + +var msgPool = &sync.Pool{ + New: func() interface{} { + return NewReq() + }, +} + +// GetReq get the msg from pool +func GetReq() *MCRequest { + return msgPool.Get().(*MCRequest) +} + +// NewReq return new mc req. +func NewReq() *MCRequest { + return &MCRequest{} +} + +// Put put req back to pool. +func (r *MCRequest) Put() { + r.rTp = RequestTypeUnknown + r.keyLen = nil + r.extraLen = nil + r.status = nil + r.bodyLen = nil + r.opaque = nil + r.cas = nil + r.key = nil + r.data = nil + msgPool.Put(r) +} + +// CmdString get cmd. +func (r *MCRequest) CmdString() string { + return r.rTp.String() +} + +// Cmd get Msg cmd. +func (r *MCRequest) Cmd() []byte { + return r.rTp.Bytes() +} + +// Key get Msg key. +func (r *MCRequest) Key() []byte { + return r.key +} + +func (r *MCRequest) String() string { + return fmt.Sprintf("type:%s key:%s data:%s", r.rTp.String(), r.key, r.data) +} diff --git a/proto/memcache/binary/request_test.go b/proto/memcache/binary/request_test.go new file mode 100644 index 00000000..571231f8 --- /dev/null +++ b/proto/memcache/binary/request_test.go @@ -0,0 +1,77 @@ +package binary + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +var _allReqTypes = []RequestType{ + RequestTypeGet, + RequestTypeSet, + RequestTypeAdd, + RequestTypeReplace, + RequestTypeDelete, + RequestTypeIncr, + RequestTypeDecr, + RequestTypeGetQ, + RequestTypeNoop, + RequestTypeGetK, + RequestTypeGetKQ, + RequestTypeAppend, + RequestTypePrepend, + RequestTypeTouch, + RequestTypeGat, + RequestTypeUnknown, +} + +func TestRequestTypeBytes(t *testing.T) { + for _, rtype := range _allReqTypes { + assert.Equal(t, []byte{byte(rtype)}, rtype.Bytes()) + } + assert.Equal(t, getString, RequestTypeGet.String()) + assert.Equal(t, setString, RequestTypeSet.String()) + assert.Equal(t, addString, RequestTypeAdd.String()) + assert.Equal(t, replaceString, RequestTypeReplace.String()) + assert.Equal(t, deleteString, RequestTypeDelete.String()) + assert.Equal(t, incrString, RequestTypeIncr.String()) + assert.Equal(t, decrString, RequestTypeDecr.String()) + assert.Equal(t, getQString, RequestTypeGetQ.String()) + assert.Equal(t, noopString, RequestTypeNoop.String()) + assert.Equal(t, getKString, RequestTypeGetK.String()) + assert.Equal(t, getKQString, RequestTypeGetKQ.String()) + assert.Equal(t, appendString, RequestTypeAppend.String()) + assert.Equal(t, prependString, RequestTypePrepend.String()) + assert.Equal(t, touchString, RequestTypeTouch.String()) + assert.Equal(t, gatString, RequestTypeGat.String()) + assert.Equal(t, unknownString, RequestTypeUnknown.String()) +} + +func TestMCRequestFuncsOk(t *testing.T) { + req := &MCRequest{ + rTp: RequestTypeGet, + keyLen: []byte("key"), + extraLen: []byte("extra"), + status: []byte("status"), + bodyLen: []byte("body"), + opaque: []byte("opaque"), + cas: []byte("cas"), + key: []byte("abc"), + data: []byte("\r\n"), + } + assert.Equal(t, []byte{byte(RequestTypeGet)}, req.Cmd()) + assert.Equal(t, "abc", string(req.Key())) + assert.Equal(t, "type:get key:abc data:\r\n", req.String()) + + req.Put() + + assert.Equal(t, RequestTypeUnknown, req.rTp) + assert.Nil(t, req.keyLen) + assert.Nil(t, req.extraLen) + assert.Nil(t, req.status) + assert.Nil(t, req.bodyLen) + assert.Nil(t, req.opaque) + assert.Nil(t, req.cas) + assert.Nil(t, req.key) + assert.Nil(t, req.data) +} diff --git a/proto/memcache/node_conn.go b/proto/memcache/node_conn.go index 66695f5e..ad2e359a 100644 --- a/proto/memcache/node_conn.go +++ b/proto/memcache/node_conn.go @@ -74,19 +74,19 @@ func (n *nodeConn) WriteBatch(mb *proto.MsgBatch) (err error) { } if err = n.bw.Flush(); err != nil { - err = errors.Wrap(err, "MC Handler handle flush Msg bytes") + err = errors.Wrap(err, "MC Writer handle flush Msg bytes") } return } func (n *nodeConn) write(m *proto.Message) (err error) { if n.Closed() { - err = errors.Wrap(ErrClosed, "MC Handler handle Msg") + err = errors.Wrap(ErrClosed, "MC Writer conn closed") return } mcr, ok := m.Request().(*MCRequest) if !ok { - err = errors.Wrap(ErrAssertMsg, "MC Handler handle assert MCMsg") + err = errors.Wrap(ErrAssertReq, "MC Writer assert request") return } _ = n.bw.Write(mcr.rTp.Bytes()) @@ -105,12 +105,11 @@ func (n *nodeConn) write(m *proto.Message) (err error) { func (n *nodeConn) ReadBatch(mb *proto.MsgBatch) (err error) { if n.Closed() { - err = errors.Wrap(ErrClosed, "MC Handler handle Msg") + err = errors.Wrap(ErrClosed, "MC Reader read batch message") return } defer n.br.ResetBuffer(nil) n.br.ResetBuffer(mb.Buffer()) - var ( size int cursor int @@ -121,19 +120,18 @@ func (n *nodeConn) ReadBatch(mb *proto.MsgBatch) (err error) { ok bool ) m = mb.Nth(nth) + mcr, ok = m.Request().(*MCRequest) if !ok { - err = errors.Wrap(ErrAssertMsg, "MC Handler handle assert MCMsg") + err = errors.Wrap(ErrAssertReq, "MC Writer assert request") return } - for { err = n.br.Read() if err != nil { - err = errors.Wrap(err, "node conn while read") + err = errors.Wrap(err, "MC Reader node conn while read") return } - for { size, err = n.fillMCRequest(mcr, n.br.Buffer().Bytes()[cursor:]) if err == bufio.ErrBufferFull { @@ -151,9 +149,10 @@ func (n *nodeConn) ReadBatch(mb *proto.MsgBatch) (err error) { if m == nil { return } + mcr, ok = m.Request().(*MCRequest) if !ok { - err = errors.Wrap(ErrAssertMsg, "MC Handler handle assert MCMsg") + err = errors.Wrap(ErrAssertReq, "MC Writer assert request") return } } diff --git a/proto/memcache/node_conn_test.go b/proto/memcache/node_conn_test.go index a699bad4..1d45c2f3 100644 --- a/proto/memcache/node_conn_test.go +++ b/proto/memcache/node_conn_test.go @@ -3,11 +3,10 @@ package memcache import ( "fmt" "io" + "net" "testing" "time" - "net" - "overlord/lib/bufio" "overlord/proto" @@ -75,6 +74,45 @@ func TestNodeConnWriteOk(t *testing.T) { } } +func TestNodeConnWriteBatchOk(t *testing.T) { + ts := []struct { + rtype RequestType + key string + data string + except string + }{ + { + rtype: RequestTypeGet, key: "mykey", data: "\r\n", + except: "get mykey\r\n", + }, + { + rtype: RequestTypeSet, key: "mykey", data: " 0 0 1\r\nb\r\n", + except: "set mykey 0 0 1\r\nb\r\n", + }, + } + for _, tt := range ts { + + t.Run(fmt.Sprintf("%v ok", tt.rtype), func(t *testing.T) { + req := _createReqMsg(tt.rtype, []byte(tt.key), []byte(tt.data)) + nc := _createNodeConn(nil) + batch := proto.NewMsgBatch() + batch.AddMsg(req) + + err := nc.WriteBatch(batch) + assert.NoError(t, err) + + c, ok := nc.conn.Conn.(*mockConn) + assert.True(t, ok) + + buf := make([]byte, 1024) + size, err := c.wbuf.Read(buf) + assert.NoError(t, err) + assert.Equal(t, tt.except, string(buf[:size])) + }) + + } +} + func TestNodeConnWriteClosed(t *testing.T) { req := _createReqMsg(RequestTypeGet, []byte("abc"), []byte(" \r\n")) nc := _createNodeConn(nil) @@ -116,7 +154,7 @@ func TestNodeConnWriteTypeAssertFail(t *testing.T) { err := nc.write(req) nc.bw.Flush() assert.Error(t, err) - _causeEqual(t, ErrAssertMsg, err) + _causeEqual(t, ErrAssertReq, err) } func TestNodeConnReadClosed(t *testing.T) { @@ -144,17 +182,17 @@ func TestNodeConnReadOk(t *testing.T) { }{ { suffix: "404", - rtype: RequestTypeGet, key: "mykey", data: " \r\n", + rtype: RequestTypeGet, key: "mykey", data: "\r\n", cData: "END\r\n", except: "END\r\n", }, { suffix: "Ok", - rtype: RequestTypeGet, key: "mykey", data: " \r\n", + rtype: RequestTypeGet, key: "mykey", data: "\r\n", cData: "VALUE mykey 0 1\r\na\r\nEND\r\n", except: "VALUE mykey 0 1\r\na\r\nEND\r\n", }, { suffix: "Ok", - rtype: RequestTypeSet, key: "mykey", data: "0 0 1\r\nb\r\n", + rtype: RequestTypeSet, key: "mykey", data: " 0 0 1\r\nb\r\n", cData: "STORED\r\n", except: "STORED\r\n", }, } @@ -167,7 +205,10 @@ func TestNodeConnReadOk(t *testing.T) { batch.AddMsg(req) err := nc.ReadBatch(batch) assert.NoError(t, err) - assert.Equal(t, tt.except, string(req.Request().Resp())) + + mcr, ok := req.Request().(*MCRequest) + assert.Equal(t, true, ok) + assert.Equal(t, tt.except, string(mcr.data)) }) } @@ -180,7 +221,7 @@ func TestNodeConnAssertError(t *testing.T) { batch := proto.NewMsgBatch() batch.AddMsg(req) err := nc.ReadBatch(batch) - _causeEqual(t, ErrAssertMsg, err) + _causeEqual(t, ErrAssertReq, err) } func TestNocdConnPingOk(t *testing.T) { diff --git a/proto/memcache/pinger.go b/proto/memcache/pinger.go index 6cb35916..869a79ee 100644 --- a/proto/memcache/pinger.go +++ b/proto/memcache/pinger.go @@ -12,10 +12,12 @@ import ( const ( pingBufferSize = 32 - ping = "set _ping 0 0 4\r\npong\r\n" ) -var pong = []byte("STORED\r\n") +var ( + ping = []byte("set _ping 0 0 4\r\npong\r\n") + pong = []byte("STORED\r\n") +) type mcPinger struct { conn *libnet.Conn @@ -37,7 +39,7 @@ func (m *mcPinger) Ping() (err error) { err = ErrPingerPong return } - if err = m.bw.WriteString(ping); err != nil { + if err = m.bw.Write(ping); err != nil { err = errors.Wrap(err, "MC ping write") return } diff --git a/proto/memcache/proxy_conn.go b/proto/memcache/proxy_conn.go index cae2c955..8c370b68 100644 --- a/proto/memcache/proxy_conn.go +++ b/proto/memcache/proxy_conn.go @@ -19,6 +19,10 @@ const ( serverErrorPrefix = "SERVER_ERROR " ) +var ( + serverErrorBytes = []byte(serverErrorPrefix) +) + type proxyConn struct { br *bufio.Reader bw *bufio.Writer @@ -340,32 +344,34 @@ func (p *proxyConn) Encode(m *proto.Message) (err error) { if err = m.Err(); err != nil { se := errors.Cause(err).Error() if !strings.HasPrefix(se, errorPrefix) && !strings.HasPrefix(se, clientErrorPrefix) && !strings.HasPrefix(se, serverErrorPrefix) { // NOTE: the mc error protocol - _ = p.bw.WriteString(serverErrorPrefix) + _ = p.bw.Write(serverErrorBytes) } - _ = p.bw.WriteString(se) + _ = p.bw.Write([]byte(se)) _ = p.bw.Write(crlfBytes) } else { - mcr, ok := m.Request().(*MCRequest) - if !ok { - _ = p.bw.WriteString(serverErrorPrefix) - _ = p.bw.WriteString(ErrAssertMsg.Error()) - _ = p.bw.Write(crlfBytes) - } else { - res := m.Response() - _, ok := withValueTypes[mcr.rTp] - trimEnd := ok && m.IsBatch() - for _, bs := range res { - if trimEnd { - bs = bytes.TrimSuffix(bs, endBytes) + var bs []byte + reqs := m.Requests() + for _, req := range reqs { + mcr, ok := req.(*MCRequest) + if !ok { + _ = p.bw.Write(serverErrorBytes) + _ = p.bw.Write([]byte(ErrAssertReq.Error())) + _ = p.bw.Write(crlfBytes) + } else { + _, ok := withValueTypes[mcr.rTp] + if ok && m.IsBatch() { + bs = bytes.TrimSuffix(mcr.data, endBytes) + } else { + bs = mcr.data } if len(bs) == 0 { continue } _ = p.bw.Write(bs) } - if trimEnd { - _ = p.bw.Write(endBytes) - } + } + if m.IsBatch() { + _ = p.bw.Write(endBytes) } } if err = p.bw.Flush(); err != nil { diff --git a/proto/memcache/request.go b/proto/memcache/request.go index ac50d349..248ca3f1 100644 --- a/proto/memcache/request.go +++ b/proto/memcache/request.go @@ -182,11 +182,10 @@ var ( // serving the client (this shouldn't normally happen), the server will // close the connection after sending the error line. This is the only // case in which the server closes a connection to a client. - ErrClosed = errs.New("SERVER_ERROR connection closed") - ErrPingerPong = errs.New("SERVER_ERROR Pinger pong unexpected") - ErrAssertMsg = errs.New("SERVER_ERROR assert MC Msg not ok") - ErrAssertResponse = errs.New("SERVER_ERROR assert MC response not ok") - ErrBadResponse = errs.New("SERVER_ERROR bad response") + ErrClosed = errs.New("SERVER_ERROR connection closed") + ErrPingerPong = errs.New("SERVER_ERROR Pinger pong unexpected") + ErrAssertReq = errs.New("SERVER_ERROR assert request not ok") + ErrBadResponse = errs.New("SERVER_ERROR bad response") ) // MCRequest is the mc client Msg type and data. @@ -248,11 +247,6 @@ func (r *MCRequest) Key() []byte { return r.key } -// Resp get response data. -func (r *MCRequest) Resp() []byte { - return r.data -} - func (r *MCRequest) String() string { return fmt.Sprintf("type:%s key:%s data:%s", r.rTp.Bytes(), r.key, r.data) } diff --git a/proto/memcache/request_test.go b/proto/memcache/request_test.go index 6e0af6d8..564ccd2c 100644 --- a/proto/memcache/request_test.go +++ b/proto/memcache/request_test.go @@ -41,4 +41,10 @@ func TestMCRequestFuncsOk(t *testing.T) { assert.Equal(t, []byte("get"), req.Cmd()) assert.Equal(t, "abc", string(req.Key())) assert.Equal(t, "type:get key:abc data:\r\n", req.String()) + + req.Put() + + assert.Equal(t, RequestTypeUnknown, req.rTp) + assert.Nil(t, req.key) + assert.Nil(t, req.data) } diff --git a/proto/message.go b/proto/message.go index 721c6437..aac336db 100644 --- a/proto/message.go +++ b/proto/message.go @@ -48,10 +48,9 @@ func PutMsg(msg *Message) { type Message struct { Type CacheType - req []Request - reqn int - subs []*Message - subResps [][]byte + req []Request + reqn int + subs []*Message // Start Time, Write Time, ReadTime, EndTime st, wt, rt, et time.Time @@ -68,7 +67,6 @@ func NewMessage() *Message { func (m *Message) Reset() { m.Type = CacheTypeUnknown m.reqn = 0 - m.subResps = m.subResps[:0] m.st, m.wt, m.rt, m.et = defaultTime, defaultTime, defaultTime, defaultTime m.err = nil } @@ -162,6 +160,14 @@ func (m *Message) Request() Request { return nil } +// Requests return all request. +func (m *Message) Requests() []Request { + if m.reqn == 0 { + return nil + } + return m.req[:m.reqn] +} + // IsBatch returns whether or not batch. func (m *Message) IsBatch() bool { return m.reqn > 1 @@ -188,14 +194,6 @@ func (m *Message) Batch() []*Message { return m.subs[:slen] } -// Response return all response bytes. -func (m *Message) Response() [][]byte { - for i := 0; i < m.reqn; i++ { - m.subResps = append(m.subResps, m.req[i].Resp()) - } - return m.subResps -} - // Err returns error. func (m *Message) Err() error { return m.err diff --git a/proto/types.go b/proto/types.go index ec7de46a..6261cb55 100644 --- a/proto/types.go +++ b/proto/types.go @@ -14,9 +14,10 @@ type CacheType string // Cache type: memcache or redis. const ( - CacheTypeUnknown CacheType = "unknown" - CacheTypeMemcache CacheType = "memcache" - CacheTypeRedis CacheType = "redis" + CacheTypeUnknown CacheType = "unknown" + CacheTypeMemcache CacheType = "memcache" + CacheTypeMemcacheBinary CacheType = "memcache_binary" + CacheTypeRedis CacheType = "redis" ) // Request request interface. @@ -24,7 +25,6 @@ type Request interface { CmdString() string Cmd() []byte Key() []byte - Resp() []byte Put() } diff --git a/proxy/cluster.go b/proxy/cluster.go index b1fa9f9d..01a76b9f 100644 --- a/proxy/cluster.go +++ b/proxy/cluster.go @@ -16,6 +16,8 @@ import ( "overlord/lib/log" "overlord/proto" "overlord/proto/memcache" + mcbin "overlord/proto/memcache/binary" + "github.com/pkg/errors" ) @@ -328,6 +330,8 @@ func newNodeConn(cc *ClusterConfig, addr string) proto.NodeConn { switch cc.CacheType { case proto.CacheTypeMemcache: return memcache.NewNodeConn(cc.Name, addr, dto, rto, wto) + case proto.CacheTypeMemcacheBinary: + return mcbin.NewNodeConn(cc.Name, addr, dto, rto, wto) case proto.CacheTypeRedis: // TODO(felix): support redis default: diff --git a/proxy/handler.go b/proxy/handler.go index 9e77b510..e071bbc9 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -12,6 +12,7 @@ import ( "overlord/lib/prom" "overlord/proto" "overlord/proto/memcache" + mcbin "overlord/proto/memcache/binary" ) const ( @@ -57,6 +58,8 @@ func NewHandler(ctx context.Context, c *Config, conn net.Conn, cluster *Cluster) switch cluster.cc.CacheType { case proto.CacheTypeMemcache: h.pc = memcache.NewProxyConn(h.conn) + case proto.CacheTypeMemcacheBinary: + h.pc = mcbin.NewProxyConn(h.conn) case proto.CacheTypeRedis: // TODO(felix): support redis. default: diff --git a/proxy/proxy.go b/proxy/proxy.go index 279e88b4..161e8525 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -11,6 +11,7 @@ import ( libnet "overlord/lib/net" "overlord/proto" "overlord/proto/memcache" + mcbin "overlord/proto/memcache/binary" "github.com/pkg/errors" ) @@ -93,6 +94,11 @@ func (p *Proxy) serve(cc *ClusterConfig) { m := proto.ErrMessage(ErrProxyMoreMaxConns) _ = encoder.Encode(m) _ = conn.Close() + case proto.CacheTypeMemcacheBinary: + encoder := mcbin.NewProxyConn(libnet.NewConn(conn, time.Second, time.Second)) + m := proto.ErrMessage(ErrProxyMoreMaxConns) + _ = encoder.Encode(m) + _ = conn.Close() case proto.CacheTypeRedis: // TODO(felix): support redis. default: diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index f4008493..9827706b 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -3,6 +3,7 @@ package proxy_test import ( "bufio" "bytes" + "encoding/binary" "net" "testing" "time" @@ -17,7 +18,7 @@ import ( var ( ccs = []*proxy.ClusterConfig{ &proxy.ClusterConfig{ - Name: "test-cluster", + Name: "mc-cluster", HashMethod: "sha1", HashDistribution: "ketama", HashTag: "", @@ -29,12 +30,29 @@ var ( ReadTimeout: 1000, NodeConnections: 10, WriteTimeout: 1000, - // PoolActive: 50, - // PoolIdle: 10, - // PoolIdleTimeout: 100000, - // PoolGetWait: true, - PingFailLimit: 3, - PingAutoEject: false, + PingFailLimit: 3, + PingAutoEject: false, + Servers: []string{ + "127.0.0.1:11211:10", + // "127.0.0.1:11212:10", + // "127.0.0.1:11213:10", + }, + }, + &proxy.ClusterConfig{ + Name: "mcbin-cluster", + HashMethod: "sha1", + HashDistribution: "ketama", + HashTag: "", + CacheType: proto.CacheType("memcache_binary"), + ListenProto: "tcp", + ListenAddr: "127.0.0.1:21212", + RedisAuth: "", + DialTimeout: 1000, + ReadTimeout: 1000, + NodeConnections: 10, + WriteTimeout: 1000, + PingFailLimit: 3, + PingAutoEject: false, Servers: []string{ "127.0.0.1:11211:10", // "127.0.0.1:11212:10", @@ -67,6 +85,35 @@ var ( []byte("gats 123456 a_11 a_22 a_33\r\n"), []byte("noexist a_11\r\n"), } + + cmdBins = [][]byte{ + []byte{ + 0x80, // magic + 0x01, // cmd + 0x00, 0x03, // key len + 0x08, // extra len + 0x00, // data type + 0x00, 0x00, // vbucket + 0x00, 0x00, 0x00, 0x10, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // extra: flags, expiration + 0x41, 0x42, 0x43, // key: ABC + 0x41, 0x42, 0x43, 0x44, 0x45, // value: ABCDE + }, + []byte{ + 0x80, // magic + 0x0c, // cmd + 0x00, 0x03, // key len + 0x00, // extra len + 0x00, // data type + 0x00, 0x00, // vbucket + 0x00, 0x00, 0x00, 0x03, // body len + 0x00, 0x00, 0x00, 0x00, // opaque + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cas + 0x41, 0x42, 0x43, // key: ABC + }, + } ) func init() { @@ -117,13 +164,52 @@ func testCmd(t testing.TB, cmds ...[]byte) { bs = append(bs, bs2...) } } - // t.Logf("read string:%s", bs) + } +} + +func testCmdBin(t testing.TB, cmds ...[]byte) { + conn, err := net.DialTimeout("tcp", "127.0.0.1:21212", time.Second) + if err != nil { + t.Errorf("net dial error:%v", err) + } + defer conn.Close() + br := bufio.NewReader(conn) + for _, cmd := range cmds { + conn.SetWriteDeadline(time.Now().Add(time.Second)) + if _, err := conn.Write(cmd); err != nil { + t.Errorf("conn write cmd:%s error:%v", cmd, err) + } + conn.SetReadDeadline(time.Now().Add(time.Second)) + bs := make([]byte, 24) + if n, err := br.Read(bs); err != nil || n != 24 { + t.Errorf("conn read cmd:%x error:%s resp:%x", cmd[1], err, bs) + continue + } + if bytes.Equal(bs[6:8], []byte{0x00, 0x01}) { + // key not found + continue + } + if !bytes.Equal(bs[6:8], []byte{0x00, 0x00}) { + t.Errorf("conn error:%s %s", bs, cmd) + continue + } + bl := binary.BigEndian.Uint32(bs[8:12]) + if bl > 0 { + body := make([]byte, bl) + n, err := br.Read(body) + if err != nil { + t.Errorf("conn read body error: %v", err) + } else if n != int(bl) { + t.Errorf("conn read body size(%d) not equal(%d)", n, bl) + } + } } } func TestProxyFull(t *testing.T) { for i := 0; i < 100; i++ { testCmd(t, cmds[0], cmds[1], cmds[2], cmds[10], cmds[11]) + testCmdBin(t, cmdBins[0], cmdBins[1]) } } @@ -178,6 +264,7 @@ func BenchmarkCmdSet(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { testCmd(b, cmds[0]) + testCmdBin(b, cmdBins[0]) } }) } @@ -186,6 +273,7 @@ func BenchmarkCmdGet(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { testCmd(b, cmds[1]) + testCmdBin(b, cmdBins[1]) } }) }