From 5c2af838175807337daec3f0055a70944eb58f00 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 11 Apr 2018 15:39:04 +0800 Subject: [PATCH] add test and remove unused warn log (#886) --- engine_cond.go | 5 +- error.go | 13 + interface.go | 1 + session.go | 700 ++++++++++++++++++++++--------------------- session_cols_test.go | 40 +++ 5 files changed, 411 insertions(+), 348 deletions(-) diff --git a/engine_cond.go b/engine_cond.go index 6c8e3879c..4dde8662e 100644 --- a/engine_cond.go +++ b/engine_cond.go @@ -9,6 +9,7 @@ import ( "encoding/json" "fmt" "reflect" + "strings" "time" "github.com/go-xorm/builder" @@ -51,7 +52,9 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{}, fieldValuePtr, err := col.ValueOf(bean) if err != nil { - engine.logger.Error(err) + if !strings.Contains(err.Error(), "is not valid") { + engine.logger.Warn(err) + } continue } diff --git a/error.go b/error.go index cfeefc31e..1694683cf 100644 --- a/error.go +++ b/error.go @@ -6,6 +6,7 @@ package xorm import ( "errors" + "fmt" ) var ( @@ -25,4 +26,16 @@ var ( ErrNotImplemented = errors.New("Not implemented") // ErrConditionType condition type unsupported ErrConditionType = errors.New("Unsupported conditon type") + // ErrColumnIsNotExist columns is not exist + ErrFieldIsNotExist = errors.New("Field is not exist") ) + +// ErrFieldIsNotValid is not valid +type ErrFieldIsNotValid struct { + FieldName string + TableName string +} + +func (e ErrFieldIsNotValid) Error() string { + return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) +} diff --git a/interface.go b/interface.go index 5ce49f480..39d50edd2 100644 --- a/interface.go +++ b/interface.go @@ -42,6 +42,7 @@ type Interface interface { IsTableExist(beanOrTableName interface{}) (bool, error) Iterate(interface{}, IterFunc) error Limit(int, ...int) *Session + MustCols(columns ...string) *Session NoAutoCondition(...bool) *Session NotIn(string, ...interface{}) *Session Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session diff --git a/session.go b/session.go index 15283624c..48baf768e 100644 --- a/session.go +++ b/session.go @@ -278,24 +278,22 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, return } -func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) *reflect.Value { +func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) (*reflect.Value, error) { var col *core.Column if col = table.GetColumnIdx(key, idx); col == nil { - //session.engine.logger.Warnf("table %v has no column %v. %v", table.Name, key, table.ColumnsSeq()) - return nil + return nil, ErrFieldIsNotExist } fieldValue, err := col.ValueOfV(dataStruct) if err != nil { - session.engine.logger.Error(err) - return nil + return nil, err } if !fieldValue.IsValid() || !fieldValue.CanSet() { - session.engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key) - return nil + return nil, ErrFieldIsNotValid{key, table.Name} } - return fieldValue + + return fieldValue, nil } // Cell cell is a result of one column field @@ -407,409 +405,417 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } tempMap[lKey] = idx - if fieldValue := session.getField(dataStruct, key, table, idx); fieldValue != nil { - rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii])) - - // if row is null then ignore - if rawValue.Interface() == nil { - continue + fieldValue, err := session.getField(dataStruct, key, table, idx) + if err != nil { + if !strings.Contains(err.Error(), "is not valid") { + session.engine.logger.Warn(err) } + continue + } + if fieldValue == nil { + continue + } + rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii])) - if fieldValue.CanAddr() { - if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { - if data, err := value2Bytes(&rawValue); err == nil { - if err := structConvert.FromDB(data); err != nil { - return nil, err - } - } else { - return nil, err - } - continue - } - } + // if row is null then ignore + if rawValue.Interface() == nil { + continue + } - if _, ok := fieldValue.Interface().(core.Conversion); ok { + if fieldValue.CanAddr() { + if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { if data, err := value2Bytes(&rawValue); err == nil { - if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + if err := structConvert.FromDB(data); err != nil { + return nil, err } - fieldValue.Interface().(core.Conversion).FromDB(data) } else { return nil, err } continue } + } - rawValueType := reflect.TypeOf(rawValue.Interface()) - vv := reflect.ValueOf(rawValue.Interface()) - col := table.GetColumnIdx(key, idx) - if col.IsPrimaryKey { - pk = append(pk, rawValue.Interface()) + if _, ok := fieldValue.Interface().(core.Conversion); ok { + if data, err := value2Bytes(&rawValue); err == nil { + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + } + fieldValue.Interface().(core.Conversion).FromDB(data) + } else { + return nil, err } - fieldType := fieldValue.Type() - hasAssigned := false + continue + } - if col.SQLType.IsJson() { - var bs []byte - if rawValueType.Kind() == reflect.String { - bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(core.BytesType) { - bs = vv.Bytes() - } else { - return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) - } + rawValueType := reflect.TypeOf(rawValue.Interface()) + vv := reflect.ValueOf(rawValue.Interface()) + col := table.GetColumnIdx(key, idx) + if col.IsPrimaryKey { + pk = append(pk, rawValue.Interface()) + } + fieldType := fieldValue.Type() + hasAssigned := false + + if col.SQLType.IsJson() { + var bs []byte + if rawValueType.Kind() == reflect.String { + bs = []byte(vv.String()) + } else if rawValueType.ConvertibleTo(core.BytesType) { + bs = vv.Bytes() + } else { + return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) + } - hasAssigned = true + hasAssigned = true - if len(bs) > 0 { - if fieldType.Kind() == reflect.String { - fieldValue.SetString(string(bs)) - continue + if len(bs) > 0 { + if fieldType.Kind() == reflect.String { + fieldValue.SetString(string(bs)) + continue + } + if fieldValue.CanAddr() { + err := json.Unmarshal(bs, fieldValue.Addr().Interface()) + if err != nil { + return nil, err } - if fieldValue.CanAddr() { - err := json.Unmarshal(bs, fieldValue.Addr().Interface()) - if err != nil { - return nil, err - } - } else { - x := reflect.New(fieldType) - err := json.Unmarshal(bs, x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) + } else { + x := reflect.New(fieldType) + err := json.Unmarshal(bs, x.Interface()) + if err != nil { + return nil, err } + fieldValue.Set(x.Elem()) } - - continue } - switch fieldType.Kind() { - case reflect.Complex64, reflect.Complex128: - // TODO: reimplement this - var bs []byte - if rawValueType.Kind() == reflect.String { - bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(core.BytesType) { - bs = vv.Bytes() - } + continue + } - hasAssigned = true - if len(bs) > 0 { - if fieldValue.CanAddr() { - err := json.Unmarshal(bs, fieldValue.Addr().Interface()) - if err != nil { - return nil, err - } - } else { - x := reflect.New(fieldType) - err := json.Unmarshal(bs, x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) + switch fieldType.Kind() { + case reflect.Complex64, reflect.Complex128: + // TODO: reimplement this + var bs []byte + if rawValueType.Kind() == reflect.String { + bs = []byte(vv.String()) + } else if rawValueType.ConvertibleTo(core.BytesType) { + bs = vv.Bytes() + } + + hasAssigned = true + if len(bs) > 0 { + if fieldValue.CanAddr() { + err := json.Unmarshal(bs, fieldValue.Addr().Interface()) + if err != nil { + return nil, err + } + } else { + x := reflect.New(fieldType) + err := json.Unmarshal(bs, x.Interface()) + if err != nil { + return nil, err } + fieldValue.Set(x.Elem()) } + } + case reflect.Slice, reflect.Array: + switch rawValueType.Kind() { case reflect.Slice, reflect.Array: - switch rawValueType.Kind() { - case reflect.Slice, reflect.Array: - switch rawValueType.Elem().Kind() { - case reflect.Uint8: - if fieldType.Elem().Kind() == reflect.Uint8 { - hasAssigned = true - if col.SQLType.IsText() { - x := reflect.New(fieldType) - err := json.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return nil, err + switch rawValueType.Elem().Kind() { + case reflect.Uint8: + if fieldType.Elem().Kind() == reflect.Uint8 { + hasAssigned = true + if col.SQLType.IsText() { + x := reflect.New(fieldType) + err := json.Unmarshal(vv.Bytes(), x.Interface()) + if err != nil { + return nil, err + } + fieldValue.Set(x.Elem()) + } else { + if fieldValue.Len() > 0 { + for i := 0; i < fieldValue.Len(); i++ { + if i < vv.Len() { + fieldValue.Index(i).Set(vv.Index(i)) + } } - fieldValue.Set(x.Elem()) } else { - if fieldValue.Len() > 0 { - for i := 0; i < fieldValue.Len(); i++ { - if i < vv.Len() { - fieldValue.Index(i).Set(vv.Index(i)) - } - } - } else { - for i := 0; i < vv.Len(); i++ { - fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) - } + for i := 0; i < vv.Len(); i++ { + fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) } } } } } - case reflect.String: - if rawValueType.Kind() == reflect.String { - hasAssigned = true - fieldValue.SetString(vv.String()) - } - case reflect.Bool: - if rawValueType.Kind() == reflect.Bool { - hasAssigned = true - fieldValue.SetBool(vv.Bool()) - } + } + case reflect.String: + if rawValueType.Kind() == reflect.String { + hasAssigned = true + fieldValue.SetString(vv.String()) + } + case reflect.Bool: + if rawValueType.Kind() == reflect.Bool { + hasAssigned = true + fieldValue.SetBool(vv.Bool()) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch rawValueType.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - switch rawValueType.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - hasAssigned = true - fieldValue.SetInt(vv.Int()) - } + hasAssigned = true + fieldValue.SetInt(vv.Int()) + } + case reflect.Float32, reflect.Float64: + switch rawValueType.Kind() { case reflect.Float32, reflect.Float64: - switch rawValueType.Kind() { - case reflect.Float32, reflect.Float64: - hasAssigned = true - fieldValue.SetFloat(vv.Float()) - } + hasAssigned = true + fieldValue.SetFloat(vv.Float()) + } + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + switch rawValueType.Kind() { case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - switch rawValueType.Kind() { - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - hasAssigned = true - fieldValue.SetUint(vv.Uint()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - hasAssigned = true - fieldValue.SetUint(uint64(vv.Int())) + hasAssigned = true + fieldValue.SetUint(vv.Uint()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + hasAssigned = true + fieldValue.SetUint(uint64(vv.Int())) + } + case reflect.Struct: + if fieldType.ConvertibleTo(core.TimeType) { + dbTZ := session.engine.DatabaseTZ + if col.TimeZone != nil { + dbTZ = col.TimeZone } - case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { - dbTZ := session.engine.DatabaseTZ - if col.TimeZone != nil { - dbTZ = col.TimeZone - } - - if rawValueType == core.TimeType { - hasAssigned = true - t := vv.Convert(core.TimeType).Interface().(time.Time) - - z, _ := t.Zone() - // set new location if database don't save timezone or give an incorrect timezone - if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location - session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) - t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), - t.Minute(), t.Second(), t.Nanosecond(), dbTZ) - } + if rawValueType == core.TimeType { + hasAssigned = true - t = t.In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } else if rawValueType == core.IntType || rawValueType == core.Int64Type || - rawValueType == core.Int32Type { - hasAssigned = true + t := vv.Convert(core.TimeType).Interface().(time.Time) - t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } else { - if d, ok := vv.Interface().([]uint8); ok { - hasAssigned = true - t, err := session.byte2Time(col, d) - if err != nil { - session.engine.logger.Error("byte2Time error:", err.Error()) - hasAssigned = false - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } - } else if d, ok := vv.Interface().(string); ok { - hasAssigned = true - t, err := session.str2Time(col, d) - if err != nil { - session.engine.logger.Error("byte2Time error:", err.Error()) - hasAssigned = false - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } - } else { - return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) - } - } - } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - // !! 增加支持sql.Scanner接口的结构,如sql.NullString - hasAssigned = true - if err := nulVal.Scan(vv.Interface()); err != nil { - session.engine.logger.Error("sql.Sanner error:", err.Error()) - hasAssigned = false - } - } else if col.SQLType.IsJson() { - if rawValueType.Kind() == reflect.String { - hasAssigned = true - x := reflect.New(fieldType) - if len([]byte(vv.String())) > 0 { - err := json.Unmarshal([]byte(vv.String()), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - } else if rawValueType.Kind() == reflect.Slice { - hasAssigned = true - x := reflect.New(fieldType) - if len(vv.Bytes()) > 0 { - err := json.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - } - } else if session.statement.UseCascade { - table, err := session.engine.autoMapType(*fieldValue) - if err != nil { - return nil, err + z, _ := t.Zone() + // set new location if database don't save timezone or give an incorrect timezone + if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location + session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) + t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), + t.Minute(), t.Second(), t.Nanosecond(), dbTZ) } + t = t.In(session.engine.TZLocation) + fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + } else if rawValueType == core.IntType || rawValueType == core.Int64Type || + rawValueType == core.Int32Type { hasAssigned = true - if len(table.PrimaryKeys) != 1 { - return nil, errors.New("unsupported non or composited primary key cascade") - } - var pk = make(core.PK, len(table.PrimaryKeys)) - pk[0], err = asKind(vv, rawValueType) - if err != nil { - return nil, err - } - if !isPKZero(pk) { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) + t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) + fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + } else { + if d, ok := vv.Interface().([]uint8); ok { + hasAssigned = true + t, err := session.byte2Time(col, d) if err != nil { - return nil, err + session.engine.logger.Error("byte2Time error:", err.Error()) + hasAssigned = false + } else { + fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) } - if has { - fieldValue.Set(structInter.Elem()) + } else if d, ok := vv.Interface().(string); ok { + hasAssigned = true + t, err := session.str2Time(col, d) + if err != nil { + session.engine.logger.Error("byte2Time error:", err.Error()) + hasAssigned = false } else { - return nil, errors.New("cascade obj is not exist") + fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) } + } else { + return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) } } - case reflect.Ptr: - // !nashtsai! TODO merge duplicated codes above - switch fieldType { - // following types case matching ptr's native type, therefore assign ptr directly - case core.PtrStringType: - if rawValueType.Kind() == reflect.String { - x := vv.String() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrBoolType: - if rawValueType.Kind() == reflect.Bool { - x := vv.Bool() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrTimeType: - if rawValueType == core.PtrTimeType { - hasAssigned = true - var x = rawValue.Interface().(time.Time) - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrFloat64Type: - if rawValueType.Kind() == reflect.Float64 { - x := vv.Float() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrUint64Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint64(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt64Type: - if rawValueType.Kind() == reflect.Int64 { - x := vv.Int() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrFloat32Type: - if rawValueType.Kind() == reflect.Float64 { - var x = float32(vv.Float()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrIntType: - if rawValueType.Kind() == reflect.Int64 { - var x = int(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrUintType: - if rawValueType.Kind() == reflect.Int64 { - var x = uint(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrUint32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.Uint8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.Uint16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.Complex64Type: - var x complex64 + } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + // !! 增加支持sql.Scanner接口的结构,如sql.NullString + hasAssigned = true + if err := nulVal.Scan(vv.Interface()); err != nil { + session.engine.logger.Error("sql.Sanner error:", err.Error()) + hasAssigned = false + } + } else if col.SQLType.IsJson() { + if rawValueType.Kind() == reflect.String { + hasAssigned = true + x := reflect.New(fieldType) if len([]byte(vv.String())) > 0 { - err := json.Unmarshal([]byte(vv.String()), &x) + err := json.Unmarshal([]byte(vv.String()), x.Interface()) if err != nil { return nil, err } - fieldValue.Set(reflect.ValueOf(&x)) + fieldValue.Set(x.Elem()) } + } else if rawValueType.Kind() == reflect.Slice { hasAssigned = true - case core.Complex128Type: - var x complex128 - if len([]byte(vv.String())) > 0 { - err := json.Unmarshal([]byte(vv.String()), &x) + x := reflect.New(fieldType) + if len(vv.Bytes()) > 0 { + err := json.Unmarshal(vv.Bytes(), x.Interface()) if err != nil { return nil, err } - fieldValue.Set(reflect.ValueOf(&x)) + fieldValue.Set(x.Elem()) } - hasAssigned = true - } // switch fieldType - } // switch fieldType.Kind() - - // !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value - if !hasAssigned { - data, err := value2Bytes(&rawValue) + } + } else if session.statement.UseCascade { + table, err := session.engine.autoMapType(*fieldValue) if err != nil { return nil, err } - if err = session.bytes2Value(col, fieldValue, data); err != nil { + hasAssigned = true + if len(table.PrimaryKeys) != 1 { + return nil, errors.New("unsupported non or composited primary key cascade") + } + var pk = make(core.PK, len(table.PrimaryKeys)) + pk[0], err = asKind(vv, rawValueType) + if err != nil { return nil, err } + + if !isPKZero(pk) { + // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch + // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne + // property to be fetched lazily + structInter := reflect.New(fieldValue.Type()) + has, err := session.ID(pk).NoCascade().get(structInter.Interface()) + if err != nil { + return nil, err + } + if has { + fieldValue.Set(structInter.Elem()) + } else { + return nil, errors.New("cascade obj is not exist") + } + } + } + case reflect.Ptr: + // !nashtsai! TODO merge duplicated codes above + switch fieldType { + // following types case matching ptr's native type, therefore assign ptr directly + case core.PtrStringType: + if rawValueType.Kind() == reflect.String { + x := vv.String() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrBoolType: + if rawValueType.Kind() == reflect.Bool { + x := vv.Bool() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrTimeType: + if rawValueType == core.PtrTimeType { + hasAssigned = true + var x = rawValue.Interface().(time.Time) + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrFloat64Type: + if rawValueType.Kind() == reflect.Float64 { + x := vv.Float() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint64Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint64(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt64Type: + if rawValueType.Kind() == reflect.Int64 { + x := vv.Int() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrFloat32Type: + if rawValueType.Kind() == reflect.Float64 { + var x = float32(vv.Float()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrIntType: + if rawValueType.Kind() == reflect.Int64 { + var x = int(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt32Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt8Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt16Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUintType: + if rawValueType.Kind() == reflect.Int64 { + var x = uint(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint32Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.Uint8Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.Uint16Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.Complex64Type: + var x complex64 + if len([]byte(vv.String())) > 0 { + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + return nil, err + } + fieldValue.Set(reflect.ValueOf(&x)) + } + hasAssigned = true + case core.Complex128Type: + var x complex128 + if len([]byte(vv.String())) > 0 { + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + return nil, err + } + fieldValue.Set(reflect.ValueOf(&x)) + } + hasAssigned = true + } // switch fieldType + } // switch fieldType.Kind() + + // !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value + if !hasAssigned { + data, err := value2Bytes(&rawValue) + if err != nil { + return nil, err + } + + if err = session.bytes2Value(col, fieldValue, data); err != nil { + return nil, err } } } diff --git a/session_cols_test.go b/session_cols_test.go index 6ec17130b..3315e6b55 100644 --- a/session_cols_test.go +++ b/session_cols_test.go @@ -67,3 +67,43 @@ func TestCols(t *testing.T) { assert.EqualValues(t, "", tb.Col1) assert.EqualValues(t, "", tb.Col2) } + +func TestMustCol(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type CustomerUpdate struct { + Id int64 `form:"id" json:"id"` + Username string `form:"username" json:"username" binding:"required"` + Email string `form:"email" json:"email"` + Sex int `form:"sex" json:"sex"` + Name string `form:"name" json:"name" binding:"required"` + Telephone string `form:"telephone" json:"telephone"` + Type int `form:"type" json:"type" binding:"required"` + ParentId int64 `form:"parent_id" json:"parent_id" xorm:"int null"` + Remark string `form:"remark" json:"remark"` + Status int `form:"status" json:"status" binding:"required"` + Age int `form:"age" json:"age"` + CreatedAt int64 `xorm:"created" form:"created_at" json:"created_at"` + UpdatedAt int64 `xorm:"updated" form:"updated_at" json:"updated_at"` + BirthDate int64 `form:"birth_date" json:"birth_date"` + Password string `xorm:"varchar(200)" form:"password" json:"password"` + } + + assertSync(t, new(CustomerUpdate)) + + var customer = CustomerUpdate{ + ParentId: 1, + } + cnt, err := testEngine.Insert(&customer) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + type CustomerOnlyId struct { + Id int64 + } + + customer.ParentId = 0 + affected, err := testEngine.MustCols("parent_id").Update(&customer, &CustomerOnlyId{Id: customer.Id}) + assert.NoError(t, err) + assert.EqualValues(t, 1, affected) +}