Skip to content

Commit

Permalink
table session and table session pool (#111)
Browse files Browse the repository at this point in the history
* table session and table session pool

* fix cluster config

* fix enableRPCCompression in reconnect

* add ut

* add session pool tests

* fix atomic usage to 1.13 version

* update common.go

* add license

* fix tablet test

* add doc

* add comment

* add comment

* return success code when tablet is empty
  • Loading branch information
shuwenwei authored Dec 9, 2024
1 parent 223b9e0 commit 8fc3fd3
Show file tree
Hide file tree
Showing 14 changed files with 3,082 additions and 859 deletions.
3 changes: 2 additions & 1 deletion client/bitmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ var UnmarkBitUtil = []byte{
}

func NewBitMap(size int) *BitMap {
// Need to maintain consistency with the calculation method on the IoTDB side.
bitMap := &BitMap{
size: size,
bits: make([]byte, (size+7)/8),
bits: make([]byte, size/8+1),
}
return bitMap
}
Expand Down
4 changes: 4 additions & 0 deletions client/rpcdataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ func (s *IoTDBRpcDataSet) getColumnType(columnName string) TSDataType {
return s.columnTypeDeduplicatedList[s.getColumnIndex(columnName)]
}

func (s *IoTDBRpcDataSet) isNullWithColumnName(columnName string) bool {
return s.isNull(int(s.getColumnIndex(columnName)), s.rowsIndex-1)
}

func (s *IoTDBRpcDataSet) isNull(columnIndex int, rowIndex int) bool {
if s.closed {
return true
Expand Down
147 changes: 107 additions & 40 deletions client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ const (
DefaultTimeZone = "Asia/Shanghai"
DefaultFetchSize = 1024
DefaultConnectRetryMax = 3
TreeSqlDialect = "tree"
TableSqlDialect = "table"
)

type Version string

const (
V_0_12 = Version("V_0_12")
V_0_13 = Version("V_0_13")
V_1_0 = Version("V_1_0")
DEFAULT_VERSION = V_1_0
)

var errLength = errors.New("deviceIds, times, measurementsList and valuesList's size should be equal")
Expand All @@ -54,6 +65,9 @@ type Config struct {
FetchSize int32
TimeZone string
ConnectRetryMax int
sqlDialect string
Version Version
Database string
}

type Session struct {
Expand All @@ -62,6 +76,7 @@ type Session struct {
sessionId int64
trans thrift.TTransport
requestStatementId int64
protocolFactory thrift.TProtocolFactory
}

type endPoint struct {
Expand All @@ -83,7 +98,6 @@ func (s *Session) Open(enableRPCCompression bool, connectionTimeoutInMs int) err
s.config.ConnectRetryMax = DefaultConnectRetryMax
}

var protocolFactory thrift.TProtocolFactory
var err error

// in thrift 0.14.1, this func returns two values; in thrift 0.15.0, it returns one.
Expand All @@ -99,16 +113,22 @@ func (s *Session) Open(enableRPCCompression bool, connectionTimeoutInMs int) err
return err
}
}
if enableRPCCompression {
protocolFactory = thrift.NewTCompactProtocolFactory()
} else {
protocolFactory = thrift.NewTBinaryProtocolFactoryDefault()
}
iprot := protocolFactory.GetProtocol(s.trans)
oprot := protocolFactory.GetProtocol(s.trans)
s.protocolFactory = getProtocolFactory(enableRPCCompression)
iprot := s.protocolFactory.GetProtocol(s.trans)
oprot := s.protocolFactory.GetProtocol(s.trans)
s.client = rpc.NewIClientRPCServiceClient(thrift.NewTStandardClient(iprot, oprot))
req := rpc.TSOpenSessionReq{ClientProtocol: rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, Username: s.config.UserName,
Password: &s.config.Password}
req.Configuration = make(map[string]string)
req.Configuration["sql_dialect"] = s.config.sqlDialect
if s.config.Version == "" {
req.Configuration["version"] = string(DEFAULT_VERSION)
} else {
req.Configuration["version"] = string(s.config.Version)
}
if s.config.Database != "" {
req.Configuration["db"] = s.config.Database
}
resp, err := s.client.OpenSession(context.Background(), &req)
if err != nil {
return err
Expand All @@ -125,14 +145,8 @@ type ClusterConfig struct {
FetchSize int32
TimeZone string
ConnectRetryMax int
}

type ClusterSession struct {
config *ClusterConfig
client *rpc.IClientRPCServiceClient
sessionId int64
trans thrift.TTransport
requestStatementId int64
sqlDialect string
Database string
}

func (s *Session) OpenCluster(enableRPCCompression bool) error {
Expand All @@ -147,19 +161,24 @@ func (s *Session) OpenCluster(enableRPCCompression bool) error {
s.config.ConnectRetryMax = DefaultConnectRetryMax
}

var protocolFactory thrift.TProtocolFactory
var err error

if enableRPCCompression {
protocolFactory = thrift.NewTCompactProtocolFactory()
} else {
protocolFactory = thrift.NewTBinaryProtocolFactoryDefault()
}
iprot := protocolFactory.GetProtocol(s.trans)
oprot := protocolFactory.GetProtocol(s.trans)
s.protocolFactory = getProtocolFactory(enableRPCCompression)
iprot := s.protocolFactory.GetProtocol(s.trans)
oprot := s.protocolFactory.GetProtocol(s.trans)
s.client = rpc.NewIClientRPCServiceClient(thrift.NewTStandardClient(iprot, oprot))
req := rpc.TSOpenSessionReq{ClientProtocol: rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, Username: s.config.UserName,
Password: &s.config.Password}
req.Configuration = make(map[string]string)
req.Configuration["sql_dialect"] = s.config.sqlDialect
if s.config.Version == "" {
req.Configuration["version"] = string(DEFAULT_VERSION)
} else {
req.Configuration["version"] = string(s.config.Version)
}
if s.config.Database != "" {
req.Configuration["db"] = s.config.Database
}

resp, err := s.client.OpenSession(context.Background(), &req)
if err != nil {
Expand All @@ -170,14 +189,22 @@ func (s *Session) OpenCluster(enableRPCCompression bool) error {
return err
}

func (s *Session) Close() (r *common.TSStatus, err error) {
func getProtocolFactory(enableRPCCompression bool) thrift.TProtocolFactory {
if enableRPCCompression {
return thrift.NewTCompactProtocolFactoryConf(&thrift.TConfiguration{})
} else {
return thrift.NewTBinaryProtocolFactoryConf(&thrift.TConfiguration{})
}
}

func (s *Session) Close() error {
req := rpc.NewTSCloseSessionReq()
req.SessionId = s.sessionId
_, err = s.client.CloseSession(context.Background(), req)
_, err := s.client.CloseSession(context.Background(), req)
if err != nil {
return nil, err
return err
}
return nil, s.trans.Close()
return s.trans.Close()
}

/*
Expand Down Expand Up @@ -453,10 +480,17 @@ func (s *Session) ExecuteNonQueryStatement(sql string) (r *common.TSStatus, err
resp, err = s.client.ExecuteStatement(context.Background(), &request)
}
}
if resp.IsSetDatabase() {
s.changeDatabase(*resp.Database)
}

return resp.Status, err
}

func (s *Session) changeDatabase(database string) {
s.config.Database = database
}

func (s *Session) ExecuteQueryStatement(sql string, timeoutMs *int64) (*SessionDataSet, error) {
request := rpc.TSExecuteStatementReq{SessionId: s.sessionId, Statement: sql, StatementId: s.requestStatementId,
FetchSize: &s.config.FetchSize, Timeout: timeoutMs}
Expand Down Expand Up @@ -613,7 +647,7 @@ func (d *deviceData) Swap(i, j int) {
// InsertRecordsOfOneDevice Insert multiple rows, which can reduce the overhead of network. This method is just like jdbc
// executeBatch, we pack some insert request in batch and send them to server. If you want improve
// your performance, please see insertTablet method
// Each row is independent, which could have different deviceId, time, number of measurements
// Each row is independent, which could have different insertTargetName, time, number of measurements
func (s *Session) InsertRecordsOfOneDevice(deviceId string, timestamps []int64, measurementsSlice [][]string, dataTypesSlice [][]TSDataType, valuesSlice [][]interface{}, sorted bool) (r *common.TSStatus, err error) {
length := len(timestamps)
if len(measurementsSlice) != length || len(dataTypesSlice) != length || len(valuesSlice) != length {
Expand Down Expand Up @@ -873,7 +907,7 @@ func (s *Session) genInsertTabletsReq(tablets []*Tablet, isAligned bool) (*rpc.T
sizeList = make([]int32, length)
)
for index, tablet := range tablets {
deviceIds[index] = tablet.deviceId
deviceIds[index] = tablet.insertTargetName
measurementsList[index] = tablet.GetMeasurements()

values, err := tablet.getValuesBytes()
Expand Down Expand Up @@ -1009,13 +1043,35 @@ func valuesToBytes(dataTypes []TSDataType, values []interface{}) ([]byte, error)
return buff.Bytes(), nil
}

func (s *Session) insertRelationalTablet(tablet *Tablet) (r *common.TSStatus, err error) {
if tablet.Len() == 0 {
return &common.TSStatus{Code: SuccessStatus}, nil
}
request, err := s.genTSInsertTabletReq(tablet, true, true)
if err != nil {
return nil, err
}
request.ColumnCategories = tablet.getColumnCategories()

r, err = s.client.InsertTablet(context.Background(), request)

if err != nil && r == nil {
if s.reconnect() {
request.SessionId = s.sessionId
r, err = s.client.InsertTablet(context.Background(), request)
}
}

return r, err
}

func (s *Session) InsertTablet(tablet *Tablet, sorted bool) (r *common.TSStatus, err error) {
if !sorted {
if err := tablet.Sort(); err != nil {
return nil, err
}
}
request, err := s.genTSInsertTabletReq(tablet, false)
request, err := s.genTSInsertTabletReq(tablet, false, false)
if err != nil {
return nil, err
}
Expand All @@ -1038,7 +1094,7 @@ func (s *Session) InsertAlignedTablet(tablet *Tablet, sorted bool) (r *common.TS
return nil, err
}
}
request, err := s.genTSInsertTabletReq(tablet, true)
request, err := s.genTSInsertTabletReq(tablet, true, false)
if err != nil {
return nil, err
}
Expand All @@ -1055,17 +1111,18 @@ func (s *Session) InsertAlignedTablet(tablet *Tablet, sorted bool) (r *common.TS
return r, err
}

func (s *Session) genTSInsertTabletReq(tablet *Tablet, isAligned bool) (*rpc.TSInsertTabletReq, error) {
func (s *Session) genTSInsertTabletReq(tablet *Tablet, isAligned bool, writeToTable bool) (*rpc.TSInsertTabletReq, error) {
if values, err := tablet.getValuesBytes(); err == nil {
request := &rpc.TSInsertTabletReq{
SessionId: s.sessionId,
PrefixPath: tablet.deviceId,
PrefixPath: tablet.insertTargetName,
Measurements: tablet.GetMeasurements(),
Values: values,
Timestamps: tablet.GetTimestampBytes(),
Types: tablet.getDataTypes(),
Size: int32(tablet.RowSize),
IsAligned: &isAligned,
WriteToTable: &writeToTable,
}
return request, nil
} else {
Expand All @@ -1078,6 +1135,11 @@ func (s *Session) GetSessionId() int64 {
}

func NewSession(config *Config) Session {
config.sqlDialect = TreeSqlDialect
return newSessionWithSpecifiedSqlDialect(config)
}

func newSessionWithSpecifiedSqlDialect(config *Config) Session {
endPoint := endPoint{}
endPoint.Host = config.Host
endPoint.Port = config.Port
Expand All @@ -1086,6 +1148,11 @@ func NewSession(config *Config) Session {
}

func NewClusterSession(clusterConfig *ClusterConfig) Session {
clusterConfig.sqlDialect = TreeSqlDialect
return newClusterSessionWithSqlDialect(clusterConfig)
}

func newClusterSessionWithSqlDialect(clusterConfig *ClusterConfig) Session {
session := Session{}
node := endPoint{}
for i := 0; i < len(clusterConfig.NodeUrls); i++ {
Expand All @@ -1107,7 +1174,7 @@ func NewClusterSession(clusterConfig *ClusterConfig) Session {
log.Println(err)
} else {
session.config = getConfig(e.Value.(endPoint).Host, e.Value.(endPoint).Port,
clusterConfig.UserName, clusterConfig.Password, clusterConfig.FetchSize, clusterConfig.TimeZone, clusterConfig.ConnectRetryMax)
clusterConfig.UserName, clusterConfig.Password, clusterConfig.FetchSize, clusterConfig.TimeZone, clusterConfig.ConnectRetryMax, clusterConfig.Database, clusterConfig.sqlDialect)
break
}
}
Expand Down Expand Up @@ -1148,10 +1215,8 @@ func (s *Session) initClusterConn(node endPoint) error {
s.config.ConnectRetryMax = DefaultConnectRetryMax
}

var protocolFactory thrift.TProtocolFactory
protocolFactory = thrift.NewTBinaryProtocolFactoryDefault()
iprot := protocolFactory.GetProtocol(s.trans)
oprot := protocolFactory.GetProtocol(s.trans)
iprot := s.protocolFactory.GetProtocol(s.trans)
oprot := s.protocolFactory.GetProtocol(s.trans)
s.client = rpc.NewIClientRPCServiceClient(thrift.NewTStandardClient(iprot, oprot))
req := rpc.TSOpenSessionReq{ClientProtocol: rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, Username: s.config.UserName,
Password: &s.config.Password}
Expand All @@ -1166,7 +1231,7 @@ func (s *Session) initClusterConn(node endPoint) error {

}

func getConfig(host string, port string, userName string, passWord string, fetchSize int32, timeZone string, connectRetryMax int) *Config {
func getConfig(host string, port string, userName string, passWord string, fetchSize int32, timeZone string, connectRetryMax int, database string, sqlDialect string) *Config {
return &Config{
Host: host,
Port: port,
Expand All @@ -1175,6 +1240,8 @@ func getConfig(host string, port string, userName string, passWord string, fetch
FetchSize: fetchSize,
TimeZone: timeZone,
ConnectRetryMax: connectRetryMax,
sqlDialect: sqlDialect,
Database: database,
}
}

Expand Down
4 changes: 4 additions & 0 deletions client/sessiondataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ func (s *SessionDataSet) GetText(columnName string) string {
return s.ioTDBRpcDataSet.getText(columnName)
}

func (s *SessionDataSet) IsNull(columnName string) bool {
return s.ioTDBRpcDataSet.isNullWithColumnName(columnName)
}

func (s *SessionDataSet) GetBool(columnName string) bool {
return s.ioTDBRpcDataSet.getBool(columnName)
}
Expand Down
Loading

0 comments on commit 8fc3fd3

Please sign in to comment.