diff --git a/engine_table.go b/engine_table.go index 1319871f3..94871a4bc 100644 --- a/engine_table.go +++ b/engine_table.go @@ -45,16 +45,17 @@ func (session *Session) tbNameNoSchema(table *core.Table) string { } func (engine *Engine) tbNameForMap(v reflect.Value) string { - t := v.Type() - if tb, ok := v.Interface().(TableName); ok { - return tb.TableName() + if v.Type().Implements(tpTableName) { + return v.Interface().(TableName).TableName() } - if v.CanAddr() { - if tb, ok := v.Addr().Interface().(TableName); ok { - return tb.TableName() + if v.Kind() == reflect.Ptr { + v = v.Elem() + if v.Type().Implements(tpTableName) { + return v.Interface().(TableName).TableName() } } - return engine.TableMapper.Obj2Table(t.Name()) + + return engine.TableMapper.Obj2Table(v.Type().Name()) } func (engine *Engine) tbNameNoSchema(tablename interface{}) string { @@ -97,6 +98,9 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string { return tablename.(TableName).TableName() case string: return tablename.(string) + case reflect.Value: + v := tablename.(reflect.Value) + return engine.tbNameForMap(v) default: v := rValue(tablename) t := v.Type() diff --git a/session_find.go b/session_find.go index 4323dc7eb..758ef2127 100644 --- a/session_find.go +++ b/session_find.go @@ -75,7 +75,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Elem().Kind() == reflect.Struct { pv := reflect.New(sliceElementType.Elem()) - if err := session.statement.setRefValue(pv.Elem()); err != nil { + if err := session.statement.setRefValue(pv); err != nil { return err } } else { @@ -83,7 +83,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } } else if sliceElementType.Kind() == reflect.Struct { pv := reflect.New(sliceElementType) - if err := session.statement.setRefValue(pv.Elem()); err != nil { + if err := session.statement.setRefValue(pv); err != nil { return err } } else { diff --git a/session_find_test.go b/session_find_test.go index 4db7f9ce3..6a04dc57a 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -584,3 +584,76 @@ func TestFindAndCountOneFunc(t *testing.T) { assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, 1, cnt) } + +type FindMapDevice struct { + Deviceid string `xorm:"pk"` + Status int +} + +func (device *FindMapDevice) TableName() string { + return "devices" +} + +func TestFindMapStringId(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(FindMapDevice)) + + cnt, err := testEngine.Insert(&FindMapDevice{ + Deviceid: "1", + Status: 1, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + deviceIDs := []string{"1"} + + deviceMaps := make(map[string]*FindMapDevice, len(deviceIDs)) + err = testEngine. + Where("status = ?", 1). + In("deviceid", deviceIDs). + Find(&deviceMaps) + assert.NoError(t, err) + + deviceMaps2 := make(map[string]FindMapDevice, len(deviceIDs)) + err = testEngine. + Where("status = ?", 1). + In("deviceid", deviceIDs). + Find(&deviceMaps2) + assert.NoError(t, err) + + devices := make([]*FindMapDevice, 0, len(deviceIDs)) + err = testEngine.Find(&devices) + assert.NoError(t, err) + + devices2 := make([]FindMapDevice, 0, len(deviceIDs)) + err = testEngine.Find(&devices2) + assert.NoError(t, err) + + var device FindMapDevice + has, err := testEngine.Get(&device) + assert.NoError(t, err) + assert.True(t, has) + + has, err = testEngine.Exist(&FindMapDevice{}) + assert.NoError(t, err) + assert.True(t, has) + + cnt, err = testEngine.Count(new(FindMapDevice)) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.ID("1").Update(&FindMapDevice{ + Status: 2, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + sum, err := testEngine.SumInt(new(FindMapDevice), "status") + assert.NoError(t, err) + assert.EqualValues(t, 2, sum) + + cnt, err = testEngine.ID("1").Delete(new(FindMapDevice)) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + +} diff --git a/session_insert.go b/session_insert.go index 31bee5bcb..6d1851282 100644 --- a/session_insert.go +++ b/session_insert.go @@ -66,7 +66,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return 0, errors.New("could not insert a empty slice") } - if err := session.statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil { + if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil { return 0, err } diff --git a/statement.go b/statement.go index f6827b38a..15b9048aa 100644 --- a/statement.go +++ b/statement.go @@ -208,7 +208,7 @@ func (statement *Statement) setRefValue(v reflect.Value) error { if err != nil { return err } - statement.tableName = statement.Engine.TableName(v.Interface(), true) + statement.tableName = statement.Engine.TableName(v, true) return nil }