Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

table session and table session pool #111

Merged
merged 13 commits into from
Dec 9, 2024
3 changes: 2 additions & 1 deletion client/bitmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ var UnmarkBitUtil = []byte{
}

func NewBitMap(size int) *BitMap {
// Need to maintain consistency with the calculation method on the IoTDB side.
bitMap := &BitMap{
size: size,
bits: make([]byte, (size+7)/8),
bits: make([]byte, size/8+1),
}
return bitMap
}
Expand Down
4 changes: 4 additions & 0 deletions client/rpcdataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ func (s *IoTDBRpcDataSet) getColumnType(columnName string) TSDataType {
return s.columnTypeDeduplicatedList[s.getColumnIndex(columnName)]
}

func (s *IoTDBRpcDataSet) isNullWithColumnName(columnName string) bool {
return s.isNull(int(s.getColumnIndex(columnName)), s.rowsIndex-1)
}

func (s *IoTDBRpcDataSet) isNull(columnIndex int, rowIndex int) bool {
if s.closed {
return true
Expand Down
147 changes: 107 additions & 40 deletions client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ const (
DefaultTimeZone = "Asia/Shanghai"
DefaultFetchSize = 1024
DefaultConnectRetryMax = 3
TreeSqlDialect = "tree"
TableSqlDialect = "table"
)

type Version string

const (
V_0_12 = Version("V_0_12")
V_0_13 = Version("V_0_13")
V_1_0 = Version("V_1_0")
DEFAULT_VERSION = V_1_0
)

var errLength = errors.New("deviceIds, times, measurementsList and valuesList's size should be equal")
Expand All @@ -54,6 +65,9 @@ type Config struct {
FetchSize int32
TimeZone string
ConnectRetryMax int
sqlDialect string
Version Version
Database string
}

type Session struct {
Expand All @@ -62,6 +76,7 @@ type Session struct {
sessionId int64
trans thrift.TTransport
requestStatementId int64
protocolFactory thrift.TProtocolFactory
}

type endPoint struct {
Expand All @@ -83,7 +98,6 @@ func (s *Session) Open(enableRPCCompression bool, connectionTimeoutInMs int) err
s.config.ConnectRetryMax = DefaultConnectRetryMax
}

var protocolFactory thrift.TProtocolFactory
var err error

// in thrift 0.14.1, this func returns two values; in thrift 0.15.0, it returns one.
Expand All @@ -99,16 +113,22 @@ func (s *Session) Open(enableRPCCompression bool, connectionTimeoutInMs int) err
return err
}
}
if enableRPCCompression {
protocolFactory = thrift.NewTCompactProtocolFactory()
} else {
protocolFactory = thrift.NewTBinaryProtocolFactoryDefault()
}
iprot := protocolFactory.GetProtocol(s.trans)
oprot := protocolFactory.GetProtocol(s.trans)
s.protocolFactory = getProtocolFactory(enableRPCCompression)
iprot := s.protocolFactory.GetProtocol(s.trans)
oprot := s.protocolFactory.GetProtocol(s.trans)
s.client = rpc.NewIClientRPCServiceClient(thrift.NewTStandardClient(iprot, oprot))
req := rpc.TSOpenSessionReq{ClientProtocol: rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, Username: s.config.UserName,
Password: &s.config.Password}
req.Configuration = make(map[string]string)
req.Configuration["sql_dialect"] = s.config.sqlDialect
if s.config.Version == "" {
req.Configuration["version"] = string(DEFAULT_VERSION)
} else {
req.Configuration["version"] = string(s.config.Version)
}
if s.config.Database != "" {
req.Configuration["db"] = s.config.Database
}
resp, err := s.client.OpenSession(context.Background(), &req)
if err != nil {
return err
Expand All @@ -125,14 +145,8 @@ type ClusterConfig struct {
FetchSize int32
TimeZone string
ConnectRetryMax int
}

type ClusterSession struct {
config *ClusterConfig
client *rpc.IClientRPCServiceClient
sessionId int64
trans thrift.TTransport
requestStatementId int64
sqlDialect string
Database string
}

func (s *Session) OpenCluster(enableRPCCompression bool) error {
Expand All @@ -147,19 +161,24 @@ func (s *Session) OpenCluster(enableRPCCompression bool) error {
s.config.ConnectRetryMax = DefaultConnectRetryMax
}

var protocolFactory thrift.TProtocolFactory
var err error

if enableRPCCompression {
protocolFactory = thrift.NewTCompactProtocolFactory()
} else {
protocolFactory = thrift.NewTBinaryProtocolFactoryDefault()
}
iprot := protocolFactory.GetProtocol(s.trans)
oprot := protocolFactory.GetProtocol(s.trans)
s.protocolFactory = getProtocolFactory(enableRPCCompression)
iprot := s.protocolFactory.GetProtocol(s.trans)
oprot := s.protocolFactory.GetProtocol(s.trans)
s.client = rpc.NewIClientRPCServiceClient(thrift.NewTStandardClient(iprot, oprot))
req := rpc.TSOpenSessionReq{ClientProtocol: rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, Username: s.config.UserName,
Password: &s.config.Password}
req.Configuration = make(map[string]string)
req.Configuration["sql_dialect"] = s.config.sqlDialect
if s.config.Version == "" {
req.Configuration["version"] = string(DEFAULT_VERSION)
} else {
req.Configuration["version"] = string(s.config.Version)
}
if s.config.Database != "" {
req.Configuration["db"] = s.config.Database
}

resp, err := s.client.OpenSession(context.Background(), &req)
if err != nil {
Expand All @@ -170,14 +189,22 @@ func (s *Session) OpenCluster(enableRPCCompression bool) error {
return err
}

func (s *Session) Close() (r *common.TSStatus, err error) {
func getProtocolFactory(enableRPCCompression bool) thrift.TProtocolFactory {
if enableRPCCompression {
return thrift.NewTCompactProtocolFactoryConf(&thrift.TConfiguration{})
} else {
return thrift.NewTBinaryProtocolFactoryConf(&thrift.TConfiguration{})
}
}

func (s *Session) Close() error {
req := rpc.NewTSCloseSessionReq()
req.SessionId = s.sessionId
_, err = s.client.CloseSession(context.Background(), req)
_, err := s.client.CloseSession(context.Background(), req)
if err != nil {
return nil, err
return err
}
return nil, s.trans.Close()
return s.trans.Close()
}

/*
Expand Down Expand Up @@ -453,10 +480,17 @@ func (s *Session) ExecuteNonQueryStatement(sql string) (r *common.TSStatus, err
resp, err = s.client.ExecuteStatement(context.Background(), &request)
}
}
if resp.IsSetDatabase() {
s.changeDatabase(*resp.Database)
}

return resp.Status, err
}

func (s *Session) changeDatabase(database string) {
s.config.Database = database
}

func (s *Session) ExecuteQueryStatement(sql string, timeoutMs *int64) (*SessionDataSet, error) {
request := rpc.TSExecuteStatementReq{SessionId: s.sessionId, Statement: sql, StatementId: s.requestStatementId,
FetchSize: &s.config.FetchSize, Timeout: timeoutMs}
Expand Down Expand Up @@ -613,7 +647,7 @@ func (d *deviceData) Swap(i, j int) {
// InsertRecordsOfOneDevice Insert multiple rows, which can reduce the overhead of network. This method is just like jdbc
// executeBatch, we pack some insert request in batch and send them to server. If you want improve
// your performance, please see insertTablet method
// Each row is independent, which could have different deviceId, time, number of measurements
// Each row is independent, which could have different insertTargetName, time, number of measurements
func (s *Session) InsertRecordsOfOneDevice(deviceId string, timestamps []int64, measurementsSlice [][]string, dataTypesSlice [][]TSDataType, valuesSlice [][]interface{}, sorted bool) (r *common.TSStatus, err error) {
length := len(timestamps)
if len(measurementsSlice) != length || len(dataTypesSlice) != length || len(valuesSlice) != length {
Expand Down Expand Up @@ -873,7 +907,7 @@ func (s *Session) genInsertTabletsReq(tablets []*Tablet, isAligned bool) (*rpc.T
sizeList = make([]int32, length)
)
for index, tablet := range tablets {
deviceIds[index] = tablet.deviceId
deviceIds[index] = tablet.insertTargetName
measurementsList[index] = tablet.GetMeasurements()

values, err := tablet.getValuesBytes()
Expand Down Expand Up @@ -1009,13 +1043,35 @@ func valuesToBytes(dataTypes []TSDataType, values []interface{}) ([]byte, error)
return buff.Bytes(), nil
}

func (s *Session) insertRelationalTablet(tablet *Tablet) (r *common.TSStatus, err error) {
if tablet.Len() == 0 {
return nil, nil
}
jt2594838 marked this conversation as resolved.
Show resolved Hide resolved
request, err := s.genTSInsertTabletReq(tablet, true, true)
if err != nil {
return nil, err
}
request.ColumnCategories = tablet.getColumnCategories()

r, err = s.client.InsertTablet(context.Background(), request)

if err != nil && r == nil {
if s.reconnect() {
request.SessionId = s.sessionId
r, err = s.client.InsertTablet(context.Background(), request)
}
}

return r, err
}

func (s *Session) InsertTablet(tablet *Tablet, sorted bool) (r *common.TSStatus, err error) {
if !sorted {
if err := tablet.Sort(); err != nil {
return nil, err
}
}
request, err := s.genTSInsertTabletReq(tablet, false)
request, err := s.genTSInsertTabletReq(tablet, false, false)
if err != nil {
return nil, err
}
Expand All @@ -1038,7 +1094,7 @@ func (s *Session) InsertAlignedTablet(tablet *Tablet, sorted bool) (r *common.TS
return nil, err
}
}
request, err := s.genTSInsertTabletReq(tablet, true)
request, err := s.genTSInsertTabletReq(tablet, true, false)
if err != nil {
return nil, err
}
Expand All @@ -1055,17 +1111,18 @@ func (s *Session) InsertAlignedTablet(tablet *Tablet, sorted bool) (r *common.TS
return r, err
}

func (s *Session) genTSInsertTabletReq(tablet *Tablet, isAligned bool) (*rpc.TSInsertTabletReq, error) {
func (s *Session) genTSInsertTabletReq(tablet *Tablet, isAligned bool, writeToTable bool) (*rpc.TSInsertTabletReq, error) {
if values, err := tablet.getValuesBytes(); err == nil {
request := &rpc.TSInsertTabletReq{
SessionId: s.sessionId,
PrefixPath: tablet.deviceId,
PrefixPath: tablet.insertTargetName,
Measurements: tablet.GetMeasurements(),
Values: values,
Timestamps: tablet.GetTimestampBytes(),
Types: tablet.getDataTypes(),
Size: int32(tablet.RowSize),
IsAligned: &isAligned,
WriteToTable: &writeToTable,
}
return request, nil
} else {
Expand All @@ -1078,6 +1135,11 @@ func (s *Session) GetSessionId() int64 {
}

func NewSession(config *Config) Session {
config.sqlDialect = TreeSqlDialect
return newSessionWithSpecifiedSqlDialect(config)
}

func newSessionWithSpecifiedSqlDialect(config *Config) Session {
endPoint := endPoint{}
endPoint.Host = config.Host
endPoint.Port = config.Port
Expand All @@ -1086,6 +1148,11 @@ func NewSession(config *Config) Session {
}

func NewClusterSession(clusterConfig *ClusterConfig) Session {
clusterConfig.sqlDialect = TreeSqlDialect
return newClusterSessionWithSqlDialect(clusterConfig)
}

func newClusterSessionWithSqlDialect(clusterConfig *ClusterConfig) Session {
session := Session{}
node := endPoint{}
for i := 0; i < len(clusterConfig.NodeUrls); i++ {
Expand All @@ -1107,7 +1174,7 @@ func NewClusterSession(clusterConfig *ClusterConfig) Session {
log.Println(err)
} else {
session.config = getConfig(e.Value.(endPoint).Host, e.Value.(endPoint).Port,
clusterConfig.UserName, clusterConfig.Password, clusterConfig.FetchSize, clusterConfig.TimeZone, clusterConfig.ConnectRetryMax)
clusterConfig.UserName, clusterConfig.Password, clusterConfig.FetchSize, clusterConfig.TimeZone, clusterConfig.ConnectRetryMax, clusterConfig.Database, clusterConfig.sqlDialect)
break
}
}
Expand Down Expand Up @@ -1148,10 +1215,8 @@ func (s *Session) initClusterConn(node endPoint) error {
s.config.ConnectRetryMax = DefaultConnectRetryMax
}

var protocolFactory thrift.TProtocolFactory
protocolFactory = thrift.NewTBinaryProtocolFactoryDefault()
iprot := protocolFactory.GetProtocol(s.trans)
oprot := protocolFactory.GetProtocol(s.trans)
iprot := s.protocolFactory.GetProtocol(s.trans)
oprot := s.protocolFactory.GetProtocol(s.trans)
s.client = rpc.NewIClientRPCServiceClient(thrift.NewTStandardClient(iprot, oprot))
req := rpc.TSOpenSessionReq{ClientProtocol: rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, Username: s.config.UserName,
Password: &s.config.Password}
Expand All @@ -1166,7 +1231,7 @@ func (s *Session) initClusterConn(node endPoint) error {

}

func getConfig(host string, port string, userName string, passWord string, fetchSize int32, timeZone string, connectRetryMax int) *Config {
func getConfig(host string, port string, userName string, passWord string, fetchSize int32, timeZone string, connectRetryMax int, database string, sqlDialect string) *Config {
return &Config{
Host: host,
Port: port,
Expand All @@ -1175,6 +1240,8 @@ func getConfig(host string, port string, userName string, passWord string, fetch
FetchSize: fetchSize,
TimeZone: timeZone,
ConnectRetryMax: connectRetryMax,
sqlDialect: sqlDialect,
Database: database,
}
}

Expand Down
4 changes: 4 additions & 0 deletions client/sessiondataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ func (s *SessionDataSet) GetText(columnName string) string {
return s.ioTDBRpcDataSet.getText(columnName)
}

func (s *SessionDataSet) IsNull(columnName string) bool {
return s.ioTDBRpcDataSet.isNullWithColumnName(columnName)
}

func (s *SessionDataSet) GetBool(columnName string) bool {
return s.ioTDBRpcDataSet.getBool(columnName)
}
Expand Down
Loading
Loading