From cf1786de8f70efbc1093f88219a3b9b82407bb03 Mon Sep 17 00:00:00 2001 From: Kailash Nadh Date: Tue, 21 May 2024 16:25:17 +0530 Subject: [PATCH 1/4] Refactor the session `AutoCreate` behaviour. This is a breaking change. --- manager.go | 55 ++++++++++++++++++-- manager_test.go | 79 +++++++++++++++++++--------- session.go | 136 ++++++------------------------------------------ session_test.go | 134 +++++++++++++++++++---------------------------- store_test.go | 3 +- 5 files changed, 176 insertions(+), 231 deletions(-) diff --git a/manager.go b/manager.go index e46cd51..93c9e73 100644 --- a/manager.go +++ b/manager.go @@ -32,8 +32,8 @@ type Manager struct { // Options are available options to configure Manager. type Options struct { - // DisableAutoSet skips creation of session cookie in frontend and new session in store if session is not already set. - DisableAutoSet bool + // If enabled, Acquire() will always create and return a new session if one doesn't already exist. + EnableAutoCreate bool // CookieName sets http cookie name. This is also sent as cookie name in `GetCookie` callback. CookieName string @@ -94,6 +94,36 @@ func (m *Manager) RegisterSetCookie(cb func(*http.Cookie, interface{}) error) { m.setCookieCb = cb } +// NewSession creates a new session. Reads cookie info from `GetCookie“ callback +// and validate the session with current store. If cookie not set then it creates +// new session and calls `SetCookie“ callback. If `DisableAutoSet` is set then it +// skips new session creation and should be manually done using `Create` method. +// If a cookie is found but its invalid in store then `ErrInvalidSession` error is returned. +func (m *Manager) NewSession(r, w interface{}) (*Session, error) { + var ( + sess = &Session{ + manager: m, + reader: r, + writer: w, + values: make(map[string]interface{}), + } + ) + + // Create new cookie in store and write to front. + // Store also calls `WriteCookie`` to write to http interface. + id, err := m.store.Create() + if err != nil { + return nil, errAs(err) + } + + // Write cookie. + if err := sess.WriteCookie(id); err != nil { + return nil, err + } + + return sess, nil +} + // Acquire gets a `Session` for current session cookie from store. // If `Session` is not found on store then it creates a new session and sets on store. // If 'DisableAutoSet` is set in options then session has to be explicitly created before @@ -124,5 +154,24 @@ func (m *Manager) Acquire(r, w interface{}, c context.Context) (*Session, error) } } - return NewSession(m, r, w) + // Get existing HTTP session cookie. + // If there's no error and there's a session ID (unvalidated at this point), + // return a session object. + ck, err := m.getCookieCb(m.opts.CookieName, r) + if err == nil && ck != nil && ck.Value != "" { + return &Session{ + manager: m, + reader: r, + writer: w, + id: ck.Value, + values: make(map[string]interface{}), + }, nil + } + + // If auto-creation is disabled, return an error. + if !m.opts.EnableAutoCreate { + return nil, ErrInvalidSession + } + + return m.NewSession(r, w) } diff --git a/manager_test.go b/manager_test.go index 2ba4d7b..b949aa3 100644 --- a/manager_test.go +++ b/manager_test.go @@ -21,7 +21,7 @@ func TestNewManagerWithDefaultOptions(t *testing.T) { func TestManagerNewManagerWithOptions(t *testing.T) { opts := Options{ - DisableAutoSet: true, + EnableAutoCreate: true, CookieName: "testcookiename", CookieDomain: "somedomain", CookiePath: "/abc/123", @@ -36,7 +36,7 @@ func TestManagerNewManagerWithOptions(t *testing.T) { assert := assert.New(t) // Default cookie path is set to root - assert.Equal(m.opts.DisableAutoSet, opts.DisableAutoSet) + assert.Equal(m.opts.EnableAutoCreate, opts.EnableAutoCreate) assert.Equal(m.opts.CookieName, opts.CookieName) assert.Equal(m.opts.CookieDomain, opts.CookieDomain) assert.Equal(m.opts.CookiePath, opts.CookiePath) @@ -60,12 +60,12 @@ func TestManagerRegisterGetCookie(t *testing.T) { assert := assert.New(t) m := New(Options{}) - testCookie := &http.Cookie{ + ck := &http.Cookie{ Name: "testcookie", } cb := func(string, interface{}) (*http.Cookie, error) { - return testCookie, http.ErrNoCookie + return ck, http.ErrNoCookie } m.RegisterGetCookie(cb) @@ -81,7 +81,7 @@ func TestManagerRegisterSetCookie(t *testing.T) { assert := assert.New(t) m := New(Options{}) - testCookie := &http.Cookie{ + ck := &http.Cookie{ Name: "testcookie", } @@ -91,15 +91,15 @@ func TestManagerRegisterSetCookie(t *testing.T) { m.RegisterSetCookie(cb) - expectCbErr := cb(testCookie, nil) - actualCbErr := m.setCookieCb(testCookie, nil) + expectCbErr := cb(ck, nil) + actualCbErr := m.setCookieCb(ck, nil) assert.Equal(expectCbErr, actualCbErr) } func TestManagerAcquireFails(t *testing.T) { assert := assert.New(t) - m := New(Options{}) + m := New(Options{EnableAutoCreate: false}) _, err := m.Acquire(nil, nil, nil) assert.Error(err, "session store is not set") @@ -108,32 +108,60 @@ func TestManagerAcquireFails(t *testing.T) { _, err = m.Acquire(nil, nil, nil) assert.Error(err, "callback `GetCookie` not set") - getCb := func(string, interface{}) (*http.Cookie, error) { + m.RegisterGetCookie(func(string, interface{}) (*http.Cookie, error) { return nil, nil - } - m.RegisterGetCookie(getCb) + }) + _, err = m.Acquire(nil, nil, nil) assert.Error(err, "callback `SetCookie` not set") + + m.RegisterSetCookie(func(*http.Cookie, interface{}) error { + return nil + }) + _, err = m.Acquire(nil, nil, nil) + assert.ErrorIs(err, ErrInvalidSession) } -func TestManagerAcquireSucceeds(t *testing.T) { - m := New(Options{}) +func TestManagerAcquireNoAutocreate(t *testing.T) { + m := New(Options{EnableAutoCreate: false}) m.UseStore(&MockStore{ isValid: true, + id: "somerandomid", }) - getCb := func(string, interface{}) (*http.Cookie, error) { + m.RegisterGetCookie(func(string, interface{}) (*http.Cookie, error) { + return &http.Cookie{ + Name: "testcookie", + Value: "somerandomid", + }, nil + }) + + m.RegisterSetCookie(func(*http.Cookie, interface{}) error { + return nil + }) + + _, err := m.Acquire(nil, nil, nil) + assert := assert.New(t) + assert.NoError(err) +} + +func TestManagerAcquireAutocreate(t *testing.T) { + m := New(Options{EnableAutoCreate: true}) + m.UseStore(&MockStore{ + isValid: true, + id: "somerandomid", + }) + + m.RegisterGetCookie(func(string, interface{}) (*http.Cookie, error) { return &http.Cookie{ Name: "testcookie", Value: "", }, nil - } - m.RegisterGetCookie(getCb) + }) - setCb := func(*http.Cookie, interface{}) error { - return http.ErrNoCookie - } - m.RegisterSetCookie(setCb) + m.RegisterSetCookie(func(*http.Cookie, interface{}) error { + return nil + }) _, err := m.Acquire(nil, nil, nil) assert := assert.New(t) @@ -142,9 +170,10 @@ func TestManagerAcquireSucceeds(t *testing.T) { func TestManagerAcquireFromContext(t *testing.T) { assert := assert.New(t) - m := New(Options{}) + m := New(Options{EnableAutoCreate: true}) m.UseStore(&MockStore{ isValid: true, + id: "somerandomid", }) getCb := func(string, interface{}) (*http.Cookie, error) { @@ -156,20 +185,20 @@ func TestManagerAcquireFromContext(t *testing.T) { m.RegisterGetCookie(getCb) setCb := func(*http.Cookie, interface{}) error { - return http.ErrNoCookie + return nil } m.RegisterSetCookie(setCb) sess, err := m.Acquire(nil, nil, nil) assert.NoError(err) - sess.cookie.Value = "updated" + sess.id = "updated" sessNew, err := m.Acquire(nil, nil, nil) assert.NoError(err) - assert.NotEqual(sessNew.cookie.Value, sess.cookie.Value) + assert.NotEqual(sessNew.id, sess.id) ctx := context.Background() ctx = context.WithValue(ctx, ContextName, sess) sessNext, err := m.Acquire(nil, nil, ctx) - assert.Equal(sessNext.cookie.Value, sess.cookie.Value) + assert.Equal(sessNext.id, sess.id) } diff --git a/session.go b/session.go index e34d742..efb5b24 100644 --- a/session.go +++ b/session.go @@ -15,18 +15,13 @@ type Session struct { // Session manager. manager *Manager - // Current http cookie. This is passed down to `SetCookie` callback. - cookie *http.Cookie + // Session ID. + id string // HTTP reader and writer interfaces which are passed on to // `GetCookie`` and `SetCookie`` callback respectively. reader interface{} writer interface{} - - // Track if session is set in store or not - // used to throw and error is autoSet is not enabled and user - // explicitly didn't create new session in store. - isSet bool } var ( @@ -52,60 +47,10 @@ type errCode interface { Code() int } -// NewSession creates a new session. Reads cookie info from `GetCookie“ callback -// and validate the session with current store. If cookie not set then it creates -// new session and calls `SetCookie“ callback. If `DisableAutoSet` is set then it -// skips new session creation and should be manually done using `Create` method. -// If a cookie is found but its invalid in store then `ErrInvalidSession` error is returned. -func NewSession(m *Manager, r, w interface{}) (*Session, error) { - var ( - err error - sess = &Session{ - manager: m, - reader: r, - writer: w, - values: make(map[string]interface{}), - } - ) - - // Get existing http session cookie - sess.cookie, err = m.getCookieCb(m.opts.CookieName, r) - - // Create new session - if err == http.ErrNoCookie { - // Skip creating new cookie in store. User has to manually create before doing Get or Set. - if m.opts.DisableAutoSet { - return sess, nil - } - - // Create new cookie in store and write to front - // Store also calls `WriteCookie`` to write to http interface - cv, err := m.store.Create() - if err != nil { - return nil, errAs(err) - } - - // Write cookie - if err := sess.WriteCookie(cv); err != nil { - return nil, err - } - - // Set isSet flag - sess.isSet = true - } else if err != nil { - return nil, err - } - - // Set isSet flag - sess.isSet = true - - return sess, nil -} - // WriteCookie updates the cookie and calls `SetCookie` callback. // This method can also be used by store to update cookie whenever the cookie value changes. func (s *Session) WriteCookie(cv string) error { - s.cookie = &http.Cookie{ + ck := &http.Cookie{ Value: cv, Name: s.manager.opts.CookieName, Domain: s.manager.opts.CookieDomain, @@ -115,18 +60,13 @@ func (s *Session) WriteCookie(cv string) error { SameSite: s.manager.opts.SameSite, } - // Set cookie expiry - if s.manager.opts.CookieLifetime != 0 { - s.cookie.Expires = time.Now().Add(s.manager.opts.CookieLifetime) - } - // Call `SetCookie` callback to write cookie to response - return s.manager.setCookieCb(s.cookie, s.writer) + return s.manager.setCookieCb(ck, s.writer) } // clearCookie sets expiry of the cookie to one day before to clear it. func (s *Session) clearCookie() error { - s.cookie = &http.Cookie{ + ck := &http.Cookie{ Name: s.manager.opts.CookieName, Value: "", // Set expiry to previous date to clear it from browser @@ -134,7 +74,7 @@ func (s *Session) clearCookie() error { } // Call `SetCookie` callback to write cookie to response - return s.manager.setCookieCb(s.cookie, s.writer) + return s.manager.setCookieCb(ck, s.writer) } // Create a new session. This is implicit when option `DisableAutoSet` is false @@ -151,18 +91,12 @@ func (s *Session) Create() error { return err } - // Set isSet flag - s.isSet = true - return nil } // ID returns the acquired session ID. If cookie is not set then empty string is returned. func (s *Session) ID() string { - if s.cookie != nil { - return s.cookie.Value - } - return "" + return s.id } // LoadValues loads the session values in memory. @@ -181,27 +115,17 @@ func (s *Session) ResetValues() { // GetAll gets all the fields in the session. func (s *Session) GetAll() (map[string]interface{}, error) { - // Check if session is set before accessing it - if !s.isSet { - return nil, ErrInvalidSession - } - // Load value from map if its already loaded if len(s.values) > 0 { return s.values, nil } - out, err := s.manager.store.GetAll(s.cookie.Value) + out, err := s.manager.store.GetAll(s.id) return out, errAs(err) } // GetMulti gets a map of values for multiple session keys. func (s *Session) GetMulti(keys ...string) (map[string]interface{}, error) { - // Check if session is set before accessing it - if !s.isSet { - return nil, ErrInvalidSession - } - // Load values from map if its already loaded if len(s.values) > 0 { vals := make(map[string]interface{}) @@ -214,7 +138,7 @@ func (s *Session) GetMulti(keys ...string) (map[string]interface{}, error) { return vals, nil } - out, err := s.manager.store.GetMulti(s.cookie.Value, keys...) + out, err := s.manager.store.GetMulti(s.id, keys...) return out, errAs(err) } @@ -222,11 +146,6 @@ func (s *Session) GetMulti(keys ...string) (map[string]interface{}, error) { // If session is already loaded using `Load` then returns values from // existing map instead of getting it from store. func (s *Session) Get(key string) (interface{}, error) { - // Check if session is set before accessing it - if !s.isSet { - return nil, ErrInvalidSession - } - // Load value from map if its already loaded if len(s.values) > 0 { if val, ok := s.values[key]; ok { @@ -235,19 +154,14 @@ func (s *Session) Get(key string) (interface{}, error) { } // Get from backend if not found in previous step - out, err := s.manager.store.Get(s.cookie.Value, key) + out, err := s.manager.store.Get(s.id, key) return out, errAs(err) } // Set sets a value for given key in session. Its up to store to commit // all previously set values at once or store it on each set. func (s *Session) Set(key string, val interface{}) error { - // Check if session is set before accessing it - if !s.isSet { - return ErrInvalidSession - } - - err := s.manager.store.Set(s.cookie.Value, key, val) + err := s.manager.store.Set(s.id, key, val) return errAs(err) } @@ -255,13 +169,8 @@ func (s *Session) Set(key string, val interface{}) error { // Its up to store to commit all previously // set values at once or store it on each set. func (s *Session) SetMulti(values map[string]interface{}) error { - // Check if session is set before accessing it - if !s.isSet { - return ErrInvalidSession - } - for k, v := range values { - if err := s.manager.store.Set(s.cookie.Value, k, v); err != nil { + if err := s.manager.store.Set(s.id, k, v); err != nil { return errAs(err) } } @@ -272,12 +181,7 @@ func (s *Session) SetMulti(values map[string]interface{}) error { // Commit commits all set to store. Its up to store to commit // all previously set values at once or store it on each set. func (s *Session) Commit() error { - // Check if session is set before accessing it - if !s.isSet { - return ErrInvalidSession - } - - if err := s.manager.store.Commit(s.cookie.Value); err != nil { + if err := s.manager.store.Commit(s.id); err != nil { return errAs(err) } @@ -286,12 +190,7 @@ func (s *Session) Commit() error { // Delete deletes a field from session. func (s *Session) Delete(key string) error { - // Check if session is set before accessing it - if !s.isSet { - return ErrInvalidSession - } - - if err := s.manager.store.Delete(s.cookie.Value, key); err != nil { + if err := s.manager.store.Delete(s.id, key); err != nil { return errAs(err) } @@ -300,12 +199,7 @@ func (s *Session) Delete(key string) error { // Clear clears session data from store and clears the cookie func (s *Session) Clear() error { - // Check if session is set before accessing it - if !s.isSet { - return ErrInvalidSession - } - - if err := s.manager.store.Clear(s.cookie.Value); err != nil { + if err := s.manager.store.Clear(s.id); err != nil { return errAs(err) } diff --git a/session_test.go b/session_test.go index 6a130f1..ccd13ea 100644 --- a/session_test.go +++ b/session_test.go @@ -90,16 +90,13 @@ func TestSessionNewSession(t *testing.T) { mockManager := newMockManager(mockStore) assert := assert.New(t) - sess, err := NewSession(mockManager, reader, writer) + sess, err := mockManager.NewSession(reader, writer) assert.NoError(err) assert.Equal(sess.manager, mockManager) assert.Equal(sess.reader, reader) assert.Equal(sess.writer, writer) assert.NotNil(sess.values) - assert.NotNil(sess.cookie) - assert.Equal(sess.cookie.Name, defaultCookieName) - assert.Equal(sess.cookie.Value, testCookieValue) - assert.True(sess.isSet) + assert.Equal(sess.id, testCookieValue) } func TestSessionNewSessionErrorStoreCreate(t *testing.T) { @@ -109,14 +106,14 @@ func TestSessionNewSessionErrorStoreCreate(t *testing.T) { testError := errors.New("this is test error") newCookieVal := "somerandomid" - mockStore.val = newCookieVal + mockStore.id = newCookieVal mockStore.err = testError mockManager := newMockManager(mockStore) mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { return nil, http.ErrNoCookie }) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.Error(err, testError.Error()) assert.Nil(sess) } @@ -128,7 +125,7 @@ func TestSessionNewSessionErrorWriteCookie(t *testing.T) { testError := errors.New("this is test error") newCookieVal := "somerandomid" - mockStore.val = newCookieVal + mockStore.id = newCookieVal mockManager := newMockManager(mockStore) mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { return nil, http.ErrNoCookie @@ -137,7 +134,7 @@ func TestSessionNewSessionErrorWriteCookie(t *testing.T) { return testError }) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.Error(err, testError.Error()) assert.Nil(sess) } @@ -151,7 +148,7 @@ func TestSessionNewSessionInvalidGetCookie(t *testing.T) { return nil, testError }) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.Error(err, testError.Error()) assert.Nil(sess) } @@ -161,32 +158,31 @@ func TestSessionNewSessionCreateNewCookie(t *testing.T) { mockStore := newMockStore() newCookieVal := "somerandomid" - mockStore.val = newCookieVal + mockStore.id = newCookieVal mockStore.isValid = true mockManager := newMockManager(mockStore) mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { return nil, http.ErrNoCookie }) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) - assert.True(sess.isSet) - assert.Equal(sess.cookie.Value, newCookieVal) + + assert.Equal(sess.id, newCookieVal) } -func TestSessionNewSessionWithDisableAutoSet(t *testing.T) { +func TestSessionNewSessionWithDisableAuto(t *testing.T) { assert := assert.New(t) mockStore := newMockStore() mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true + mockManager.opts.EnableAutoCreate = true mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { return nil, http.ErrNoCookie }) - sess, err := NewSession(mockManager, nil, nil) + _, err := mockManager.NewSession(nil, nil) assert.NoError(err) - assert.False(sess.isSet) } func TestSessionNewSessionGetCookieCb(t *testing.T) { @@ -195,7 +191,7 @@ func TestSessionNewSessionGetCookieCb(t *testing.T) { // Calls write cookie callback if cookie is not set already newCookieVal := "somerandomid" - mockStore.val = newCookieVal + mockStore.id = newCookieVal mockStore.isValid = true mockManager := newMockManager(mockStore) @@ -210,9 +206,9 @@ func TestSessionNewSessionGetCookieCb(t *testing.T) { }) var reader = "this is reader interface" - sess, err := NewSession(mockManager, reader, nil) + _, err := mockManager.NewSession(reader, nil) assert.NoError(err) - assert.True(sess.isSet) + assert.True(isCallbackTriggered) assert.Equal(receivedName, mockManager.opts.CookieName) assert.Equal(receivedReader, reader) @@ -224,7 +220,7 @@ func TestSessionNewSessionSetCookieCb(t *testing.T) { // Calls write cookie callback if cookie is not set already newCookieVal := "somerandomid" - mockStore.val = newCookieVal + mockStore.id = newCookieVal mockStore.isValid = true mockManager := newMockManager(mockStore) mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { @@ -242,9 +238,9 @@ func TestSessionNewSessionSetCookieCb(t *testing.T) { }) var writer = "this is writer interface" - sess, err := NewSession(mockManager, nil, writer) + _, err := mockManager.NewSession(nil, writer) assert.NoError(err) - assert.True(sess.isSet) + assert.True(isCallbackTriggered) assert.Equal(receivedCookie.Value, newCookieVal) assert.Equal(receivedWriter, writer) @@ -261,27 +257,20 @@ func TestSessionWriteCookie(t *testing.T) { CookieLifetime: time.Second * 1000, IsHTTPOnlyCookie: true, IsSecureCookie: true, - DisableAutoSet: true, + EnableAutoCreate: false, SameSite: http.SameSiteDefaultMode, } mockStore.isValid = true - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) - sess.WriteCookie("testvalue") - assert.Equal(sess.cookie.Name, mockManager.opts.CookieName) - assert.Equal(sess.cookie.Value, "testvalue") - assert.Equal(sess.cookie.Domain, mockManager.opts.CookieDomain) - assert.Equal(sess.cookie.Path, mockManager.opts.CookiePath) - assert.Equal(sess.cookie.Secure, mockManager.opts.IsSecureCookie) - assert.Equal(sess.cookie.SameSite, mockManager.opts.SameSite) - assert.Equal(sess.cookie.HttpOnly, mockManager.opts.IsHTTPOnlyCookie) + assert.NoError(sess.WriteCookie("testvalue")) // Ignore seconds - expiry := time.Now().Add(mockManager.opts.CookieLifetime) - assert.Equal(sess.cookie.Expires.Format("2006-01-02 15:04:05"), expiry.Format("2006-01-02 15:04:05")) - assert.WithinDuration(expiry, sess.cookie.Expires, time.Millisecond*1000) + // expiry := time.Now().Add(mockManager.opts.CookieLifetime) + // assert.Equal(sess.id.Expires.Format("2006-01-02 15:04:05"), expiry.Format("2006-01-02 15:04:05")) + // assert.WithinDuration(expiry, sess.id.Expires, time.Millisecond*1000) } func TestSessionClearCookie(t *testing.T) { @@ -298,7 +287,7 @@ func TestSessionClearCookie(t *testing.T) { return nil }) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) err = sess.clearCookie() @@ -315,7 +304,7 @@ func TestSessionCreate(t *testing.T) { mockStore.isValid = true mockStore.val = "test" mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true + mockManager.opts.EnableAutoCreate = true mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { return nil, http.ErrNoCookie }) @@ -326,15 +315,14 @@ func TestSessionCreate(t *testing.T) { return nil }) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) - assert.False(sess.isSet) assert.False(isCallbackTriggered) err = sess.Create() assert.NoError(err) assert.True(isCallbackTriggered) - assert.True(sess.isSet) + } func TestSessionLoadValues(t *testing.T) { @@ -344,7 +332,7 @@ func TestSessionLoadValues(t *testing.T) { mockManager := newMockManager(mockStore) assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) err = sess.LoadValues() @@ -360,7 +348,7 @@ func TestSessionResetValues(t *testing.T) { mockManager := newMockManager(mockStore) assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) err = sess.LoadValues() @@ -379,7 +367,7 @@ func TestSessionGetAllFromStore(t *testing.T) { mockManager := newMockManager(mockStore) assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) vals, err := sess.GetAll() @@ -394,7 +382,7 @@ func TestSessionGetAllLoadedValues(t *testing.T) { mockManager := newMockManager(mockStore) assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) setVals := make(map[string]interface{}) @@ -410,13 +398,13 @@ func TestSessionGetAllLoadedValues(t *testing.T) { func TestSessionGetAllInvalidSession(t *testing.T) { mockStore := newMockStore() mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true + mockManager.opts.EnableAutoCreate = true mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { return nil, http.ErrNoCookie }) assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) vals, err := sess.GetAll() @@ -431,7 +419,7 @@ func TestSessionGetMultiFromStore(t *testing.T) { mockManager := newMockManager(mockStore) assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) vals, err := sess.GetMulti("val") @@ -446,7 +434,7 @@ func TestSessionGetMultiLoadedValues(t *testing.T) { mockManager := newMockManager(mockStore) assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) setVals := make(map[string]interface{}) @@ -464,13 +452,13 @@ func TestSessionGetMultiLoadedValues(t *testing.T) { func TestSessionGetMultiInvalidSession(t *testing.T) { mockStore := newMockStore() mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true + mockManager.opts.EnableAutoCreate = true mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { return nil, http.ErrNoCookie }) assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) vals, err := sess.GetMulti("val") @@ -485,7 +473,7 @@ func TestSessionGetFromStore(t *testing.T) { mockManager := newMockManager(mockStore) assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) val, err := sess.Get("val") @@ -499,7 +487,7 @@ func TestSessionGetLoadedValues(t *testing.T) { mockManager := newMockManager(mockStore) assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) setVals := make(map[string]interface{}) @@ -515,13 +503,13 @@ func TestSessionGetLoadedValues(t *testing.T) { func TestSessionGetInvalidSession(t *testing.T) { mockStore := newMockStore() mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true + mockManager.opts.EnableAutoCreate = true mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { return nil, http.ErrNoCookie }) assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) vals, err := sess.Get("val") @@ -535,7 +523,7 @@ func TestSessionSet(t *testing.T) { mockManager := newMockManager(mockStore) assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) err = sess.Set("key", 100) @@ -546,13 +534,13 @@ func TestSessionSet(t *testing.T) { func TestSessionSetInvalidSession(t *testing.T) { mockStore := newMockStore() mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true + mockManager.opts.EnableAutoCreate = true mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { return nil, http.ErrNoCookie }) assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) err = sess.Set("key", 100) @@ -565,7 +553,7 @@ func TestSessionCommit(t *testing.T) { mockManager := newMockManager(mockStore) assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) err = sess.Set("key", 100) @@ -580,13 +568,13 @@ func TestSessionCommit(t *testing.T) { func TestSessionCommitInvalidSession(t *testing.T) { mockStore := newMockStore() mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true + mockManager.opts.EnableAutoCreate = true mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { return nil, http.ErrNoCookie }) assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) err = sess.Commit() @@ -600,7 +588,7 @@ func TestSessionDelete(t *testing.T) { mockStore.isValid = true mockStore.val = 100 - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) assert.Equal(mockStore.val, 100) @@ -627,7 +615,7 @@ func TestSessionClear(t *testing.T) { return nil }) - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) assert.Equal(mockStore.val, 100) @@ -644,7 +632,7 @@ func TestSessionClearError(t *testing.T) { mockManager := newMockManager(mockStore) mockStore.isValid = true - sess, err := NewSession(mockManager, nil, nil) + sess, err := mockManager.NewSession(nil, nil) assert.NoError(err) testError := errors.New("this is test error") @@ -653,22 +641,6 @@ func TestSessionClearError(t *testing.T) { assert.Error(err, testError.Error()) } -func TestSessionClearInvalidSession(t *testing.T) { - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - err = sess.Clear() - assert.Error(err, ErrInvalidSession.Error()) -} - type Err struct { code int msg string diff --git a/store_test.go b/store_test.go index 4eba6fe..99ec0ee 100644 --- a/store_test.go +++ b/store_test.go @@ -5,6 +5,7 @@ type MockStore struct { isValid bool cookieValue string err error + id string val interface{} isCommited bool } @@ -18,7 +19,7 @@ func (s *MockStore) reset() { } func (s *MockStore) Create() (cv string, err error) { - return s.val.(string), s.err + return s.id, s.err } func (s *MockStore) Get(cv, key string) (value interface{}, err error) { From d20435524faa5de4965b31724623cae1bd606703 Mon Sep 17 00:00:00 2001 From: Vivek R Date: Wed, 22 May 2024 17:39:57 +0530 Subject: [PATCH 2/4] feat: refactor tests --- manager.go | 42 +-- manager_test.go | 162 +++++----- session.go | 19 +- session_test.go | 845 ++++++++++++++++++------------------------------ store_test.go | 87 +++-- 5 files changed, 473 insertions(+), 682 deletions(-) diff --git a/manager.go b/manager.go index 93c9e73..f2080df 100644 --- a/manager.go +++ b/manager.go @@ -7,12 +7,14 @@ import ( "time" ) +type ctxNameType string + const ( // Default cookie name used to store session. defaultCookieName = "session" // ContextName is the key used to store session in context passed to acquire method. - ContextName = "_simple_session" + ContextName ctxNameType = "_simple_session" ) // Manager is a utility to scaffold session and store. @@ -94,20 +96,16 @@ func (m *Manager) RegisterSetCookie(cb func(*http.Cookie, interface{}) error) { m.setCookieCb = cb } -// NewSession creates a new session. Reads cookie info from `GetCookie“ callback -// and validate the session with current store. If cookie not set then it creates -// new session and calls `SetCookie“ callback. If `DisableAutoSet` is set then it -// skips new session creation and should be manually done using `Create` method. -// If a cookie is found but its invalid in store then `ErrInvalidSession` error is returned. +// NewSession creates a new session. func (m *Manager) NewSession(r, w interface{}) (*Session, error) { - var ( - sess = &Session{ - manager: m, - reader: r, - writer: w, - values: make(map[string]interface{}), - } - ) + // Check if any store is set + if m.store == nil { + return nil, fmt.Errorf("session store is not set") + } + + if m.setCookieCb == nil { + return nil, fmt.Errorf("callback `SetCookie` not set") + } // Create new cookie in store and write to front. // Store also calls `WriteCookie`` to write to http interface. @@ -116,6 +114,13 @@ func (m *Manager) NewSession(r, w interface{}) (*Session, error) { return nil, errAs(err) } + var sess = &Session{ + id: id, + manager: m, + reader: r, + writer: w, + values: make(map[string]interface{}), + } // Write cookie. if err := sess.WriteCookie(id); err != nil { return nil, err @@ -125,14 +130,15 @@ func (m *Manager) NewSession(r, w interface{}) (*Session, error) { } // Acquire gets a `Session` for current session cookie from store. -// If `Session` is not found on store then it creates a new session and sets on store. -// If 'DisableAutoSet` is set in options then session has to be explicitly created before -// using `Session` for getting or setting. +// If `Session` is not found and `opt.EnableAutoCreate` option is true then +// then it creates a new session and sets on store. +// If `Session` is not found and `opt.EnableAutoCreate` option is false then +// then it returns ErrInvalidSession. // `r` and `w` is request and response interfaces which are sent back in GetCookie and SetCookie callbacks respectively. // In case of net/http `r` will be r` // Optionally context can be passed around which is used to get already loaded session. This is useful when // handler is wrapped with multiple middlewares and `Acquire` is already called in any of the middleware. -func (m *Manager) Acquire(r, w interface{}, c context.Context) (*Session, error) { +func (m *Manager) Acquire(c context.Context, r, w interface{}) (*Session, error) { // Check if any store is set if m.store == nil { return nil, fmt.Errorf("session store is not set") diff --git a/manager_test.go b/manager_test.go index b949aa3..3983257 100644 --- a/manager_test.go +++ b/manager_test.go @@ -9,14 +9,44 @@ import ( "github.com/stretchr/testify/assert" ) +const mockSessionID = "sometestcookievalue" + +func newMockStore() *MockStore { + return &MockStore{ + id: mockSessionID, + data: map[string]interface{}{}, + temp: map[string]interface{}{}, + err: nil, + } +} + +func newMockManager(store *MockStore) *Manager { + m := New(Options{}) + m.UseStore(store) + m.RegisterGetCookie(mockGetCookieCb) + m.RegisterSetCookie(mockSetCookieCb) + return m +} + +func mockGetCookieCb(name string, r interface{}) (*http.Cookie, error) { + return &http.Cookie{ + Name: name, + Value: mockSessionID, + }, nil +} + +func mockSetCookieCb(*http.Cookie, interface{}) error { + return nil +} + func TestNewManagerWithDefaultOptions(t *testing.T) { m := New(Options{}) assert := assert.New(t) // Default cookie path is set to root - assert.Equal(m.opts.CookiePath, "/") + assert.Equal("/", m.opts.CookiePath) // Default cookie name is set - assert.Equal(m.opts.CookieName, defaultCookieName) + assert.Equal(defaultCookieName, m.opts.CookieName) } func TestManagerNewManagerWithOptions(t *testing.T) { @@ -36,24 +66,21 @@ func TestManagerNewManagerWithOptions(t *testing.T) { assert := assert.New(t) // Default cookie path is set to root - assert.Equal(m.opts.EnableAutoCreate, opts.EnableAutoCreate) - assert.Equal(m.opts.CookieName, opts.CookieName) - assert.Equal(m.opts.CookieDomain, opts.CookieDomain) - assert.Equal(m.opts.CookiePath, opts.CookiePath) - assert.Equal(m.opts.IsSecureCookie, opts.IsSecureCookie) - assert.Equal(m.opts.SameSite, opts.SameSite) - assert.Equal(m.opts.IsHTTPOnlyCookie, opts.IsHTTPOnlyCookie) - assert.Equal(m.opts.CookieLifetime, opts.CookieLifetime) + assert.Equal(opts.EnableAutoCreate, m.opts.EnableAutoCreate) + assert.Equal(opts.CookieName, m.opts.CookieName) + assert.Equal(opts.CookieDomain, m.opts.CookieDomain) + assert.Equal(opts.CookiePath, m.opts.CookiePath) + assert.Equal(opts.IsSecureCookie, m.opts.IsSecureCookie) + assert.Equal(opts.SameSite, m.opts.SameSite) + assert.Equal(opts.IsHTTPOnlyCookie, m.opts.IsHTTPOnlyCookie) + assert.Equal(opts.CookieLifetime, m.opts.CookieLifetime) } func TestManagerUseStore(t *testing.T) { assert := assert.New(t) - mockStr := &MockStore{} - assert.Implements((*Store)(nil), mockStr) - - m := New(Options{}) - m.UseStore(mockStr) - assert.Equal(m.store, mockStr) + s := newMockStore() + m := newMockManager(s) + assert.Equal(s, m.store) } func TestManagerRegisterGetCookie(t *testing.T) { @@ -99,106 +126,63 @@ func TestManagerRegisterSetCookie(t *testing.T) { func TestManagerAcquireFails(t *testing.T) { assert := assert.New(t) - m := New(Options{EnableAutoCreate: false}) + m := New(Options{}) - _, err := m.Acquire(nil, nil, nil) - assert.Error(err, "session store is not set") + // Fail if store is not assigned. + _, err := m.Acquire(context.Background(), nil, nil) + assert.Equal("session store is not set", err.Error()) + // Fail if getCookie callback is not assigned. m.UseStore(&MockStore{}) - _, err = m.Acquire(nil, nil, nil) - assert.Error(err, "callback `GetCookie` not set") + _, err = m.Acquire(context.Background(), nil, nil) + assert.Equal("callback `GetCookie` not set", err.Error()) + // Assign getCookie, returns nil cookie to make sure it + // fails in create session with invalid session. m.RegisterGetCookie(func(string, interface{}) (*http.Cookie, error) { return nil, nil }) - _, err = m.Acquire(nil, nil, nil) - assert.Error(err, "callback `SetCookie` not set") - - m.RegisterSetCookie(func(*http.Cookie, interface{}) error { - return nil - }) - _, err = m.Acquire(nil, nil, nil) - assert.ErrorIs(err, ErrInvalidSession) -} - -func TestManagerAcquireNoAutocreate(t *testing.T) { - m := New(Options{EnableAutoCreate: false}) - m.UseStore(&MockStore{ - isValid: true, - id: "somerandomid", - }) - - m.RegisterGetCookie(func(string, interface{}) (*http.Cookie, error) { - return &http.Cookie{ - Name: "testcookie", - Value: "somerandomid", - }, nil - }) + // Fail if setCookie callback is not assigned. + _, err = m.Acquire(context.Background(), nil, nil) + assert.Equal("callback `SetCookie` not set", err.Error()) + // Register setCookie callback. m.RegisterSetCookie(func(*http.Cookie, interface{}) error { return nil }) - _, err := m.Acquire(nil, nil, nil) - assert := assert.New(t) - assert.NoError(err) + // By default EnableAutoCreate is disabled + // Check if it returns invalid session. + _, err = m.Acquire(context.Background(), nil, nil) + assert.ErrorIs(err, ErrInvalidSession) } func TestManagerAcquireAutocreate(t *testing.T) { - m := New(Options{EnableAutoCreate: true}) - m.UseStore(&MockStore{ - isValid: true, - id: "somerandomid", - }) - + m := newMockManager(newMockStore()) + // Enable autocreate. + m.opts.EnableAutoCreate = true m.RegisterGetCookie(func(string, interface{}) (*http.Cookie, error) { - return &http.Cookie{ - Name: "testcookie", - Value: "", - }, nil + return nil, ErrInvalidSession }) - m.RegisterSetCookie(func(*http.Cookie, interface{}) error { - return nil - }) - - _, err := m.Acquire(nil, nil, nil) + // If cookie doesn't exist then should return a new one without error. + sess, err := m.Acquire(context.Background(), nil, nil) assert := assert.New(t) assert.NoError(err) + assert.Equal(mockSessionID, sess.id) } func TestManagerAcquireFromContext(t *testing.T) { assert := assert.New(t) - m := New(Options{EnableAutoCreate: true}) - m.UseStore(&MockStore{ - isValid: true, - id: "somerandomid", - }) - - getCb := func(string, interface{}) (*http.Cookie, error) { - return &http.Cookie{ - Name: "testcookie", - Value: "", - }, nil - } - m.RegisterGetCookie(getCb) - - setCb := func(*http.Cookie, interface{}) error { - return nil - } - m.RegisterSetCookie(setCb) + m := newMockManager(newMockStore()) - sess, err := m.Acquire(nil, nil, nil) - assert.NoError(err) + sess, err := m.Acquire(context.Background(), nil, nil) sess.id = "updated" - - sessNew, err := m.Acquire(nil, nil, nil) assert.NoError(err) - assert.NotEqual(sessNew.id, sess.id) - ctx := context.Background() - ctx = context.WithValue(ctx, ContextName, sess) - sessNext, err := m.Acquire(nil, nil, ctx) - assert.Equal(sessNext.id, sess.id) + ctx := context.WithValue(context.Background(), ContextName, sess) + sessNext, err := m.Acquire(ctx, nil, nil) + assert.Equal(sess.id, sessNext.id) + assert.NoError(err) } diff --git a/session.go b/session.go index efb5b24..d72a801 100644 --- a/session.go +++ b/session.go @@ -77,23 +77,6 @@ func (s *Session) clearCookie() error { return s.manager.setCookieCb(ck, s.writer) } -// Create a new session. This is implicit when option `DisableAutoSet` is false -// else session has to be manually created before setting or getting values. -func (s *Session) Create() error { - // Create new cookie in store and write to front. - cv, err := s.manager.store.Create() - if err != nil { - return errAs(err) - } - - // Write cookie - if err := s.WriteCookie(cv); err != nil { - return err - } - - return nil -} - // ID returns the acquired session ID. If cookie is not set then empty string is returned. func (s *Session) ID() string { return s.id @@ -197,7 +180,7 @@ func (s *Session) Delete(key string) error { return nil } -// Clear clears session data from store and clears the cookie +// Clear clears session data from store and clears the cookie. func (s *Session) Clear() error { if err := s.manager.store.Clear(s.id); err != nil { return errAs(err) diff --git a/session_test.go b/session_test.go index ccd13ea..d2051b3 100644 --- a/session_test.go +++ b/session_test.go @@ -9,37 +9,38 @@ import ( "github.com/stretchr/testify/assert" ) -var ( - testCookieName = "sometestcookie" - testCookieValue = "sometestcookievalue" -) - -func newMockStore() *MockStore { - return &MockStore{} +type Err struct { + code int + msg string } -func newMockManager(store *MockStore) *Manager { - mockManager := New(Options{}) - mockManager.UseStore(store) - mockManager.RegisterGetCookie(getCookieCb) - mockManager.RegisterSetCookie(setCookieCb) - - return mockManager +func (e *Err) Error() string { + return e.msg } -func getCookieCb(name string, r interface{}) (*http.Cookie, error) { - return &http.Cookie{ - Name: name, - Value: testCookieValue, - }, nil +func (e *Err) Code() int { + return e.code } -func setCookieCb(*http.Cookie, interface{}) error { - return nil +func TestErrorTypes(t *testing.T) { + var ( + // Error codes for store errors. This should match the codes + // defined in the /simplesessions package exactly. + errInvalidSession = &Err{code: 1, msg: "invalid session"} + errFieldNotFound = &Err{code: 2, msg: "field not found"} + errAssertType = &Err{code: 3, msg: "assertion failed"} + errNil = &Err{code: 4, msg: "nil returned"} + errCustom = &Err{msg: "custom error"} + ) + + assert.Equal(t, errAs(errInvalidSession), ErrInvalidSession) + assert.Equal(t, errAs(errFieldNotFound), ErrFieldNotFound) + assert.Equal(t, errAs(errAssertType), ErrAssertType) + assert.Equal(t, errAs(errNil), ErrNil) + assert.Equal(t, errAs(errCustom), errCustom) } func TestSessionHelpers(t *testing.T) { - assert := assert.New(t) sess := Session{ manager: newMockManager(newMockStore()), } @@ -47,210 +48,124 @@ func TestSessionHelpers(t *testing.T) { // Int var inp1 = 100 v1, err := sess.Int(inp1, errors.New("test error")) - assert.Equal(v1, inp1) - assert.Error(err, "test error") + assert.Equal(t, inp1, v1) + assert.Equal(t, "test error", err.Error()) // Int64 var inp2 int64 = 100 v2, err := sess.Int64(inp2, errors.New("test error")) - assert.Equal(v2, inp2) - assert.Error(err, "test error") + assert.Equal(t, inp2, v2) + assert.Equal(t, "test error", err.Error()) var inp3 uint64 = 100 v3, err := sess.UInt64(inp3, errors.New("test error")) - assert.Equal(v3, inp3) - assert.Error(err, "test error") + assert.Equal(t, inp3, v3) + assert.Equal(t, "test error", err.Error()) var inp4 float64 = 100 v4, err := sess.Float64(inp4, errors.New("test error")) - assert.Equal(v4, inp4) - assert.Error(err, "test error") + assert.Equal(t, inp4, v4) + assert.Equal(t, "test error", err.Error()) var inp5 = "abc123" v5, err := sess.String(inp5, errors.New("test error")) - assert.Equal(v5, inp5) - assert.Error(err, "test error") + assert.Equal(t, inp5, v5) + assert.Equal(t, "test error", err.Error()) var inp6 = true v6, err := sess.Bool(inp6, errors.New("test error")) - assert.Equal(v6, inp6) - assert.Error(err, "test error") + assert.Equal(t, inp6, v6) + assert.Equal(t, "test error", err.Error()) var inp7 = []byte{} v7, err := sess.Bytes(inp7, errors.New("test error")) - assert.Equal(v7, inp7) - assert.Error(err, "test error") + assert.Equal(t, inp7, v7) + assert.Equal(t, "test error", err.Error()) } func TestSessionNewSession(t *testing.T) { reader := "some reader" writer := "some writer" - mockStore := newMockStore() - mockStore.isValid = true - mockManager := newMockManager(mockStore) + mgr := newMockManager(newMockStore()) - assert := assert.New(t) - sess, err := mockManager.NewSession(reader, writer) - assert.NoError(err) - assert.Equal(sess.manager, mockManager) - assert.Equal(sess.reader, reader) - assert.Equal(sess.writer, writer) - assert.NotNil(sess.values) - assert.Equal(sess.id, testCookieValue) + sess, err := mgr.NewSession(reader, writer) + assert.NoError(t, err) + assert.Equal(t, mgr, sess.manager) + assert.Equal(t, reader, sess.reader) + assert.Equal(t, writer, sess.writer) + assert.NotNil(t, sess.values) + assert.Equal(t, mockSessionID, sess.id) + assert.Equal(t, sess.id, sess.ID()) } -func TestSessionNewSessionErrorStoreCreate(t *testing.T) { +func TestSessionNewSessionErrors(t *testing.T) { assert := assert.New(t) - mockStore := newMockStore() - mockStore.isValid = true - - testError := errors.New("this is test error") - newCookieVal := "somerandomid" - mockStore.id = newCookieVal - mockStore.err = testError - mockManager := newMockManager(mockStore) - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - sess, err := mockManager.NewSession(nil, nil) - assert.Error(err, testError.Error()) + mgr := New(Options{}) + sess, err := mgr.NewSession(nil, nil) + assert.Equal("session store is not set", err.Error()) assert.Nil(sess) -} -func TestSessionNewSessionErrorWriteCookie(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockStore.isValid = true - - testError := errors.New("this is test error") - newCookieVal := "somerandomid" - mockStore.id = newCookieVal - mockManager := newMockManager(mockStore) - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - mockManager.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { - return testError - }) + mgr = New(Options{}) + mgr.UseStore(&MockStore{}) + sess, err = mgr.NewSession(nil, nil) + assert.Equal("callback `SetCookie` not set", err.Error()) + assert.Nil(sess) - sess, err := mockManager.NewSession(nil, nil) - assert.Error(err, testError.Error()) + // Store error. + tErr := errors.New("store error") + str := newMockStore() + str.err = tErr + mgr = newMockManager(str) + sess, err = mgr.NewSession(nil, nil) + assert.ErrorIs(tErr, err) assert.Nil(sess) -} -func TestSessionNewSessionInvalidGetCookie(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - testError := errors.New("custom error") - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, testError + // Cookie write error. + str.err = nil + wErr := errors.New("write cookie error") + mgr.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { + return wErr }) - - sess, err := mockManager.NewSession(nil, nil) - assert.Error(err, testError.Error()) + sess, err = mgr.NewSession(nil, nil) + assert.ErrorIs(wErr, err) assert.Nil(sess) } func TestSessionNewSessionCreateNewCookie(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - - newCookieVal := "somerandomid" - mockStore.id = newCookieVal - mockStore.isValid = true - mockManager := newMockManager(mockStore) - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - - assert.Equal(sess.id, newCookieVal) -} - -func TestSessionNewSessionWithDisableAuto(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - - mockManager := newMockManager(mockStore) - mockManager.opts.EnableAutoCreate = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - _, err := mockManager.NewSession(nil, nil) - assert.NoError(err) -} - -func TestSessionNewSessionGetCookieCb(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - - // Calls write cookie callback if cookie is not set already - newCookieVal := "somerandomid" - mockStore.id = newCookieVal - mockStore.isValid = true - mockManager := newMockManager(mockStore) - - var receivedName string - var receivedReader interface{} - var isCallbackTriggered bool - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - isCallbackTriggered = true - receivedName = name - receivedReader = r - return nil, http.ErrNoCookie - }) - - var reader = "this is reader interface" - _, err := mockManager.NewSession(reader, nil) - assert.NoError(err) - - assert.True(isCallbackTriggered) - assert.Equal(receivedName, mockManager.opts.CookieName) - assert.Equal(receivedReader, reader) + mgr := newMockManager(newMockStore()) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) + assert.Equal(t, sess.id, mockSessionID) } func TestSessionNewSessionSetCookieCb(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - - // Calls write cookie callback if cookie is not set already - newCookieVal := "somerandomid" - mockStore.id = newCookieVal - mockStore.isValid = true - mockManager := newMockManager(mockStore) - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) + var ( + mgr = newMockManager(newMockStore()) + receCk *http.Cookie + receWr interface{} + isCb bool + ) - var receivedCookie *http.Cookie - var receivedWriter interface{} - var isCallbackTriggered bool - mockManager.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { - receivedCookie = cookie - receivedWriter = w - isCallbackTriggered = true + mgr.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { + receCk = cookie + receWr = w + isCb = true return nil }) var writer = "this is writer interface" - _, err := mockManager.NewSession(nil, writer) - assert.NoError(err) + _, err := mgr.NewSession(nil, writer) + assert.NoError(t, err) - assert.True(isCallbackTriggered) - assert.Equal(receivedCookie.Value, newCookieVal) - assert.Equal(receivedWriter, writer) + assert.True(t, isCb) + assert.Equal(t, mockSessionID, receCk.Value) + assert.Equal(t, writer, receWr) } func TestSessionWriteCookie(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockManager.opts = &Options{ + mgr := newMockManager(newMockStore()) + mgr.opts = &Options{ CookieName: "somename", CookieDomain: "abc.xyz", CookiePath: "/abc/xyz", @@ -260,412 +175,278 @@ func TestSessionWriteCookie(t *testing.T) { EnableAutoCreate: false, SameSite: http.SameSiteDefaultMode, } - mockStore.isValid = true - - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - - assert.NoError(sess.WriteCookie("testvalue")) - // Ignore seconds - // expiry := time.Now().Add(mockManager.opts.CookieLifetime) - // assert.Equal(sess.id.Expires.Format("2006-01-02 15:04:05"), expiry.Format("2006-01-02 15:04:05")) - // assert.WithinDuration(expiry, sess.id.Expires, time.Millisecond*1000) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) + assert.NoError(t, sess.WriteCookie("testvalue")) } func TestSessionClearCookie(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockStore.isValid = true - - var receivedCookie *http.Cookie - var isCallbackTriggered bool - mockManager.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { - receivedCookie = cookie - isCallbackTriggered = true + var ( + mgr = newMockManager(newMockStore()) + receCk *http.Cookie + isCb bool + ) + mgr.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { + receCk = cookie + isCb = true return nil }) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) err = sess.clearCookie() - assert.NoError(err) - - assert.True(isCallbackTriggered) - assert.Equal(receivedCookie.Value, "") - assert.True(receivedCookie.Expires.UnixNano() < time.Now().UnixNano()) -} - -func TestSessionCreate(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockStore.isValid = true - mockStore.val = "test" - mockManager := newMockManager(mockStore) - mockManager.opts.EnableAutoCreate = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - var isCallbackTriggered bool - mockManager.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { - isCallbackTriggered = true - return nil - }) - - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - assert.False(isCallbackTriggered) - - err = sess.Create() - assert.NoError(err) - assert.True(isCallbackTriggered) + assert.NoError(t, err) + assert.True(t, isCb) + assert.Equal(t, "", receCk.Value) + assert.True(t, receCk.Expires.UnixNano() < time.Now().UnixNano()) } func TestSessionLoadValues(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockStore.val = 100 - mockManager := newMockManager(mockStore) + str := newMockStore() + str.data = map[string]interface{}{ + "key1": 1, + "key2": 2, + } + mgr := newMockManager(str) - assert := assert.New(t) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) err = sess.LoadValues() - assert.NoError(err) - assert.Contains(sess.values, "val") - assert.Equal(sess.values["val"], 100) + assert.NoError(t, err) + assert.Equal(t, str.data, sess.values) } func TestSessionResetValues(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockStore.val = 100 - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - - err = sess.LoadValues() - assert.NoError(err) - assert.Contains(sess.values, "val") - assert.Equal(sess.values["val"], 100) + str := newMockStore() + str.data = map[string]interface{}{ + "key1": 1, + "key2": 2, + } + mgr := newMockManager(str) + sess, _ := mgr.NewSession(nil, nil) + sess.LoadValues() + assert.NotEqual(t, 0, len(sess.values)) sess.ResetValues() - assert.Equal(len(sess.values), 0) -} - -func TestSessionGetAllFromStore(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockStore.val = 100 - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - - vals, err := sess.GetAll() - assert.NoError(err) - assert.Contains(vals, "val") - assert.Equal(vals["val"], 100) -} - -func TestSessionGetAllLoadedValues(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - - setVals := make(map[string]interface{}) - setVals["sample"] = "someval" - sess.values = setVals - - vals, err := sess.GetAll() - assert.NoError(err) - assert.Contains(vals, "sample") - assert.Equal(vals["sample"], "someval") -} - -func TestSessionGetAllInvalidSession(t *testing.T) { - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockManager.opts.EnableAutoCreate = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - assert := assert.New(t) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - - vals, err := sess.GetAll() - assert.Error(err, ErrInvalidSession.Error()) - assert.Nil(vals) -} - -func TestSessionGetMultiFromStore(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockStore.val = 100 - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - - vals, err := sess.GetMulti("val") - assert.NoError(err) - assert.Contains(vals, "val") - assert.Equal(vals["val"], 100) -} - -func TestSessionGetMultiLoadedValues(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - - setVals := make(map[string]interface{}) - setVals["key1"] = "someval" - setVals["key2"] = "someval" - sess.values = setVals - - vals, err := sess.GetMulti("key1") - assert.NoError(err) - assert.Contains(vals, "key1") - assert.Equal(vals["key1"], "someval") - assert.NotContains(vals, "key2") -} - -func TestSessionGetMultiInvalidSession(t *testing.T) { - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockManager.opts.EnableAutoCreate = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - assert := assert.New(t) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - - vals, err := sess.GetMulti("val") - assert.Error(err, ErrInvalidSession.Error()) - assert.Nil(vals) -} - -func TestSessionGetFromStore(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockStore.val = 100 - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - - val, err := sess.Get("val") - assert.NoError(err) - assert.Equal(val, 100) + assert.Equal(t, 0, len(sess.values)) } -func TestSessionGetLoadedValues(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - - setVals := make(map[string]interface{}) - setVals["key1"] = "someval1" - setVals["key2"] = "someval2" - sess.values = setVals - - val, err := sess.Get("key1") - assert.NoError(err) - assert.Equal(val, "someval1") -} +func TestSessionGetStore(t *testing.T) { + str := newMockStore() + str.data = map[string]interface{}{ + "key1": 1, + "key2": 2, + "key3": 3, + } + mgr := newMockManager(str) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) + + // GetAll. + v1, err := sess.GetAll() + assert.NoError(t, err) + assert.Equal(t, str.data, v1) + + // Get Multi. + v2, err := sess.GetMulti("key1", "key2") + assert.NoError(t, err) + assert.Contains(t, v2, "key1") + assert.Equal(t, str.data["key1"], v2["key1"]) + assert.Contains(t, v2, "key2") + assert.Equal(t, str.data["key2"], v2["key2"]) + assert.NotContains(t, v2, "key3") + + // Get. + v3, err := sess.Get("key1") + assert.NoError(t, err) + assert.Contains(t, str.data, "key1") + assert.Equal(t, str.data["key1"], v3) +} + +func TestSessionGetLoaded(t *testing.T) { + str := newMockStore() + mgr := newMockManager(str) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) + + sess.values = map[string]interface{}{ + "key1": 1, + "key2": 2, + "key3": 3, + } -func TestSessionGetInvalidSession(t *testing.T) { - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockManager.opts.EnableAutoCreate = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) + // GetAll. + v1, err := sess.GetAll() + assert.NoError(t, err) + assert.Equal(t, sess.values, v1) - assert := assert.New(t) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) + // GetMulti. + v2, err := sess.GetMulti("key1", "key2") + assert.NoError(t, err) + assert.Contains(t, v2, "key1") + assert.Equal(t, sess.values["key1"], v2["key1"]) + assert.Contains(t, v2, "key2") + assert.Equal(t, sess.values["key2"], v2["key2"]) + assert.NotContains(t, v2, "key3") - vals, err := sess.Get("val") - assert.Error(err, ErrInvalidSession.Error()) - assert.Nil(vals) + // Get. + v3, err := sess.Get("key1") + assert.NoError(t, err) + assert.Contains(t, sess.values, "key1") + assert.Equal(t, sess.values["key1"], v3) } func TestSessionSet(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockManager := newMockManager(mockStore) + str := newMockStore() + str.data = map[string]interface{}{} + mgr := newMockManager(str) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) - assert := assert.New(t) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) + err = sess.Set("key1", 1) + assert.NoError(t, err) - err = sess.Set("key", 100) - assert.NoError(err) - assert.Equal(mockStore.val, 100) -} + // Check if its set on temp. + assert.Contains(t, str.temp, "key1") + assert.NotContains(t, str.data, "key1") + assert.Equal(t, str.temp["key1"], 1) -func TestSessionSetInvalidSession(t *testing.T) { - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockManager.opts.EnableAutoCreate = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - assert := assert.New(t) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) + // Commit. + err = sess.Commit() + assert.NoError(t, err) - err = sess.Set("key", 100) - assert.Error(err, ErrInvalidSession.Error()) + // Check if its set on data after commit. + assert.Contains(t, str.data, "key1") + assert.NotContains(t, str.temp, "key1") + assert.Equal(t, 1, str.data["key1"]) } -func TestSessionCommit(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockManager := newMockManager(mockStore) +func TestSessionSetMulti(t *testing.T) { + str := newMockStore() + str.data = map[string]interface{}{} + mgr := newMockManager(str) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) - assert := assert.New(t) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - - err = sess.Set("key", 100) - assert.NoError(err) - assert.NoError(err) - assert.False(mockStore.isCommited) + data := map[string]interface{}{ + "key1": 1, + "key2": 2, + "key3": 3, + } + err = sess.SetMulti(data) + assert.NoError(t, err) + + // Check if its set on temp. + assert.Contains(t, str.temp, "key1") + assert.Contains(t, str.temp, "key2") + assert.Contains(t, str.temp, "key3") + assert.NotContains(t, str.data, "key1") + assert.NotContains(t, str.data, "key2") + assert.NotContains(t, str.data, "key3") + assert.Equal(t, data["key1"], str.temp["key1"]) + assert.Equal(t, data["key2"], str.temp["key2"]) + assert.Equal(t, data["key3"], str.temp["key3"]) + + // Commit. err = sess.Commit() - assert.NoError(err) - assert.True(mockStore.isCommited) -} - -func TestSessionCommitInvalidSession(t *testing.T) { - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockManager.opts.EnableAutoCreate = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - assert := assert.New(t) - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - + assert.NoError(t, err) + + // Check if its set on data after commit. + assert.Contains(t, str.data, "key1") + assert.Contains(t, str.data, "key2") + assert.Contains(t, str.data, "key3") + assert.NotContains(t, str.temp, "key1") + assert.NotContains(t, str.temp, "key2") + assert.NotContains(t, str.temp, "key3") + assert.Equal(t, data["key1"], str.data["key1"]) + assert.Equal(t, data["key2"], str.data["key2"]) + assert.Equal(t, data["key3"], str.data["key3"]) + + // Test error. + str.err = errors.New("store error") + err = sess.SetMulti(data) + assert.ErrorIs(t, str.err, err) + + // Test error. + str.err = nil + err = sess.SetMulti(data) + assert.NoError(t, err) + str.err = errors.New("store error") err = sess.Commit() - assert.Error(err, ErrInvalidSession.Error()) + assert.ErrorIs(t, str.err, err) } func TestSessionDelete(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockStore.isValid = true - mockStore.val = 100 - - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - assert.Equal(mockStore.val, 100) - - err = sess.Delete("somekey") - assert.NoError(err) - assert.Nil(mockStore.val) - - testError := errors.New("this is test error") - mockStore.err = testError - err = sess.Delete("somekey") - assert.Error(err, testError.Error()) -} - -func TestSessionClear(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockStore.isValid = true - mockStore.val = 100 - - var isCallbackTriggered bool - mockManager.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { - isCallbackTriggered = true - return nil - }) - - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - assert.Equal(mockStore.val, 100) + str := newMockStore() + str.data = map[string]interface{}{ + "key1": 1, + "key2": 2, + "key3": 3, + } + mgr := newMockManager(str) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) - err = sess.Clear() - assert.NoError(err) + assert.Contains(t, str.data, "key1") + err = sess.Delete("key1") + assert.NoError(t, err) + assert.NotContains(t, str.data, "key1") - assert.True(isCallbackTriggered) - assert.Equal(mockStore.val, nil) + // Test error. + str.err = errors.New("store error") + err = sess.Delete("key2") + assert.ErrorIs(t, str.err, err) } -func TestSessionClearError(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockStore.isValid = true - - sess, err := mockManager.NewSession(nil, nil) - assert.NoError(err) - - testError := errors.New("this is test error") - mockStore.err = testError +func TestSessionClear(t *testing.T) { + // Test errors. + str := newMockStore() + mgr := newMockManager(str) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) + str.err = errors.New("store error") err = sess.Clear() - assert.Error(err, testError.Error()) -} + assert.ErrorIs(t, str.err, err) -type Err struct { - code int - msg string -} - -func (e *Err) Error() string { - return e.msg -} + // Test cookie write error. + str.err = nil + ckErr := errors.New("cookie error") + mgr.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { + return ckErr + }) + err = sess.Clear() + assert.ErrorIs(t, ckErr, err) -func (e *Err) Code() int { - return e.code -} + // Test clear. + str = newMockStore() + str.data = map[string]interface{}{ + "key1": 1, + "key2": 2, + } + mgr = newMockManager(str) + sess, err = mgr.NewSession(nil, nil) + assert.NoError(t, err) + err = sess.Clear() + assert.NoError(t, err) + assert.Equal(t, 0, len(str.data)) -func TestErrorTypes(t *testing.T) { + // Test deleteCookie callback. var ( - // Error codes for store errors. This should match the codes - // defined in the /simplesessions package exactly. - errInvalidSession = &Err{code: 1, msg: "invalid session"} - errFieldNotFound = &Err{code: 2, msg: "field not found"} - errAssertType = &Err{code: 3, msg: "assertion failed"} - errNil = &Err{code: 4, msg: "nil returned"} + receCk *http.Cookie + isCb bool ) - - assert.Equal(t, errAs(errInvalidSession), ErrInvalidSession) - assert.Equal(t, errAs(errFieldNotFound), ErrFieldNotFound) - assert.Equal(t, errAs(errAssertType), ErrAssertType) - assert.Equal(t, errAs(errNil), ErrNil) + mgr.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { + receCk = cookie + isCb = true + return nil + }) + err = sess.Clear() + assert.NoError(t, err) + assert.Equal(t, 0, len(str.data)) + assert.True(t, isCb) + assert.Greater(t, time.Now(), receCk.Expires) } diff --git a/store_test.go b/store_test.go index 99ec0ee..ab1f43b 100644 --- a/store_test.go +++ b/store_test.go @@ -2,20 +2,10 @@ package simplesessions // MockStore mocks the store for testing type MockStore struct { - isValid bool - cookieValue string - err error - id string - val interface{} - isCommited bool -} - -func (s *MockStore) reset() { - s.isValid = false - s.cookieValue = "" - s.err = nil - s.val = nil - s.isCommited = false + err error + id string + data map[string]interface{} + temp map[string]interface{} } func (s *MockStore) Create() (cv string, err error) { @@ -23,38 +13,85 @@ func (s *MockStore) Create() (cv string, err error) { } func (s *MockStore) Get(cv, key string) (value interface{}, err error) { - return s.val, s.err + if s.id == "" { + return nil, ErrInvalidSession + } + + d, ok := s.data[key] + if !ok { + return nil, ErrFieldNotFound + } + + return d, s.err } func (s *MockStore) GetMulti(cv string, keys ...string) (values map[string]interface{}, err error) { - vals := make(map[string]interface{}) - vals["val"] = s.val - return vals, s.err + if s.id == "" { + return nil, ErrInvalidSession + } + + out := make(map[string]interface{}) + for _, key := range keys { + v, err := s.Get(cv, key) + if err != nil { + if err == ErrFieldNotFound { + v = nil + } else { + return nil, err + } + } + out[key] = v + } + + return out, s.err } func (s *MockStore) GetAll(cv string) (values map[string]interface{}, err error) { - vals := make(map[string]interface{}) - vals["val"] = s.val - return vals, s.err + if s.id == "" { + return nil, ErrInvalidSession + } + + return s.data, s.err } func (s *MockStore) Set(cv, key string, value interface{}) error { - s.val = value + if s.id == "" { + return ErrInvalidSession + } + + s.temp[key] = value return s.err } func (s *MockStore) Commit(cv string) error { - s.isCommited = true + if s.id == "" { + return ErrInvalidSession + } + + for key, val := range s.temp { + s.data[key] = val + } + s.temp = map[string]interface{}{} + return s.err } func (s *MockStore) Delete(cv string, key string) error { - s.val = nil + if s.id == "" { + return ErrInvalidSession + } + s.temp = nil + delete(s.data, key) + delete(s.temp, key) return s.err } func (s *MockStore) Clear(cv string) error { - s.val = nil + if s.id == "" { + return ErrInvalidSession + } + s.data = map[string]interface{}{} + s.temp = map[string]interface{}{} return s.err } From c128b02c5d51effe97789bbfdc7dfa83db990130 Mon Sep 17 00:00:00 2001 From: Vivek R Date: Thu, 23 May 2024 11:55:54 +0530 Subject: [PATCH 3/4] fix: update code comments --- manager.go | 26 ++++++++---------- session.go | 77 +++++++++++++++++++++++++++--------------------------- 2 files changed, 49 insertions(+), 54 deletions(-) diff --git a/manager.go b/manager.go index f2080df..0431c5a 100644 --- a/manager.go +++ b/manager.go @@ -17,7 +17,7 @@ const ( ContextName ctxNameType = "_simple_session" ) -// Manager is a utility to scaffold session and store. +// Manager handles the storage and management of HTTP cookies. type Manager struct { // Store to be used. store Store @@ -35,6 +35,7 @@ type Manager struct { // Options are available options to configure Manager. type Options struct { // If enabled, Acquire() will always create and return a new session if one doesn't already exist. + // If disabled then new session can only be created using NewSession() method. EnableAutoCreate bool // CookieName sets http cookie name. This is also sent as cookie name in `GetCookie` callback. @@ -84,19 +85,18 @@ func (m *Manager) UseStore(str Store) { m.store = str } -// RegisterGetCookie sets a callback to get http cookie from any reader interface which -// is sent on session acquisition using `Acquire` method. +// RegisterGetCookie sets a callback to retrieve an HTTP cookie during session acquisition. func (m *Manager) RegisterGetCookie(cb func(string, interface{}) (*http.Cookie, error)) { m.getCookieCb = cb } -// RegisterSetCookie sets a callback to set cookie from http writer interface which -// is sent on session acquisition using `Acquire` method. +// RegisterSetCookie sets a callback to set an HTTP cookie during session acquisition. func (m *Manager) RegisterSetCookie(cb func(*http.Cookie, interface{}) error) { m.setCookieCb = cb } -// NewSession creates a new session. +// NewSession creates a new `Session` and updates the cookie with a new session ID, +// replacing any existing session ID if it exists. func (m *Manager) NewSession(r, w interface{}) (*Session, error) { // Check if any store is set if m.store == nil { @@ -129,15 +129,11 @@ func (m *Manager) NewSession(r, w interface{}) (*Session, error) { return sess, nil } -// Acquire gets a `Session` for current session cookie from store. -// If `Session` is not found and `opt.EnableAutoCreate` option is true then -// then it creates a new session and sets on store. -// If `Session` is not found and `opt.EnableAutoCreate` option is false then -// then it returns ErrInvalidSession. -// `r` and `w` is request and response interfaces which are sent back in GetCookie and SetCookie callbacks respectively. -// In case of net/http `r` will be r` -// Optionally context can be passed around which is used to get already loaded session. This is useful when -// handler is wrapped with multiple middlewares and `Acquire` is already called in any of the middleware. +// Acquire retrieves a `Session` from the store using the current session cookie. +// If not found and `opt.EnableAutoCreate` is true, a new session is created and stored. +// If not found and `opt.EnableAutoCreate` is false which is the default, it returns ErrInvalidSession. +// `r` and `w` are request and response interfaces which is passed back in in GetCookie and SetCookie callbacks. +// Optionally, a context can be passed to get an already loaded session, useful in middleware chains. func (m *Manager) Acquire(c context.Context, r, w interface{}) (*Session, error) { // Check if any store is set if m.store == nil { diff --git a/session.go b/session.go index d72a801..998fc0a 100644 --- a/session.go +++ b/session.go @@ -6,10 +6,10 @@ import ( "time" ) -// Session is utility for get, set or clear session. +// Session provides the object to get, set, or clear session data. type Session struct { - // Map to store session data which can be loaded using `Load` method. - // Get session method check if the field is available here before getting from store directly. + // Map to store session data, loaded using `LoadValues` method. + // All `Get` methods checks here before fetching from the store. values map[string]interface{} // Session manager. @@ -19,7 +19,7 @@ type Session struct { id string // HTTP reader and writer interfaces which are passed on to - // `GetCookie`` and `SetCookie`` callback respectively. + // `GetCookie`` and `SetCookie`` callbacks. reader interface{} writer interface{} } @@ -47,11 +47,11 @@ type errCode interface { Code() int } -// WriteCookie updates the cookie and calls `SetCookie` callback. -// This method can also be used by store to update cookie whenever the cookie value changes. -func (s *Session) WriteCookie(cv string) error { +// WriteCookie creates a cookie with the given session ID and parameters, +// then calls the `SetCookie` callback. This can be used to update the cookie externally. +func (s *Session) WriteCookie(id string) error { ck := &http.Cookie{ - Value: cv, + Value: id, Name: s.manager.opts.CookieName, Domain: s.manager.opts.CookieDomain, Path: s.manager.opts.CookiePath, @@ -64,7 +64,7 @@ func (s *Session) WriteCookie(cv string) error { return s.manager.setCookieCb(ck, s.writer) } -// clearCookie sets expiry of the cookie to one day before to clear it. +// clearCookie sets the cookie's expiry to one day prior to clear it. func (s *Session) clearCookie() error { ck := &http.Cookie{ Name: s.manager.opts.CookieName, @@ -82,23 +82,25 @@ func (s *Session) ID() string { return s.id } -// LoadValues loads the session values in memory. -// Get session field tries to get value from memory before hitting store. +// LoadValues loads session values into memory for quick access. +// Ideal for centralized session fetching, e.g., in middleware. +// Subsequent Get/GetMulti calls return cached values, avoiding store access. +// Use ResetValues() to ensure GetAll/Get/GetMulti fetches from the store. +// Set/SetMulti/Clear do not update the values, so this method must be called again for any changes. func (s *Session) LoadValues() error { var err error s.values, err = s.GetAll() return err } -// ResetValues reset the loaded values using `LoadValues` method.ResetValues -// Subsequent Get, GetAll and GetMulti +// ResetValues clears loaded values, ensuring subsequent Get, GetAll, and GetMulti calls fetch from the store. func (s *Session) ResetValues() { s.values = make(map[string]interface{}) } -// GetAll gets all the fields in the session. +// GetAll gets all the fields for the given session id. func (s *Session) GetAll() (map[string]interface{}, error) { - // Load value from map if its already loaded + // Load value from map if its already loaded. if len(s.values) > 0 { return s.values, nil } @@ -107,7 +109,8 @@ func (s *Session) GetAll() (map[string]interface{}, error) { return out, errAs(err) } -// GetMulti gets a map of values for multiple session keys. +// GetMulti retrieves values for multiple session fields. +// If a field is not found in the store then its returned as nil. func (s *Session) GetMulti(keys ...string) (map[string]interface{}, error) { // Load values from map if its already loaded if len(s.values) > 0 { @@ -125,49 +128,47 @@ func (s *Session) GetMulti(keys ...string) (map[string]interface{}, error) { return out, errAs(err) } -// Get gets a value for given key in session. -// If session is already loaded using `Load` then returns values from -// existing map instead of getting it from store. +// Get retrieves a value for the given key in the session. +// If the session is already loaded, it returns the value from the existing map. +// Otherwise, it fetches the value from the store. func (s *Session) Get(key string) (interface{}, error) { - // Load value from map if its already loaded + // Return value from map if already loaded. if len(s.values) > 0 { if val, ok := s.values[key]; ok { return val, nil } } - // Get from backend if not found in previous step + // Fetch from store if not found in the map. out, err := s.manager.store.Get(s.id, key) return out, errAs(err) } -// Set sets a value for given key in session. Its up to store to commit -// all previously set values at once or store it on each set. +// Set assigns a value to the given key in the session. +// The store determines whether to commit all values at once or store them individually. +// Use Commit() method to commit all values if the store doesn't immediately persist them. func (s *Session) Set(key string, val interface{}) error { err := s.manager.store.Set(s.id, key, val) return errAs(err) } -// SetMulti sets all values in the session. -// Its up to store to commit all previously -// set values at once or store it on each set. +// SetMulti assigns multiple values to the session. +// The store determines whether to commit all values at once or store them individually. func (s *Session) SetMulti(values map[string]interface{}) error { for k, v := range values { if err := s.manager.store.Set(s.id, k, v); err != nil { return errAs(err) } } - return nil } -// Commit commits all set to store. Its up to store to commit -// all previously set values at once or store it on each set. +// Commit persists all values to the store. +// The store determines whether to commit all values at once or store them individually. func (s *Session) Commit() error { if err := s.manager.store.Commit(s.id); err != nil { return errAs(err) } - return nil } @@ -176,7 +177,6 @@ func (s *Session) Delete(key string) error { if err := s.manager.store.Delete(s.id, key); err != nil { return errAs(err) } - return nil } @@ -185,47 +185,46 @@ func (s *Session) Clear() error { if err := s.manager.store.Clear(s.id); err != nil { return errAs(err) } - return s.clearCookie() } -// Int is a helper to get values as integer +// Int is a helper to get values as integer. func (s *Session) Int(r interface{}, err error) (int, error) { out, err := s.manager.store.Int(r, err) return out, errAs(err) } -// Int64 is a helper to get values as Int64 +// Int64 is a helper to get values as Int64. func (s *Session) Int64(r interface{}, err error) (int64, error) { out, err := s.manager.store.Int64(r, err) return out, errAs(err) } -// UInt64 is a helper to get values as UInt64 +// UInt64 is a helper to get values as UInt64. func (s *Session) UInt64(r interface{}, err error) (uint64, error) { out, err := s.manager.store.UInt64(r, err) return out, errAs(err) } -// Float64 is a helper to get values as Float64 +// Float64 is a helper to get values as Float64. func (s *Session) Float64(r interface{}, err error) (float64, error) { out, err := s.manager.store.Float64(r, err) return out, errAs(err) } -// String is a helper to get values as String +// String is a helper to get values as String. func (s *Session) String(r interface{}, err error) (string, error) { out, err := s.manager.store.String(r, err) return out, errAs(err) } -// Bytes is a helper to get values as Bytes +// Bytes is a helper to get values as Bytes. func (s *Session) Bytes(r interface{}, err error) ([]byte, error) { out, err := s.manager.store.Bytes(r, err) return out, errAs(err) } -// Bool is a helper to get values as Bool +// Bool is a helper to get values as Bool. func (s *Session) Bool(r interface{}, err error) (bool, error) { out, err := s.manager.store.Bool(r, err) return out, errAs(err) From 7f3248818115bbe56432d3357b0b83db11c49670 Mon Sep 17 00:00:00 2001 From: Vivek R Date: Thu, 23 May 2024 15:19:54 +0530 Subject: [PATCH 4/4] refactor: remove Commit() pattern and introduce setMulti to store interface - remove Commit() pattern, Set and SetMulti should immediately set the values to backend. - rename LoadValues() to CacheAll() and update cache on Set, SetMulti, Delete and Clear calls. --- manager.go | 4 +- manager_test.go | 1 - session.go | 187 ++++++++++++++++++++++++++++++++++-------------- session_test.go | 140 +++++++++++++++++++++++------------- store.go | 28 ++++---- store_test.go | 43 +++++------ 6 files changed, 256 insertions(+), 147 deletions(-) diff --git a/manager.go b/manager.go index 0431c5a..43d44cd 100644 --- a/manager.go +++ b/manager.go @@ -119,7 +119,7 @@ func (m *Manager) NewSession(r, w interface{}) (*Session, error) { manager: m, reader: r, writer: w, - values: make(map[string]interface{}), + cache: nil, } // Write cookie. if err := sess.WriteCookie(id); err != nil { @@ -166,7 +166,7 @@ func (m *Manager) Acquire(c context.Context, r, w interface{}) (*Session, error) reader: r, writer: w, id: ck.Value, - values: make(map[string]interface{}), + cache: nil, }, nil } diff --git a/manager_test.go b/manager_test.go index 3983257..f7e3dc4 100644 --- a/manager_test.go +++ b/manager_test.go @@ -15,7 +15,6 @@ func newMockStore() *MockStore { return &MockStore{ id: mockSessionID, data: map[string]interface{}{}, - temp: map[string]interface{}{}, err: nil, } } diff --git a/session.go b/session.go index 998fc0a..842ac8d 100644 --- a/session.go +++ b/session.go @@ -3,14 +3,16 @@ package simplesessions import ( "errors" "net/http" + "sync" "time" ) // Session provides the object to get, set, or clear session data. type Session struct { - // Map to store session data, loaded using `LoadValues` method. + // Map to store session data, loaded using `CacheAll` method. // All `Get` methods checks here before fetching from the store. - values map[string]interface{} + cache map[string]interface{} + cacheMux sync.RWMutex // Session manager. manager *Manager @@ -82,49 +84,126 @@ func (s *Session) ID() string { return s.id } -// LoadValues loads session values into memory for quick access. +// getCacheAll returns a copy of cached map. +func (s *Session) getCacheAll() map[string]interface{} { + s.cacheMux.RLock() + defer s.cacheMux.RUnlock() + + if s.cache == nil { + return nil + } + + out := map[string]interface{}{} + for k, v := range s.cache { + out[k] = v + } + + return out +} + +// getCacheAll returns a map of values for the given list of keys. +// If key doesn't exist then ErrFieldNotFound is returned. +func (s *Session) getCache(key ...string) map[string]interface{} { + s.cacheMux.RLock() + defer s.cacheMux.RUnlock() + + if s.cache == nil { + return nil + } + + out := map[string]interface{}{} + for _, k := range key { + v, ok := s.cache[k] + if ok { + out[k] = v + } else { + out[k] = ErrFieldNotFound + } + } + + return out +} + +// setCache sets a cache for given kv pairs. +func (s *Session) setCache(data map[string]interface{}) { + s.cacheMux.Lock() + defer s.cacheMux.Unlock() + + // If cacheAll() is not called the don't maintain cache. + if s.cache == nil { + return + } + + for k, v := range data { + s.cache[k] = v + } +} + +// deleteCache sets a cache for given kv pairs. +func (s *Session) deleteCache(key ...string) { + s.cacheMux.Lock() + defer s.cacheMux.Unlock() + + // If cacheAll() is not called the don't maintain cache. + if s.cache == nil { + return + } + + for _, k := range key { + delete(s.cache, k) + } +} + +// CacheAll loads session values into memory for quick access. // Ideal for centralized session fetching, e.g., in middleware. // Subsequent Get/GetMulti calls return cached values, avoiding store access. -// Use ResetValues() to ensure GetAll/Get/GetMulti fetches from the store. -// Set/SetMulti/Clear do not update the values, so this method must be called again for any changes. -func (s *Session) LoadValues() error { - var err error - s.values, err = s.GetAll() - return err +// Use ResetCache() to ensure GetAll/Get/GetMulti fetches from the store. +func (s *Session) CacheAll() error { + all, err := s.manager.store.GetAll(s.id) + if err != nil { + return err + } + + s.cacheMux.Lock() + defer s.cacheMux.Unlock() + s.cache = map[string]interface{}{} + for k, v := range all { + s.cache[k] = v + } + + return nil } -// ResetValues clears loaded values, ensuring subsequent Get, GetAll, and GetMulti calls fetch from the store. -func (s *Session) ResetValues() { - s.values = make(map[string]interface{}) +// ResetCache clears loaded values, ensuring subsequent Get, GetAll, and GetMulti calls fetch from the store. +func (s *Session) ResetCache() { + s.cacheMux.Lock() + defer s.cacheMux.Unlock() + s.cache = nil } // GetAll gets all the fields for the given session id. func (s *Session) GetAll() (map[string]interface{}, error) { - // Load value from map if its already loaded. - if len(s.values) > 0 { - return s.values, nil + // Try to get the values from cache. + c := s.getCacheAll() + if c != nil { + return c, nil } + // Get the values from store. out, err := s.manager.store.GetAll(s.id) return out, errAs(err) } // GetMulti retrieves values for multiple session fields. // If a field is not found in the store then its returned as nil. -func (s *Session) GetMulti(keys ...string) (map[string]interface{}, error) { - // Load values from map if its already loaded - if len(s.values) > 0 { - vals := make(map[string]interface{}) - for _, k := range keys { - if v, ok := s.values[k]; ok { - vals[k] = v - } - } - - return vals, nil +func (s *Session) GetMulti(key ...string) (map[string]interface{}, error) { + // Try to get the values from cache. + c := s.getCache(key...) + if c != nil { + return c, nil } - out, err := s.manager.store.GetMulti(s.id, keys...) + out, err := s.manager.store.GetMulti(s.id, key...) return out, errAs(err) } @@ -132,10 +211,14 @@ func (s *Session) GetMulti(keys ...string) (map[string]interface{}, error) { // If the session is already loaded, it returns the value from the existing map. // Otherwise, it fetches the value from the store. func (s *Session) Get(key string) (interface{}, error) { - // Return value from map if already loaded. - if len(s.values) > 0 { - if val, ok := s.values[key]; ok { - return val, nil + // Try to get the values from cache. + c := s.getCache(key) + if c != nil { + err, ok := c[key].(error) + if ok { + return nil, err + } else { + return c[key], nil } } @@ -145,45 +228,41 @@ func (s *Session) Get(key string) (interface{}, error) { } // Set assigns a value to the given key in the session. -// The store determines whether to commit all values at once or store them individually. -// Use Commit() method to commit all values if the store doesn't immediately persist them. func (s *Session) Set(key string, val interface{}) error { err := s.manager.store.Set(s.id, key, val) + if err == nil { + s.setCache(map[string]interface{}{ + key: val, + }) + } return errAs(err) } // SetMulti assigns multiple values to the session. -// The store determines whether to commit all values at once or store them individually. -func (s *Session) SetMulti(values map[string]interface{}) error { - for k, v := range values { - if err := s.manager.store.Set(s.id, k, v); err != nil { - return errAs(err) - } - } - return nil -} - -// Commit persists all values to the store. -// The store determines whether to commit all values at once or store them individually. -func (s *Session) Commit() error { - if err := s.manager.store.Commit(s.id); err != nil { - return errAs(err) +func (s *Session) SetMulti(data map[string]interface{}) error { + err := s.manager.store.SetMulti(s.id, data) + if err == nil { + s.setCache(data) } - return nil + return errAs(err) } // Delete deletes a field from session. -func (s *Session) Delete(key string) error { - if err := s.manager.store.Delete(s.id, key); err != nil { - return errAs(err) +func (s *Session) Delete(key ...string) error { + err := s.manager.store.Delete(s.id, key...) + if err == nil { + s.deleteCache(key...) } - return nil + return errAs(err) } // Clear clears session data from store and clears the cookie. func (s *Session) Clear() error { - if err := s.manager.store.Clear(s.id); err != nil { + err := s.manager.store.Clear(s.id) + if err != nil { return errAs(err) + } else { + s.ResetCache() } return s.clearCookie() } diff --git a/session_test.go b/session_test.go index d2051b3..ec20148 100644 --- a/session_test.go +++ b/session_test.go @@ -93,7 +93,7 @@ func TestSessionNewSession(t *testing.T) { assert.Equal(t, mgr, sess.manager) assert.Equal(t, reader, sess.reader) assert.Equal(t, writer, sess.writer) - assert.NotNil(t, sess.values) + assert.Nil(t, sess.cache) assert.Equal(t, mockSessionID, sess.id) assert.Equal(t, sess.id, sess.ID()) } @@ -204,7 +204,7 @@ func TestSessionClearCookie(t *testing.T) { assert.True(t, receCk.Expires.UnixNano() < time.Now().UnixNano()) } -func TestSessionLoadValues(t *testing.T) { +func TestSessionCacheAll(t *testing.T) { str := newMockStore() str.data = map[string]interface{}{ "key1": 1, @@ -215,12 +215,20 @@ func TestSessionLoadValues(t *testing.T) { sess, err := mgr.NewSession(nil, nil) assert.NoError(t, err) - err = sess.LoadValues() + // Test error. + str.err = errors.New("store error") + err = sess.CacheAll() + assert.ErrorIs(t, str.err, err) + assert.Nil(t, sess.cache) + + // Test without error. + str.err = nil + err = sess.CacheAll() assert.NoError(t, err) - assert.Equal(t, str.data, sess.values) + assert.Equal(t, str.data, sess.cache) } -func TestSessionResetValues(t *testing.T) { +func TestSessionResetCache(t *testing.T) { str := newMockStore() str.data = map[string]interface{}{ "key1": 1, @@ -228,11 +236,11 @@ func TestSessionResetValues(t *testing.T) { } mgr := newMockManager(str) sess, _ := mgr.NewSession(nil, nil) - sess.LoadValues() - assert.NotEqual(t, 0, len(sess.values)) + sess.CacheAll() + assert.Equal(t, str.data, sess.cache) - sess.ResetValues() - assert.Equal(t, 0, len(sess.values)) + sess.ResetCache() + assert.Nil(t, sess.cache) } func TestSessionGetStore(t *testing.T) { @@ -267,13 +275,13 @@ func TestSessionGetStore(t *testing.T) { assert.Equal(t, str.data["key1"], v3) } -func TestSessionGetLoaded(t *testing.T) { +func TestSessionGetCached(t *testing.T) { str := newMockStore() mgr := newMockManager(str) sess, err := mgr.NewSession(nil, nil) assert.NoError(t, err) - sess.values = map[string]interface{}{ + sess.cache = map[string]interface{}{ "key1": 1, "key2": 2, "key3": 3, @@ -282,22 +290,42 @@ func TestSessionGetLoaded(t *testing.T) { // GetAll. v1, err := sess.GetAll() assert.NoError(t, err) - assert.Equal(t, sess.values, v1) + assert.Equal(t, sess.cache, v1) // GetMulti. v2, err := sess.GetMulti("key1", "key2") assert.NoError(t, err) assert.Contains(t, v2, "key1") - assert.Equal(t, sess.values["key1"], v2["key1"]) + assert.Equal(t, sess.cache["key1"], v2["key1"]) assert.Contains(t, v2, "key2") - assert.Equal(t, sess.values["key2"], v2["key2"]) + assert.Equal(t, sess.cache["key2"], v2["key2"]) assert.NotContains(t, v2, "key3") // Get. v3, err := sess.Get("key1") assert.NoError(t, err) - assert.Contains(t, sess.values, "key1") - assert.Equal(t, sess.values["key1"], v3) + assert.Contains(t, sess.cache, "key1") + assert.Equal(t, sess.cache["key1"], v3) + + // Get unknowm field. + _, err = sess.Get("key99") + assert.ErrorIs(t, ErrFieldNotFound, err) + + // GetMulti unknown fields + v4, err := sess.GetMulti("key1", "key2", "key99", "key100") + assert.NoError(t, err) + assert.Contains(t, v4, "key1") + assert.Equal(t, sess.cache["key1"], v4["key1"]) + assert.Contains(t, v4, "key99") + assert.Contains(t, v4, "key100") + + err, ok := v4["key99"].(error) + assert.True(t, ok) + assert.ErrorIs(t, ErrFieldNotFound, err) + + err, ok = v4["key100"].(error) + assert.True(t, ok) + assert.ErrorIs(t, ErrFieldNotFound, err) } func TestSessionSet(t *testing.T) { @@ -310,19 +338,18 @@ func TestSessionSet(t *testing.T) { err = sess.Set("key1", 1) assert.NoError(t, err) - // Check if its set on temp. - assert.Contains(t, str.temp, "key1") - assert.NotContains(t, str.data, "key1") - assert.Equal(t, str.temp["key1"], 1) - - // Commit. - err = sess.Commit() - assert.NoError(t, err) - // Check if its set on data after commit. assert.Contains(t, str.data, "key1") - assert.NotContains(t, str.temp, "key1") assert.Equal(t, 1, str.data["key1"]) + assert.Nil(t, sess.cache) + + // Cache and set. + err = sess.CacheAll() + assert.NoError(t, err) + err = sess.Set("key1", 1) + assert.NoError(t, err) + assert.NotNil(t, sess.cache) + assert.Equal(t, sess.cache, str.data) } func TestSessionSetMulti(t *testing.T) { @@ -340,43 +367,28 @@ func TestSessionSetMulti(t *testing.T) { err = sess.SetMulti(data) assert.NoError(t, err) - // Check if its set on temp. - assert.Contains(t, str.temp, "key1") - assert.Contains(t, str.temp, "key2") - assert.Contains(t, str.temp, "key3") - assert.NotContains(t, str.data, "key1") - assert.NotContains(t, str.data, "key2") - assert.NotContains(t, str.data, "key3") - assert.Equal(t, data["key1"], str.temp["key1"]) - assert.Equal(t, data["key2"], str.temp["key2"]) - assert.Equal(t, data["key3"], str.temp["key3"]) - - // Commit. - err = sess.Commit() - assert.NoError(t, err) - // Check if its set on data after commit. assert.Contains(t, str.data, "key1") assert.Contains(t, str.data, "key2") assert.Contains(t, str.data, "key3") - assert.NotContains(t, str.temp, "key1") - assert.NotContains(t, str.temp, "key2") - assert.NotContains(t, str.temp, "key3") assert.Equal(t, data["key1"], str.data["key1"]) assert.Equal(t, data["key2"], str.data["key2"]) assert.Equal(t, data["key3"], str.data["key3"]) + assert.Nil(t, sess.cache) - // Test error. - str.err = errors.New("store error") + // Cache and set. + str.data = map[string]interface{}{} + err = sess.CacheAll() + assert.NoError(t, err) err = sess.SetMulti(data) - assert.ErrorIs(t, str.err, err) + assert.NoError(t, err) + assert.NotNil(t, sess.cache) + assert.Equal(t, sess.cache, str.data) // Test error. - str.err = nil - err = sess.SetMulti(data) - assert.NoError(t, err) + sess.ResetCache() str.err = errors.New("store error") - err = sess.Commit() + err = sess.SetMulti(data) assert.ErrorIs(t, str.err, err) } @@ -396,6 +408,14 @@ func TestSessionDelete(t *testing.T) { assert.NoError(t, err) assert.NotContains(t, str.data, "key1") + // Cache and set. + err = sess.CacheAll() + assert.NoError(t, err) + err = sess.Delete("key2") + assert.NoError(t, err) + assert.NotNil(t, sess.cache) + assert.Equal(t, sess.cache, str.data) + // Test error. str.err = errors.New("store error") err = sess.Delete("key2") @@ -433,6 +453,24 @@ func TestSessionClear(t *testing.T) { err = sess.Clear() assert.NoError(t, err) assert.Equal(t, 0, len(str.data)) + assert.Nil(t, sess.cache) + + // Test clear. + str = newMockStore() + str.data = map[string]interface{}{ + "key1": 1, + "key2": 2, + } + mgr = newMockManager(str) + sess, err = mgr.NewSession(nil, nil) + assert.NoError(t, err) + err = sess.CacheAll() + assert.NoError(t, err) + assert.NotNil(t, sess.cache) + err = sess.Clear() + assert.NoError(t, err) + assert.Equal(t, 0, len(str.data)) + assert.Nil(t, sess.cache) // Test deleteCookie callback. var ( diff --git a/store.go b/store.go index 5f22791..476337c 100644 --- a/store.go +++ b/store.go @@ -3,30 +3,30 @@ package simplesessions // Store represents store interface. This interface can be // implemented to create various backend stores for session. type Store interface { - // Create creates new session in store and returns the cookie value. - Create() (cookieValue string, err error) + // Create creates new session in the store and returns the session ID. + Create() (id string, err error) // Get gets a value for given key from session. - Get(cookieValue, key string) (value interface{}, err error) + Get(id, key string) (value interface{}, err error) // GetMulti gets a maps of multiple values for given keys. - GetMulti(cookieValue string, keys ...string) (values map[string]interface{}, err error) + // If some fields are not found then return ErrFieldNotFound for that field. + GetMulti(id string, keys ...string) (data map[string]interface{}, err error) - // GetAll gets all key and value from session, - GetAll(cookieValue string) (values map[string]interface{}, err error) + // GetAll gets all key and value from session. + GetAll(id string) (data map[string]interface{}, err error) // Set sets an value for a field in session. - // Its up to store to either store it in session right after set or after commit. - Set(cookieValue, key string, value interface{}) error + Set(id, key string, value interface{}) error - // Commit commits all the previously set values to store. - Commit(cookieValue string) error + // Set takes a map of kv pair and set the field in store. + SetMulti(id string, data map[string]interface{}) error - // Delete a field from session. - Delete(cookieValue string, key string) error + // Delete a given list of keys from session. + Delete(id string, key ...string) error - // Clear clears the session key from backend if exists. - Clear(cookieValue string) error + // Clear clears the entire session. + Clear(id string) error // Helper method for typecasting/asserting. Int(interface{}, error) (int, error) diff --git a/store_test.go b/store_test.go index ab1f43b..e08dc59 100644 --- a/store_test.go +++ b/store_test.go @@ -5,14 +5,13 @@ type MockStore struct { err error id string data map[string]interface{} - temp map[string]interface{} } -func (s *MockStore) Create() (cv string, err error) { +func (s *MockStore) Create() (string, error) { return s.id, s.err } -func (s *MockStore) Get(cv, key string) (value interface{}, err error) { +func (s *MockStore) Get(id, key string) (interface{}, error) { if s.id == "" { return nil, ErrInvalidSession } @@ -21,24 +20,19 @@ func (s *MockStore) Get(cv, key string) (value interface{}, err error) { if !ok { return nil, ErrFieldNotFound } - return d, s.err } -func (s *MockStore) GetMulti(cv string, keys ...string) (values map[string]interface{}, err error) { +func (s *MockStore) GetMulti(id string, keys ...string) (values map[string]interface{}, err error) { if s.id == "" { return nil, ErrInvalidSession } out := make(map[string]interface{}) for _, key := range keys { - v, err := s.Get(cv, key) - if err != nil { - if err == ErrFieldNotFound { - v = nil - } else { - return nil, err - } + v, ok := s.data[key] + if !ok { + v = err } out[key] = v } @@ -46,7 +40,7 @@ func (s *MockStore) GetMulti(cv string, keys ...string) (values map[string]inter return out, s.err } -func (s *MockStore) GetAll(cv string) (values map[string]interface{}, err error) { +func (s *MockStore) GetAll(id string) (values map[string]interface{}, err error) { if s.id == "" { return nil, ErrInvalidSession } @@ -59,39 +53,38 @@ func (s *MockStore) Set(cv, key string, value interface{}) error { return ErrInvalidSession } - s.temp[key] = value + s.data[key] = value return s.err } -func (s *MockStore) Commit(cv string) error { +func (s *MockStore) SetMulti(id string, data map[string]interface{}) error { if s.id == "" { return ErrInvalidSession } - for key, val := range s.temp { - s.data[key] = val + for k, v := range data { + s.data[k] = v } - s.temp = map[string]interface{}{} - return s.err } -func (s *MockStore) Delete(cv string, key string) error { +func (s *MockStore) Delete(id string, key ...string) error { if s.id == "" { return ErrInvalidSession } - s.temp = nil - delete(s.data, key) - delete(s.temp, key) + + for _, k := range key { + delete(s.data, k) + } return s.err } -func (s *MockStore) Clear(cv string) error { +func (s *MockStore) Clear(id string) error { if s.id == "" { return ErrInvalidSession } + s.data = map[string]interface{}{} - s.temp = map[string]interface{}{} return s.err }