Skip to content

Commit

Permalink
NewLoginSession include userid for easier querying. Rename sessionSto…
Browse files Browse the repository at this point in the history
…re to authStore
  • Loading branch information
Rob Archibald committed Oct 8, 2016
1 parent f6920f5 commit 0a9bdbe
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 26 deletions.
16 changes: 8 additions & 8 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type loginData struct {
Password string
}

func auth(sessionStore SessionStorer, w http.ResponseWriter, r *http.Request) {
func auth(sessionStore AuthStorer, w http.ResponseWriter, r *http.Request) {
session, err := sessionStore.GetSession()
if err != nil {
http.Error(w, "Authentication required: "+err.Error(), http.StatusUnauthorized)
Expand All @@ -23,7 +23,7 @@ func auth(sessionStore SessionStorer, w http.ResponseWriter, r *http.Request) {
}
}

func authBasic(sessionStore SessionStorer, w http.ResponseWriter, r *http.Request) {
func authBasic(sessionStore AuthStorer, w http.ResponseWriter, r *http.Request) {
session, err := sessionStore.GetBasicAuth()
if err != nil {
w.Header().Set("WWW-Authenticate", "Basic realm='Endfirst.com'")
Expand All @@ -33,27 +33,27 @@ func authBasic(sessionStore SessionStorer, w http.ResponseWriter, r *http.Reques
}
}

func login(sessionStore SessionStorer, w http.ResponseWriter, r *http.Request) {
func login(sessionStore AuthStorer, w http.ResponseWriter, r *http.Request) {
run(sessionStore.Login, w)
}

func register(sessionStore SessionStorer, w http.ResponseWriter, r *http.Request) {
func register(sessionStore AuthStorer, w http.ResponseWriter, r *http.Request) {
run(sessionStore.Register, w)
}

func createProfile(sessionStore SessionStorer, w http.ResponseWriter, r *http.Request) {
func createProfile(sessionStore AuthStorer, w http.ResponseWriter, r *http.Request) {
run(sessionStore.CreateProfile, w)
}

func updateEmail(sessionStore SessionStorer, w http.ResponseWriter, r *http.Request) {
func updateEmail(sessionStore AuthStorer, w http.ResponseWriter, r *http.Request) {
run(sessionStore.UpdateEmail, w)
}

func updatePassword(sessionStore SessionStorer, w http.ResponseWriter, r *http.Request) {
func updatePassword(sessionStore AuthStorer, w http.ResponseWriter, r *http.Request) {
run(sessionStore.UpdatePassword, w)
}

func verifyEmail(sessionStore SessionStorer, w http.ResponseWriter, r *http.Request) {
func verifyEmail(sessionStore AuthStorer, w http.ResponseWriter, r *http.Request) {
run(sessionStore.VerifyEmail, w)
}

Expand Down
8 changes: 4 additions & 4 deletions sessionStore.go → authStore.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"time"
)

type SessionStorer interface {
type AuthStorer interface {
GetSession() (*UserLoginSession, error)
GetBasicAuth() (*UserLoginSession, error)
Login() error
Expand Down Expand Up @@ -203,10 +203,10 @@ func (s *SessionStore) login(email, password string, rememberMe bool) (*UserLogi
return nil, NewLoggedError("Invalid username or password", nil)
}

return s.createSession(login.LoginId, rememberMe)
return s.createSession(login.LoginId, login.UserId, rememberMe)
}

func (s *SessionStore) createSession(loginId int, rememberMe bool) (*UserLoginSession, error) {
func (s *SessionStore) createSession(loginId, userId int, rememberMe bool) (*UserLoginSession, error) {
var err error
var selector, token, tokenHash string
if rememberMe {
Expand All @@ -220,7 +220,7 @@ func (s *SessionStore) createSession(loginId int, rememberMe bool) (*UserLoginSe
return nil, NewLoggedError("Problem generating sessionId", nil)
}

session, remember, err := s.backend.NewLoginSession(loginId, sessionHash, time.Now().UTC().Add(sessionRenewDuration), time.Now().UTC().Add(sessionExpireDuration), rememberMe, selector, tokenHash, time.Now().UTC().Add(rememberMeRenewDuration), time.Now().UTC().Add(rememberMeExpireDuration))
session, remember, err := s.backend.NewLoginSession(loginId, userId, sessionHash, time.Now().UTC().Add(sessionRenewDuration), time.Now().UTC().Add(sessionExpireDuration), rememberMe, selector, tokenHash, time.Now().UTC().Add(rememberMeRenewDuration), time.Now().UTC().Add(rememberMeExpireDuration))
if err != nil {
return nil, NewLoggedError("Unable to create new session", err)
}
Expand Down
2 changes: 1 addition & 1 deletion sessionStore_test.go → authStore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ func TestCreateSession(t *testing.T) {
for i, test := range createSessionTests {
backend := &MockBackend{NewSessionReturn: test.NewSessionReturn}
store := getStore(nil, test.SessionCookie, test.RememberMeCookie, test.HasCookieGetError, test.HasCookiePutError, backend)
val, err := store.createSession(1, test.RememberMe)
val, err := store.createSession(1, 1, test.RememberMe)
methods := store.backend.(*MockBackend).MethodsCalled
if (err == nil && test.ExpectedErr != "" || err != nil && test.ExpectedErr != err.Error()) ||
!collectionEqual(test.MethodsCalled, methods) {
Expand Down
2 changes: 1 addition & 1 deletion backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

type BackendQuerier interface {
GetUserLogin(email, loginProvider string) (*UserLogin, error)
NewLoginSession(loginId int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error)
NewLoginSession(loginId, userId int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error)
GetSession(sessionHash string) (*UserLoginSession, error)
RenewSession(sessionHash string, renewTimeUTC time.Time) (*UserLoginSession, error)
GetRememberMe(selector string) (*UserLoginRememberMe, error)
Expand Down
4 changes: 2 additions & 2 deletions backendMemory.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (m *BackendMemory) GetUserLogin(email, loginProvider string) (*UserLogin, e
return login, nil
}

func (m *BackendMemory) NewLoginSession(loginId int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) {
func (m *BackendMemory) NewLoginSession(loginId, userId int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) {
login := m.getLoginByLoginId(loginId)
if login == nil {
return nil, nil, ErrLoginNotFound
Expand Down Expand Up @@ -142,7 +142,7 @@ func (m *BackendMemory) CreateLogin(emailVerifyHash, passwordHash string, fullNa
m.Logins = append(m.Logins, &login)

// don't set remember me
session, _, err := m.NewLoginSession(login.LoginId, sessionHash, sessionRenewTimeUTC, sessionExpireTimeUTC, false, "", "", time.Time{}, time.Time{})
session, _, err := m.NewLoginSession(login.LoginId, login.UserId, sessionHash, sessionRenewTimeUTC, sessionExpireTimeUTC, false, "", "", time.Time{}, time.Time{})
return session, err
}

Expand Down
14 changes: 7 additions & 7 deletions backendMemory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,38 +26,38 @@ func TestBackendGetUserLogin(t *testing.T) {

func TestBackendNewLoginSession(t *testing.T) {
backend := NewBackendMemory()
if _, _, err := backend.NewLoginSession(1, "sessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); err != ErrLoginNotFound {
if _, _, err := backend.NewLoginSession(1, 1, "sessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); err != ErrLoginNotFound {
t.Error("expected error since login doesn't exist")
}
backend.Logins = append(backend.Logins, &UserLogin{UserId: 1, LoginId: 1})
if session, _, _ := backend.NewLoginSession(1, "sessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); session.SessionHash != "sessionHash" || session.LoginId != 1 || session.UserId != 1 {
if session, _, _ := backend.NewLoginSession(1, 1, "sessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); session.SessionHash != "sessionHash" || session.LoginId != 1 || session.UserId != 1 {
t.Error("expected matching session", session)
}
// create again, shouldn't create new Session, just update
if session, _, _ := backend.NewLoginSession(1, "sessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); session.SessionHash != "sessionHash" || session.LoginId != 1 || session.UserId != 1 || len(backend.Sessions) != 1 {
if session, _, _ := backend.NewLoginSession(1, 1, "sessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); session.SessionHash != "sessionHash" || session.LoginId != 1 || session.UserId != 1 || len(backend.Sessions) != 1 {
t.Error("expected matching session", session)
}
// new session ID since it was generated when no cookie was found
if session, _, _ := backend.NewLoginSession(1, "newSessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); session.SessionHash != "newSessionHash" || len(backend.Sessions) != 2 {
if session, _, _ := backend.NewLoginSession(1, 1, "newSessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); session.SessionHash != "newSessionHash" || len(backend.Sessions) != 2 {
t.Error("expected matching session", session)
}

// existing remember already exists
backend.RememberMes = append(backend.RememberMes, &UserLoginRememberMe{LoginId: 1, Selector: "selector"})
if session, rememberMe, err := backend.NewLoginSession(1, "sessionHash", in5Minutes, in1Hour, true, "selector", "hash", time.Time{}, time.Time{}); session.SessionHash != "sessionHash" || session.LoginId != 1 || session.UserId != 1 ||
if session, rememberMe, err := backend.NewLoginSession(1, 1, "sessionHash", in5Minutes, in1Hour, true, "selector", "hash", time.Time{}, time.Time{}); session.SessionHash != "sessionHash" || session.LoginId != 1 || session.UserId != 1 ||
rememberMe.LoginId != 1 || rememberMe.Selector != "selector" || rememberMe.TokenHash != "hash" {
t.Error("expected RememberMe to be created", session, rememberMe, err)
}

// create new rememberMe
if session, rememberMe, err := backend.NewLoginSession(1, "sessionHash", in5Minutes, in1Hour, true, "newselector", "hash", time.Time{}, time.Time{}); session.SessionHash != "sessionHash" || session.LoginId != 1 || session.UserId != 1 ||
if session, rememberMe, err := backend.NewLoginSession(1, 1, "sessionHash", in5Minutes, in1Hour, true, "newselector", "hash", time.Time{}, time.Time{}); session.SessionHash != "sessionHash" || session.LoginId != 1 || session.UserId != 1 ||
rememberMe.LoginId != 1 || rememberMe.Selector != "newselector" || rememberMe.TokenHash != "hash" {
t.Error("expected RememberMe to be created", session, rememberMe, err)
}

// existing remember is for different login... error
backend.RememberMes = append(backend.RememberMes, &UserLoginRememberMe{LoginId: 2, Selector: "otherselector"})
if _, _, err := backend.NewLoginSession(1, "sessionHash", in5Minutes, in1Hour, true, "otherselector", "hash", time.Time{}, time.Time{}); err != ErrRememberMeSelectorExists {
if _, _, err := backend.NewLoginSession(1, 1, "sessionHash", in5Minutes, in1Hour, true, "otherselector", "hash", time.Time{}, time.Time{}); err != ErrRememberMeSelectorExists {
t.Error("expected error", err)
}
}
Expand Down
16 changes: 16 additions & 0 deletions backendOnedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ type BackendOnedb struct {

GetUserLoginQuery string
GetSessionQuery string
NewLoginSessionQuery string
NewRememberMeQuery string
RenewSessionQuery string
GetRememberMeQuery string
RenewRememberMeQuery string
Expand All @@ -33,6 +35,20 @@ func (b *BackendOnedb) GetSession(sessionHash string) (*UserLoginSession, error)
return session, b.Db.QueryStructRow(onedb.NewSqlQuery(b.GetSessionQuery, sessionHash), session)
}

func (m *BackendOnedb) NewLoginSession(loginId, userId int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) {
var session *UserLoginSession
var remember *UserLoginRememberMe
err := m.Db.QueryStructRow(onedb.NewSqlQuery(m.NewLoginSessionQuery, loginId, userId, sessionHash, sessionRenewTimeUTC, sessionExpireTimeUTC), session)
if err != nil {
return nil, nil, err
}
err = m.Db.QueryStructRow(onedb.NewSqlQuery(m.NewRememberMeQuery, loginId, userId, sessionHash, sessionRenewTimeUTC, sessionExpireTimeUTC), rememberMe)
if err != nil {
return nil, nil, err
}
return session, remember, nil
}

func (b *BackendOnedb) RenewSession(sessionHash string, renewTimeUTC time.Time) (*UserLoginSession, error) {
var session *UserLoginSession
return session, b.Db.QueryStructRow(onedb.NewSqlQuery(b.RenewSessionQuery, sessionHash), session)
Expand Down
2 changes: 1 addition & 1 deletion backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (b *MockBackend) GetSession(sessionHash string) (*UserLoginSession, error)
b.MethodsCalled = append(b.MethodsCalled, "GetSession")
return b.GetSessionReturn.Session, b.GetSessionReturn.Err
}
func (b *MockBackend) NewLoginSession(loginId int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) {
func (b *MockBackend) NewLoginSession(loginId, userId int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) {
b.MethodsCalled = append(b.MethodsCalled, "NewSession")
return b.NewSessionReturn.Session, b.NewSessionReturn.RememberMe, b.NewSessionReturn.Err
}
Expand Down
2 changes: 1 addition & 1 deletion nginxauth.conf
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ getUserLoginQuery="select userId, loginProvider, providerKey, twoFactorEnabled,
from sharedpeople p
inner join shareduserlogins l on l.userid = p.id
where p.email = $1 and l.loginProvider = $2 and p.isdeleted = 'false'"
getSessionQuery=""
getSessionQuery="select loginId, sessionhash, userid, renewtimeutc, expiretimeutc from shareduserloginsessions where sessionhash = $1"
renewSessionQuery=""
renewSessionUsingRememberMeQuery=""
getRememberMeQuery=""
Expand Down
2 changes: 1 addition & 1 deletion nginxauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func fileLoggerHandler(h http.Handler) http.Handler {
return handlers.CombinedLoggingHandler(logFile, h)
}

func (s *nginxauth) method(name string, handler func(sessionStore SessionStorer, w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
func (s *nginxauth) method(name string, handler func(sessionStore AuthStorer, w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != name {
http.Error(w, "Unsupported method", http.StatusInternalServerError)
Expand Down

0 comments on commit 0a9bdbe

Please sign in to comment.