Skip to content

Commit

Permalink
Use RedisSessionBackend for storage and retrieval of session data
Browse files Browse the repository at this point in the history
  • Loading branch information
Rob Archibald committed Oct 12, 2016
1 parent 3dd4b4e commit b925435
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 36 deletions.
8 changes: 4 additions & 4 deletions authStore.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ type authStore struct {

var emailRegex = regexp.MustCompile(`^(?i)[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}$`)

func NewAuthStore(backend Backender, mailer Mailer, w http.ResponseWriter, r *http.Request, cookieKey []byte, cookiePrefix string, secureOnlyCookie bool) AuthStorer {
sessionStore := NewSessionStore(backend, w, r, cookieKey, cookiePrefix, secureOnlyCookie)
loginStore := NewLoginStore(backend, mailer, r)
return &authStore{backend, sessionStore, loginStore, mailer, NewCookieStore(w, r, cookieKey, secureOnlyCookie), r}
func NewAuthStore(b Backender, sb SessionBackender, mailer Mailer, w http.ResponseWriter, r *http.Request, cookieKey []byte, cookiePrefix string, secureOnlyCookie bool) AuthStorer {
sessionStore := NewSessionStore(sb, w, r, cookieKey, cookiePrefix, secureOnlyCookie)
loginStore := NewLoginStore(b, mailer, r)
return &authStore{b, sessionStore, loginStore, mailer, NewCookieStore(w, r, cookieKey, secureOnlyCookie), r}
}

func (s *authStore) GetSession() (*UserLoginSession, error) {
Expand Down
4 changes: 2 additions & 2 deletions authStore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestNewAuthStore(t *testing.T) {
r := &http.Request{}
b := &MockBackend{}
m := &TextMailer{}
actual := NewAuthStore(b, m, w, r, cookieKey, "prefix", false).(*authStore)
actual := NewAuthStore(b, b, m, w, r, cookieKey, "prefix", false).(*authStore)
if actual.backend != b || actual.cookieStore.(*cookieStore).w != w || actual.cookieStore.(*cookieStore).r != r {
t.Fatal("expected correct init")
}
Expand Down Expand Up @@ -68,7 +68,7 @@ func TestAuthStoreEndToEnd(t *testing.T) {
r := &http.Request{Header: http.Header{}}
b := NewBackendMemory().(*backendMemory)
m := &TextMailer{}
s := NewAuthStore(b, m, w, r, cookieKey, "prefix", false).(*authStore)
s := NewAuthStore(b, b, m, w, r, cookieKey, "prefix", false).(*authStore)

// register new user
// adds to users, logins and sessions
Expand Down
9 changes: 9 additions & 0 deletions nginxauth.conf
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ backendUser=postgres
backendDatabase=postgres
backendPassword='mysecretpassword'

###########################################
# Redis Backend Config
###########################################
redisServer="localhost"
redisPort=6379
redisPassword=""
redisMaxIdle=80
redisMaxConnections=12000

###########################################
# Database Queries
###########################################
Expand Down
15 changes: 12 additions & 3 deletions nginxauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ type authConf struct {
UpdatePasswordAndInvalidateSessionsQuery string
InvalidateUserSessionsQuery string

RedisServer string
RedisPort int
RedisPassword string
RedisMaxIdle int
RedisMaxConnections int
ConcurrentDownloads int

CookieBase64Key string

SMTPServer string
Expand All @@ -56,6 +63,7 @@ type authConf struct {

type nginxauth struct {
backend Backender
sb SessionBackender
mailer Mailer
cookieKey []byte
conf authConf
Expand All @@ -78,7 +86,8 @@ func newNginxAuth() (*nginxauth, error) {
log.Fatal(err)
}

backend := NewBackendMemory() //temporarily using the in-memory DB for testing
b := NewBackendMemory() //temporarily using the in-memory DB for testing
sb := NewRedisSessionBackend(config.RedisServer, config.RedisPort, config.RedisPassword, config.RedisMaxIdle, config.RedisMaxConnections)

mailer, err := config.NewEmailer()
if err != nil {
Expand All @@ -90,7 +99,7 @@ func newNginxAuth() (*nginxauth, error) {
return nil, err
}

return &nginxauth{backend, mailer, cookieKey, config}, nil
return &nginxauth{b, sb, mailer, cookieKey, config}, nil
}

func (n *authConf) newOnedbBackend() (Backender, error) {
Expand Down Expand Up @@ -166,7 +175,7 @@ func (s *nginxauth) method(name string, handler func(authStore AuthStorer, w htt
return
}
secureOnly := strings.HasPrefix(r.Referer(), "https") // proxy to back-end so if referer is secure connection, we can use secureOnly cookies
authStore := NewAuthStore(s.backend, s.mailer, w, r, s.cookieKey, "ef", secureOnly)
authStore := NewAuthStore(s.backend, s.sb, s.mailer, w, r, s.cookieKey, "ef", secureOnly)
handler(authStore, w, r)
}
}
46 changes: 34 additions & 12 deletions redisSessionBackend.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,32 @@ import (
"time"
)

var redisCreate redisCreator = &redisRealCreator{}

type redisCreator interface {
newConnPool(server string, port int, password string, maxIdle, maxConnections int) redisBackender
}

type redisRealCreator struct{}

func (c *redisRealCreator) newConnPool(server string, port int, password string, maxIdle, maxConnections int) redisBackender {
return &redis.Pool{
MaxIdle: maxIdle,
MaxActive: maxConnections,
Dial: func() (redis.Conn, error) {
if password != "" {
return redis.Dial("tcp", fmt.Sprintf("%s:%d", server, port), redis.DialPassword(password))
}
return redis.Dial("tcp", fmt.Sprintf("%s:%d", server, port))
},
}
}

type redisBackender interface {
Close() error
Get() redis.Conn
}

type SessionBackender interface {
CreateSession(loginID, userID int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error)
GetSession(sessionHash string) (*UserLoginSession, error)
Expand All @@ -21,20 +47,16 @@ type SessionBackender interface {
}

type RedisSessionBackend struct {
pool *redis.Pool
pool redisBackender
}

func newPool(server string, port int, password string, maxIdle, maxConnections int) *redis.Pool {
return &redis.Pool{
MaxIdle: maxIdle,
MaxActive: maxConnections,
Dial: func() (redis.Conn, error) {
if password != "" {
return redis.Dial("tcp", fmt.Sprintf("%s:%d", server, port), redis.DialPassword(password))
}
return redis.Dial("tcp", fmt.Sprintf("%s:%d", server, port))
},
}
func NewRedis(server string, port int, password string, maxIdle, maxConnections int) redisBackender {
return redisCreate.newConnPool(server, port, password, maxIdle, maxConnections)
}

func NewRedisSessionBackend(server string, port int, password string, maxIdle, maxConnections int) SessionBackender {
r := NewRedis(server, port, password, maxIdle, maxConnections)
return &RedisSessionBackend{pool: r}
}

func (r *RedisSessionBackend) CreateSession(loginID, userID int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time,
Expand Down
22 changes: 11 additions & 11 deletions sessionStore.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ type rememberMeCookie struct {
}

type sessionStore struct {
backend SessionBackender
b SessionBackender
cookieStore CookieStorer
r *http.Request
}

func NewSessionStore(backend Backender, w http.ResponseWriter, r *http.Request, cookieKey []byte, cookiePrefix string, secureOnlyCookie bool) SessionStorer {
func NewSessionStore(b SessionBackender, w http.ResponseWriter, r *http.Request, cookieKey []byte, cookiePrefix string, secureOnlyCookie bool) SessionStorer {
emailCookieName = cookiePrefix + "Email"
sessionCookieName = cookiePrefix + "Session"
rememberMeCookieName = cookiePrefix + "RememberMe"
return &sessionStore{backend, NewCookieStore(w, r, cookieKey, secureOnlyCookie), r}
return &sessionStore{b, NewCookieStore(w, r, cookieKey, secureOnlyCookie), r}
}

var emailCookieName = "Email"
Expand Down Expand Up @@ -61,7 +61,7 @@ func (s *sessionStore) GetSession() (*UserLoginSession, error) {
return s.renewSession(cookie.SessionID, sessionHash, &cookie.RenewTimeUTC, &cookie.ExpireTimeUTC)
}

session, err := s.backend.GetSession(sessionHash)
session, err := s.b.GetSession(sessionHash)
if err != nil {
if err == errSessionNotFound {
s.deleteSessionCookie()
Expand All @@ -81,7 +81,7 @@ func (s *sessionStore) getRememberMe() (*UserLoginRememberMe, error) {
return nil, newAuthError("RememberMe cookie has expired", nil)
}

rememberMe, err := s.backend.GetRememberMe(cookie.Selector)
rememberMe, err := s.b.GetRememberMe(cookie.Selector)
if err != nil {
if err == errRememberMeNotFound {
s.deleteRememberMeCookie()
Expand All @@ -93,7 +93,7 @@ func (s *sessionStore) getRememberMe() (*UserLoginRememberMe, error) {
return nil, newLoggedError("RememberMe cookie doesn't match backend token", nil)
}
if rememberMe.RenewTimeUTC.Before(time.Now().UTC()) {
rememberMe, err = s.backend.RenewRememberMe(cookie.Selector, time.Now().UTC().Add(rememberMeRenewDuration))
rememberMe, err = s.b.RenewRememberMe(cookie.Selector, time.Now().UTC().Add(rememberMeRenewDuration))
if err != nil {
if err == errRememberMeNotFound {
s.deleteRememberMeCookie()
Expand All @@ -106,7 +106,7 @@ func (s *sessionStore) getRememberMe() (*UserLoginRememberMe, error) {

func (s *sessionStore) renewSession(sessionID, sessionHash string, renewTimeUTC, expireTimeUTC *time.Time) (*UserLoginSession, error) {
if renewTimeUTC.Before(time.Now().UTC()) && expireTimeUTC.After(time.Now().UTC()) {
session, err := s.backend.RenewSession(sessionHash, time.Now().UTC().Add(sessionRenewDuration))
session, err := s.b.RenewSession(sessionHash, time.Now().UTC().Add(sessionRenewDuration))
if err != nil {
return nil, newLoggedError("Unable to renew session", err)
}
Expand All @@ -122,7 +122,7 @@ func (s *sessionStore) renewSession(sessionID, sessionHash string, renewTimeUTC,
return nil, newAuthError("Unable to renew session", err)
}

session, err := s.backend.RenewSession(sessionHash, time.Now().UTC().Add(sessionRenewDuration))
session, err := s.b.RenewSession(sessionHash, time.Now().UTC().Add(sessionRenewDuration))
if err != nil {
if err == errSessionNotFound {
s.deleteSessionCookie()
Expand Down Expand Up @@ -150,7 +150,7 @@ func (s *sessionStore) CreateSession(loginID, userID int, rememberMe bool) (*Use
return nil, newLoggedError("Problem generating sessionId", nil)
}

session, remember, err := s.backend.CreateSession(loginID, userID, sessionHash, time.Now().UTC().Add(sessionRenewDuration), time.Now().UTC().Add(sessionExpireDuration), rememberMe, selector, tokenHash, time.Now().UTC().Add(rememberMeRenewDuration), time.Now().UTC().Add(rememberMeExpireDuration))
session, remember, err := s.b.CreateSession(loginID, userID, sessionHash, time.Now().UTC().Add(sessionRenewDuration), time.Now().UTC().Add(sessionExpireDuration), rememberMe, selector, tokenHash, time.Now().UTC().Add(rememberMeRenewDuration), time.Now().UTC().Add(rememberMeExpireDuration))
if err != nil {
return nil, newLoggedError("Unable to create new session", err)
}
Expand All @@ -159,13 +159,13 @@ func (s *sessionStore) CreateSession(loginID, userID int, rememberMe bool) (*Use
if err == nil {
oldSessionHash, err := decodeStringToHash(sessionCookie.SessionID)
if err == nil {
s.backend.InvalidateSession(oldSessionHash)
s.b.InvalidateSession(oldSessionHash)
}
}

rememberCookie, err := s.getRememberMeCookie()
if err == nil {
s.backend.InvalidateRememberMe(rememberCookie.Selector)
s.b.InvalidateRememberMe(rememberCookie.Selector)
s.deleteRememberMeCookie()
}

Expand Down
8 changes: 4 additions & 4 deletions sessionStore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestGetSession(t *testing.T) {
backend := &MockBackend{GetSessionReturn: test.GetSessionReturn, RenewSessionReturn: test.RenewSessionReturn}
store := getSessionStore(nil, test.SessionCookie, nil, test.HasCookieGetError, test.HasCookiePutError, backend)
val, err := store.GetSession()
methods := store.backend.(*MockBackend).MethodsCalled
methods := store.b.(*MockBackend).MethodsCalled
if (err == nil && test.ExpectedErr != "" || err != nil && test.ExpectedErr != err.Error()) ||
!collectionEqual(test.MethodsCalled, methods) {
t.Errorf("Scenario[%d] failed: %s\nexpected err:%v\tactual err:%v\nexpected val:%v\tactual val:%v\nexpected methods: %s\tactual methods: %s", i, test.Scenario, test.ExpectedErr, err, test.ExpectedResult, val, test.MethodsCalled, methods)
Expand Down Expand Up @@ -143,7 +143,7 @@ func TestRenewSession(t *testing.T) {
backend := &MockBackend{RenewSessionReturn: test.RenewSessionReturn, GetRememberMeReturn: test.GetRememberMeReturn}
store := getSessionStore(nil, nil, test.RememberCookie, test.HasCookieGetError, test.HasCookiePutError, backend)
val, err := store.renewSession("sessionId", "sessionHash", &test.RenewTimeUTC, &test.ExpireTimeUTC)
methods := store.backend.(*MockBackend).MethodsCalled
methods := store.b.(*MockBackend).MethodsCalled
if (err == nil && test.ExpectedErr != "" || err != nil && test.ExpectedErr != err.Error()) ||
!collectionEqual(test.MethodsCalled, methods) {
t.Errorf("Scenario[%d] failed: %s\nexpected err:%v\tactual err:%v\nexpected val:%v\tactual val:%v\nexpected methods: %s\tactual methods: %s", i, test.Scenario, test.ExpectedErr, err, test.ExpectedResult, val, test.MethodsCalled, methods)
Expand Down Expand Up @@ -209,7 +209,7 @@ func TestRememberMe(t *testing.T) {
backend := &MockBackend{GetRememberMeReturn: test.GetRememberMeReturn, RenewRememberMeReturn: test.RenewRememberMeReturn}
store := getSessionStore(nil, nil, test.RememberCookie, test.HasCookieGetError, test.HasCookiePutError, backend)
val, err := store.getRememberMe()
methods := store.backend.(*MockBackend).MethodsCalled
methods := store.b.(*MockBackend).MethodsCalled
if (err == nil && test.ExpectedErr != "" || err != nil && test.ExpectedErr != err.Error()) ||
!collectionEqual(test.MethodsCalled, methods) {
t.Errorf("Scenario[%d] failed: %s\nexpected err:%v\tactual err:%v\nexpected val:%v\tactual val:%v\nexpected methods: %s\tactual methods: %s", i, test.Scenario, test.ExpectedErr, err, test.ExpectedResult, val, test.MethodsCalled, methods)
Expand Down Expand Up @@ -279,7 +279,7 @@ func TestCreateSession(t *testing.T) {
backend := &MockBackend{NewLoginSessionReturn: test.NewLoginSessionReturn}
store := getSessionStore(nil, test.SessionCookie, test.RememberMeCookie, test.HasCookieGetError, test.HasCookiePutError, backend)
val, err := store.CreateSession(1, 1, test.RememberMe)
methods := store.backend.(*MockBackend).MethodsCalled
methods := store.b.(*MockBackend).MethodsCalled
if (err == nil && test.ExpectedErr != "" || err != nil && test.ExpectedErr != err.Error()) ||
!collectionEqual(test.MethodsCalled, methods) {
t.Errorf("Scenario[%d] failed: %s\nexpected err:%v\tactual err:%v\nexpected val:%v\tactual val:%v\nexpected methods: %s\tactual methods: %s", i, test.Scenario, test.ExpectedErr, err, test.ExpectedResult, val, test.MethodsCalled, methods)
Expand Down

0 comments on commit b925435

Please sign in to comment.