diff --git a/session_find.go b/session_find.go index a8ff6b7..7303f88 100644 --- a/session_find.go +++ b/session_find.go @@ -84,6 +84,8 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte } func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { + defer session.resetStatement() + if session.statement.lastError != nil { return session.statement.lastError } diff --git a/session_get.go b/session_get.go index ff1f855..4a92fe7 100644 --- a/session_get.go +++ b/session_get.go @@ -24,6 +24,8 @@ func (session *Session) Get(bean interface{}) (bool, error) { } func (session *Session) get(bean interface{}) (bool, error) { + defer session.resetStatement() + if session.statement.lastError != nil { return false, session.statement.lastError } @@ -86,6 +88,8 @@ func (session *Session) get(bean interface{}) (bool, error) { if context != nil { res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args)) if res != nil { + session.engine.logger.Debug("hit context cache", sqlStr) + structValue := reflect.Indirect(reflect.ValueOf(bean)) structValue.Set(reflect.Indirect(reflect.ValueOf(res))) session.lastSQL = "" @@ -93,13 +97,16 @@ func (session *Session) get(bean interface{}) (bool, error) { return true, nil } } + has, err := session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...) if err != nil || !has { return has, err } + if context != nil { context.Put(fmt.Sprintf("%v-%v", sqlStr, args), bean) } + return true, nil } @@ -138,6 +145,114 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea vvv.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(Value(v))) } + return true, nil + case *string: + var res sql.NullString + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*string)) = res.String + } + return true, nil + case *int: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*int)) = int(res.Int64) + } + return true, nil + case *int8: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*int8)) = int8(res.Int64) + } + return true, nil + case *int16: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*int16)) = int16(res.Int64) + } + return true, nil + case *int32: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*int32)) = int32(res.Int64) + } + return true, nil + case *int64: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*int64)) = int64(res.Int64) + } + return true, nil + case *uint: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*uint)) = uint(res.Int64) + } + return true, nil + case *uint8: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*uint8)) = uint8(res.Int64) + } + return true, nil + case *uint16: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*uint16)) = uint16(res.Int64) + } + return true, nil + case *uint32: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*uint32)) = uint32(res.Int64) + } + return true, nil + case *uint64: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*uint64)) = uint64(res.Int64) + } + return true, nil + case *bool: + var res sql.NullBool + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*bool)) = res.Bool + } return true, nil } @@ -167,6 +282,9 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea err = rows.ScanSlice(bean) case reflect.Map: err = rows.ScanMap(bean) + case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + err = rows.Scan(&bean) default: err = rows.Scan(bean) } diff --git a/session_get_test.go b/session_get_test.go index 0dadba5..87931e7 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -47,6 +47,12 @@ func TestGetVar(t *testing.T) { assert.Equal(t, true, has) assert.Equal(t, 28, age) + var ageMax int + has, err = testEngine.SQL("SELECT max(age) FROM "+testEngine.TableName("get_var", true)+" WHERE `id` = ?", data.Id).Get(&ageMax) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, 28, ageMax) + var age2 int64 has, err = testEngine.Table("get_var").Cols("age"). Where("age > ?", 20). @@ -56,6 +62,69 @@ func TestGetVar(t *testing.T) { assert.Equal(t, true, has) assert.EqualValues(t, 28, age2) + var age3 int8 + has, err = testEngine.Table("get_var").Cols("age").Get(&age3) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age3) + + var age4 int16 + has, err = testEngine.Table("get_var").Cols("age"). + Where("age > ?", 20). + And("age < ?", 30). + Get(&age4) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age4) + + var age5 int32 + has, err = testEngine.Table("get_var").Cols("age"). + Where("age > ?", 20). + And("age < ?", 30). + Get(&age5) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age5) + + var age6 int + has, err = testEngine.Table("get_var").Cols("age").Get(&age6) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age6) + + var age7 int64 + has, err = testEngine.Table("get_var").Cols("age"). + Where("age > ?", 20). + And("age < ?", 30). + Get(&age7) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age7) + + var age8 int8 + has, err = testEngine.Table("get_var").Cols("age").Get(&age8) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age8) + + var age9 int16 + has, err = testEngine.Table("get_var").Cols("age"). + Where("age > ?", 20). + And("age < ?", 30). + Get(&age9) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age9) + + var age10 int32 + has, err = testEngine.Table("get_var").Cols("age"). + Where("age > ?", 20). + And("age < ?", 30). + Get(&age10) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age10) + var id sql.NullInt64 has, err = testEngine.Table("get_var").Cols("id").Get(&id) assert.NoError(t, err) @@ -433,3 +502,85 @@ func TestGetCustomTableInterface(t *testing.T) { assert.NoError(t, err) assert.True(t, has) } + +func TestGetNullVar(t *testing.T) { + type TestGetNullVarStruct struct { + Id int64 + Name string + Age int + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(TestGetNullVarStruct)) + + affected, err := testEngine.Exec("insert into " + testEngine.TableName(new(TestGetNullVarStruct), true) + " (name,age) values (null,null)") + assert.NoError(t, err) + a, _ := affected.RowsAffected() + assert.EqualValues(t, 1, a) + + var name string + has, err := testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("name").Get(&name) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "", name) + + var age int + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age) + + var age2 int8 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age2) + + var age3 int16 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age3) + + var age4 int32 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age4) + + var age5 int64 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age5) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age5) + + var age6 uint + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age6) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age6) + + var age7 uint8 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age7) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age7) + + var age8 int16 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age8) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age8) + + var age9 int32 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age9) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age9) + + var age10 int64 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age10) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age10) +} diff --git a/xorm.go b/xorm.go index fdb3205..2fbaea0 100644 --- a/xorm.go +++ b/xorm.go @@ -20,7 +20,7 @@ import ( const ( // Version show the xorm's version - Version string = "0.7.4.0724" + Version string = "0.7.5.0803" ) func regDrvsNDialects() bool {