Skip to content

Commit

Permalink
fix update map with table name (go-xorm#888)
Browse files Browse the repository at this point in the history
* fix update map with table name

* fix bug update map when cache enabled

* refactor cacheInsert

* fix cache test
  • Loading branch information
lunny authored Apr 11, 2018
1 parent bfdf773 commit 636ccef
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 101 deletions.
89 changes: 37 additions & 52 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,35 @@ type Engine struct {
tagHandlers map[string]tagHandler

engineGroup *EngineGroup

cachers map[string]core.Cacher
cacherLock sync.RWMutex
}

func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {
engine.cacherLock.Lock()
engine.cachers[tableName] = cacher
engine.cacherLock.Unlock()
}

func (engine *Engine) SetCacher(tableName string, cacher core.Cacher) {
engine.setCacher(tableName, cacher)
}

func (engine *Engine) getCacher(tableName string) core.Cacher {
var cacher core.Cacher
var ok bool
engine.cacherLock.RLock()
cacher, ok = engine.cachers[tableName]
engine.cacherLock.RUnlock()
if !ok && !engine.disableGlobalCache {
cacher = engine.Cacher
}
return cacher
}

func (engine *Engine) GetCacher(tableName string) core.Cacher {
return engine.getCacher(tableName)
}

// BufferSize sets buffer size for iterate
Expand Down Expand Up @@ -245,13 +274,7 @@ func (engine *Engine) NoCascade() *Session {

// MapCacher Set a table use a special cacher
func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error {
v := rValue(bean)
tb, err := engine.autoMapType(v)
if err != nil {
return err
}

tb.Cacher = cacher
engine.setCacher(engine.TableName(bean, true), cacher)
return nil
}

Expand Down Expand Up @@ -834,15 +857,6 @@ func addIndex(indexName string, table *core.Table, col *core.Column, indexType i
}
}

func (engine *Engine) newTable() *core.Table {
table := core.NewEmptyTable()

if !engine.disableGlobalCache {
table.Cacher = engine.Cacher
}
return table
}

// TableName table name interface to define customerize table name
type TableName interface {
TableName() string
Expand All @@ -854,7 +868,7 @@ var (

func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
t := v.Type()
table := engine.newTable()
table := core.NewEmptyTable()
table.Type = t
table.Name = engine.tbNameForMap(v)

Expand Down Expand Up @@ -1010,15 +1024,15 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
if hasCacheTag {
if engine.Cacher != nil { // !nash! use engine's cacher if provided
engine.logger.Info("enable cache on table:", table.Name)
table.Cacher = engine.Cacher
engine.setCacher(table.Name, engine.Cacher)
} else {
engine.logger.Info("enable LRU cache on table:", table.Name)
table.Cacher = NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) // !nashtsai! HACK use LRU cacher for now
engine.setCacher(table.Name, NewLRUCacher2(NewMemoryStore(), time.Hour, 10000))
}
}
if hasNoCacheTag {
engine.logger.Info("no cache on table:", table.Name)
table.Cacher = nil
engine.logger.Info("disable cache on table:", table.Name)
engine.setCacher(table.Name, nil)
}

return table, nil
Expand Down Expand Up @@ -1123,26 +1137,10 @@ func (engine *Engine) CreateUniques(bean interface{}) error {
return session.CreateUniques(bean)
}

func (engine *Engine) getCacher2(table *core.Table) core.Cacher {
return table.Cacher
}

// ClearCacheBean if enabled cache, clear the cache bean
func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
v := rValue(bean)
t := v.Type()
if t.Kind() != reflect.Struct {
return errors.New("error params")
}
tableName := engine.TableName(bean)
table, err := engine.autoMapType(v)
if err != nil {
return err
}
cacher := table.Cacher
if cacher == nil {
cacher = engine.Cacher
}
cacher := engine.getCacher(tableName)
if cacher != nil {
cacher.ClearIds(tableName)
cacher.DelBean(tableName, id)
Expand All @@ -1153,21 +1151,8 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
// ClearCache if enabled cache, clear some tables' cache
func (engine *Engine) ClearCache(beans ...interface{}) error {
for _, bean := range beans {
v := rValue(bean)
t := v.Type()
if t.Kind() != reflect.Struct {
return errors.New("error params")
}
tableName := engine.TableName(bean)
table, err := engine.autoMapType(v)
if err != nil {
return err
}

cacher := table.Cacher
if cacher == nil {
cacher = engine.Cacher
}
cacher := engine.getCacher(tableName)
if cacher != nil {
cacher.ClearIds(tableName)
cacher.ClearBeans(tableName)
Expand Down
2 changes: 2 additions & 0 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ type EngineInterface interface {
Dialect() core.Dialect
DropTables(...interface{}) error
DumpAllToFile(fp string, tp ...core.DbType) error
GetCacher(string) core.Cacher
GetColumnMapper() core.IMapper
GetDefaultCacher() core.Cacher
GetTableMapper() core.IMapper
Expand All @@ -85,6 +86,7 @@ type EngineInterface interface {
NewSession() *Session
NoAutoTime() *Session
Quote(string) string
SetCacher(string, core.Cacher)
SetDefaultCacher(core.Cacher)
SetLogLevel(core.LogLevel)
SetMapper(core.IMapper)
Expand Down
4 changes: 2 additions & 2 deletions session_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string,
return ErrCacheFailed
}

cacher := session.engine.getCacher2(table)
cacher := session.engine.getCacher(tableName)
pkColumns := table.PKColumns()
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil {
Expand Down Expand Up @@ -199,7 +199,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
})
}

if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache {
session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...)
}

Expand Down
10 changes: 7 additions & 3 deletions session_find.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
}

if session.canCache() {
if cacher := session.engine.getCacher2(table); cacher != nil &&
if cacher := session.engine.getCacher(table.Name); cacher != nil &&
!session.statement.IsDistinct &&
!session.statement.unscoped {
err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...)
Expand Down Expand Up @@ -321,6 +321,12 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
return ErrCacheFailed
}

tableName := session.statement.TableName()
cacher := session.engine.getCacher(tableName)
if cacher == nil {
return nil
}

for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
}
Expand All @@ -330,9 +336,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
return ErrCacheFailed
}

tableName := session.statement.TableName()
table := session.statement.RefTable
cacher := session.engine.getCacher2(table)
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil {
rows, err := session.queryRows(newsql, args...)
Expand Down
5 changes: 3 additions & 2 deletions session_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
table := session.statement.RefTable

if session.canCache() && beanValue.Elem().Kind() == reflect.Struct {
if cacher := session.engine.getCacher2(table); cacher != nil &&
if cacher := session.engine.getCacher(table.Name); cacher != nil &&
!session.statement.unscoped {
has, err := session.cacheGet(bean, sqlStr, args...)
if err != ErrCacheFailed {
Expand Down Expand Up @@ -134,8 +134,9 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return false, ErrCacheFailed
}

cacher := session.engine.getCacher2(session.statement.RefTable)
tableName := session.statement.TableName()
cacher := session.engine.getCacher(tableName)

session.engine.logger.Debug("[cacheGet] find sql:", newsql, args)
table := session.statement.RefTable
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
Expand Down
37 changes: 14 additions & 23 deletions session_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, err
}

if len(session.statement.TableName()) <= 0 {
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}

Expand Down Expand Up @@ -205,7 +206,6 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error

var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)"
var statement string
var tableName = session.statement.TableName()
if session.engine.dialect.DBType() == core.ORACLE {
sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL"
temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
Expand All @@ -232,9 +232,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, err
}

if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
session.cacheInsert(tableName)

lenAfterClosures := len(session.afterClosures)
for i := 0; i < size; i++ {
Expand Down Expand Up @@ -394,9 +392,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {

defer handleAfterInsertProcessorFunc(bean)

if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
session.cacheInsert(tableName)

if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
Expand Down Expand Up @@ -439,9 +435,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
}
defer handleAfterInsertProcessorFunc(bean)

if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
session.cacheInsert(tableName)

if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
Expand Down Expand Up @@ -482,9 +476,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {

defer handleAfterInsertProcessorFunc(bean)

if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
session.cacheInsert(tableName)

if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
Expand Down Expand Up @@ -531,17 +523,16 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
return session.innerInsert(bean)
}

func (session *Session) cacheInsert(table *core.Table, tables ...string) error {
if table == nil {
return ErrCacheFailed
func (session *Session) cacheInsert(table string) error {
if !session.statement.UseCache {
return nil
}

cacher := session.engine.getCacher2(table)
for _, t := range tables {
session.engine.logger.Debug("[cache] clear sql:", t)
cacher.ClearIds(t)
cacher := session.engine.getCacher(table)
if cacher == nil {
return nil
}

session.engine.logger.Debug("[cache] clear sql:", table)
cacher.ClearIds(table)
return nil
}

Expand Down
13 changes: 6 additions & 7 deletions session_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
}
}

cacher := session.engine.getCacher2(table)
cacher := session.engine.getCacher(tableName)
session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:])
if err != nil {
Expand Down Expand Up @@ -361,12 +361,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
}

if table != nil {
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
//session.cacheUpdate(table, tableName, sqlStr, args...)
cacher.ClearIds(tableName)
cacher.ClearBeans(tableName)
}
if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache {
//session.cacheUpdate(table, tableName, sqlStr, args...)
session.engine.logger.Debug("[cacheUpdate] clear table ", tableName)
cacher.ClearIds(tableName)
cacher.ClearBeans(tableName)
}

// handle after update processors
Expand Down
18 changes: 17 additions & 1 deletion session_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,7 @@ func TestUpdateMapContent(t *testing.T) {
assert.EqualValues(t, 0, c1.Age)

cnt, err = testEngine.Table(new(UpdateMapContent)).ID(c.Id).Update(map[string]interface{}{
"age": 16,
"age": 16,
"is_man": false,
"gender": 2,
})
Expand All @@ -1212,4 +1212,20 @@ func TestUpdateMapContent(t *testing.T) {
assert.EqualValues(t, 16, c2.Age)
assert.EqualValues(t, false, c2.IsMan)
assert.EqualValues(t, 2, c2.Gender)

cnt, err = testEngine.Table(testEngine.TableName(new(UpdateMapContent))).ID(c.Id).Update(map[string]interface{}{
"age": 15,
"is_man": true,
"gender": 1,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)

var c3 UpdateMapContent
has, err = testEngine.ID(c.Id).Get(&c3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 15, c3.Age)
assert.EqualValues(t, true, c3.IsMan)
assert.EqualValues(t, 1, c3.Gender)
}
Loading

0 comments on commit 636ccef

Please sign in to comment.