Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: refactor: simplify shard queue #333

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 91 additions & 105 deletions mux/shard_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,93 +17,89 @@ package mux
import (
"fmt"
"runtime"
"sync"
"sync/atomic"

"github.com/bytedance/gopkg/util/gopool"

"github.com/cloudwego/netpoll"
)

/* DOC:
* ShardQueue uses the netpoll's nocopy API to merge and send data.
* The Data Flush is passively triggered by ShardQueue.Add and does not require user operations.
* If there is an error in the data transmission, the connection will be closed.
*
* ShardQueue.Add: add the data to be sent.
* NewShardQueue: create a queue with netpoll.Connection.
* ShardSize: the recommended number of shards is 32.
*/
var ShardSize int

func init() {
ShardSize = runtime.GOMAXPROCS(0)
}
// ShardQueue uses the netpoll nocopy Writer API to merge multi packets and send them at once.
// The Data Flush is passively triggered by ShardQueue.Add and does not require user operations.
// If there is an error in the data transmission, the connection will be closed.

// NewShardQueue .
func NewShardQueue(size int, conn netpoll.Connection) (queue *ShardQueue) {
// NewShardQueue create a queue with netpoll.Connection
func NewShardQueue(shardsize int, conn netpoll.Connection) (queue *ShardQueue) {
queue = &ShardQueue{
conn: conn,
size: int32(size),
getters: make([][]WriterGetter, size),
swap: make([]WriterGetter, 0, 64),
locks: make([]int32, size),
conn: conn,
shardsize: uint32(shardsize),
shards: make([][]WriterGetter, shardsize),
locks: make([]int32, shardsize),
}
for i := range queue.getters {
queue.getters[i] = make([]WriterGetter, 0, 64)
for i := range queue.shards {
queue.shards[i] = make([]WriterGetter, 0, 64)
}
queue.list = make([]int32, size)
queue.shard = make([]WriterGetter, 0, 64)
return queue
}

// WriterGetter is used to get a netpoll.Writer.
type WriterGetter func() (buf netpoll.Writer, isNil bool)

// ShardQueue uses the netpoll's nocopy API to merge and send data.
// ShardQueue uses the netpoll s nocopy API to merge and send data.
// The Data Flush is passively triggered by ShardQueue.Add and does not require user operations.
// If there is an error in the data transmission, the connection will be closed.
// ShardQueue.Add: add the data to be sent.
type ShardQueue struct {
// state definition:
// active : only active state can allow user Add new task
// closing : ShardQueue.Close is called and try to close gracefully, cannot Add new data
// closed : Gracefully shutdown finished
state int32

conn netpoll.Connection
idx, size int32
getters [][]WriterGetter // len(getters) = size
swap []WriterGetter // use for swap
locks []int32 // len(locks) = size
queueTrigger
size uint32 // the size of all getters in all shards
shardsize uint32 // the size of shards
shards [][]WriterGetter // the shards of getters, len(shards) = shardsize
shard []WriterGetter // the shard is dealing, use shard to swap
locks []int32 // the locks of shards, len(locks) = shardsize
// trigger used to avoid triggering function re-enter twice.
// trigger == 0: nothing to do
// trigger == 1: we should start a new triggering()
// trigger >= 2: triggering() already started
trigger int32
}

const (
// queueTrigger state
// ShardQueue state
active = 0
closing = 1
closed = 2
)

// here for trigger
type queueTrigger struct {
trigger int32
state int32 // 0: active, 1: closing, 2: closed
runNum int32
w, r int32 // ptr of list
list []int32 // record the triggered shard
listLock sync.Mutex // list total lock
}
var idgen uint32

// Add adds to q.getters[shard]
func (q *ShardQueue) Add(gts ...WriterGetter) {
if atomic.LoadInt32(&q.state) != active {
return
// Add adds gts to ShardQueue
func (q *ShardQueue) Add(gts ...WriterGetter) bool {
size := uint32(len(gts))
if size == 0 || atomic.LoadInt32(&q.state) != active {
return false
}
shard := atomic.AddInt32(&q.idx, 1) % q.size
q.lock(shard)
trigger := len(q.getters[shard]) == 0
q.getters[shard] = append(q.getters[shard], gts...)
q.unlock(shard)
if trigger {
q.triggering(shard)

// get current shard id
shardid := atomic.AddUint32(&idgen, 1) % q.shardsize
// add new shards into shard
q.lock(shardid)
q.shards[shardid] = append(q.shards[shardid], gts...)
// size update should happen in lock, because we should make sure when q.shards unlock, worker can get the correct size
_ = atomic.AddUint32(&q.size, size)
q.unlock(shardid)

if atomic.AddInt32(&q.trigger, 1) == 1 {
go q.triggering(shardid)
}
return true
}

// Close graceful shutdown the ShardQueue and will flush all data added first
func (q *ShardQueue) Close() error {
if !atomic.CompareAndSwapInt32(&q.state, active, closing) {
return fmt.Errorf("shardQueue has been closed")
Expand All @@ -120,73 +116,63 @@ func (q *ShardQueue) Close() error {
}

// triggering shard.
func (q *ShardQueue) triggering(shard int32) {
q.listLock.Lock()
q.w = (q.w + 1) % q.size
q.list[q.w] = shard
q.listLock.Unlock()

if atomic.AddInt32(&q.trigger, 1) > 1 {
return
func (q *ShardQueue) triggering(shardid uint32) {
WORKER:
for atomic.LoadUint32(&q.size) > 0 {
// lock & shard
q.lock(shardid)
shard := q.shards[shardid]
q.shards[shardid] = q.shard[:0]
q.shard = shard[:0] // reuse current shard's space for next round
q.unlock(shardid)

if len(shard) > 0 {
// flush shard
q.deal(shard)
// only decrease q.size when the shard dealt
atomic.AddUint32(&q.size, -uint32(len(shard)))
}
// if there have any new data, the next shard must not be empty
shardid = (shardid + 1) % q.shardsize
}
q.foreach()
}

// foreach swap r & w. It's not concurrency safe.
func (q *ShardQueue) foreach() {
if atomic.AddInt32(&q.runNum, 1) > 1 {
return
// flush connection
q.flush()

// [IMPORTANT] Atomic Double Check:
// ShardQueue.Add will ensure it will always update 'size' and 'trigger'.
// - If CAS(q.trigger, oldTrigger, 0) = true, it means there is no triggering() call during size check,
// so it's safe to exit triggering(). And any new Add() call will start triggering() successfully.
// - If CAS failed, there may have a failed triggering() call during Load(q.trigger) and CAS(q.trigger),
// so we should re-check q.size again from beginning.
oldTrigger := atomic.LoadInt32(&q.trigger)
if atomic.LoadUint32(&q.size) > 0 {
goto WORKER
}
if !atomic.CompareAndSwapInt32(&q.trigger, oldTrigger, 0) {
goto WORKER
}
gopool.CtxGo(nil, func() {
var negNum int32 // is negative number of triggerNum
for triggerNum := atomic.LoadInt32(&q.trigger); triggerNum > 0; {
q.r = (q.r + 1) % q.size
shared := q.list[q.r]

// lock & swap
q.lock(shared)
tmp := q.getters[shared]
q.getters[shared] = q.swap[:0]
q.swap = tmp
q.unlock(shared)

// deal
q.deal(q.swap)
negNum--
if triggerNum+negNum == 0 {
triggerNum = atomic.AddInt32(&q.trigger, negNum)
negNum = 0
}
}
q.flush()

// quit & check again
atomic.StoreInt32(&q.runNum, 0)
if atomic.LoadInt32(&q.trigger) > 0 {
q.foreach()
return
}
// if state is closing, change it to closed
atomic.CompareAndSwapInt32(&q.state, closing, closed)
})
// if state is closing, change it to closed
atomic.CompareAndSwapInt32(&q.state, closing, closed)
return
}

// deal is used to get deal of netpoll.Writer.
// deal append all getters into connection
func (q *ShardQueue) deal(gts []WriterGetter) {
writer := q.conn.Writer()
for _, gt := range gts {
buf, isNil := gt()
if !isNil {
err := writer.Append(buf)
if err != nil {
if err != nil { // never happen
q.conn.Close()
return
}
}
}
}

// flush is used to flush netpoll.Writer.
// flush the connection and send all appended data
func (q *ShardQueue) flush() {
err := q.conn.Writer().Flush()
if err != nil {
Expand All @@ -196,13 +182,13 @@ func (q *ShardQueue) flush() {
}

// lock shard.
func (q *ShardQueue) lock(shard int32) {
func (q *ShardQueue) lock(shard uint32) {
for !atomic.CompareAndSwapInt32(&q.locks[shard], 0, 1) {
runtime.Gosched()
}
}

// unlock shard.
func (q *ShardQueue) unlock(shard int32) {
func (q *ShardQueue) unlock(shard uint32) {
atomic.StoreInt32(&q.locks[shard], 0)
}
78 changes: 53 additions & 25 deletions mux/shard_queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
package mux

import (
"io"
"net"
"runtime"
"sync/atomic"
"testing"
"time"

Expand All @@ -27,18 +30,22 @@ import (

func TestShardQueue(t *testing.T) {
var svrConn net.Conn
var cliConn netpoll.Connection
accepted := make(chan struct{})
stopped := make(chan struct{})
streams, framesize := 128, 1024
totalsize := int32(streams * framesize)
var send, read int32

// create server connection
network, address := "tcp", ":18888"
ln, err := net.Listen("tcp", ":18888")
MustNil(t, err)
stop := make(chan int, 1)
defer close(stop)
go func() {
var err error
for {
select {
case <-stop:
case <-stopped:
err = ln.Close()
MustNil(t, err)
return
Expand All @@ -47,35 +54,56 @@ func TestShardQueue(t *testing.T) {
svrConn, err = ln.Accept()
MustNil(t, err)
accepted <- struct{}{}
go func() {
recv := make([]byte, 10240)
for {
n, err := svrConn.Read(recv)
atomic.AddInt32(&read, int32(n))
for i := 0; i < n; i++ {
MustTrue(t, recv[i] == 'a')
}
if err == io.EOF {
return
}
MustNil(t, err)
}
}()
}
}()

conn, err := netpoll.DialConnection(network, address, time.Second)
// create client connection
cliConn, err = netpoll.DialConnection(network, address, time.Second)
MustNil(t, err)
<-accepted
<-accepted // wait svrConn accepted

// test
queue := NewShardQueue(4, conn)
count, pkgsize := 16, 11
for i := 0; i < int(count); i++ {
var getter WriterGetter = func() (buf netpoll.Writer, isNil bool) {
buf = netpoll.NewLinkBuffer(pkgsize)
buf.Malloc(pkgsize)
return buf, false
}
queue.Add(getter)
// cliConn flush packets to svrConn with ShardQueue
queue := NewShardQueue(4, cliConn)
for i := 0; i < streams; i++ {
go func() {
var getter WriterGetter = func() (buf netpoll.Writer, isNil bool) {
buf = netpoll.NewLinkBuffer(framesize)
data, err := buf.Malloc(framesize)
MustNil(t, err)
for b := 0; b < framesize; b++ {
data[b] = 'a'
}
return buf, false
}
if queue.Add(getter) {
atomic.AddInt32(&send, int32(framesize))
}
}()
}

//cliConn graceful close, shardQueue should flush all data correctly
for atomic.LoadInt32(&send) < totalsize/2 {
t.Logf("waiting send all packets: send=%d", atomic.LoadInt32(&send))
runtime.Gosched()
}
err = queue.Close()
MustNil(t, err)
total := count * pkgsize
recv := make([]byte, total)
rn, err := svrConn.Read(recv)
MustNil(t, err)
Equal(t, rn, total)
}

// TODO: need mock flush
func BenchmarkShardQueue(b *testing.B) {
b.Skip()
for atomic.LoadInt32(&read) != atomic.LoadInt32(&send) {
t.Logf("waiting read all packets: read=%d", atomic.LoadInt32(&read))
runtime.Gosched()
}
}
Loading