diff --git a/stores/goredis/go.mod b/stores/goredis/go.mod index 9a5f4b2..b268fed 100644 --- a/stores/goredis/go.mod +++ b/stores/goredis/go.mod @@ -6,7 +6,7 @@ require ( github.com/alicebob/miniredis/v2 v2.32.1 github.com/redis/go-redis/v9 v9.5.1 github.com/stretchr/testify v1.9.0 - github.com/vividvilla/simplesessions v0.2.0 + github.com/vividvilla/simplesessions/conv v1.0.0 ) require ( diff --git a/stores/goredis/store.go b/stores/goredis/store.go index fc0825e..2e4de30 100644 --- a/stores/goredis/store.go +++ b/stores/goredis/store.go @@ -2,14 +2,37 @@ package goredis import ( "context" + "crypto/rand" "sync" "time" + "unicode" "github.com/redis/go-redis/v9" - "github.com/vividvilla/simplesessions" "github.com/vividvilla/simplesessions/conv" ) +var ( + // Error codes for store errors. This should match the codes + // defined in the /simplesessions package exactly. + ErrInvalidSession = &Err{code: 1, msg: "invalid session"} + ErrFieldNotFound = &Err{code: 2, msg: "field not found"} + ErrAssertType = &Err{code: 3, msg: "assertion failed"} + ErrNil = &Err{code: 4, msg: "nil returned"} +) + +type Err struct { + code int + msg string +} + +func (e *Err) Error() string { + return e.msg +} + +func (e *Err) Code() int { + return e.code +} + // Store represents redis session store for simple sessions. // Each session is stored as redis hashmap. type Store struct { @@ -54,20 +77,9 @@ func (s *Store) SetTTL(d time.Duration) { s.ttl = d } -// isValidSessionID checks is the given session id is valid. -func (s *Store) isValidSessionID(sess *simplesessions.Session, id string) bool { - return len(id) == sessionIDLen && sess.IsValidRandomString(id) -} - -// IsValid checks if the session is set for the id. -func (s *Store) IsValid(sess *simplesessions.Session, id string) (bool, error) { - // Validate session is valid generate string or not - return s.isValidSessionID(sess, id), nil -} - // Create returns a new session id but doesn't stores it in redis since empty hashmap can't be created. -func (s *Store) Create(sess *simplesessions.Session) (string, error) { - id, err := sess.GenerateRandomString(sessionIDLen) +func (s *Store) Create() (string, error) { + id, err := generateID(sessionIDLen) if err != nil { return "", err } @@ -76,25 +88,23 @@ func (s *Store) Create(sess *simplesessions.Session) (string, error) { } // Get gets a field in hashmap. If field is nill then ErrFieldNotFound is raised -func (s *Store) Get(sess *simplesessions.Session, id, key string) (interface{}, error) { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return nil, simplesessions.ErrInvalidSession +func (s *Store) Get(id, key string) (interface{}, error) { + if !validateID(id) { + return nil, ErrInvalidSession } v, err := s.client.HGet(s.clientCtx, s.prefix+id, key).Result() if err == redis.Nil { - return nil, simplesessions.ErrFieldNotFound + return nil, ErrFieldNotFound } return v, err } // GetMulti gets a map for values for multiple keys. If key is not found then its set as nil. -func (s *Store) GetMulti(sess *simplesessions.Session, id string, keys ...string) (map[string]interface{}, error) { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return nil, simplesessions.ErrInvalidSession +func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, error) { + if !validateID(id) { + return nil, ErrInvalidSession } v, err := s.client.HMGet(s.clientCtx, s.prefix+id, keys...).Result() @@ -113,10 +123,9 @@ func (s *Store) GetMulti(sess *simplesessions.Session, id string, keys ...string } // GetAll gets all fields from hashmap. -func (s *Store) GetAll(sess *simplesessions.Session, id string) (map[string]interface{}, error) { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return nil, simplesessions.ErrInvalidSession +func (s *Store) GetAll(id string) (map[string]interface{}, error) { + if !validateID(id) { + return nil, ErrInvalidSession } res, err := s.client.HGetAll(s.clientCtx, s.prefix+id).Result() @@ -136,10 +145,9 @@ func (s *Store) GetAll(sess *simplesessions.Session, id string) (map[string]inte } // Set sets a value to given session but stored only on commit -func (s *Store) Set(sess *simplesessions.Session, id, key string, val interface{}) error { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return simplesessions.ErrInvalidSession +func (s *Store) Set(id, key string, val interface{}) error { + if !validateID(id) { + return ErrInvalidSession } s.mu.Lock() @@ -156,11 +164,10 @@ func (s *Store) Set(sess *simplesessions.Session, id, key string, val interface{ return nil } -// Commit sets all set values -func (s *Store) Commit(sess *simplesessions.Session, id string) error { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return simplesessions.ErrInvalidSession +// Commit sets all set values. +func (s *Store) Commit(id string) error { + if !validateID(id) { + return ErrInvalidSession } s.mu.RLock() @@ -200,10 +207,9 @@ func (s *Store) Commit(sess *simplesessions.Session, id string) error { } // Delete deletes a key from redis session hashmap. -func (s *Store) Delete(sess *simplesessions.Session, id string, key string) error { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return simplesessions.ErrInvalidSession +func (s *Store) Delete(id string, key string) error { + if !validateID(id) { + return ErrInvalidSession } // Clear temp map for given session id @@ -213,16 +219,15 @@ func (s *Store) Delete(sess *simplesessions.Session, id string, key string) erro err := s.client.HDel(s.clientCtx, s.prefix+id, key).Err() if err == redis.Nil { - return simplesessions.ErrFieldNotFound + return ErrFieldNotFound } return err } // Clear clears session in redis. -func (s *Store) Clear(sess *simplesessions.Session, id string) error { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return simplesessions.ErrInvalidSession +func (s *Store) Clear(id string) error { + if !validateID(id) { + return ErrInvalidSession } return s.client.Del(s.clientCtx, s.prefix+id).Err() @@ -262,3 +267,32 @@ func (s *Store) Bytes(r interface{}, err error) ([]byte, error) { func (s *Store) Bool(r interface{}, err error) (bool, error) { return conv.Bool(r, err) } + +func validateID(id string) bool { + if len(id) != sessionIDLen { + return false + } + + for _, r := range id { + if !unicode.IsDigit(r) && !unicode.IsLetter(r) { + return false + } + } + + return true +} + +// generateID generates a random alpha-num session ID. +func generateID(n int) (string, error) { + const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + bytes := make([]byte, n) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + + for k, v := range bytes { + bytes[k] = dict[v%byte(len(dict))] + } + + return string(bytes), nil +} diff --git a/stores/goredis/store_test.go b/stores/goredis/store_test.go index 6c1a8d2..0f51d94 100644 --- a/stores/goredis/store_test.go +++ b/stores/goredis/store_test.go @@ -9,7 +9,6 @@ import ( "github.com/alicebob/miniredis/v2" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" - "github.com/vividvilla/simplesessions" ) var ( @@ -56,67 +55,22 @@ func TestSetTTL(t *testing.T) { assert.Equal(str.ttl, testDur) } -func TestIsValidSessionID(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - sess := &simplesessions.Session{} - - // Not valid since length doesn't match - testString := "abc123" - assert.NotEqual(len(testString), sessionIDLen) - assert.False(str.isValidSessionID(sess, testString)) - - // Not valid since length is same but not alpha numeric - invalidTestString := "0dIHy6S2uBuKaNnTUszB218L898ikGY$" - assert.Equal(len(invalidTestString), sessionIDLen) - assert.False(str.isValidSessionID(sess, invalidTestString)) - - // Valid - validTestString := "1dIHy6S2uBuKaNnTUszB218L898ikGY1" - assert.Equal(len(validTestString), sessionIDLen) - assert.True(str.isValidSessionID(sess, validTestString)) -} - -func TestIsValid(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - sess := &simplesessions.Session{} - - // Not valid since length doesn't match - testString := "abc123" - assert.NotEqual(len(testString), sessionIDLen) - assert.False(str.IsValid(sess, testString)) - - // Not valid since length is same but not alpha numeric - invalidTestString := "2dIHy6S2uBuKaNnTUszB218L898ikGY$" - assert.Equal(len(invalidTestString), sessionIDLen) - assert.False(str.IsValid(sess, invalidTestString)) - - // Valid - validTestString := "3dIHy6S2uBuKaNnTUszB218L898ikGY1" - assert.Equal(len(validTestString), sessionIDLen) - assert.True(str.IsValid(sess, validTestString)) -} - func TestCreate(t *testing.T) { assert := assert.New(t) str := New(context.TODO(), getRedisClient()) - sess := &simplesessions.Session{} - id, err := str.Create(sess) + id, err := str.Create() assert.Nil(err) assert.Equal(len(id), sessionIDLen) - assert.True(str.IsValid(sess, id)) } func TestGetInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(context.TODO(), getRedisClient()) - sess := &simplesessions.Session{} - val, err := str.Get(sess, "invalidkey", "invalidkey") + val, err := str.Get("invalidkey", "invalidkey") assert.Nil(val) - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + assert.Error(err, ErrInvalidSession) } func TestGet(t *testing.T) { @@ -131,9 +85,8 @@ func TestGet(t *testing.T) { assert.NoError(err) str := New(context.TODO(), client) - sess := &simplesessions.Session{} - val, err := str.Int(str.Get(sess, key, field)) + val, err := str.Int(str.Get(key, field)) assert.NoError(err) assert.Equal(val, value) } @@ -141,32 +94,29 @@ func TestGet(t *testing.T) { func TestGetFieldNotFoundError(t *testing.T) { assert := assert.New(t) str := New(context.TODO(), getRedisClient()) - sess := &simplesessions.Session{} key := "10IHy6S2uBuKaNnTUszB218L898ikGY1" - val, err := str.Get(sess, key, "invalidkey") + val, err := str.Get(key, "invalidkey") assert.Nil(val) - assert.Error(err, simplesessions.ErrFieldNotFound.Error()) + assert.Error(err, ErrFieldNotFound.Error()) } func TestGetMultiInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(context.TODO(), getRedisClient()) - sess := &simplesessions.Session{} - val, err := str.GetMulti(sess, "invalidkey", "invalidkey") + val, err := str.GetMulti("invalidkey", "invalidkey") assert.Nil(val) - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + assert.Error(err, ErrInvalidSession.Error()) } func TestGetMultiFieldEmptySession(t *testing.T) { assert := assert.New(t) str := New(context.TODO(), getRedisClient()) - sess := &simplesessions.Session{} key := "11IHy6S2uBuKaNnTUszB218L898ikGY1" field := "somefield" - _, err := str.GetMulti(sess, key, field) + _, err := str.GetMulti(key, field) assert.Nil(err) } @@ -186,9 +136,8 @@ func TestGetMulti(t *testing.T) { assert.NoError(err) str := New(context.TODO(), client) - sess := &simplesessions.Session{} - vals, err := str.GetMulti(sess, key, field1, field2) + vals, err := str.GetMulti(key, field1, field2) assert.NoError(err) assert.Contains(vals, field1) assert.Contains(vals, field2) @@ -206,11 +155,10 @@ func TestGetMulti(t *testing.T) { func TestGetAllInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(context.TODO(), getRedisClient()) - sess := &simplesessions.Session{} - val, err := str.GetAll(sess, "invalidkey") + val, err := str.GetAll("invalidkey") assert.Nil(val) - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + assert.Error(err, ErrInvalidSession.Error()) } func TestGetAll(t *testing.T) { @@ -229,9 +177,8 @@ func TestGetAll(t *testing.T) { assert.NoError(err) str := New(context.TODO(), client) - sess := &simplesessions.Session{} - vals, err := str.GetAll(sess, key) + vals, err := str.GetAll(key) assert.NoError(err) assert.Contains(vals, field1) assert.Contains(vals, field2) @@ -253,10 +200,9 @@ func TestGetAll(t *testing.T) { func TestSetInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(context.TODO(), getRedisClient()) - sess := &simplesessions.Session{} - err := str.Set(sess, "invalidid", "key", "value") - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + err := str.Set("invalidid", "key", "value") + assert.Error(err, ErrInvalidSession.Error()) } func TestSet(t *testing.T) { @@ -264,7 +210,6 @@ func TestSet(t *testing.T) { assert := assert.New(t) client := getRedisClient() str := New(context.TODO(), client) - sess := &simplesessions.Session{} // this key is unique across all tests key := "7dIHy6S2uBuKaNnTUszB218L898ikGY9" @@ -274,7 +219,7 @@ func TestSet(t *testing.T) { assert.NotNil(str.tempSetMap) assert.NotContains(str.tempSetMap, key) - err := str.Set(sess, key, field, value) + err := str.Set(key, field, value) assert.NoError(err) assert.Contains(str.tempSetMap, key) assert.Contains(str.tempSetMap[key], field) @@ -289,18 +234,16 @@ func TestSet(t *testing.T) { func TestCommitInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(context.TODO(), getRedisClient()) - sess := &simplesessions.Session{} - err := str.Commit(sess, "invalidkey") - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + err := str.Commit("invalidkey") + assert.Error(err, ErrInvalidSession.Error()) } func TestEmptyCommit(t *testing.T) { assert := assert.New(t) str := New(context.TODO(), getRedisClient()) - sess := &simplesessions.Session{} - err := str.Commit(sess, "15IHy6S2uBuKaNnTUszB2180898ikGY1") + err := str.Commit("15IHy6S2uBuKaNnTUszB2180898ikGY1") assert.NoError(err) } @@ -309,7 +252,6 @@ func TestCommit(t *testing.T) { assert := assert.New(t) client := getRedisClient() str := New(context.TODO(), client) - sess := &simplesessions.Session{} str.SetTTL(10 * time.Second) @@ -320,13 +262,13 @@ func TestCommit(t *testing.T) { field2 := "someotherkey" value2 := "abc123" - err := str.Set(sess, key, field1, value1) + err := str.Set(key, field1, value1) assert.NoError(err) - err = str.Set(sess, key, field2, value2) + err = str.Set(key, field2, value2) assert.NoError(err) - err = str.Commit(sess, key) + err = str.Commit(key) assert.NoError(err) vals, err := client.HGetAll(context.TODO(), defaultPrefix+key).Result() @@ -340,10 +282,9 @@ func TestCommit(t *testing.T) { func TestDeleteInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(context.TODO(), getRedisClient()) - sess := &simplesessions.Session{} - err := str.Delete(sess, "invalidkey", "somefield") - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + err := str.Delete("invalidkey", "somefield") + assert.Error(err, ErrInvalidSession.Error()) } func TestDelete(t *testing.T) { @@ -351,7 +292,6 @@ func TestDelete(t *testing.T) { assert := assert.New(t) client := getRedisClient() str := New(context.TODO(), client) - sess := &simplesessions.Session{} // this key is unique across all tests key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" @@ -363,7 +303,7 @@ func TestDelete(t *testing.T) { err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2).Err() assert.NoError(err) - err = str.Delete(sess, key, field1) + err = str.Delete(key, field1) assert.NoError(err) val, err := client.HExists(context.TODO(), defaultPrefix+key, field1).Result() @@ -376,10 +316,9 @@ func TestDelete(t *testing.T) { func TestClearInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(context.TODO(), getRedisClient()) - sess := &simplesessions.Session{} - err := str.Clear(sess, "invalidkey") - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + err := str.Clear("invalidkey") + assert.Error(err, ErrInvalidSession.Error()) } func TestClear(t *testing.T) { @@ -387,7 +326,6 @@ func TestClear(t *testing.T) { assert := assert.New(t) client := getRedisClient() str := New(context.TODO(), client) - sess := &simplesessions.Session{} // this key is unique across all tests key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" @@ -404,7 +342,7 @@ func TestClear(t *testing.T) { assert.NoError(err) assert.NotEqual(val, int64(0)) - err = str.Clear(sess, key) + err = str.Clear(key) assert.NoError(err) val, err = client.Exists(context.TODO(), defaultPrefix+key).Result() diff --git a/stores/memory/go.mod b/stores/memory/go.mod index 83b24d4..b39b0e1 100644 --- a/stores/memory/go.mod +++ b/stores/memory/go.mod @@ -1,11 +1,8 @@ -module github.com/vividvilla/simplesessions/stores/memory +module github.com/vividvilla/simplesessions/stores/memory/v2 go 1.18 -require ( - github.com/stretchr/testify v1.9.0 - github.com/vividvilla/simplesessions v0.2.0 -) +require github.com/stretchr/testify v1.9.0 require ( github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/stores/memory/store.go b/stores/memory/store.go index 6d43569..69e78d6 100644 --- a/stores/memory/store.go +++ b/stores/memory/store.go @@ -1,15 +1,37 @@ package memory import ( + "crypto/rand" "sync" - - "github.com/vividvilla/simplesessions" + "unicode" ) const ( sessionIDLen = 32 ) +var ( + // Error codes for store errors. This should match the codes + // defined in the /simplesessions package exactly. + ErrInvalidSession = &Err{code: 1, msg: "invalid session"} + ErrFieldNotFound = &Err{code: 2, msg: "field not found"} + ErrAssertType = &Err{code: 3, msg: "assertion failed"} + ErrNil = &Err{code: 4, msg: "nil returned"} +) + +type Err struct { + code int + msg string +} + +func (e *Err) Error() string { + return e.msg +} + +func (e *Err) Code() int { + return e.code +} + // Store represents in-memory session store type Store struct { // map to store all sessions and its values @@ -25,21 +47,11 @@ func New() *Store { } } -// isValidSessionID checks is the given session id is valid. -func (s *Store) isValidSessionID(sess *simplesessions.Session, id string) bool { - return len(id) == sessionIDLen && sess.IsValidRandomString(id) -} - -// IsValid checks if the session is set for the id -func (s *Store) IsValid(sess *simplesessions.Session, id string) (bool, error) { - return s.isValidSessionID(sess, id), nil -} - // Create creates a new session id and returns it. This doesn't create the session in // sessions map since memory can be saved by not storing empty sessions and system // can not be stressed by just creating new sessions -func (s *Store) Create(sess *simplesessions.Session) (string, error) { - id, err := sess.GenerateRandomString(sessionIDLen) +func (s *Store) Create() (string, error) { + id, err := generateID(sessionIDLen) if err != nil { return "", err } @@ -48,9 +60,9 @@ func (s *Store) Create(sess *simplesessions.Session) (string, error) { } // Get gets a field in session -func (s *Store) Get(sess *simplesessions.Session, id, key string) (interface{}, error) { - if !s.isValidSessionID(sess, id) { - return nil, simplesessions.ErrInvalidSession +func (s *Store) Get(id, key string) (interface{}, error) { + if !validateID(id) { + return nil, ErrInvalidSession } var val interface{} @@ -65,17 +77,16 @@ func (s *Store) Get(sess *simplesessions.Session, id, key string) (interface{}, // If session doesn't exist or field doesn't exist then send field not found error // since we don't add session to sessions map on session create if !ok || v == nil { - return nil, simplesessions.ErrFieldNotFound + return nil, ErrFieldNotFound } return val, nil } // GetMulti gets a map for values for multiple keys. If key is not present in session then nil is returned. -func (s *Store) GetMulti(sess *simplesessions.Session, id string, keys ...string) (map[string]interface{}, error) { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return nil, simplesessions.ErrInvalidSession +func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, error) { + if !validateID(id) { + return nil, ErrInvalidSession } s.mu.RLock() @@ -101,10 +112,9 @@ func (s *Store) GetMulti(sess *simplesessions.Session, id string, keys ...string } // GetAll gets all fields in session -func (s *Store) GetAll(sess *simplesessions.Session, id string) (map[string]interface{}, error) { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return nil, simplesessions.ErrInvalidSession +func (s *Store) GetAll(id string) (map[string]interface{}, error) { + if !validateID(id) { + return nil, ErrInvalidSession } s.mu.RLock() @@ -115,10 +125,9 @@ func (s *Store) GetAll(sess *simplesessions.Session, id string) (map[string]inte } // Set sets a value to given session but stored only on commit -func (s *Store) Set(sess *simplesessions.Session, id, key string, val interface{}) error { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return simplesessions.ErrInvalidSession +func (s *Store) Set(id, key string, val interface{}) error { + if !validateID(id) { + return ErrInvalidSession } s.mu.Lock() @@ -135,15 +144,14 @@ func (s *Store) Set(sess *simplesessions.Session, id, key string, val interface{ } // Commit does nothing here since Set sets the value. -func (s *Store) Commit(sess *simplesessions.Session, id string) error { +func (s *Store) Commit(id string) error { return nil } // Delete deletes a key from session. -func (s *Store) Delete(sess *simplesessions.Session, id string, key string) error { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return simplesessions.ErrInvalidSession +func (s *Store) Delete(id string, key string) error { + if !validateID(id) { + return ErrInvalidSession } s.mu.Lock() @@ -161,10 +169,9 @@ func (s *Store) Delete(sess *simplesessions.Session, id string, key string) erro } // Clear clears session in redis. -func (s *Store) Clear(sess *simplesessions.Session, id string) error { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return simplesessions.ErrInvalidSession +func (s *Store) Clear(id string) error { + if !validateID(id) { + return ErrInvalidSession } s.mu.Lock() @@ -186,7 +193,7 @@ func (s *Store) Int(r interface{}, err error) (int, error) { v, ok := r.(int) if !ok { - err = simplesessions.ErrAssertType + err = ErrAssertType } return v, err @@ -200,7 +207,7 @@ func (s *Store) Int64(r interface{}, err error) (int64, error) { v, ok := r.(int64) if !ok { - err = simplesessions.ErrAssertType + err = ErrAssertType } return v, err @@ -214,7 +221,7 @@ func (s *Store) UInt64(r interface{}, err error) (uint64, error) { v, ok := r.(uint64) if !ok { - err = simplesessions.ErrAssertType + err = ErrAssertType } return v, err @@ -228,7 +235,7 @@ func (s *Store) Float64(r interface{}, err error) (float64, error) { v, ok := r.(float64) if !ok { - err = simplesessions.ErrAssertType + err = ErrAssertType } return v, err @@ -242,7 +249,7 @@ func (s *Store) String(r interface{}, err error) (string, error) { v, ok := r.(string) if !ok { - err = simplesessions.ErrAssertType + err = ErrAssertType } return v, err @@ -256,7 +263,7 @@ func (s *Store) Bytes(r interface{}, err error) ([]byte, error) { v, ok := r.([]byte) if !ok { - err = simplesessions.ErrAssertType + err = ErrAssertType } return v, err @@ -270,8 +277,37 @@ func (s *Store) Bool(r interface{}, err error) (bool, error) { v, ok := r.(bool) if !ok { - err = simplesessions.ErrAssertType + err = ErrAssertType } return v, err } + +func validateID(id string) bool { + if len(id) != sessionIDLen { + return false + } + + for _, r := range id { + if !unicode.IsDigit(r) && !unicode.IsLetter(r) { + return false + } + } + + return true +} + +// generateID generates a random alpha-num session ID. +func generateID(n int) (string, error) { + const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + bytes := make([]byte, n) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + + for k, v := range bytes { + bytes[k] = dict[v%byte(len(dict))] + } + + return string(bytes), nil +} diff --git a/stores/memory/store_test.go b/stores/memory/store_test.go index 509103e..1149189 100644 --- a/stores/memory/store_test.go +++ b/stores/memory/store_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/vividvilla/simplesessions" ) func TestNew(t *testing.T) { @@ -16,65 +15,59 @@ func TestNew(t *testing.T) { func TestIsValidSessionID(t *testing.T) { assert := assert.New(t) - str := New() - sess := &simplesessions.Session{} // Not valid since length doesn't match testString := "abc123" assert.NotEqual(len(testString), sessionIDLen) - assert.False(str.isValidSessionID(sess, testString)) + assert.False(validateID(testString)) // Not valid since length is same but not alpha numeric invalidTestString := "0dIHy6S2uBuKaNnTUszB218L898ikGY$" assert.Equal(len(invalidTestString), sessionIDLen) - assert.False(str.isValidSessionID(sess, invalidTestString)) + assert.False(validateID(invalidTestString)) // Valid validTestString := "1dIHy6S2uBuKaNnTUszB218L898ikGY1" assert.Equal(len(validTestString), sessionIDLen) - assert.True(str.isValidSessionID(sess, validTestString)) + assert.True(validateID(validTestString)) } func TestIsValid(t *testing.T) { assert := assert.New(t) - str := New() - sess := &simplesessions.Session{} // Not valid since length doesn't match testString := "abc123" assert.NotEqual(len(testString), sessionIDLen) - assert.False(str.IsValid(sess, testString)) + assert.False(validateID(testString)) // Not valid since length is same but not alpha numeric invalidTestString := "2dIHy6S2uBuKaNnTUszB218L898ikGY$" assert.Equal(len(invalidTestString), sessionIDLen) - assert.False(str.IsValid(sess, invalidTestString)) + assert.False(validateID(invalidTestString)) // Valid validTestString := "3dIHy6S2uBuKaNnTUszB218L898ikGY1" assert.Equal(len(validTestString), sessionIDLen) - assert.True(str.IsValid(sess, validTestString)) + assert.True(validateID(validTestString)) } func TestCreate(t *testing.T) { assert := assert.New(t) str := New() - sess := &simplesessions.Session{} - id, err := str.Create(sess) + id, err := str.Create() assert.Nil(err) assert.Equal(len(id), sessionIDLen) - assert.True(str.IsValid(sess, id)) + assert.True(validateID(id)) } func TestGetInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New() - sess := &simplesessions.Session{} - val, err := str.Get(sess, "invalidkey", "invalidkey") + val, err := str.Get("invalidkey", "invalidkey") assert.Nil(val) - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + assert.Error(err, ErrInvalidSession.Error()) } func TestGet(t *testing.T) { @@ -85,11 +78,11 @@ func TestGet(t *testing.T) { // Set a key str := New() - sess := &simplesessions.Session{} + str.sessions[key] = make(map[string]interface{}) str.sessions[key][field] = value - val, err := str.Get(sess, key, field) + val, err := str.Get(key, field) assert.NoError(err) assert.Equal(val, value) } @@ -97,31 +90,28 @@ func TestGet(t *testing.T) { func TestGetFieldNotFoundError(t *testing.T) { assert := assert.New(t) str := New() - sess := &simplesessions.Session{} key := "10IHy6S2uBuKaNnTUszB218L898ikGY1" - val, err := str.Get(sess, key, "invalidkey") + val, err := str.Get(key, "invalidkey") assert.Nil(val) - assert.Error(err, simplesessions.ErrFieldNotFound.Error()) + assert.Error(err, ErrFieldNotFound.Error()) } func TestGetMultiInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New() - sess := &simplesessions.Session{} - val, err := str.GetMulti(sess, "invalidkey", "invalidkey") + val, err := str.GetMulti("invalidkey", "invalidkey") assert.Nil(val) - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + assert.Error(err, ErrInvalidSession.Error()) } func TestGetMultiFieldEmptySession(t *testing.T) { assert := assert.New(t) str := New() - sess := &simplesessions.Session{} key := "11IHy6S2uBuKaNnTUszB218L898ikGY1" - _, err := str.GetMulti(sess, key) + _, err := str.GetMulti(key) assert.Nil(err) } @@ -136,7 +126,6 @@ func TestGetMulti(t *testing.T) { value3 := 100.10 str := New() - sess := &simplesessions.Session{} // Set a key str.sessions[key] = make(map[string]interface{}) @@ -144,7 +133,7 @@ func TestGetMulti(t *testing.T) { str.sessions[key][field2] = value2 str.sessions[key][field3] = value3 - vals, err := str.GetMulti(sess, key, field1, field2) + vals, err := str.GetMulti(key, field1, field2) assert.NoError(err) assert.Contains(vals, field1) assert.Contains(vals, field2) @@ -160,11 +149,10 @@ func TestGetMulti(t *testing.T) { func TestGetAllInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New() - sess := &simplesessions.Session{} - val, err := str.GetAll(sess, "invalidkey") + val, err := str.GetAll("invalidkey") assert.Nil(val) - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + assert.Error(err, ErrInvalidSession.Error()) } func TestGetAll(t *testing.T) { @@ -178,7 +166,6 @@ func TestGetAll(t *testing.T) { value3 := 100.10 str := New() - sess := &simplesessions.Session{} // Set a key str.sessions[key] = make(map[string]interface{}) @@ -186,7 +173,7 @@ func TestGetAll(t *testing.T) { str.sessions[key][field2] = value2 str.sessions[key][field3] = value3 - vals, err := str.GetAll(sess, key) + vals, err := str.GetAll(key) assert.NoError(err) assert.Contains(vals, field1) assert.Contains(vals, field2) @@ -205,17 +192,15 @@ func TestGetAll(t *testing.T) { func TestSetInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New() - sess := &simplesessions.Session{} - err := str.Set(sess, "invalidid", "key", "value") - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + err := str.Set("invalidid", "key", "value") + assert.Error(err, ErrInvalidSession.Error()) } func TestSet(t *testing.T) { // Test should only set in internal map and not in redis assert := assert.New(t) str := New() - sess := &simplesessions.Session{} // this key is unique across all tests key := "7dIHy6S2uBuKaNnTUszB218L898ikGY9" @@ -224,7 +209,7 @@ func TestSet(t *testing.T) { assert.NotContains(str.sessions, key) - err := str.Set(sess, key, field, value) + err := str.Set(key, field, value) assert.NoError(err) assert.Contains(str.sessions, key) assert.Contains(str.sessions[key], field) @@ -234,26 +219,23 @@ func TestSet(t *testing.T) { func TestCommit(t *testing.T) { assert := assert.New(t) str := New() - sess := &simplesessions.Session{} - err := str.Commit(sess, "invalidkey") + err := str.Commit("invalidkey") assert.Nil(err) } func TestDeleteInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New() - sess := &simplesessions.Session{} - err := str.Delete(sess, "invalidkey", "somekey") - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + err := str.Delete("invalidkey", "somekey") + assert.Error(err, ErrInvalidSession.Error()) } func TestDelete(t *testing.T) { // Test should only set in internal map and not in redis assert := assert.New(t) str := New() - sess := &simplesessions.Session{} // this key is unique across all tests key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" @@ -263,7 +245,7 @@ func TestDelete(t *testing.T) { str.sessions[key][field1] = 10 str.sessions[key][field2] = 10 - err := str.Delete(sess, key, field1) + err := str.Delete(key, field1) assert.NoError(err) assert.Contains(str.sessions[key], field2) assert.NotContains(str.sessions[key], field1) @@ -272,23 +254,21 @@ func TestDelete(t *testing.T) { func TestClearInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New() - sess := &simplesessions.Session{} - err := str.Clear(sess, "invalidkey") - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + err := str.Clear("invalidkey") + assert.Error(err, ErrInvalidSession.Error()) } func TestClear(t *testing.T) { // Test should only set in internal map and not in redis assert := assert.New(t) str := New() - sess := &simplesessions.Session{} // this key is unique across all tests key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" str.sessions[key] = make(map[string]interface{}) - err := str.Clear(sess, key) + err := str.Clear(key) assert.NoError(err) assert.NotContains(str.sessions, key) } @@ -308,7 +288,7 @@ func TestInt(t *testing.T) { assert.Error(testError) _, err = str.Int("string", nil) - assert.Error(simplesessions.ErrAssertType) + assert.Error(ErrAssertType) } func TestInt64(t *testing.T) { @@ -325,7 +305,7 @@ func TestInt64(t *testing.T) { assert.Error(testError) _, err = str.Int64("string", nil) - assert.Error(simplesessions.ErrAssertType) + assert.Error(ErrAssertType) } func TestUInt64(t *testing.T) { @@ -342,7 +322,7 @@ func TestUInt64(t *testing.T) { assert.Error(testError) _, err = str.UInt64("string", nil) - assert.Error(simplesessions.ErrAssertType) + assert.Error(ErrAssertType) } func TestFloat64(t *testing.T) { @@ -359,7 +339,7 @@ func TestFloat64(t *testing.T) { assert.Error(testError) _, err = str.Float64("string", nil) - assert.Error(simplesessions.ErrAssertType) + assert.Error(ErrAssertType) } func TestString(t *testing.T) { @@ -376,7 +356,7 @@ func TestString(t *testing.T) { assert.Error(testError) _, err = str.String(123, nil) - assert.Error(simplesessions.ErrAssertType) + assert.Error(ErrAssertType) } func TestBytes(t *testing.T) { @@ -393,7 +373,7 @@ func TestBytes(t *testing.T) { assert.Error(testError) _, err = str.Bytes("string", nil) - assert.Error(simplesessions.ErrAssertType) + assert.Error(ErrAssertType) } func TestBool(t *testing.T) { @@ -410,5 +390,5 @@ func TestBool(t *testing.T) { assert.Error(testError) _, err = str.Bool("string", nil) - assert.Error(simplesessions.ErrAssertType) + assert.Error(ErrAssertType) } diff --git a/stores/redis/go.mod b/stores/redis/go.mod index 46108a4..b15a078 100644 --- a/stores/redis/go.mod +++ b/stores/redis/go.mod @@ -1,4 +1,4 @@ -module github.com/vividvilla/simplesessions/stores/redis +module github.com/vividvilla/simplesessions/stores/redis/v2 go 1.18 @@ -6,7 +6,6 @@ require ( github.com/alicebob/miniredis/v2 v2.32.1 github.com/gomodule/redigo v1.9.2 github.com/stretchr/testify v1.9.0 - github.com/vividvilla/simplesessions v0.2.0 ) require ( diff --git a/stores/redis/store.go b/stores/redis/store.go index 792ee23..50977dd 100644 --- a/stores/redis/store.go +++ b/stores/redis/store.go @@ -1,14 +1,37 @@ package redis import ( + "crypto/rand" "errors" "sync" "time" + "unicode" "github.com/gomodule/redigo/redis" - "github.com/vividvilla/simplesessions" ) +var ( + // Error codes for store errors. This should match the codes + // defined in the /simplesessions package exactly. + ErrInvalidSession = &Err{code: 1, msg: "invalid session"} + ErrFieldNotFound = &Err{code: 2, msg: "field not found"} + ErrAssertType = &Err{code: 3, msg: "assertion failed"} + ErrNil = &Err{code: 4, msg: "nil returned"} +) + +type Err struct { + code int + msg string +} + +func (e *Err) Error() string { + return e.msg +} + +func (e *Err) Code() int { + return e.code +} + // Store represents redis session store for simple sessions. // Each session is stored as redis hashmap. type Store struct { @@ -51,20 +74,9 @@ func (s *Store) SetTTL(d time.Duration) { s.ttl = d } -// isValidSessionID checks is the given session id is valid. -func (s *Store) isValidSessionID(sess *simplesessions.Session, id string) bool { - return len(id) == sessionIDLen && sess.IsValidRandomString(id) -} - -// IsValid checks if the session is set for the id. -func (s *Store) IsValid(sess *simplesessions.Session, id string) (bool, error) { - // Validate session is valid generate string or not - return s.isValidSessionID(sess, id), nil -} - // Create returns a new session id but doesn't stores it in redis since empty hashmap can't be created. -func (s *Store) Create(sess *simplesessions.Session) (string, error) { - id, err := sess.GenerateRandomString(sessionIDLen) +func (s *Store) Create() (string, error) { + id, err := generateID(sessionIDLen) if err != nil { return "", err } @@ -73,10 +85,9 @@ func (s *Store) Create(sess *simplesessions.Session) (string, error) { } // Get gets a field in hashmap. If field is nill then ErrFieldNotFound is raised -func (s *Store) Get(sess *simplesessions.Session, id, key string) (interface{}, error) { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return nil, simplesessions.ErrInvalidSession +func (s *Store) Get(id, key string) (interface{}, error) { + if !validateID(id) { + return nil, ErrInvalidSession } conn := s.pool.Get() @@ -84,17 +95,16 @@ func (s *Store) Get(sess *simplesessions.Session, id, key string) (interface{}, v, err := conn.Do("HGET", s.prefix+id, key) if v == nil || err == redis.ErrNil { - return nil, simplesessions.ErrFieldNotFound + return nil, ErrFieldNotFound } return v, err } // GetMulti gets a map for values for multiple keys. If key is not found then its set as nil. -func (s *Store) GetMulti(sess *simplesessions.Session, id string, keys ...string) (map[string]interface{}, error) { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return nil, simplesessions.ErrInvalidSession +func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, error) { + if !validateID(id) { + return nil, ErrInvalidSession } conn := s.pool.Get() @@ -123,10 +133,9 @@ func (s *Store) GetMulti(sess *simplesessions.Session, id string, keys ...string } // GetAll gets all fields from hashmap. -func (s *Store) GetAll(sess *simplesessions.Session, id string) (map[string]interface{}, error) { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return nil, simplesessions.ErrInvalidSession +func (s *Store) GetAll(id string) (map[string]interface{}, error) { + if !validateID(id) { + return nil, ErrInvalidSession } conn := s.pool.Get() @@ -136,10 +145,9 @@ func (s *Store) GetAll(sess *simplesessions.Session, id string) (map[string]inte } // Set sets a value to given session but stored only on commit -func (s *Store) Set(sess *simplesessions.Session, id, key string, val interface{}) error { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return simplesessions.ErrInvalidSession +func (s *Store) Set(id, key string, val interface{}) error { + if !validateID(id) { + return ErrInvalidSession } s.mu.Lock() @@ -157,10 +165,9 @@ func (s *Store) Set(sess *simplesessions.Session, id, key string, val interface{ } // Commit sets all set values -func (s *Store) Commit(sess *simplesessions.Session, id string) error { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return simplesessions.ErrInvalidSession +func (s *Store) Commit(id string) error { + if !validateID(id) { + return ErrInvalidSession } s.mu.RLock() @@ -217,10 +224,9 @@ func (s *Store) Commit(sess *simplesessions.Session, id string) error { } // Delete deletes a key from redis session hashmap. -func (s *Store) Delete(sess *simplesessions.Session, id string, key string) error { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return simplesessions.ErrInvalidSession +func (s *Store) Delete(id string, key string) error { + if !validateID(id) { + return ErrInvalidSession } // Clear temp map for given session id @@ -236,10 +242,9 @@ func (s *Store) Delete(sess *simplesessions.Session, id string, key string) erro } // Clear clears session in redis. -func (s *Store) Clear(sess *simplesessions.Session, id string) error { - // Check if valid session - if !s.isValidSessionID(sess, id) { - return simplesessions.ErrInvalidSession +func (s *Store) Clear(id string) error { + if !validateID(id) { + return ErrInvalidSession } conn := s.pool.Get() @@ -307,3 +312,32 @@ func (s *Store) Bytes(r interface{}, err error) ([]byte, error) { func (s *Store) Bool(r interface{}, err error) (bool, error) { return redis.Bool(r, err) } + +func validateID(id string) bool { + if len(id) != sessionIDLen { + return false + } + + for _, r := range id { + if !unicode.IsDigit(r) && !unicode.IsLetter(r) { + return false + } + } + + return true +} + +// generateID generates a random alpha-num session ID. +func generateID(n int) (string, error) { + const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + bytes := make([]byte, n) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + + for k, v := range bytes { + bytes[k] = dict[v%byte(len(dict))] + } + + return string(bytes), nil +} diff --git a/stores/redis/store_test.go b/stores/redis/store_test.go index dd10473..f9bcd5e 100644 --- a/stores/redis/store_test.go +++ b/stores/redis/store_test.go @@ -5,8 +5,6 @@ import ( "testing" "time" - "github.com/vividvilla/simplesessions" - "github.com/alicebob/miniredis/v2" "github.com/gomodule/redigo/redis" "github.com/stretchr/testify/assert" @@ -62,67 +60,22 @@ func TestSetTTL(t *testing.T) { assert.Equal(str.ttl, testDur) } -func TestIsValidSessionID(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) - sess := &simplesessions.Session{} - - // Not valid since length doesn't match - testString := "abc123" - assert.NotEqual(len(testString), sessionIDLen) - assert.False(str.isValidSessionID(sess, testString)) - - // Not valid since length is same but not alpha numeric - invalidTestString := "0dIHy6S2uBuKaNnTUszB218L898ikGY$" - assert.Equal(len(invalidTestString), sessionIDLen) - assert.False(str.isValidSessionID(sess, invalidTestString)) - - // Valid - validTestString := "1dIHy6S2uBuKaNnTUszB218L898ikGY1" - assert.Equal(len(validTestString), sessionIDLen) - assert.True(str.isValidSessionID(sess, validTestString)) -} - -func TestIsValid(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) - sess := &simplesessions.Session{} - - // Not valid since length doesn't match - testString := "abc123" - assert.NotEqual(len(testString), sessionIDLen) - assert.False(str.IsValid(sess, testString)) - - // Not valid since length is same but not alpha numeric - invalidTestString := "2dIHy6S2uBuKaNnTUszB218L898ikGY$" - assert.Equal(len(invalidTestString), sessionIDLen) - assert.False(str.IsValid(sess, invalidTestString)) - - // Valid - validTestString := "3dIHy6S2uBuKaNnTUszB218L898ikGY1" - assert.Equal(len(validTestString), sessionIDLen) - assert.True(str.IsValid(sess, validTestString)) -} - func TestCreate(t *testing.T) { assert := assert.New(t) str := New(getRedisPool()) - sess := &simplesessions.Session{} - id, err := str.Create(sess) + id, err := str.Create() assert.Nil(err) assert.Equal(len(id), sessionIDLen) - assert.True(str.IsValid(sess, id)) } func TestGetInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(getRedisPool()) - sess := &simplesessions.Session{} - val, err := str.Get(sess, "invalidkey", "invalidkey") + val, err := str.Get("invalidkey", "invalidkey") assert.Nil(val) - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + assert.Error(err, ErrInvalidSession.Error()) } func TestGet(t *testing.T) { @@ -139,9 +92,8 @@ func TestGet(t *testing.T) { assert.NoError(err) str := New(redisPool) - sess := &simplesessions.Session{} - val, err := redis.Int(str.Get(sess, key, field)) + val, err := redis.Int(str.Get(key, field)) assert.NoError(err) assert.Equal(val, value) } @@ -149,32 +101,29 @@ func TestGet(t *testing.T) { func TestGetFieldNotFoundError(t *testing.T) { assert := assert.New(t) str := New(getRedisPool()) - sess := &simplesessions.Session{} key := "10IHy6S2uBuKaNnTUszB218L898ikGY1" - val, err := str.Get(sess, key, "invalidkey") + val, err := str.Get(key, "invalidkey") assert.Nil(val) - assert.Error(err, simplesessions.ErrFieldNotFound.Error()) + assert.Error(err, ErrFieldNotFound.Error()) } func TestGetMultiInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(getRedisPool()) - sess := &simplesessions.Session{} - val, err := str.GetMulti(sess, "invalidkey", "invalidkey") + val, err := str.GetMulti("invalidkey", "invalidkey") assert.Nil(val) - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + assert.Error(err, ErrInvalidSession.Error()) } func TestGetMultiFieldEmptySession(t *testing.T) { assert := assert.New(t) str := New(getRedisPool()) - sess := &simplesessions.Session{} key := "11IHy6S2uBuKaNnTUszB218L898ikGY1" field := "somefield" - _, err := str.GetMulti(sess, key, field) + _, err := str.GetMulti(key, field) assert.Nil(err) } @@ -196,9 +145,8 @@ func TestGetMulti(t *testing.T) { assert.NoError(err) str := New(redisPool) - sess := &simplesessions.Session{} - vals, err := str.GetMulti(sess, key, field1, field2) + vals, err := str.GetMulti(key, field1, field2) assert.NoError(err) assert.Contains(vals, field1) assert.Contains(vals, field2) @@ -216,11 +164,10 @@ func TestGetMulti(t *testing.T) { func TestGetAllInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(getRedisPool()) - sess := &simplesessions.Session{} - val, err := str.GetAll(sess, "invalidkey") + val, err := str.GetAll("invalidkey") assert.Nil(val) - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + assert.Error(err, ErrInvalidSession.Error()) } func TestGetAll(t *testing.T) { @@ -241,9 +188,8 @@ func TestGetAll(t *testing.T) { assert.NoError(err) str := New(redisPool) - sess := &simplesessions.Session{} - vals, err := str.GetAll(sess, key) + vals, err := str.GetAll(key) assert.NoError(err) assert.Contains(vals, field1) assert.Contains(vals, field2) @@ -265,10 +211,9 @@ func TestGetAll(t *testing.T) { func TestSetInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(getRedisPool()) - sess := &simplesessions.Session{} - err := str.Set(sess, "invalidid", "key", "value") - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + err := str.Set("invalidid", "key", "value") + assert.Error(err, ErrInvalidSession.Error()) } func TestSet(t *testing.T) { @@ -276,7 +221,6 @@ func TestSet(t *testing.T) { assert := assert.New(t) redisPool := getRedisPool() str := New(redisPool) - sess := &simplesessions.Session{} // this key is unique across all tests key := "7dIHy6S2uBuKaNnTUszB218L898ikGY9" @@ -286,7 +230,7 @@ func TestSet(t *testing.T) { assert.NotNil(str.tempSetMap) assert.NotContains(str.tempSetMap, key) - err := str.Set(sess, key, field, value) + err := str.Set(key, field, value) assert.NoError(err) assert.Contains(str.tempSetMap, key) assert.Contains(str.tempSetMap[key], field) @@ -304,18 +248,16 @@ func TestSet(t *testing.T) { func TestCommitInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(getRedisPool()) - sess := &simplesessions.Session{} - err := str.Commit(sess, "invalidkey") - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + err := str.Commit("invalidkey") + assert.Error(err, ErrInvalidSession.Error()) } func TestEmptyCommit(t *testing.T) { assert := assert.New(t) str := New(getRedisPool()) - sess := &simplesessions.Session{} - err := str.Commit(sess, "15IHy6S2uBuKaNnTUszB2180898ikGY1") + err := str.Commit("15IHy6S2uBuKaNnTUszB2180898ikGY1") assert.NoError(err) } @@ -324,7 +266,6 @@ func TestCommit(t *testing.T) { assert := assert.New(t) redisPool := getRedisPool() str := New(redisPool) - sess := &simplesessions.Session{} str.SetTTL(10 * time.Second) @@ -335,13 +276,13 @@ func TestCommit(t *testing.T) { field2 := "someotherkey" value2 := "abc123" - err := str.Set(sess, key, field1, value1) + err := str.Set(key, field1, value1) assert.NoError(err) - err = str.Set(sess, key, field2, value2) + err = str.Set(key, field2, value2) assert.NoError(err) - err = str.Commit(sess, key) + err = str.Commit(key) assert.NoError(err) conn := redisPool.Get() @@ -358,10 +299,9 @@ func TestCommit(t *testing.T) { func TestDeleteInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(getRedisPool()) - sess := &simplesessions.Session{} - err := str.Delete(sess, "invalidkey", "somefield") - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + err := str.Delete("invalidkey", "somefield") + assert.Error(err, ErrInvalidSession.Error()) } func TestDelete(t *testing.T) { @@ -369,7 +309,6 @@ func TestDelete(t *testing.T) { assert := assert.New(t) redisPool := getRedisPool() str := New(redisPool) - sess := &simplesessions.Session{} // this key is unique across all tests key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" @@ -383,7 +322,7 @@ func TestDelete(t *testing.T) { _, err := conn.Do("HMSET", defaultPrefix+key, field1, value1, field2, value2) assert.NoError(err) - err = str.Delete(sess, key, field1) + err = str.Delete(key, field1) assert.NoError(err) val, err := redis.Bool(conn.Do("HEXISTS", defaultPrefix+key, field1)) @@ -396,10 +335,9 @@ func TestDelete(t *testing.T) { func TestClearInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(getRedisPool()) - sess := &simplesessions.Session{} - err := str.Clear(sess, "invalidkey") - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + err := str.Clear("invalidkey") + assert.Error(err, ErrInvalidSession.Error()) } func TestClear(t *testing.T) { @@ -407,7 +345,6 @@ func TestClear(t *testing.T) { assert := assert.New(t) redisPool := getRedisPool() str := New(redisPool) - sess := &simplesessions.Session{} // this key is unique across all tests key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" @@ -427,7 +364,7 @@ func TestClear(t *testing.T) { // -2 represents key doesn't exist assert.NotEqual(val, int64(-2)) - err = str.Clear(sess, key) + err = str.Clear(key) assert.NoError(err) val, err = conn.Do("TTL", defaultPrefix+key) diff --git a/stores/securecookie/go.mod b/stores/securecookie/go.mod index a0913b2..fbff51e 100644 --- a/stores/securecookie/go.mod +++ b/stores/securecookie/go.mod @@ -1,9 +1,8 @@ -module github.com/vividvilla/simplesessions/stores/securecookie +module github.com/vividvilla/simplesessions/stores/securecookie/v2 go 1.14 require ( github.com/gorilla/securecookie v1.1.2 github.com/stretchr/testify v1.9.0 - github.com/vividvilla/simplesessions v0.2.0 ) diff --git a/stores/securecookie/secure_cookie.go b/stores/securecookie/secure_cookie.go index 02b76e6..4016f94 100644 --- a/stores/securecookie/secure_cookie.go +++ b/stores/securecookie/secure_cookie.go @@ -1,16 +1,38 @@ package securecookie import ( + "errors" "sync" "github.com/gorilla/securecookie" - "github.com/vividvilla/simplesessions" ) const ( defaultCookieName = "session" ) +var ( + // Error codes for store errors. This should match the codes + // defined in the /simplesessions package exactly. + ErrInvalidSession = &Err{code: 1, msg: "invalid session"} + ErrFieldNotFound = &Err{code: 2, msg: "field not found"} + ErrAssertType = &Err{code: 3, msg: "assertion failed"} + ErrNil = &Err{code: 4, msg: "nil returned"} +) + +type Err struct { + code int + msg string +} + +func (e *Err) Error() string { + return e.msg +} + +func (e *Err) Code() int { + return e.code +} + // Store represents secure cookie session store type Store struct { // Temp map to store values before commit. @@ -54,7 +76,7 @@ func (s *Store) SetCookieName(cookieName string) { } // IsValid checks if the given cookie value is valid. -func (s *Store) IsValid(sess *simplesessions.Session, cv string) (bool, error) { +func (s *Store) IsValid(cv string) (bool, error) { if _, err := s.decode(cv); err != nil { return false, nil } @@ -63,23 +85,23 @@ func (s *Store) IsValid(sess *simplesessions.Session, cv string) (bool, error) { } // Create creates a new secure cookie session with empty map. -func (s *Store) Create(sess *simplesessions.Session) (string, error) { +func (s *Store) Create() (string, error) { // Create empty cookie return s.encode(make(map[string]interface{})) } // Get returns a field value from session -func (s *Store) Get(sess *simplesessions.Session, cv, key string) (interface{}, error) { +func (s *Store) Get(cv, key string) (interface{}, error) { // Decode cookie value vals, err := s.decode(cv) if err != nil { - return nil, simplesessions.ErrInvalidSession + return nil, ErrInvalidSession } // Get given field val, ok := vals[key] if !ok { - return nil, simplesessions.ErrFieldNotFound + return nil, ErrFieldNotFound } return val, nil @@ -87,11 +109,11 @@ func (s *Store) Get(sess *simplesessions.Session, cv, key string) (interface{}, // GetMulti returns values for multiple fields in session. // If a field is not present then nil is returned. -func (s *Store) GetMulti(sess *simplesessions.Session, cv string, keys ...string) (map[string]interface{}, error) { +func (s *Store) GetMulti(cv string, keys ...string) (map[string]interface{}, error) { // Decode cookie value vals, err := s.decode(cv) if err != nil { - return nil, simplesessions.ErrInvalidSession + return nil, ErrInvalidSession } // Get all given fields @@ -104,17 +126,17 @@ func (s *Store) GetMulti(sess *simplesessions.Session, cv string, keys ...string } // GetAll returns all field for given session. -func (s *Store) GetAll(sess *simplesessions.Session, cv string) (map[string]interface{}, error) { +func (s *Store) GetAll(cv string) (map[string]interface{}, error) { vals, err := s.decode(cv) if err != nil { - return nil, simplesessions.ErrInvalidSession + return nil, ErrInvalidSession } return vals, nil } // Set sets a field in session but not saved untill commit is called. -func (s *Store) Set(sess *simplesessions.Session, cv, key string, val interface{}) error { +func (s *Store) Set(cv, key string, val interface{}) error { s.mu.Lock() defer s.mu.Unlock() @@ -129,77 +151,60 @@ func (s *Store) Set(sess *simplesessions.Session, cv, key string, val interface{ return nil } -// Commit saves all the field set previously to cookie. -func (s *Store) Commit(sess *simplesessions.Session, cv string) error { - // Decode current cookie - vals, err := s.decode(cv) - if err != nil { - return simplesessions.ErrInvalidSession - } - - s.mu.RLock() - tempVals, ok := s.tempSetMap[cv] - s.mu.RUnlock() - if !ok { - // Nothing to commit - return nil - } +// Commit is unsupported in this store. +func (s *Store) Commit(cv string) error { + return errors.New("Commit() is not supported. Use Flush() to get values and write to cookie externally.") +} - // Assign new fields to current values - for k, v := range tempVals { - vals[k] = v - } +// Flush flushes the 'set' buffer and returns encoded secure cookie value ready to be saved. +// This value should be written to the cookie externally. +func (s *Store) Flush(cv string) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() - // Encode new values - encoded, err := s.encode(vals) - if err != nil { - return err + vals, ok := s.tempSetMap[cv] + if !ok { + return "", nil } - // Clear temp map for given session id - s.mu.Lock() delete(s.tempSetMap, cv) - s.mu.Unlock() - // Write cookie - return sess.WriteCookie(encoded) + encoded, err := s.encode(vals) + return encoded, err } -// Delete deletes a field from session. -func (s *Store) Delete(sess *simplesessions.Session, cv, key string) error { +// Delete deletes a field from session. Once called, Flush() should be +// called to retrieve the updated, unflushed values and written to the cookie +// externally. +func (s *Store) Delete(cv, key string) error { // Decode current cookie vals, err := s.decode(cv) if err != nil { - return simplesessions.ErrInvalidSession + return ErrInvalidSession } - // Delete given key in current values + // Delete given key in current values. delete(vals, key) - // Encode new values - encoded, err := s.encode(vals) - if err != nil { - return err + // Create session map if doesn't exist. + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.tempSetMap[cv]; !ok { + s.tempSetMap[cv] = make(map[string]interface{}) } - // Clear temp map for given session id - s.mu.Lock() - delete(s.tempSetMap, cv) - s.mu.Unlock() + for k, v := range vals { + s.tempSetMap[cv][k] = v + } - // Write new value to cookie - return sess.WriteCookie(encoded) + // After this, Flush() should be called to obtain the updated encoded + // values to be written to the cookie externally. + return nil } // Clear clears the session. -func (s *Store) Clear(sess *simplesessions.Session, id string) error { - encoded, err := s.encode(make(map[string]interface{})) - if err != nil { - return err - } - - // Write new value to cookie - return sess.WriteCookie(encoded) +func (s *Store) Clear(cv string) error { + return errors.New("Clear() is not supported. Use Create() to create an empty map and write to cookie externally.") } // Int is a helper method to type assert as integer @@ -210,7 +215,7 @@ func (s *Store) Int(r interface{}, err error) (int, error) { v, ok := r.(int) if !ok { - err = simplesessions.ErrAssertType + err = ErrAssertType } return v, err @@ -224,7 +229,7 @@ func (s *Store) Int64(r interface{}, err error) (int64, error) { v, ok := r.(int64) if !ok { - err = simplesessions.ErrAssertType + err = ErrAssertType } return v, err @@ -238,7 +243,7 @@ func (s *Store) UInt64(r interface{}, err error) (uint64, error) { v, ok := r.(uint64) if !ok { - err = simplesessions.ErrAssertType + err = ErrAssertType } return v, err @@ -252,7 +257,7 @@ func (s *Store) Float64(r interface{}, err error) (float64, error) { v, ok := r.(float64) if !ok { - err = simplesessions.ErrAssertType + err = ErrAssertType } return v, err @@ -266,7 +271,7 @@ func (s *Store) String(r interface{}, err error) (string, error) { v, ok := r.(string) if !ok { - err = simplesessions.ErrAssertType + err = ErrAssertType } return v, err @@ -280,7 +285,7 @@ func (s *Store) Bytes(r interface{}, err error) ([]byte, error) { v, ok := r.([]byte) if !ok { - err = simplesessions.ErrAssertType + err = ErrAssertType } return v, err @@ -294,7 +299,7 @@ func (s *Store) Bool(r interface{}, err error) (bool, error) { v, ok := r.(bool) if !ok { - err = simplesessions.ErrAssertType + err = ErrAssertType } return v, err diff --git a/stores/securecookie/secure_cookie_test.go b/stores/securecookie/secure_cookie_test.go index 5d13276..d61101a 100644 --- a/stores/securecookie/secure_cookie_test.go +++ b/stores/securecookie/secure_cookie_test.go @@ -2,11 +2,9 @@ package securecookie import ( "errors" - "net/http" "testing" "github.com/stretchr/testify/assert" - "github.com/vividvilla/simplesessions" ) var ( @@ -35,33 +33,30 @@ func TestSetCookieName(t *testing.T) { func TestIsValid(t *testing.T) { assert := assert.New(t) str := New(secretKey, blockKey) - sess := &simplesessions.Session{} - assert.False(str.IsValid(sess, "")) + assert.False(str.IsValid("")) encoded, err := str.encode(make(map[string]interface{})) assert.Nil(err) - assert.True(str.IsValid(sess, encoded)) + assert.True(str.IsValid(encoded)) } func TestCreate(t *testing.T) { assert := assert.New(t) str := New(secretKey, blockKey) - sess := &simplesessions.Session{} - id, err := str.Create(sess) + id, err := str.Create() assert.Nil(err) - assert.True(str.IsValid(sess, id)) + assert.True(str.IsValid(id)) } func TestGetInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(secretKey, blockKey) - sess := &simplesessions.Session{} - val, err := str.Get(sess, "invalidkey", "invalidkey") + val, err := str.Get("invalidkey", "invalidkey") assert.Nil(val) - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + assert.Error(err, ErrInvalidSession.Error()) } func TestGet(t *testing.T) { @@ -71,14 +66,13 @@ func TestGet(t *testing.T) { // Set a key str := New(secretKey, blockKey) - sess := &simplesessions.Session{} m := make(map[string]interface{}) m[field] = value cv, err := str.encode(m) assert.Nil(err) - val, err := str.Get(sess, cv, field) + val, err := str.Get(cv, field) assert.NoError(err) assert.Equal(val, value) } @@ -89,36 +83,33 @@ func TestGetFieldNotFoundError(t *testing.T) { // Set a key str := New(secretKey, blockKey) - sess := &simplesessions.Session{} m := make(map[string]interface{}) cv, err := str.encode(m) assert.Nil(err) - _, err = str.Get(sess, cv, field) - assert.Error(simplesessions.ErrFieldNotFound) + _, err = str.Get(cv, field) + assert.Error(ErrFieldNotFound) } func TestGetMultiInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(secretKey, blockKey) - sess := &simplesessions.Session{} - val, err := str.GetMulti(sess, "invalidkey", "invalidkey") + val, err := str.GetMulti("invalidkey", "invalidkey") assert.Nil(val) - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + assert.Error(err, ErrInvalidSession.Error()) } func TestGetMultiFieldEmptySession(t *testing.T) { assert := assert.New(t) str := New(secretKey, blockKey) - sess := &simplesessions.Session{} m := make(map[string]interface{}) cv, err := str.encode(m) assert.Nil(err) - _, err = str.GetMulti(sess, cv) + _, err = str.GetMulti(cv) assert.Nil(err) } @@ -132,7 +123,6 @@ func TestGetMulti(t *testing.T) { value3 := 100.10 str := New(secretKey, blockKey) - sess := &simplesessions.Session{} // Set a key m := make(map[string]interface{}) @@ -142,7 +132,7 @@ func TestGetMulti(t *testing.T) { cv, err := str.encode(m) assert.Nil(err) - vals, err := str.GetMulti(sess, cv, field1, field2) + vals, err := str.GetMulti(cv, field1, field2) assert.NoError(err) assert.Contains(vals, field1) assert.Contains(vals, field2) @@ -158,11 +148,10 @@ func TestGetMulti(t *testing.T) { func TestGetAllInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(secretKey, blockKey) - sess := &simplesessions.Session{} - val, err := str.GetAll(sess, "invalidkey") + val, err := str.GetAll("invalidkey") assert.Nil(val) - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + assert.Error(err, ErrInvalidSession.Error()) } func TestGetAll(t *testing.T) { @@ -173,7 +162,6 @@ func TestGetAll(t *testing.T) { value2 := "abc123" str := New(secretKey, blockKey) - sess := &simplesessions.Session{} // Set a key m := make(map[string]interface{}) @@ -182,7 +170,7 @@ func TestGetAll(t *testing.T) { cv, err := str.encode(m) assert.Nil(err) - vals, err := str.GetAll(sess, cv) + vals, err := str.GetAll(cv) assert.NoError(err) assert.Contains(vals, field1) assert.Contains(vals, field2) @@ -198,7 +186,6 @@ func TestSet(t *testing.T) { // Test should only set in internal map and not in redis assert := assert.New(t) str := New(secretKey, blockKey) - sess := &simplesessions.Session{} // this key is unique across all tests field := "somekey" @@ -208,53 +195,29 @@ func TestSet(t *testing.T) { cv, err := str.encode(m) assert.Nil(err) - err = str.Set(sess, cv, field, value) + err = str.Set(cv, field, value) assert.NoError(err) assert.Contains(str.tempSetMap, cv) assert.Contains(str.tempSetMap[cv], field) assert.Equal(str.tempSetMap[cv][field], value) } -func TestCommitInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(secretKey, blockKey) - sess := &simplesessions.Session{} - - err := str.Commit(sess, "invalid") - assert.Error(err, simplesessions.ErrInvalidSession.Error()) -} - func TestEmptyCommit(t *testing.T) { assert := assert.New(t) str := New(secretKey, blockKey) - sess := &simplesessions.Session{} m := make(map[string]interface{}) cv, err := str.encode(m) assert.Nil(err) - err = str.Commit(sess, cv) + v, err := str.Flush(cv) + assert.Empty(v) assert.NoError(err) } func TestCommit(t *testing.T) { assert := assert.New(t) str := New(secretKey, blockKey) - sessMan := simplesessions.New(simplesessions.Options{}) - sessMan.UseStore(str) - - var receivedCookieValue string - sessMan.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { - receivedCookieValue = cookie.Value - return nil - }) - - sessMan.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - sess, err := simplesessions.NewSession(sessMan, nil, nil) - assert.Nil(err) // this key is unique across all tests field := "somekey" @@ -264,15 +227,16 @@ func TestCommit(t *testing.T) { cv, err := str.encode(m) assert.Nil(err) - err = str.Set(sess, cv, field, value) + err = str.Set(cv, field, value) assert.NoError(err) assert.Equal(len(str.tempSetMap), 1) - err = str.Commit(sess, cv) + v, err := str.Flush(cv) + assert.NotEmpty(v) assert.NoError(err) assert.Equal(len(str.tempSetMap), 0) - decoded, err := str.decode(receivedCookieValue) + decoded, err := str.decode(v) assert.NoError(err) assert.Contains(decoded, field) assert.Equal(decoded[field], value) @@ -281,84 +245,29 @@ func TestCommit(t *testing.T) { func TestDeleteInvalidSessionError(t *testing.T) { assert := assert.New(t) str := New(secretKey, blockKey) - sess := &simplesessions.Session{} - err := str.Delete(sess, "invalidkey", "somekey") - assert.Error(err, simplesessions.ErrInvalidSession.Error()) + err := str.Delete("invalidkey", "somekey") + assert.Error(err, ErrInvalidSession.Error()) } func TestDelete(t *testing.T) { assert := assert.New(t) str := New(secretKey, blockKey) - sessMan := simplesessions.New(simplesessions.Options{}) - sessMan.UseStore(str) - - var receivedCookieValue string - sessMan.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { - receivedCookieValue = cookie.Value - return nil - }) - - sessMan.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - sess, err := simplesessions.NewSession(sessMan, nil, nil) - assert.Nil(err) - - // this key is unique across all tests - field := "somekey" - value := 100 m := make(map[string]interface{}) - m[field] = value + m["key1"] = "val1" + m["key2"] = "val2" cv, err := str.encode(m) assert.Nil(err) - err = str.Delete(sess, cv, field) - assert.NoError(err) - assert.Equal(len(str.tempSetMap), 0) - - decoded, err := str.decode(receivedCookieValue) - assert.NoError(err) - assert.NotContains(decoded, field) -} - -func TestClear(t *testing.T) { - assert := assert.New(t) - str := New(secretKey, blockKey) - sessMan := simplesessions.New(simplesessions.Options{}) - sessMan.UseStore(str) - - var receivedCookieValue string - sessMan.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { - receivedCookieValue = cookie.Value - return nil - }) - - sessMan.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - sess, err := simplesessions.NewSession(sessMan, nil, nil) - assert.Nil(err) - - // this key is unique across all tests - field := "somekey" - value := 100 - - m := make(map[string]interface{}) - m[field] = value - cv, err := str.encode(m) - assert.Nil(err) + assert.NoError(str.Delete(cv, "key1")) - err = str.Clear(sess, cv) + v, err := str.Flush(cv) assert.NoError(err) - assert.Equal(len(str.tempSetMap), 0) - decoded, err := str.decode(receivedCookieValue) + decoded, err := str.decode(v) assert.NoError(err) - assert.NotContains(decoded, field) + assert.NotContains(decoded, "key1") } func TestInt(t *testing.T) { @@ -376,7 +285,7 @@ func TestInt(t *testing.T) { assert.Error(testError) _, err = str.Int("string", nil) - assert.Error(simplesessions.ErrAssertType) + assert.Error(ErrAssertType) } func TestInt64(t *testing.T) { @@ -393,7 +302,7 @@ func TestInt64(t *testing.T) { assert.Error(testError) _, err = str.Int64("string", nil) - assert.Error(simplesessions.ErrAssertType) + assert.Error(ErrAssertType) } func TestUInt64(t *testing.T) { @@ -410,7 +319,7 @@ func TestUInt64(t *testing.T) { assert.Error(testError) _, err = str.UInt64("string", nil) - assert.Error(simplesessions.ErrAssertType) + assert.Error(ErrAssertType) } func TestFloat64(t *testing.T) { @@ -427,7 +336,7 @@ func TestFloat64(t *testing.T) { assert.Error(testError) _, err = str.Float64("string", nil) - assert.Error(simplesessions.ErrAssertType) + assert.Error(ErrAssertType) } func TestString(t *testing.T) { @@ -444,7 +353,7 @@ func TestString(t *testing.T) { assert.Error(testError) _, err = str.String(123, nil) - assert.Error(simplesessions.ErrAssertType) + assert.Error(ErrAssertType) } func TestBytes(t *testing.T) { @@ -461,7 +370,7 @@ func TestBytes(t *testing.T) { assert.Error(testError) _, err = str.Bytes("string", nil) - assert.Error(simplesessions.ErrAssertType) + assert.Error(ErrAssertType) } func TestBool(t *testing.T) { @@ -478,5 +387,5 @@ func TestBool(t *testing.T) { assert.Error(testError) _, err = str.Bool("string", nil) - assert.Error(simplesessions.ErrAssertType) + assert.Error(ErrAssertType) }