diff --git a/common/param/column.go b/common/param/column.go index de5dedc..1542f70 100644 --- a/common/param/column.go +++ b/common/param/column.go @@ -16,6 +16,10 @@ func NewColumnType(size int) *ColumnType { return &ColumnType{size: size, value: make([]*types.ColumnType, size)} } +func NewColumnTypeWithValue(value []*types.ColumnType) *ColumnType { + return &ColumnType{size: len(value), value: value, column: len(value)} +} + func (c *ColumnType) AddBool() *ColumnType { if c.column >= c.size { return c diff --git a/common/param/param.go b/common/param/param.go index a14854b..cce09d0 100644 --- a/common/param/param.go +++ b/common/param/param.go @@ -20,6 +20,15 @@ func NewParam(size int) *Param { } } +func NewParamsWithRowValue(value []driver.Value) []*Param { + params := make([]*Param, len(value)) + for i, d := range value { + params[i] = NewParam(1) + params[i].AddValue(d) + } + return params +} + func (p *Param) SetBool(offset int, value bool) { if offset >= p.size { return diff --git a/common/serializer/block.go b/common/serializer/block.go index 03a53f2..50d7a6e 100644 --- a/common/serializer/block.go +++ b/common/serializer/block.go @@ -37,7 +37,7 @@ func BMSetNull(c byte, n int) byte { return c + (1 << (7 - BitPos(n))) } -var ColumnNumerNotMatch = errors.New("number of columns does not match") +var ColumnNumberNotMatch = errors.New("number of columns does not match") var DataTypeWrong = errors.New("wrong data type") func SerializeRawBlock(params []*param.Param, colType *param.ColumnType) ([]byte, error) { @@ -48,7 +48,7 @@ func SerializeRawBlock(params []*param.Param, colType *param.ColumnType) ([]byte return nil, err } if len(colTypes) != columns { - return nil, ColumnNumerNotMatch + return nil, ColumnNumberNotMatch } var block []byte //version int32 diff --git a/taosSql/statement.go b/taosSql/statement.go index e103a0e..9513e37 100644 --- a/taosSql/statement.go +++ b/taosSql/statement.go @@ -138,11 +138,11 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { case reflect.Bool: v.Value = types.TaosBool(rv.Bool()) case reflect.Float32, reflect.Float64: - v.Value = types.TaosBool(rv.Float() == 1) + v.Value = types.TaosBool(rv.Float() > 0) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - v.Value = types.TaosBool(rv.Int() == 1) + v.Value = types.TaosBool(rv.Int() > 0) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - v.Value = types.TaosBool(rv.Uint() == 1) + v.Value = types.TaosBool(rv.Uint() > 0) case reflect.String: vv, err := strconv.ParseBool(rv.String()) if err != nil { diff --git a/taosWS/connection.go b/taosWS/connection.go index 71697e6..f97d6d0 100644 --- a/taosWS/connection.go +++ b/taosWS/connection.go @@ -15,6 +15,7 @@ import ( "github.com/gorilla/websocket" jsoniter "github.com/json-iterator/go" "github.com/taosdata/driver-go/v3/common" + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" taosErrors "github.com/taosdata/driver-go/v3/errors" ) @@ -26,6 +27,14 @@ const ( WSFetch = "fetch" WSFetchBlock = "fetch_block" WSFreeResult = "free_result" + + STMTInit = "init" + STMTPrepare = "prepare" + STMTAddBatch = "add_batch" + STMTExec = "exec" + STMTClose = "close" + STMTGetColFields = "get_col_fields" + STMTUseResult = "use_result" ) var ( @@ -51,7 +60,7 @@ func newTaosConn(cfg *config) (*taosConn, error) { endpointUrl := &url.URL{ Scheme: cfg.net, Host: fmt.Sprintf("%s:%d", cfg.addr, cfg.port), - Path: "/rest/ws", + Path: "/ws", } if cfg.token != "" { endpointUrl.RawQuery = fmt.Sprintf("token=%s", cfg.token) @@ -99,9 +108,297 @@ func (tc *taosConn) Close() (err error) { } func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { - return nil, &taosErrors.TaosError{Code: 0xffff, ErrStr: "websocket does not support stmt"} + stmtID, err := tc.stmtInit() + if err != nil { + return nil, err + } + isInsert, err := tc.stmtPrepare(stmtID, query) + if err != nil { + tc.stmtClose(stmtID) + return nil, err + } + stmt := &Stmt{ + conn: tc, + stmtID: stmtID, + isInsert: isInsert, + pSql: query, + } + return stmt, nil +} + +func (tc *taosConn) stmtInit() (uint64, error) { + reqID := tc.generateReqID() + req := &StmtInitReq{ + ReqID: reqID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return 0, err + } + action := &WSAction{ + Action: STMTInit, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return 0, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return 0, err + } + var resp StmtInitResp + err = tc.readTo(&resp) + if err != nil { + return 0, err + } + if resp.Code != 0 { + return 0, taosErrors.NewError(resp.Code, resp.Message) + } + return resp.StmtID, nil +} + +func (tc *taosConn) stmtPrepare(stmtID uint64, sql string) (bool, error) { + reqID := tc.generateReqID() + req := &StmtPrepareRequest{ + ReqID: reqID, + StmtID: stmtID, + SQL: sql, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return false, err + } + action := &WSAction{ + Action: STMTPrepare, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return false, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return false, err + } + var resp StmtPrepareResponse + err = tc.readTo(&resp) + if err != nil { + return false, err + } + if resp.Code != 0 { + return false, taosErrors.NewError(resp.Code, resp.Message) + } + return resp.IsInsert, nil +} + +func (tc *taosConn) stmtClose(stmtID uint64) error { + reqID := tc.generateReqID() + req := &StmtCloseRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return err + } + action := &WSAction{ + Action: STMTClose, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return err + } + return nil +} + +func (tc *taosConn) stmtGetColFields(stmtID uint64) ([]*stmtCommon.StmtField, error) { + reqID := tc.generateReqID() + req := &StmtGetColFieldsRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return nil, err + } + action := &WSAction{ + Action: STMTGetColFields, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return nil, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return nil, err + } + var resp StmtGetColFieldsResponse + err = tc.readTo(&resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + return resp.Fields, nil } +func (tc *taosConn) stmtBindParam(stmtID uint64, block []byte) error { + reqID := tc.generateReqID() + tc.buf.Reset() + WriteUint64(tc.buf, reqID) + WriteUint64(tc.buf, stmtID) + WriteUint64(tc.buf, BindMessage) + tc.buf.Write(block) + err := tc.writeBinary(tc.buf.Bytes()) + if err != nil { + return err + } + var resp StmtBindResponse + err = tc.readTo(&resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func WriteUint64(buffer *bytes.Buffer, v uint64) { + buffer.WriteByte(byte(v)) + buffer.WriteByte(byte(v >> 8)) + buffer.WriteByte(byte(v >> 16)) + buffer.WriteByte(byte(v >> 24)) + buffer.WriteByte(byte(v >> 32)) + buffer.WriteByte(byte(v >> 40)) + buffer.WriteByte(byte(v >> 48)) + buffer.WriteByte(byte(v >> 56)) +} + +func (tc *taosConn) stmtAddBatch(stmtID uint64) error { + reqID := tc.generateReqID() + req := &StmtAddBatchRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return err + } + action := &WSAction{ + Action: STMTAddBatch, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return err + } + var resp StmtAddBatchResponse + err = tc.readTo(&resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func (tc *taosConn) stmtExec(stmtID uint64) (int, error) { + reqID := tc.generateReqID() + req := &StmtExecRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return 0, err + } + action := &WSAction{ + Action: STMTExec, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return 0, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return 0, err + } + var resp StmtExecResponse + err = tc.readTo(&resp) + if err != nil { + return 0, err + } + if resp.Code != 0 { + return 0, taosErrors.NewError(resp.Code, resp.Message) + } + return resp.Affected, nil +} + +func (tc *taosConn) stmtUseResult(stmtID uint64) (*rows, error) { + reqID := tc.generateReqID() + req := &StmtUseResultRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return nil, err + } + action := &WSAction{ + Action: STMTUseResult, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return nil, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return nil, err + } + var resp StmtUseResultResponse + err = tc.readTo(&resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + rs := &rows{ + buf: &bytes.Buffer{}, + conn: tc, + resultID: resp.ResultID, + fieldsCount: resp.FieldsCount, + fieldsNames: resp.FieldsNames, + fieldsTypes: resp.FieldsTypes, + fieldsLengths: resp.FieldsLengths, + precision: resp.Precision, + isStmt: true, + } + return rs, nil +} func (tc *taosConn) Exec(query string, args []driver.Value) (driver.Result, error) { return tc.execCtx(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) } @@ -261,8 +558,16 @@ func (tc *taosConn) connect() error { } func (tc *taosConn) writeText(data []byte) error { + return tc.write(data, websocket.TextMessage) +} + +func (tc *taosConn) writeBinary(data []byte) error { + return tc.write(data, websocket.BinaryMessage) +} + +func (tc *taosConn) write(data []byte, messageType int) error { tc.client.SetWriteDeadline(time.Now().Add(tc.writeTimeout)) - err := tc.client.WriteMessage(websocket.TextMessage, data) + err := tc.client.WriteMessage(messageType, data) if err != nil { return NewBadConnErrorWithCtx(err, string(data)) } diff --git a/taosWS/proto.go b/taosWS/proto.go index fd2fb39..2731eec 100644 --- a/taosWS/proto.go +++ b/taosWS/proto.go @@ -1,6 +1,10 @@ package taosWS -import "encoding/json" +import ( + "encoding/json" + + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" +) type WSConnectReq struct { ReqID uint64 `json:"req_id"` @@ -69,3 +73,122 @@ type WSAction struct { Action string `json:"action"` Args json.RawMessage `json:"args"` } + +type StmtPrepareRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` + SQL string `json:"sql"` +} + +type StmtPrepareResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + IsInsert bool `json:"is_insert"` +} + +type StmtInitReq struct { + ReqID uint64 `json:"req_id"` +} + +type StmtInitResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} +type StmtCloseRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtCloseResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id,omitempty"` +} + +type StmtGetColFieldsRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtGetColFieldsResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + Fields []*stmtCommon.StmtField `json:"fields"` +} + +const ( + BindMessage = 2 +) + +type StmtBindResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtAddBatchRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtAddBatchResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtExecRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtExecResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + Affected int `json:"affected"` +} + +type StmtUseResultRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtUseResultResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + ResultID uint64 `json:"result_id"` + FieldsCount int `json:"fields_count"` + FieldsNames []string `json:"fields_names"` + FieldsTypes []uint8 `json:"fields_types"` + FieldsLengths []int64 `json:"fields_lengths"` + Precision int `json:"precision"` +} diff --git a/taosWS/rows.go b/taosWS/rows.go index b462f4e..9b8f0be 100644 --- a/taosWS/rows.go +++ b/taosWS/rows.go @@ -27,6 +27,7 @@ type rows struct { fieldsTypes []uint8 fieldsLengths []int64 precision int + isStmt bool } func (rs *rows) Columns() []string { @@ -158,6 +159,9 @@ func (rs *rows) fetchBlock() error { } func (rs *rows) freeResult() error { + if rs.isStmt { + return nil + } tc := rs.conn reqID := tc.generateReqID() req := &WSFreeResultReq{ diff --git a/taosWS/statement.go b/taosWS/statement.go new file mode 100644 index 0000000..d313820 --- /dev/null +++ b/taosWS/statement.go @@ -0,0 +1,517 @@ +package taosWS + +import ( + "bytes" + "database/sql/driver" + "errors" + "fmt" + "reflect" + "strconv" + "time" + + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/param" + "github.com/taosdata/driver-go/v3/common/serializer" + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" + "github.com/taosdata/driver-go/v3/types" +) + +type Stmt struct { + stmtID uint64 + conn *taosConn + buffer bytes.Buffer + pSql string + isInsert bool + cols []*stmtCommon.StmtField + colTypes *param.ColumnType + queryColTypes []*types.ColumnType +} + +func (stmt *Stmt) Close() error { + err := stmt.conn.stmtClose(stmt.stmtID) + stmt.buffer.Reset() + stmt.conn = nil + return err +} + +func (stmt *Stmt) NumInput() int { + if stmt.colTypes != nil { + return len(stmt.cols) + } + return -1 +} + +func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { + if stmt.conn == nil { + return nil, driver.ErrBadConn + } + if len(args) != len(stmt.cols) { + return nil, fmt.Errorf("stmt exec error: wrong number of parameters") + } + block, err := serializer.SerializeRawBlock(param.NewParamsWithRowValue(args), stmt.colTypes) + if err != nil { + return nil, err + } + err = stmt.conn.stmtBindParam(stmt.stmtID, block) + if err != nil { + return nil, err + } + err = stmt.conn.stmtAddBatch(stmt.stmtID) + if err != nil { + return nil, err + } + affected, err := stmt.conn.stmtExec(stmt.stmtID) + if err != nil { + return nil, err + } + return driver.RowsAffected(affected), nil +} + +func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { + if stmt.conn == nil { + return nil, driver.ErrBadConn + } + block, err := serializer.SerializeRawBlock(param.NewParamsWithRowValue(args), param.NewColumnTypeWithValue(stmt.queryColTypes)) + if err != nil { + return nil, err + } + err = stmt.conn.stmtBindParam(stmt.stmtID, block) + if err != nil { + return nil, err + } + err = stmt.conn.stmtAddBatch(stmt.stmtID) + if err != nil { + return nil, err + } + _, err = stmt.conn.stmtExec(stmt.stmtID) + if err != nil { + return nil, err + } + return stmt.conn.stmtUseResult(stmt.stmtID) +} + +func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { + if stmt.isInsert { + if stmt.cols == nil { + cols, err := stmt.conn.stmtGetColFields(stmt.stmtID) + if err != nil { + return err + } + colTypes := make([]*types.ColumnType, len(cols)) + for i, col := range cols { + t, err := col.GetType() + if err != nil { + return err + } + colTypes[i] = t + } + stmt.cols = cols + stmt.colTypes = param.NewColumnTypeWithValue(colTypes) + } + if v.Ordinal > len(stmt.cols) { + return nil + } + if v.Value == nil { + return nil + } + switch stmt.cols[v.Ordinal-1].FieldType { + case common.TSDB_DATA_TYPE_NULL: + v.Value = nil + case common.TSDB_DATA_TYPE_BOOL: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + v.Value = types.TaosBool(rv.Bool()) + case reflect.Float32, reflect.Float64: + v.Value = types.TaosBool(rv.Float() > 0) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosBool(rv.Int() > 0) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosBool(rv.Uint() > 0) + case reflect.String: + vv, err := strconv.ParseBool(rv.String()) + if err != nil { + return err + } + v.Value = types.TaosBool(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to bool", v) + } + case common.TSDB_DATA_TYPE_TINYINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosTinyint(1) + } else { + v.Value = types.TaosTinyint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosTinyint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosTinyint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosTinyint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 8) + if err != nil { + return err + } + v.Value = types.TaosTinyint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to tinyint", v) + } + case common.TSDB_DATA_TYPE_SMALLINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosSmallint(1) + } else { + v.Value = types.TaosSmallint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosSmallint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosSmallint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosSmallint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 16) + if err != nil { + return err + } + v.Value = types.TaosSmallint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to smallint", v) + } + case common.TSDB_DATA_TYPE_INT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosInt(1) + } else { + v.Value = types.TaosInt(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosInt(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosInt(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosInt(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 32) + if err != nil { + return err + } + v.Value = types.TaosInt(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to int", v) + } + case common.TSDB_DATA_TYPE_BIGINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosBigint(1) + } else { + v.Value = types.TaosBigint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosBigint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosBigint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosBigint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 64) + if err != nil { + return err + } + v.Value = types.TaosBigint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to bigint", v) + } + case common.TSDB_DATA_TYPE_FLOAT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosFloat(1) + } else { + v.Value = types.TaosFloat(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosFloat(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosFloat(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosFloat(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseFloat(rv.String(), 32) + if err != nil { + return err + } + v.Value = types.TaosFloat(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to float", v) + } + case common.TSDB_DATA_TYPE_DOUBLE: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosDouble(1) + } else { + v.Value = types.TaosDouble(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosDouble(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosDouble(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosDouble(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseFloat(rv.String(), 64) + if err != nil { + return err + } + v.Value = types.TaosDouble(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to double", v) + } + case common.TSDB_DATA_TYPE_BINARY: + switch v.Value.(type) { + case string: + v.Value = types.TaosBinary(v.Value.(string)) + case []byte: + v.Value = types.TaosBinary(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to binary", v) + } + case common.TSDB_DATA_TYPE_VARBINARY: + switch v.Value.(type) { + case string: + v.Value = types.TaosVarBinary(v.Value.(string)) + case []byte: + v.Value = types.TaosVarBinary(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to varbinary", v) + } + + case common.TSDB_DATA_TYPE_GEOMETRY: + switch v.Value.(type) { + case string: + v.Value = types.TaosGeometry(v.Value.(string)) + case []byte: + v.Value = types.TaosGeometry(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to geometry", v) + } + + case common.TSDB_DATA_TYPE_TIMESTAMP: + t, is := v.Value.(time.Time) + if is { + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + return nil + } + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Float32, reflect.Float64: + t := common.TimestampConvertToTime(int64(rv.Float()), int(stmt.cols[v.Ordinal-1].Precision)) + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + t := common.TimestampConvertToTime(rv.Int(), int(stmt.cols[v.Ordinal-1].Precision)) + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + t := common.TimestampConvertToTime(int64(rv.Uint()), int(stmt.cols[v.Ordinal-1].Precision)) + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + case reflect.String: + t, err := time.Parse(time.RFC3339Nano, rv.String()) + if err != nil { + return err + } + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to timestamp", v) + } + case common.TSDB_DATA_TYPE_NCHAR: + switch v.Value.(type) { + case string: + v.Value = types.TaosNchar(v.Value.(string)) + case []byte: + v.Value = types.TaosNchar(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to nchar", v) + } + case common.TSDB_DATA_TYPE_UTINYINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUTinyint(1) + } else { + v.Value = types.TaosUTinyint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUTinyint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUTinyint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUTinyint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 8) + if err != nil { + return err + } + v.Value = types.TaosUTinyint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to tinyint unsigned", v) + } + case common.TSDB_DATA_TYPE_USMALLINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUSmallint(1) + } else { + v.Value = types.TaosUSmallint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUSmallint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUSmallint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUSmallint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 16) + if err != nil { + return err + } + v.Value = types.TaosUSmallint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to smallint unsigned", v) + } + case common.TSDB_DATA_TYPE_UINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUInt(1) + } else { + v.Value = types.TaosUInt(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUInt(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUInt(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUInt(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 32) + if err != nil { + return err + } + v.Value = types.TaosUInt(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to int unsigned", v) + } + case common.TSDB_DATA_TYPE_UBIGINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUBigint(1) + } else { + v.Value = types.TaosUBigint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUBigint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUBigint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUBigint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 64) + if err != nil { + return err + } + v.Value = types.TaosUBigint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to bigint unsigned", v) + } + } + return nil + } else { + if v.Value == nil { + return errors.New("CheckNamedValue: value is nil") + } + if v.Ordinal == 1 { + stmt.queryColTypes = nil + } + if len(stmt.queryColTypes) < v.Ordinal { + tmp := stmt.queryColTypes + stmt.queryColTypes = make([]*types.ColumnType, v.Ordinal) + copy(stmt.queryColTypes, tmp) + } + t, is := v.Value.(time.Time) + if is { + v.Value = types.TaosBinary(t.Format(time.RFC3339Nano)) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosBinaryType} + return nil + } + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + v.Value = types.TaosBool(rv.Bool()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosBoolType} + case reflect.Float32, reflect.Float64: + v.Value = types.TaosDouble(rv.Float()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosDoubleType} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosBigint(rv.Int()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosBigintType} + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUBigint(rv.Uint()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosUBigintType} + case reflect.String: + strVal := rv.String() + v.Value = types.TaosBinary(strVal) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{ + Type: types.TaosBinaryType, + MaxLen: len(strVal), + } + case reflect.Slice: + ek := rv.Type().Elem().Kind() + if ek == reflect.Uint8 { + bsVal := rv.Bytes() + v.Value = types.TaosBinary(bsVal) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{ + Type: types.TaosBinaryType, + MaxLen: len(bsVal), + } + } else { + return fmt.Errorf("CheckNamedValue: can not convert query value %v", v) + } + default: + return fmt.Errorf("CheckNamedValue: can not convert query value %v", v) + } + return nil + } +} diff --git a/taosWS/statement_test.go b/taosWS/statement_test.go new file mode 100644 index 0000000..1ab008c --- /dev/null +++ b/taosWS/statement_test.go @@ -0,0 +1,2159 @@ +package taosWS + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestStmtExec(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer func() { + t.Log("start3") + db.Close() + t.Log("done3") + }() + defer func() { + _, err = db.Exec("drop database if exists test_stmt_driver_ws") + if err != nil { + t.Error(err) + return + } + t.Log("done2") + }() + _, err = db.Exec("create database if not exists test_stmt_driver_ws") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("create table if not exists test_stmt_driver_ws.ct(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + t.Error(err) + return + } + stmt, err := db.Prepare("insert into test_stmt_driver_ws.ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + + if err != nil { + t.Error(err) + return + } + result, err := stmt.Exec(time.Now(), 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, "binary", "nchar") + if err != nil { + t.Error(err) + return + } + affected, err := result.RowsAffected() + assert.NoError(t, err) + assert.Equal(t, int64(1), affected) + t.Log("done") +} + +func TestStmtQuery(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer db.Close() + defer func() { + db.Exec("drop database if exists test_stmt_driver_ws_q") + }() + _, err = db.Exec("create database if not exists test_stmt_driver_ws_q") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("create table if not exists test_stmt_driver_ws_q.ct(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + t.Error(err) + return + } + stmt, err := db.Prepare("insert into test_stmt_driver_ws_q.ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + if err != nil { + t.Error(err) + return + } + now := time.Now() + result, err := stmt.Exec(now, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, "binary", "nchar") + if err != nil { + t.Error(err) + return + } + affected, err := result.RowsAffected() + if err != nil { + t.Error(err) + return + } + assert.Equal(t, int64(1), affected) + stmt.Close() + stmt, err = db.Prepare("select * from test_stmt_driver_ws_q.ct where ts = ?") + if err != nil { + t.Error(err) + return + } + rows, err := stmt.Query(now) + if err != nil { + t.Error(err) + return + } + columns, err := rows.Columns() + if err != nil { + t.Error(err) + return + } + assert.Equal(t, []string{"ts", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13"}, columns) + count := 0 + for rows.Next() { + count += 1 + var ( + ts time.Time + c1 bool + c2 int8 + c3 int16 + c4 int32 + c5 int64 + c6 uint8 + c7 uint16 + c8 uint32 + c9 uint64 + c10 float32 + c11 float64 + c12 string + c13 string + ) + err = rows.Scan(&ts, + &c1, + &c2, + &c3, + &c4, + &c5, + &c6, + &c7, + &c8, + &c9, + &c10, + &c11, + &c12, + &c13) + assert.NoError(t, err) + assert.Equal(t, now.UnixNano()/1e6, ts.UnixNano()/1e6) + assert.Equal(t, true, c1) + assert.Equal(t, int8(2), c2) + assert.Equal(t, int16(3), c3) + assert.Equal(t, int32(4), c4) + assert.Equal(t, int64(5), c5) + assert.Equal(t, uint8(6), c6) + assert.Equal(t, uint16(7), c7) + assert.Equal(t, uint32(8), c8) + assert.Equal(t, uint64(9), c9) + assert.Equal(t, float32(10), c10) + assert.Equal(t, float64(11), c11) + assert.Equal(t, "binary", c12) + assert.Equal(t, "nchar", c13) + } + assert.Equal(t, 1, count) +} + +func TestStmtConvertExec(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer db.Close() + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + defer func() { + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + }() + _, err = db.Exec("create database test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("use test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + now := time.Now().Format(time.RFC3339Nano) + tests := []struct { + name string + tbType string + pos string + bind []interface{} + expectValue interface{} + expectError bool + }{ + //bool + { + name: "bool_null", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "bool_err", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, []int{123}}, + expectValue: nil, + expectError: true, + }, + { + name: "bool_bool_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: true, + }, + { + name: "bool_bool_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: false, + }, + { + name: "bool_float_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: true, + }, + { + name: "bool_float_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, float32(0)}, + expectValue: false, + }, + { + name: "bool_int_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, int32(1)}, + expectValue: true, + }, + { + name: "bool_int_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, int32(0)}, + expectValue: false, + }, + { + name: "bool_uint_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, uint32(1)}, + expectValue: true, + }, + { + name: "bool_uint_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, uint32(0)}, + expectValue: false, + }, + { + name: "bool_string_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, "true"}, + expectValue: true, + }, + { + name: "bool_string_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, "false"}, + expectValue: false, + }, + //tiny int + { + name: "tiny_nil", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "tiny_err", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "tiny_bool_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int8(1), + }, + { + name: "tiny_bool_0", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int8(0), + }, + { + name: "tiny_float_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int8(1), + }, + { + name: "tiny_int_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int8(1), + }, + { + name: "tiny_uint_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int8(1), + }, + { + name: "tiny_string_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int8(1), + }, + // small int + { + name: "small_nil", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "small_err", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "small_bool_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int16(1), + }, + { + name: "small_bool_0", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int16(0), + }, + { + name: "small_float_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int16(1), + }, + { + name: "small_int_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int16(1), + }, + { + name: "small_uint_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int16(1), + }, + { + name: "small_string_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int16(1), + }, + // int + { + name: "int_nil", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "int_err", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "int_bool_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int32(1), + }, + { + name: "int_bool_0", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int32(0), + }, + { + name: "int_float_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int32(1), + }, + { + name: "int_int_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int32(1), + }, + { + name: "int_uint_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int32(1), + }, + { + name: "int_string_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int32(1), + }, + // big int + { + name: "big_nil", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "big_err", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "big_bool_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int64(1), + }, + { + name: "big_bool_0", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int64(0), + }, + { + name: "big_float_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int64(1), + }, + { + name: "big_int_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int64(1), + }, + { + name: "big_uint_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int64(1), + }, + { + name: "big_string_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int64(1), + }, + // float + { + name: "float_nil", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "float_err", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "float_bool_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: float32(1), + }, + { + name: "float_bool_0", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: float32(0), + }, + { + name: "float_float_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: float32(1), + }, + { + name: "float_int_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: float32(1), + }, + { + name: "float_uint_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: float32(1), + }, + { + name: "float_string_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: float32(1), + }, + //double + { + name: "double_nil", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "double_err", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "double_bool_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: float64(1), + }, + { + name: "double_bool_0", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: float64(0), + }, + { + name: "double_double_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: float64(1), + }, + { + name: "double_int_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: float64(1), + }, + { + name: "double_uint_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: float64(1), + }, + { + name: "double_string_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: float64(1), + }, + + //tiny int unsigned + { + name: "utiny_nil", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "utiny_err", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "utiny_bool_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint8(1), + }, + { + name: "utiny_bool_0", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint8(0), + }, + { + name: "utiny_float_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint8(1), + }, + { + name: "utiny_int_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint8(1), + }, + { + name: "utiny_uint_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint8(1), + }, + { + name: "utiny_string_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint8(1), + }, + // small int unsigned + { + name: "usmall_nil", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "usmall_err", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "usmall_bool_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint16(1), + }, + { + name: "usmall_bool_0", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint16(0), + }, + { + name: "usmall_float_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint16(1), + }, + { + name: "usmall_int_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint16(1), + }, + { + name: "usmall_uint_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint16(1), + }, + { + name: "usmall_string_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint16(1), + }, + // int unsigned + { + name: "uint_nil", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "uint_err", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "uint_bool_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint32(1), + }, + { + name: "uint_bool_0", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint32(0), + }, + { + name: "uint_float_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint32(1), + }, + { + name: "uint_int_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint32(1), + }, + { + name: "uint_uint_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint32(1), + }, + { + name: "uint_string_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint32(1), + }, + // big int unsigned + { + name: "ubig_nil", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "ubig_err", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "ubig_bool_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint64(1), + }, + { + name: "ubig_bool_0", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint64(0), + }, + { + name: "ubig_float_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint64(1), + }, + { + name: "ubig_int_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint64(1), + }, + { + name: "ubig_uint_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint64(1), + }, + { + name: "ubig_string_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint64(1), + }, + //binary + { + name: "binary_nil", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "binary_err", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "binary_string_chinese", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, "中文"}, + expectValue: "中文", + }, + { + name: "binary_bytes_chinese", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, []byte("中文")}, + expectValue: "中文", + }, + //nchar + { + name: "nchar_nil", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "nchar_err", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "binary_string_chinese", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, "中文"}, + expectValue: "中文", + }, + { + name: "binary_bytes_chinese", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, []byte("中文")}, + expectValue: "中文", + }, + // timestamp + { + name: "ts_nil", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "ts_err", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "ts_time_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, time.Unix(0, 1e6)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_float_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_int_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_uint_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_string_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, "1970-01-01T00:00:00.001Z"}, + expectValue: time.Unix(0, 1e6), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tbName := fmt.Sprintf("test_%s", tt.name) + tbType := tt.tbType + drop := fmt.Sprintf("drop table if exists %s", tbName) + create := fmt.Sprintf("create table if not exists %s(%s)", tbName, tbType) + pos := tt.pos + sql := fmt.Sprintf("insert into %s values(%s)", tbName, pos) + var err error + if _, err = db.Exec(drop); err != nil { + t.Error(err) + return + } + if _, err = db.Exec(create); err != nil { + t.Error(err) + return + } + stmt, err := db.Prepare(sql) + if err != nil { + t.Error(err) + return + } + result, err := stmt.Exec(tt.bind...) + if tt.expectError { + assert.NotNil(t, err) + stmt.Close() + return + } + if err != nil { + t.Error(err) + return + } + affected, err := result.RowsAffected() + if err != nil { + t.Error(err) + return + } + assert.Equal(t, int64(1), affected) + rows, err := db.Query(fmt.Sprintf("select v from %s", tbName)) + if err != nil { + t.Error(err) + return + } + var data []driver.Value + tts, err := rows.ColumnTypes() + if err != nil { + t.Error(err) + return + } + typesL := make([]reflect.Type, 1) + for i, tp := range tts { + st := tp.ScanType() + if st == nil { + t.Errorf("scantype is null for column %q", tp.Name()) + continue + } + typesL[i] = st + } + for rows.Next() { + values := make([]interface{}, 1) + for i := range values { + values[i] = reflect.New(typesL[i]).Interface() + } + err = rows.Scan(values...) + if err != nil { + t.Error(err) + return + } + v, err := values[0].(driver.Valuer).Value() + if err != nil { + t.Error(err) + } + data = append(data, v) + } + if len(data) != 1 { + t.Errorf("expect %d got %d", 1, len(data)) + return + } + if data[0] != tt.expectValue { + t.Errorf("expect %v got %v", tt.expectValue, data[0]) + return + } + }) + } +} + +func TestStmtConvertQuery(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer db.Close() + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + defer func() { + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + }() + _, err = db.Exec("create database test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("use test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("create table t0 (ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + t.Error(err) + return + } + now := time.Now() + after1s := now.Add(time.Second) + _, err = db.Exec(fmt.Sprintf("insert into t0 values('%s',true,2,3,4,5,6,7,8,9,10,11,'binary','nchar')", now.Format(time.RFC3339Nano))) + if err != nil { + t.Error(err) + return + } + _, err = db.Exec(fmt.Sprintf("insert into t0 values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", after1s.Format(time.RFC3339Nano))) + if err != nil { + t.Error(err) + return + } + tests := []struct { + name string + field string + where string + bind interface{} + expectNoValue bool + expectValue driver.Value + expectError bool + }{ + //ts + { + name: "ts", + field: "ts", + where: "ts = ?", + bind: now, + expectValue: time.Unix(now.Unix(), int64((now.Nanosecond()/1e6)*1e6)).Local(), + }, + + //bool + { + name: "bool_true", + field: "c1", + where: "c1 = ?", + bind: true, + expectValue: true, + }, + { + name: "bool_false", + field: "c1", + where: "c1 = ?", + bind: false, + expectNoValue: true, + }, + { + name: "tinyint_int8", + field: "c2", + where: "c2 = ?", + bind: int8(2), + expectValue: int8(2), + }, + { + name: "tinyint_iny16", + field: "c2", + where: "c2 = ?", + bind: int16(2), + expectValue: int8(2), + }, + { + name: "tinyint_int32", + field: "c2", + where: "c2 = ?", + bind: int32(2), + expectValue: int8(2), + }, + { + name: "tinyint_int64", + field: "c2", + where: "c2 = ?", + bind: int64(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint8", + field: "c2", + where: "c2 = ?", + bind: uint8(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint16", + field: "c2", + where: "c2 = ?", + bind: uint16(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint32", + field: "c2", + where: "c2 = ?", + bind: uint32(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint64", + field: "c2", + where: "c2 = ?", + bind: uint64(2), + expectValue: int8(2), + }, + { + name: "tinyint_float32", + field: "c2", + where: "c2 = ?", + bind: float32(2), + expectValue: int8(2), + }, + { + name: "tinyint_float64", + field: "c2", + where: "c2 = ?", + bind: float64(2), + expectValue: int8(2), + }, + { + name: "tinyint_int", + field: "c2", + where: "c2 = ?", + bind: int(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint", + field: "c2", + where: "c2 = ?", + bind: uint(2), + expectValue: int8(2), + }, + + // smallint + { + name: "smallint_int8", + field: "c3", + where: "c3 = ?", + bind: int8(3), + expectValue: int16(3), + }, + { + name: "smallint_iny16", + field: "c3", + where: "c3 = ?", + bind: int16(3), + expectValue: int16(3), + }, + { + name: "smallint_int32", + field: "c3", + where: "c3 = ?", + bind: int32(3), + expectValue: int16(3), + }, + { + name: "smallint_int64", + field: "c3", + where: "c3 = ?", + bind: int64(3), + expectValue: int16(3), + }, + { + name: "smallint_uint8", + field: "c3", + where: "c3 = ?", + bind: uint8(3), + expectValue: int16(3), + }, + { + name: "smallint_uint16", + field: "c3", + where: "c3 = ?", + bind: uint16(3), + expectValue: int16(3), + }, + { + name: "smallint_uint32", + field: "c3", + where: "c3 = ?", + bind: uint32(3), + expectValue: int16(3), + }, + { + name: "smallint_uint64", + field: "c3", + where: "c3 = ?", + bind: uint64(3), + expectValue: int16(3), + }, + { + name: "smallint_float32", + field: "c3", + where: "c3 = ?", + bind: float32(3), + expectValue: int16(3), + }, + { + name: "smallint_float64", + field: "c3", + where: "c3 = ?", + bind: float64(3), + expectValue: int16(3), + }, + { + name: "smallint_int", + field: "c3", + where: "c3 = ?", + bind: int(3), + expectValue: int16(3), + }, + { + name: "smallint_uint", + field: "c3", + where: "c3 = ?", + bind: uint(3), + expectValue: int16(3), + }, + + //int + { + name: "int_int8", + field: "c4", + where: "c4 = ?", + bind: int8(4), + expectValue: int32(4), + }, + { + name: "int_iny16", + field: "c4", + where: "c4 = ?", + bind: int16(4), + expectValue: int32(4), + }, + { + name: "int_int32", + field: "c4", + where: "c4 = ?", + bind: int32(4), + expectValue: int32(4), + }, + { + name: "int_int64", + field: "c4", + where: "c4 = ?", + bind: int64(4), + expectValue: int32(4), + }, + { + name: "int_uint8", + field: "c4", + where: "c4 = ?", + bind: uint8(4), + expectValue: int32(4), + }, + { + name: "int_uint16", + field: "c4", + where: "c4 = ?", + bind: uint16(4), + expectValue: int32(4), + }, + { + name: "int_uint32", + field: "c4", + where: "c4 = ?", + bind: uint32(4), + expectValue: int32(4), + }, + { + name: "int_uint64", + field: "c4", + where: "c4 = ?", + bind: uint64(4), + expectValue: int32(4), + }, + { + name: "int_float32", + field: "c4", + where: "c4 = ?", + bind: float32(4), + expectValue: int32(4), + }, + { + name: "int_float64", + field: "c4", + where: "c4 = ?", + bind: float64(4), + expectValue: int32(4), + }, + { + name: "int_int", + field: "c4", + where: "c4 = ?", + bind: int(4), + expectValue: int32(4), + }, + { + name: "int_uint", + field: "c4", + where: "c4 = ?", + bind: uint(4), + expectValue: int32(4), + }, + + //bigint + { + name: "bigint_int8", + field: "c5", + where: "c5 = ?", + bind: int8(5), + expectValue: int64(5), + }, + { + name: "bigint_iny16", + field: "c5", + where: "c5 = ?", + bind: int16(5), + expectValue: int64(5), + }, + { + name: "bigint_int32", + field: "c5", + where: "c5 = ?", + bind: int32(5), + expectValue: int64(5), + }, + { + name: "bigint_int64", + field: "c5", + where: "c5 = ?", + bind: int64(5), + expectValue: int64(5), + }, + { + name: "bigint_uint8", + field: "c5", + where: "c5 = ?", + bind: uint8(5), + expectValue: int64(5), + }, + { + name: "bigint_uint16", + field: "c5", + where: "c5 = ?", + bind: uint16(5), + expectValue: int64(5), + }, + { + name: "bigint_uint32", + field: "c5", + where: "c5 = ?", + bind: uint32(5), + expectValue: int64(5), + }, + { + name: "bigint_uint64", + field: "c5", + where: "c5 = ?", + bind: uint64(5), + expectValue: int64(5), + }, + { + name: "bigint_float32", + field: "c5", + where: "c5 = ?", + bind: float32(5), + expectValue: int64(5), + }, + { + name: "bigint_float64", + field: "c5", + where: "c5 = ?", + bind: float64(5), + expectValue: int64(5), + }, + { + name: "bigint_int", + field: "c5", + where: "c5 = ?", + bind: int(5), + expectValue: int64(5), + }, + { + name: "bigint_uint", + field: "c5", + where: "c5 = ?", + bind: uint(5), + expectValue: int64(5), + }, + + //utinyint + { + name: "utinyint_int8", + field: "c6", + where: "c6 = ?", + bind: int8(6), + expectValue: uint8(6), + }, + { + name: "utinyint_iny16", + field: "c6", + where: "c6 = ?", + bind: int16(6), + expectValue: uint8(6), + }, + { + name: "utinyint_int32", + field: "c6", + where: "c6 = ?", + bind: int32(6), + expectValue: uint8(6), + }, + { + name: "utinyint_int64", + field: "c6", + where: "c6 = ?", + bind: int64(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint8", + field: "c6", + where: "c6 = ?", + bind: uint8(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint16", + field: "c6", + where: "c6 = ?", + bind: uint16(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint32", + field: "c6", + where: "c6 = ?", + bind: uint32(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint64", + field: "c6", + where: "c6 = ?", + bind: uint64(6), + expectValue: uint8(6), + }, + { + name: "utinyint_float32", + field: "c6", + where: "c6 = ?", + bind: float32(6), + expectValue: uint8(6), + }, + { + name: "utinyint_float64", + field: "c6", + where: "c6 = ?", + bind: float64(6), + expectValue: uint8(6), + }, + { + name: "utinyint_int", + field: "c6", + where: "c6 = ?", + bind: int(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint", + field: "c6", + where: "c6 = ?", + bind: uint(6), + expectValue: uint8(6), + }, + + //usmallint + { + name: "usmallint_int8", + field: "c7", + where: "c7 = ?", + bind: int8(7), + expectValue: uint16(7), + }, + { + name: "usmallint_iny16", + field: "c7", + where: "c7 = ?", + bind: int16(7), + expectValue: uint16(7), + }, + { + name: "usmallint_int32", + field: "c7", + where: "c7 = ?", + bind: int32(7), + expectValue: uint16(7), + }, + { + name: "usmallint_int64", + field: "c7", + where: "c7 = ?", + bind: int64(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint8", + field: "c7", + where: "c7 = ?", + bind: uint8(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint16", + field: "c7", + where: "c7 = ?", + bind: uint16(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint32", + field: "c7", + where: "c7 = ?", + bind: uint32(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint64", + field: "c7", + where: "c7 = ?", + bind: uint64(7), + expectValue: uint16(7), + }, + { + name: "usmallint_float32", + field: "c7", + where: "c7 = ?", + bind: float32(7), + expectValue: uint16(7), + }, + { + name: "usmallint_float64", + field: "c7", + where: "c7 = ?", + bind: float64(7), + expectValue: uint16(7), + }, + { + name: "usmallint_int", + field: "c7", + where: "c7 = ?", + bind: int(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint", + field: "c7", + where: "c7 = ?", + bind: uint(7), + expectValue: uint16(7), + }, + + //uint + { + name: "uint_int8", + field: "c8", + where: "c8 = ?", + bind: int8(8), + expectValue: uint32(8), + }, + { + name: "uint_iny16", + field: "c8", + where: "c8 = ?", + bind: int16(8), + expectValue: uint32(8), + }, + { + name: "uint_int32", + field: "c8", + where: "c8 = ?", + bind: int32(8), + expectValue: uint32(8), + }, + { + name: "uint_int64", + field: "c8", + where: "c8 = ?", + bind: int64(8), + expectValue: uint32(8), + }, + { + name: "uint_uint8", + field: "c8", + where: "c8 = ?", + bind: uint8(8), + expectValue: uint32(8), + }, + { + name: "uint_uint16", + field: "c8", + where: "c8 = ?", + bind: uint16(8), + expectValue: uint32(8), + }, + { + name: "uint_uint32", + field: "c8", + where: "c8 = ?", + bind: uint32(8), + expectValue: uint32(8), + }, + { + name: "uint_uint64", + field: "c8", + where: "c8 = ?", + bind: uint64(8), + expectValue: uint32(8), + }, + { + name: "uint_float32", + field: "c8", + where: "c8 = ?", + bind: float32(8), + expectValue: uint32(8), + }, + { + name: "uint_float64", + field: "c8", + where: "c8 = ?", + bind: float64(8), + expectValue: uint32(8), + }, + { + name: "uint_int", + field: "c8", + where: "c8 = ?", + bind: int(8), + expectValue: uint32(8), + }, + { + name: "uint_uint", + field: "c8", + where: "c8 = ?", + bind: uint(8), + expectValue: uint32(8), + }, + + //ubigint + { + name: "ubigint_int8", + field: "c9", + where: "c9 = ?", + bind: int8(9), + expectValue: uint64(9), + }, + { + name: "ubigint_iny16", + field: "c9", + where: "c9 = ?", + bind: int16(9), + expectValue: uint64(9), + }, + { + name: "ubigint_int32", + field: "c9", + where: "c9 = ?", + bind: int32(9), + expectValue: uint64(9), + }, + { + name: "ubigint_int64", + field: "c9", + where: "c9 = ?", + bind: int64(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint8", + field: "c9", + where: "c9 = ?", + bind: uint8(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint16", + field: "c9", + where: "c9 = ?", + bind: uint16(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint32", + field: "c9", + where: "c9 = ?", + bind: uint32(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint64", + field: "c9", + where: "c9 = ?", + bind: uint64(9), + expectValue: uint64(9), + }, + { + name: "ubigint_float32", + field: "c9", + where: "c9 = ?", + bind: float32(9), + expectValue: uint64(9), + }, + { + name: "ubigint_float64", + field: "c9", + where: "c9 = ?", + bind: float64(9), + expectValue: uint64(9), + }, + { + name: "ubigint_int", + field: "c9", + where: "c9 = ?", + bind: int(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint", + field: "c9", + where: "c9 = ?", + bind: uint(9), + expectValue: uint64(9), + }, + + //float + { + name: "float_int8", + field: "c10", + where: "c10 = ?", + bind: int8(10), + expectValue: float32(10), + }, + { + name: "float_iny16", + field: "c10", + where: "c10 = ?", + bind: int16(10), + expectValue: float32(10), + }, + { + name: "float_int32", + field: "c10", + where: "c10 = ?", + bind: int32(10), + expectValue: float32(10), + }, + { + name: "float_int64", + field: "c10", + where: "c10 = ?", + bind: int64(10), + expectValue: float32(10), + }, + { + name: "float_uint8", + field: "c10", + where: "c10 = ?", + bind: uint8(10), + expectValue: float32(10), + }, + { + name: "float_uint16", + field: "c10", + where: "c10 = ?", + bind: uint16(10), + expectValue: float32(10), + }, + { + name: "float_uint32", + field: "c10", + where: "c10 = ?", + bind: uint32(10), + expectValue: float32(10), + }, + { + name: "float_uint64", + field: "c10", + where: "c10 = ?", + bind: uint64(10), + expectValue: float32(10), + }, + { + name: "float_float32", + field: "c10", + where: "c10 = ?", + bind: float32(10), + expectValue: float32(10), + }, + { + name: "float_float64", + field: "c10", + where: "c10 = ?", + bind: float64(10), + expectValue: float32(10), + }, + { + name: "float_int", + field: "c10", + where: "c10 = ?", + bind: int(10), + expectValue: float32(10), + }, + { + name: "float_uint", + field: "c10", + where: "c10 = ?", + bind: uint(10), + expectValue: float32(10), + }, + + //double + { + name: "double_int8", + field: "c11", + where: "c11 = ?", + bind: int8(11), + expectValue: float64(11), + }, + { + name: "double_iny16", + field: "c11", + where: "c11 = ?", + bind: int16(11), + expectValue: float64(11), + }, + { + name: "double_int32", + field: "c11", + where: "c11 = ?", + bind: int32(11), + expectValue: float64(11), + }, + { + name: "double_int64", + field: "c11", + where: "c11 = ?", + bind: int64(11), + expectValue: float64(11), + }, + { + name: "double_uint8", + field: "c11", + where: "c11 = ?", + bind: uint8(11), + expectValue: float64(11), + }, + { + name: "double_uint16", + field: "c11", + where: "c11 = ?", + bind: uint16(11), + expectValue: float64(11), + }, + { + name: "double_uint32", + field: "c11", + where: "c11 = ?", + bind: uint32(11), + expectValue: float64(11), + }, + { + name: "double_uint64", + field: "c11", + where: "c11 = ?", + bind: uint64(11), + expectValue: float64(11), + }, + { + name: "double_float32", + field: "c11", + where: "c11 = ?", + bind: float32(11), + expectValue: float64(11), + }, + { + name: "double_float64", + field: "c11", + where: "c11 = ?", + bind: float64(11), + expectValue: float64(11), + }, + { + name: "double_int", + field: "c11", + where: "c11 = ?", + bind: int(11), + expectValue: float64(11), + }, + { + name: "double_uint", + field: "c11", + where: "c11 = ?", + bind: uint(11), + expectValue: float64(11), + }, + + // binary + { + name: "binary_string", + field: "c12", + where: "c12 = ?", + bind: "binary", + expectValue: "binary", + }, + { + name: "binary_bytes", + field: "c12", + where: "c12 = ?", + bind: []byte("binary"), + expectValue: "binary", + }, + { + name: "binary_string_like", + field: "c12", + where: "c12 like ?", + bind: "bin%", + expectValue: "binary", + }, + + // nchar + { + name: "nchar_string", + field: "c13", + where: "c13 = ?", + bind: "nchar", + expectValue: "nchar", + }, + { + name: "nchar_bytes", + field: "c13", + where: "c13 = ?", + bind: []byte("nchar"), + expectValue: "nchar", + }, + { + name: "nchar_string", + field: "c13", + where: "c13 like ?", + bind: "nch%", + expectValue: "nchar", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql := fmt.Sprintf("select %s from t0 where %s", tt.field, tt.where) + + stmt, err := db.Prepare(sql) + if err != nil { + t.Error(err) + return + } + defer stmt.Close() + rows, err := stmt.Query(tt.bind) + if tt.expectError { + assert.NotNil(t, err) + stmt.Close() + return + } + if err != nil { + t.Error(err) + return + } + tts, err := rows.ColumnTypes() + typesL := make([]reflect.Type, 1) + for i, tp := range tts { + st := tp.ScanType() + if st == nil { + t.Errorf("scantype is null for column %q", tp.Name()) + continue + } + typesL[i] = st + } + var data []driver.Value + for rows.Next() { + values := make([]interface{}, 1) + for i := range values { + values[i] = reflect.New(typesL[i]).Interface() + } + err = rows.Scan(values...) + if err != nil { + t.Error(err) + return + } + v, err := values[0].(driver.Valuer).Value() + if err != nil { + t.Error(err) + } + data = append(data, v) + } + if tt.expectNoValue { + if len(data) > 0 { + t.Errorf("expect no value got %#v", data) + return + } + return + } + if len(data) != 1 { + t.Errorf("expect %d got %d", 1, len(data)) + return + } + if data[0] != tt.expectValue { + t.Errorf("expect %v got %v", tt.expectValue, data[0]) + return + } + }) + } +}