Skip to content

Commit

Permalink
Change GetLogin to Login. Include userid and full name in Redis session.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rob Archibald committed Jan 24, 2017
1 parent 2a94ddd commit 703ddf4
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 131 deletions.
13 changes: 5 additions & 8 deletions authStore.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,15 @@ func (s *authStore) login(email, password string, rememberMe bool) (*loginSessio

// add in check for DDOS attack. Slow down or lock out checks for same account
// or same IP with multiple failed attempts
login, err := s.backend.GetLogin(email, loginProviderDefaultName)
login, err := s.backend.Login(email, password)
if err != nil {
return nil, newLoggedError("Invalid username or password", err)
}

if err := cryptoHashEquals(password, login.ProviderKey); err != nil {
return nil, newLoggedError("Invalid username or password", err)
}
return s.createSession(email, rememberMe)
return s.createSession(email, login.UserID, login.FullName, rememberMe)
}

func (s *authStore) createSession(email string, rememberMe bool) (*loginSession, error) {
func (s *authStore) createSession(email string, userID int, fullname string, rememberMe bool) (*loginSession, error) {
var err error
var selector, token, tokenHash string
if rememberMe {
Expand All @@ -220,7 +217,7 @@ func (s *authStore) createSession(email string, rememberMe bool) (*loginSession,
return nil, newLoggedError("Problem generating sessionId", nil)
}

session, remember, err := s.backend.CreateSession(1, email, sessionHash, time.Now().UTC().Add(sessionRenewDuration), time.Now().UTC().Add(sessionExpireDuration), rememberMe, selector, tokenHash, time.Now().UTC().Add(rememberMeRenewDuration), time.Now().UTC().Add(rememberMeExpireDuration))
session, remember, err := s.backend.CreateSession(userID, email, fullname, sessionHash, time.Now().UTC().Add(sessionRenewDuration), time.Now().UTC().Add(sessionExpireDuration), rememberMe, selector, tokenHash, time.Now().UTC().Add(rememberMeRenewDuration), time.Now().UTC().Add(rememberMeExpireDuration))
if err != nil {
return nil, newLoggedError("Unable to create new session", err)
}
Expand Down Expand Up @@ -358,7 +355,7 @@ func (s *authStore) createProfile(fullName, organization, password, picturePath
return newLoggedError("Unable to create login", err)
}

_, err = s.createSession(session.Email, false)
_, err = s.createSession(session.Email, session.UserID, fullName, false)
if err != nil {
return err
}
Expand Down
42 changes: 17 additions & 25 deletions authStore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ func TestCreateSession(t *testing.T) {
for i, test := range createSessionTests {
backend := &mockBackend{CreateSessionReturn: test.CreateSessionReturn}
store := getAuthStore(nil, test.SessionCookie, test.RememberMeCookie, test.HasCookieGetError, test.HasCookiePutError, nil, backend)
val, err := store.createSession("[email protected]", test.RememberMe)
val, err := store.createSession("[email protected]", 1, "fullname", test.RememberMe)
methods := store.backend.(*mockBackend).MethodsCalled
if (err == nil && test.ExpectedErr != "" || err != nil && test.ExpectedErr != err.Error()) ||
!collectionEqual(test.MethodsCalled, methods) {
Expand All @@ -320,20 +320,20 @@ func TestAuthGetBasicAuth(t *testing.T) {
}

// Credential error
store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{GetUserLoginReturn: loginErr()})
store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{LoginReturn: loginErr()})
if _, err := store.GetBasicAuth(); err == nil || err.Error() != "Problem decoding credentials from basic auth" {
t.Error("expected error")
}

// login error
store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{GetUserLoginReturn: loginErr(), GetSessionReturn: sessionSuccess(futureTime, futureTime)})
store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{LoginReturn: loginErr(), GetSessionReturn: sessionSuccess(futureTime, futureTime)})
store.r = basicAuthRequest("[email protected]", "password")
if _, err := store.GetBasicAuth(); err == nil || err.Error() != "Invalid username or password" {
t.Error("expected error", err)
}

// login success
store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{GetUserLoginReturn: loginSuccess(), GetSessionReturn: sessionSuccess(futureTime, futureTime), CreateSessionReturn: sessionRemember(futureTime, futureTime)})
store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{LoginReturn: loginSuccess(), GetSessionReturn: sessionSuccess(futureTime, futureTime), CreateSessionReturn: sessionRemember(futureTime, futureTime)})
store.r = basicAuthRequest("[email protected]", "correctPassword")
if _, err := store.GetBasicAuth(); err != nil {
t.Error("expected success")
Expand Down Expand Up @@ -382,9 +382,9 @@ func TestAuthStoreEndToEnd(t *testing.T) {

// create profile
err = s.createProfile("fullName", "company", "password", "picturePath", 1, 1)
hashErr := cryptoHashEquals("password", b.Logins[0].ProviderKey)
hashErr := cryptoHashEquals("password", b.Logins[0].PasswordHash)
if err != nil || len(b.Users) != 1 || len(b.Sessions) != 1 || len(b.Logins) != 1 || b.Logins[0].Email != "[email protected]" || len(b.EmailSessions) != 0 || hashErr != nil {
t.Fatal("expected valid user, login and session", b.Logins[0], b.Logins[0].ProviderKey, hashErr)
t.Fatal("expected valid user, login and session", b.Logins[0], b.Logins[0].PasswordHash, hashErr)
}

// decode session cookie
Expand Down Expand Up @@ -643,7 +643,7 @@ var loginTests = []struct {
Password string
RememberMe bool
CreateSessionReturn *SessionRememberReturn
GetUserLoginReturn *LoginReturn
LoginReturn *LoginReturn
ErrReturn error
MethodsCalled []string
ExpectedResult *rememberMeSession
Expand All @@ -661,34 +661,26 @@ var loginTests = []struct {
ExpectedErr: passwordValidationMessage,
},
{
Scenario: "Can't get login",
Email: "[email protected]",
Password: "validPassword",
GetUserLoginReturn: loginErr(),
MethodsCalled: []string{"GetLogin"},
ExpectedErr: "Invalid username or password",
},
{
Scenario: "Incorrect password",
Email: "[email protected]",
Password: "wrongPassword",
GetUserLoginReturn: &LoginReturn{Login: &userLogin{Email: "[email protected]", ProviderKey: "1234"}},
MethodsCalled: []string{"GetLogin"},
ExpectedErr: "Invalid username or password",
Scenario: "Can't get login",
Email: "[email protected]",
Password: "validPassword",
LoginReturn: loginErr(),
MethodsCalled: []string{"Login"},
ExpectedErr: "Invalid username or password",
},
{
Scenario: "Got session",
Email: "[email protected]",
Password: "correctPassword",
GetUserLoginReturn: loginSuccess(),
LoginReturn: loginSuccess(),
CreateSessionReturn: sessionRemember(futureTime, futureTime),
MethodsCalled: []string{"GetLogin", "CreateSession", "InvalidateSession", "InvalidateRememberMe"},
MethodsCalled: []string{"Login", "CreateSession", "InvalidateSession", "InvalidateRememberMe"},
},
}

func TestAuthLogin(t *testing.T) {
for i, test := range loginTests {
backend := &mockBackend{GetUserLoginReturn: test.GetUserLoginReturn, ErrReturn: test.ErrReturn, CreateSessionReturn: test.CreateSessionReturn}
backend := &mockBackend{LoginReturn: test.LoginReturn, ErrReturn: test.ErrReturn, CreateSessionReturn: test.CreateSessionReturn}
store := getAuthStore(nil, nil, nil, false, false, nil, backend)
val, err := store.login(test.Email, test.Password, test.RememberMe)
methods := store.backend.(*mockBackend).MethodsCalled
Expand Down Expand Up @@ -771,7 +763,7 @@ func TestLoginJson(t *testing.T) {
var buf bytes.Buffer
buf.WriteString(`{"Email":"[email protected]", "Password":"password", "RememberMe":true}`)
r := &http.Request{Body: ioutil.NopCloser(&buf)}
backend := &mockBackend{GetUserLoginReturn: loginErr()}
backend := &mockBackend{LoginReturn: loginErr()}
store := getAuthStore(nil, nil, nil, true, false, nil, backend)
store.r = r
err := store.Login().(*authError).innerError
Expand Down
23 changes: 12 additions & 11 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ var errInvalidSessionHash = errors.New("DB: Invalid SessionHash")
var errRememberMeSelectorExists = errors.New("DB: RememberMe selector already exists")
var errUserNotFound = errors.New("DB: User not found")
var errLoginNotFound = errors.New("DB: Login not found")
var errInvalidCredentials = errors.New("DB: Invalid Credentials")
var errSessionNotFound = errors.New("DB: Session not found")
var errSessionAlreadyExists = errors.New("DB: Session already exists")
var errRememberMeNotFound = errors.New("DB: RememberMe not found")
Expand All @@ -27,7 +28,7 @@ type backender interface {

// LoginBackender. Write out since it contains duplicate BackendCloser
CreateLogin(userID int, email, passwordHash, fullName, homeDirectory string, uidNumber, gidNumber int, mailQuota, fileQuota string) (*userLogin, error)
GetLogin(email, loginProvider string) (*userLogin, error)
Login(email, password string) (*userLogin, error)
UpdateEmail(email string, password string, newEmail string) (*loginSession, error)
UpdatePassword(email string, oldPassword string, newPassword string) (*loginSession, error)

Expand All @@ -47,7 +48,7 @@ type userBackender interface {

type loginBackender interface {
CreateLogin(userID int, email, passwordHash, fullName, homeDirectory string, uidNumber, gidNumber int, mailQuota, fileQuota string) (*userLogin, error)
GetLogin(email, loginProvider string) (*userLogin, error)
Login(email, password string) (*userLogin, error)
UpdateEmail(email string, password string, newEmail string) (*loginSession, error)
UpdatePassword(email string, oldPassword string, newPassword string) (*loginSession, error)
backendCloser
Expand All @@ -59,7 +60,7 @@ type sessionBackender interface {
UpdateEmailSession(verifyHash string, userID int, email string) error
DeleteEmailSession(verifyHash string) error

CreateSession(userID int, email string, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*loginSession, *rememberMeSession, error)
CreateSession(userID int, email, fullname, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*loginSession, *rememberMeSession, error)
GetSession(sessionHash string) (*loginSession, error)
RenewSession(sessionHash string, renewTimeUTC time.Time) (*loginSession, error)
InvalidateSession(sessionHash string) error
Expand All @@ -86,15 +87,15 @@ type user struct {
}

type userLogin struct {
UserID int
Email string
LoginProviderID int
ProviderKey string
UserID int
Email string
FullName string
}

type loginSession struct {
UserID int
Email string
FullName string
SessionHash string
RenewTimeUTC time.Time
ExpireTimeUTC time.Time
Expand Down Expand Up @@ -159,12 +160,12 @@ type backend struct {
backendCloser
}

func (b *backend) GetLogin(email, loginProvider string) (*userLogin, error) {
return b.l.GetLogin(email, loginProvider)
func (b *backend) Login(email, password string) (*userLogin, error) {
return b.l.Login(email, password)
}

func (b *backend) CreateSession(userID int, email, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*loginSession, *rememberMeSession, error) {
return b.s.CreateSession(userID, email, sessionHash, sessionRenewTimeUTC, sessionExpireTimeUTC, rememberMe, rememberMeSelector, rememberMeTokenHash, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC)
func (b *backend) CreateSession(userID int, email, fullname, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*loginSession, *rememberMeSession, error) {
return b.s.CreateSession(userID, email, fullname, sessionHash, sessionRenewTimeUTC, sessionExpireTimeUTC, rememberMe, rememberMeSelector, rememberMeTokenHash, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC)
}

func (b *backend) GetSession(sessionHash string) (*loginSession, error) {
Expand Down
28 changes: 16 additions & 12 deletions backendLDAPLogin.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,29 @@ func newBackendLDAPLogin(server string, port int, bindDn, password, baseDn, user
}

type ldapData struct {
UID []string
UserPassword []string
UIDNumber []string
GIDNumber []string
HomeDirectory []string
UID string
DbUserId string
Cn string
}

func (l *backendLDAPLogin) GetLogin(email, loginProvider string) (*userLogin, error) {
req := ldap.NewSearchRequest(l.baseDn, ldap.ScopeSingleLevel, ldap.NeverDerefAliases, 0, 0, false, fmt.Sprintf(l.userLoginFilter, email), []string{"uid", "userPassword", "uidNumber", "gidNumber", "homeDirectory"}, nil)
func (l *backendLDAPLogin) Login(email, password string) (*userLogin, error) {
// check credentials
err := l.db.Execute(ldap.NewSimpleBindRequest(email, password, nil))
if err != nil {
return nil, err
}
// get login info
req := ldap.NewSearchRequest(l.baseDn, ldap.ScopeSingleLevel, ldap.NeverDerefAliases, 0, 0, false, fmt.Sprintf(l.userLoginFilter, email), []string{"uid", "dbUserId", "cn"}, nil)
data := &ldapData{}
err := l.db.QueryStructRow(req, data)
err = l.db.QueryStructRow(req, data)
if err != nil {
return nil, err
}
var password string
if len(data.UserPassword) != 0 {
password = data.UserPassword[0]
dbUserID, err := strconv.Atoi(data.DbUserId)
if err != nil {
return nil, err
}
return &userLogin{ProviderKey: password}, nil
return &userLogin{UserID: dbUserID, Email: data.UID, FullName: data.Cn}, nil
}

/**************** TODO: create different type of user if not using file and mail quotas **********************/
Expand Down
21 changes: 12 additions & 9 deletions backendLDAPLogin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,34 @@ func TestNewBackendLDAPLogin(t *testing.T) {
t.Fatal("unable to login", err)
}

_, err = l.GetLogin("[email protected]", "")
_, err = l.Login("[email protected]", "")
if err == nil {
t.Fatal("Expected no results", err)
}
}

func TestLdapGetLogin(t *testing.T) {
func TestLdapLogin(t *testing.T) {
// success
data := ldapData{UserPassword: []string{"password"}}
data := ldapData{UID: "email", DbUserId: "1234"}
m := onedb.NewMock(nil, nil, data)
l := backendLDAPLogin{db: m, userLoginFilter: "%s"}
login, err := l.GetLogin("email", "provider")
if err != nil || login.ProviderKey != "password" {
t.Error("expected to find data", login)
login, err := l.Login("email", "password")
if err != nil || login.Email != "email" {
t.Error("expected to find data", login, err)
}

queries := m.QueriesRun()
if _, ok := queries[0].(*ldap.SearchRequest); !ok {
t.Error("expected ldap search request")
if _, ok := queries[0].(*ldap.SimpleBindRequest); !ok {
t.Error("expected ldap bind request first")
}
if _, ok := queries[1].(*ldap.SearchRequest); !ok {
t.Error("expected ldap searc request next")
}

// error
m = onedb.NewMock(nil, nil, nil)
l = backendLDAPLogin{db: m, userLoginFilter: "%s"}
_, err = l.GetLogin("email", "provider")
_, err = l.Login("email", "password")
if err == nil {
t.Error("expected error")
}
Expand Down
Loading

0 comments on commit 703ddf4

Please sign in to comment.