diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 00000000..ab5eeaad --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,4 @@ +coverage: + status: + project: on + patch: off diff --git a/.gitignore b/.gitignore index 303a81ea..cb7b05b3 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ .glide/ cmd/proxy/proxy +coverage.txt diff --git a/.travis.yml b/.travis.yml index 37bd95e5..b4c7162b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,6 @@ language: go go: - - 1.9.x - 1.10.x go_import_path: overlord diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c006d1d..98b458ef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Overlord +## Version 1.2.0 +1. add redis protocol support. + ## Version 1.1.0 1. add memcache binary protocol support. 2. add conf file check diff --git a/README.md b/README.md index f5997d35..7b23b325 100644 --- a/README.md +++ b/README.md @@ -35,12 +35,16 @@ go build #### Test ```shell +# test memcache echo -e "set a_11 0 0 5\r\nhello\r\n" | nc 127.0.0.1 21211 # STORED echo -e "get a_11\r\n" | nc 127.0.0.1 21211 # VALUE a_11 0 5 # hello # END + +# test redis +python ./scripts/validate_redis_features.py # require fakeredis==0.11.0 redis==2.10.6 gevent==1.3.5 ``` Congratulations! You've just run the overlord proxy. @@ -48,7 +52,7 @@ Congratulations! You've just run the overlord proxy. ## Features - [x] support memcache protocol -- [ ] support redis protocol +- [x] support redis protocol - [x] connection pool for reduce number to backend caching servers - [x] keepalive & failover - [x] hash tag: specify the part of the key used for hashing diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index e6acbd5b..83a172b0 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -82,6 +82,8 @@ func main() { go http.ListenAndServe(c.Pprof, nil) if c.Proxy.UseMetrics { prom.Init() + } else { + prom.On = false } } // new proxy diff --git a/cmd/proxy/proxy-cluster-example.toml b/cmd/proxy/proxy-cluster-example.toml index 6ca2c199..936b1cd3 100644 --- a/cmd/proxy/proxy-cluster-example.toml +++ b/cmd/proxy/proxy-cluster-example.toml @@ -1,13 +1,13 @@ [[clusters]] # This be used to specify the name of cache cluster. -name = "test-cluster" +name = "test-mc" # The name of the hash function. Possible values are: sha1. -hash_method = "sha1" +hash_method = "fnv1a_64" # The key distribution mode. Possible values are: ketama. hash_distribution = "ketama" # A two character string that specifies the part of the key used for hashing. Eg "{}". hash_tag = "" -# cache type: memcache | redis +# cache type: memcache | memcache_binary |redis cache_type = "memcache" # proxy listen proto: tcp | unix listen_proto = "tcp" @@ -29,5 +29,39 @@ ping_fail_limit = 3 ping_auto_eject = true # A list of server address, port and weight (name:port:weight or ip:port:weight) for this server pool. Also you can use alias name like: ip:port:weight alias. servers = [ - "127.0.0.1:11211:10", + "127.0.0.1:11211:1", ] + +[[clusters]] +# This be used to specify the name of cache cluster. +name = "test-redis" +# The name of the hash function. Possible values are: sha1. +hash_method = "fnv1a_64" +# The key distribution mode. Possible values are: ketama. +hash_distribution = "ketama" +# A two character string that specifies the part of the key used for hashing. Eg "{}". +hash_tag = "" +# cache type: memcache | redis +cache_type = "redis" +# proxy listen proto: tcp | unix +listen_proto = "tcp" +# proxy listen addr: tcp addr | unix sock path +listen_addr = "0.0.0.0:26379" +# Authenticate to the Redis server on connect. +redis_auth = "" +# The dial timeout value in msec that we wait for to establish a connection to the server. By default, we wait indefinitely. +dial_timeout = 1000 +# The read timeout value in msec that we wait for to receive a response from a server. By default, we wait indefinitely. +read_timeout = 1000 +# The write timeout value in msec that we wait for to write a response to a server. By default, we wait indefinitely. +write_timeout = 1000 +# The number of connections that can be opened to each server. By default, we open at most 1 server connection. +node_connections = 2 +# The number of consecutive failures on a server that would lead to it being temporarily ejected when auto_eject is set to true. Defaults to 3. +ping_fail_limit = 3 +# A boolean value that controls if server should be ejected temporarily when it fails consecutively ping_fail_limit times. +ping_auto_eject = false +# A list of server address, port and weight (name:port:weight or ip:port:weight) for this server pool. Also you can use alias name like: ip:port:weight alias. +servers = [ + "127.0.0.1:6379:1", +] \ No newline at end of file diff --git a/codecov.sh b/codecov.sh index 74ad8b01..834b0d1f 100755 --- a/codecov.sh +++ b/codecov.sh @@ -4,9 +4,10 @@ set -e echo "" > coverage.txt for d in $(go list ./... | grep -v vendor | grep -v cmd); do + echo "testing for $d ..." go test -coverprofile=profile.out -covermode=atomic $d if [ -f profile.out ]; then cat profile.out >> coverage.txt rm profile.out fi -done \ No newline at end of file +done diff --git a/lib/bufio/buffer.go b/lib/bufio/buffer.go index 46c06c90..658b004e 100644 --- a/lib/bufio/buffer.go +++ b/lib/bufio/buffer.go @@ -36,9 +36,7 @@ func init() { func initBufPool(idx int) { pools[idx] = &sync.Pool{ New: func() interface{} { - return &Buffer{ - buf: make([]byte, sizes[idx]), - } + return NewBuffer(sizes[idx]) }, } } @@ -49,6 +47,11 @@ type Buffer struct { r, w int } +// NewBuffer new buffer. +func NewBuffer(size int) *Buffer { + return &Buffer{buf: make([]byte, size)} +} + // Bytes return the bytes readed func (b *Buffer) Bytes() []byte { return b.buf[b.r:b.w] @@ -97,8 +100,7 @@ func Get(size int) *Buffer { } i := sort.SearchInts(sizes, size) if i >= len(pools) { - b := &Buffer{buf: make([]byte, size)} - return b + return NewBuffer(size) } b := pools[i].Get().(*Buffer) b.Reset() diff --git a/lib/bufio/io.go b/lib/bufio/io.go index 6470d5b2..281e8724 100644 --- a/lib/bufio/io.go +++ b/lib/bufio/io.go @@ -9,13 +9,13 @@ import ( libnet "overlord/lib/net" ) -const ( - maxBuffered = 64 +var ( + // ErrBufferFull err buffer full + ErrBufferFull = bufio.ErrBufferFull ) -// ErrProxy var ( - ErrBufferFull = bufio.ErrBufferFull + crlfBytes = []byte("\r\n") ) // Reader implements buffering for an io.Reader object. @@ -47,6 +47,16 @@ func (r *Reader) Advance(n int) { r.b.Advance(n) } +// Mark return buf read pos. +func (r *Reader) Mark() int { + return r.b.r +} + +// AdvanceTo reset buffer read pos. +func (r *Reader) AdvanceTo(mark int) { + r.Advance(mark - r.b.r) +} + // Buffer will return the reference of local buffer func (r *Reader) Buffer() *Buffer { return r.b @@ -57,21 +67,31 @@ func (r *Reader) Read() error { if r.err != nil { return r.err } - if r.b.buffered() == r.b.len() { r.b.grow() } - if r.b.w == r.b.len() { r.b.shrink() } - if err := r.fill(); err != io.EOF { return err } return nil } +// ReadLine will read until meet the first crlf bytes. +func (r *Reader) ReadLine() (line []byte, err error) { + idx := bytes.Index(r.b.buf[r.b.r:r.b.w], crlfBytes) + if idx == -1 { + line = nil + err = ErrBufferFull + return + } + line = r.b.buf[r.b.r : r.b.r+idx+2] + r.b.r += idx + 2 + return +} + // ReadSlice will read until the delim or return ErrBufferFull. // It never contains any I/O operation func (r *Reader) ReadSlice(delim byte) (data []byte, err error) { @@ -121,73 +141,9 @@ func (r *Reader) ResetBuffer(b *Buffer) { r.b.w = b.w } -// ReadUntil reads until the first occurrence of delim in the input, -// returning a slice pointing at the bytes in the buffer. -// The bytes stop being valid at the next read. -// If ReadUntil encounters an error before finding a delimiter, -// it returns all the data in the buffer and the error itself (often io.EOF). -// ReadUntil returns err != nil if and only if line does not end in delim. -func (r *Reader) ReadUntil(delim byte) ([]byte, error) { - if r.err != nil { - return nil, r.err - } - for { - var index = bytes.IndexByte(r.b.buf[r.b.r:r.b.w], delim) - if index >= 0 { - limit := r.b.r + index + 1 - slice := r.b.buf[r.b.r:limit] - r.b.r = limit - return slice, nil - } - if r.b.w >= r.b.len() { - r.b.grow() - } - err := r.fill() - if err == io.EOF && r.b.buffered() > 0 { - data := r.b.buf[r.b.r:r.b.w] - r.b.r = r.b.w - return data, nil - } else if err != nil { - r.err = err - return nil, err - } - } -} - -// ReadFull reads exactly n bytes from r into buf. -// It returns the number of bytes copied and an error if fewer bytes were read. -// The error is EOF only if no bytes were read. -// If an EOF happens after reading some but not all the bytes, -// ReadFull returns ErrUnexpectedEOF. -// On return, n == len(buf) if and only if err == nil. -func (r *Reader) ReadFull(n int) ([]byte, error) { - if n <= 0 { - return nil, nil - } - if r.err != nil { - return nil, r.err - } - for { - if r.b.buffered() >= n { - bs := r.b.buf[r.b.r : r.b.r+n] - r.b.r += n - return bs, nil - } - maxCanRead := r.b.len() - r.b.w + r.b.buffered() - if maxCanRead < n { - r.b.grow() - } - err := r.fill() - if err == io.EOF && r.b.buffered() > 0 { - data := r.b.buf[r.b.r:r.b.w] - r.b.r = r.b.w - return data, nil - } else if err != nil { - r.err = err - return nil, err - } - } -} +const ( + maxWritevSize = 1024 +) // Writer implements buffering for an io.Writer object. // If an error occurs writing to a Writer, no more data will be @@ -196,17 +152,17 @@ func (r *Reader) ReadFull(n int) ([]byte, error) { // Flush method to guarantee all data has been forwarded to // the underlying io.Writer. type Writer struct { - wr *libnet.Conn - bufsp net.Buffers - bufs [][]byte - cursor int + wr *libnet.Conn + bufsp net.Buffers + bufs [][]byte + cnt int err error } // NewWriter returns a new Writer whose buffer has the default size. func NewWriter(wr *libnet.Conn) *Writer { - return &Writer{wr: wr, bufs: make([][]byte, maxBuffered)} + return &Writer{wr: wr, bufs: make([][]byte, 0, maxWritevSize)} } // Flush writes any buffered data to the underlying io.Writer. @@ -217,12 +173,13 @@ func (w *Writer) Flush() error { if len(w.bufs) == 0 { return nil } - w.bufsp = net.Buffers(w.bufs[:w.cursor]) + w.bufsp = net.Buffers(w.bufs[:w.cnt]) _, err := w.wr.Writev(&w.bufsp) if err != nil { w.err = err } - w.cursor = 0 + w.bufs = w.bufs[:0] + w.cnt = 0 return w.err } @@ -237,11 +194,10 @@ func (w *Writer) Write(p []byte) (err error) { if p == nil { return nil } - - if w.cursor+1 == maxBuffered { - w.Flush() + w.bufs = append(w.bufs, p) + w.cnt++ + if len(w.bufs) == maxWritevSize { + err = w.Flush() } - w.bufs[w.cursor] = p - w.cursor = (w.cursor + 1) % maxBuffered - return nil + return } diff --git a/lib/bufio/io_test.go b/lib/bufio/io_test.go index f1f825f1..51a0a3a4 100644 --- a/lib/bufio/io_test.go +++ b/lib/bufio/io_test.go @@ -3,12 +3,12 @@ package bufio import ( "bytes" "errors" - "io" "net" - libnet "overlord/lib/net" "testing" "time" + libnet "overlord/lib/net" + "github.com/stretchr/testify/assert" ) @@ -34,6 +34,15 @@ func TestReaderAdvance(t *testing.T) { assert.Len(t, buf.Bytes(), 502) b.Advance(-10) assert.Len(t, buf.Bytes(), 512) + + b.ReadExact(10) + m := b.Mark() + assert.Equal(t, 10, m) + + b.AdvanceTo(5) + m = b.Mark() + assert.Equal(t, 5, m) + assert.Equal(t, 5, b.Buffer().r) } func TestReaderRead(t *testing.T) { @@ -48,18 +57,6 @@ func TestReaderRead(t *testing.T) { assert.EqualError(t, err, "some error") } -func TestReaderReadUntil(t *testing.T) { - bts := _genData() - b := NewReader(bytes.NewBuffer(bts), Get(defaultBufferSize)) - data, err := b.ReadUntil(fbyte) - assert.NoError(t, err) - assert.Len(t, data, 5*3*100) - - b.err = errors.New("some error") - _, err = b.ReadUntil(fbyte) - assert.EqualError(t, err, "some error") -} - func TestReaderReadSlice(t *testing.T) { bts := _genData() @@ -70,20 +67,15 @@ func TestReaderReadSlice(t *testing.T) { assert.Len(t, data, 3) _, err = b.ReadSlice('\n') - assert.EqualError(t, err, "bufio: buffer full") + assert.Equal(t, ErrBufferFull, err) } -func TestReaderReadFull(t *testing.T) { - bts := _genData() - - b := NewReader(bytes.NewBuffer(bts), Get(defaultBufferSize)) - data, err := b.ReadFull(1200) +func TestReaderReadLine(t *testing.T) { + b := NewReader(bytes.NewBuffer([]byte("abcd\r\nabc")), Get(defaultBufferSize)) + b.Read() + data, err := b.ReadLine() assert.NoError(t, err) - assert.Len(t, data, 1200) - - b.err = errors.New("some error") - _, err = b.ReadFull(1) - assert.EqualError(t, err, "some error") + assert.Len(t, data, 6) } func TestReaderReadExact(t *testing.T) { @@ -96,24 +88,30 @@ func TestReaderReadExact(t *testing.T) { assert.Len(t, data, 5) _, err = b.ReadExact(5 * 3 * 100) - assert.EqualError(t, err, "bufio: buffer full") + assert.Equal(t, ErrBufferFull, err) } func TestReaderResetBuffer(t *testing.T) { bts := _genData() b := NewReader(bytes.NewBuffer(bts), Get(defaultBufferSize)) - _, err := b.ReadFull(1200) + err := b.Read() + assert.NoError(t, err) + + _, err = b.ReadExact(512) assert.NoError(t, err) b.ResetBuffer(Get(defaultBufferSize)) - data, err := b.ReadFull(300) + err = b.Read() + assert.NoError(t, err) + + data, err := b.ReadExact(300) assert.NoError(t, err) assert.Len(t, data, 300) - _, err = b.ReadFull(300) + _, err = b.ReadExact(300) assert.Error(t, err) - assert.Equal(t, io.EOF, err) + assert.Equal(t, ErrBufferFull, err) b.ResetBuffer(nil) buf := b.Buffer() diff --git a/lib/conv/conv.go b/lib/conv/conv.go index 799b8e84..2fd72417 100644 --- a/lib/conv/conv.go +++ b/lib/conv/conv.go @@ -4,27 +4,7 @@ import ( "strconv" ) -const ( - maxCmdLen = 64 -) - -var ( - charmap [256]byte -) - -func init() { - for i := range charmap { - c := byte(i) - switch { - case c >= 'A' && c <= 'Z': - charmap[i] = c + 'a' - 'A' - case c >= 'a' && c <= 'z': - charmap[i] = c - } - } -} - -// btoi returns the corresponding value i. +// Btoi returns the corresponding value i. func Btoi(b []byte) (int64, error) { if len(b) != 0 && len(b) < 10 { var neg, i = false, 0 @@ -57,20 +37,20 @@ func Btoi(b []byte) (int64, error) { // UpdateToLower will convert to lower case func UpdateToLower(src []byte) { + const step = byte('a') - byte('A') for i := range src { - if c := charmap[src[i]]; c != 0 { - src[i] = c + if src[i] >= 'A' && src[i] <= 'Z' { + src[i] += step } } } -// ToLower returns a copy of the string s with all Unicode letters mapped to their lower case. -func ToLower(src []byte) []byte { - var lower [maxCmdLen]byte +// UpdateToUpper will convert to lower case +func UpdateToUpper(src []byte) { + const step = byte('a') - byte('A') for i := range src { - if c := charmap[src[i]]; c != 0 { - lower[i] = c + if src[i] >= 'a' && src[i] <= 'z' { + src[i] -= step } } - return lower[:len(src)] } diff --git a/lib/hashkit/hash.go b/lib/hashkit/hash.go new file mode 100644 index 00000000..6fbd2e7d --- /dev/null +++ b/lib/hashkit/hash.go @@ -0,0 +1,18 @@ +package hashkit + +// constants defines +const ( + HashMethodFnv1a = "fnv1a_64" +) + +// NewRing will create new and need init method. +func NewRing(des, method string) *HashRing { + var hash func([]byte) uint + switch method { + case HashMethodFnv1a: + fallthrough + default: + hash = NewFnv1a64().fnv1a64 + } + return newRingWithHash(hash) +} diff --git a/lib/hashkit/hash_test.go b/lib/hashkit/hash_test.go new file mode 100644 index 00000000..205ea895 --- /dev/null +++ b/lib/hashkit/hash_test.go @@ -0,0 +1,15 @@ +package hashkit + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewRingOk(t *testing.T) { + ring := NewRing("redis_cluster", "crc16") + assert.NotNil(t, ring) + + ring = NewRing("ketama", "fnv1a_64") + assert.NotNil(t, ring) +} diff --git a/lib/hashkit/ketama.go b/lib/hashkit/ketama.go index 303cdd6a..3839d109 100644 --- a/lib/hashkit/ketama.go +++ b/lib/hashkit/ketama.go @@ -40,15 +40,15 @@ type HashRing struct { } // Ketama new a hash ring with ketama consistency. -// Default hash: sha1 +// Default hash: fnv1a64 func Ketama() (h *HashRing) { h = new(HashRing) h.hash = NewFnv1a64().fnv1a64 return } -// NewRingWithHash new a hash ring with a hash func. -func NewRingWithHash(hash func([]byte) uint) (h *HashRing) { +// newRingWithHash new a hash ring with a hash func. +func newRingWithHash(hash func([]byte) uint) (h *HashRing) { h = Ketama() h.hash = hash return @@ -161,8 +161,8 @@ func (h *HashRing) DelNode(n string) { } } -// Hash returns result node. -func (h *HashRing) Hash(key []byte) (string, bool) { +// GetNode returns result node by given key. +func (h *HashRing) GetNode(key []byte) (string, bool) { ts, ok := h.ticks.Load().(*tickArray) if !ok || ts.length == 0 { return "", false diff --git a/lib/hashkit/ketama_test.go b/lib/hashkit/ketama_test.go index 4407bc0d..b9c50476 100644 --- a/lib/hashkit/ketama_test.go +++ b/lib/hashkit/ketama_test.go @@ -58,7 +58,7 @@ func testHash(t *testing.T) { for i := 0; i < 1e6; i++ { s := "test value" + strconv.FormatUint(uint64(i), 10) bs := []byte(s) - n, ok := ring.Hash(bs) + n, ok := ring.GetNode(bs) if !ok { if !delAll { t.Error("unexpected not ok???") diff --git a/lib/net/conn.go b/lib/net/conn.go index 00b94ea1..5717de02 100644 --- a/lib/net/conn.go +++ b/lib/net/conn.go @@ -1,10 +1,16 @@ package net import ( + "errors" "net" "time" ) +var ( + // ErrConnClosed error connection closed. + ErrConnClosed = errors.New("connection is closed") +) + // Conn is a net.Conn self implement // Add auto timeout setting. type Conn struct { @@ -52,8 +58,8 @@ func (c *Conn) ReConnect() (err error) { } func (c *Conn) Read(b []byte) (n int, err error) { - if c.closed { - return + if c.closed || c.Conn == nil { + return 0, ErrConnClosed } if c.err != nil && c.addr != "" { if re := c.ReConnect(); re != nil { @@ -68,12 +74,16 @@ func (c *Conn) Read(b []byte) (n int, err error) { return } } + n, err = c.Conn.Read(b) c.err = err return } func (c *Conn) Write(b []byte) (n int, err error) { + if c.closed || c.Conn == nil { + return 0, ErrConnClosed + } if c.err != nil && c.addr != "" { if re := c.ReConnect(); re != nil { err = c.err @@ -103,5 +113,8 @@ func (c *Conn) Close() error { // Writev impl the net.buffersWriter to support writev func (c *Conn) Writev(buf *net.Buffers) (int64, error) { + if c.closed || c.Conn == nil { + return 0, ErrConnClosed + } return buf.WriteTo(c.Conn) } diff --git a/lib/net/net_test.go b/lib/net/conn_test.go similarity index 84% rename from lib/net/net_test.go rename to lib/net/conn_test.go index c5a1a9b0..f5cab8ab 100644 --- a/lib/net/net_test.go +++ b/lib/net/conn_test.go @@ -194,3 +194,38 @@ func TestConnWriteBuffersOk(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 7, int(n)) } + +func TestConnNoConn(t *testing.T) { + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") + assert.NoError(t, err) + l, err := net.ListenTCP("tcp", addr) + assert.NoError(t, err) + laddr := l.Addr() + go func() { + defer l.Close() + buf := make([]byte, 1024) + for { + sock, err := l.Accept() + assert.NoError(t, err) + n, err := sock.Read(buf) + assert.NoError(t, err) + assert.NotZero(t, n) + } + }() + conn := DialWithTimeout(laddr.String(), time.Second, time.Second, time.Second) + conn.Conn = nil + + bs := make([]byte, 1) + n, err := conn.Read(bs) + assert.Equal(t, 0, n) + assert.Equal(t, ErrConnClosed, err) + + n, err = conn.Write(bs) + assert.Equal(t, 0, n) + assert.Equal(t, ErrConnClosed, err) + + buffers := net.Buffers([][]byte{[]byte("baka"), []byte("qiu")}) + n64, err := conn.Writev(&buffers) + assert.Equal(t, int64(0), n64) + assert.Equal(t, ErrConnClosed, err) +} diff --git a/lib/prom/prom.go b/lib/prom/prom.go index d0fffb9f..3a47fe98 100644 --- a/lib/prom/prom.go +++ b/lib/prom/prom.go @@ -30,6 +30,8 @@ var ( clusterNodeErrLabels = []string{"cluster", "node", "cmd", "error"} clusterCmdLabels = []string{"cluster", "cmd"} clusterNodeCmdLabels = []string{"cluster", "node", "cmd"} + // On Prom switch + On = true ) // Init init prometheus. diff --git a/proto/batch.go b/proto/batch.go index 91cb7c12..ff0078a4 100644 --- a/proto/batch.go +++ b/proto/batch.go @@ -12,7 +12,7 @@ import ( ) const ( - defaultRespBufSize = 4096 + defaultRespBufSize = 1024 defaultMsgBatchSize = 2 ) @@ -41,13 +41,6 @@ func NewMsgBatch() *MsgBatch { return msgBatchPool.Get().(*MsgBatch) } -// PutMsgBatch will release the batch object -func PutMsgBatch(b *MsgBatch) { - b.Reset() - b.wg = nil - msgBatchPool.Put(b) -} - // MsgBatch is a single execute unit type MsgBatch struct { buf *bufio.Buffer @@ -115,10 +108,12 @@ func (m *MsgBatch) Msgs() []*Message { // BatchDone will set done and report prom HandleTime. func (m *MsgBatch) BatchDone(cluster, addr string) { - for _, msg := range m.Msgs() { - prom.HandleTime(cluster, addr, msg.Request().CmdString(), int64(msg.RemoteDur()/time.Microsecond)) - } m.Done() + if prom.On { + for _, msg := range m.Msgs() { + prom.HandleTime(cluster, addr, msg.Request().CmdString(), int64(msg.RemoteDur()/time.Microsecond)) + } + } } // BatchDoneWithError will set done with error and report prom ErrIncr. @@ -128,7 +123,19 @@ func (m *MsgBatch) BatchDoneWithError(cluster, addr string, err error) { if log.V(1) { log.Errorf("cluster(%s) Msg(%s) cluster process handle error:%+v", cluster, msg.Request().Key(), err) } - prom.ErrIncr(cluster, addr, msg.Request().CmdString(), errors.Cause(err).Error()) + if prom.On { + prom.ErrIncr(cluster, addr, msg.Request().CmdString(), errors.Cause(err).Error()) + } } m.Done() } + +// DropMsgBatch put MsgBatch into recycle using pool. +func DropMsgBatch(m *MsgBatch) { + m.buf.Reset() + m.msgs = m.msgs[:0] + m.count = 0 + m.wg = nil + msgBatchPool.Put(m) + m = nil +} diff --git a/proto/memcache/binary/mcbin_test.go b/proto/memcache/binary/mcbin_test.go new file mode 100644 index 00000000..7d123bbb --- /dev/null +++ b/proto/memcache/binary/mcbin_test.go @@ -0,0 +1,67 @@ +package binary + +import ( + "bytes" + "net" + "time" + + libnet "overlord/lib/net" +) + +type mockAddr string + +func (m mockAddr) Network() string { + return "tcp" +} +func (m mockAddr) String() string { + return string(m) +} + +type mockConn struct { + addr mockAddr + rbuf *bytes.Buffer + wbuf *bytes.Buffer + data []byte + repeat int +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + if m.repeat > 0 { + m.rbuf.Write(m.data) + m.repeat-- + } + return m.rbuf.Read(b) +} +func (m *mockConn) Write(b []byte) (n int, err error) { + return m.wbuf.Write(b) +} + +// writeBuffers impl the net.buffersWriter to support writev +func (m *mockConn) writeBuffers(buf *net.Buffers) (int64, error) { + return buf.WriteTo(m.wbuf) +} + +func (m *mockConn) Close() error { return nil } +func (m *mockConn) LocalAddr() net.Addr { return m.addr } +func (m *mockConn) RemoteAddr() net.Addr { return m.addr } + +func (m *mockConn) SetDeadline(t time.Time) error { return nil } +func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } + +// _createConn is useful tools for handler test +func _createConn(data []byte) *libnet.Conn { + return _createRepeatConn(data, 1) +} + +func _createRepeatConn(data []byte, r int) *libnet.Conn { + mconn := &mockConn{ + addr: "127.0.0.1:12345", + rbuf: bytes.NewBuffer(nil), + wbuf: new(bytes.Buffer), + data: data, + repeat: r, + } + conn := libnet.NewConn(mconn, time.Second, time.Second) + return conn +} diff --git a/proto/memcache/binary/node_conn.go b/proto/memcache/binary/node_conn.go index 10ad0092..8dab7354 100644 --- a/proto/memcache/binary/node_conn.go +++ b/proto/memcache/binary/node_conn.go @@ -172,7 +172,9 @@ func (n *nodeConn) fillMCRequest(mcr *MCRequest, data []byte) (size int, err err bl := binary.BigEndian.Uint32(mcr.bodyLen) if bl == 0 { if mcr.rTp == RequestTypeGet || mcr.rTp == RequestTypeGetQ || mcr.rTp == RequestTypeGetK || mcr.rTp == RequestTypeGetKQ { - prom.Miss(n.cluster, n.addr) + if prom.On { + prom.Miss(n.cluster, n.addr) + } } size = requestHeaderLen return @@ -184,7 +186,9 @@ func (n *nodeConn) fillMCRequest(mcr *MCRequest, data []byte) (size int, err err mcr.data = data[requestHeaderLen : requestHeaderLen+bl] if mcr.rTp == RequestTypeGet || mcr.rTp == RequestTypeGetQ || mcr.rTp == RequestTypeGetK || mcr.rTp == RequestTypeGetKQ { - prom.Hit(n.cluster, n.addr) + if prom.On { + prom.Hit(n.cluster, n.addr) + } } return } diff --git a/proto/memcache/binary/pinger.go b/proto/memcache/binary/pinger.go index a85394a8..5262753d 100644 --- a/proto/memcache/binary/pinger.go +++ b/proto/memcache/binary/pinger.go @@ -11,7 +11,7 @@ import ( ) const ( - pingBufferSize = 32 + pingBufferSize = 24 ) var ( @@ -50,7 +50,7 @@ func newMCPinger(nc *libnet.Conn) *mcPinger { return &mcPinger{ conn: nc, bw: bufio.NewWriter(nc), - br: bufio.NewReader(nc, bufio.Get(pingBufferSize)), + br: bufio.NewReader(nc, bufio.NewBuffer(pingBufferSize)), } } @@ -64,7 +64,7 @@ func (m *mcPinger) Ping() (err error) { err = errors.Wrap(err, "MC ping flush") return } - err = m.br.Read() + _ = m.br.Read() head, err := m.br.ReadExact(requestHeaderLen) if err != nil { err = errors.Wrap(err, "MC ping read exact") diff --git a/proto/memcache/binary/pinger_test.go b/proto/memcache/binary/pinger_test.go index 9b2b8027..bf271d08 100644 --- a/proto/memcache/binary/pinger_test.go +++ b/proto/memcache/binary/pinger_test.go @@ -1,68 +1,14 @@ package binary import ( - "bytes" - "net" "testing" - "time" "overlord/lib/bufio" - libnet "overlord/lib/net" "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) -type mockAddr string - -func (m mockAddr) Network() string { - return "tcp" -} -func (m mockAddr) String() string { - return string(m) -} - -type mockConn struct { - rbuf *bytes.Buffer - wbuf *bytes.Buffer - addr mockAddr -} - -func (m *mockConn) Read(b []byte) (n int, err error) { - return m.rbuf.Read(b) -} -func (m *mockConn) Write(b []byte) (n int, err error) { - return m.wbuf.Write(b) -} - -// writeBuffers impl the net.buffersWriter to support writev -func (m *mockConn) writeBuffers(buf *net.Buffers) (int64, error) { - return buf.WriteTo(m.wbuf) -} - -func (m *mockConn) Close() error { return nil } -func (m *mockConn) LocalAddr() net.Addr { return m.addr } -func (m *mockConn) RemoteAddr() net.Addr { return m.addr } - -func (m *mockConn) SetDeadline(t time.Time) error { return nil } -func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } -func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } - -// _createConn is useful tools for handler test -func _createConn(data []byte) *libnet.Conn { - return _createRepeatConn(data, 1) -} - -func _createRepeatConn(data []byte, r int) *libnet.Conn { - mconn := &mockConn{ - addr: "127.0.0.1:12345", - rbuf: bytes.NewBuffer(bytes.Repeat(data, r)), - wbuf: new(bytes.Buffer), - } - conn := libnet.NewConn(mconn, time.Second, time.Second) - return conn -} - func TestPingerPingOk(t *testing.T) { conn := _createConn(pongBs) pinger := newMCPinger(conn) @@ -91,7 +37,7 @@ func TestPingerPing100Ok(t *testing.T) { for i := 0; i < 100; i++ { err := pinger.Ping() - assert.NoError(t, err) + assert.NoError(t, err, "error iter: %d", i) } err := pinger.Ping() @@ -111,9 +57,15 @@ func TestPingerClosed(t *testing.T) { } func TestPingerNotReturnPong(t *testing.T) { - conn := _createRepeatConn([]byte("emmmmm...."), 100) + conn := _createConn([]byte("iam test bytes 24 length")) pinger := newMCPinger(conn) err := pinger.Ping() assert.Error(t, err) _causeEqual(t, ErrPingerPong, err) + + conn = _createConn([]byte("less than 24 length")) + pinger = newMCPinger(conn) + err = pinger.Ping() + assert.Error(t, err) + _causeEqual(t, bufio.ErrBufferFull, err) } diff --git a/proto/memcache/binary/proxy_conn.go b/proto/memcache/binary/proxy_conn.go index 89fa4ad0..d14b4a40 100644 --- a/proto/memcache/binary/proxy_conn.go +++ b/proto/memcache/binary/proxy_conn.go @@ -164,6 +164,10 @@ func (p *proxyConn) Encode(m *proto.Message) (err error) { _ = p.bw.Write(mcr.data) } } + return +} + +func (p *proxyConn) Flush() (err error) { if err = p.bw.Flush(); err != nil { err = errors.Wrap(err, "MC Encoder encode response flush bytes") } diff --git a/proto/memcache/binary/proxy_conn_test.go b/proto/memcache/binary/proxy_conn_test.go index 753cce45..02068ab4 100644 --- a/proto/memcache/binary/proxy_conn_test.go +++ b/proto/memcache/binary/proxy_conn_test.go @@ -347,6 +347,7 @@ func TestProxyConnEncodeOk(t *testing.T) { msg := _createRespMsg(t, []byte(tt.Req), tt.Resp) err := p.Encode(msg) assert.NoError(t, err) + assert.NoError(t, p.Flush()) c := conn.Conn.(*mockConn) buf := make([]byte, 1024) size, err := c.wbuf.Read(buf) diff --git a/proto/memcache/mc_test.go b/proto/memcache/mc_test.go new file mode 100644 index 00000000..37f2d8b2 --- /dev/null +++ b/proto/memcache/mc_test.go @@ -0,0 +1,67 @@ +package memcache + +import ( + "bytes" + "net" + "time" + + libnet "overlord/lib/net" +) + +type mockAddr string + +func (m mockAddr) Network() string { + return "tcp" +} +func (m mockAddr) String() string { + return string(m) +} + +type mockConn struct { + addr mockAddr + rbuf *bytes.Buffer + wbuf *bytes.Buffer + data []byte + repeat int +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + if m.repeat > 0 { + m.rbuf.Write(m.data) + m.repeat-- + } + return m.rbuf.Read(b) +} +func (m *mockConn) Write(b []byte) (n int, err error) { + return m.wbuf.Write(b) +} + +// writeBuffers impl the net.buffersWriter to support writev +func (m *mockConn) writeBuffers(buf *net.Buffers) (int64, error) { + return buf.WriteTo(m.wbuf) +} + +func (m *mockConn) Close() error { return nil } +func (m *mockConn) LocalAddr() net.Addr { return m.addr } +func (m *mockConn) RemoteAddr() net.Addr { return m.addr } + +func (m *mockConn) SetDeadline(t time.Time) error { return nil } +func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } + +// _createConn is useful tools for handler test +func _createConn(data []byte) *libnet.Conn { + return _createRepeatConn(data, 1) +} + +func _createRepeatConn(data []byte, r int) *libnet.Conn { + mconn := &mockConn{ + addr: "127.0.0.1:12345", + rbuf: bytes.NewBuffer(nil), + wbuf: new(bytes.Buffer), + data: data, + repeat: r, + } + conn := libnet.NewConn(mconn, time.Second, time.Second) + return conn +} diff --git a/proto/memcache/node_conn.go b/proto/memcache/node_conn.go index ad2e359a..65ded7e5 100644 --- a/proto/memcache/node_conn.go +++ b/proto/memcache/node_conn.go @@ -135,7 +135,6 @@ func (n *nodeConn) ReadBatch(mb *proto.MsgBatch) (err error) { for { size, err = n.fillMCRequest(mcr, n.br.Buffer().Bytes()[cursor:]) if err == bufio.ErrBufferFull { - err = nil break } else if err != nil { return @@ -173,11 +172,14 @@ func (n *nodeConn) fillMCRequest(mcr *MCRequest, data []byte) (size int, err err } if bytes.Equal(bs, endBytes) { - prom.Miss(n.cluster, n.addr) + if prom.On { + prom.Miss(n.cluster, n.addr) + } return } - prom.Hit(n.cluster, n.addr) - + if prom.On { + prom.Hit(n.cluster, n.addr) + } length, err := findLength(bs, mcr.rTp == RequestTypeGets || mcr.rTp == RequestTypeGats) if err != nil { err = errors.Wrap(err, "MC Handler while parse length") @@ -188,7 +190,6 @@ func (n *nodeConn) fillMCRequest(mcr *MCRequest, data []byte) (size int, err err if len(data) < size { return 0, bufio.ErrBufferFull } - mcr.data = data[:size] return } diff --git a/proto/memcache/node_conn_test.go b/proto/memcache/node_conn_test.go index 1d45c2f3..972acc98 100644 --- a/proto/memcache/node_conn_test.go +++ b/proto/memcache/node_conn_test.go @@ -140,10 +140,6 @@ func (*mockReq) Key() []byte { return []byte{} } -func (*mockReq) Resp() []byte { - return nil -} - func (*mockReq) Put() { } @@ -205,9 +201,9 @@ func TestNodeConnReadOk(t *testing.T) { batch.AddMsg(req) err := nc.ReadBatch(batch) assert.NoError(t, err) - mcr, ok := req.Request().(*MCRequest) - assert.Equal(t, true, ok) + assert.True(t, ok) + assert.NotNil(t, mcr) assert.Equal(t, tt.except, string(mcr.data)) }) @@ -225,7 +221,7 @@ func TestNodeConnAssertError(t *testing.T) { } func TestNocdConnPingOk(t *testing.T) { - nc := _createNodeConn(pong) + nc := _createNodeConn(pongBytes) err := nc.Ping() assert.NoError(t, err) assert.NoError(t, nc.Close()) diff --git a/proto/memcache/pinger.go b/proto/memcache/pinger.go index 869a79ee..14ffb60d 100644 --- a/proto/memcache/pinger.go +++ b/proto/memcache/pinger.go @@ -11,12 +11,12 @@ import ( ) const ( - pingBufferSize = 32 + pingBufferSize = 8 ) var ( - ping = []byte("set _ping 0 0 4\r\npong\r\n") - pong = []byte("STORED\r\n") + pingBytes = []byte("set _ping 0 0 4\r\npong\r\n") + pongBytes = []byte("STORED\r\n") ) type mcPinger struct { @@ -30,7 +30,7 @@ func newMCPinger(nc *libnet.Conn) *mcPinger { return &mcPinger{ conn: nc, bw: bufio.NewWriter(nc), - br: bufio.NewReader(nc, bufio.Get(pingBufferSize)), + br: bufio.NewReader(nc, bufio.NewBuffer(pingBufferSize)), } } @@ -39,7 +39,7 @@ func (m *mcPinger) Ping() (err error) { err = ErrPingerPong return } - if err = m.bw.Write(ping); err != nil { + if err = m.bw.Write(pingBytes); err != nil { err = errors.Wrap(err, "MC ping write") return } @@ -47,12 +47,13 @@ func (m *mcPinger) Ping() (err error) { err = errors.Wrap(err, "MC ping flush") return } + _ = m.br.Read() var b []byte - if b, err = m.br.ReadUntil(delim); err != nil { + if b, err = m.br.ReadLine(); err != nil { err = errors.Wrap(err, "MC ping read response") return } - if !bytes.Equal(b, pong) { + if !bytes.Equal(b, pongBytes) { err = ErrPingerPong } return diff --git a/proto/memcache/pinger_test.go b/proto/memcache/pinger_test.go index 0206d577..7f54da5d 100644 --- a/proto/memcache/pinger_test.go +++ b/proto/memcache/pinger_test.go @@ -1,70 +1,16 @@ package memcache import ( - "bytes" - "io" - "net" "testing" - "time" - libnet "overlord/lib/net" + "overlord/lib/bufio" "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) -type mockAddr string - -func (m mockAddr) Network() string { - return "tcp" -} -func (m mockAddr) String() string { - return string(m) -} - -type mockConn struct { - rbuf *bytes.Buffer - wbuf *bytes.Buffer - addr mockAddr -} - -func (m *mockConn) Read(b []byte) (n int, err error) { - return m.rbuf.Read(b) -} -func (m *mockConn) Write(b []byte) (n int, err error) { - return m.wbuf.Write(b) -} - -// writeBuffers impl the net.buffersWriter to support writev -func (m *mockConn) writeBuffers(buf *net.Buffers) (int64, error) { - return buf.WriteTo(m.wbuf) -} - -func (m *mockConn) Close() error { return nil } -func (m *mockConn) LocalAddr() net.Addr { return m.addr } -func (m *mockConn) RemoteAddr() net.Addr { return m.addr } - -func (m *mockConn) SetDeadline(t time.Time) error { return nil } -func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } -func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } - -// _createConn is useful tools for handler test -func _createConn(data []byte) *libnet.Conn { - return _createRepeatConn(data, 1) -} - -func _createRepeatConn(data []byte, r int) *libnet.Conn { - mconn := &mockConn{ - addr: "127.0.0.1:12345", - rbuf: bytes.NewBuffer(bytes.Repeat(data, r)), - wbuf: new(bytes.Buffer), - } - conn := libnet.NewConn(mconn, time.Second, time.Second) - return conn -} - func TestPingerPingOk(t *testing.T) { - conn := _createConn(pong) + conn := _createConn(pongBytes) pinger := newMCPinger(conn) err := pinger.Ping() @@ -72,7 +18,7 @@ func TestPingerPingOk(t *testing.T) { } func TestPingerPingEOF(t *testing.T) { - conn := _createConn(pong) + conn := _createConn(pongBytes) pinger := newMCPinger(conn) err := pinger.Ping() @@ -82,25 +28,25 @@ func TestPingerPingEOF(t *testing.T) { assert.Error(t, err) err = errors.Cause(err) - assert.Equal(t, io.EOF, err) + assert.Equal(t, bufio.ErrBufferFull, err) } func TestPingerPing100Ok(t *testing.T) { - conn := _createRepeatConn(pong, 100) + conn := _createRepeatConn(pongBytes, 100) pinger := newMCPinger(conn) for i := 0; i < 100; i++ { err := pinger.Ping() - assert.NoError(t, err) + assert.NoError(t, err, "error iter: %d", i) } err := pinger.Ping() assert.Error(t, err) - _causeEqual(t, io.EOF, err) + _causeEqual(t, bufio.ErrBufferFull, err) } func TestPingerClosed(t *testing.T) { - conn := _createRepeatConn(pong, 100) + conn := _createRepeatConn(pongBytes, 100) pinger := newMCPinger(conn) err := pinger.Close() assert.NoError(t, err) diff --git a/proto/memcache/proxy_conn.go b/proto/memcache/proxy_conn.go index 8c370b68..92c3151e 100644 --- a/proto/memcache/proxy_conn.go +++ b/proto/memcache/proxy_conn.go @@ -374,6 +374,10 @@ func (p *proxyConn) Encode(m *proto.Message) (err error) { _ = p.bw.Write(endBytes) } } + return +} + +func (p *proxyConn) Flush() (err error) { if err = p.bw.Flush(); err != nil { err = errors.Wrap(err, "MC Encoder encode response flush bytes") } diff --git a/proto/memcache/proxy_conn_test.go b/proto/memcache/proxy_conn_test.go index 81ac6daa..a4b8b497 100644 --- a/proto/memcache/proxy_conn_test.go +++ b/proto/memcache/proxy_conn_test.go @@ -106,7 +106,9 @@ func TestProxyConnDecodeOk(t *testing.T) { conn := _createConn([]byte(tt.Data)) p := NewProxyConn(conn) mlist := proto.GetMsgSlice(2) - + // test req reuse. + mlist[0].WithRequest(NewReq()) + mlist[0].Reset() msgs, err := p.Decode(mlist) if tt.Err != nil { @@ -183,6 +185,8 @@ func TestProxyConnEncodeOk(t *testing.T) { msg := _createRespMsg(t, []byte(tt.Req), tt.Resp) err := p.Encode(msg) assert.NoError(t, err) + err = p.Flush() + assert.NoError(t, err) c := conn.Conn.(*mockConn) buf := make([]byte, 1024) size, err := c.wbuf.Read(buf) diff --git a/proto/memcache/request_test.go b/proto/memcache/request_test.go index 564ccd2c..f0b741a3 100644 --- a/proto/memcache/request_test.go +++ b/proto/memcache/request_test.go @@ -29,6 +29,7 @@ func TestRequestTypeString(t *testing.T) { reg := regexp.MustCompile(`[a-z]+`) for _, rtype := range _allReqTypes { assert.True(t, reg.Match(rtype.Bytes())) + assert.True(t, reg.MatchString(rtype.String())) } } diff --git a/proto/message.go b/proto/message.go index aac336db..a03a71d9 100644 --- a/proto/message.go +++ b/proto/message.go @@ -51,7 +51,6 @@ type Message struct { req []Request reqn int subs []*Message - // Start Time, Write Time, ReadTime, EndTime st, wt, rt, et time.Time err error diff --git a/proto/redis/node_conn.go b/proto/redis/node_conn.go new file mode 100644 index 00000000..9301c7da --- /dev/null +++ b/proto/redis/node_conn.go @@ -0,0 +1,105 @@ +package redis + +import ( + "sync/atomic" + "time" + + "overlord/lib/bufio" + libnet "overlord/lib/net" + "overlord/proto" +) + +const ( + opened = uint32(0) + closed = uint32(1) +) + +type nodeConn struct { + cluster string + addr string + conn *libnet.Conn + bw *bufio.Writer + br *bufio.Reader + state uint32 + + p *pinger +} + +// NewNodeConn create the node conn from proxy to redis +func NewNodeConn(cluster, addr string, dialTimeout, readTimeout, writeTimeout time.Duration) (nc proto.NodeConn) { + conn := libnet.DialWithTimeout(addr, dialTimeout, readTimeout, writeTimeout) + return newNodeConn(cluster, addr, conn) +} + +func newNodeConn(cluster, addr string, conn *libnet.Conn) proto.NodeConn { + return &nodeConn{ + cluster: cluster, + addr: addr, + br: bufio.NewReader(conn, nil), + bw: bufio.NewWriter(conn), + conn: conn, + p: newPinger(conn), + } +} + +func (nc *nodeConn) WriteBatch(mb *proto.MsgBatch) (err error) { + for _, m := range mb.Msgs() { + req, ok := m.Request().(*Request) + if !ok { + m.DoneWithError(ErrBadAssert) + return ErrBadAssert + } + if !req.isSupport() || req.isCtl() { + continue + } + if err = req.resp.encode(nc.bw); err != nil { + m.DoneWithError(err) + return err + } + m.MarkWrite() + } + return nc.bw.Flush() +} + +func (nc *nodeConn) ReadBatch(mb *proto.MsgBatch) (err error) { + nc.br.ResetBuffer(mb.Buffer()) + defer nc.br.ResetBuffer(nil) + begin := nc.br.Mark() + now := nc.br.Mark() + for i := 0; i < mb.Count(); { + m := mb.Nth(i) + req, ok := m.Request().(*Request) + if !ok { + return ErrBadAssert + } + if !req.isSupport() || req.isCtl() { + i++ + continue + } + if err = req.reply.decode(nc.br); err == bufio.ErrBufferFull { + nc.br.AdvanceTo(begin) + if err = nc.br.Read(); err != nil { + return + } + nc.br.AdvanceTo(now) + continue + } else if err != nil { + return + } + m.MarkRead() + now = nc.br.Mark() + i++ + } + return +} + +func (nc *nodeConn) Ping() (err error) { + return nc.p.ping() +} + +func (nc *nodeConn) Close() (err error) { + if atomic.CompareAndSwapUint32(&nc.state, opened, closed) { + return nc.conn.Close() + } + return +} diff --git a/proto/redis/node_conn_test.go b/proto/redis/node_conn_test.go new file mode 100644 index 00000000..bb197150 --- /dev/null +++ b/proto/redis/node_conn_test.go @@ -0,0 +1,161 @@ +package redis + +import ( + "bytes" + "fmt" + "io" + "strconv" + "testing" + + "overlord/proto" + + "github.com/stretchr/testify/assert" +) + +type mockCmd struct { +} + +func (*mockCmd) CmdString() string { + return "" +} + +func (*mockCmd) Cmd() []byte { + return []byte("") +} + +func (*mockCmd) Key() []byte { + return []byte{} +} + +func (*mockCmd) Put() { +} + +func TestNodeConnWriteBatchOk(t *testing.T) { + nc := newNodeConn("baka", "127.0.0.1:12345", _createConn(nil)) + mb := proto.NewMsgBatch() + msg := proto.GetMsg() + req := newRequest("GET", "AA") + msg.WithRequest(req) + mb.AddMsg(msg) + msg = proto.NewMessage() + req = newRequest("unsupport") + msg.WithRequest(req) + mb.AddMsg(msg) + err := nc.WriteBatch(mb) + assert.NoError(t, err) + nc.Close() +} + +func TestNodeConnWriteBadAssert(t *testing.T) { + nc := newNodeConn("baka", "127.0.0.1:12345", _createConn(nil)) + mb := proto.NewMsgBatch() + msg := proto.GetMsg() + msg.WithRequest(&mockCmd{}) + mb.AddMsg(msg) + + err := nc.WriteBatch(mb) + assert.Error(t, err) + assert.Equal(t, ErrBadAssert, err) +} + +func TestReadBatchOk(t *testing.T) { + data := ":1\r\n" + nc := newNodeConn("baka", "127.0.0.1:12345", _createConn([]byte(data))) + mb := proto.NewMsgBatch() + msg := proto.GetMsg() + req := newRequest("unsportcmd", "a") + msg.WithRequest(req) + mb.AddMsg(msg) + msg = proto.GetMsg() + req = newRequest("GET", "a") + msg.WithRequest(req) + mb.AddMsg(msg) + err := nc.ReadBatch(mb) + assert.NoError(t, err) +} + +func TestReadBatchWithBadAssert(t *testing.T) { + nc := newNodeConn("baka", "127.0.0.1:12345", _createConn([]byte(":123\r\n"))) + mb := proto.NewMsgBatch() + msg := proto.GetMsg() + msg.WithRequest(&mockCmd{}) + mb.AddMsg(msg) + err := nc.ReadBatch(mb) + assert.Error(t, err) + assert.Equal(t, ErrBadAssert, err) +} + +func TestReadBatchWithNilError(t *testing.T) { + nc := newNodeConn("baka", "127.0.0.1:12345", _createConn(nil)) + mb := proto.NewMsgBatch() + msg := proto.GetMsg() + req := getReq() + req.mType = mergeTypeJoin + req.reply = &resp{} + req.resp = newresp(respArray, []byte("2")) + req.resp.array = append(req.resp.array, newresp(respBulk, []byte("GET"))) + req.resp.arrayn++ + msg.WithRequest(req) + mb.AddMsg(msg) + err := nc.ReadBatch(mb) + assert.Error(t, err) + assert.Equal(t, io.EOF, err) +} + +func TestPingOk(t *testing.T) { + nc := newNodeConn("baka", "127.0.0.1:12345", _createRepeatConn(pongBytes, 1)) + err := nc.Ping() + assert.NoError(t, err) +} + +func newRequest(cmd string, args ...string) *Request { + respObj := &resp{} + respObj.array = append(respObj.array, newresp(respBulk, []byte(fmt.Sprintf("%d\r\n%s", len(cmd), cmd)))) + respObj.arrayn++ + maxLen := len(args) + 1 + for i := 1; i < maxLen; i++ { + data := args[i-1] + line := fmt.Sprintf("%d\r\n%s", len(data), data) + respObj.array = append(respObj.array, newresp(respBulk, []byte(line))) + respObj.arrayn++ + } + respObj.data = []byte(strconv.Itoa(len(args) + 1)) + return &Request{ + resp: respObj, + mType: getMergeType(respObj.array[0].data), + reply: &resp{}, + } +} +func getMergeType(cmd []byte) mergeType { + // fmt.Println("mtype :", strconv.Quote(string(cmd))) + // TODO: impl with tire tree to search quickly + if bytes.Equal(cmd, cmdMGetBytes) || bytes.Equal(cmd, cmdGetBytes) { + return mergeTypeJoin + } + + if bytes.Equal(cmd, cmdMSetBytes) { + return mergeTypeOK + } + + if bytes.Equal(cmd, cmdExistsBytes) || bytes.Equal(cmd, cmdDelBytes) { + return mergeTypeCount + } + + return mergeTypeNo +} + +func newresp(rtype respType, data []byte) (robj *resp) { + robj = &resp{} + robj.rTp = rtype + robj.data = data + return +} + +func newrespArray(resps []*resp) (robj *resp) { + robj = &resp{} + robj.rTp = respArray + robj.data = []byte((strconv.Itoa(len(resps)))) + robj.array = resps + robj.arrayn = len(resps) + return +} diff --git a/proto/redis/pinger.go b/proto/redis/pinger.go new file mode 100644 index 00000000..21e934e5 --- /dev/null +++ b/proto/redis/pinger.go @@ -0,0 +1,66 @@ +package redis + +import ( + "bytes" + "errors" + "sync/atomic" + + "overlord/lib/bufio" + libnet "overlord/lib/net" +) + +// errors +var ( + ErrPingClosed = errors.New("ping interface has been closed") + ErrBadPong = errors.New("pong response payload is bad") +) + +var ( + pingBytes = []byte("*1\r\n$4\r\nPING\r\n") + pongBytes = []byte("+PONG\r\n") +) + +type pinger struct { + conn *libnet.Conn + + br *bufio.Reader + bw *bufio.Writer + + state uint32 +} + +func newPinger(conn *libnet.Conn) *pinger { + return &pinger{ + conn: conn, + br: bufio.NewReader(conn, bufio.NewBuffer(7)), + bw: bufio.NewWriter(conn), + state: opened, + } +} + +func (p *pinger) ping() (err error) { + if atomic.LoadUint32(&p.state) == closed { + err = ErrPingClosed + return + } + _ = p.bw.Write(pingBytes) + if err = p.bw.Flush(); err != nil { + return err + } + _ = p.br.Read() + data, err := p.br.ReadLine() + if err != nil { + return + } + if !bytes.Equal(data, pongBytes) { + err = ErrBadPong + } + return +} + +func (p *pinger) Close() error { + if atomic.CompareAndSwapUint32(&p.state, opened, closed) { + return p.conn.Close() + } + return nil +} diff --git a/proto/redis/pinger_test.go b/proto/redis/pinger_test.go new file mode 100644 index 00000000..42929148 --- /dev/null +++ b/proto/redis/pinger_test.go @@ -0,0 +1,37 @@ +package redis + +import ( + "testing" + + "overlord/lib/bufio" + + "github.com/stretchr/testify/assert" +) + +func TestPingerPingOk(t *testing.T) { + conn := _createConn(pongBytes) + p := newPinger(conn) + err := p.ping() + assert.NoError(t, err) +} + +func TestPingerClosed(t *testing.T) { + conn := _createRepeatConn(pongBytes, 10) + p := newPinger(conn) + assert.NoError(t, p.Close()) + err := p.ping() + assert.Equal(t, ErrPingClosed, err) + assert.NoError(t, p.Close()) +} + +func TestPingerWrongResp(t *testing.T) { + conn := _createConn([]byte("-Error: iam more than 7 bytes\r\n")) + p := newPinger(conn) + err := p.ping() + assert.Equal(t, bufio.ErrBufferFull, err) + + conn = _createConn([]byte("-Err\r\n")) + p = newPinger(conn) + err = p.ping() + assert.Equal(t, ErrBadPong, err) +} diff --git a/proto/redis/proxy_conn.go b/proto/redis/proxy_conn.go new file mode 100644 index 00000000..e0d7168a --- /dev/null +++ b/proto/redis/proxy_conn.go @@ -0,0 +1,240 @@ +package redis + +import ( + "bytes" + "strconv" + + "overlord/lib/bufio" + "overlord/lib/conv" + libnet "overlord/lib/net" + "overlord/proto" + + "github.com/pkg/errors" +) + +var ( + nullBytes = []byte("-1\r\n") + okBytes = []byte("OK\r\n") + pongDataBytes = []byte("+PONG") + notSupportDataBytes = []byte("Error: command not support") +) + +type proxyConn struct { + br *bufio.Reader + bw *bufio.Writer + completed bool + + resp *resp +} + +// NewProxyConn creates new redis Encoder and Decoder. +func NewProxyConn(conn *libnet.Conn) proto.ProxyConn { + r := &proxyConn{ + br: bufio.NewReader(conn, bufio.Get(1024)), + bw: bufio.NewWriter(conn), + completed: true, + resp: &resp{}, + } + return r +} + +func (pc *proxyConn) Decode(msgs []*proto.Message) ([]*proto.Message, error) { + var err error + if pc.completed { + if err = pc.br.Read(); err != nil { + return nil, err + } + pc.completed = false + } + for i := range msgs { + msgs[i].Type = proto.CacheTypeRedis + // decode + if err = pc.decode(msgs[i]); err == bufio.ErrBufferFull { + pc.completed = true + return msgs[:i], nil + } else if err != nil { + return nil, err + } + msgs[i].MarkStart() + } + return msgs, nil +} + +func (pc *proxyConn) decode(m *proto.Message) (err error) { + mark := pc.br.Mark() + if err = pc.resp.decode(pc.br); err != nil { + if err == bufio.ErrBufferFull { + pc.br.AdvanceTo(mark) + } + return + } + if pc.resp.arrayn < 1 { + r := nextReq(m) + r.resp.copy(pc.resp) + return + } + conv.UpdateToUpper(pc.resp.array[0].data) + cmd := pc.resp.array[0].data // NOTE: when array, first is command + if bytes.Equal(cmd, cmdMSetBytes) { + if pc.resp.arrayn%2 == 0 { + err = ErrBadRequest + return + } + mid := pc.resp.arrayn / 2 + for i := 0; i < mid; i++ { + r := nextReq(m) + r.mType = mergeTypeOK + r.resp.reset() // NOTE: *3\r\n + r.resp.rTp = respArray + r.resp.data = arrayLenThree + // array resp: mset + nre1 := r.resp.next() // NOTE: $4\r\nMSET\r\n + nre1.reset() + nre1.rTp = respBulk + nre1.data = cmdMSetBytes + // array resp: key + nre2 := r.resp.next() // NOTE: $klen\r\nkey\r\n + nre2.copy(pc.resp.array[i*2+1]) + // array resp: value + nre3 := r.resp.next() // NOTE: $vlen\r\nvalue\r\n + nre3.copy(pc.resp.array[i*2+2]) + } + } else if bytes.Equal(cmd, cmdMGetBytes) { + for i := 1; i < pc.resp.arrayn; i++ { + r := nextReq(m) + r.mType = mergeTypeJoin + r.resp.reset() // NOTE: *2\r\n + r.resp.rTp = respArray + r.resp.data = arrayLenTwo + // array resp: get + nre1 := r.resp.next() // NOTE: $3\r\nGET\r\n + nre1.reset() + nre1.rTp = respBulk + nre1.data = cmdGetBytes + // array resp: key + nre2 := r.resp.next() // NOTE: $klen\r\nkey\r\n + nre2.copy(pc.resp.array[i]) + } + } else if bytes.Equal(cmd, cmdDelBytes) || bytes.Equal(cmd, cmdExistsBytes) { + for i := 1; i < pc.resp.arrayn; i++ { + r := nextReq(m) + r.mType = mergeTypeCount + r.resp.reset() // NOTE: *2\r\n + r.resp.rTp = respArray + r.resp.data = arrayLenTwo + // array resp: get + nre1 := r.resp.next() // NOTE: $3\r\nDEL\r\n | $6\r\nEXISTS\r\n + nre1.copy(pc.resp.array[0]) + // array resp: key + nre2 := r.resp.next() // NOTE: $klen\r\nkey\r\n + nre2.copy(pc.resp.array[i]) + } + } else { + r := nextReq(m) + r.resp.copy(pc.resp) + } + return +} + +func nextReq(m *proto.Message) *Request { + req := m.NextReq() + if req == nil { + r := getReq() + m.WithRequest(r) + return r + } + r := req.(*Request) + return r +} + +func (pc *proxyConn) Encode(m *proto.Message) (err error) { + if err = m.Err(); err != nil { + return + } + req, ok := m.Request().(*Request) + if !ok { + return ErrBadAssert + } + if !m.IsBatch() { + if !req.isSupport() { + req.reply.rTp = respError + req.reply.data = notSupportDataBytes + } else if req.isCtl() { + if bytes.Equal(req.Cmd(), pingBytes) { + req.reply.rTp = respString + req.reply.data = pongDataBytes + } + } + err = req.reply.encode(pc.bw) + } else { + switch req.mType { + case mergeTypeOK: + err = pc.mergeOK(m) + case mergeTypeJoin: + err = pc.mergeJoin(m) + case mergeTypeCount: + err = pc.mergeCount(m) + default: + panic("unreachable merge path") + } + } + if err != nil { + err = errors.Wrap(err, "Redis Encoder before flush response") + } + return +} + +func (pc *proxyConn) mergeOK(m *proto.Message) (err error) { + _ = pc.bw.Write(respStringBytes) + err = pc.bw.Write(okBytes) + return +} + +func (pc *proxyConn) mergeCount(m *proto.Message) (err error) { + var sum = 0 + for _, mreq := range m.Requests() { + req, ok := mreq.(*Request) + if !ok { + return ErrBadAssert + } + ival, err := conv.Btoi(req.reply.data) + if err != nil { + return ErrBadCount + } + sum += int(ival) + } + _ = pc.bw.Write(respIntBytes) + _ = pc.bw.Write([]byte(strconv.Itoa(sum))) + err = pc.bw.Write(crlfBytes) + return +} + +func (pc *proxyConn) mergeJoin(m *proto.Message) (err error) { + reqs := m.Requests() + _ = pc.bw.Write(respArrayBytes) + if len(reqs) == 0 { + err = pc.bw.Write(nullBytes) + return + } + _ = pc.bw.Write([]byte(strconv.Itoa(len(reqs)))) + if err = pc.bw.Write(crlfBytes); err != nil { + return + } + for _, mreq := range reqs { + req, ok := mreq.(*Request) + if !ok { + return ErrBadAssert + } + if err = req.reply.encode(pc.bw); err != nil { + return + } + } + return +} + +func (pc *proxyConn) Flush() (err error) { + if err = pc.bw.Flush(); err != nil { + err = errors.Wrap(err, "Redis Encoder flush response") + } + return +} diff --git a/proto/redis/proxy_conn_test.go b/proto/redis/proxy_conn_test.go new file mode 100644 index 00000000..52df85d4 --- /dev/null +++ b/proto/redis/proxy_conn_test.go @@ -0,0 +1,252 @@ +package redis + +import ( + "testing" + + "overlord/proto" + + "github.com/stretchr/testify/assert" +) + +func TestDecodeBasicOk(t *testing.T) { + data := "*2\r\n$3\r\nGET\r\n$4\r\nbaka\r\n" + conn := _createConn([]byte(data)) + pc := NewProxyConn(conn) + + msgs := proto.GetMsgSlice(1) + nmsgs, err := pc.Decode(msgs) + assert.NoError(t, err) + assert.Len(t, nmsgs, 1) + + req := msgs[0].Request().(*Request) + assert.Equal(t, mergeTypeNo, req.mType) + assert.Equal(t, 2, req.resp.arrayn) + assert.Equal(t, "GET", req.CmdString()) + assert.Equal(t, []byte("GET"), req.Cmd()) + assert.Equal(t, "baka", string(req.Key())) + assert.Equal(t, []byte("2"), req.resp.data) + assert.Equal(t, []byte("3\r\nGET"), req.resp.array[0].data) + assert.Equal(t, []byte("4\r\nbaka"), req.resp.array[1].data) +} + +func TestDecodeComplexOk(t *testing.T) { + data := "*3\r\n$4\r\nMGET\r\n$4\r\nbaka\r\n$4\r\nkaba\r\n*5\r\n$4\r\nMSET\r\n$1\r\na\r\n$1\r\nb\r\n$3\r\neee\r\n$5\r\n12345\r\n*3\r\n$4\r\nMGET\r\n$4\r\nenen\r\n$4\r\nnime\r\n*2\r\n$3\r\nGET\r\n$5\r\nabcde\r\n*3\r\n$3\r\nDEL\r\n$1\r\na\r\n$1\r\nb\r\n" + conn := _createConn([]byte(data)) + pc := NewProxyConn(conn) + // test reuse command + msgs := proto.GetMsgSlice(16) + msgs[1].WithRequest(getReq()) + msgs[1].WithRequest(getReq()) + msgs[1].Reset() + msgs[2].WithRequest(getReq()) + msgs[2].WithRequest(getReq()) + msgs[2].WithRequest(getReq()) + msgs[2].Reset() + // decode + nmsgs, err := pc.Decode(msgs) + assert.NoError(t, err) + assert.Len(t, nmsgs, 5) + // MGET baka + assert.Len(t, nmsgs[0].Batch(), 2) + req := msgs[0].Requests()[0].(*Request) + assert.Equal(t, mergeTypeJoin, req.mType) + assert.Equal(t, 2, req.resp.arrayn) + assert.Equal(t, "GET", req.CmdString()) + assert.Equal(t, []byte("GET"), req.Cmd()) + assert.Equal(t, "baka", string(req.Key())) + assert.Equal(t, []byte("2"), req.resp.data) + assert.Equal(t, []byte("3\r\nGET"), req.resp.array[0].data) + assert.Equal(t, []byte("4\r\nbaka"), req.resp.array[1].data) + // MGET kaba + req = msgs[0].Requests()[1].(*Request) + assert.Equal(t, mergeTypeJoin, req.mType) + assert.Equal(t, 2, req.resp.arrayn) + assert.Equal(t, "GET", req.CmdString()) + assert.Equal(t, []byte("GET"), req.Cmd()) + assert.Equal(t, "kaba", string(req.Key())) + assert.Equal(t, []byte("2"), req.resp.data) + assert.Equal(t, []byte("3\r\nGET"), req.resp.array[0].data) + assert.Equal(t, []byte("4\r\nkaba"), req.resp.array[1].data) + // MSET a b + assert.Len(t, nmsgs[1].Batch(), 2) + req = msgs[1].Requests()[0].(*Request) + assert.Equal(t, mergeTypeOK, req.mType) + assert.Equal(t, 3, req.resp.arrayn) + assert.Equal(t, "MSET", req.CmdString()) + assert.Equal(t, []byte("MSET"), req.Cmd()) + assert.Equal(t, "a", string(req.Key())) + assert.Equal(t, []byte("3"), req.resp.data) + assert.Equal(t, []byte("4\r\nMSET"), req.resp.array[0].data) + assert.Equal(t, []byte("1\r\na"), req.resp.array[1].data) + assert.Equal(t, []byte("1\r\nb"), req.resp.array[2].data) + // MSET eee 12345 + req = msgs[1].Requests()[1].(*Request) + assert.Equal(t, mergeTypeOK, req.mType) + assert.Equal(t, 3, req.resp.arrayn) + assert.Equal(t, "MSET", req.CmdString()) + assert.Equal(t, []byte("MSET"), req.Cmd()) + assert.Equal(t, "eee", string(req.Key())) + assert.Equal(t, []byte("3"), req.resp.data) + assert.Equal(t, []byte("4\r\nMSET"), req.resp.array[0].data) + assert.Equal(t, []byte("3\r\neee"), req.resp.array[1].data) + assert.Equal(t, []byte("5\r\n12345"), req.resp.array[2].data) + // MGET enen + assert.Len(t, nmsgs[0].Batch(), 2) + req = msgs[2].Requests()[0].(*Request) + assert.Equal(t, mergeTypeJoin, req.mType) + assert.Equal(t, 2, req.resp.arrayn) + assert.Equal(t, "GET", req.CmdString()) + assert.Equal(t, []byte("GET"), req.Cmd()) + assert.Equal(t, "enen", string(req.Key())) + assert.Equal(t, []byte("2"), req.resp.data) + assert.Equal(t, []byte("3\r\nGET"), req.resp.array[0].data) + assert.Equal(t, []byte("4\r\nenen"), req.resp.array[1].data) + // MGET nime + req = msgs[2].Requests()[1].(*Request) + assert.Equal(t, mergeTypeJoin, req.mType) + assert.Equal(t, 2, req.resp.arrayn) + assert.Equal(t, "GET", req.CmdString()) + assert.Equal(t, []byte("GET"), req.Cmd()) + assert.Equal(t, "nime", string(req.Key())) + assert.Equal(t, []byte("2"), req.resp.data) + assert.Equal(t, []byte("3\r\nGET"), req.resp.array[0].data) + assert.Equal(t, []byte("4\r\nnime"), req.resp.array[1].data) + // GET abcde + req = msgs[3].Requests()[0].(*Request) + assert.Equal(t, mergeTypeNo, req.mType) + assert.Equal(t, 2, req.resp.arrayn) + assert.Equal(t, "GET", req.CmdString()) + assert.Equal(t, []byte("GET"), req.Cmd()) + assert.Equal(t, "abcde", string(req.Key())) + assert.Equal(t, []byte("2"), req.resp.data) + assert.Equal(t, []byte("3\r\nGET"), req.resp.array[0].data) + assert.Equal(t, []byte("5\r\nabcde"), req.resp.array[1].data) + + req = msgs[4].Requests()[0].(*Request) + assert.Equal(t, mergeTypeCount, req.mType) + assert.Equal(t, 2, req.resp.arrayn) + assert.Equal(t, "DEL", req.CmdString()) + assert.Equal(t, "a", string(req.Key())) + assert.Equal(t, []byte("2"), req.resp.data) +} + +func TestEncodeCmdOk(t *testing.T) { + ts := []struct { + Name string + MType mergeType + Reply []*resp + Expect string + }{ + { + Name: "mergeNotSupport", + MType: mergeTypeNo, + Reply: []*resp{ + &resp{ + rTp: respString, + data: []byte("123456789"), + }, + }, + Expect: "-Error: command not support\r\n", + }, + // { + // Name: "mergeCtl", + // MType: mergeTypeNo, + // Reply: []*resp{ + // &resp{ + // rTp: respInt, + // data: []byte("12"), + // }, + // }, + // Expect: ":12\r\n", + // }, + // { + // Name: "mergeError", + // MType: mergeTypeNo, + // Reply: []*resp{ + // &resp{ + // rTp: respError, + // data: []byte("i am error"), + // }, + // }, + // Expect: "-i am error\r\n", + // }, + { + Name: "mergeOK", + MType: mergeTypeOK, + Reply: []*resp{ + &resp{ + rTp: respString, + data: []byte("OK"), + }, + &resp{ + rTp: respString, + data: []byte("OK"), + }, + }, + Expect: "+OK\r\n", + }, + { + Name: "mergeCount", + MType: mergeTypeCount, + Reply: []*resp{ + &resp{ + rTp: respInt, + data: []byte("1"), + }, + &resp{ + rTp: respInt, + data: []byte("1"), + }, + }, + Expect: ":2\r\n", + }, + { + Name: "mergeJoin", + MType: mergeTypeJoin, + Reply: []*resp{ + &resp{ + rTp: respString, + data: []byte("abc"), + }, + &resp{ + rTp: respString, + data: []byte("ooo"), + }, + &resp{ + rTp: respString, + data: []byte("mmm"), + }, + }, + Expect: "*3\r\n+abc\r\n+ooo\r\n+mmm\r\n", + }, + } + for _, tt := range ts { + t.Run(tt.Name, func(t *testing.T) { + msg := proto.NewMessage() + for _, rpl := range tt.Reply { + req := getReq() + req.mType = tt.MType + req.reply = rpl + msg.WithRequest(req) + } + if msg.IsBatch() { + msg.Batch() + } + conn, buf := _createDownStreamConn() + pc := NewProxyConn(conn) + err := pc.Encode(msg) + if !assert.NoError(t, err) { + return + } + err = pc.Flush() + if !assert.NoError(t, err) { + return + } + data := make([]byte, 2048) + size, err := buf.Read(data) + assert.NoError(t, err) + assert.Equal(t, tt.Expect, string(data[:size])) + + }) + } +} diff --git a/proto/redis/redis_test.go b/proto/redis/redis_test.go new file mode 100644 index 00000000..ee4dc455 --- /dev/null +++ b/proto/redis/redis_test.go @@ -0,0 +1,76 @@ +package redis + +import ( + "bytes" + "net" + "time" + + libnet "overlord/lib/net" +) + +type mockAddr string + +func (m mockAddr) Network() string { + return "tcp" +} +func (m mockAddr) String() string { + return string(m) +} + +type mockConn struct { + addr mockAddr + rbuf *bytes.Buffer + wbuf *bytes.Buffer + data []byte + repeat int +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + if m.repeat > 0 { + m.rbuf.Write(m.data) + m.repeat-- + } + return m.rbuf.Read(b) +} +func (m *mockConn) Write(b []byte) (n int, err error) { + return m.wbuf.Write(b) +} + +// writeBuffers impl the net.buffersWriter to support writev +func (m *mockConn) writeBuffers(buf *net.Buffers) (int64, error) { + return buf.WriteTo(m.wbuf) +} + +func (m *mockConn) Close() error { return nil } +func (m *mockConn) LocalAddr() net.Addr { return m.addr } +func (m *mockConn) RemoteAddr() net.Addr { return m.addr } + +func (m *mockConn) SetDeadline(t time.Time) error { return nil } +func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } + +// _createConn is useful tools for handler test +func _createConn(data []byte) *libnet.Conn { + return _createRepeatConn(data, 1) +} + +func _createRepeatConn(data []byte, r int) *libnet.Conn { + mconn := &mockConn{ + addr: "127.0.0.1:12345", + rbuf: bytes.NewBuffer(nil), + wbuf: new(bytes.Buffer), + data: data, + repeat: r, + } + conn := libnet.NewConn(mconn, time.Second, time.Second) + return conn +} + +func _createDownStreamConn() (*libnet.Conn, *bytes.Buffer) { + buf := new(bytes.Buffer) + mconn := &mockConn{ + addr: "127.0.0.1:12345", + wbuf: buf, + } + return libnet.NewConn(mconn, time.Second, time.Second), buf +} diff --git a/proto/redis/request.go b/proto/redis/request.go new file mode 100644 index 00000000..adf4c058 --- /dev/null +++ b/proto/redis/request.go @@ -0,0 +1,259 @@ +package redis + +import ( + "bytes" + errs "errors" + "sync" +) + +var ( + emptyBytes = []byte("") + crlfBytes = []byte("\r\n") + + arrayLenTwo = []byte("2") + arrayLenThree = []byte("3") + + cmdPingBytes = []byte("4\r\nPING") + cmdMSetBytes = []byte("4\r\nMSET") + cmdMGetBytes = []byte("4\r\nMGET") + cmdGetBytes = []byte("3\r\nGET") + cmdDelBytes = []byte("3\r\nDEL") + cmdExistsBytes = []byte("6\r\nEXISTS") + + reqReadCmdsBytes = []byte("" + + "4\r\nDUMP" + + "6\r\nEXISTS" + + "4\r\nPTTL" + + "3\r\nTTL" + + "4\r\nTYPE" + + "8\r\nBITCOUNT" + + "6\r\nBITPOS" + + "3\r\nGET" + + "6\r\nGETBIT" + + "8\r\nGETRANGE" + + "4\r\nMGET" + + "6\r\nSTRLEN" + + "7\r\nHEXISTS" + + "4\r\nHGET" + + "7\r\nHGETALL" + + "5\r\nHKEYS" + + "4\r\nHLEN" + + "5\r\nHMGET" + + "7\r\nHSTRLEN" + + "5\r\nHVALS" + + "5\r\nHSCAN" + + "5\r\nSCARD" + + "5\r\nSDIFF" + + "6\r\nSINTER" + + "9\r\nSISMEMBER" + + "8\r\nSMEMBERS" + + "11\r\nSRANDMEMBER" + + "6\r\nSUNION" + + "5\r\nSSCAN" + + "5\r\nZCARD" + + "6\r\nZCOUNT" + + "9\r\nZLEXCOUNT" + + "6\r\nZRANGE" + + "11\r\nZRANGEBYLEX" + + "13\r\nZRANGEBYSCORE" + + "5\r\nZRANK" + + "9\r\nZREVRANGE" + + "14\r\nZREVRANGEBYLEX" + + "16\r\nZREVRANGEBYSCORE" + + "8\r\nZREVRANK" + + "6\r\nZSCORE" + + "5\r\nZSCAN" + + "6\r\nLINDEX" + + "4\r\nLLEN" + + "6\r\nLRANGE" + + "7\r\nPFCOUNT") + + reqWriteCmdsBytes = []byte("" + + "3\r\nDEL" + + "6\r\nEXPIRE" + + "8\r\nEXPIREAT" + + "7\r\nPERSIST" + + "7\r\nPEXPIRE" + + "9\r\nPEXPIREAT" + + "7\r\nRESTORE" + + "4\r\nSORT" + + "6\r\nAPPEND" + + "4\r\nDECR" + + "6\r\nDECRBY" + + "6\r\nGETSET" + + "4\r\nINCR" + + "6\r\nINCRBY" + + "11\r\nINCRBYFLOAT" + + "4\r\nMSET" + + "6\r\nPSETEX" + + "3\r\nSET" + + "6\r\nSETBIT" + + "5\r\nSETEX" + + "5\r\nSETNX" + + "8\r\nSETRANGE" + + "4\r\nHDEL" + + "7\r\nHINCRBY" + + "12\r\nHINCRBYFLOAT" + + "5\r\nHMSET" + + "4\r\nHSET" + + "6\r\nHSETNX" + + "7\r\nLINSERT" + + "4\r\nLPOP" + + "5\r\nLPUSH" + + "6\r\nLPUSHX" + + "4\r\nLREM" + + "4\r\nLSET" + + "5\r\nLTRIM" + + "4\r\nRPOP" + + "9\r\nRPOPLPUSH" + + "5\r\nRPUSH" + + "6\r\nRPUSHX" + + "4\r\nSADD" + + "5\r\nSMOVE" + + "4\r\nSPOP" + + "4\r\nSREM" + + "4\r\nZADD" + + "7\r\nZINCRBY" + + "11\r\nZINTERSTORE" + + "4\r\nZREM" + + "14\r\nZREMRANGEBYLEX" + + "15\r\nZREMRANGEBYRANK" + + "16\r\nZREMRANGEBYSCORE" + + "5\r\nPFADD" + + "7\r\nPFMERGE") + + reqNotSupportCmdsBytes = []byte("" + + "6\r\nMSETNX" + + "10\r\nSDIFFSTORE" + + "11\r\nSINTERSTORE" + + "11\r\nSUNIONSTORE" + + "11\r\nZUNIONSTORE" + + "5\r\nBLPOP" + + "5\r\nBRPOP" + + "10\r\nBRPOPLPUSH" + + "4\r\nKEYS" + + "7\r\nMIGRATE" + + "4\r\nMOVE" + + "6\r\nOBJECT" + + "9\r\nRANDOMKEY" + + "6\r\nRENAME" + + "8\r\nRENAMENX" + + "4\r\nSCAN" + + "4\r\nWAIT" + + "5\r\nBITOP" + + "4\r\nEVAL" + + "7\r\nEVALSHA" + + "4\r\nAUTH" + + "4\r\nECHO" + + "4\r\nINFO" + + "5\r\nPROXY" + + "7\r\nSLOWLOG" + + "4\r\nQUIT" + + "6\r\nSELECT" + + "4\r\nTIME" + + "6\r\nCONFIG" + + "8\r\nCOMMANDS") + + reqCtlCmdsBytes = []byte("4\r\nPING") +) + +// errors +var ( + ErrBadAssert = errs.New("bad assert for redis") + ErrBadCount = errs.New("bad count number") + ErrBadRequest = errs.New("bad request") +) + +// mergeType is used to decript the merge operation. +type mergeType = uint8 + +// merge types +const ( + mergeTypeNo mergeType = iota + mergeTypeCount + mergeTypeOK + mergeTypeJoin +) + +// Request is the type of a complete redis command +type Request struct { + resp *resp + reply *resp + mType mergeType +} + +var reqPool = &sync.Pool{ + New: func() interface{} { + return newReq() + }, +} + +// getReq get the msg from pool +func getReq() *Request { + return reqPool.Get().(*Request) +} + +func newReq() *Request { + r := &Request{} + r.resp = &resp{} + r.reply = &resp{} + return r +} + +// CmdString get the cmd +func (r *Request) CmdString() string { + return string(r.Cmd()) +} + +// Cmd get the cmd +func (r *Request) Cmd() []byte { + if r.resp.arrayn < 1 { + return emptyBytes + } + cmd := r.resp.array[0] + var pos int + if cmd.rTp == respBulk { + pos = bytes.Index(cmd.data, crlfBytes) + 2 + } + return cmd.data[pos:] +} + +// Key impl the proto.protoRequest and get the Key of redis +func (r *Request) Key() []byte { + if r.resp.arrayn < 1 { + return emptyBytes + } + if r.resp.arrayn == 1 { + return r.resp.array[0].data + } + k := r.resp.array[1] + var pos int + if k.rTp == respBulk { + pos = bytes.Index(k.data, crlfBytes) + 2 + } + return k.data[pos:] +} + +// Put the resource back to pool +func (r *Request) Put() { + r.resp.reset() + r.reply.reset() + r.mType = mergeTypeNo + reqPool.Put(r) +} + +// isSupport check command support. +func (r *Request) isSupport() bool { + if r.resp.arrayn < 1 { + return false + } + return bytes.Index(reqReadCmdsBytes, r.resp.array[0].data) > -1 || bytes.Index(reqWriteCmdsBytes, r.resp.array[0].data) > -1 +} + +// isCtl is control command. +func (r *Request) isCtl() bool { + if r.resp.arrayn < 1 { + return false + } + return bytes.Index(reqCtlCmdsBytes, r.resp.array[0].data) > -1 +} diff --git a/proto/redis/request_test.go b/proto/redis/request_test.go new file mode 100644 index 00000000..1d667fcc --- /dev/null +++ b/proto/redis/request_test.go @@ -0,0 +1,25 @@ +package redis + +import ( + "testing" + + "overlord/lib/bufio" + + "github.com/stretchr/testify/assert" +) + +func TestRequestNewRequest(t *testing.T) { + var bs = []byte("*2\r\n$4\r\nLLEN\r\n$6\r\nmylist\r\n") + // conn + conn := _createConn(bs) + br := bufio.NewReader(conn, bufio.Get(1024)) + br.Read() + req := getReq() + err := req.resp.decode(br) + assert.Nil(t, err) + assert.Equal(t, mergeTypeNo, req.mType) + assert.Equal(t, 2, req.resp.arrayn) + assert.Equal(t, "LLEN", req.CmdString()) + assert.Equal(t, []byte("LLEN"), req.Cmd()) + assert.Equal(t, "mylist", string(req.Key())) +} diff --git a/proto/redis/resp.go b/proto/redis/resp.go new file mode 100644 index 00000000..c8ed56a7 --- /dev/null +++ b/proto/redis/resp.go @@ -0,0 +1,221 @@ +package redis + +import ( + "overlord/lib/bufio" + "overlord/lib/conv" +) + +// respType is the type of redis resp +type respType = byte + +// resp type define +const ( + respUnknown respType = '0' + respString respType = '+' + respError respType = '-' + respInt respType = ':' + respBulk respType = '$' + respArray respType = '*' +) + +var ( + respStringBytes = []byte("+") + respErrorBytes = []byte("-") + respIntBytes = []byte(":") + respBulkBytes = []byte("$") + respArrayBytes = []byte("*") + + nullDataBytes = []byte("-1") +) + +// resp is a redis server protocol item. +type resp struct { + rTp respType + // in Bulk this is the size field + // in array this is the count field + data []byte + array []*resp + // in order to reuse array.use arrayn to mark current obj. + arrayn int +} + +func (r *resp) reset() { + r.rTp = respUnknown + r.data = nil + r.arrayn = 0 + for _, ar := range r.array { + ar.reset() + } +} + +func (r *resp) copy(re *resp) { + r.reset() + r.rTp = re.rTp + r.data = re.data + for i := 0; i < re.arrayn; i++ { + nre := r.next() + nre.copy(re.array[i]) + } +} + +func (r *resp) next() *resp { + if r.arrayn < len(r.array) { + nr := r.array[r.arrayn] + nr.reset() + r.arrayn++ + return nr + } + nr := &resp{} + nr.reset() + r.array = append(r.array, nr) + r.arrayn++ + return nr +} + +func (r *resp) decode(br *bufio.Reader) (err error) { + r.reset() + // start read + line, err := br.ReadLine() + if err != nil { + return err + } + rTp := line[0] + r.rTp = rTp + switch rTp { + case respString, respInt, respError: + r.data = line[1 : len(line)-2] + case respBulk: + err = r.decodeBulk(line, br) + case respArray: + err = r.decodeArray(line, br) + default: + err = ErrBadRequest + } + return +} + +func (r *resp) decodeBulk(line []byte, br *bufio.Reader) (err error) { + ls := len(line) + sBs := line[1 : ls-2] + size, err := conv.Btoi(sBs) + if err != nil { + return + } + if size == -1 { + r.data = nil + return + } + br.Advance(-(ls - 1)) + all := ls - 1 + int(size) + 2 + data, err := br.ReadExact(all) + if err == bufio.ErrBufferFull { + br.Advance(-1) + return err + } else if err != nil { + return + } + r.data = data[:len(data)-2] + return +} + +func (r *resp) decodeArray(line []byte, br *bufio.Reader) (err error) { + ls := len(line) + sBs := line[1 : ls-2] + size, err := conv.Btoi(sBs) + if err != nil { + return + } + if size == -1 { + r.data = nil + return + } + r.data = sBs + mark := br.Mark() + for i := 0; i < int(size); i++ { + nre := r.next() + if err = nre.decode(br); err != nil { + br.AdvanceTo(mark) + br.Advance(-ls) + return + } + } + return +} + +func (r *resp) encode(w *bufio.Writer) (err error) { + switch r.rTp { + case respInt, respString, respError: + err = r.encodePlain(w) + case respBulk: + err = r.encodeBulk(w) + case respArray: + err = r.encodeArray(w) + } + return +} + +func (r *resp) encodePlain(w *bufio.Writer) (err error) { + switch r.rTp { + case respInt: + _ = w.Write(respIntBytes) + case respError: + _ = w.Write(respErrorBytes) + case respString: + _ = w.Write(respStringBytes) + } + if len(r.data) > 0 { + _ = w.Write(r.data) + } + err = w.Write(crlfBytes) + return +} + +func (r *resp) encodeBulk(w *bufio.Writer) (err error) { + _ = w.Write(respBulkBytes) + if len(r.data) > 0 { + _ = w.Write(r.data) + } else { + _ = w.Write(nullDataBytes) + } + err = w.Write(crlfBytes) + return +} + +func (r *resp) encodeArray(w *bufio.Writer) (err error) { + _ = w.Write(respArrayBytes) + if len(r.data) > 0 { + _ = w.Write(r.data) + } else { + _ = w.Write(nullDataBytes) + } + _ = w.Write(crlfBytes) + for i := 0; i < r.arrayn; i++ { + if err = r.array[i].encode(w); err != nil { + return + } + } + return +} + +// func (r *resp) String() string { +// var sb strings.Builder +// sb.Write([]byte{r.rTp}) +// switch r.rTp { +// case respString, respInt, respError: +// sb.Write(r.data) +// sb.Write(crlfBytes) +// case respBulk: +// sb.Write(r.data) +// sb.Write(crlfBytes) +// case respArray: +// sb.Write([]byte(strconv.Itoa(r.arrayn))) +// sb.Write(crlfBytes) + +// for i := 0; i < r.arrayn; i++ { +// sb.WriteString(r.array[i].String()) +// } +// default: +// panic(fmt.Sprintf("not support robj:%s", sb.String())) +// } +// return sb.String() +// } diff --git a/proto/redis/resp_test.go b/proto/redis/resp_test.go new file mode 100644 index 00000000..924b9e6d --- /dev/null +++ b/proto/redis/resp_test.go @@ -0,0 +1,246 @@ +package redis + +import ( + "testing" + + "overlord/lib/bufio" + + "github.com/stretchr/testify/assert" +) + +func TestRespDecode(t *testing.T) { + ts := []struct { + Name string + Bytes []byte + ExpectTp respType + ExpectLen int + ExpectData []byte + ExpectArr [][]byte + }{ + { + Name: "ok", + Bytes: []byte("+OK\r\n"), + ExpectTp: respString, + ExpectLen: 0, + ExpectData: []byte("OK"), + }, + { + Name: "error", + Bytes: []byte("-Error message\r\n"), + ExpectTp: respError, + ExpectLen: 0, + ExpectData: []byte("Error message"), + }, + { + Name: "int", + Bytes: []byte(":1000\r\n"), + ExpectTp: respInt, + ExpectLen: 0, + ExpectData: []byte("1000"), + }, + { + Name: "bulk", + Bytes: []byte("$6\r\nfoobar\r\n"), + ExpectTp: respBulk, + ExpectLen: 0, + ExpectData: []byte("6\r\nfoobar"), + }, + { + Name: "array1", + Bytes: []byte("*2\r\n$3\r\nfoo\r\n$4\r\nbara\r\n"), + ExpectTp: respArray, + ExpectLen: 2, + ExpectData: []byte("2"), + ExpectArr: [][]byte{ + []byte("3\r\nfoo"), + []byte("4\r\nbara"), + }, + }, + { + Name: "array2", + Bytes: []byte("*3\r\n:1\r\n:2\r\n:3\r\n"), + ExpectTp: respArray, + ExpectLen: 3, + ExpectData: []byte("3"), + ExpectArr: [][]byte{ + []byte("1"), + []byte("2"), + []byte("3"), + }, + }, + { + Name: "array3", + Bytes: []byte("*2\r\n*3\r\n:1\r\n:2\r\n:3\r\n*2\r\n+Foo\r\n-Bar\r\n"), + ExpectTp: respArray, + ExpectLen: 2, + ExpectData: []byte("2"), + ExpectArr: [][]byte{ + []byte("3"), + []byte("2"), + }, + }, + } + for _, tt := range ts { + t.Run(tt.Name, func(t *testing.T) { + conn := _createConn(tt.Bytes) + r := &resp{} + r.reset() + br := bufio.NewReader(conn, bufio.Get(1024)) + br.Read() + if err := r.decode(br); err != nil { + t.Fatalf("decode error:%v", err) + } + assert.Equal(t, tt.ExpectTp, r.rTp) + assert.Equal(t, tt.ExpectLen, r.arrayn) + assert.Equal(t, tt.ExpectData, r.data) + if len(tt.ExpectArr) > 0 { + for i, ea := range tt.ExpectArr { + assert.Equal(t, ea, r.array[i].data) + } + } + }) + } +} + +func TestRespEncode(t *testing.T) { + ts := []struct { + Name string + Resp *resp + Expect []byte + }{ + { + Name: "ok", + Resp: &resp{ + rTp: respString, + data: []byte("OK"), + }, + Expect: []byte("+OK\r\n"), + }, + { + Name: "error", + Resp: &resp{ + rTp: respError, + data: []byte("Error message"), + }, + Expect: []byte("-Error message\r\n"), + }, + { + Name: "int", + Resp: &resp{ + rTp: respInt, + data: []byte("1000"), + }, + Expect: []byte(":1000\r\n"), + }, + { + Name: "bulk", + Resp: &resp{ + rTp: respBulk, + data: []byte("6\r\nfoobar"), + }, + Expect: []byte("$6\r\nfoobar\r\n"), + }, + { + Name: "array1", + Resp: &resp{ + rTp: respArray, + data: []byte("2"), + array: []*resp{ + &resp{ + rTp: respBulk, + data: []byte("3\r\nfoo"), + }, + &resp{ + rTp: respBulk, + data: []byte("4\r\nbara"), + }, + }, + arrayn: 2, + }, + Expect: []byte("*2\r\n$3\r\nfoo\r\n$4\r\nbara\r\n"), + }, + { + Name: "array2", + Resp: &resp{ + rTp: respArray, + data: []byte("3"), + array: []*resp{ + &resp{ + rTp: respInt, + data: []byte("1"), + }, + &resp{ + rTp: respInt, + data: []byte("2"), + }, + &resp{ + rTp: respInt, + data: []byte("3"), + }, + }, + arrayn: 3, + }, + Expect: []byte("*3\r\n:1\r\n:2\r\n:3\r\n"), + }, + { + Name: "array3", + Resp: &resp{ + rTp: respArray, + data: []byte("2"), + array: []*resp{ + &resp{ + rTp: respArray, + data: []byte("3"), + array: []*resp{ + &resp{ + rTp: respInt, + data: []byte("1"), + }, + &resp{ + rTp: respInt, + data: []byte("2"), + }, + &resp{ + rTp: respInt, + data: []byte("3"), + }, + }, + arrayn: 3, + }, + &resp{ + rTp: respArray, + data: []byte("2"), + array: []*resp{ + &resp{ + rTp: respString, + data: []byte("Foo"), + }, + &resp{ + rTp: respError, + data: []byte("Bar"), + }, + }, + arrayn: 2, + }, + }, + arrayn: 2, + }, + Expect: []byte("*2\r\n*3\r\n:1\r\n:2\r\n:3\r\n*2\r\n+Foo\r\n-Bar\r\n"), + }, + } + for _, tt := range ts { + t.Run(tt.Name, func(t *testing.T) { + conn := _createConn(nil) + bw := bufio.NewWriter(conn) + + err := tt.Resp.encode(bw) + bw.Flush() + assert.Nil(t, err) + + buf := make([]byte, 1024) + n, err := conn.Conn.(*mockConn).wbuf.Read(buf) + assert.Nil(t, err) + assert.Equal(t, tt.Expect, buf[:n]) + }) + } +} diff --git a/proto/types.go b/proto/types.go index 6261cb55..a46b7a8c 100644 --- a/proto/types.go +++ b/proto/types.go @@ -32,6 +32,7 @@ type Request interface { type ProxyConn interface { Decode([]*Message) ([]*Message, error) Encode(msg *Message) error + Flush() error } // NodeConn handle Msg to backend cache server and read response. diff --git a/proxy/cluster.go b/proxy/cluster.go index 01a76b9f..703ea71d 100644 --- a/proxy/cluster.go +++ b/proxy/cluster.go @@ -17,6 +17,7 @@ import ( "overlord/proto" "overlord/proto/memcache" mcbin "overlord/proto/memcache/binary" + "overlord/proto/redis" "github.com/pkg/errors" ) @@ -87,18 +88,21 @@ func NewCluster(ctx context.Context, cc *ClusterConfig) (c *Cluster) { c.hashTag = []byte{cc.HashTag[0], cc.HashTag[1]} } c.alias = alias - - ring := hashkit.Ketama() - if c.alias { - ring.Init(ans, ws) + // hash ring + ring := hashkit.NewRing(cc.HashDistribution, cc.HashMethod) + if cc.CacheType == proto.CacheTypeMemcache || cc.CacheType == proto.CacheTypeMemcacheBinary || cc.CacheType == proto.CacheTypeRedis { + if c.alias { + ring.Init(ans, ws) + } else { + ring.Init(addrs, ws) + } } else { - ring.Init(addrs, ws) + panic("unsupported protocol") } nodeChan := make(map[int]*batchChanel) nodeMap := make(map[string]int) aliasMap := make(map[string]int) - for i := range addrs { nbc := newBatchChanel(cc.NodeConnections) go c.processBatch(nbc, addrs[i]) @@ -136,10 +140,7 @@ func (c *Cluster) calculateBatchIndex(key []byte) int { // DispatchBatch delivers all the messages to batch execute by hash func (c *Cluster) DispatchBatch(mbs []*proto.MsgBatch, slice []*proto.Message) { // TODO: dynamic update mbs by add more than configrured nodes - var ( - bidx int - ) - + var bidx int for _, msg := range slice { if msg.IsBatch() { for _, sub := range msg.Batch() { @@ -151,7 +152,6 @@ func (c *Cluster) DispatchBatch(mbs []*proto.MsgBatch, slice []*proto.Message) { mbs[bidx].AddMsg(msg) } } - c.deliver(mbs) } @@ -188,16 +188,12 @@ func (c *Cluster) processBatchIO(addr string, ch <-chan *proto.MsgBatch, nc prot } return } - - err := c.processWriteBatch(nc, mb) - if err != nil { + if err := nc.WriteBatch(mb); err != nil { err = errors.Wrap(err, "Cluster batch write") mb.BatchDoneWithError(c.cc.Name, addr, err) continue } - - err = c.processReadBatch(nc, mb) - if err != nil { + if err := nc.ReadBatch(mb); err != nil { err = errors.Wrap(err, "Cluster batch read") mb.BatchDoneWithError(c.cc.Name, addr, err) continue @@ -206,15 +202,6 @@ func (c *Cluster) processBatchIO(addr string, ch <-chan *proto.MsgBatch, nc prot } } -func (c *Cluster) processWriteBatch(w proto.NodeConn, mb *proto.MsgBatch) error { - return w.WriteBatch(mb) -} - -func (c *Cluster) processReadBatch(r proto.NodeConn, mb *proto.MsgBatch) error { - err := r.ReadBatch(mb) - return err -} - func (c *Cluster) startPinger(cc *ClusterConfig, addrs []string, ws []int) { for idx, addr := range addrs { w := ws[idx] @@ -235,6 +222,7 @@ func (c *Cluster) processPing(p *pinger) { if err := p.ping.Ping(); err != nil { p.failure++ p.retries = 0 + log.Warnf("node ping fail:%d times with err:%v", p.failure, err) } else { p.failure = 0 if del { @@ -269,7 +257,7 @@ func (c *Cluster) hash(key []byte) (node string, ok bool) { if len(realKey) == 0 { realKey = key } - node, ok = c.ring.Hash(realKey) + node, ok = c.ring.GetNode(realKey) return } @@ -333,9 +321,8 @@ func newNodeConn(cc *ClusterConfig, addr string) proto.NodeConn { case proto.CacheTypeMemcacheBinary: return mcbin.NewNodeConn(cc.Name, addr, dto, rto, wto) case proto.CacheTypeRedis: - // TODO(felix): support redis + return redis.NewNodeConn(cc.Name, addr, dto, rto, wto) default: panic(proto.ErrNoSupportCacheType) } - return nil } diff --git a/proxy/handler.go b/proxy/handler.go index e071bbc9..76230d88 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -4,6 +4,7 @@ import ( "context" "io" "net" + "strings" "sync/atomic" "time" @@ -13,6 +14,7 @@ import ( "overlord/proto" "overlord/proto/memcache" mcbin "overlord/proto/memcache/binary" + "overlord/proto/redis" ) const ( @@ -44,6 +46,7 @@ type Handler struct { closed int32 // wg sync.WaitGroup err error + str strings.Builder } // NewHandler new a conn handler. @@ -61,7 +64,7 @@ func NewHandler(ctx context.Context, c *Config, conn net.Conn, cluster *Cluster) case proto.CacheTypeMemcacheBinary: h.pc = mcbin.NewProxyConn(h.conn) case proto.CacheTypeRedis: - // TODO(felix): support redis. + h.pc = redis.NewProxyConn(h.conn) default: panic(proto.ErrNoSupportCacheType) } @@ -76,10 +79,16 @@ func (h *Handler) Handle() { go h.handle() } +func (h *Handler) toStr(p []byte) string { + h.str.Reset() + h.str.Write(p) + return h.str.String() +} + func (h *Handler) handle() { var ( messages = proto.GetMsgSlice(defaultConcurrent) - mbatch = proto.NewMsgBatchSlice(len(h.cluster.cc.Servers)) + mbatch = proto.NewMsgBatchSlice(len(h.cluster.nodeMap)) msgs []*proto.Message err error ) @@ -88,6 +97,9 @@ func (h *Handler) handle() { for _, msg := range msgs { msg.Clear() } + for _, mb := range mbatch { + proto.DropMsgBatch(mb) + } h.closeWithError(err) }() @@ -113,7 +125,12 @@ func (h *Handler) handle() { } msg.MarkEnd() msg.ReleaseSubs() - prom.ProxyTime(h.cluster.cc.Name, msg.Request().CmdString(), int64(msg.TotalDur()/time.Microsecond)) + if prom.On { + prom.ProxyTime(h.cluster.cc.Name, h.toStr(msg.Request().Cmd()), int64(msg.TotalDur()/time.Microsecond)) + } + } + if err = h.pc.Flush(); err != nil { + return } // 4. release resource @@ -151,7 +168,9 @@ func (h *Handler) closeWithError(err error) { h.cancel() h.msgCh.Close() _ = h.conn.Close() - prom.ConnDecr(h.cluster.cc.Name) + if prom.On { + prom.ConnDecr(h.cluster.cc.Name) + } if log.V(3) { if err != io.EOF { log.Warnf("cluster(%s) addr(%s) remoteAddr(%s) handler close error:%+v", h.cluster.cc.Name, h.cluster.cc.ListenAddr, h.conn.RemoteAddr(), err) diff --git a/proxy/proxy.go b/proxy/proxy.go index 161e8525..7ba1b70f 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -12,6 +12,7 @@ import ( "overlord/proto" "overlord/proto/memcache" mcbin "overlord/proto/memcache/binary" + "overlord/proto/redis" "github.com/pkg/errors" ) @@ -88,22 +89,19 @@ func (p *Proxy) serve(cc *ClusterConfig) { if p.c.Proxy.MaxConnections > 0 { if conns := atomic.AddInt32(&p.conns, 1); conns > p.c.Proxy.MaxConnections { // cache type + var encoder proto.ProxyConn switch cc.CacheType { case proto.CacheTypeMemcache: - encoder := memcache.NewProxyConn(libnet.NewConn(conn, time.Second, time.Second)) - m := proto.ErrMessage(ErrProxyMoreMaxConns) - _ = encoder.Encode(m) - _ = conn.Close() + encoder = memcache.NewProxyConn(libnet.NewConn(conn, time.Second, time.Second)) case proto.CacheTypeMemcacheBinary: - encoder := mcbin.NewProxyConn(libnet.NewConn(conn, time.Second, time.Second)) - m := proto.ErrMessage(ErrProxyMoreMaxConns) - _ = encoder.Encode(m) - _ = conn.Close() + encoder = mcbin.NewProxyConn(libnet.NewConn(conn, time.Second, time.Second)) case proto.CacheTypeRedis: - // TODO(felix): support redis. - default: - _ = conn.Close() + encoder = redis.NewProxyConn(libnet.NewConn(conn, time.Second, time.Second)) } + if encoder != nil { + _ = encoder.Encode(proto.ErrMessage(ErrProxyMoreMaxConns)) + } + _ = conn.Close() if log.V(3) { log.Warnf("proxy reject connection count(%d) due to more than max(%d)", conns, p.c.Proxy.MaxConnections) } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 9827706b..0935d364 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -26,8 +26,8 @@ var ( ListenProto: "tcp", ListenAddr: "127.0.0.1:21211", RedisAuth: "", - DialTimeout: 1000, - ReadTimeout: 1000, + DialTimeout: 100, + ReadTimeout: 100, NodeConnections: 10, WriteTimeout: 1000, PingFailLimit: 3, @@ -47,8 +47,8 @@ var ( ListenProto: "tcp", ListenAddr: "127.0.0.1:21212", RedisAuth: "", - DialTimeout: 1000, - ReadTimeout: 1000, + DialTimeout: 100, + ReadTimeout: 100, NodeConnections: 10, WriteTimeout: 1000, PingFailLimit: 3, @@ -207,10 +207,10 @@ func testCmdBin(t testing.TB, cmds ...[]byte) { } func TestProxyFull(t *testing.T) { - for i := 0; i < 100; i++ { - testCmd(t, cmds[0], cmds[1], cmds[2], cmds[10], cmds[11]) - testCmdBin(t, cmdBins[0], cmdBins[1]) - } + // for i := 0; i < 10; i++ { + testCmd(t, cmds[0], cmds[1], cmds[2], cmds[10], cmds[11]) + // testCmdBin(t, cmdBins[0], cmdBins[1]) + // } } func TestProxyWithAssert(t *testing.T) { @@ -231,7 +231,7 @@ func TestProxyWithAssert(t *testing.T) { {Name: "MultiCmdGetOk", Line: 6, Cmd: "gets a_11\r\ngets a_11\r\n", Except: []string{"VALUE a_11 0 1", "\r\n1\r\n", "END\r\n"}}, } - for i := 0; i < 100; i++ { + for i := 0; i < 10; i++ { conn, err := net.DialTimeout("tcp", "127.0.0.1:21211", time.Second) if err != nil { t.Errorf("net dial error:%v", err) @@ -251,8 +251,12 @@ func TestProxyWithAssert(t *testing.T) { buf = append(buf, data...) } sb := string(buf) - for _, except := range tt.Except { - assert.Contains(t, sb, except, "CMD:%s", tt.Cmd) + if len(tt.Except) == 1 { + assert.Equal(t, sb, tt.Except[0], "CMD:%s", tt.Cmd) + } else { + for _, except := range tt.Except { + assert.Contains(t, sb, except, "CMD:%s", tt.Cmd) + } } }) } diff --git a/scripts/validate_keys_dist.py b/scripts/validate_keys_dist.py new file mode 100755 index 00000000..7dc99004 --- /dev/null +++ b/scripts/validate_keys_dist.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python + +import redis +import argparse +import os + +from contextlib import contextmanager + +def gen_str(n): + return ''.join(map(lambda xx: (hex(ord(xx))[2:]), os.urandom(n))) + + +def gen_items(prefix, n): + return [ + "_overlord-%s-%s-%010d" % (prefix, gen_str(8), num) for num in range(n) + ] + + +def parse_ip_port(addr): + asp = addr.split(":") + return (asp[0], int(asp[1])) + + +def dial(expect): + ip, port = parse_ip_port(expect) + rc = redis.StrictRedis(host=ip, port=port) + return rc + +@contextmanager +def del_keys(rc, keys): + try: + yield + finally: + epipe = rc.pipeline(transaction=False) + for key in keys: + epipe.delete(key) + epipe.execute() + + +def check_vals(expect_rc, check_rc, keys, vals): + epipe = expect_rc.pipeline(transaction=False) + for key, val in zip(keys, vals): + epipe.set(key, val, ex=10) + epipe.execute() + + cpipe = check_rc.pipeline(transaction=False) + for key in keys: + epipe.get(key) + for i,val in enumerate(epipe.execute()): + assert vals[i] == val + + +def run_check(check, expect, n=1024): + keys = gen_items("keys", n) + vals = gen_items("vals", n) + + expect_rc = dial(expect) + check_rc = dial(check) + + with del_keys(expect_rc, keys): + check_vals(expect_rc, check_rc, keys, vals) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("check", help="address need to be checked.") + parser.add_argument( + "expect", + help= + "expect validate address. command will be send to this address first.") + parser.add_argument("-k", "--keys", type=int, default=1024, help="default 1024. It's recommands be the 10 times than the count of backends.") + opt = parser.parse_args() + check = opt.check + expect = opt.check + run_check(check, expect, n=opt.keys) + + +if __name__ == "__main__": + main() diff --git a/scripts/validate_redis_features.py b/scripts/validate_redis_features.py new file mode 100755 index 00000000..a0a48668 --- /dev/null +++ b/scripts/validate_redis_features.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python + +try: + from gevent.monkey import patch_all + patch_all() + from gevent.pool import Pool + + from fakeredis import FakeStrictRedis + from redis import StrictRedis +except ImportError: + print("""ERROR: you are running within a bad dependencies environment. +You may run the follows commands to fixed it: + + pip install fakeredis==0.11.0 redis==2.10.6 gevent==1.3.5 + +""") + raise + +import random +import sys + +host = "127.0.0.1" +port = 26379 +timeout = 3 + + +def gen_items(prefix, n): + return ["%s%010d" % (prefix, num) for num in range(n)] + + +class Cmd(object): + def __init__(self, fn, *args, **kwargs): + self.fn = fn + self.args = args + self.kwargs = kwargs + + def execute(self): + rslt = self.fn(*self.args, **self.kwargs) + + def __str__(self): + return "Cmd" % (self.fn.__name__, self.args, + self.kwargs) + + def __repr__(self): + return self.__str__() + + +def append_cmd(l, is_write, rcfn, fakefn, *args, **kwargs): + l.append((is_write, Cmd(rcfn, *args, **kwargs), Cmd( + fakefn, *args, **kwargs))) + + +def gen_cmds(fake, rc, keys, vals): + """ Checked Commands list: + STRING Commands: GET(1) SET(2) MGET(n) MSET(2n) + HASH Commands: HGET(2) HGETALL(1) HMGET(2n) HSET(3) HMSET(1+2n) + SET Commands: SCARD(1) SMEMBERS(1) SISMEMBER(2) SADD(2) + ZSET Commands: ZCOUNT(1) ZCARD(1) ZRANGE(3) ZADD(3) + LIST Commands: LLEN(1) LPOP(1) LPUSH(2) RPOP(1) + HYPERLOGLOG Commands: PFCOUNT(1) PFADD + """ + # string + string_cmd_pairs = [] + append_cmd(string_cmd_pairs, False, rc.get, fake.get, keys[0]) + append_cmd(string_cmd_pairs, True, rc.set, fake.set, keys[1], vals[1]) + append_cmd(string_cmd_pairs, False, rc.mget, fake.mget, keys[2], keys[3]) + append_cmd(string_cmd_pairs, True, rc.mset, fake.mset, { + keys[4]: vals[4], + keys[5]: vals[5] + }) + yield string_cmd_pairs + + # hash + hash_cmd_pairs = [] + append_cmd(hash_cmd_pairs, True, rc.hset, fake.hset, keys[100], keys[10], + vals[10]) + append_cmd(hash_cmd_pairs, True, rc.hset, fake.hset, keys[100], keys[11], + vals[11]) + append_cmd(hash_cmd_pairs, True, rc.hset, fake.hset, keys[100], keys[12], + vals[12]) + + append_cmd(hash_cmd_pairs, True, rc.hset, fake.hset, keys[101], keys[10], + vals[10]) + append_cmd(hash_cmd_pairs, True, rc.hset, fake.hset, keys[101], keys[11], + vals[11]) + append_cmd(hash_cmd_pairs, True, rc.hset, fake.hset, keys[101], keys[12], + vals[12]) + + append_cmd(hash_cmd_pairs, False, rc.hget, fake.hget, keys[100], keys[10]) + append_cmd(hash_cmd_pairs, False, rc.hget, fake.hget, keys[100], keys[11]) + append_cmd(hash_cmd_pairs, False, rc.hget, fake.hget, keys[100], keys[12]) + append_cmd(hash_cmd_pairs, False, rc.hget, fake.hget, keys[101], keys[10]) + append_cmd(hash_cmd_pairs, False, rc.hget, fake.hget, keys[101], keys[11]) + append_cmd(hash_cmd_pairs, False, rc.hget, fake.hget, keys[101], keys[12]) + + append_cmd(hash_cmd_pairs, False, rc.hmget, fake.hmget, keys[101], + [keys[10], keys[11]]) + append_cmd(hash_cmd_pairs, False, rc.hmget, fake.hmget, keys[101], + [keys[11], keys[12]]) + append_cmd(hash_cmd_pairs, False, rc.hmget, fake.hmget, keys[101], + [keys[10], keys[12]]) + append_cmd(hash_cmd_pairs, False, rc.hmget, fake.hmget, keys[101], + [keys[13]]) + + append_cmd(hash_cmd_pairs, True, rc.hmset, fake.hmset, keys[100], { + keys[14]: vals[14], + keys[15]: vals[15] + }) + append_cmd(hash_cmd_pairs, True, rc.hmset, fake.hmset, keys[100], { + keys[14]: vals[15], + keys[15]: vals[14] + }) + append_cmd(hash_cmd_pairs, True, rc.hmset, fake.hmset, keys[101], { + keys[14]: vals[14], + keys[15]: vals[15] + }) + append_cmd(hash_cmd_pairs, True, rc.hmset, fake.hmset, keys[101], { + keys[14]: vals[15], + keys[15]: vals[14] + }) + + append_cmd(hash_cmd_pairs, False, rc.hgetall, fake.hgetall, keys[100]) + append_cmd(hash_cmd_pairs, False, rc.hgetall, fake.hgetall, keys[101]) + append_cmd(hash_cmd_pairs, False, rc.hgetall, fake.hgetall, keys[102]) + + yield hash_cmd_pairs + + # set + set_cmd_pairs = [] + append_cmd(set_cmd_pairs, True, rc.sadd, fake.sadd, keys[200], vals[20]) + append_cmd(set_cmd_pairs, True, rc.sadd, fake.sadd, keys[200], vals[21]) + append_cmd(set_cmd_pairs, True, rc.sadd, fake.sadd, keys[200], vals[22]) + + append_cmd(set_cmd_pairs, True, rc.sadd, fake.sadd, keys[201], vals[20]) + append_cmd(set_cmd_pairs, True, rc.sadd, fake.sadd, keys[201], vals[21]) + + append_cmd(set_cmd_pairs, False, rc.smembers, fake.smembers, keys[201]) + append_cmd(set_cmd_pairs, False, rc.smembers, fake.smembers, keys[200]) + append_cmd(set_cmd_pairs, False, rc.smembers, fake.smembers, keys[202]) + + append_cmd(set_cmd_pairs, False, rc.sismember, fake.sismember, keys[200], + vals[20]) + append_cmd(set_cmd_pairs, False, rc.sismember, fake.sismember, keys[200], + vals[21]) + append_cmd(set_cmd_pairs, False, rc.sismember, fake.sismember, keys[200], + vals[23]) + + append_cmd(set_cmd_pairs, False, rc.scard, fake.scard, keys[200]) + append_cmd(set_cmd_pairs, False, rc.scard, fake.scard, keys[201]) + yield set_cmd_pairs + + # zset + zset_cmd_pairs = [] + append_cmd(zset_cmd_pairs, True, rc.zadd, fake.zadd, keys[300], 1.1, + vals[30], 2.2, vals[40]) + append_cmd(zset_cmd_pairs, False, rc.zcount, fake.zcount, keys[300], 0.1, + 1.5) + append_cmd(zset_cmd_pairs, False, rc.zrange, fake.zrange, keys[300], 0, 20) + append_cmd(zset_cmd_pairs, False, rc.zcard, fake.zcard, keys[300]) + yield zset_cmd_pairs + + list_cmd_pairs = [] + append_cmd(list_cmd_pairs, True, rc.lpush, fake.lpush, keys[400], vals[40]) + append_cmd(list_cmd_pairs, False, rc.llen, fake.llen, keys[400]) + append_cmd(list_cmd_pairs, False, rc.rpop, fake.rpop, keys[400]) + append_cmd(list_cmd_pairs, False, rc.lpop, fake.lpop, keys[400]) + yield list_cmd_pairs + + hyperloglog_cmd_pairs = [] + append_cmd(hyperloglog_cmd_pairs, True, rc.pfadd, fake.pfadd, keys[500], vals[50]) + append_cmd(hyperloglog_cmd_pairs, True, rc.pfadd, fake.pfadd, keys[500], vals[51]) + append_cmd(hyperloglog_cmd_pairs, True, rc.pfadd, fake.pfadd, keys[500], vals[52]) + + append_cmd(hyperloglog_cmd_pairs, True, rc.pfcount, fake.pfcount, keys[500]) + append_cmd(hyperloglog_cmd_pairs, True, rc.pfcount, fake.pfcount, keys[500]) + yield hyperloglog_cmd_pairs + +def run_check(is_write, cmd1, cmd2): + if is_write: + cmd1.execute() + cmd2.execute() + else: + rslt1 = None + rslt2 = None + try: + rslt1 = cmd1.execute() + rslt2 = cmd2.execute() + assert rslt1 == rslt2 + except AssertionError as e: + print("assert execute %s == %s" % (cmd1, cmd2)) + print("\tassert result %s == %s" % (rslt1, rslt2)) + raise + + +def check(cmds_list_all, iround=100): + for cmds_list in cmds_list_all: + for _ in range(iround): + if cmds_list: + is_write, cmd1, cmd2 = random.choice(cmds_list) + run_check(is_write, cmd1, cmd2) + +def delete_all(rc, keys): + pipe = rc.pipeline(transaction=False) + for key in keys: + rc.delete(key) + pipe.execute() + +def run(i): + keys = gen_items("keys-%04d" % (i, ), 501) + vals = gen_items("vals-%04d" % (i, ), 60) + rc = StrictRedis(host=host, port=port) + fake = FakeStrictRedis() + cmds_list = list(gen_cmds(fake, rc, keys, vals)) + check(cmds_list) + delete_all(rc, keys) + + +def main(): + global host, port + if len(sys.argv) == 3: + host = sys.argv[1].strip() + port = int(sys.argv[2].strip()) + elif len(sys.argv) == 2: + port = int(sys.argv[1].strip()) + + pool = Pool(20) + list(pool.map(run, range(100))) + + +if __name__ == "__main__": + main() diff --git a/vendor/vendor.json b/vendor/vendor.json index 7f20161d..93d06e02 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -111,5 +111,5 @@ "revisionTime": "2018-05-06T18:05:49Z" } ], - "rootPath": "github.com/felixhao/overlord" + "rootPath": "overlord" }