diff --git a/manager.go b/manager.go index 0431c5a..43d44cd 100644 --- a/manager.go +++ b/manager.go @@ -119,7 +119,7 @@ func (m *Manager) NewSession(r, w interface{}) (*Session, error) { manager: m, reader: r, writer: w, - values: make(map[string]interface{}), + cache: nil, } // Write cookie. if err := sess.WriteCookie(id); err != nil { @@ -166,7 +166,7 @@ func (m *Manager) Acquire(c context.Context, r, w interface{}) (*Session, error) reader: r, writer: w, id: ck.Value, - values: make(map[string]interface{}), + cache: nil, }, nil } diff --git a/manager_test.go b/manager_test.go index 3983257..f7e3dc4 100644 --- a/manager_test.go +++ b/manager_test.go @@ -15,7 +15,6 @@ func newMockStore() *MockStore { return &MockStore{ id: mockSessionID, data: map[string]interface{}{}, - temp: map[string]interface{}{}, err: nil, } } diff --git a/session.go b/session.go index 998fc0a..842ac8d 100644 --- a/session.go +++ b/session.go @@ -3,14 +3,16 @@ package simplesessions import ( "errors" "net/http" + "sync" "time" ) // Session provides the object to get, set, or clear session data. type Session struct { - // Map to store session data, loaded using `LoadValues` method. + // Map to store session data, loaded using `CacheAll` method. // All `Get` methods checks here before fetching from the store. - values map[string]interface{} + cache map[string]interface{} + cacheMux sync.RWMutex // Session manager. manager *Manager @@ -82,49 +84,126 @@ func (s *Session) ID() string { return s.id } -// LoadValues loads session values into memory for quick access. +// getCacheAll returns a copy of cached map. +func (s *Session) getCacheAll() map[string]interface{} { + s.cacheMux.RLock() + defer s.cacheMux.RUnlock() + + if s.cache == nil { + return nil + } + + out := map[string]interface{}{} + for k, v := range s.cache { + out[k] = v + } + + return out +} + +// getCacheAll returns a map of values for the given list of keys. +// If key doesn't exist then ErrFieldNotFound is returned. +func (s *Session) getCache(key ...string) map[string]interface{} { + s.cacheMux.RLock() + defer s.cacheMux.RUnlock() + + if s.cache == nil { + return nil + } + + out := map[string]interface{}{} + for _, k := range key { + v, ok := s.cache[k] + if ok { + out[k] = v + } else { + out[k] = ErrFieldNotFound + } + } + + return out +} + +// setCache sets a cache for given kv pairs. +func (s *Session) setCache(data map[string]interface{}) { + s.cacheMux.Lock() + defer s.cacheMux.Unlock() + + // If cacheAll() is not called the don't maintain cache. + if s.cache == nil { + return + } + + for k, v := range data { + s.cache[k] = v + } +} + +// deleteCache sets a cache for given kv pairs. +func (s *Session) deleteCache(key ...string) { + s.cacheMux.Lock() + defer s.cacheMux.Unlock() + + // If cacheAll() is not called the don't maintain cache. + if s.cache == nil { + return + } + + for _, k := range key { + delete(s.cache, k) + } +} + +// CacheAll loads session values into memory for quick access. // Ideal for centralized session fetching, e.g., in middleware. // Subsequent Get/GetMulti calls return cached values, avoiding store access. -// Use ResetValues() to ensure GetAll/Get/GetMulti fetches from the store. -// Set/SetMulti/Clear do not update the values, so this method must be called again for any changes. -func (s *Session) LoadValues() error { - var err error - s.values, err = s.GetAll() - return err +// Use ResetCache() to ensure GetAll/Get/GetMulti fetches from the store. +func (s *Session) CacheAll() error { + all, err := s.manager.store.GetAll(s.id) + if err != nil { + return err + } + + s.cacheMux.Lock() + defer s.cacheMux.Unlock() + s.cache = map[string]interface{}{} + for k, v := range all { + s.cache[k] = v + } + + return nil } -// ResetValues clears loaded values, ensuring subsequent Get, GetAll, and GetMulti calls fetch from the store. -func (s *Session) ResetValues() { - s.values = make(map[string]interface{}) +// ResetCache clears loaded values, ensuring subsequent Get, GetAll, and GetMulti calls fetch from the store. +func (s *Session) ResetCache() { + s.cacheMux.Lock() + defer s.cacheMux.Unlock() + s.cache = nil } // GetAll gets all the fields for the given session id. func (s *Session) GetAll() (map[string]interface{}, error) { - // Load value from map if its already loaded. - if len(s.values) > 0 { - return s.values, nil + // Try to get the values from cache. + c := s.getCacheAll() + if c != nil { + return c, nil } + // Get the values from store. out, err := s.manager.store.GetAll(s.id) return out, errAs(err) } // GetMulti retrieves values for multiple session fields. // If a field is not found in the store then its returned as nil. -func (s *Session) GetMulti(keys ...string) (map[string]interface{}, error) { - // Load values from map if its already loaded - if len(s.values) > 0 { - vals := make(map[string]interface{}) - for _, k := range keys { - if v, ok := s.values[k]; ok { - vals[k] = v - } - } - - return vals, nil +func (s *Session) GetMulti(key ...string) (map[string]interface{}, error) { + // Try to get the values from cache. + c := s.getCache(key...) + if c != nil { + return c, nil } - out, err := s.manager.store.GetMulti(s.id, keys...) + out, err := s.manager.store.GetMulti(s.id, key...) return out, errAs(err) } @@ -132,10 +211,14 @@ func (s *Session) GetMulti(keys ...string) (map[string]interface{}, error) { // If the session is already loaded, it returns the value from the existing map. // Otherwise, it fetches the value from the store. func (s *Session) Get(key string) (interface{}, error) { - // Return value from map if already loaded. - if len(s.values) > 0 { - if val, ok := s.values[key]; ok { - return val, nil + // Try to get the values from cache. + c := s.getCache(key) + if c != nil { + err, ok := c[key].(error) + if ok { + return nil, err + } else { + return c[key], nil } } @@ -145,45 +228,41 @@ func (s *Session) Get(key string) (interface{}, error) { } // Set assigns a value to the given key in the session. -// The store determines whether to commit all values at once or store them individually. -// Use Commit() method to commit all values if the store doesn't immediately persist them. func (s *Session) Set(key string, val interface{}) error { err := s.manager.store.Set(s.id, key, val) + if err == nil { + s.setCache(map[string]interface{}{ + key: val, + }) + } return errAs(err) } // SetMulti assigns multiple values to the session. -// The store determines whether to commit all values at once or store them individually. -func (s *Session) SetMulti(values map[string]interface{}) error { - for k, v := range values { - if err := s.manager.store.Set(s.id, k, v); err != nil { - return errAs(err) - } - } - return nil -} - -// Commit persists all values to the store. -// The store determines whether to commit all values at once or store them individually. -func (s *Session) Commit() error { - if err := s.manager.store.Commit(s.id); err != nil { - return errAs(err) +func (s *Session) SetMulti(data map[string]interface{}) error { + err := s.manager.store.SetMulti(s.id, data) + if err == nil { + s.setCache(data) } - return nil + return errAs(err) } // Delete deletes a field from session. -func (s *Session) Delete(key string) error { - if err := s.manager.store.Delete(s.id, key); err != nil { - return errAs(err) +func (s *Session) Delete(key ...string) error { + err := s.manager.store.Delete(s.id, key...) + if err == nil { + s.deleteCache(key...) } - return nil + return errAs(err) } // Clear clears session data from store and clears the cookie. func (s *Session) Clear() error { - if err := s.manager.store.Clear(s.id); err != nil { + err := s.manager.store.Clear(s.id) + if err != nil { return errAs(err) + } else { + s.ResetCache() } return s.clearCookie() } diff --git a/session_test.go b/session_test.go index d2051b3..ec20148 100644 --- a/session_test.go +++ b/session_test.go @@ -93,7 +93,7 @@ func TestSessionNewSession(t *testing.T) { assert.Equal(t, mgr, sess.manager) assert.Equal(t, reader, sess.reader) assert.Equal(t, writer, sess.writer) - assert.NotNil(t, sess.values) + assert.Nil(t, sess.cache) assert.Equal(t, mockSessionID, sess.id) assert.Equal(t, sess.id, sess.ID()) } @@ -204,7 +204,7 @@ func TestSessionClearCookie(t *testing.T) { assert.True(t, receCk.Expires.UnixNano() < time.Now().UnixNano()) } -func TestSessionLoadValues(t *testing.T) { +func TestSessionCacheAll(t *testing.T) { str := newMockStore() str.data = map[string]interface{}{ "key1": 1, @@ -215,12 +215,20 @@ func TestSessionLoadValues(t *testing.T) { sess, err := mgr.NewSession(nil, nil) assert.NoError(t, err) - err = sess.LoadValues() + // Test error. + str.err = errors.New("store error") + err = sess.CacheAll() + assert.ErrorIs(t, str.err, err) + assert.Nil(t, sess.cache) + + // Test without error. + str.err = nil + err = sess.CacheAll() assert.NoError(t, err) - assert.Equal(t, str.data, sess.values) + assert.Equal(t, str.data, sess.cache) } -func TestSessionResetValues(t *testing.T) { +func TestSessionResetCache(t *testing.T) { str := newMockStore() str.data = map[string]interface{}{ "key1": 1, @@ -228,11 +236,11 @@ func TestSessionResetValues(t *testing.T) { } mgr := newMockManager(str) sess, _ := mgr.NewSession(nil, nil) - sess.LoadValues() - assert.NotEqual(t, 0, len(sess.values)) + sess.CacheAll() + assert.Equal(t, str.data, sess.cache) - sess.ResetValues() - assert.Equal(t, 0, len(sess.values)) + sess.ResetCache() + assert.Nil(t, sess.cache) } func TestSessionGetStore(t *testing.T) { @@ -267,13 +275,13 @@ func TestSessionGetStore(t *testing.T) { assert.Equal(t, str.data["key1"], v3) } -func TestSessionGetLoaded(t *testing.T) { +func TestSessionGetCached(t *testing.T) { str := newMockStore() mgr := newMockManager(str) sess, err := mgr.NewSession(nil, nil) assert.NoError(t, err) - sess.values = map[string]interface{}{ + sess.cache = map[string]interface{}{ "key1": 1, "key2": 2, "key3": 3, @@ -282,22 +290,42 @@ func TestSessionGetLoaded(t *testing.T) { // GetAll. v1, err := sess.GetAll() assert.NoError(t, err) - assert.Equal(t, sess.values, v1) + assert.Equal(t, sess.cache, v1) // GetMulti. v2, err := sess.GetMulti("key1", "key2") assert.NoError(t, err) assert.Contains(t, v2, "key1") - assert.Equal(t, sess.values["key1"], v2["key1"]) + assert.Equal(t, sess.cache["key1"], v2["key1"]) assert.Contains(t, v2, "key2") - assert.Equal(t, sess.values["key2"], v2["key2"]) + assert.Equal(t, sess.cache["key2"], v2["key2"]) assert.NotContains(t, v2, "key3") // Get. v3, err := sess.Get("key1") assert.NoError(t, err) - assert.Contains(t, sess.values, "key1") - assert.Equal(t, sess.values["key1"], v3) + assert.Contains(t, sess.cache, "key1") + assert.Equal(t, sess.cache["key1"], v3) + + // Get unknowm field. + _, err = sess.Get("key99") + assert.ErrorIs(t, ErrFieldNotFound, err) + + // GetMulti unknown fields + v4, err := sess.GetMulti("key1", "key2", "key99", "key100") + assert.NoError(t, err) + assert.Contains(t, v4, "key1") + assert.Equal(t, sess.cache["key1"], v4["key1"]) + assert.Contains(t, v4, "key99") + assert.Contains(t, v4, "key100") + + err, ok := v4["key99"].(error) + assert.True(t, ok) + assert.ErrorIs(t, ErrFieldNotFound, err) + + err, ok = v4["key100"].(error) + assert.True(t, ok) + assert.ErrorIs(t, ErrFieldNotFound, err) } func TestSessionSet(t *testing.T) { @@ -310,19 +338,18 @@ func TestSessionSet(t *testing.T) { err = sess.Set("key1", 1) assert.NoError(t, err) - // Check if its set on temp. - assert.Contains(t, str.temp, "key1") - assert.NotContains(t, str.data, "key1") - assert.Equal(t, str.temp["key1"], 1) - - // Commit. - err = sess.Commit() - assert.NoError(t, err) - // Check if its set on data after commit. assert.Contains(t, str.data, "key1") - assert.NotContains(t, str.temp, "key1") assert.Equal(t, 1, str.data["key1"]) + assert.Nil(t, sess.cache) + + // Cache and set. + err = sess.CacheAll() + assert.NoError(t, err) + err = sess.Set("key1", 1) + assert.NoError(t, err) + assert.NotNil(t, sess.cache) + assert.Equal(t, sess.cache, str.data) } func TestSessionSetMulti(t *testing.T) { @@ -340,43 +367,28 @@ func TestSessionSetMulti(t *testing.T) { err = sess.SetMulti(data) assert.NoError(t, err) - // Check if its set on temp. - assert.Contains(t, str.temp, "key1") - assert.Contains(t, str.temp, "key2") - assert.Contains(t, str.temp, "key3") - assert.NotContains(t, str.data, "key1") - assert.NotContains(t, str.data, "key2") - assert.NotContains(t, str.data, "key3") - assert.Equal(t, data["key1"], str.temp["key1"]) - assert.Equal(t, data["key2"], str.temp["key2"]) - assert.Equal(t, data["key3"], str.temp["key3"]) - - // Commit. - err = sess.Commit() - assert.NoError(t, err) - // Check if its set on data after commit. assert.Contains(t, str.data, "key1") assert.Contains(t, str.data, "key2") assert.Contains(t, str.data, "key3") - assert.NotContains(t, str.temp, "key1") - assert.NotContains(t, str.temp, "key2") - assert.NotContains(t, str.temp, "key3") assert.Equal(t, data["key1"], str.data["key1"]) assert.Equal(t, data["key2"], str.data["key2"]) assert.Equal(t, data["key3"], str.data["key3"]) + assert.Nil(t, sess.cache) - // Test error. - str.err = errors.New("store error") + // Cache and set. + str.data = map[string]interface{}{} + err = sess.CacheAll() + assert.NoError(t, err) err = sess.SetMulti(data) - assert.ErrorIs(t, str.err, err) + assert.NoError(t, err) + assert.NotNil(t, sess.cache) + assert.Equal(t, sess.cache, str.data) // Test error. - str.err = nil - err = sess.SetMulti(data) - assert.NoError(t, err) + sess.ResetCache() str.err = errors.New("store error") - err = sess.Commit() + err = sess.SetMulti(data) assert.ErrorIs(t, str.err, err) } @@ -396,6 +408,14 @@ func TestSessionDelete(t *testing.T) { assert.NoError(t, err) assert.NotContains(t, str.data, "key1") + // Cache and set. + err = sess.CacheAll() + assert.NoError(t, err) + err = sess.Delete("key2") + assert.NoError(t, err) + assert.NotNil(t, sess.cache) + assert.Equal(t, sess.cache, str.data) + // Test error. str.err = errors.New("store error") err = sess.Delete("key2") @@ -433,6 +453,24 @@ func TestSessionClear(t *testing.T) { err = sess.Clear() assert.NoError(t, err) assert.Equal(t, 0, len(str.data)) + assert.Nil(t, sess.cache) + + // Test clear. + str = newMockStore() + str.data = map[string]interface{}{ + "key1": 1, + "key2": 2, + } + mgr = newMockManager(str) + sess, err = mgr.NewSession(nil, nil) + assert.NoError(t, err) + err = sess.CacheAll() + assert.NoError(t, err) + assert.NotNil(t, sess.cache) + err = sess.Clear() + assert.NoError(t, err) + assert.Equal(t, 0, len(str.data)) + assert.Nil(t, sess.cache) // Test deleteCookie callback. var ( diff --git a/store.go b/store.go index 5f22791..476337c 100644 --- a/store.go +++ b/store.go @@ -3,30 +3,30 @@ package simplesessions // Store represents store interface. This interface can be // implemented to create various backend stores for session. type Store interface { - // Create creates new session in store and returns the cookie value. - Create() (cookieValue string, err error) + // Create creates new session in the store and returns the session ID. + Create() (id string, err error) // Get gets a value for given key from session. - Get(cookieValue, key string) (value interface{}, err error) + Get(id, key string) (value interface{}, err error) // GetMulti gets a maps of multiple values for given keys. - GetMulti(cookieValue string, keys ...string) (values map[string]interface{}, err error) + // If some fields are not found then return ErrFieldNotFound for that field. + GetMulti(id string, keys ...string) (data map[string]interface{}, err error) - // GetAll gets all key and value from session, - GetAll(cookieValue string) (values map[string]interface{}, err error) + // GetAll gets all key and value from session. + GetAll(id string) (data map[string]interface{}, err error) // Set sets an value for a field in session. - // Its up to store to either store it in session right after set or after commit. - Set(cookieValue, key string, value interface{}) error + Set(id, key string, value interface{}) error - // Commit commits all the previously set values to store. - Commit(cookieValue string) error + // Set takes a map of kv pair and set the field in store. + SetMulti(id string, data map[string]interface{}) error - // Delete a field from session. - Delete(cookieValue string, key string) error + // Delete a given list of keys from session. + Delete(id string, key ...string) error - // Clear clears the session key from backend if exists. - Clear(cookieValue string) error + // Clear clears the entire session. + Clear(id string) error // Helper method for typecasting/asserting. Int(interface{}, error) (int, error) diff --git a/store_test.go b/store_test.go index ab1f43b..e08dc59 100644 --- a/store_test.go +++ b/store_test.go @@ -5,14 +5,13 @@ type MockStore struct { err error id string data map[string]interface{} - temp map[string]interface{} } -func (s *MockStore) Create() (cv string, err error) { +func (s *MockStore) Create() (string, error) { return s.id, s.err } -func (s *MockStore) Get(cv, key string) (value interface{}, err error) { +func (s *MockStore) Get(id, key string) (interface{}, error) { if s.id == "" { return nil, ErrInvalidSession } @@ -21,24 +20,19 @@ func (s *MockStore) Get(cv, key string) (value interface{}, err error) { if !ok { return nil, ErrFieldNotFound } - return d, s.err } -func (s *MockStore) GetMulti(cv string, keys ...string) (values map[string]interface{}, err error) { +func (s *MockStore) GetMulti(id string, keys ...string) (values map[string]interface{}, err error) { if s.id == "" { return nil, ErrInvalidSession } out := make(map[string]interface{}) for _, key := range keys { - v, err := s.Get(cv, key) - if err != nil { - if err == ErrFieldNotFound { - v = nil - } else { - return nil, err - } + v, ok := s.data[key] + if !ok { + v = err } out[key] = v } @@ -46,7 +40,7 @@ func (s *MockStore) GetMulti(cv string, keys ...string) (values map[string]inter return out, s.err } -func (s *MockStore) GetAll(cv string) (values map[string]interface{}, err error) { +func (s *MockStore) GetAll(id string) (values map[string]interface{}, err error) { if s.id == "" { return nil, ErrInvalidSession } @@ -59,39 +53,38 @@ func (s *MockStore) Set(cv, key string, value interface{}) error { return ErrInvalidSession } - s.temp[key] = value + s.data[key] = value return s.err } -func (s *MockStore) Commit(cv string) error { +func (s *MockStore) SetMulti(id string, data map[string]interface{}) error { if s.id == "" { return ErrInvalidSession } - for key, val := range s.temp { - s.data[key] = val + for k, v := range data { + s.data[k] = v } - s.temp = map[string]interface{}{} - return s.err } -func (s *MockStore) Delete(cv string, key string) error { +func (s *MockStore) Delete(id string, key ...string) error { if s.id == "" { return ErrInvalidSession } - s.temp = nil - delete(s.data, key) - delete(s.temp, key) + + for _, k := range key { + delete(s.data, k) + } return s.err } -func (s *MockStore) Clear(cv string) error { +func (s *MockStore) Clear(id string) error { if s.id == "" { return ErrInvalidSession } + s.data = map[string]interface{}{} - s.temp = map[string]interface{}{} return s.err }