diff --git a/engine.go b/engine.go index 876004ba4..4984d3746 100644 --- a/engine.go +++ b/engine.go @@ -49,6 +49,35 @@ type Engine struct { tagHandlers map[string]tagHandler engineGroup *EngineGroup + + cachers map[string]core.Cacher + cacherLock sync.RWMutex +} + +func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { + engine.cacherLock.Lock() + engine.cachers[tableName] = cacher + engine.cacherLock.Unlock() +} + +func (engine *Engine) SetCacher(tableName string, cacher core.Cacher) { + engine.setCacher(tableName, cacher) +} + +func (engine *Engine) getCacher(tableName string) core.Cacher { + var cacher core.Cacher + var ok bool + engine.cacherLock.RLock() + cacher, ok = engine.cachers[tableName] + engine.cacherLock.RUnlock() + if !ok && !engine.disableGlobalCache { + cacher = engine.Cacher + } + return cacher +} + +func (engine *Engine) GetCacher(tableName string) core.Cacher { + return engine.getCacher(tableName) } // BufferSize sets buffer size for iterate @@ -245,13 +274,7 @@ func (engine *Engine) NoCascade() *Session { // MapCacher Set a table use a special cacher func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error { - v := rValue(bean) - tb, err := engine.autoMapType(v) - if err != nil { - return err - } - - tb.Cacher = cacher + engine.setCacher(engine.TableName(bean, true), cacher) return nil } @@ -834,15 +857,6 @@ func addIndex(indexName string, table *core.Table, col *core.Column, indexType i } } -func (engine *Engine) newTable() *core.Table { - table := core.NewEmptyTable() - - if !engine.disableGlobalCache { - table.Cacher = engine.Cacher - } - return table -} - // TableName table name interface to define customerize table name type TableName interface { TableName() string @@ -854,7 +868,7 @@ var ( func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { t := v.Type() - table := engine.newTable() + table := core.NewEmptyTable() table.Type = t table.Name = engine.tbNameForMap(v) @@ -1010,15 +1024,15 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { if hasCacheTag { if engine.Cacher != nil { // !nash! use engine's cacher if provided engine.logger.Info("enable cache on table:", table.Name) - table.Cacher = engine.Cacher + engine.setCacher(table.Name, engine.Cacher) } else { engine.logger.Info("enable LRU cache on table:", table.Name) - table.Cacher = NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) // !nashtsai! HACK use LRU cacher for now + engine.setCacher(table.Name, NewLRUCacher2(NewMemoryStore(), time.Hour, 10000)) } } if hasNoCacheTag { - engine.logger.Info("no cache on table:", table.Name) - table.Cacher = nil + engine.logger.Info("disable cache on table:", table.Name) + engine.setCacher(table.Name, nil) } return table, nil @@ -1123,26 +1137,10 @@ func (engine *Engine) CreateUniques(bean interface{}) error { return session.CreateUniques(bean) } -func (engine *Engine) getCacher2(table *core.Table) core.Cacher { - return table.Cacher -} - // ClearCacheBean if enabled cache, clear the cache bean func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { - v := rValue(bean) - t := v.Type() - if t.Kind() != reflect.Struct { - return errors.New("error params") - } tableName := engine.TableName(bean) - table, err := engine.autoMapType(v) - if err != nil { - return err - } - cacher := table.Cacher - if cacher == nil { - cacher = engine.Cacher - } + cacher := engine.getCacher(tableName) if cacher != nil { cacher.ClearIds(tableName) cacher.DelBean(tableName, id) @@ -1153,21 +1151,8 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { // ClearCache if enabled cache, clear some tables' cache func (engine *Engine) ClearCache(beans ...interface{}) error { for _, bean := range beans { - v := rValue(bean) - t := v.Type() - if t.Kind() != reflect.Struct { - return errors.New("error params") - } tableName := engine.TableName(bean) - table, err := engine.autoMapType(v) - if err != nil { - return err - } - - cacher := table.Cacher - if cacher == nil { - cacher = engine.Cacher - } + cacher := engine.getCacher(tableName) if cacher != nil { cacher.ClearIds(tableName) cacher.ClearBeans(tableName) diff --git a/interface.go b/interface.go index 39d50edd2..0bc12ba00 100644 --- a/interface.go +++ b/interface.go @@ -77,6 +77,7 @@ type EngineInterface interface { Dialect() core.Dialect DropTables(...interface{}) error DumpAllToFile(fp string, tp ...core.DbType) error + GetCacher(string) core.Cacher GetColumnMapper() core.IMapper GetDefaultCacher() core.Cacher GetTableMapper() core.IMapper @@ -85,6 +86,7 @@ type EngineInterface interface { NewSession() *Session NoAutoTime() *Session Quote(string) string + SetCacher(string, core.Cacher) SetDefaultCacher(core.Cacher) SetLogLevel(core.LogLevel) SetMapper(core.IMapper) diff --git a/session_delete.go b/session_delete.go index eb91614ce..d9cf3ea93 100644 --- a/session_delete.go +++ b/session_delete.go @@ -27,7 +27,7 @@ func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string, return ErrCacheFailed } - cacher := session.engine.getCacher2(table) + cacher := session.engine.getCacher(tableName) pkColumns := table.PKColumns() ids, err := core.GetCacheSql(cacher, tableName, newsql, args) if err != nil { @@ -199,7 +199,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { }) } - if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { + if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) } diff --git a/session_find.go b/session_find.go index 758ef2127..46bbf26c9 100644 --- a/session_find.go +++ b/session_find.go @@ -176,7 +176,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } if session.canCache() { - if cacher := session.engine.getCacher2(table); cacher != nil && + if cacher := session.engine.getCacher(table.Name); cacher != nil && !session.statement.IsDistinct && !session.statement.unscoped { err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) @@ -321,6 +321,12 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in return ErrCacheFailed } + tableName := session.statement.TableName() + cacher := session.engine.getCacher(tableName) + if cacher == nil { + return nil + } + for _, filter := range session.engine.dialect.Filters() { sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) } @@ -330,9 +336,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in return ErrCacheFailed } - tableName := session.statement.TableName() table := session.statement.RefTable - cacher := session.engine.getCacher2(table) ids, err := core.GetCacheSql(cacher, tableName, newsql, args) if err != nil { rows, err := session.queryRows(newsql, args...) diff --git a/session_get.go b/session_get.go index 58191de18..3b2c9493c 100644 --- a/session_get.go +++ b/session_get.go @@ -57,7 +57,7 @@ func (session *Session) get(bean interface{}) (bool, error) { table := session.statement.RefTable if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { - if cacher := session.engine.getCacher2(table); cacher != nil && + if cacher := session.engine.getCacher(table.Name); cacher != nil && !session.statement.unscoped { has, err := session.cacheGet(bean, sqlStr, args...) if err != ErrCacheFailed { @@ -134,8 +134,9 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf return false, ErrCacheFailed } - cacher := session.engine.getCacher2(session.statement.RefTable) tableName := session.statement.TableName() + cacher := session.engine.getCacher(tableName) + session.engine.logger.Debug("[cacheGet] find sql:", newsql, args) table := session.statement.RefTable ids, err := core.GetCacheSql(cacher, tableName, newsql, args) diff --git a/session_insert.go b/session_insert.go index 6d1851282..c1182fe64 100644 --- a/session_insert.go +++ b/session_insert.go @@ -70,7 +70,8 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return 0, err } - if len(session.statement.TableName()) <= 0 { + tableName := session.statement.TableName() + if len(tableName) <= 0 { return 0, ErrTableNotFound } @@ -205,7 +206,6 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)" var statement string - var tableName = session.statement.TableName() if session.engine.dialect.DBType() == core.ORACLE { sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL" temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (", @@ -232,9 +232,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return 0, err } - if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { - session.cacheInsert(table, tableName) - } + session.cacheInsert(tableName) lenAfterClosures := len(session.afterClosures) for i := 0; i < size; i++ { @@ -394,9 +392,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { defer handleAfterInsertProcessorFunc(bean) - if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { - session.cacheInsert(table, tableName) - } + session.cacheInsert(tableName) if table.Version != "" && session.statement.checkVersion { verValue, err := table.VersionColumn().ValueOf(bean) @@ -439,9 +435,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } defer handleAfterInsertProcessorFunc(bean) - if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { - session.cacheInsert(table, tableName) - } + session.cacheInsert(tableName) if table.Version != "" && session.statement.checkVersion { verValue, err := table.VersionColumn().ValueOf(bean) @@ -482,9 +476,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { defer handleAfterInsertProcessorFunc(bean) - if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { - session.cacheInsert(table, tableName) - } + session.cacheInsert(tableName) if table.Version != "" && session.statement.checkVersion { verValue, err := table.VersionColumn().ValueOf(bean) @@ -531,17 +523,16 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { return session.innerInsert(bean) } -func (session *Session) cacheInsert(table *core.Table, tables ...string) error { - if table == nil { - return ErrCacheFailed +func (session *Session) cacheInsert(table string) error { + if !session.statement.UseCache { + return nil } - - cacher := session.engine.getCacher2(table) - for _, t := range tables { - session.engine.logger.Debug("[cache] clear sql:", t) - cacher.ClearIds(t) + cacher := session.engine.getCacher(table) + if cacher == nil { + return nil } - + session.engine.logger.Debug("[cache] clear sql:", table) + cacher.ClearIds(table) return nil } diff --git a/session_update.go b/session_update.go index 2115d67f7..84c7e7fec 100644 --- a/session_update.go +++ b/session_update.go @@ -40,7 +40,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, } } - cacher := session.engine.getCacher2(table) + cacher := session.engine.getCacher(tableName) session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:]) ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:]) if err != nil { @@ -361,12 +361,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - if table != nil { - if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { - //session.cacheUpdate(table, tableName, sqlStr, args...) - cacher.ClearIds(tableName) - cacher.ClearBeans(tableName) - } + if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { + //session.cacheUpdate(table, tableName, sqlStr, args...) + session.engine.logger.Debug("[cacheUpdate] clear table ", tableName) + cacher.ClearIds(tableName) + cacher.ClearBeans(tableName) } // handle after update processors diff --git a/session_update_test.go b/session_update_test.go index 78f58df17..4183dc092 100644 --- a/session_update_test.go +++ b/session_update_test.go @@ -1198,7 +1198,7 @@ func TestUpdateMapContent(t *testing.T) { assert.EqualValues(t, 0, c1.Age) cnt, err = testEngine.Table(new(UpdateMapContent)).ID(c.Id).Update(map[string]interface{}{ - "age": 16, + "age": 16, "is_man": false, "gender": 2, }) @@ -1212,4 +1212,20 @@ func TestUpdateMapContent(t *testing.T) { assert.EqualValues(t, 16, c2.Age) assert.EqualValues(t, false, c2.IsMan) assert.EqualValues(t, 2, c2.Gender) + + cnt, err = testEngine.Table(testEngine.TableName(new(UpdateMapContent))).ID(c.Id).Update(map[string]interface{}{ + "age": 15, + "is_man": true, + "gender": 1, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var c3 UpdateMapContent + has, err = testEngine.ID(c.Id).Get(&c3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 15, c3.Age) + assert.EqualValues(t, true, c3.IsMan) + assert.EqualValues(t, 1, c3.Gender) } diff --git a/statement.go b/statement.go index 15b9048aa..38fa26d2e 100644 --- a/statement.go +++ b/statement.go @@ -950,14 +950,14 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, columnStr = "*" } - if err := statement.processIDParam(); err != nil { - return "", nil, err - } - if isStruct { if err := statement.mergeConds(bean); err != nil { return "", nil, err } + } else { + if err := statement.processIDParam(); err != nil { + return "", nil, err + } } condSQL, condArgs, err := builder.ToSQL(statement.cond) if err != nil { @@ -1143,7 +1143,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n } func (statement *Statement) processIDParam() error { - if statement.idParam == nil { + if statement.idParam == nil || statement.RefTable == nil { return nil } diff --git a/tag_cache_test.go b/tag_cache_test.go index 14a65fb80..30e2c51aa 100644 --- a/tag_cache_test.go +++ b/tag_cache_test.go @@ -19,9 +19,7 @@ func TestCacheTag(t *testing.T) { } assert.NoError(t, testEngine.CreateTables(&CacheDomain{})) - - table := testEngine.TableInfo(&CacheDomain{}) - assert.True(t, table.Cacher != nil) + assert.True(t, testEngine.GetCacher(testEngine.TableName(&CacheDomain{})) != nil) } func TestNoCacheTag(t *testing.T) { @@ -33,7 +31,5 @@ func TestNoCacheTag(t *testing.T) { } assert.NoError(t, testEngine.CreateTables(&NoCacheDomain{})) - - table := testEngine.TableInfo(&NoCacheDomain{}) - assert.True(t, table.Cacher == nil) + assert.True(t, testEngine.GetCacher(testEngine.TableName(&NoCacheDomain{})) == nil) } diff --git a/xorm.go b/xorm.go index b731aa5f4..8b2316bb3 100644 --- a/xorm.go +++ b/xorm.go @@ -90,6 +90,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { TagIdentifier: "xorm", TZLocation: time.Local, tagHandlers: defaultTagHandlers, + cachers: make(map[string]core.Cacher), } if uri.DbType == core.SQLITE {