Skip to content

Commit

Permalink
fix tablename bug (go-xorm#887)
Browse files Browse the repository at this point in the history
* fix tablename bug

* fix test
  • Loading branch information
lunny authored Apr 11, 2018
1 parent 5c2af83 commit bfdf773
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 11 deletions.
18 changes: 11 additions & 7 deletions engine_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,17 @@ func (session *Session) tbNameNoSchema(table *core.Table) string {
}

func (engine *Engine) tbNameForMap(v reflect.Value) string {
t := v.Type()
if tb, ok := v.Interface().(TableName); ok {
return tb.TableName()
if v.Type().Implements(tpTableName) {
return v.Interface().(TableName).TableName()
}
if v.CanAddr() {
if tb, ok := v.Addr().Interface().(TableName); ok {
return tb.TableName()
if v.Kind() == reflect.Ptr {
v = v.Elem()
if v.Type().Implements(tpTableName) {
return v.Interface().(TableName).TableName()
}
}
return engine.TableMapper.Obj2Table(t.Name())

return engine.TableMapper.Obj2Table(v.Type().Name())
}

func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
Expand Down Expand Up @@ -97,6 +98,9 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
return tablename.(TableName).TableName()
case string:
return tablename.(string)
case reflect.Value:
v := tablename.(reflect.Value)
return engine.tbNameForMap(v)
default:
v := rValue(tablename)
t := v.Type()
Expand Down
4 changes: 2 additions & 2 deletions session_find.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct {
pv := reflect.New(sliceElementType.Elem())
if err := session.statement.setRefValue(pv.Elem()); err != nil {
if err := session.statement.setRefValue(pv); err != nil {
return err
}
} else {
tp = tpNonStruct
}
} else if sliceElementType.Kind() == reflect.Struct {
pv := reflect.New(sliceElementType)
if err := session.statement.setRefValue(pv.Elem()); err != nil {
if err := session.statement.setRefValue(pv); err != nil {
return err
}
} else {
Expand Down
73 changes: 73 additions & 0 deletions session_find_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -584,3 +584,76 @@ func TestFindAndCountOneFunc(t *testing.T) {
assert.EqualValues(t, 1, len(results))
assert.EqualValues(t, 1, cnt)
}

type FindMapDevice struct {
Deviceid string `xorm:"pk"`
Status int
}

func (device *FindMapDevice) TableName() string {
return "devices"
}

func TestFindMapStringId(t *testing.T) {
assert.NoError(t, prepareEngine())
assertSync(t, new(FindMapDevice))

cnt, err := testEngine.Insert(&FindMapDevice{
Deviceid: "1",
Status: 1,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)

deviceIDs := []string{"1"}

deviceMaps := make(map[string]*FindMapDevice, len(deviceIDs))
err = testEngine.
Where("status = ?", 1).
In("deviceid", deviceIDs).
Find(&deviceMaps)
assert.NoError(t, err)

deviceMaps2 := make(map[string]FindMapDevice, len(deviceIDs))
err = testEngine.
Where("status = ?", 1).
In("deviceid", deviceIDs).
Find(&deviceMaps2)
assert.NoError(t, err)

devices := make([]*FindMapDevice, 0, len(deviceIDs))
err = testEngine.Find(&devices)
assert.NoError(t, err)

devices2 := make([]FindMapDevice, 0, len(deviceIDs))
err = testEngine.Find(&devices2)
assert.NoError(t, err)

var device FindMapDevice
has, err := testEngine.Get(&device)
assert.NoError(t, err)
assert.True(t, has)

has, err = testEngine.Exist(&FindMapDevice{})
assert.NoError(t, err)
assert.True(t, has)

cnt, err = testEngine.Count(new(FindMapDevice))
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)

cnt, err = testEngine.ID("1").Update(&FindMapDevice{
Status: 2,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)

sum, err := testEngine.SumInt(new(FindMapDevice), "status")
assert.NoError(t, err)
assert.EqualValues(t, 2, sum)

cnt, err = testEngine.ID("1").Delete(new(FindMapDevice))
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)

}
2 changes: 1 addition & 1 deletion session_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, errors.New("could not insert a empty slice")
}

if err := session.statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil {
if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil {
return 0, err
}

Expand Down
2 changes: 1 addition & 1 deletion statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func (statement *Statement) setRefValue(v reflect.Value) error {
if err != nil {
return err
}
statement.tableName = statement.Engine.TableName(v.Interface(), true)
statement.tableName = statement.Engine.TableName(v, true)
return nil
}

Expand Down

0 comments on commit bfdf773

Please sign in to comment.