diff --git a/go.mod b/go.mod index b8af66f..3aa6170 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,5 @@ module github.com/vividvilla/simplesessions/v2 -require ( - github.com/stretchr/testify v1.9.0 - github.com/valyala/fasthttp v1.40.0 -) +require github.com/stretchr/testify v1.9.0 go 1.14 diff --git a/manager.go b/manager.go index 43d44cd..639a320 100644 --- a/manager.go +++ b/manager.go @@ -2,9 +2,11 @@ package simplesessions import ( "context" + "crypto/rand" "fmt" "net/http" "time" + "unicode" ) type ctxNameType string @@ -13,6 +15,9 @@ const ( // Default cookie name used to store session. defaultCookieName = "session" + // default sessionID length. + defaultSessIDLength = 32 + // ContextName is the key used to store session in context passed to acquire method. ContextName ctxNameType = "_simple_session" ) @@ -25,11 +30,17 @@ type Manager struct { // Store basic cookie details. opts *Options - // Callback to get http cookie. - getCookieCb func(name string, r interface{}) (*http.Cookie, error) + // Hook to get http cookie. + getCookieHook func(name string, r interface{}) (*http.Cookie, error) + + // Hook to set http cookie. + setCookieHook func(cookie *http.Cookie, w interface{}) error + + // generate cookie ID. + generateID func() (string, error) - // Callback to set http cookie. - setCookieCb func(cookie *http.Cookie, w interface{}) error + // validate cookie ID. + validateID func(string) bool } // Options are available options to configure Manager. @@ -59,6 +70,11 @@ type Options struct { // SameSite sets allows you to declare if your cookie should be restricted to a first-party or same-site context. SameSite http.SameSite + + // Cookie ID length. Defaults to alphanumeric 32 characters. + // Might not be applicable to some stores like SecureCookie. + // Also not applicable if custom generateID and validateID is set. + SessionIDLength int } // New creates a new session manager for given options. @@ -77,6 +93,14 @@ func New(opts Options) *Manager { m.opts.CookiePath = "/" } + if m.opts.SessionIDLength == 0 { + m.opts.SessionIDLength = defaultSessIDLength + } + + // Assign default set and validate generate ID. + m.generateID = m.defaultGenerateID + m.validateID = m.defaultValidateID + return m } @@ -85,14 +109,27 @@ func (m *Manager) UseStore(str Store) { m.store = str } -// RegisterGetCookie sets a callback to retrieve an HTTP cookie during session acquisition. -func (m *Manager) RegisterGetCookie(cb func(string, interface{}) (*http.Cookie, error)) { - m.getCookieCb = cb +// SetCookieHooks cane be used to get and set HTTP cookie for the session. +// +// getCookie hook takes session ID and reader interface and returns http.Cookie and error. +// In a HTTP request context reader interface will be the http request object and +// it should obtain http.Cookie from the request object for the given cookie ID. +// +// setCookie hook takes http.Cookie object and a writer interface and returns error. +// In a HTTP request context the write interface will be the http request object and +// it should write http request with the incoming cookie. +func (m *Manager) SetCookieHooks(getCookie func(string, interface{}) (*http.Cookie, error), setCookie func(*http.Cookie, interface{}) error) { + m.getCookieHook = getCookie + m.setCookieHook = setCookie } -// RegisterSetCookie sets a callback to set an HTTP cookie during session acquisition. -func (m *Manager) RegisterSetCookie(cb func(*http.Cookie, interface{}) error) { - m.setCookieCb = cb +// SetSessionIDHooks cane be used to generate and validate custom session ID. +// Bydefault alpha-numeric 32bit length session ID is used if its not set. +// - Generating custom session ID, which will be uses as the ID for storing sessions in the backend. +// - Validating custom session ID, which will be used to verify the ID before querying backend. +func (m *Manager) SetSessionIDHooks(generateID func() (string, error), validateID func(string) bool) { + m.generateID = generateID + m.validateID = validateID } // NewSession creates a new `Session` and updates the cookie with a new session ID, @@ -100,20 +137,24 @@ func (m *Manager) RegisterSetCookie(cb func(*http.Cookie, interface{}) error) { func (m *Manager) NewSession(r, w interface{}) (*Session, error) { // Check if any store is set if m.store == nil { - return nil, fmt.Errorf("session store is not set") + return nil, fmt.Errorf("session store not set") } - if m.setCookieCb == nil { - return nil, fmt.Errorf("callback `SetCookie` not set") + if m.setCookieHook == nil { + return nil, fmt.Errorf("`SetCookie` hook not set") } // Create new cookie in store and write to front. // Store also calls `WriteCookie`` to write to http interface. - id, err := m.store.Create() + id, err := m.generateID() if err != nil { return nil, errAs(err) } + if err = m.store.Create(id); err != nil { + return nil, errAs(err) + } + var sess = &Session{ id: id, manager: m, @@ -137,16 +178,16 @@ func (m *Manager) NewSession(r, w interface{}) (*Session, error) { func (m *Manager) Acquire(c context.Context, r, w interface{}) (*Session, error) { // Check if any store is set if m.store == nil { - return nil, fmt.Errorf("session store is not set") + return nil, fmt.Errorf("session store not set") } // Check if callbacks are set - if m.getCookieCb == nil { - return nil, fmt.Errorf("callback `GetCookie` not set") + if m.getCookieHook == nil { + return nil, fmt.Errorf("`GetCookie` hook not set") } - if m.setCookieCb == nil { - return nil, fmt.Errorf("callback `SetCookie` not set") + if m.setCookieHook == nil { + return nil, fmt.Errorf("`SetCookie` hook not set") } // If a session was already set in the context by a middleware somewhere, return that. @@ -159,7 +200,7 @@ func (m *Manager) Acquire(c context.Context, r, w interface{}) (*Session, error) // Get existing HTTP session cookie. // If there's no error and there's a session ID (unvalidated at this point), // return a session object. - ck, err := m.getCookieCb(m.opts.CookieName, r) + ck, err := m.getCookieHook(m.opts.CookieName, r) if err == nil && ck != nil && ck.Value != "" { return &Session{ manager: m, @@ -177,3 +218,37 @@ func (m *Manager) Acquire(c context.Context, r, w interface{}) (*Session, error) return m.NewSession(r, w) } + +// defaultGenerateID generates a random alpha-num session ID. +// This will be the default method to generate cookie ID and +// can override using `SetCookieIDGenerate` method. +func (m *Manager) defaultGenerateID() (string, error) { + const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + bytes := make([]byte, m.opts.SessionIDLength) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + + for k, v := range bytes { + bytes[k] = dict[v%byte(len(dict))] + } + + return string(bytes), nil +} + +// defaultValidateID validates the incoming to ID to check +// if its alpha-numeric with configured cookie ID length. +// Can override using `SetCookieIDGenerate` method. +func (m *Manager) defaultValidateID(id string) bool { + if len(id) != m.opts.SessionIDLength { + return false + } + + for _, r := range id { + if !unicode.IsDigit(r) && !unicode.IsLetter(r) { + return false + } + } + + return true +} diff --git a/manager_test.go b/manager_test.go index f7e3dc4..9d0d6c5 100644 --- a/manager_test.go +++ b/manager_test.go @@ -97,7 +97,7 @@ func TestManagerRegisterGetCookie(t *testing.T) { m.RegisterGetCookie(cb) expectCbRes, expectCbErr := cb("", nil) - actualCbRes, actualCbErr := m.getCookieCb("", nil) + actualCbRes, actualCbErr := m.getCookieHook("", nil) assert.Equal(expectCbRes, actualCbRes) assert.Equal(expectCbErr, actualCbErr) @@ -118,7 +118,7 @@ func TestManagerRegisterSetCookie(t *testing.T) { m.RegisterSetCookie(cb) expectCbErr := cb(ck, nil) - actualCbErr := m.setCookieCb(ck, nil) + actualCbErr := m.setCookieHook(ck, nil) assert.Equal(expectCbErr, actualCbErr) } diff --git a/session.go b/session.go index 842ac8d..0cafba9 100644 --- a/session.go +++ b/session.go @@ -32,17 +32,13 @@ var ( // Store code = 1 ErrInvalidSession = errors.New("simplesession: invalid session") - // ErrFieldNotFound is raised when given key is not found in store + // ErrNil is raised when returned value is nil. // Store code = 2 - ErrFieldNotFound = errors.New("simplesession: session field not found in store") + ErrNil = errors.New("simplesession: nil returned") // ErrAssertType is raised when type assertion fails // Store code = 3 ErrAssertType = errors.New("simplesession: invalid type assertion") - - // ErrNil is raised when returned value is nil. - // Store code = 4 - ErrNil = errors.New("simplesession: nil returned") ) type errCode interface { @@ -63,7 +59,7 @@ func (s *Session) WriteCookie(id string) error { } // Call `SetCookie` callback to write cookie to response - return s.manager.setCookieCb(ck, s.writer) + return s.manager.setCookieHook(ck, s.writer) } // clearCookie sets the cookie's expiry to one day prior to clear it. @@ -76,7 +72,7 @@ func (s *Session) clearCookie() error { } // Call `SetCookie` callback to write cookie to response - return s.manager.setCookieCb(ck, s.writer) + return s.manager.setCookieHook(ck, s.writer) } // ID returns the acquired session ID. If cookie is not set then empty string is returned. @@ -117,7 +113,7 @@ func (s *Session) getCache(key ...string) map[string]interface{} { if ok { out[k] = v } else { - out[k] = ErrFieldNotFound + out[k] = nil } } @@ -325,11 +321,9 @@ func errAs(err error) error { case 1: return ErrInvalidSession case 2: - return ErrFieldNotFound + return ErrNil case 3: return ErrAssertType - case 4: - return ErrNil } return err diff --git a/store.go b/store.go index 476337c..24cd2db 100644 --- a/store.go +++ b/store.go @@ -4,7 +4,7 @@ package simplesessions // implemented to create various backend stores for session. type Store interface { // Create creates new session in the store and returns the session ID. - Create() (id string, err error) + Create(id string) (err error) // Get gets a value for given key from session. Get(id, key string) (value interface{}, err error) diff --git a/store_test.go b/store_test.go index e08dc59..07eaa08 100644 --- a/store_test.go +++ b/store_test.go @@ -18,7 +18,7 @@ func (s *MockStore) Get(id, key string) (interface{}, error) { d, ok := s.data[key] if !ok { - return nil, ErrFieldNotFound + return nil, nil } return d, s.err }