From 703ddf47d1ef61a39c045af58c192500eba68737 Mon Sep 17 00:00:00 2001 From: Rob Archibald Date: Tue, 24 Jan 2017 12:22:26 -0800 Subject: [PATCH] Change GetLogin to Login. Include userid and full name in Redis session. --- authStore.go | 13 +++++------ authStore_test.go | 42 ++++++++++++++-------------------- backend.go | 23 ++++++++++--------- backendLDAPLogin.go | 28 +++++++++++++---------- backendLDAPLogin_test.go | 21 +++++++++-------- backendMemory.go | 44 +++++++++++++++++------------------- backendMemory_test.go | 45 ++++++++++++++++++++----------------- backendRedisSession.go | 5 ++--- backendRedisSession_test.go | 6 ++--- backend_test.go | 30 ++++++++++++------------- 10 files changed, 126 insertions(+), 131 deletions(-) diff --git a/authStore.go b/authStore.go index 85678a3..1ae8369 100644 --- a/authStore.go +++ b/authStore.go @@ -195,18 +195,15 @@ func (s *authStore) login(email, password string, rememberMe bool) (*loginSessio // add in check for DDOS attack. Slow down or lock out checks for same account // or same IP with multiple failed attempts - login, err := s.backend.GetLogin(email, loginProviderDefaultName) + login, err := s.backend.Login(email, password) if err != nil { return nil, newLoggedError("Invalid username or password", err) } - if err := cryptoHashEquals(password, login.ProviderKey); err != nil { - return nil, newLoggedError("Invalid username or password", err) - } - return s.createSession(email, rememberMe) + return s.createSession(email, login.UserID, login.FullName, rememberMe) } -func (s *authStore) createSession(email string, rememberMe bool) (*loginSession, error) { +func (s *authStore) createSession(email string, userID int, fullname string, rememberMe bool) (*loginSession, error) { var err error var selector, token, tokenHash string if rememberMe { @@ -220,7 +217,7 @@ func (s *authStore) createSession(email string, rememberMe bool) (*loginSession, return nil, newLoggedError("Problem generating sessionId", nil) } - session, remember, err := s.backend.CreateSession(1, email, sessionHash, time.Now().UTC().Add(sessionRenewDuration), time.Now().UTC().Add(sessionExpireDuration), rememberMe, selector, tokenHash, time.Now().UTC().Add(rememberMeRenewDuration), time.Now().UTC().Add(rememberMeExpireDuration)) + session, remember, err := s.backend.CreateSession(userID, email, fullname, sessionHash, time.Now().UTC().Add(sessionRenewDuration), time.Now().UTC().Add(sessionExpireDuration), rememberMe, selector, tokenHash, time.Now().UTC().Add(rememberMeRenewDuration), time.Now().UTC().Add(rememberMeExpireDuration)) if err != nil { return nil, newLoggedError("Unable to create new session", err) } @@ -358,7 +355,7 @@ func (s *authStore) createProfile(fullName, organization, password, picturePath return newLoggedError("Unable to create login", err) } - _, err = s.createSession(session.Email, false) + _, err = s.createSession(session.Email, session.UserID, fullName, false) if err != nil { return err } diff --git a/authStore_test.go b/authStore_test.go index ba692ea..215bbec 100644 --- a/authStore_test.go +++ b/authStore_test.go @@ -303,7 +303,7 @@ func TestCreateSession(t *testing.T) { for i, test := range createSessionTests { backend := &mockBackend{CreateSessionReturn: test.CreateSessionReturn} store := getAuthStore(nil, test.SessionCookie, test.RememberMeCookie, test.HasCookieGetError, test.HasCookiePutError, nil, backend) - val, err := store.createSession("test@test.com", test.RememberMe) + val, err := store.createSession("test@test.com", 1, "fullname", test.RememberMe) methods := store.backend.(*mockBackend).MethodsCalled if (err == nil && test.ExpectedErr != "" || err != nil && test.ExpectedErr != err.Error()) || !collectionEqual(test.MethodsCalled, methods) { @@ -320,20 +320,20 @@ func TestAuthGetBasicAuth(t *testing.T) { } // Credential error - store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{GetUserLoginReturn: loginErr()}) + store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{LoginReturn: loginErr()}) if _, err := store.GetBasicAuth(); err == nil || err.Error() != "Problem decoding credentials from basic auth" { t.Error("expected error") } // login error - store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{GetUserLoginReturn: loginErr(), GetSessionReturn: sessionSuccess(futureTime, futureTime)}) + store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{LoginReturn: loginErr(), GetSessionReturn: sessionSuccess(futureTime, futureTime)}) store.r = basicAuthRequest("test@test.com", "password") if _, err := store.GetBasicAuth(); err == nil || err.Error() != "Invalid username or password" { t.Error("expected error", err) } // login success - store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{GetUserLoginReturn: loginSuccess(), GetSessionReturn: sessionSuccess(futureTime, futureTime), CreateSessionReturn: sessionRemember(futureTime, futureTime)}) + store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{LoginReturn: loginSuccess(), GetSessionReturn: sessionSuccess(futureTime, futureTime), CreateSessionReturn: sessionRemember(futureTime, futureTime)}) store.r = basicAuthRequest("test@test.com", "correctPassword") if _, err := store.GetBasicAuth(); err != nil { t.Error("expected success") @@ -382,9 +382,9 @@ func TestAuthStoreEndToEnd(t *testing.T) { // create profile err = s.createProfile("fullName", "company", "password", "picturePath", 1, 1) - hashErr := cryptoHashEquals("password", b.Logins[0].ProviderKey) + hashErr := cryptoHashEquals("password", b.Logins[0].PasswordHash) if err != nil || len(b.Users) != 1 || len(b.Sessions) != 1 || len(b.Logins) != 1 || b.Logins[0].Email != "test@test.com" || len(b.EmailSessions) != 0 || hashErr != nil { - t.Fatal("expected valid user, login and session", b.Logins[0], b.Logins[0].ProviderKey, hashErr) + t.Fatal("expected valid user, login and session", b.Logins[0], b.Logins[0].PasswordHash, hashErr) } // decode session cookie @@ -643,7 +643,7 @@ var loginTests = []struct { Password string RememberMe bool CreateSessionReturn *SessionRememberReturn - GetUserLoginReturn *LoginReturn + LoginReturn *LoginReturn ErrReturn error MethodsCalled []string ExpectedResult *rememberMeSession @@ -661,34 +661,26 @@ var loginTests = []struct { ExpectedErr: passwordValidationMessage, }, { - Scenario: "Can't get login", - Email: "email@example.com", - Password: "validPassword", - GetUserLoginReturn: loginErr(), - MethodsCalled: []string{"GetLogin"}, - ExpectedErr: "Invalid username or password", - }, - { - Scenario: "Incorrect password", - Email: "email@example.com", - Password: "wrongPassword", - GetUserLoginReturn: &LoginReturn{Login: &userLogin{Email: "test@test.com", ProviderKey: "1234"}}, - MethodsCalled: []string{"GetLogin"}, - ExpectedErr: "Invalid username or password", + Scenario: "Can't get login", + Email: "email@example.com", + Password: "validPassword", + LoginReturn: loginErr(), + MethodsCalled: []string{"Login"}, + ExpectedErr: "Invalid username or password", }, { Scenario: "Got session", Email: "email@example.com", Password: "correctPassword", - GetUserLoginReturn: loginSuccess(), + LoginReturn: loginSuccess(), CreateSessionReturn: sessionRemember(futureTime, futureTime), - MethodsCalled: []string{"GetLogin", "CreateSession", "InvalidateSession", "InvalidateRememberMe"}, + MethodsCalled: []string{"Login", "CreateSession", "InvalidateSession", "InvalidateRememberMe"}, }, } func TestAuthLogin(t *testing.T) { for i, test := range loginTests { - backend := &mockBackend{GetUserLoginReturn: test.GetUserLoginReturn, ErrReturn: test.ErrReturn, CreateSessionReturn: test.CreateSessionReturn} + backend := &mockBackend{LoginReturn: test.LoginReturn, ErrReturn: test.ErrReturn, CreateSessionReturn: test.CreateSessionReturn} store := getAuthStore(nil, nil, nil, false, false, nil, backend) val, err := store.login(test.Email, test.Password, test.RememberMe) methods := store.backend.(*mockBackend).MethodsCalled @@ -771,7 +763,7 @@ func TestLoginJson(t *testing.T) { var buf bytes.Buffer buf.WriteString(`{"Email":"test@test.com", "Password":"password", "RememberMe":true}`) r := &http.Request{Body: ioutil.NopCloser(&buf)} - backend := &mockBackend{GetUserLoginReturn: loginErr()} + backend := &mockBackend{LoginReturn: loginErr()} store := getAuthStore(nil, nil, nil, true, false, nil, backend) store.r = r err := store.Login().(*authError).innerError diff --git a/backend.go b/backend.go index c073b7e..cb027db 100644 --- a/backend.go +++ b/backend.go @@ -12,6 +12,7 @@ var errInvalidSessionHash = errors.New("DB: Invalid SessionHash") var errRememberMeSelectorExists = errors.New("DB: RememberMe selector already exists") var errUserNotFound = errors.New("DB: User not found") var errLoginNotFound = errors.New("DB: Login not found") +var errInvalidCredentials = errors.New("DB: Invalid Credentials") var errSessionNotFound = errors.New("DB: Session not found") var errSessionAlreadyExists = errors.New("DB: Session already exists") var errRememberMeNotFound = errors.New("DB: RememberMe not found") @@ -27,7 +28,7 @@ type backender interface { // LoginBackender. Write out since it contains duplicate BackendCloser CreateLogin(userID int, email, passwordHash, fullName, homeDirectory string, uidNumber, gidNumber int, mailQuota, fileQuota string) (*userLogin, error) - GetLogin(email, loginProvider string) (*userLogin, error) + Login(email, password string) (*userLogin, error) UpdateEmail(email string, password string, newEmail string) (*loginSession, error) UpdatePassword(email string, oldPassword string, newPassword string) (*loginSession, error) @@ -47,7 +48,7 @@ type userBackender interface { type loginBackender interface { CreateLogin(userID int, email, passwordHash, fullName, homeDirectory string, uidNumber, gidNumber int, mailQuota, fileQuota string) (*userLogin, error) - GetLogin(email, loginProvider string) (*userLogin, error) + Login(email, password string) (*userLogin, error) UpdateEmail(email string, password string, newEmail string) (*loginSession, error) UpdatePassword(email string, oldPassword string, newPassword string) (*loginSession, error) backendCloser @@ -59,7 +60,7 @@ type sessionBackender interface { UpdateEmailSession(verifyHash string, userID int, email string) error DeleteEmailSession(verifyHash string) error - CreateSession(userID int, email string, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*loginSession, *rememberMeSession, error) + CreateSession(userID int, email, fullname, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*loginSession, *rememberMeSession, error) GetSession(sessionHash string) (*loginSession, error) RenewSession(sessionHash string, renewTimeUTC time.Time) (*loginSession, error) InvalidateSession(sessionHash string) error @@ -86,15 +87,15 @@ type user struct { } type userLogin struct { - UserID int - Email string - LoginProviderID int - ProviderKey string + UserID int + Email string + FullName string } type loginSession struct { UserID int Email string + FullName string SessionHash string RenewTimeUTC time.Time ExpireTimeUTC time.Time @@ -159,12 +160,12 @@ type backend struct { backendCloser } -func (b *backend) GetLogin(email, loginProvider string) (*userLogin, error) { - return b.l.GetLogin(email, loginProvider) +func (b *backend) Login(email, password string) (*userLogin, error) { + return b.l.Login(email, password) } -func (b *backend) CreateSession(userID int, email, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*loginSession, *rememberMeSession, error) { - return b.s.CreateSession(userID, email, sessionHash, sessionRenewTimeUTC, sessionExpireTimeUTC, rememberMe, rememberMeSelector, rememberMeTokenHash, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC) +func (b *backend) CreateSession(userID int, email, fullname, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*loginSession, *rememberMeSession, error) { + return b.s.CreateSession(userID, email, fullname, sessionHash, sessionRenewTimeUTC, sessionExpireTimeUTC, rememberMe, rememberMeSelector, rememberMeTokenHash, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC) } func (b *backend) GetSession(sessionHash string) (*loginSession, error) { diff --git a/backendLDAPLogin.go b/backendLDAPLogin.go index dcfb882..3c44fb7 100644 --- a/backendLDAPLogin.go +++ b/backendLDAPLogin.go @@ -22,25 +22,29 @@ func newBackendLDAPLogin(server string, port int, bindDn, password, baseDn, user } type ldapData struct { - UID []string - UserPassword []string - UIDNumber []string - GIDNumber []string - HomeDirectory []string + UID string + DbUserId string + Cn string } -func (l *backendLDAPLogin) GetLogin(email, loginProvider string) (*userLogin, error) { - req := ldap.NewSearchRequest(l.baseDn, ldap.ScopeSingleLevel, ldap.NeverDerefAliases, 0, 0, false, fmt.Sprintf(l.userLoginFilter, email), []string{"uid", "userPassword", "uidNumber", "gidNumber", "homeDirectory"}, nil) +func (l *backendLDAPLogin) Login(email, password string) (*userLogin, error) { + // check credentials + err := l.db.Execute(ldap.NewSimpleBindRequest(email, password, nil)) + if err != nil { + return nil, err + } + // get login info + req := ldap.NewSearchRequest(l.baseDn, ldap.ScopeSingleLevel, ldap.NeverDerefAliases, 0, 0, false, fmt.Sprintf(l.userLoginFilter, email), []string{"uid", "dbUserId", "cn"}, nil) data := &ldapData{} - err := l.db.QueryStructRow(req, data) + err = l.db.QueryStructRow(req, data) if err != nil { return nil, err } - var password string - if len(data.UserPassword) != 0 { - password = data.UserPassword[0] + dbUserID, err := strconv.Atoi(data.DbUserId) + if err != nil { + return nil, err } - return &userLogin{ProviderKey: password}, nil + return &userLogin{UserID: dbUserID, Email: data.UID, FullName: data.Cn}, nil } /**************** TODO: create different type of user if not using file and mail quotas **********************/ diff --git a/backendLDAPLogin_test.go b/backendLDAPLogin_test.go index ea67ef4..d76016d 100644 --- a/backendLDAPLogin_test.go +++ b/backendLDAPLogin_test.go @@ -23,31 +23,34 @@ func TestNewBackendLDAPLogin(t *testing.T) { t.Fatal("unable to login", err) } - _, err = l.GetLogin("test@test.com", "") + _, err = l.Login("test@test.com", "") if err == nil { t.Fatal("Expected no results", err) } } -func TestLdapGetLogin(t *testing.T) { +func TestLdapLogin(t *testing.T) { // success - data := ldapData{UserPassword: []string{"password"}} + data := ldapData{UID: "email", DbUserId: "1234"} m := onedb.NewMock(nil, nil, data) l := backendLDAPLogin{db: m, userLoginFilter: "%s"} - login, err := l.GetLogin("email", "provider") - if err != nil || login.ProviderKey != "password" { - t.Error("expected to find data", login) + login, err := l.Login("email", "password") + if err != nil || login.Email != "email" { + t.Error("expected to find data", login, err) } queries := m.QueriesRun() - if _, ok := queries[0].(*ldap.SearchRequest); !ok { - t.Error("expected ldap search request") + if _, ok := queries[0].(*ldap.SimpleBindRequest); !ok { + t.Error("expected ldap bind request first") + } + if _, ok := queries[1].(*ldap.SearchRequest); !ok { + t.Error("expected ldap searc request next") } // error m = onedb.NewMock(nil, nil, nil) l = backendLDAPLogin{db: m, userLoginFilter: "%s"} - _, err = l.GetLogin("email", "provider") + _, err = l.Login("email", "password") if err == nil { t.Error("expected error") } diff --git a/backendMemory.go b/backendMemory.go index eba5b80..26c5272 100644 --- a/backendMemory.go +++ b/backendMemory.go @@ -6,11 +6,18 @@ import ( "time" ) +type userLoginMemory struct { + UserID int + Email string + FullName string + PasswordHash string +} + type backendMemory struct { backender EmailSessions []*emailSession Users []*user - Logins []*userLogin + Logins []*userLoginMemory Sessions []*loginSession RememberMes []*rememberMeSession LoginProviders []*loginProvider @@ -24,21 +31,24 @@ func newBackendMemory() backender { return &backendMemory{LoginProviders: []*loginProvider{&loginProvider{LoginProviderID: 1, Name: loginProviderDefaultName}}} } -func (m *backendMemory) GetLogin(email, loginProvider string) (*userLogin, error) { +func (m *backendMemory) Login(email, password string) (*userLogin, error) { login := m.getLoginByEmail(email) if login == nil { return nil, errLoginNotFound } - return login, nil + if err := cryptoHashEquals(password, login.PasswordHash); err != nil { + return nil, err + } + return &userLogin{login.UserID, login.Email, login.FullName}, nil } -func (m *backendMemory) CreateSession(userID int, email, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*loginSession, *rememberMeSession, error) { +func (m *backendMemory) CreateSession(userID int, email, fullname, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*loginSession, *rememberMeSession, error) { session := m.getSessionByHash(sessionHash) if session != nil { return nil, nil, errSessionAlreadyExists } - session = &loginSession{userID, email, sessionHash, sessionRenewTimeUTC, sessionExpireTimeUTC} + session = &loginSession{userID, email, fullname, sessionHash, sessionRenewTimeUTC, sessionExpireTimeUTC} m.Sessions = append(m.Sessions, session) var rememberItem *rememberMeSession if rememberMe { @@ -156,12 +166,11 @@ func (m *backendMemory) UpdateUser(email, fullname string, company string, pictu return nil } -// This method needs to be fixed to work with the new data model using LDAP func (m *backendMemory) CreateLogin(userID int, email, passwordHash, fullName, homeDirectory string, uidNumber, gidNumber int, mailQuota, fileQuota string) (*userLogin, error) { - login := userLogin{userID, email, 1, passwordHash} + login := userLoginMemory{userID, email, fullName, passwordHash} m.Logins = append(m.Logins, &login) - return &login, nil + return &userLogin{userID, email, fullName}, nil } func (m *backendMemory) UpdateEmail(email string, password string, newEmail string) (*loginSession, error) { @@ -241,29 +250,16 @@ func (m *backendMemory) removeSession(sessionHash string) { } } -func (m *backendMemory) getLoginProvider(name string) *loginProvider { - for _, provider := range m.LoginProviders { - if provider.Name == name { - return provider - } - } - return nil -} - -func (m *backendMemory) getLoginByUser(email, loginProvider string) *userLogin { - provider := m.getLoginProvider(loginProvider) - if provider == nil { - return nil - } +func (m *backendMemory) getLoginByUser(email string) *userLoginMemory { for _, login := range m.Logins { - if login.Email == email && login.LoginProviderID == provider.LoginProviderID { + if login.Email == email { return login } } return nil } -func (m *backendMemory) getLoginByEmail(email string) *userLogin { +func (m *backendMemory) getLoginByEmail(email string) *userLoginMemory { for _, login := range m.Logins { if login.Email == email { return login diff --git a/backendMemory_test.go b/backendMemory_test.go index aa3a988..f517448 100644 --- a/backendMemory_test.go +++ b/backendMemory_test.go @@ -8,43 +8,53 @@ import ( var in5Minutes = time.Now().UTC().Add(5 * time.Minute) var in1Hour = time.Now().UTC().Add(time.Hour) -func TestMemoryGetLogin(t *testing.T) { +func TestMemoryLogin(t *testing.T) { + // can't get login backend := newBackendMemory().(*backendMemory) - if _, err := backend.GetLogin("email", loginProviderDefaultName); err != errLoginNotFound { + if _, err := backend.Login("email", "password"); err != errLoginNotFound { t.Error("expected no login since login not added yet", err) } - expected := &userLogin{Email: "email", LoginProviderID: 1} - backend.Logins = append(backend.Logins, expected) - if actual, _ := backend.GetLogin("email", loginProviderDefaultName); expected != actual { - t.Error("expected no login since login not added yet") + + // invalid credentials + expected := &userLoginMemory{Email: "email", UserID: 1, FullName: "name", PasswordHash: "$6$bogushash"} + backend.Logins = []*userLoginMemory{expected} + if _, err := backend.Login("email", "correctPassword"); err != nil && err.Error() != "input string does not match the supplied hash" { + t.Error("expected error", err) + } + + // success + expected = &userLoginMemory{Email: "email", UserID: 1, FullName: "name", PasswordHash: "$6$rounds=200000$pYt48w3PgDcRoCMx$sxbuADDhNI9nNe35HcrFYW7vpWLLMNiPBKcbqOgaRxTBYE8hePJWvmuN9dp.783JmDZBhDJRG956Wc/fzghhh."} // hash of "correctPassword"" + backend.Logins = []*userLoginMemory{expected} + if actual, _ := backend.Login("email", "correctPassword"); expected == nil || expected.Email != actual.Email || expected.FullName != actual.FullName || expected.UserID != actual.UserID { + t.Error("expected success", expected, actual) } } func TestMemoryCreateSession(t *testing.T) { backend := newBackendMemory().(*backendMemory) - if session, _, _ := backend.CreateSession(1, "test@test.com", "sessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); session.SessionHash != "sessionHash" || session.Email != "test@test.com" { + if session, _, _ := backend.CreateSession(1, "test@test.com", "fullname", "sessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); session.SessionHash != "sessionHash" || session.Email != "test@test.com" { t.Error("expected matching session", session) } // create again, should error - if _, _, err := backend.CreateSession(1, "test@test.com", "sessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); err == nil { + if _, _, err := backend.CreateSession(1, "test@test.com", "fullname", "sessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); err == nil { t.Error("expected error since session exists", err) } // new session ID since it was generated when no cookie was found (e.g. on another computer or browser) - if session, _, _ := backend.CreateSession(1, "test@test.com", "newSessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); session.SessionHash != "newSessionHash" || len(backend.Sessions) != 2 { + if session, _, _ := backend.CreateSession(1, "test@test.com", "fullname", "newSessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); session.SessionHash != "newSessionHash" || len(backend.Sessions) != 2 { t.Error("expected matching session", session) } // new rememberMe backend.Sessions = nil backend.RememberMes = nil - if session, rememberMe, err := backend.CreateSession(1, "test@test.com", "sessionHash", in5Minutes, in1Hour, true, "selector", "hash", time.Time{}, time.Time{}); session == nil || session.SessionHash != "sessionHash" || session.Email != "test@test.com" || + if session, rememberMe, err := backend.CreateSession(1, "test@test.com", "fullname", "sessionHash", in5Minutes, in1Hour, true, "selector", "hash", time.Time{}, time.Time{}); session == nil || session.SessionHash != "sessionHash" || session.Email != "test@test.com" || rememberMe == nil || rememberMe.Selector != "selector" || rememberMe.TokenHash != "hash" { t.Error("expected RememberMe to be created", session, rememberMe, err) } // existing rememberMe. Error backend.Sessions = nil - if _, _, err := backend.CreateSession(1, "test@test.com", "sessionHash", in5Minutes, in1Hour, true, "selector", "hash", time.Time{}, time.Time{}); err != errRememberMeSelectorExists { + if _, _, err := backend.CreateSession(1, "test@test.com", "fullname", "sessionHash", in5Minutes, in1Hour, true, "selector", "hash", time.Time{}, time.Time{}); err != errRememberMeSelectorExists { t.Error("expected error", err) } } @@ -200,27 +210,20 @@ func TestMemoryClose(t *testing.T) { func TestToString(t *testing.T) { backend := newBackendMemory().(*backendMemory) backend.Users = append(backend.Users, &user{}) - backend.Logins = append(backend.Logins, &userLogin{}) + backend.Logins = append(backend.Logins, &userLoginMemory{}) backend.Sessions = append(backend.Sessions, &loginSession{}) backend.RememberMes = append(backend.RememberMes, &rememberMeSession{}) actual := backend.ToString() - expected := "Users:\n {0 0}\nLogins:\n {0 0 }\nSessions:\n {0 0001-01-01 00:00:00 +0000 UTC 0001-01-01 00:00:00 +0000 UTC}\nRememberMe:\n {0 0001-01-01 00:00:00 +0000 UTC 0001-01-01 00:00:00 +0000 UTC}\n" + expected := "Users:\n {0 0}\nLogins:\n {0 }\nSessions:\n {0 0001-01-01 00:00:00 +0000 UTC 0001-01-01 00:00:00 +0000 UTC}\nRememberMe:\n {0 0001-01-01 00:00:00 +0000 UTC 0001-01-01 00:00:00 +0000 UTC}\n" if actual != expected { t.Error("expected different value", actual) } } -func TestGetLoginProvider(t *testing.T) { - backend := newBackendMemory().(*backendMemory) - if backend.getLoginProvider("bogus") != nil { - t.Error("expected no provider") - } -} - func TestGetLoginByUser(t *testing.T) { backend := newBackendMemory().(*backendMemory) - if backend.getLoginByUser("email", "bogus") != nil { + if backend.getLoginByUser("email") != nil { t.Error("expected no login") } } diff --git a/backendRedisSession.go b/backendRedisSession.go index cc3b698..aac6811 100644 --- a/backendRedisSession.go +++ b/backendRedisSession.go @@ -2,7 +2,6 @@ package main import ( "errors" - "fmt" "github.com/robarchibald/onedb" "math" "time" @@ -36,9 +35,9 @@ func (r *backendRedisSession) DeleteEmailSession(emailVerifyHash string) error { return nil } -func (r *backendRedisSession) CreateSession(userID int, email, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, +func (r *backendRedisSession) CreateSession(userID int, email, fullname, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, includeRememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*loginSession, *rememberMeSession, error) { - session := loginSession{userID, email, sessionHash, sessionRenewTimeUTC, sessionExpireTimeUTC} + session := loginSession{userID, email, fullname, sessionHash, sessionRenewTimeUTC, sessionExpireTimeUTC} err := r.saveSession(&session) if err != nil { return nil, nil, err diff --git a/backendRedisSession_test.go b/backendRedisSession_test.go index ecec630..3770208 100644 --- a/backendRedisSession_test.go +++ b/backendRedisSession_test.go @@ -14,13 +14,13 @@ func TestRedisCreateSession(t *testing.T) { // expired session error m := onedb.NewMock(nil, nil, nil) r := backendRedisSession{db: m, prefix: "test"} - _, _, err := r.CreateSession(1, "test@test.com", "hash", time.Now(), time.Now(), false, "selector", "token", time.Now(), time.Now()) + _, _, err := r.CreateSession(1, "test@test.com", "fullname", "hash", time.Now(), time.Now(), false, "selector", "token", time.Now(), time.Now()) if err == nil || len(m.QueriesRun()) != 0 { t.Error("expected error") } // expired rememberMe, but session should save. - _, _, err = r.CreateSession(1, "test@test.com", "hash", time.Now(), time.Now().AddDate(1, 0, 0), true, "selector", "token", time.Now(), time.Now()) + _, _, err = r.CreateSession(1, "test@test.com", "fullname", "hash", time.Now(), time.Now().AddDate(1, 0, 0), true, "selector", "token", time.Now(), time.Now()) if q := m.QueriesRun(); err == nil || len(q) != 1 || q[0].(*onedb.RedisCommand).Command != "SETEX" || len(q[0].(*onedb.RedisCommand).Args) != 3 || q[0].(*onedb.RedisCommand).Args[0] != "test/session/hash" { t.Error("expected error") } @@ -28,7 +28,7 @@ func TestRedisCreateSession(t *testing.T) { // success m = onedb.NewMock(nil, nil, nil) r = backendRedisSession{db: m, prefix: "test"} - session, rememberMe, err := r.CreateSession(1, "test@test.com", "hash", time.Now(), time.Now().AddDate(1, 0, 0), true, "selector", "token", time.Now(), time.Now().AddDate(1, 0, 0)) + session, rememberMe, err := r.CreateSession(1, "test@test.com", "fullname", "hash", time.Now(), time.Now().AddDate(1, 0, 0), true, "selector", "token", time.Now(), time.Now().AddDate(1, 0, 0)) if q := m.QueriesRun(); err != nil || len(q) != 2 || q[1].(*onedb.RedisCommand).Command != "SETEX" || len(q[1].(*onedb.RedisCommand).Args) != 3 || q[1].(*onedb.RedisCommand).Args[0] != "test/rememberMe/selector" { t.Error("expected success") } diff --git a/backend_test.go b/backend_test.go index a2a69f4..71b8cf5 100644 --- a/backend_test.go +++ b/backend_test.go @@ -18,11 +18,11 @@ func TestAuthError(t *testing.T) { } } -func TestGetLogin(t *testing.T) { - m := &mockBackend{GetUserLoginReturn: loginSuccess()} +func TestBackendLogin(t *testing.T) { + m := &mockBackend{LoginReturn: loginSuccess()} b := backend{u: m, l: m, s: m} - b.GetLogin("email", "provider") - if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "GetLogin" { + b.Login("email", "password") + if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "Login" { t.Error("Expected it would call backend", m.MethodsCalled) } } @@ -30,7 +30,7 @@ func TestGetLogin(t *testing.T) { func TestBackendCreateSession(t *testing.T) { m := &mockBackend{CreateSessionReturn: sessionRemember(time.Now(), time.Now())} b := backend{u: m, l: m, s: m} - b.CreateSession(1, "test@test.com", "hash", time.Now(), time.Now(), false, "", "", time.Now(), time.Now()) + b.CreateSession(1, "test@test.com", "fullname", "hash", time.Now(), time.Now(), false, "", "", time.Now(), time.Now()) if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "CreateSession" { t.Error("Expected it would call backend", m.MethodsCalled) } @@ -225,7 +225,7 @@ type getEmailSessionReturn struct { type mockBackend struct { backender - GetUserLoginReturn *LoginReturn + LoginReturn *LoginReturn ExpirationReturn *time.Time GetSessionReturn *SessionReturn CreateSessionReturn *SessionRememberReturn @@ -245,12 +245,12 @@ type mockBackend struct { MethodsCalled []string } -func (b *mockBackend) GetLogin(email, loginProvider string) (*userLogin, error) { - b.MethodsCalled = append(b.MethodsCalled, "GetLogin") - if b.GetUserLoginReturn == nil { - return nil, errors.New("GetUserLoginReturn not initialized") +func (b *mockBackend) Login(email, password string) (*userLogin, error) { + b.MethodsCalled = append(b.MethodsCalled, "Login") + if b.LoginReturn == nil { + return nil, errors.New("LoginReturn not initialized") } - return b.GetUserLoginReturn.Login, b.GetUserLoginReturn.Err + return b.LoginReturn.Login, b.LoginReturn.Err } func (b *mockBackend) GetSession(sessionHash string) (*loginSession, error) { @@ -261,7 +261,7 @@ func (b *mockBackend) GetSession(sessionHash string) (*loginSession, error) { return b.GetSessionReturn.Session, b.GetSessionReturn.Err } -func (b *mockBackend) CreateSession(userID int, email, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*loginSession, *rememberMeSession, error) { +func (b *mockBackend) CreateSession(userID int, email, fullname, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*loginSession, *rememberMeSession, error) { b.MethodsCalled = append(b.MethodsCalled, "CreateSession") if b.CreateSessionReturn == nil { return nil, nil, errors.New("CreateSessionReturn not initialized") @@ -376,7 +376,7 @@ func (b *mockBackend) Close() error { } func loginSuccess() *LoginReturn { - return &LoginReturn{&userLogin{Email: "test@test.com", ProviderKey: "$6$rounds=200000$pYt48w3PgDcRoCMx$sxbuADDhNI9nNe35HcrFYW7vpWLLMNiPBKcbqOgaRxTBYE8hePJWvmuN9dp.783JmDZBhDJRG956Wc/fzghhh."}, nil} // cryptoHash of "correctPassword" + return &LoginReturn{&userLogin{Email: "test@test.com"}, nil} } func loginErr() *LoginReturn { @@ -384,7 +384,7 @@ func loginErr() *LoginReturn { } func sessionSuccess(renewTimeUTC, expireTimeUTC time.Time) *SessionReturn { - return &SessionReturn{&loginSession{1, "test@test.com", "sessionHash", renewTimeUTC, expireTimeUTC}, nil} + return &SessionReturn{&loginSession{1, "test@test.com", "fullname", "sessionHash", renewTimeUTC, expireTimeUTC}, nil} } func sessionErr() *SessionReturn { @@ -400,7 +400,7 @@ func rememberErr() *RememberMeReturn { } func sessionRemember(renewTimeUTC, expireTimeUTC time.Time) *SessionRememberReturn { - return &SessionRememberReturn{&loginSession{1, "test@test.com", "sessionHash", renewTimeUTC, expireTimeUTC}, &rememberMeSession{TokenHash: "PEaenWxYddN6Q_NT1PiOYfz4EsZu7jRXRlpAsNpBU-A=", ExpireTimeUTC: expireTimeUTC, RenewTimeUTC: renewTimeUTC}, nil} + return &SessionRememberReturn{&loginSession{1, "test@test.com", "fullname", "sessionHash", renewTimeUTC, expireTimeUTC}, &rememberMeSession{TokenHash: "PEaenWxYddN6Q_NT1PiOYfz4EsZu7jRXRlpAsNpBU-A=", ExpireTimeUTC: expireTimeUTC, RenewTimeUTC: renewTimeUTC}, nil} } func sessionRememberErr() *SessionRememberReturn {