From e269bad2d6591dcfd624f7ecbf64393f92b03461 Mon Sep 17 00:00:00 2001 From: shuwenwei <55970239+shuwenwei@users.noreply.github.com> Date: Wed, 11 Dec 2024 18:34:23 +0800 Subject: [PATCH] Fix bugs for rc/1.3.3 (#113) * Return error when constructing cluster session failed * fix bugs --- client/bitmap.go | 3 ++- client/rpcdataset.go | 4 +++ client/session.go | 52 ++++++++++++++++++-------------------- client/sessiondataset.go | 4 +++ client/sessionpool.go | 7 +++-- client/tablet.go | 2 ++ example/session_example.go | 7 +++-- test/e2e/e2e_test.go | 6 +++-- 8 files changed, 51 insertions(+), 34 deletions(-) diff --git a/client/bitmap.go b/client/bitmap.go index b82d121..8054eba 100644 --- a/client/bitmap.go +++ b/client/bitmap.go @@ -37,9 +37,10 @@ var UnmarkBitUtil = []byte{ } func NewBitMap(size int) *BitMap { + // Need to maintain consistency with the calculation method on the IoTDB side. bitMap := &BitMap{ size: size, - bits: make([]byte, (size+7)/8), + bits: make([]byte, size/8+1), } return bitMap } diff --git a/client/rpcdataset.go b/client/rpcdataset.go index 8477950..11ec81a 100644 --- a/client/rpcdataset.go +++ b/client/rpcdataset.go @@ -89,6 +89,10 @@ func (s *IoTDBRpcDataSet) getColumnType(columnName string) TSDataType { return s.columnTypeDeduplicatedList[s.getColumnIndex(columnName)] } +func (s *IoTDBRpcDataSet) isNullWithColumnName(columnName string) bool { + return s.isNull(int(s.getColumnIndex(columnName)), s.rowsIndex-1) +} + func (s *IoTDBRpcDataSet) isNull(columnIndex int, rowIndex int) bool { if s.closed { return true diff --git a/client/session.go b/client/session.go index 2859d8c..e5f2184 100644 --- a/client/session.go +++ b/client/session.go @@ -62,6 +62,7 @@ type Session struct { sessionId int64 trans thrift.TTransport requestStatementId int64 + protocolFactory thrift.TProtocolFactory } type endPoint struct { @@ -83,7 +84,6 @@ func (s *Session) Open(enableRPCCompression bool, connectionTimeoutInMs int) err s.config.ConnectRetryMax = DefaultConnectRetryMax } - var protocolFactory thrift.TProtocolFactory var err error // in thrift 0.14.1, this func returns two values; in thrift 0.15.0, it returns one. @@ -99,13 +99,10 @@ func (s *Session) Open(enableRPCCompression bool, connectionTimeoutInMs int) err return err } } - if enableRPCCompression { - protocolFactory = thrift.NewTCompactProtocolFactory() - } else { - protocolFactory = thrift.NewTBinaryProtocolFactoryDefault() - } - iprot := protocolFactory.GetProtocol(s.trans) - oprot := protocolFactory.GetProtocol(s.trans) + s.protocolFactory = getProtocolFactory(enableRPCCompression) + iprot := s.protocolFactory.GetProtocol(s.trans) + oprot := s.protocolFactory.GetProtocol(s.trans) + s.client = rpc.NewIClientRPCServiceClient(thrift.NewTStandardClient(iprot, oprot)) req := rpc.TSOpenSessionReq{ClientProtocol: rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, Username: s.config.UserName, Password: &s.config.Password} @@ -147,16 +144,11 @@ func (s *Session) OpenCluster(enableRPCCompression bool) error { s.config.ConnectRetryMax = DefaultConnectRetryMax } - var protocolFactory thrift.TProtocolFactory var err error - if enableRPCCompression { - protocolFactory = thrift.NewTCompactProtocolFactory() - } else { - protocolFactory = thrift.NewTBinaryProtocolFactoryDefault() - } - iprot := protocolFactory.GetProtocol(s.trans) - oprot := protocolFactory.GetProtocol(s.trans) + s.protocolFactory = getProtocolFactory(enableRPCCompression) + iprot := s.protocolFactory.GetProtocol(s.trans) + oprot := s.protocolFactory.GetProtocol(s.trans) s.client = rpc.NewIClientRPCServiceClient(thrift.NewTStandardClient(iprot, oprot)) req := rpc.TSOpenSessionReq{ClientProtocol: rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, Username: s.config.UserName, Password: &s.config.Password} @@ -170,14 +162,22 @@ func (s *Session) OpenCluster(enableRPCCompression bool) error { return err } -func (s *Session) Close() (r *common.TSStatus, err error) { +func getProtocolFactory(enableRPCCompression bool) thrift.TProtocolFactory { + if enableRPCCompression { + return thrift.NewTCompactProtocolFactoryConf(&thrift.TConfiguration{}) + } else { + return thrift.NewTBinaryProtocolFactoryConf(&thrift.TConfiguration{}) + } +} + +func (s *Session) Close() error { req := rpc.NewTSCloseSessionReq() req.SessionId = s.sessionId - _, err = s.client.CloseSession(context.Background(), req) + _, err := s.client.CloseSession(context.Background(), req) if err != nil { - return nil, err + return err } - return nil, s.trans.Close() + return s.trans.Close() } /* @@ -1085,7 +1085,7 @@ func NewSession(config *Config) Session { return Session{config: config} } -func NewClusterSession(clusterConfig *ClusterConfig) Session { +func NewClusterSession(clusterConfig *ClusterConfig) (Session, error) { session := Session{} node := endPoint{} for i := 0; i < len(clusterConfig.NodeUrls); i++ { @@ -1113,9 +1113,9 @@ func NewClusterSession(clusterConfig *ClusterConfig) Session { } } if !session.trans.IsOpen() { - log.Fatal("No Server Can Connect") + return session, fmt.Errorf("no server can connect") } - return session + return session, nil } func (s *Session) initClusterConn(node endPoint) error { @@ -1148,10 +1148,8 @@ func (s *Session) initClusterConn(node endPoint) error { s.config.ConnectRetryMax = DefaultConnectRetryMax } - var protocolFactory thrift.TProtocolFactory - protocolFactory = thrift.NewTBinaryProtocolFactoryDefault() - iprot := protocolFactory.GetProtocol(s.trans) - oprot := protocolFactory.GetProtocol(s.trans) + iprot := s.protocolFactory.GetProtocol(s.trans) + oprot := s.protocolFactory.GetProtocol(s.trans) s.client = rpc.NewIClientRPCServiceClient(thrift.NewTStandardClient(iprot, oprot)) req := rpc.TSOpenSessionReq{ClientProtocol: rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, Username: s.config.UserName, Password: &s.config.Password} diff --git a/client/sessiondataset.go b/client/sessiondataset.go index 2bce854..b1fafd4 100644 --- a/client/sessiondataset.go +++ b/client/sessiondataset.go @@ -46,6 +46,10 @@ func (s *SessionDataSet) GetText(columnName string) string { return s.ioTDBRpcDataSet.getText(columnName) } +func (s *SessionDataSet) IsNull(columnName string) bool { + return s.ioTDBRpcDataSet.isNullWithColumnName(columnName) +} + func (s *SessionDataSet) GetBool(columnName string) bool { return s.ioTDBRpcDataSet.getBool(columnName) } diff --git a/client/sessionpool.go b/client/sessionpool.go index 156ce2a..48b322a 100644 --- a/client/sessionpool.go +++ b/client/sessionpool.go @@ -95,8 +95,11 @@ func (spool *SessionPool) GetSession() (session Session, err error) { func (spool *SessionPool) ConstructSession(config *PoolConfig) (session Session, err error) { if len(config.NodeUrls) > 0 { - session = NewClusterSession(getClusterSessionConfig(config)) - if err := session.OpenCluster(spool.enableCompression); err != nil { + session, err = NewClusterSession(getClusterSessionConfig(config)) + if err != nil { + return session, err + } + if err = session.OpenCluster(spool.enableCompression); err != nil { log.Print(err) return session, err } diff --git a/client/tablet.go b/client/tablet.go index e5d7daa..7b62f89 100644 --- a/client/tablet.go +++ b/client/tablet.go @@ -119,6 +119,7 @@ func (t *Tablet) SetValueAt(value interface{}, columnIndex, rowIndex int) error } // Mark the nil value position t.bitMaps[columnIndex].Mark(rowIndex) + return nil } switch t.measurementSchemas[columnIndex].DataType { @@ -296,6 +297,7 @@ func (t *Tablet) getValuesBytes() ([]byte, error) { columnHasNil := bitMap != nil && !bitMap.IsAllUnmarked() binary.Write(buff, binary.BigEndian, columnHasNil) if columnHasNil { + // Need to maintain consistency with the calculation method on the IoTDB side. binary.Write(buff, binary.BigEndian, bitMap.GetBits()[0:t.RowSize/8+1]) } } diff --git a/example/session_example.go b/example/session_example.go index 4777486..32026c1 100644 --- a/example/session_example.go +++ b/example/session_example.go @@ -152,8 +152,11 @@ func connectCluster() { UserName: "root", Password: "root", } - session = client.NewClusterSession(config) - if err := session.OpenCluster(false); err != nil { + session, err := client.NewClusterSession(config) + if err != nil { + log.Fatal(err) + } + if err = session.OpenCluster(false); err != nil { log.Fatal(err) } } diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index fd50245..e4b91ea 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -47,8 +47,10 @@ func (s *e2eTestSuite) SetupSuite() { UserName: "root", Password: "root", } - s.session = client.NewClusterSession(&clusterConfig) - err := s.session.Open(false, 0) + session, err := client.NewClusterSession(&clusterConfig) + s.Require().NoError(err) + s.session = session + err = s.session.Open(false, 0) s.Require().NoError(err) }