From f6ae9e015eab6230d0be27efd338fdba55fbb683 Mon Sep 17 00:00:00 2001 From: wangzhuowei Date: Thu, 30 May 2024 17:30:25 +0800 Subject: [PATCH] refactor: simplify shard queue --- mux/shard_queue.go | 196 +++++++++++++++++++--------------------- mux/shard_queue_test.go | 78 +++++++++++----- 2 files changed, 144 insertions(+), 130 deletions(-) diff --git a/mux/shard_queue.go b/mux/shard_queue.go index 364fabae..40f39656 100644 --- a/mux/shard_queue.go +++ b/mux/shard_queue.go @@ -17,93 +17,89 @@ package mux import ( "fmt" "runtime" - "sync" "sync/atomic" - "github.com/bytedance/gopkg/util/gopool" - "github.com/cloudwego/netpoll" ) -/* DOC: - * ShardQueue uses the netpoll's nocopy API to merge and send data. - * The Data Flush is passively triggered by ShardQueue.Add and does not require user operations. - * If there is an error in the data transmission, the connection will be closed. - * - * ShardQueue.Add: add the data to be sent. - * NewShardQueue: create a queue with netpoll.Connection. - * ShardSize: the recommended number of shards is 32. - */ -var ShardSize int - -func init() { - ShardSize = runtime.GOMAXPROCS(0) -} +// ShardQueue uses the netpoll nocopy Writer API to merge multi packets and send them at once. +// The Data Flush is passively triggered by ShardQueue.Add and does not require user operations. +// If there is an error in the data transmission, the connection will be closed. -// NewShardQueue . -func NewShardQueue(size int, conn netpoll.Connection) (queue *ShardQueue) { +// NewShardQueue create a queue with netpoll.Connection +func NewShardQueue(shardsize int, conn netpoll.Connection) (queue *ShardQueue) { queue = &ShardQueue{ - conn: conn, - size: int32(size), - getters: make([][]WriterGetter, size), - swap: make([]WriterGetter, 0, 64), - locks: make([]int32, size), + conn: conn, + shardsize: uint32(shardsize), + shards: make([][]WriterGetter, shardsize), + locks: make([]int32, shardsize), } - for i := range queue.getters { - queue.getters[i] = make([]WriterGetter, 0, 64) + for i := range queue.shards { + queue.shards[i] = make([]WriterGetter, 0, 64) } - queue.list = make([]int32, size) + queue.shard = make([]WriterGetter, 0, 64) return queue } // WriterGetter is used to get a netpoll.Writer. type WriterGetter func() (buf netpoll.Writer, isNil bool) -// ShardQueue uses the netpoll's nocopy API to merge and send data. +// ShardQueue uses the netpoll s nocopy API to merge and send data. // The Data Flush is passively triggered by ShardQueue.Add and does not require user operations. // If there is an error in the data transmission, the connection will be closed. // ShardQueue.Add: add the data to be sent. type ShardQueue struct { + // state definition: + // active : only active state can allow user Add new task + // closing : ShardQueue.Close is called and try to close gracefully, cannot Add new data + // closed : Gracefully shutdown finished + state int32 + conn netpoll.Connection - idx, size int32 - getters [][]WriterGetter // len(getters) = size - swap []WriterGetter // use for swap - locks []int32 // len(locks) = size - queueTrigger + size uint32 // the size of all getters in all shards + shardsize uint32 // the size of shards + shards [][]WriterGetter // the shards of getters, len(shards) = shardsize + shard []WriterGetter // the shard is dealing, use shard to swap + locks []int32 // the locks of shards, len(locks) = shardsize + // trigger used to avoid triggering function re-enter twice. + // trigger == 0: nothing to do + // trigger == 1: we should start a new triggering() + // trigger >= 2: triggering() already started + trigger int32 } const ( - // queueTrigger state + // ShardQueue state active = 0 closing = 1 closed = 2 ) -// here for trigger -type queueTrigger struct { - trigger int32 - state int32 // 0: active, 1: closing, 2: closed - runNum int32 - w, r int32 // ptr of list - list []int32 // record the triggered shard - listLock sync.Mutex // list total lock -} +var idgen uint32 -// Add adds to q.getters[shard] -func (q *ShardQueue) Add(gts ...WriterGetter) { - if atomic.LoadInt32(&q.state) != active { - return +// Add adds gts to ShardQueue +func (q *ShardQueue) Add(gts ...WriterGetter) bool { + size := uint32(len(gts)) + if size == 0 || atomic.LoadInt32(&q.state) != active { + return false } - shard := atomic.AddInt32(&q.idx, 1) % q.size - q.lock(shard) - trigger := len(q.getters[shard]) == 0 - q.getters[shard] = append(q.getters[shard], gts...) - q.unlock(shard) - if trigger { - q.triggering(shard) + + // get current shard id + shardid := atomic.AddUint32(&idgen, 1) % q.shardsize + // add new shards into shard + q.lock(shardid) + q.shards[shardid] = append(q.shards[shardid], gts...) + // size update should happen in lock, because we should make sure when q.shards unlock, worker can get the correct size + _ = atomic.AddUint32(&q.size, size) + q.unlock(shardid) + + if atomic.AddInt32(&q.trigger, 1) == 1 { + go q.triggering(shardid) } + return true } +// Close graceful shutdown the ShardQueue and will flush all data added first func (q *ShardQueue) Close() error { if !atomic.CompareAndSwapInt32(&q.state, active, closing) { return fmt.Errorf("shardQueue has been closed") @@ -120,65 +116,55 @@ func (q *ShardQueue) Close() error { } // triggering shard. -func (q *ShardQueue) triggering(shard int32) { - q.listLock.Lock() - q.w = (q.w + 1) % q.size - q.list[q.w] = shard - q.listLock.Unlock() - - if atomic.AddInt32(&q.trigger, 1) > 1 { - return +func (q *ShardQueue) triggering(shardid uint32) { +WORKER: + for atomic.LoadUint32(&q.size) > 0 { + // lock & shard + q.lock(shardid) + shard := q.shards[shardid] + q.shards[shardid] = q.shard[:0] + q.shard = shard[:0] // reuse current shard's space for next round + q.unlock(shardid) + + if len(shard) > 0 { + // flush shard + q.deal(shard) + // only decrease q.size when the shard dealt + atomic.AddUint32(&q.size, -uint32(len(shard))) + } + // if there have any new data, the next shard must not be empty + shardid = (shardid + 1) % q.shardsize } - q.foreach() -} - -// foreach swap r & w. It's not concurrency safe. -func (q *ShardQueue) foreach() { - if atomic.AddInt32(&q.runNum, 1) > 1 { - return + // flush connection + q.flush() + + // [IMPORTANT] Atomic Double Check: + // ShardQueue.Add will ensure it will always update 'size' and 'trigger'. + // - If CAS(q.trigger, oldTrigger, 0) = true, it means there is no triggering() call during size check, + // so it's safe to exit triggering(). And any new Add() call will start triggering() successfully. + // - If CAS failed, there may have a failed triggering() call during Load(q.trigger) and CAS(q.trigger), + // so we should re-check q.size again from beginning. + oldTrigger := atomic.LoadInt32(&q.trigger) + if atomic.LoadUint32(&q.size) > 0 { + goto WORKER + } + if !atomic.CompareAndSwapInt32(&q.trigger, oldTrigger, 0) { + goto WORKER } - gopool.CtxGo(nil, func() { - var negNum int32 // is negative number of triggerNum - for triggerNum := atomic.LoadInt32(&q.trigger); triggerNum > 0; { - q.r = (q.r + 1) % q.size - shared := q.list[q.r] - - // lock & swap - q.lock(shared) - tmp := q.getters[shared] - q.getters[shared] = q.swap[:0] - q.swap = tmp - q.unlock(shared) - - // deal - q.deal(q.swap) - negNum-- - if triggerNum+negNum == 0 { - triggerNum = atomic.AddInt32(&q.trigger, negNum) - negNum = 0 - } - } - q.flush() - // quit & check again - atomic.StoreInt32(&q.runNum, 0) - if atomic.LoadInt32(&q.trigger) > 0 { - q.foreach() - return - } - // if state is closing, change it to closed - atomic.CompareAndSwapInt32(&q.state, closing, closed) - }) + // if state is closing, change it to closed + atomic.CompareAndSwapInt32(&q.state, closing, closed) + return } -// deal is used to get deal of netpoll.Writer. +// deal append all getters into connection func (q *ShardQueue) deal(gts []WriterGetter) { writer := q.conn.Writer() for _, gt := range gts { buf, isNil := gt() if !isNil { err := writer.Append(buf) - if err != nil { + if err != nil { // never happen q.conn.Close() return } @@ -186,7 +172,7 @@ func (q *ShardQueue) deal(gts []WriterGetter) { } } -// flush is used to flush netpoll.Writer. +// flush the connection and send all appended data func (q *ShardQueue) flush() { err := q.conn.Writer().Flush() if err != nil { @@ -196,13 +182,13 @@ func (q *ShardQueue) flush() { } // lock shard. -func (q *ShardQueue) lock(shard int32) { +func (q *ShardQueue) lock(shard uint32) { for !atomic.CompareAndSwapInt32(&q.locks[shard], 0, 1) { runtime.Gosched() } } // unlock shard. -func (q *ShardQueue) unlock(shard int32) { +func (q *ShardQueue) unlock(shard uint32) { atomic.StoreInt32(&q.locks[shard], 0) } diff --git a/mux/shard_queue_test.go b/mux/shard_queue_test.go index 7a595d21..8681a67a 100644 --- a/mux/shard_queue_test.go +++ b/mux/shard_queue_test.go @@ -18,7 +18,10 @@ package mux import ( + "io" "net" + "runtime" + "sync/atomic" "testing" "time" @@ -27,18 +30,22 @@ import ( func TestShardQueue(t *testing.T) { var svrConn net.Conn + var cliConn netpoll.Connection accepted := make(chan struct{}) + stopped := make(chan struct{}) + streams, framesize := 128, 1024 + totalsize := int32(streams * framesize) + var send, read int32 + // create server connection network, address := "tcp", ":18888" ln, err := net.Listen("tcp", ":18888") MustNil(t, err) - stop := make(chan int, 1) - defer close(stop) go func() { var err error for { select { - case <-stop: + case <-stopped: err = ln.Close() MustNil(t, err) return @@ -47,35 +54,56 @@ func TestShardQueue(t *testing.T) { svrConn, err = ln.Accept() MustNil(t, err) accepted <- struct{}{} + go func() { + recv := make([]byte, 10240) + for { + n, err := svrConn.Read(recv) + atomic.AddInt32(&read, int32(n)) + for i := 0; i < n; i++ { + MustTrue(t, recv[i] == 'a') + } + if err == io.EOF { + return + } + MustNil(t, err) + } + }() } }() - conn, err := netpoll.DialConnection(network, address, time.Second) + // create client connection + cliConn, err = netpoll.DialConnection(network, address, time.Second) MustNil(t, err) - <-accepted + <-accepted // wait svrConn accepted - // test - queue := NewShardQueue(4, conn) - count, pkgsize := 16, 11 - for i := 0; i < int(count); i++ { - var getter WriterGetter = func() (buf netpoll.Writer, isNil bool) { - buf = netpoll.NewLinkBuffer(pkgsize) - buf.Malloc(pkgsize) - return buf, false - } - queue.Add(getter) + // cliConn flush packets to svrConn with ShardQueue + queue := NewShardQueue(4, cliConn) + for i := 0; i < streams; i++ { + go func() { + var getter WriterGetter = func() (buf netpoll.Writer, isNil bool) { + buf = netpoll.NewLinkBuffer(framesize) + data, err := buf.Malloc(framesize) + MustNil(t, err) + for b := 0; b < framesize; b++ { + data[b] = 'a' + } + return buf, false + } + if queue.Add(getter) { + atomic.AddInt32(&send, int32(framesize)) + } + }() } + //cliConn graceful close, shardQueue should flush all data correctly + for atomic.LoadInt32(&send) < totalsize/2 { + t.Logf("waiting send all packets: send=%d", atomic.LoadInt32(&send)) + runtime.Gosched() + } err = queue.Close() MustNil(t, err) - total := count * pkgsize - recv := make([]byte, total) - rn, err := svrConn.Read(recv) - MustNil(t, err) - Equal(t, rn, total) -} - -// TODO: need mock flush -func BenchmarkShardQueue(b *testing.B) { - b.Skip() + for atomic.LoadInt32(&read) != atomic.LoadInt32(&send) { + t.Logf("waiting read all packets: read=%d", atomic.LoadInt32(&read)) + runtime.Gosched() + } }