diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index add69d1e..bff66b1b 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -39,23 +39,23 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.16 + go-version: "1.20" - - uses: actions/cache@v2 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- +# - uses: actions/cache@v2 +# with: +# path: ~/go/pkg/mod +# key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} +# restore-keys: | +# ${{ runner.os }}-go- # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v1 + uses: github/codeql-action/init@v2 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -66,7 +66,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@v1 + uses: github/codeql-action/autobuild@v2 # ℹ️ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl @@ -80,4 +80,4 @@ jobs: # make release - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v1 + uses: github/codeql-action/analyze@v2 diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index abd6cf13..a923b418 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -6,7 +6,7 @@ jobs: compatibility-test: strategy: matrix: - go: [ 1.15, "1.20" ] + go: [ 1.15, "1.21" ] os: [ X64, ARM64 ] runs-on: ${{ matrix.os }} steps: @@ -15,12 +15,12 @@ jobs: uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - - uses: actions/cache@v2 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- +# - uses: actions/cache@v2 +# with: +# path: ~/go/pkg/mod +# key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} +# restore-keys: | +# ${{ runner.os }}-go- - name: Unit Test run: go test -v -race -covermode=atomic -coverprofile=coverage.out ./... - name: Benchmark @@ -33,12 +33,12 @@ jobs: uses: actions/setup-go@v3 with: go-version: "1.20" - - uses: actions/cache@v2 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- +# - uses: actions/cache@v2 +# with: +# path: ~/go/pkg/mod +# key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} +# restore-keys: | +# ${{ runner.os }}-go- - name: Build Test run: go vet -v ./... style-test: diff --git a/README.md b/README.md index 5a8662d5..340cad09 100644 --- a/README.md +++ b/README.md @@ -54,10 +54,8 @@ For more information, please refer to [Document](#document). - Linux, macOS (operating system) * **Future** - - [multisyscall][multisyscall] supports batch system calls - [io_uring][io_uring] - Shared Memory IPC - - Serial scheduling I/O, suitable for pure computing - TLS - UDP @@ -102,5 +100,4 @@ More benchmarks reference [kitex-benchmark][kitex-benchmark] and [hertz-benchmar [LinkBuffer]: nocopy_linkbuffer.go [gopool]: https://github.com/bytedance/gopkg/tree/develop/util/gopool [mcache]: https://github.com/bytedance/gopkg/tree/develop/lang/mcache -[multisyscall]: https://github.com/cloudwego/multisyscall [io_uring]: https://github.com/axboe/liburing diff --git a/README_CN.md b/README_CN.md index 8e8d0bac..376691ec 100644 --- a/README_CN.md +++ b/README_CN.md @@ -49,10 +49,8 @@ goroutine,大幅增加调度开销。此外,[net.Conn][net.Conn] 没有提 - 支持 Linux,macOS(操作系统) * **即将开源** - - [multisyscall][multisyscall] 支持批量系统调用 - [io_uring][io_uring] - Shared Memory IPC - - 串行调度 I/O,适用于纯计算 - 支持 TLS - 支持 UDP @@ -95,5 +93,4 @@ goroutine,大幅增加调度开销。此外,[net.Conn][net.Conn] 没有提 [LinkBuffer]: nocopy_linkbuffer.go [gopool]: https://github.com/bytedance/gopkg/tree/develop/util/gopool [mcache]: https://github.com/bytedance/gopkg/tree/develop/lang/mcache -[multisyscall]: https://github.com/cloudwego/multisyscall [io_uring]: https://github.com/axboe/liburing diff --git a/connection_impl.go b/connection_impl.go index def1d97c..1fa1a8e4 100644 --- a/connection_impl.go +++ b/connection_impl.go @@ -36,7 +36,7 @@ type connection struct { operator *FDOperator readTimeout time.Duration readTimer *time.Timer - readTrigger chan struct{} + readTrigger chan error waitReadSize int64 writeTimeout time.Duration writeTimer *time.Timer @@ -319,9 +319,9 @@ var barrierPool = sync.Pool{ // init initialize the connection with options func (c *connection) init(conn Conn, opts *options) (err error) { // init buffer, barrier, finalizer - c.readTrigger = make(chan struct{}, 1) + c.readTrigger = make(chan error, 1) c.writeTrigger = make(chan error, 1) - c.bookSize, c.maxSize = block1k/2, pagesize + c.bookSize, c.maxSize = pagesize, pagesize c.inputBuffer, c.outputBuffer = NewLinkBuffer(pagesize), NewLinkBuffer() c.inputBarrier, c.outputBarrier = barrierPool.Get().(*barrier), barrierPool.Get().(*barrier) @@ -357,19 +357,12 @@ func (c *connection) initNetFD(conn Conn) { } func (c *connection) initFDOperator() { - var op *FDOperator - if c.pd != nil && c.pd.operator != nil { - // reuse operator created at connect step - op = c.pd.operator - } else { - poll := pollmanager.Pick() - op = poll.Alloc() - } + poll := pollmanager.Pick() + op := poll.Alloc() op.FD = c.fd op.OnRead, op.OnWrite, op.OnHup = nil, nil, c.onHup op.Inputs, op.InputAck = c.inputs, c.inputAck op.Outputs, op.OutputAck = c.outputs, c.outputAck - c.operator = op } @@ -385,9 +378,9 @@ func (c *connection) initFinalizer() { }) } -func (c *connection) triggerRead() { +func (c *connection) triggerRead(err error) { select { - case c.readTrigger <- struct{}{}: + case c.readTrigger <- err: default: } } @@ -411,10 +404,17 @@ func (c *connection) waitRead(n int) (err error) { } // wait full n for c.inputBuffer.Len() < n { - if !c.IsActive() { + switch c.status(closing) { + case poller: + return Exception(ErrEOF, "wait read") + case user: return Exception(ErrConnClosed, "wait read") + default: + err = <-c.readTrigger + if err != nil { + return err + } } - <-c.readTrigger } return nil } @@ -429,24 +429,32 @@ func (c *connection) waitReadWithTimeout(n int) (err error) { } for c.inputBuffer.Len() < n { - if !c.IsActive() { - // cannot return directly, stop timer before ! + switch c.status(closing) { + case poller: + // cannot return directly, stop timer first! + err = Exception(ErrEOF, "wait read") + goto RET + case user: + // cannot return directly, stop timer first! err = Exception(ErrConnClosed, "wait read") - break - } - - select { - case <-c.readTimer.C: - // double check if there is enough data to be read - if c.inputBuffer.Len() >= n { - return nil + goto RET + default: + select { + case <-c.readTimer.C: + // double check if there is enough data to be read + if c.inputBuffer.Len() >= n { + return nil + } + return Exception(ErrReadTimeout, c.remoteAddr.String()) + case err = <-c.readTrigger: + if err != nil { + return err + } + continue } - return Exception(ErrReadTimeout, c.remoteAddr.String()) - case <-c.readTrigger: - continue } } - +RET: // clean timer.C if !c.readTimer.Stop() { <-c.readTimer.C diff --git a/connection_lock.go b/connection_lock.go index 2dce6622..4b0f7360 100644 --- a/connection_lock.go +++ b/connection_lock.go @@ -19,7 +19,7 @@ import ( "sync/atomic" ) -type who int32 +type who = int32 const ( none who = iota @@ -65,6 +65,14 @@ func (l *locker) isCloseBy(w who) (yes bool) { return atomic.LoadInt32(&l.keychain[closing]) == int32(w) } +func (l *locker) status(k key) int32 { + return atomic.LoadInt32(&l.keychain[k]) +} + +func (l *locker) force(k key, v int32) { + atomic.StoreInt32(&l.keychain[k], v) +} + func (l *locker) lock(k key) (success bool) { return atomic.CompareAndSwapInt32(&l.keychain[k], 0, 1) } diff --git a/connection_onevent.go b/connection_onevent.go index f8351f32..9b87f01b 100644 --- a/connection_onevent.go +++ b/connection_onevent.go @@ -177,18 +177,45 @@ func (c *connection) onProcess(isProcessable func(c *connection) bool, process f } // add new task var task = func() { + panicked := true + defer func() { + // cannot use recover() here, since we don't want to break the panic stack + if panicked { + c.unlock(processing) + if c.IsActive() { + c.Close() + } else { + c.closeCallback(false, false) + } + } + }() START: // `process` must be executed at least once if `isProcessable` in order to cover the `send & close by peer` case. // Then the loop processing must ensure that the connection `IsActive`. if isProcessable(c) { process(c) } - for c.IsActive() && isProcessable(c) { + // `process` must either eventually read all the input data or actively Close the connection, + // otherwise the goroutine will fall into a dead loop. + var closedBy who + for { + closedBy = c.status(closing) + // close by user or no processable + if closedBy == user || !isProcessable(c) { + break + } process(c) } // Handling callback if connection has been closed. - if !c.IsActive() { - c.closeCallback(false) + if closedBy != none { + // if closed by user when processing, it "may" needs detach + needDetach := closedBy == user + // Here is a conor case that operator will be detached twice: + // If server closed the connection(client OnHup will detach op first and closeBy=poller), + // and then client's OnRequest function also closed the connection(closeBy=user). + // But operator already prevent that detach twice will not cause any problem + c.closeCallback(false, needDetach) + panicked = false return } c.unlock(processing) @@ -197,6 +224,7 @@ func (c *connection) onProcess(isProcessable func(c *connection) bool, process f goto START } // task exits + panicked = false return } @@ -207,14 +235,14 @@ func (c *connection) onProcess(isProcessable func(c *connection) bool, process f // closeCallback . // It can be confirmed that closeCallback and onRequest will not be executed concurrently. // If onRequest is still running, it will trigger closeCallback on exit. -func (c *connection) closeCallback(needLock bool) (err error) { +func (c *connection) closeCallback(needLock bool, needDetach bool) (err error) { if needLock && !c.lock(processing) { return nil } - // If Close is called during OnPrepare, poll is not registered. - if c.isCloseBy(user) && c.operator.poll != nil { - if err = c.operator.Control(PollDetach); err != nil { - logger.Printf("NETPOLL: closeCallback detach operator failed: %v", err) + if needDetach && c.operator.poll != nil { // If Close is called during OnPrepare, poll is not registered. + // PollDetach only happen when user call conn.Close() or poller detect error + if err := c.operator.Control(PollDetach); err != nil { + logger.Printf("NETPOLL: closeCallback[%v,%v] detach operator failed: %v", needLock, needDetach, err) } } var latest = c.closeCallbacks.Load() @@ -229,14 +257,7 @@ func (c *connection) closeCallback(needLock bool) (err error) { // register only use for connection register into poll. func (c *connection) register() (err error) { - if c.operator.isUnused() { - // operator is not registered - err = c.operator.Control(PollReadable) - } else { - // operator is already registered - // change event to wait read new data - err = c.operator.Control(PollModReadable) - } + err = c.operator.Control(PollReadable) if err != nil { logger.Printf("NETPOLL: connection register failed: %v", err) c.Close() diff --git a/connection_reactor.go b/connection_reactor.go index 65621b93..fa485be1 100644 --- a/connection_reactor.go +++ b/connection_reactor.go @@ -25,35 +25,42 @@ import ( // onHup means close by poller. func (c *connection) onHup(p Poll) error { - if c.closeBy(poller) { - c.triggerRead() - c.triggerWrite(ErrConnClosed) - // It depends on closing by user if OnConnect and OnRequest is nil, otherwise it needs to be released actively. - // It can be confirmed that the OnRequest goroutine has been exited before closecallback executing, - // and it is safe to close the buffer at this time. - var onConnect, _ = c.onConnectCallback.Load().(OnConnect) - var onRequest, _ = c.onRequestCallback.Load().(OnRequest) - if onConnect != nil || onRequest != nil { - c.closeCallback(true) - } + if !c.closeBy(poller) { + return nil + } + c.triggerRead(Exception(ErrEOF, "peer close")) + c.triggerWrite(Exception(ErrConnClosed, "peer close")) + // It depends on closing by user if OnConnect and OnRequest is nil, otherwise it needs to be released actively. + // It can be confirmed that the OnRequest goroutine has been exited before closeCallback executing, + // and it is safe to close the buffer at this time. + var onConnect = c.onConnectCallback.Load() + var onRequest = c.onRequestCallback.Load() + var needCloseByUser = onConnect == nil && onRequest == nil + if !needCloseByUser { + // already PollDetach when call OnHup + c.closeCallback(true, false) } return nil } // onClose means close by user. func (c *connection) onClose() error { + // user code close the connection if c.closeBy(user) { - c.triggerRead() - c.triggerWrite(ErrConnClosed) - c.closeCallback(true) + c.triggerRead(Exception(ErrConnClosed, "self close")) + c.triggerWrite(Exception(ErrConnClosed, "self close")) + // Detach from poller when processing finished, otherwise it will cause race + c.closeCallback(true, true) return nil } - if c.isCloseBy(poller) { - // Connection with OnRequest of nil - // relies on the user to actively close the connection to recycle resources. - c.closeCallback(true) - } - return nil + + // closed by poller + // still need to change closing status to `user` since OnProcess should not be processed again + c.force(closing, user) + + // user code should actively close the connection to recycle resources. + // poller already detached operator + return c.closeCallback(true, false) } // closeBuffer recycle input & output LinkBuffer. @@ -103,7 +110,7 @@ func (c *connection) inputAck(n int) (err error) { needTrigger = c.onRequest() } if needTrigger && length >= int(atomic.LoadInt64(&c.waitReadSize)) { - c.triggerRead() + c.triggerRead(nil) } return nil } diff --git a/connection_test.go b/connection_test.go index 3d8fe160..6de6f017 100644 --- a/connection_test.go +++ b/connection_test.go @@ -211,7 +211,7 @@ func writeAll(fd int, buf []byte) error { // Large packet write test. The socket buffer is 2MB by default, here to verify // whether Connection.Close can be executed normally after socket output buffer is full. func TestLargeBufferWrite(t *testing.T) { - ln, err := CreateListener("tcp", ":1234") + ln, err := createTestListener("tcp", ":12345") MustNil(t, err) trigger := make(chan int) @@ -230,40 +230,43 @@ func TestLargeBufferWrite(t *testing.T) { } }() - conn, err := DialConnection("tcp", ":1234", time.Second) + conn, err := DialConnection("tcp", ":12345", time.Second) MustNil(t, err) rfd := <-trigger var wg sync.WaitGroup wg.Add(1) - bufferSize := 2 * 1024 * 1024 + bufferSize := 2 * 1024 * 1024 // 2MB + round := 128 //start large buffer writing go func() { defer wg.Done() - for i := 0; i < 129; i++ { + for i := 1; i <= round+1; i++ { _, err := conn.Writer().Malloc(bufferSize) MustNil(t, err) err = conn.Writer().Flush() - if i < 128 { + if i <= round { MustNil(t, err) } } }() - time.Sleep(time.Millisecond * 50) + // wait socket buffer full + time.Sleep(time.Millisecond * 100) buf := make([]byte, 1024) - for i := 0; i < 128*bufferSize/1024; i++ { - _, err := syscall.Read(rfd, buf) - MustNil(t, err) + for received := 0; received < round*bufferSize; { + n, _ := syscall.Read(rfd, buf) + received += n } // close success err = conn.Close() MustNil(t, err) wg.Wait() + trigger <- 1 } func TestWriteTimeout(t *testing.T) { - ln, err := CreateListener("tcp", ":1234") + ln, err := createTestListener("tcp", ":1234") MustNil(t, err) interval := time.Millisecond * 100 @@ -397,7 +400,7 @@ func TestConnectionUntil(t *testing.T) { buf, err := rconn.Reader().Until('\n') Equal(t, len(buf), 100) - Assert(t, errors.Is(err, ErrConnClosed), err) + Assert(t, errors.Is(err, ErrEOF), err) } func TestBookSizeLargerThanMaxSize(t *testing.T) { @@ -436,7 +439,7 @@ func TestBookSizeLargerThanMaxSize(t *testing.T) { } func TestConnDetach(t *testing.T) { - ln, err := CreateListener("tcp", ":1234") + ln, err := createTestListener("tcp", ":1234") MustNil(t, err) go func() { @@ -491,3 +494,147 @@ func TestConnDetach(t *testing.T) { err = ln.Close() MustNil(t, err) } + +func TestParallelShortConnection(t *testing.T) { + ln, err := createTestListener("tcp", ":12345") + MustNil(t, err) + defer ln.Close() + + var received int64 + el, err := NewEventLoop(func(ctx context.Context, connection Connection) error { + data, err := connection.Reader().Next(connection.Reader().Len()) + if err != nil { + return err + } + atomic.AddInt64(&received, int64(len(data))) + //t.Logf("conn[%s] received: %d, active: %v", connection.RemoteAddr(), len(data), connection.IsActive()) + return nil + }) + go func() { + el.Serve(ln) + }() + + conns := 100 + sizePerConn := 1024 * 100 + totalSize := conns * sizePerConn + var wg sync.WaitGroup + for i := 0; i < conns; i++ { + wg.Add(1) + go func() { + defer wg.Done() + conn, err := DialConnection("tcp", ":12345", time.Second) + MustNil(t, err) + n, err := conn.Writer().WriteBinary(make([]byte, sizePerConn)) + MustNil(t, err) + MustTrue(t, n == sizePerConn) + err = conn.Writer().Flush() + MustNil(t, err) + err = conn.Close() + MustNil(t, err) + }() + } + wg.Wait() + + for atomic.LoadInt64(&received) < int64(totalSize) { + t.Logf("received: %d, except: %d", atomic.LoadInt64(&received), totalSize) + time.Sleep(time.Millisecond * 100) + } +} + +func TestConnectionServerClose(t *testing.T) { + ln, err := createTestListener("tcp", ":12345") + MustNil(t, err) + defer ln.Close() + + /* + Client Server + - Client --- connect --> Server + - Client <-- [ping] --- Server + - Client --- [pong] --> Server + - Client <-- close --- Server + - Client --- close --> Server + */ + const PING, PONG = "ping", "pong" + var wg sync.WaitGroup + el, err := NewEventLoop( + func(ctx context.Context, connection Connection) error { + t.Logf("server.OnRequest: addr=%s", connection.RemoteAddr()) + defer wg.Done() + buf, err := connection.Reader().Next(len(PONG)) // pong + Equal(t, string(buf), PONG) + MustNil(t, err) + err = connection.Reader().Release() + MustNil(t, err) + err = connection.Close() + MustNil(t, err) + return err + }, + WithOnConnect(func(ctx context.Context, connection Connection) context.Context { + t.Logf("server.OnConnect: addr=%s", connection.RemoteAddr()) + defer wg.Done() + // check OnPrepare + v := ctx.Value("prepare").(string) + Equal(t, v, "true") + + _, err := connection.Writer().WriteBinary([]byte(PING)) + MustNil(t, err) + err = connection.Writer().Flush() + MustNil(t, err) + connection.AddCloseCallback(func(connection Connection) error { + t.Logf("server.CloseCallback: addr=%s", connection.RemoteAddr()) + wg.Done() + return nil + }) + return ctx + }), + WithOnPrepare(func(connection Connection) context.Context { + t.Logf("server.OnPrepare: addr=%s", connection.RemoteAddr()) + defer wg.Done() + return context.WithValue(context.Background(), "prepare", "true") + }), + ) + defer el.Shutdown(context.Background()) + go func() { + err := el.Serve(ln) + if err != nil { + t.Logf("servce end with error: %v", err) + } + }() + + var clientOnRequest OnRequest = func(ctx context.Context, connection Connection) error { + t.Logf("client.OnRequest: addr=%s", connection.LocalAddr()) + defer wg.Done() + buf, err := connection.Reader().Next(len(PING)) + MustNil(t, err) + Equal(t, string(buf), PING) + + _, err = connection.Writer().WriteBinary([]byte(PONG)) + MustNil(t, err) + err = connection.Writer().Flush() + MustNil(t, err) + + _, err = connection.Reader().Next(1) // server will not send any data, just wait for server close + MustTrue(t, errors.Is(err, ErrEOF)) // should get EOF when server close + + return connection.Close() + } + conns := 100 + // server: OnPrepare, OnConnect, OnRequest, CloseCallback + // client: OnRequest, CloseCallback + wg.Add(conns * 6) + for i := 0; i < conns; i++ { + go func() { + conn, err := DialConnection("tcp", ":12345", time.Second) + MustNil(t, err) + err = conn.SetOnRequest(clientOnRequest) + MustNil(t, err) + conn.AddCloseCallback(func(connection Connection) error { + t.Logf("client.CloseCallback: addr=%s", connection.LocalAddr()) + defer wg.Done() + return nil + }) + }() + } + //time.Sleep(time.Second) + wg.Wait() +} diff --git a/fd_operator.go b/fd_operator.go index 4132fe9c..1ac843a9 100644 --- a/fd_operator.go +++ b/fd_operator.go @@ -42,6 +42,9 @@ type FDOperator struct { // poll is the registered location of the file descriptor. poll Poll + // protect only detach once + detached int32 + // private, used by operatorCache next *FDOperator state int32 // CAS: 0(unused) 1(inuse) 2(do-done) @@ -49,6 +52,9 @@ type FDOperator struct { } func (op *FDOperator) Control(event PollEvent) error { + if event == PollDetach && atomic.AddInt32(&op.detached, 1) > 1 { + return nil + } return op.poll.Control(op, event) } @@ -92,4 +98,5 @@ func (op *FDOperator) reset() { op.Inputs, op.InputAck = nil, nil op.Outputs, op.OutputAck = nil, nil op.poll = nil + op.detached = 0 } diff --git a/net_dialer_test.go b/net_dialer_test.go index 3d08ed89..7383fd0d 100644 --- a/net_dialer_test.go +++ b/net_dialer_test.go @@ -167,7 +167,7 @@ func TestFDClose(t *testing.T) { // fd data package race test, use two servers and two dialers. func TestDialerThenClose(t *testing.T) { // server 1 - ln1, _ := CreateListener("tcp", ":1231") + ln1, _ := createTestListener("tcp", ":1231") el1 := mockDialerEventLoop(1) go func() { el1.Serve(ln1) @@ -177,7 +177,7 @@ func TestDialerThenClose(t *testing.T) { defer el1.Shutdown(ctx1) // server 2 - ln2, _ := CreateListener("tcp", ":1232") + ln2, _ := createTestListener("tcp", ":1232") el2 := mockDialerEventLoop(2) go func() { el2.Serve(ln2) diff --git a/net_polldesc.go b/net_polldesc.go index 0b78c653..dfd95de1 100644 --- a/net_polldesc.go +++ b/net_polldesc.go @@ -21,16 +21,15 @@ import ( "context" ) -// TODO: recycle *pollDesc func newPollDesc(fd int) *pollDesc { pd := &pollDesc{} poll := pollmanager.Pick() - op := poll.Alloc() - op.FD = fd - op.OnWrite = pd.onwrite - op.OnHup = pd.onhup - - pd.operator = op + pd.operator = &FDOperator{ + poll: poll, + FD: fd, + OnWrite: pd.onwrite, + OnHup: pd.onhup, + } pd.writeTrigger = make(chan struct{}) pd.closeTrigger = make(chan struct{}) return pd @@ -45,13 +44,6 @@ type pollDesc struct { // WaitWrite . func (pd *pollDesc) WaitWrite(ctx context.Context) (err error) { - defer func() { - // if return err != nil, upper caller function will close the connection - if err != nil { - pd.operator.Free() - } - }() - if pd.operator.isUnused() { // add ET|Write|Hup if err = pd.operator.Control(PollWritable); err != nil { @@ -84,6 +76,7 @@ func (pd *pollDesc) onwrite(p Poll) error { select { case <-pd.writeTrigger: default: + pd.detach() close(pd.writeTrigger) } return nil diff --git a/netpoll_test.go b/netpoll_test.go index cedf6226..0467e879 100644 --- a/netpoll_test.go +++ b/netpoll_test.go @@ -251,6 +251,41 @@ func TestCloseCallbackWhenOnConnect(t *testing.T) { MustNil(t, err) } +func TestCloseConnWhenOnConnect(t *testing.T) { + var network, address = "tcp", ":8888" + conns := 10 + var wg sync.WaitGroup + wg.Add(conns) + var loop = newTestEventLoop(network, address, + nil, + WithOnConnect(func(ctx context.Context, connection Connection) context.Context { + defer wg.Done() + err := connection.Close() + MustNil(t, err) + return ctx + }), + ) + + for i := 0; i < conns; i++ { + wg.Add(1) + go func() { + defer wg.Done() + var conn, err = DialConnection(network, address, time.Second) + if err != nil { + return + } + _, err = conn.Reader().Next(1) + Assert(t, errors.Is(err, ErrEOF)) + err = conn.Close() + MustNil(t, err) + }() + } + + wg.Wait() + err := loop.Shutdown(context.Background()) + MustNil(t, err) +} + func TestServerReadAndClose(t *testing.T) { var network, address = "tcp", ":18888" var sendMsg = []byte("hello") @@ -287,6 +322,37 @@ func TestServerReadAndClose(t *testing.T) { MustNil(t, err) } +func TestServerPanicAndClose(t *testing.T) { + var network, address = "tcp", ":18888" + var sendMsg = []byte("hello") + var paniced int32 + var loop = newTestEventLoop(network, address, + func(ctx context.Context, connection Connection) error { + _, err := connection.Reader().Next(len(sendMsg)) + MustNil(t, err) + atomic.StoreInt32(&paniced, 1) + panic("test") + }, + ) + + var conn, err = DialConnection(network, address, time.Second) + MustNil(t, err) + _, err = conn.Writer().WriteBinary(sendMsg) + MustNil(t, err) + err = conn.Writer().Flush() + MustNil(t, err) + + for atomic.LoadInt32(&paniced) == 0 { + runtime.Gosched() // wait for poller close connection + } + for conn.IsActive() { + runtime.Gosched() // wait for poller close connection + } + + err = loop.Shutdown(context.Background()) + MustNil(t, err) +} + func TestClientWriteAndClose(t *testing.T) { var ( network, address = "tcp", ":18889" @@ -331,8 +397,18 @@ func TestClientWriteAndClose(t *testing.T) { MustNil(t, err) } +func createTestListener(network, address string) (Listener, error) { + for { + ln, err := CreateListener(network, address) + if err == nil { + return ln, nil + } + time.Sleep(time.Millisecond * 100) + } +} + func newTestEventLoop(network, address string, onRequest OnRequest, opts ...Option) EventLoop { - ln, err := CreateListener(network, address) + ln, err := createTestListener(network, address) if err != nil { panic(err) } diff --git a/nocopy_linkbuffer.go b/nocopy_linkbuffer.go index 59cc6530..555ba5ce 100644 --- a/nocopy_linkbuffer.go +++ b/nocopy_linkbuffer.go @@ -461,9 +461,8 @@ func (b *LinkBuffer) WriteBinary(p []byte) (n int, err error) { } // here will copy b.growth(n) - malloc := b.write.malloc - b.write.malloc += n - return copy(b.write.buf[malloc:b.write.malloc], p), nil + buf := b.write.Malloc(n) + return copy(buf, p), nil } // WriteDirect cannot be mixed with WriteString or WriteBinary functions. @@ -578,7 +577,8 @@ func (b *LinkBuffer) GetBytes(p [][]byte) (vs [][]byte) { // // bookSize: The size of data that can be read at once. // maxSize: The maximum size of data between two Release(). In some cases, this can -// guarantee all data allocated in one node to reduce copy. +// +// guarantee all data allocated in one node to reduce copy. func (b *LinkBuffer) book(bookSize, maxSize int) (p []byte) { l := cap(b.write.buf) - b.write.malloc // grow linkBuffer diff --git a/nocopy_linkbuffer_race.go b/nocopy_linkbuffer_race.go index a785aa15..4b3635d0 100644 --- a/nocopy_linkbuffer_race.go +++ b/nocopy_linkbuffer_race.go @@ -497,9 +497,8 @@ func (b *LinkBuffer) WriteBinary(p []byte) (n int, err error) { } // here will copy b.growth(n) - malloc := b.write.malloc - b.write.malloc += n - return copy(b.write.buf[malloc:b.write.malloc], p), nil + buf := b.write.Malloc(n) + return copy(buf, p), nil } // WriteDirect cannot be mixed with WriteString or WriteBinary functions. @@ -622,7 +621,8 @@ func (b *LinkBuffer) GetBytes(p [][]byte) (vs [][]byte) { // // bookSize: The size of data that can be read at once. // maxSize: The maximum size of data between two Release(). In some cases, this can -// guarantee all data allocated in one node to reduce copy. +// +// guarantee all data allocated in one node to reduce copy. func (b *LinkBuffer) book(bookSize, maxSize int) (p []byte) { b.Lock() defer b.Unlock() diff --git a/poll.go b/poll.go index 1d5c42fb..c494ffd6 100644 --- a/poll.go +++ b/poll.go @@ -57,10 +57,6 @@ const ( // PollDetach is used to remove the FDOperator from poll. PollDetach PollEvent = 0x3 - // PollModReadable is used to re-register the readable monitor for the FDOperator created by the dialer. - // It is only used when calling the dialer's conn init. - PollModReadable PollEvent = 0x4 - // PollR2RW is used to monitor writable for FDOperator, // which is only called when the socket write buffer is full. PollR2RW PollEvent = 0x5 diff --git a/poll_default.go b/poll_default.go index e926311b..b35ff5a6 100644 --- a/poll_default.go +++ b/poll_default.go @@ -55,21 +55,22 @@ func (p *defaultPoll) onhups() { } // readall read all left data before close connection -func readall(op *FDOperator, br barrier) (err error) { +func readall(op *FDOperator, br barrier) (total int, err error) { var bs = br.bs var ivs = br.ivs var n int for { bs = op.Inputs(br.bs) if len(bs) == 0 { - return nil + return total, nil } TryRead: n, err = ioread(op.FD, bs, ivs) op.InputAck(n) + total += n if err != nil { - return err + return total, err } if n == 0 { goto TryRead diff --git a/poll_default_bsd.go b/poll_default_bsd.go index 3312e435..9c8aa8c9 100644 --- a/poll_default_bsd.go +++ b/poll_default_bsd.go @@ -18,6 +18,7 @@ package netpoll import ( + "errors" "sync" "sync/atomic" "syscall" @@ -89,6 +90,7 @@ func (p *defaultPoll) Wait() error { continue } + var totalRead int evt := events[i] triggerRead = evt.Filter == syscall.EVFILT_READ && evt.Flags&syscall.EV_ENABLE != 0 triggerWrite = evt.Filter == syscall.EVFILT_WRITE && evt.Flags&syscall.EV_ENABLE != 0 @@ -104,6 +106,7 @@ func (p *defaultPoll) Wait() error { if len(bs) > 0 { var n, err = ioread(operator.FD, bs, barriers[i].ivs) operator.InputAck(n) + totalRead += n if err != nil { p.appendHup(operator) continue @@ -111,14 +114,20 @@ func (p *defaultPoll) Wait() error { } } } - if triggerHup && triggerRead && operator.Inputs != nil { // read all left data if peer send and close - if err = readall(operator, barriers[i]); err != nil { - logger.Printf("NETPOLL: readall(fd=%d) before close: %s", operator.FD, err.Error()) - } - } if triggerHup { - p.appendHup(operator) - continue + if triggerRead && operator.Inputs != nil { + var leftRead int + // read all left data if peer send and close + if leftRead, err = readall(operator, barriers[i]); err != nil && !errors.Is(err, ErrEOF) { + logger.Printf("NETPOLL: readall(fd=%d)=%d before close: %s", operator.FD, total, err.Error()) + } + totalRead += leftRead + } + // only close connection if no further read bytes + if totalRead == 0 { + p.appendHup(operator) + continue + } } if triggerWrite { if operator.OnWrite != nil { @@ -171,19 +180,23 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { evs[0].Ident = uint64(operator.FD) p.setOperator(unsafe.Pointer(&evs[0].Udata), operator) switch event { - case PollReadable, PollModReadable: + case PollReadable: operator.inuse() evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE case PollWritable: operator.inuse() - evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE|syscall.EV_ONESHOT + evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE case PollDetach: + if operator.OnWrite != nil { // means WaitWrite finished + evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE + } else { + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE + } p.delOperator(operator) - evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE|syscall.EV_ONESHOT case PollR2RW: evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE case PollRW2R: - evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE|syscall.EV_ONESHOT + evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE } _, err := syscall.Kevent(p.fd, evs, nil, nil) return err diff --git a/poll_default_linux.go b/poll_default_linux.go index 8da7d55b..a0087ee0 100644 --- a/poll_default_linux.go +++ b/poll_default_linux.go @@ -15,6 +15,7 @@ package netpoll import ( + "errors" "runtime" "sync" "sync/atomic" @@ -116,12 +117,14 @@ func (p *defaultPoll) Wait() (err error) { func (p *defaultPoll) handler(events []epollevent) (closed bool) { var triggerRead, triggerWrite, triggerHup, triggerError bool + var err error for i := range events { operator := p.getOperator(0, unsafe.Pointer(&events[i].data)) if operator == nil || !operator.do() { continue } + var totalRead int evt := events[i].events triggerRead = evt&syscall.EPOLLIN != 0 triggerWrite = evt&syscall.EPOLLOUT != 0 @@ -154,6 +157,7 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) { if len(bs) > 0 { var n, err = ioread(operator.FD, bs, p.barriers[i].ivs) operator.InputAck(n) + totalRead += n if err != nil { p.appendHup(operator) continue @@ -163,14 +167,21 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) { logger.Printf("NETPOLL: operator has critical problem! event=%d operator=%v", evt, operator) } } - if triggerHup && triggerRead && operator.Inputs != nil { // read all left data if peer send and close - if err := readall(operator, p.barriers[i]); err != nil { - logger.Printf("NETPOLL: readall(fd=%d) before close: %s", operator.FD, err.Error()) - } - } if triggerHup { - p.appendHup(operator) - continue + if triggerRead && operator.Inputs != nil { + // read all left data if peer send and close + var leftRead int + // read all left data if peer send and close + if leftRead, err = readall(operator, p.barriers[i]); err != nil && !errors.Is(err, ErrEOF) { + logger.Printf("NETPOLL: readall(fd=%d)=%d before close: %s", operator.FD, total, err.Error()) + } + totalRead += leftRead + } + // only close connection if no further read bytes + if totalRead == 0 { + p.appendHup(operator) + continue + } } if triggerError { // Under block-zerocopy, the kernel may give an error callback, which is not a real error, just an EAGAIN. @@ -237,8 +248,6 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { case PollWritable: // client create a new connection and wait connect finished operator.inuse() op, evt.events = syscall.EPOLL_CTL_ADD, EPOLLET|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR - case PollModReadable: // client wait read/write - op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollDetach: // deregister p.delOperator(operator) op, evt.events = syscall.EPOLL_CTL_DEL, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR diff --git a/poll_default_linux_test.go b/poll_default_linux_test.go index acd0afc9..072963d7 100644 --- a/poll_default_linux_test.go +++ b/poll_default_linux_test.go @@ -62,7 +62,7 @@ func TestEpollEvent(t *testing.T) { MustNil(t, err) _, err = syscall.Write(wfd, send) MustNil(t, err) - n, err := EpollWait(epollfd, events, -1) + n, err := epollWaitUntil(epollfd, events, -1) MustNil(t, err) Equal(t, n, 1) Equal(t, events[0].data, eventdata2) @@ -80,7 +80,7 @@ func TestEpollEvent(t *testing.T) { MustNil(t, err) _, err = syscall.Write(wfd, send) MustNil(t, err) - n, err = EpollWait(epollfd, events, -1) + n, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Equal(t, events[0].data, eventdata3) _, err = syscall.Read(rfd, recv) @@ -112,7 +112,7 @@ func TestEpollWait(t *testing.T) { } err = EpollCtl(epollfd, unix.EPOLL_CTL_ADD, rfd, event) MustNil(t, err) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLIN == 0) Assert(t, events[0].events&syscall.EPOLLOUT != 0) @@ -120,7 +120,7 @@ func TestEpollWait(t *testing.T) { // EPOLL: readable _, err = syscall.Write(wfd, send) MustNil(t, err) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLIN != 0) Assert(t, events[0].events&syscall.EPOLLOUT != 0) @@ -128,7 +128,7 @@ func TestEpollWait(t *testing.T) { MustTrue(t, err == nil && string(recv) == string(send)) // EPOLL: read finished - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLIN == 0) Assert(t, events[0].events&syscall.EPOLLOUT != 0) @@ -136,7 +136,7 @@ func TestEpollWait(t *testing.T) { // EPOLL: close peer fd err = syscall.Close(wfd) MustNil(t, err) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLIN != 0) Assert(t, events[0].events&syscall.EPOLLOUT != 0) @@ -149,7 +149,7 @@ func TestEpollWait(t *testing.T) { err = EpollCtl(epollfd, unix.EPOLL_CTL_ADD, rfd2, event) err = syscall.Close(rfd2) MustNil(t, err) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLIN != 0) Assert(t, events[0].events&syscall.EPOLLOUT != 0) @@ -174,7 +174,7 @@ func TestEpollETClose(t *testing.T) { // EPOLL: init state err = EpollCtl(epollfd, unix.EPOLL_CTL_ADD, rfd, event) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLIN == 0) Assert(t, events[0].events&syscall.EPOLLOUT != 0) @@ -185,7 +185,7 @@ func TestEpollETClose(t *testing.T) { // nothing will happen err = syscall.Close(rfd) MustNil(t, err) - n, err := EpollWait(epollfd, events, 100) + n, err := epollWaitUntil(epollfd, events, 100) MustNil(t, err) Assert(t, n == 0, n) err = syscall.Close(wfd) @@ -197,7 +197,7 @@ func TestEpollETClose(t *testing.T) { err = EpollCtl(epollfd, unix.EPOLL_CTL_ADD, rfd, event) err = syscall.Close(wfd) MustNil(t, err) - n, err = EpollWait(epollfd, events, 100) + n, err = epollWaitUntil(epollfd, events, 100) MustNil(t, err) Assert(t, n == 1, n) Assert(t, events[0].events&syscall.EPOLLIN != 0) @@ -231,7 +231,7 @@ func TestEpollETDel(t *testing.T) { MustNil(t, err) _, err = syscall.Write(wfd, send) MustNil(t, err) - _, err = EpollWait(epollfd, events, 100) + _, err = epollWaitUntil(epollfd, events, 100) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLIN == 0) Assert(t, events[0].events&syscall.EPOLLRDHUP == 0) @@ -272,11 +272,11 @@ func TestEpollConnectSameFD(t *testing.T) { MustNil(t, err) err = syscall.Connect(fd1, &addr) t.Log(err) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLOUT != 0) - Assert(t, events[0].events&syscall.EPOLLRDHUP == 0) - Assert(t, events[0].events&syscall.EPOLLERR == 0) + //Assert(t, events[0].events&syscall.EPOLLRDHUP == 0) + //Assert(t, events[0].events&syscall.EPOLLERR == 0) // forget to del fd //err = EpollCtl(epollfd, unix.EPOLL_CTL_DEL, fd1, event1) //MustNil(t, err) @@ -293,7 +293,7 @@ func TestEpollConnectSameFD(t *testing.T) { MustNil(t, err) err = syscall.Connect(fd2, &addr) t.Log(err) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLOUT != 0) Assert(t, events[0].events&syscall.EPOLLRDHUP == 0) @@ -314,7 +314,7 @@ func TestEpollConnectSameFD(t *testing.T) { MustNil(t, err) err = syscall.Connect(fd3, &addr) t.Log(err) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLOUT != 0) Assert(t, events[0].events&syscall.EPOLLRDHUP == 0) @@ -324,7 +324,16 @@ func TestEpollConnectSameFD(t *testing.T) { MustNil(t, err) err = syscall.Close(fd3) // close fd3 MustNil(t, err) - n, err := EpollWait(epollfd, events, 100) + n, err := epollWaitUntil(epollfd, events, 100) MustNil(t, err) Assert(t, n == 0) } + +func epollWaitUntil(epfd int, events []epollevent, msec int) (n int, err error) { +WAIT: + n, err = EpollWait(epfd, events, msec) + if err == syscall.EINTR { + goto WAIT + } + return n, err +}