diff --git a/client/obclient.go b/client/obclient.go index 977d555..a8bd721 100644 --- a/client/obclient.go +++ b/client/obclient.go @@ -216,6 +216,8 @@ func (c *obClient) init() error { func (c *obClient) initOdp() error { // 1. Init odp table t := NewObTable(c.odpIP, c.odpRpcPort, c.tenantName, c.fullUserName, c.password, c.database) + t.setMaxConnectionAge(c.config.MaxConnectionAge) + t.setEnableSLBLoadBalance(c.config.EnableSLBLoadBalance) err := t.init(c.config.ConnPoolMaxConnSize, c.config.ConnConnectTimeOut, c.config.ConnLoginTimeout) // 2. Init sql // ObVersion will be set when login in init() diff --git a/client/table.go b/client/table.go index 203fde6..2df0ef0 100644 --- a/client/table.go +++ b/client/table.go @@ -41,6 +41,9 @@ type ObTable struct { isClosed bool mutex sync.Mutex + + maxConnectionAge time.Duration + enableSLBLoadBalance bool } func NewObTable( @@ -51,13 +54,15 @@ func NewObTable( password string, database string) *ObTable { return &ObTable{ - ip: ip, - port: port, - tenantName: tenantName, - userName: userName, - password: password, - database: database, - isClosed: false, + ip: ip, + port: port, + tenantName: tenantName, + userName: userName, + password: password, + database: database, + isClosed: false, + maxConnectionAge: time.Duration(0), + enableSLBLoadBalance: false, } } @@ -72,6 +77,8 @@ func (t *ObTable) init(connPoolSize int, connectTimeout time.Duration, loginTime t.database, t.userName, t.password, + t.maxConnectionAge, + t.enableSLBLoadBalance, ) cli, err := obkvrpc.NewRpcClient(opt) if err != nil { @@ -81,6 +88,14 @@ func (t *ObTable) init(connPoolSize int, connectTimeout time.Duration, loginTime return nil } +func (t *ObTable) setMaxConnectionAge(duration time.Duration) { + t.maxConnectionAge = duration +} + +func (t *ObTable) setEnableSLBLoadBalance(b bool) { + t.enableSLBLoadBalance = b +} + func (t *ObTable) retry( ctx context.Context, request protocol.ObPayload, diff --git a/config/client_config.go b/config/client_config.go index 218b062..a984c8d 100644 --- a/config/client_config.go +++ b/config/client_config.go @@ -45,6 +45,10 @@ type ClientConfig struct { RsListHttpGetRetryInterval time.Duration EnableRerouting bool + + // connection rebalance in ODP mode + MaxConnectionAge time.Duration + EnableSLBLoadBalance bool } func NewDefaultClientConfig() *ClientConfig { @@ -65,6 +69,8 @@ func NewDefaultClientConfig() *ClientConfig { RsListHttpGetRetryTimes: 3, RsListHttpGetRetryInterval: time.Duration(100) * time.Millisecond, // 100ms, EnableRerouting: false, + MaxConnectionAge: time.Duration(0) * time.Second, // valid iff > 0 + EnableSLBLoadBalance: false, } } @@ -85,6 +91,8 @@ func (c *ClientConfig) String() string { "RsListHttpGetTimeout:" + c.RsListHttpGetTimeout.String() + ", " + "RsListHttpGetRetryTimes:" + strconv.Itoa(c.RsListHttpGetRetryTimes) + ", " + "RsListHttpGetRetryInterval:" + c.RsListHttpGetRetryInterval.String() + ", " + - "EnableRerouting:" + strconv.FormatBool(c.EnableRerouting) + + "EnableRerouting:" + strconv.FormatBool(c.EnableRerouting) + ", " + + "MaxConnectionAge:" + c.MaxConnectionAge.String() + ", " + + "EnableSLBLoadBalance:" + strconv.FormatBool(c.EnableSLBLoadBalance) + "}" } diff --git a/config/toml_config.go b/config/toml_config.go index a529f33..27c31d5 100644 --- a/config/toml_config.go +++ b/config/toml_config.go @@ -81,9 +81,11 @@ type RsListConfig struct { } type ExtraConfig struct { - OperationTimeOut int - LogLevel string - EnableRerouting bool + OperationTimeOut int + LogLevel string + EnableRerouting bool + MaxConnectionAge int + EnableSLBLoadBalance bool } func (c *ClientConfiguration) checkClientConfiguration() error { @@ -147,6 +149,8 @@ func (c *ClientConfiguration) GetClientConfig() *ClientConfig { RsListHttpGetRetryTimes: c.RsListConfig.HttpGetRetryTimes, RsListHttpGetRetryInterval: time.Duration(c.RsListConfig.HttpGetRetryInterval) * time.Millisecond, EnableRerouting: c.ExtraConfig.EnableRerouting, + MaxConnectionAge: time.Duration(c.ExtraConfig.MaxConnectionAge) * time.Millisecond, + EnableSLBLoadBalance: c.ExtraConfig.EnableSLBLoadBalance, } } diff --git a/configurations/obkv-table-default.toml b/configurations/obkv-table-default.toml index 13ba688..b202eb0 100644 --- a/configurations/obkv-table-default.toml +++ b/configurations/obkv-table-default.toml @@ -48,3 +48,5 @@ HttpGetRetryInterval = 100 OperationTimeOut = 10000 LogLevel = "info" EnableRerouting = false +MaxConnectionAge = 0 +EnableSLBLoadBalance = false diff --git a/go.mod b/go.mod index 86fc6da..304e090 100644 --- a/go.mod +++ b/go.mod @@ -4,15 +4,17 @@ go 1.19 require ( github.com/go-sql-driver/mysql v1.7.0 + github.com/naoina/toml v0.1.1 github.com/pkg/errors v0.9.1 + github.com/scylladb/go-set v1.0.2 github.com/stretchr/testify v1.8.0 go.uber.org/zap v1.24.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/naoina/go-stringutil v0.1.0 // indirect - github.com/naoina/toml v0.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect go.uber.org/atomic v1.10.0 // indirect go.uber.org/multierr v1.11.0 // indirect diff --git a/go.sum b/go.sum index 3050ebd..872b3a1 100644 --- a/go.sum +++ b/go.sum @@ -2,10 +2,12 @@ github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLj github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/set v0.2.1 h1:nn2CaJyknWE/6txyUDGwysr3G5QC6xWB/PtVjPBbeaA= +github.com/fatih/set v0.2.1/go.mod h1:+RKtMCH+favT2+3YecHGxcc0b4KyVWA1QWWJUs4E0CI= github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/naoina/go-stringutil v0.1.0 h1:rCUeRUHjBjGTSHl0VC00jUPLz8/F9dDzYI70Hzifhks= github.com/naoina/go-stringutil v0.1.0/go.mod h1:XJ2SJL9jCtBh+P9q5btrd/Ylo8XwT/h1USek5+NqSA0= github.com/naoina/toml v0.1.1 h1:PT/lllxVVN0gzzSqSlHEmP8MJB4MY2U7STGxiouV4X8= @@ -14,8 +16,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= -github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/scylladb/go-set v1.0.2 h1:SkvlMCKhP0wyyct6j+0IHJkBkSZL+TDzZ4E7f7BCcRE= +github.com/scylladb/go-set v1.0.2/go.mod h1:DkpGd78rljTxKAnTDPFqXSGxvETQnJyuSOQwsHycqfs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -28,11 +30,6 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/obkvrpc/connection.go b/obkvrpc/connection.go index 4b6dc22..7de0546 100644 --- a/obkvrpc/connection.go +++ b/obkvrpc/connection.go @@ -105,6 +105,9 @@ type Connection struct { ezHeaderLength int rpcHeaderLength int + expireTime time.Time + isExpired atomic.Bool + slbLoader *SLBLoader } type packet struct { @@ -211,6 +214,7 @@ func (c *Connection) Execute( seq := c.seq.Add(1) totalBuf := c.encodePacket(seq, request) + trace := fmt.Sprintf("Y%X-%016X", request.UniqueId(), request.Sequence()) call := &call{ err: nil, @@ -234,21 +238,21 @@ func (c *Connection) Execute( c.mutex.Lock() delete(c.pending, seq) c.mutex.Unlock() - return errors.WithMessage(ctx.Err(), "wait send packet to channel") + return errors.WithMessage(ctx.Err(), "wait send packet to channel, trace: "+trace) } // wait call back select { case call = <-call.signal: if call.err != nil { // transport failed - return errors.WithMessage(call.err, "receive packet") + return errors.WithMessage(call.err, "receive packet, trace: "+trace) } case <-ctx.Done(): // timeout c.mutex.Lock() delete(c.pending, seq) c.mutex.Unlock() - return errors.WithMessage(ctx.Err(), "wait transport packet") + return errors.WithMessage(ctx.Err(), "wait transport packet, trace: "+trace) } // transport success @@ -402,6 +406,7 @@ func (c *Connection) writerWrite(packet packet) { } func (c *Connection) Close() { + log.Info(fmt.Sprintf("close connection start, remote addr:%s", c.conn.RemoteAddr().String())) c.active.Store(false) c.closeOnce.Do(func() { close(c.packetChannelClose) // close packet channel @@ -415,6 +420,7 @@ func (c *Connection) Close() { } c.mutex.Unlock() }) + log.Info(fmt.Sprintf("close connection success, remote addr:%s", c.conn.RemoteAddr().String())) } func (c *Connection) encodePacket(seq uint32, request protocol.ObPayload) []byte { diff --git a/obkvrpc/connection_lifecycle_mgr.go b/obkvrpc/connection_lifecycle_mgr.go new file mode 100644 index 0000000..cf25c27 --- /dev/null +++ b/obkvrpc/connection_lifecycle_mgr.go @@ -0,0 +1,122 @@ +/*- + * #%L + * OBKV Table Client Framework + * %% + * Copyright (C) 2023 OceanBase + * %% + * OBKV Table Client Framework is licensed under Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, + * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, + * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + * #L% + */ + +package obkvrpc + +import ( + "context" + "fmt" + "github.com/oceanbase/obkv-table-client-go/log" + "go.uber.org/zap" + "math" + "time" +) + +type ConnectionLifeCycleMgr struct { + connPool *ConnectionPool + maxConnectionAge time.Duration + lastExpireIdx int +} + +func (s *ConnectionLifeCycleMgr) String() string { + return fmt.Sprintf("ConnectionLifeCycleMgr{connPool: %p, maxConnectionAge: %d,lastExpireIdx: %d}", + s.connPool, s.maxConnectionAge, s.lastExpireIdx) +} + +func NewConnectionLifeCycleMgr(connPool *ConnectionPool, maxConnectionAge time.Duration) *ConnectionLifeCycleMgr { + connLifeCycleMgr := &ConnectionLifeCycleMgr{ + connPool: connPool, + maxConnectionAge: maxConnectionAge, + lastExpireIdx: 0, + } + return connLifeCycleMgr +} + +// check and reconnect timeout connections +func (c *ConnectionLifeCycleMgr) run() { + if c.connPool == nil { + log.Error("connection pool is null") + return + } + + // 1. get all timeout connections + expiredConnIds := make([]int, 0, len(c.connPool.connections)) + for i := 1; i <= len(c.connPool.connections); i++ { + connection := c.connPool.connections[(i+c.lastExpireIdx)%(len(c.connPool.connections))] + if !connection.expireTime.IsZero() && connection.expireTime.Before(time.Now()) { + expiredConnIds = append(expiredConnIds, (i+c.lastExpireIdx)%(len(c.connPool.connections))) + } + } + + if len(expiredConnIds) > 0 { + log.Info(fmt.Sprintf("Find %d expired connections", len(expiredConnIds))) + for idx, connIdx := range expiredConnIds { + log.Info(fmt.Sprintf("%d: ip=%s, port=%d", idx, c.connPool.connections[connIdx].option.ip, c.connPool.connections[connIdx].option.port)) + } + } + + // 2. mark 30% expired connections as expired + maxReconnIdx := int(math.Ceil(float64(len(expiredConnIds)) / 3)) + if maxReconnIdx > 0 { + c.lastExpireIdx = expiredConnIds[maxReconnIdx-1] + log.Info(fmt.Sprintf("Begin to refresh expired connections which idx less than %d", maxReconnIdx)) + } + for i := 0; i < maxReconnIdx; i++ { + // no one can get expired connection + c.connPool.connections[expiredConnIds[i]].isExpired.Store(true) + } + defer func() { + for i := 0; i < maxReconnIdx; i++ { + c.connPool.connections[expiredConnIds[i]].isExpired.Store(false) + } + }() + + // 3. wait all expired connection finished + time.Sleep(DefaultConnectWaitTime) + for i := 0; i < maxReconnIdx; i++ { + pool := c.connPool.connections + idx := expiredConnIds[i] + for j := 0; len(pool[idx].pending) > 0; j++ { + time.Sleep(time.Duration(10) * time.Millisecond) + if j > 0 && j%100 == 0 { + log.Info(fmt.Sprintf("Wait too long time for the connection to end,"+ + "connection idx: %d, ip:%s, port:%d, current connection pending size: %d", + idx, pool[idx].option.ip, pool[idx].option.port, len(pool[idx].pending))) + } + + if j > 3000 { + log.Warn("Wait too much time for the connection to end, stop ConnectionLifeCycleMgr") + return + } + } + } + + // 4. close and reconnect all expired connections + ctx, _ := context.WithTimeout(context.Background(), c.connPool.option.connectTimeout) + for i := 0; i < maxReconnIdx; i++ { + // close and reconnect + c.connPool.connections[expiredConnIds[i]].Close() + _, err := c.connPool.RecreateConnection(ctx, expiredConnIds[i]) + if err != nil { + log.Warn("reconnect failed", zap.Error(err)) + return + } + } + if maxReconnIdx > 0 { + log.Info(fmt.Sprintf("Finish to refresh expired connections which idx less than %d", maxReconnIdx)) + } +} diff --git a/obkvrpc/connection_mgr.go b/obkvrpc/connection_mgr.go new file mode 100644 index 0000000..416eea0 --- /dev/null +++ b/obkvrpc/connection_mgr.go @@ -0,0 +1,88 @@ +/*- + * #%L + * OBKV Table Client Framework + * %% + * Copyright (C) 2023 OceanBase + * %% + * OBKV Table Client Framework is licensed under Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, + * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, + * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + * #L% + */ + +package obkvrpc + +import ( + "fmt" + "github.com/oceanbase/obkv-table-client-go/log" + "sync" + "time" +) + +const DefaultConnectWaitTime = time.Duration(3) * time.Second +const ConnectionMgrTaskInterval = time.Duration(3) * time.Second + +type ConnectionMgr struct { + connLifeCycleMgr *ConnectionLifeCycleMgr + slbLoader *SLBLoader + needStop chan bool + wg sync.WaitGroup +} + +func NewConnectionMgr(p *ConnectionPool) *ConnectionMgr { + connMgr := &ConnectionMgr{ + needStop: make(chan bool), + } + if p.option.enableSLBLoadBalance { + connMgr.slbLoader = NewSLBLoader(p) + connMgr.slbLoader.refreshSLBList() + } + + if p.option.maxConnectionAge > 0 || p.option.enableSLBLoadBalance { + connMgr.connLifeCycleMgr = NewConnectionLifeCycleMgr(p, p.option.maxConnectionAge) + } + return connMgr +} + +func (c *ConnectionMgr) start() { + ticker := time.NewTicker(ConnectionMgrTaskInterval) + go func() { + c.wg.Add(1) + defer c.wg.Done() + for { + select { + case <-c.needStop: + ticker.Stop() + log.Info("Stop ConnectionMgr") + return + case <-ticker.C: + c.run() + } + } + }() + log.Info("start ConnectionMgr, " + c.String()) +} + +func (c *ConnectionMgr) run() { + if c.slbLoader != nil { + c.slbLoader.run() + } + if c.connLifeCycleMgr != nil { + c.connLifeCycleMgr.run() + } +} + +func (c *ConnectionMgr) close() { + c.needStop <- true + c.wg.Wait() +} + +func (c *ConnectionMgr) String() string { + return fmt.Sprintf("ConnectionMgr{connLifeCycleMgr: %s, slbLoader: %s, needStop: %d, wg: %v}", + c.connLifeCycleMgr, c.slbLoader, c.needStop, c.wg) +} diff --git a/obkvrpc/connection_pool.go b/obkvrpc/connection_pool.go index 0e00bd6..54e4ed7 100644 --- a/obkvrpc/connection_pool.go +++ b/obkvrpc/connection_pool.go @@ -19,6 +19,8 @@ package obkvrpc import ( "context" + "fmt" + "github.com/oceanbase/obkv-table-client-go/log" "math/rand" "sync" "time" @@ -37,6 +39,9 @@ type PoolOption struct { databaseName string userName string password string + + maxConnectionAge time.Duration + enableSLBLoadBalance bool } type ConnectionPool struct { @@ -44,20 +49,23 @@ type ConnectionPool struct { connections []*Connection rwMutexes []sync.RWMutex + connMgr *ConnectionMgr } func NewPoolOption(ip string, port int, connPoolMaxConnSize int, connectTimeout time.Duration, loginTimeout time.Duration, - tenantName string, databaseName string, userName string, password string) *PoolOption { + tenantName string, databaseName string, userName string, password string, maxConnectionAge time.Duration, enableSLBLoadBalance bool) *PoolOption { return &PoolOption{ - ip: ip, - port: port, - connPoolMaxConnSize: connPoolMaxConnSize, - connectTimeout: connectTimeout, - loginTimeout: loginTimeout, - tenantName: tenantName, - databaseName: databaseName, - userName: userName, - password: password, + ip: ip, + port: port, + connPoolMaxConnSize: connPoolMaxConnSize, + connectTimeout: connectTimeout, + loginTimeout: loginTimeout, + tenantName: tenantName, + databaseName: databaseName, + userName: userName, + password: password, + maxConnectionAge: maxConnectionAge, + enableSLBLoadBalance: enableSLBLoadBalance, } } @@ -68,23 +76,15 @@ func NewConnectionPool(option *PoolOption) (*ConnectionPool, error) { rwMutexes: make([]sync.RWMutex, 0, option.connPoolMaxConnSize), } - connectionOption := NewConnectionOption(pool.option.ip, pool.option.port, pool.option.connectTimeout, pool.option.loginTimeout, - pool.option.tenantName, pool.option.databaseName, pool.option.userName, pool.option.password) + if option.maxConnectionAge > 0 || option.enableSLBLoadBalance { + pool.connMgr = NewConnectionMgr(pool) + } for i := 0; i < pool.option.connPoolMaxConnSize; i++ { - - connection := NewConnection(connectionOption) - ctx, _ := context.WithTimeout(context.Background(), pool.option.connectTimeout) - err := connection.Connect(ctx) + connection, err := pool.CreateConnection(ctx) if err != nil { - return nil, errors.WithMessage(err, "connection connect") - } - - ctx, _ = context.WithTimeout(context.Background(), pool.option.loginTimeout) - err = connection.Login(ctx) - if err != nil { - return nil, errors.WithMessage(err, "connection login") + return nil, errors.WithMessage(err, "create connection") } pool.connections = append(pool.connections, connection) @@ -92,11 +92,25 @@ func NewConnectionPool(option *PoolOption) (*ConnectionPool, error) { } + if pool.connMgr != nil { + pool.connMgr.start() + } + return pool, nil } +// GetConnection Find an unexpired and active connection to use +// In theory, all connection won't expire at the same time func (p *ConnectionPool) GetConnection() (*Connection, int) { index := rand.Intn(p.option.connPoolMaxConnSize) + for i := 0; i < p.option.connPoolMaxConnSize; i++ { + if p.connections[(index+i)%p.option.connPoolMaxConnSize].isExpired.Load() == false { + index = (index + i) % p.option.connPoolMaxConnSize + break + } else if i == p.option.connPoolMaxConnSize-1 { + log.Warn("All connections is expired, will pick a expired connection") + } + } p.rwMutexes[index].RLock() defer p.rwMutexes[index].RUnlock() @@ -128,7 +142,8 @@ func (p *ConnectionPool) RecreateConnection(ctx context.Context, connectionIdx i } func (p *ConnectionPool) CreateConnection(ctx context.Context) (*Connection, error) { - connectionOption := NewConnectionOption(p.option.ip, p.option.port, p.option.connectTimeout, p.option.loginTimeout, + ip, port := p.getNextConnAddress() + connectionOption := NewConnectionOption(ip, port, p.option.connectTimeout, p.option.loginTimeout, p.option.tenantName, p.option.databaseName, p.option.userName, p.option.password) connection := NewConnection(connectionOption) err := connection.Connect(ctx) @@ -139,6 +154,12 @@ func (p *ConnectionPool) CreateConnection(ctx context.Context) (*Connection, err if err != nil { return nil, errors.WithMessage(err, "connection login") } + // put it to here to ensure connection should not expire during connect & login phase + if p.option.maxConnectionAge > 0 { + connection.expireTime = time.Now().Add(p.option.maxConnectionAge) + } + log.Info(fmt.Sprintf("connect success, remote addr:%s, expire time: %s", + connection.conn.RemoteAddr().String(), connection.expireTime.String())) return connection, nil } @@ -146,4 +167,19 @@ func (p *ConnectionPool) Close() { for _, connection := range p.connections { connection.Close() } + + if p.connMgr != nil { + p.connMgr.close() + } +} + +func (p *ConnectionPool) getNextConnAddress() (string, int) { + ip := p.option.ip + port := p.option.port + if p.connMgr != nil && p.connMgr.slbLoader != nil { + ip = p.connMgr.slbLoader.getNextSLBAddress() + log.Info(fmt.Sprintf("Get a SLB address %s:%d", ip, port)) + } + + return ip, port } diff --git a/obkvrpc/rpc_client.go b/obkvrpc/rpc_client.go index 0064bb8..6bea817 100644 --- a/obkvrpc/rpc_client.go +++ b/obkvrpc/rpc_client.go @@ -38,6 +38,9 @@ type RpcClientOption struct { databaseName string userName string password string + + maxConnectionAge time.Duration + enableSLBLoadBalance bool } func (o *RpcClientOption) ConnPoolMaxConnSize() int { @@ -53,17 +56,19 @@ func (o *RpcClientOption) LoginTimeout() time.Duration { } func NewRpcClientOption(ip string, port int, connPoolMaxConnSize int, connectTimeout time.Duration, loginTimeout time.Duration, - tenantName string, databaseName string, userName string, password string) *RpcClientOption { + tenantName string, databaseName string, userName string, password string, maxConnectionAge time.Duration, enableSLBLoadBalance bool) *RpcClientOption { return &RpcClientOption{ - ip: ip, - port: port, - connPoolMaxConnSize: connPoolMaxConnSize, - connectTimeout: connectTimeout, - loginTimeout: loginTimeout, - tenantName: tenantName, - databaseName: databaseName, - userName: userName, - password: password, + ip: ip, + port: port, + connPoolMaxConnSize: connPoolMaxConnSize, + connectTimeout: connectTimeout, + loginTimeout: loginTimeout, + tenantName: tenantName, + databaseName: databaseName, + userName: userName, + password: password, + maxConnectionAge: maxConnectionAge, + enableSLBLoadBalance: enableSLBLoadBalance, } } @@ -77,7 +82,9 @@ func (o *RpcClientOption) String() string { "tenantName:" + o.tenantName + ", " + "databaseName:" + o.databaseName + ", " + "userName:" + o.userName + ", " + - "password:" + o.password + + "password:" + o.password + ", " + + "maxConnectionAge:" + o.maxConnectionAge.String() + ", " + + "enableSLBLoadBalance:" + strconv.FormatBool(o.enableSLBLoadBalance) + "}" } @@ -95,7 +102,7 @@ func NewRpcClient(rpcClientOption *RpcClientOption) (*RpcClient, error) { client := &RpcClient{option: rpcClientOption} poolOption := NewPoolOption(client.option.ip, client.option.port, client.option.connPoolMaxConnSize, client.option.connectTimeout, client.option.loginTimeout, - client.option.tenantName, client.option.databaseName, client.option.userName, client.option.password) + client.option.tenantName, client.option.databaseName, client.option.userName, client.option.password, client.option.maxConnectionAge, client.option.enableSLBLoadBalance) connectionPool, err := NewConnectionPool(poolOption) if err != nil { return nil, errors.WithMessage(err, "create connection pool") diff --git a/obkvrpc/slb_loader.go b/obkvrpc/slb_loader.go new file mode 100644 index 0000000..7d74ea4 --- /dev/null +++ b/obkvrpc/slb_loader.go @@ -0,0 +1,103 @@ +/*- + * #%L + * OBKV Table Client Framework + * %% + * Copyright (C) 2023 OceanBase + * %% + * OBKV Table Client Framework is licensed under Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, + * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, + * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + * #L% + */ + +package obkvrpc + +import ( + "fmt" + "github.com/oceanbase/obkv-table-client-go/log" + "github.com/pkg/errors" + "github.com/scylladb/go-set/strset" + "go.uber.org/zap" + "net" + "sync" + "sync/atomic" + "time" +) + +type SLBLoader struct { + connPool *ConnectionPool + dnsAddress string + round atomic.Int64 + mutex sync.RWMutex + slbAddress []string +} + +// refresh SLB list from DNS address +func (s *SLBLoader) refreshSLBList() (bool, error) { + ips, err := net.LookupIP(s.dnsAddress) + if err != nil { + return false, errors.WithMessagef(err, "fail to look up slb address, dns addr: %s", s.dnsAddress) + } + slbAddress := strset.NewWithSize(len(ips)) + for _, ip := range ips { + slbAddress.Add(ip.String()) + } + changed := !slbAddress.IsEqual(strset.New(s.slbAddress...)) + s.mutex.Lock() + defer s.mutex.Unlock() + if changed { + log.Info(fmt.Sprint("SLB address changed, before: ", s.slbAddress, ", after: ", slbAddress)) + s.slbAddress = slbAddress.List() + } + return changed, nil +} + +// round-robin get next slb address from slb list +func (s *SLBLoader) getNextSLBAddress() string { + s.mutex.RLock() + defer s.mutex.RUnlock() + slbNum := len(s.slbAddress) + if slbNum > 0 { + slbAddr := s.slbAddress[(s.round.Add(1))%(int64(slbNum))] + return slbAddr + } + return s.dnsAddress +} + +// refresh SLBList and refresh connection expire time if SLBList changed +func (s *SLBLoader) run() { + changed, err := s.refreshSLBList() + if err != nil { + log.Warn("reconnect failed", zap.Error(err)) + return + } + if changed { + s.refreshConnectionLife() + } +} + +func (s *SLBLoader) refreshConnectionLife() { + for _, conn := range s.connPool.connections { + conn.expireTime = time.Now() + } +} + +func (s *SLBLoader) String() string { + return fmt.Sprintf("SLBLoader{connPool: %p, dnsAddress: %s, round: %d, slbAddress: %v}", + s.connPool, s.dnsAddress, s.round.Load(), s.slbAddress) +} + +func NewSLBLoader(p *ConnectionPool) *SLBLoader { + slbLoader := &SLBLoader{ + slbAddress: make([]string, 0, 10), + dnsAddress: p.option.ip, + connPool: p, + } + slbLoader.round.Store(-1) + return slbLoader +} diff --git a/test/connection_balance/all_test.go b/test/connection_balance/all_test.go new file mode 100644 index 0000000..4804da0 --- /dev/null +++ b/test/connection_balance/all_test.go @@ -0,0 +1,58 @@ +/*- + * #%L + * OBKV Table Client Framework + * %% + * Copyright (C) 2021 OceanBase + * %% + * OBKV Table Client Framework is licensed under Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, + * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, + * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + * #L% + */ + +package connection_balance + +import ( + "fmt" + "github.com/oceanbase/obkv-table-client-go/client" + "github.com/oceanbase/obkv-table-client-go/test" + "os" + "testing" + "time" +) + +var cli client.Client + +const ( + testConnectionBalanceTableName = "test_connection_balance" + // NOTE: cannot create table directly in obkv cluster + testConnectionBalanceCreateStatement = "create table if not exists `test_connection_balance`(`c1` varchar(1024) primary key,`c2` int);" + concurrencyNum = 300 + // NOTE: make sure test timeout is greatter than testDuration, e.g, go test -timeout 10m + testDuration = time.Duration(8) * time.Minute + maxConnectionAge = time.Duration(10) * time.Second + connectionPoolSize = 150 + enableSLBLoadBalance = true +) + +func setup() { + cli = test.CreateConnectionBalanceClient(maxConnectionAge, enableSLBLoadBalance, connectionPoolSize) + fmt.Println("connection balance setup") +} + +func teardown() { + cli.Close() + fmt.Println("connection balance teardown") +} + +func TestMain(m *testing.M) { + setup() + code := m.Run() + teardown() + os.Exit(code) +} diff --git a/test/connection_balance/connection_balance_test.go b/test/connection_balance/connection_balance_test.go new file mode 100644 index 0000000..66f1887 --- /dev/null +++ b/test/connection_balance/connection_balance_test.go @@ -0,0 +1,64 @@ +/*- + * #%L + * OBKV Table Client Framework + * %% + * Copyright (C) 2021 OceanBase + * %% + * OBKV Table Client Framework is licensed under Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, + * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, + * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + * #L% + */ + +package connection_balance + +import ( + "context" + "fmt" + "github.com/oceanbase/obkv-table-client-go/log" + "github.com/oceanbase/obkv-table-client-go/table" + "github.com/stretchr/testify/assert" + "sync" + "testing" + "time" +) + +func run(i int, done chan bool, wg *sync.WaitGroup, t *testing.T) { + defer wg.Done() + executeNum := 0 + for { + select { + case <-done: + log.Info(fmt.Sprintf("Finish %d worker, executeNum: %d", i, executeNum)) + return + default: + rowKey := []*table.Column{table.NewColumn("c1", fmt.Sprintf("key%d", i))} + mutateColumns := []*table.Column{table.NewColumn("c2", int32(1))} + ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) // 10s + affectRows, err := cli.InsertOrUpdate(ctx, testConnectionBalanceTableName, rowKey, mutateColumns) + assert.Equal(t, nil, err) + assert.EqualValues(t, 1, affectRows) + executeNum++ + } + } +} + +func TestMaxConnectionAge(t *testing.T) { + println("Test begin") + done := make(chan bool) + var wg sync.WaitGroup + for i := 0; i < concurrencyNum; i++ { + wg.Add(1) + go run(i, done, &wg, t) + } + time.Sleep(testDuration) + close(done) + println("Wait All Coroutine finish") + wg.Wait() + println("Test Finished") +} diff --git a/test/util.go b/test/util.go index 9290017..8661f5d 100644 --- a/test/util.go +++ b/test/util.go @@ -20,11 +20,10 @@ package test import ( "database/sql" "fmt" - _ "github.com/go-sql-driver/mysql" - "github.com/oceanbase/obkv-table-client-go/client" "github.com/oceanbase/obkv-table-client-go/config" + "time" ) const ( @@ -79,6 +78,20 @@ func CreateMoveClient() client.Client { return cli } +func CreateConnectionBalanceClient(maxConnectionAge time.Duration, enableSLBLoadBalance bool, connectionPoolSize int) client.Client { + cfg := config.NewDefaultClientConfig() + cfg.MaxConnectionAge = maxConnectionAge + cfg.ConnPoolMaxConnSize = connectionPoolSize + cfg.EnableSLBLoadBalance = enableSLBLoadBalance + + cli, err := client.NewOdpClient(odpFullUserName, odpPassWord, odpIP, odpRpcPort, database, cfg) + if err != nil { + panic(err.Error()) + } + println("connection Balance Client Created") + return cli +} + var GlobalDB *sql.DB func CreateDB() {