Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor all stores to remove the simplesessions dependency and adhere #31

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion stores/goredis/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/alicebob/miniredis/v2 v2.32.1
github.com/redis/go-redis/v9 v9.5.1
github.com/stretchr/testify v1.9.0
github.com/vividvilla/simplesessions v0.2.0
github.com/vividvilla/simplesessions/conv v1.0.0
)

require (
Expand Down
124 changes: 79 additions & 45 deletions stores/goredis/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,37 @@ package goredis

import (
"context"
"crypto/rand"
"sync"
"time"
"unicode"

"github.com/redis/go-redis/v9"
"github.com/vividvilla/simplesessions"
"github.com/vividvilla/simplesessions/conv"
)

var (
// Error codes for store errors. This should match the codes
// defined in the /simplesessions package exactly.
ErrInvalidSession = &Err{code: 1, msg: "invalid session"}
ErrFieldNotFound = &Err{code: 2, msg: "field not found"}
ErrAssertType = &Err{code: 3, msg: "assertion failed"}
ErrNil = &Err{code: 4, msg: "nil returned"}
)

type Err struct {
code int
msg string
}

func (e *Err) Error() string {
return e.msg
}

func (e *Err) Code() int {
return e.code
}

// Store represents redis session store for simple sessions.
// Each session is stored as redis hashmap.
type Store struct {
Expand Down Expand Up @@ -54,20 +77,9 @@ func (s *Store) SetTTL(d time.Duration) {
s.ttl = d
}

// isValidSessionID checks is the given session id is valid.
func (s *Store) isValidSessionID(sess *simplesessions.Session, id string) bool {
return len(id) == sessionIDLen && sess.IsValidRandomString(id)
}

// IsValid checks if the session is set for the id.
func (s *Store) IsValid(sess *simplesessions.Session, id string) (bool, error) {
// Validate session is valid generate string or not
return s.isValidSessionID(sess, id), nil
}

// Create returns a new session id but doesn't stores it in redis since empty hashmap can't be created.
func (s *Store) Create(sess *simplesessions.Session) (string, error) {
id, err := sess.GenerateRandomString(sessionIDLen)
func (s *Store) Create() (string, error) {
id, err := generateID(sessionIDLen)
if err != nil {
return "", err
}
Expand All @@ -76,25 +88,23 @@ func (s *Store) Create(sess *simplesessions.Session) (string, error) {
}

// Get gets a field in hashmap. If field is nill then ErrFieldNotFound is raised
func (s *Store) Get(sess *simplesessions.Session, id, key string) (interface{}, error) {
// Check if valid session
if !s.isValidSessionID(sess, id) {
return nil, simplesessions.ErrInvalidSession
func (s *Store) Get(id, key string) (interface{}, error) {
if !validateID(id) {
return nil, ErrInvalidSession
}

v, err := s.client.HGet(s.clientCtx, s.prefix+id, key).Result()
if err == redis.Nil {
return nil, simplesessions.ErrFieldNotFound
return nil, ErrFieldNotFound
}

return v, err
}

// GetMulti gets a map for values for multiple keys. If key is not found then its set as nil.
func (s *Store) GetMulti(sess *simplesessions.Session, id string, keys ...string) (map[string]interface{}, error) {
// Check if valid session
if !s.isValidSessionID(sess, id) {
return nil, simplesessions.ErrInvalidSession
func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, error) {
if !validateID(id) {
return nil, ErrInvalidSession
}

v, err := s.client.HMGet(s.clientCtx, s.prefix+id, keys...).Result()
Expand All @@ -113,10 +123,9 @@ func (s *Store) GetMulti(sess *simplesessions.Session, id string, keys ...string
}

// GetAll gets all fields from hashmap.
func (s *Store) GetAll(sess *simplesessions.Session, id string) (map[string]interface{}, error) {
// Check if valid session
if !s.isValidSessionID(sess, id) {
return nil, simplesessions.ErrInvalidSession
func (s *Store) GetAll(id string) (map[string]interface{}, error) {
if !validateID(id) {
return nil, ErrInvalidSession
}

res, err := s.client.HGetAll(s.clientCtx, s.prefix+id).Result()
Expand All @@ -136,10 +145,9 @@ func (s *Store) GetAll(sess *simplesessions.Session, id string) (map[string]inte
}

// Set sets a value to given session but stored only on commit
func (s *Store) Set(sess *simplesessions.Session, id, key string, val interface{}) error {
// Check if valid session
if !s.isValidSessionID(sess, id) {
return simplesessions.ErrInvalidSession
func (s *Store) Set(id, key string, val interface{}) error {
if !validateID(id) {
return ErrInvalidSession
}

s.mu.Lock()
Expand All @@ -156,11 +164,10 @@ func (s *Store) Set(sess *simplesessions.Session, id, key string, val interface{
return nil
}

// Commit sets all set values
func (s *Store) Commit(sess *simplesessions.Session, id string) error {
// Check if valid session
if !s.isValidSessionID(sess, id) {
return simplesessions.ErrInvalidSession
// Commit sets all set values.
func (s *Store) Commit(id string) error {
if !validateID(id) {
return ErrInvalidSession
}

s.mu.RLock()
Expand Down Expand Up @@ -200,10 +207,9 @@ func (s *Store) Commit(sess *simplesessions.Session, id string) error {
}

// Delete deletes a key from redis session hashmap.
func (s *Store) Delete(sess *simplesessions.Session, id string, key string) error {
// Check if valid session
if !s.isValidSessionID(sess, id) {
return simplesessions.ErrInvalidSession
func (s *Store) Delete(id string, key string) error {
if !validateID(id) {
return ErrInvalidSession
}

// Clear temp map for given session id
Expand All @@ -213,16 +219,15 @@ func (s *Store) Delete(sess *simplesessions.Session, id string, key string) erro

err := s.client.HDel(s.clientCtx, s.prefix+id, key).Err()
if err == redis.Nil {
return simplesessions.ErrFieldNotFound
return ErrFieldNotFound
}
return err
}

// Clear clears session in redis.
func (s *Store) Clear(sess *simplesessions.Session, id string) error {
// Check if valid session
if !s.isValidSessionID(sess, id) {
return simplesessions.ErrInvalidSession
func (s *Store) Clear(id string) error {
if !validateID(id) {
return ErrInvalidSession
}

return s.client.Del(s.clientCtx, s.prefix+id).Err()
Expand Down Expand Up @@ -262,3 +267,32 @@ func (s *Store) Bytes(r interface{}, err error) ([]byte, error) {
func (s *Store) Bool(r interface{}, err error) (bool, error) {
return conv.Bool(r, err)
}

func validateID(id string) bool {
if len(id) != sessionIDLen {
return false
}

for _, r := range id {
if !unicode.IsDigit(r) && !unicode.IsLetter(r) {
return false
}
}

return true
}

// generateID generates a random alpha-num session ID.
func generateID(n int) (string, error) {
const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
bytes := make([]byte, n)
if _, err := rand.Read(bytes); err != nil {
return "", err
}

for k, v := range bytes {
bytes[k] = dict[v%byte(len(dict))]
}

return string(bytes), nil
}
Loading
Loading