diff --git a/connection_impl.go b/connection_impl.go index db2f8d69..aa15e1e9 100644 --- a/connection_impl.go +++ b/connection_impl.go @@ -24,8 +24,14 @@ import ( "time" ) +type connState = int32 + const ( defaultZeroCopyTimeoutSec = 60 + + connStateNone = 0 + connStateConnected = 1 + connStateDisconnected = 2 ) // connection is the implement of Connection @@ -45,9 +51,9 @@ type connection struct { outputBuffer *LinkBuffer outputBarrier *barrier supportZeroCopy bool - maxSize int // The maximum size of data between two Release(). - bookSize int // The size of data that can be read at once. - state int32 // 0: not connected, 1: connected, 2: disconnected. Connection state should be changed sequentially. + maxSize int // The maximum size of data between two Release(). + bookSize int // The size of data that can be read at once. + state connState // Connection state should be changed sequentially. } var ( @@ -333,7 +339,7 @@ func (c *connection) init(conn Conn, opts *options) (err error) { c.bookSize, c.maxSize = defaultLinkBufferSize, defaultLinkBufferSize c.inputBuffer, c.outputBuffer = NewLinkBuffer(defaultLinkBufferSize), NewLinkBuffer() c.outputBarrier = barrierPool.Get().(*barrier) - c.state = 0 + c.state = connStateNone c.initNetFD(conn) // conn must be *netFD{} c.initFDOperator() @@ -536,3 +542,15 @@ func (c *connection) waitFlush() (err error) { return Exception(ErrWriteTimeout, c.remoteAddr.String()) } } + +func (c *connection) getState() connState { + return atomic.LoadInt32(&c.state) +} + +func (c *connection) setState(newState connState) { + atomic.StoreInt32(&c.state, newState) +} + +func (c *connection) changeState(from, to connState) bool { + return atomic.CompareAndSwapInt32(&c.state, from, to) +} diff --git a/connection_onevent.go b/connection_onevent.go index 35b7c001..5dc986da 100644 --- a/connection_onevent.go +++ b/connection_onevent.go @@ -134,7 +134,7 @@ func (c *connection) onPrepare(opts *options) (err error) { func (c *connection) onConnect() { var onConnect, _ = c.onConnectCallback.Load().(OnConnect) if onConnect == nil { - atomic.StoreInt32(&c.state, 1) + c.changeState(connStateNone, connStateConnected) return } if !c.lock(connecting) { @@ -142,35 +142,7 @@ func (c *connection) onConnect() { return } var onRequest, _ = c.onRequestCallback.Load().(OnRequest) - c.onProcess( - // only process when conn active and have unread data - func(c *connection) bool { - // if onConnect not called - if atomic.LoadInt32(&c.state) == 0 { - return true - } - // check for onRequest - return onRequest != nil && c.Reader().Len() > 0 - }, - func(c *connection) { - if atomic.CompareAndSwapInt32(&c.state, 0, 1) { - c.ctx = onConnect(c.ctx, c) - - if !c.IsActive() && atomic.CompareAndSwapInt32(&c.state, 1, 2) { - // since we hold connecting lock, so we should help to call onDisconnect here - var onDisconnect, _ = c.onDisconnectCallback.Load().(OnDisconnect) - if onDisconnect != nil { - onDisconnect(c.ctx, c) - } - } - c.unlock(connecting) - return - } - if onRequest != nil { - _ = onRequest(c.ctx, c) - } - }, - ) + c.onProcess(onConnect, onRequest) } // when onDisconnect called, c.IsActive() must return false @@ -182,15 +154,16 @@ func (c *connection) onDisconnect() { var onConnect, _ = c.onConnectCallback.Load().(OnConnect) if onConnect == nil { // no need lock if onConnect is nil - atomic.StoreInt32(&c.state, 2) + // it's ok to force set state to disconnected since onConnect is nil + c.setState(connStateDisconnected) onDisconnect(c.ctx, c) return } // check if OnConnect finished when onConnect != nil && onDisconnect != nil - if atomic.LoadInt32(&c.state) > 0 && c.lock(connecting) { // means OnConnect already finished + if c.getState() != connStateNone && c.lock(connecting) { // means OnConnect already finished // protect onDisconnect run once // if CAS return false, means OnConnect already helps to run onDisconnect - if atomic.CompareAndSwapInt32(&c.state, 1, 2) { + if c.changeState(connStateConnected, connStateDisconnected) { onDisconnect(c.ctx, c) } c.unlock(connecting) @@ -207,63 +180,66 @@ func (c *connection) onRequest() (needTrigger bool) { return true } // wait onConnect finished first - if atomic.LoadInt32(&c.state) == 0 && c.onConnectCallback.Load() != nil { + if c.getState() == connStateNone && c.onConnectCallback.Load() != nil { // let onConnect to call onRequest return } - processed := c.onProcess( - // only process when conn active and have unread data - func(c *connection) bool { - return c.Reader().Len() > 0 - }, - func(c *connection) { - _ = onRequest(c.ctx, c) - }, - ) + processed := c.onProcess(nil, onRequest) // if not processed, should trigger read return !processed } -// onProcess is responsible for executing the process function serially, -// and make sure the connection has been closed correctly if user call c.Close() in process function. -func (c *connection) onProcess(isProcessable func(c *connection) bool, process func(c *connection)) (processed bool) { - if process == nil { - return false - } +// onProcess is responsible for executing the onConnect/onRequest function serially, +// and make sure the connection has been closed correctly if user call c.Close() in onConnect/onRequest function. +func (c *connection) onProcess(onConnect OnConnect, onRequest OnRequest) (processed bool) { // task already exists if !c.lock(processing) { return false } - // add new task - var task = func() { + + task := func() { panicked := true defer func() { + if !panicked { + return + } // cannot use recover() here, since we don't want to break the panic stack - if panicked { - c.unlock(processing) - if c.IsActive() { - c.Close() - } else { - c.closeCallback(false, false) - } + c.unlock(processing) + if c.IsActive() { + c.Close() + } else { + c.closeCallback(false, false) } }() + // trigger onConnect first + if onConnect != nil && c.changeState(connStateNone, connStateConnected) { + c.ctx = onConnect(c.ctx, c) + if !c.IsActive() && c.changeState(connStateConnected, connStateDisconnected) { + // since we hold connecting lock, so we should help to call onDisconnect here + onDisconnect, _ := c.onDisconnectCallback.Load().(OnDisconnect) + if onDisconnect != nil { + onDisconnect(c.ctx, c) + } + } + c.unlock(connecting) + } START: - // `process` must be executed at least once if `isProcessable` in order to cover the `send & close by peer` case. - // Then the loop processing must ensure that the connection `IsActive`. - if isProcessable(c) { - process(c) + // The `onRequest` must be executed at least once if conn have any readable data, + // which is in order to cover the `send & close by peer` case. + if onRequest != nil && c.Reader().Len() > 0 { + _ = onRequest(c.ctx, c) } - // `process` must either eventually read all the input data or actively Close the connection, + // The processing loop must ensure that the connection meets `IsActive`. + // `onRequest` must either eventually read all the input data or actively Close the connection, // otherwise the goroutine will fall into a dead loop. var closedBy who for { closedBy = c.status(closing) - // close by user or no processable - if closedBy == user || !isProcessable(c) { + // close by user or not processable + if closedBy == user || onRequest == nil || c.Reader().Len() == 0 { break } - process(c) + _ = onRequest(c.ctx, c) } // handling callback if connection has been closed. if closedBy != none { @@ -288,14 +264,15 @@ func (c *connection) onProcess(isProcessable func(c *connection) bool, process f panicked = false return } - // double check isProcessable - if isProcessable(c) && c.lock(processing) { + // double check is processable + if onRequest != nil && c.Reader().Len() > 0 && c.lock(processing) { goto START } // task exits panicked = false return } + // add new task runTask(c.ctx, task) return true } diff --git a/connection_test.go b/connection_test.go index 163645ab..65dd69ff 100644 --- a/connection_test.go +++ b/connection_test.go @@ -32,6 +32,30 @@ import ( "time" ) +func BenchmarkConnectionIO(b *testing.B) { + var dataSize = 1024 * 16 + var writeBuffer = make([]byte, dataSize) + var rfd, wfd = GetSysFdPairs() + var rconn, wconn = new(connection), new(connection) + rconn.init(&netFD{fd: rfd}, &options{onRequest: func(ctx context.Context, connection Connection) error { + read, _ := connection.Reader().Next(dataSize) + _ = wconn.Reader().Release() + _, _ = connection.Writer().WriteBinary(read) + _ = connection.Writer().Flush() + return nil + }}) + wconn.init(&netFD{fd: wfd}, new(options)) + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = wconn.WriteBinary(writeBuffer) + _ = wconn.Flush() + _, _ = wconn.Reader().Next(dataSize) + _ = wconn.Reader().Release() + } +} + func TestConnectionWrite(t *testing.T) { var cycle, caps = 10000, 256 var msg, buf = make([]byte, caps), make([]byte, caps) diff --git a/poll_manager_test.go b/poll_manager_test.go index c5648a76..f61c5282 100644 --- a/poll_manager_test.go +++ b/poll_manager_test.go @@ -61,7 +61,7 @@ func TestPollManagerSetNumLoops(t *testing.T) { poll := pm.Pick() newGs := runtime.NumGoroutine() Assert(t, poll != nil) - Assert(t, newGs-startGs == 1, newGs, startGs) + Assert(t, newGs-startGs >= 1, newGs, startGs) t.Logf("old=%d, new=%d", startGs, newGs) // change pollers