Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]fix: shard queue panic #248

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 78 additions & 46 deletions mux/shard_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ package mux
import (
"fmt"
"runtime"
"sync"
"sync/atomic"
"time"

"github.com/bytedance/gopkg/util/gopool"

Expand All @@ -43,16 +43,18 @@ func init() {
// NewShardQueue .
func NewShardQueue(size int, conn netpoll.Connection) (queue *ShardQueue) {
queue = &ShardQueue{
conn: conn,
size: int32(size),
getters: make([][]WriterGetter, size),
swap: make([]WriterGetter, 0, 64),
locks: make([]int32, size),
conn: conn,
size: int32(size),
getters: make([][]WriterGetter, size),
swap: make([]WriterGetter, 0, 64),
locks: make([]int32, size),
closeNotif: make(chan struct{}),
}
for i := range queue.getters {
queue.getters[i] = make([]WriterGetter, 0, 64)
}
queue.list = make([]int32, size)
// To avoid w equals to r when loop writing, make list larger than size.
queue.list = make([]int32, size+1)
return queue
}

Expand All @@ -69,6 +71,8 @@ type ShardQueue struct {
getters [][]WriterGetter // len(getters) = size
swap []WriterGetter // use for swap
locks []int32 // len(locks) = size

closeNotif chan struct{}
queueTrigger
}

Expand All @@ -81,17 +85,29 @@ const (

// here for trigger
type queueTrigger struct {
trigger int32
state int32 // 0: active, 1: closing, 2: closed
runNum int32
w, r int32 // ptr of list
list []int32 // record the triggered shard
listLock sync.Mutex // list total lock
bufNum int32
state int32 // 0: active, 1: closing, 2: closed
runNum int32
w, r int32 // ptr of list
list []int32 // record the triggered shard
}

func (q *queueTrigger) length() int {
w := int(atomic.LoadInt32(&q.w))
r := int(atomic.LoadInt32(&q.r))
if w < r {
w += len(q.list)
}
return w - r
}

// Add adds to q.getters[shard]
func (q *ShardQueue) Add(gts ...WriterGetter) {
atomic.AddInt32(&q.bufNum, 1)
if atomic.LoadInt32(&q.state) != active {
if atomic.AddInt32(&q.bufNum, -1) <= 0 {
close(q.closeNotif)
}
return
}
shard := atomic.AddInt32(&q.idx, 1) % q.size
Expand All @@ -109,90 +125,106 @@ func (q *ShardQueue) Close() error {
return fmt.Errorf("shardQueue has been closed")
}
// wait for all tasks finished
for atomic.LoadInt32(&q.state) != closed {
if atomic.LoadInt32(&q.trigger) == 0 {
if atomic.LoadInt32(&q.bufNum) == 0 {
atomic.StoreInt32(&q.state, closed)
} else {
timeout := time.NewTimer(3 * time.Second)
select {
case <-q.closeNotif:
atomic.StoreInt32(&q.state, closed)
return nil
timeout.Stop()
case <-timeout.C:
atomic.StoreInt32(&q.state, closed)
return fmt.Errorf("shardQueue close timeout")
}
runtime.Gosched()
}
return nil
}

// triggering shard.
func (q *ShardQueue) triggering(shard int32) {
q.listLock.Lock()
q.w = (q.w + 1) % q.size
q.list[q.w] = shard
q.listLock.Unlock()

if atomic.AddInt32(&q.trigger, 1) > 1 {
return
for {
ow := atomic.LoadInt32(&q.w)
nw := (ow + 1) % int32(len(q.list))
if atomic.CompareAndSwapInt32(&q.w, ow, nw) {
q.list[nw] = shard
break
}
}
q.foreach()
}

// foreach swap r & w. It's not concurrency safe.
// foreach swap r & w.
func (q *ShardQueue) foreach() {
if atomic.AddInt32(&q.runNum, 1) > 1 {
return
}
gopool.CtxGo(nil, func() {
var negNum int32 // is negative number of triggerNum
for triggerNum := atomic.LoadInt32(&q.trigger); triggerNum > 0; {
q.r = (q.r + 1) % q.size
shared := q.list[q.r]
var negBufNum int32 // is negative number of bufNum
for q.length() > 0 {
nr := (atomic.LoadInt32(&q.r) + 1) % int32(len(q.list))
atomic.StoreInt32(&q.r, nr)
shard := q.list[nr]

// lock & swap
q.lock(shared)
tmp := q.getters[shared]
q.getters[shared] = q.swap[:0]
q.lock(shard)
tmp := q.getters[shard]
q.getters[shard] = q.swap[:0]
q.swap = tmp
q.unlock(shared)
q.unlock(shard)

// deal
q.deal(q.swap)
negNum--
if triggerNum+negNum == 0 {
triggerNum = atomic.AddInt32(&q.trigger, negNum)
negNum = 0
if err := q.deal(q.swap); err != nil {
close(q.closeNotif)
return
}
negBufNum -= int32(len(q.swap))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么是 -= len(q.swap) 呢? bufNum 每增加一次,对应的 shard 中不一定只写入了一个 getter 吧,这样似乎 negBufNum != -bufNum ?

另外请问下这份PR还会继续维护吗?

}
if negBufNum < 0 {
if err := q.flush(); err != nil {
close(q.closeNotif)
return
}
}

// MUST decrease bufNum first.
if atomic.AddInt32(&q.bufNum, negBufNum) <= 0 && atomic.LoadInt32(&q.state) != active {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为啥也要判断?

close(q.closeNotif)
return
}
q.flush()

// quit & check again
atomic.StoreInt32(&q.runNum, 0)
if atomic.LoadInt32(&q.trigger) > 0 {
if q.length() > 0 {
q.foreach()
return
}
// if state is closing, change it to closed
atomic.CompareAndSwapInt32(&q.state, closing, closed)
})
}

// deal is used to get deal of netpoll.Writer.
func (q *ShardQueue) deal(gts []WriterGetter) {
func (q *ShardQueue) deal(gts []WriterGetter) error {
writer := q.conn.Writer()
for _, gt := range gts {
buf, isNil := gt()
if !isNil {
err := writer.Append(buf)
if err != nil {
q.conn.Close()
return
return err
}
}
}
return nil
}

// flush is used to flush netpoll.Writer.
func (q *ShardQueue) flush() {
func (q *ShardQueue) flush() error {
err := q.conn.Writer().Flush()
if err != nil {
q.conn.Close()
return
}
return err
}

// lock shard.
Expand Down