diff --git a/authStore.go b/authStore.go index 8bd8002..cc17082 100644 --- a/authStore.go +++ b/authStore.go @@ -40,7 +40,7 @@ type AuthStorer interface { Login(w http.ResponseWriter, r *http.Request) (*LoginSession, error) Register(w http.ResponseWriter, r *http.Request) error CreateProfile(w http.ResponseWriter, r *http.Request) (*LoginSession, error) - VerifyEmail(w http.ResponseWriter, r *http.Request) (string, string, error) + VerifyEmail(w http.ResponseWriter, r *http.Request) (string, map[string]interface{}, error) CreateSecondaryEmail(w http.ResponseWriter, r *http.Request) error SetPrimaryEmail(w http.ResponseWriter, r *http.Request) error UpdatePassword(w http.ResponseWriter, r *http.Request) error @@ -209,63 +209,49 @@ func (s *authStore) login(w http.ResponseWriter, r *http.Request, email, passwor // add in check for DDOS attack. Slow down or lock out checks for same account // or same IP with multiple failed attempts - login, err := s.backend.Login(email, password) + login, err := s.backend.LoginAndGetUser(email, password) if err != nil { return nil, newLoggedError("Invalid username or password", err) } - return s.createSession(w, r, email, login.UserID, login.FullName, rememberMe) + return s.createSession(w, r, login.UserID, email, login.Info, rememberMe) } func (s *authStore) OAuthLogin(w http.ResponseWriter, r *http.Request) (string, error) { - email, fullname, err := getOAuthCredentials(r) + email, info, err := getOAuthCredentials(r) if err != nil { return "", err } - return s.oauthLogin(w, r, email, fullname) + return s.oauthLogin(w, r, email, info) } -func (s *authStore) oauthLogin(w http.ResponseWriter, r *http.Request, email, fullname string) (string, error) { - var userID string +func (s *authStore) oauthLogin(w http.ResponseWriter, r *http.Request, email string, info map[string]interface{}) (string, error) { user, err := s.backend.GetUser(email) if user == nil || err != nil { - userID, err = s.backend.AddUser(email) - if err != nil { - return "", newLoggedError("Failed to create new user in database", err) - } - - err = s.backend.UpdateUser(userID, fullname, "", "") - if err != nil { - return "", newLoggedError("Unable to update user", err) - } - } else { - userID = user.UserID - } - - if _, err := s.backend.GetLogin(email); err != nil { - _, err = s.backend.CreateLogin(userID, email, "", fullname) + user, err = s.backend.AddUserFull(email, "", info) if err != nil { return "", newLoggedError("Unable to create login", err) } } - session, err := s.createSession(w, r, email, userID, fullname, false) + session, err := s.createSession(w, r, user.UserID, email, info, false) if err != nil { return "", err } return session.CSRFToken, nil } -func getOAuthCredentials(r *http.Request) (string, string, error) { - var fullname, email, email2 string +func getOAuthCredentials(r *http.Request) (string, map[string]interface{}, error) { + var email, email2 string + info := make(map[string]interface{}) authHeader := r.Header.Get("Authorization") if authHeader == "" { - return "", "", fmt.Errorf("No authorization found") + return "", nil, fmt.Errorf("No authorization found") } authHeaderParts := strings.Split(authHeader, " ") if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", "", fmt.Errorf("Authorization header format must be Bearer {token}") + return "", nil, fmt.Errorf("Authorization header format must be Bearer {token}") } // need to actually parse here and handle error @@ -282,19 +268,19 @@ func getOAuthCredentials(r *http.Request) (string, string, error) { claims, ok := token.Claims.(jwt.MapClaims) if ok { - fullname = fmt.Sprintf("%v", claims["name"]) + info["fullname"] = fmt.Sprintf("%v", claims["name"]) email = fmt.Sprintf("%v", claims["unique_name"]) fmt.Println("unique_name:", email) email2 = fmt.Sprintf("%v", claims["email"]) fmt.Println("email:", email2) - if email == "" || fullname == "" { - return "", "", fmt.Errorf("expected email and fullname") + if email == "" || info["fullname"] == "" { + return "", nil, fmt.Errorf("expected email and fullname") } } - return email, fullname, nil + return email, info, nil } -func (s *authStore) createSession(w http.ResponseWriter, r *http.Request, email, userID, fullname string, rememberMe bool) (*LoginSession, error) { +func (s *authStore) createSession(w http.ResponseWriter, r *http.Request, userID, email string, info map[string]interface{}, rememberMe bool) (*LoginSession, error) { var err error var selector, token, tokenHash string if rememberMe { @@ -313,7 +299,7 @@ func (s *authStore) createSession(w http.ResponseWriter, r *http.Request, email, return nil, newLoggedError("Problem generating csrf token", nil) } - session, err := s.backend.CreateSession(userID, email, fullname, sessionHash, csrfToken, time.Now().UTC().Add(sessionRenewDuration), time.Now().UTC().Add(sessionExpireDuration)) + session, err := s.backend.CreateSession(userID, email, info, sessionHash, csrfToken, time.Now().UTC().Add(sessionRenewDuration), time.Now().UTC().Add(sessionExpireDuration)) if err != nil { return nil, newLoggedError("Unable to create new session", err) } @@ -368,10 +354,10 @@ func (s *authStore) Register(w http.ResponseWriter, r *http.Request) error { if err != nil { return newAuthError("Unable to get email", err) } - return s.register(r, registration.Email, registration.DestinationURL) + return s.register(r, registration.Email, registration.Info) } -func (s *authStore) register(r *http.Request, email, destinationURL string) error { +func (s *authStore) register(r *http.Request, email string, info map[string]interface{}) error { if !isValidEmail(email) { return newAuthError("Invalid email", nil) } @@ -381,7 +367,7 @@ func (s *authStore) register(r *http.Request, email, destinationURL string) erro return newAuthError("User already registered", err) } - verifyCode, err := s.addEmailSession(email, destinationURL) + verifyCode, err := s.addEmailSession(email, info) if err != nil { return newLoggedError("Unable to save user", err) } @@ -406,7 +392,7 @@ func getBaseURL(url string) string { return url[:protoIndex+3+firstSlash] } -func (s *authStore) addEmailSession(email, destinationURL string) (string, error) { +func (s *authStore) addEmailSession(email string, info map[string]interface{}) (string, error) { verifyCode, verifyHash, err := generateStringAndHash() if err != nil { return "", newLoggedError("Problem generating email confirmation code", err) @@ -417,7 +403,7 @@ func (s *authStore) addEmailSession(email, destinationURL string) (string, error return "", newLoggedError("Problem generating csrf token", err) } - err = s.backend.CreateEmailSession(email, verifyHash, csrfToken, destinationURL) + err = s.backend.CreateEmailSession(email, info, verifyHash, csrfToken) if err != nil { return "", newLoggedError("Problem adding user to database", err) } @@ -434,10 +420,10 @@ func (s *authStore) CreateProfile(w http.ResponseWriter, r *http.Request) (*Logi if csrfToken == "" { return nil, errMissingCSRF } - return s.createProfile(w, r, csrfToken, profile.FullName, profile.Organization, profile.Password, profile.PicturePath) + return s.createProfile(w, r, csrfToken, profile.Password, profile.Info) } -func (s *authStore) createProfile(w http.ResponseWriter, r *http.Request, csrfToken, fullName, organization, password, picturePath string) (*LoginSession, error) { +func (s *authStore) createProfile(w http.ResponseWriter, r *http.Request, csrfToken, password string, info map[string]interface{}) (*LoginSession, error) { emailCookie, err := s.getEmailCookie(w, r) if err != nil || emailCookie.EmailVerificationCode == "" { return nil, newLoggedError("Unable to get email verification cookie", err) @@ -456,7 +442,7 @@ func (s *authStore) createProfile(w http.ResponseWriter, r *http.Request, csrfTo return nil, errInvalidCSRF } - err = s.backend.UpdateUser(session.UserID, fullName, organization, picturePath) + err = s.backend.UpdateUser(session.UserID, password, info) if err != nil { return nil, newLoggedError("Unable to update user", err) } @@ -466,12 +452,7 @@ func (s *authStore) createProfile(w http.ResponseWriter, r *http.Request, csrfTo return nil, newLoggedError("Error while creating profile", err) } - _, err = s.backend.CreateLogin(session.UserID, session.Email, password, fullName) - if err != nil { - return nil, newLoggedError("Unable to create login", err) - } - - ls, err := s.createSession(w, r, session.Email, session.UserID, fullName, false) + ls, err := s.createSession(w, r, session.UserID, session.Email, info, false) if err != nil { return nil, err } @@ -481,52 +462,52 @@ func (s *authStore) createProfile(w http.ResponseWriter, r *http.Request, csrfTo } // move to sessionStore -func (s *authStore) VerifyEmail(w http.ResponseWriter, r *http.Request) (string, string, error) { +func (s *authStore) VerifyEmail(w http.ResponseWriter, r *http.Request) (string, map[string]interface{}, error) { verify, err := getVerificationCode(r) if err != nil { - return "", "", newAuthError("Unable to get verification email from JSON", err) + return "", nil, newAuthError("Unable to get verification email from JSON", err) } return s.verifyEmail(w, r, verify.EmailVerificationCode) } -func (s *authStore) verifyEmail(w http.ResponseWriter, r *http.Request, emailVerificationCode string) (string, string, error) { +func (s *authStore) verifyEmail(w http.ResponseWriter, r *http.Request, emailVerificationCode string) (string, map[string]interface{}, error) { if !strings.HasSuffix(emailVerificationCode, "=") { // add back the "=" then decode emailVerificationCode = emailVerificationCode + "=" } emailVerifyHash, err := decodeStringToHash(emailVerificationCode) if err != nil { - return "", "", newLoggedError("Invalid verification code", err) + return "", nil, newLoggedError("Invalid verification code", err) } session, err := s.backend.GetEmailSession(emailVerifyHash) if err != nil { - return "", "", newLoggedError("Failed to verify email", err) + return "", nil, newLoggedError("Failed to verify email", err) } - userID, err := s.backend.AddUser(session.Email) + userID, err := s.backend.AddUser(session.Email, session.Info) if err != nil { user, err := s.backend.GetUser(session.Email) if err != nil { - return "", "", newLoggedError("Failed to get user in database", err) + return "", nil, newLoggedError("Failed to get user in database", err) } userID = user.UserID } err = s.backend.UpdateEmailSession(emailVerifyHash, userID) if err != nil { - return "", "", newLoggedError("Failed to update email session", err) + return "", nil, newLoggedError("Failed to update email session", err) } err = s.saveEmailCookie(w, r, emailVerificationCode, time.Now().UTC().Add(emailExpireDuration)) if err != nil { - return "", "", newLoggedError("Failed to save email cookie", err) + return "", nil, newLoggedError("Failed to save email cookie", err) } err = s.mailer.SendWelcome(session.Email, nil) if err != nil { - return "", "", newLoggedError("Failed to send welcome email", err) + return "", nil, newLoggedError("Failed to send welcome email", err) } - return session.CSRFToken, session.DestinationURL, nil + return session.CSRFToken, session.Info, nil } func (s *authStore) CreateSecondaryEmail(w http.ResponseWriter, r *http.Request) error { @@ -598,8 +579,8 @@ func (s *authStore) saveRememberMeCookie(w http.ResponseWriter, r *http.Request, } type registration struct { - Email string `json:"email"` - DestinationURL string `json:"destinationURL"` + Email string `json:"email"` + Info map[string]interface{} `json:"info"` } func getRegistration(r *http.Request) (*registration, error) { @@ -649,15 +630,14 @@ func generateThumbnail(filename string) (string, error) { } type profile struct { - FullName string - Organization string - Password string - PicturePath string + Password string + Info map[string]interface{} } func getProfile(r *http.Request) (*profile, error) { - profile := &profile{} + profile := &profile{Info: make(map[string]interface{})} r.ParseMultipartForm(32 << 20) // 32 MB file + file, handler, err := r.FormFile("file") if err == nil { // received the file, so save it defer file.Close() @@ -668,14 +648,16 @@ func getProfile(r *http.Request) (*profile, error) { } defer f.Close() io.Copy(f, file) - profile.PicturePath = handler.Filename + profile.Info["filename"] = handler.Filename } - // ************** TODO: change to generic way to get other parameters ******************* - - profile.FullName = r.FormValue("fullName") - profile.Organization = r.FormValue("organization") - profile.Password = r.FormValue("password") + for key := range r.Form { // save form values + if key == "password" { + profile.Password = r.FormValue(key) + } else { + profile.Info[key] = r.FormValue(key) + } + } return profile, nil } diff --git a/authStore_e2e_test.go b/authStore_e2e_test.go index d4e6bf7..18033e0 100644 --- a/authStore_e2e_test.go +++ b/authStore_e2e_test.go @@ -102,7 +102,7 @@ func _register(email string, b *backendMemory, m *TextMailer) (string, error) { // register new user // adds to users, logins and sessions - err := s.register(r, email, "returnURL") + err := s.register(r, email, map[string]interface{}{"key": "value"}) if err != nil { return "", err } @@ -112,8 +112,8 @@ func _register(email string, b *backendMemory, m *TextMailer) (string, error) { // get code from "email" data := m.MessageData.(*sendVerifyParams) emailVerifyHash, _ := decodeStringToHash(data.VerificationCode + "=") - if b.EmailSessions[lenSessions].Email != email || b.EmailSessions[lenSessions].EmailVerifyHash != emailVerifyHash || b.EmailSessions[lenSessions].DestinationURL != "returnURL" { - return "", errors.Errorf("expected to have valid session: %s, %v, %v", b.EmailSessions[lenSessions].Email, b.EmailSessions[lenSessions].EmailVerifyHash != emailVerifyHash, b.EmailSessions[lenSessions].DestinationURL != "returnURL") + if b.EmailSessions[lenSessions].Email != email || b.EmailSessions[lenSessions].EmailVerifyHash != emailVerifyHash || b.EmailSessions[lenSessions].Info == nil || b.EmailSessions[lenSessions].Info["key"] != "value" { + return "", errors.Errorf("expected to have valid session: %s, %v, %v", b.EmailSessions[lenSessions].Email, b.EmailSessions[lenSessions].EmailVerifyHash != emailVerifyHash, b.EmailSessions[lenSessions].Info != nil && b.EmailSessions[lenSessions].Info["key"] != "value") } return data.VerificationCode, nil @@ -129,12 +129,12 @@ func _verify(verifyCode string, b *backendMemory, m *TextMailer) (string, *email emailSession := b.getEmailSessionByEmailVerifyHash(emailVerifyHash) // verify Email. Should 1. add user to b.Users, 2. set UserID in EmailSession, 3. add session - csrfToken, destinationURL, err := s.verifyEmail(nil, r, verifyCode) + csrfToken, info, err := s.verifyEmail(nil, r, verifyCode) if err != nil { return "", nil, err } - if destinationURL != "returnURL" { - return "", nil, errors.Errorf("expected to get back destinationURL that we entered during register phase") + if info == nil || info["key"] != "value" { + return "", nil, errors.Errorf("expected to get back info that we entered during register phase") } if len(b.Users) != +lenUsers+1 || len(b.EmailSessions) != lenEmailSessions { return "", nil, errors.Errorf("expected to add user and update existing session: %v, %v", len(b.Users) != lenUsers+1, len(b.EmailSessions) != lenEmailSessions) @@ -162,32 +162,25 @@ func _createProfile(fullName, password string, emailCookie *emailCookie, b *back emailVerifyHash, _ := decodeStringToHash(emailCookie.EmailVerificationCode) oldEmailSession := b.getEmailSessionByEmailVerifyHash(emailVerifyHash) var user *user - var oldLogin *userLoginMemory if oldEmailSession != nil { user = b.getUserByEmail(oldEmailSession.Email) - oldLogin = b.getLoginByEmail(oldEmailSession.Email) } // create profile - newSession, err := s.createProfile(nil, r, csrfToken, fullName, "company", password, "picturePath") + newSession, err := s.createProfile(nil, r, csrfToken, password, map[string]interface{}{"myKey": "value"}) if err != nil { return "", nil, err } // check password was saved correctly h := &hashStore{} - hashErr := h.HashEquals(password, b.Logins[0].PasswordHash) - if hashErr != nil { + passwordHash, err := h.Hash(password) + if err != nil { return "", nil, err } // check user was saved correctly - if user == nil || user.FullName != fullName || oldEmailSession == nil || user.PrimaryEmail != oldEmailSession.Email || user.UserID != oldEmailSession.UserID { + if user == nil || oldEmailSession == nil || user.PrimaryEmail != oldEmailSession.Email || user.UserID != oldEmailSession.UserID || user.PasswordHash != passwordHash || user.Info == nil || user.Info["myKey"] != "value" { return "", nil, errors.Errorf("expected user to be updated with expected values: %v, %v", user, oldEmailSession) } - // check login was saved correctly - login := b.getLoginByEmail(oldEmailSession.Email) - if oldLogin != nil || login == nil || login.UserID != user.UserID || login.FullName != fullName { - return "", nil, errors.Errorf("expected new login to be created with expected values: %v, %v", oldLogin, login) - } // verify email session was deleted if emailSession := b.getEmailSessionByEmailVerifyHash(emailVerifyHash); emailSession != nil { return "", nil, errors.Errorf("expected Email session to be deleted: %v", emailSession) @@ -197,7 +190,7 @@ func _createProfile(fullName, password string, emailCookie *emailCookie, b *back sessionCookie := c.cookies["Session"].(*sessionCookie) sessionHash, _ := decodeStringToHash(sessionCookie.SessionID) session := b.getSessionByHash(sessionHash) - if session == nil || session.SessionHash != sessionHash || session.Email != oldEmailSession.Email || session.UserID != oldEmailSession.UserID || session.FullName != fullName { + if session == nil || session.SessionHash != sessionHash || session.Email != oldEmailSession.Email || session.UserID != oldEmailSession.UserID || session.Info == nil || session.Info["myKey"] != "value" { return "", nil, errors.Errorf("expected session to be created, %v", session) } return newSession.CSRFToken, sessionCookie, nil @@ -207,7 +200,6 @@ func _login(email, password string, remember bool, clientSessionCookie *sessionC r := &http.Request{Header: http.Header{}} c := NewMockCookieStore(map[string]interface{}{"Session": clientSessionCookie, "RememberMe": rememberCookie}, false, false) s := &authStore{b, m, c} - lenLogins := len(b.Logins) lenUsers := len(b.Users) // login @@ -217,12 +209,12 @@ func _login(email, password string, remember bool, clientSessionCookie *sessionC } // verify session is valid user := b.getUserByEmail(email) - if session == nil || session.Email != email || session.FullName != user.FullName || session.UserID != user.UserID { + if session == nil || session.Email != email || session.UserID != user.UserID || session.Info == nil || session.Info["myKey"] != "value" { return "", nil, nil, errors.Errorf("session wasn't created correctly, %v", session) } - // verify no logins, or users were created - if lenLogins != len(b.Logins) || lenUsers != len(b.Users) { - return "", nil, nil, errors.Errorf("expected no new users or logins to be created, %v, %v", lenLogins != len(b.Logins), lenUsers != len(b.Users)) + // verify no users were created + if lenUsers != len(b.Users) { + return "", nil, nil, errors.Errorf("expected no new users to be created, %v", lenUsers != len(b.Users)) } // verify old session and old remember me were deleted if clientSessionCookie != nil { @@ -243,7 +235,7 @@ func _login(email, password string, remember bool, clientSessionCookie *sessionC newSessionCookie := c.cookies["Session"].(*sessionCookie) sessionHash, _ := decodeStringToHash(newSessionCookie.SessionID) newSession := b.getSessionByHash(sessionHash) - if newSession == nil || newSession.SessionHash != sessionHash || newSession.Email != session.Email || newSession.UserID != session.UserID || newSession.FullName != session.FullName { + if newSession == nil || newSession.SessionHash != sessionHash || newSession.Email != session.Email || newSession.UserID != session.UserID || newSession.Info == nil || newSession.Info["myKey"] != "value" { return "", nil, nil, errors.Errorf("expected session to be created in database that matches return from function, %v", newSession) } // verify rememberMe cookie diff --git a/authStore_test.go b/authStore_test.go index 4c00e5d..588f5e5 100644 --- a/authStore_test.go +++ b/authStore_test.go @@ -345,7 +345,7 @@ func TestCreateSession(t *testing.T) { for i, test := range createSessionTests { backend := &mockBackend{CreateSessionReturn: test.CreateSessionReturn, CreateRememberMeReturn: test.CreateRememberMeReturn} store := getAuthStore(nil, test.SessionCookie, test.RememberMeCookie, test.HasCookieGetError, test.HasCookiePutError, nil, backend) - val, err := store.createSession(nil, &http.Request{}, "test@test.com", "1", "fullname", test.RememberMe) + val, err := store.createSession(nil, &http.Request{}, "test@test.com", "1", map[string]interface{}{"key": "value"}, test.RememberMe) methods := store.backend.(*mockBackend).MethodsCalled if (err == nil && test.ExpectedErr != "" || err != nil && test.ExpectedErr != err.Error()) || !collectionEqual(test.MethodsCalled, methods) { @@ -364,19 +364,19 @@ func TestAuthGetBasicAuth(t *testing.T) { } // Credential error - store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{LoginReturn: loginErr()}) + store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{}) if _, err := store.GetBasicAuth(nil, &http.Request{}); err == nil || err.Error() != "Problem decoding credentials from basic auth" { t.Error("expected error") } // login error - store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{LoginReturn: loginErr(), GetSessionReturn: sessionSuccess(futureTime, futureTime)}) + store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{LoginAndGetUserReturn: userErr(), GetSessionReturn: sessionSuccess(futureTime, futureTime)}) if _, err := store.GetBasicAuth(nil, basicAuthRequest("test@test.com", "password")); err == nil || err.Error() != "Invalid username or password" { t.Error("expected error", err) } // login success - store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{LoginReturn: loginSuccess(), GetSessionReturn: sessionSuccess(futureTime, futureTime), CreateSessionReturn: sessionSuccess(futureTime, futureTime)}) + store = getAuthStore(nil, nil, nil, true, false, nil, &mockBackend{LoginAndGetUserReturn: userSuccess(), GetSessionReturn: sessionSuccess(futureTime, futureTime), CreateSessionReturn: sessionSuccess(futureTime, futureTime)}) if _, err := store.GetBasicAuth(nil, basicAuthRequest("test@test.com", "correctPassword")); err != nil { t.Error("expected success") } @@ -394,7 +394,7 @@ var registerTests = []struct { Scenario string Email string CreateEmailSessionReturn error - GetUserReturn *GetUserReturn + GetUserReturn *UserReturn MailErr error MethodsCalled []string ExpectedErr string @@ -407,7 +407,7 @@ var registerTests = []struct { { Scenario: "User Already Exists", Email: "validemail@test.com", - GetUserReturn: getUserSuccess(), + GetUserReturn: userSuccess(), MethodsCalled: []string{"GetUser"}, ExpectedErr: "User already registered", }, @@ -415,13 +415,13 @@ var registerTests = []struct { Scenario: "Add User error", Email: "validemail@test.com", CreateEmailSessionReturn: errors.New("failed"), - GetUserReturn: getUserErr(), + GetUserReturn: userErr(), MethodsCalled: []string{"GetUser", "CreateEmailSession"}, ExpectedErr: "Unable to save user", }, { Scenario: "Can't send email", - GetUserReturn: getUserErr(), + GetUserReturn: userErr(), Email: "validemail@test.com", MailErr: errors.New("fail"), MethodsCalled: []string{"GetUser", "CreateEmailSession"}, @@ -429,7 +429,7 @@ var registerTests = []struct { }, { Scenario: "Send verify email", - GetUserReturn: getUserErr(), + GetUserReturn: userErr(), Email: "validemail@test.com", MethodsCalled: []string{"GetUser", "CreateEmailSession"}, }, @@ -439,7 +439,7 @@ func TestAuthRegister(t *testing.T) { for i, test := range registerTests { backend := &mockBackend{ErrReturn: test.CreateEmailSessionReturn, GetUserReturn: test.GetUserReturn} store := getAuthStore(nil, nil, nil, false, false, test.MailErr, backend) - err := store.register(&http.Request{}, test.Email, "destinationURL") + err := store.register(&http.Request{}, test.Email, map[string]interface{}{"key": "value"}) methods := store.backend.(*mockBackend).MethodsCalled if (err == nil && test.ExpectedErr != "" || err != nil && test.ExpectedErr != err.Error()) || !collectionEqual(test.MethodsCalled, methods) { @@ -452,10 +452,10 @@ var createProfileTests = []struct { Scenario string HasCookieGetError bool HasCookiePutError bool - getEmailSessionReturn *getEmailSessionReturn + getEmailSessionReturn *EmailSessionReturn EmailCookie *emailCookie CSRFToken string - LoginReturn *LoginReturn + LoginReturn *UserReturn UpdateUserErr error DeleteEmailSessionErr error CreateSessionReturn *SessionReturn @@ -493,7 +493,7 @@ var createProfileTests = []struct { EmailCookie: &emailCookie{EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0=", ExpireTimeUTC: time.Now()}, getEmailSessionReturn: getEmailSessionSuccess(), UpdateUserErr: errors.New("failed"), - LoginReturn: loginErr(), + LoginReturn: userErr(), MethodsCalled: []string{"GetEmailSession", "UpdateUser"}, ExpectedErr: "Unable to update user", }, @@ -503,27 +503,18 @@ var createProfileTests = []struct { EmailCookie: &emailCookie{EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0=", ExpireTimeUTC: time.Now()}, getEmailSessionReturn: getEmailSessionSuccess(), DeleteEmailSessionErr: errors.New("failed"), - LoginReturn: loginErr(), + LoginReturn: userErr(), MethodsCalled: []string{"GetEmailSession", "UpdateUser", "DeleteEmailSession"}, ExpectedErr: "Error while creating profile", }, - { - Scenario: "Error Creating login", - CSRFToken: "csrfToken", - EmailCookie: &emailCookie{EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0=", ExpireTimeUTC: time.Now()}, - getEmailSessionReturn: getEmailSessionSuccess(), - LoginReturn: loginErr(), - MethodsCalled: []string{"GetEmailSession", "UpdateUser", "DeleteEmailSession", "CreateLogin"}, - ExpectedErr: "Unable to create login", - }, { Scenario: "Error creating session", CSRFToken: "csrfToken", EmailCookie: &emailCookie{EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0=", ExpireTimeUTC: time.Now()}, getEmailSessionReturn: getEmailSessionSuccess(), - LoginReturn: loginSuccess(), + LoginReturn: userSuccess(), CreateSessionReturn: sessionErr(), - MethodsCalled: []string{"GetEmailSession", "UpdateUser", "DeleteEmailSession", "CreateLogin", "CreateSession"}, + MethodsCalled: []string{"GetEmailSession", "UpdateUser", "DeleteEmailSession", "CreateSession"}, ExpectedErr: "Unable to create new session", }, { @@ -531,17 +522,17 @@ var createProfileTests = []struct { CSRFToken: "csrfToken", EmailCookie: &emailCookie{EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0=", ExpireTimeUTC: time.Now()}, getEmailSessionReturn: getEmailSessionSuccess(), - LoginReturn: loginSuccess(), + LoginReturn: userSuccess(), CreateSessionReturn: sessionSuccess(futureTime, futureTime), - MethodsCalled: []string{"GetEmailSession", "UpdateUser", "DeleteEmailSession", "CreateLogin", "CreateSession"}, + MethodsCalled: []string{"GetEmailSession", "UpdateUser", "DeleteEmailSession", "CreateSession"}, }, } func TestAuthCreateProfile(t *testing.T) { for i, test := range createProfileTests { - backend := &mockBackend{ErrReturn: test.UpdateUserErr, getEmailSessionReturn: test.getEmailSessionReturn, CreateLoginReturn: test.LoginReturn, CreateSessionReturn: test.CreateSessionReturn, DeleteEmailSessionErr: test.DeleteEmailSessionErr} + backend := &mockBackend{UpdateUserErr: test.UpdateUserErr, getEmailSessionReturn: test.getEmailSessionReturn, CreateSessionReturn: test.CreateSessionReturn, DeleteEmailSessionErr: test.DeleteEmailSessionErr} store := getAuthStore(test.EmailCookie, nil, nil, test.HasCookieGetError, test.HasCookiePutError, nil, backend) - _, err := store.createProfile(nil, &http.Request{}, test.CSRFToken, "name", "organization", "password", "path") + _, err := store.createProfile(nil, &http.Request{}, test.CSRFToken, "password", map[string]interface{}{"key": "value"}) methods := store.backend.(*mockBackend).MethodsCalled if (err == nil && test.ExpectedErr != "" || err != nil && test.ExpectedErr != err.Error()) || !collectionEqual(test.MethodsCalled, methods) { @@ -554,13 +545,13 @@ var verifyEmailTests = []struct { Scenario string EmailVerificationCode string HasCookiePutError bool - getEmailSessionReturn *getEmailSessionReturn + getEmailSessionReturn *EmailSessionReturn AddUserErr error UpdateEmailSessionErr error MailErr error MethodsCalled []string ExpectedErr string - DestinatinURL string + InfoValue string }{ { Scenario: "Decode error", @@ -611,7 +602,7 @@ var verifyEmailTests = []struct { Scenario: "Email sent", EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0", getEmailSessionReturn: getEmailSessionSuccess(), - DestinatinURL: "destinationURL", + InfoValue: "value", MethodsCalled: []string{"GetEmailSession", "AddUser", "UpdateEmailSession"}, }, } @@ -620,26 +611,26 @@ func TestAuthVerifyEmail(t *testing.T) { for i, test := range verifyEmailTests { backend := &mockBackend{getEmailSessionReturn: test.getEmailSessionReturn, AddUserErr: test.AddUserErr, UpdateEmailSessionErr: test.UpdateEmailSessionErr} store := getAuthStore(nil, nil, nil, false, test.HasCookiePutError, test.MailErr, backend) - _, destinationURL, err := store.verifyEmail(nil, &http.Request{}, test.EmailVerificationCode) + _, info, err := store.verifyEmail(nil, &http.Request{}, test.EmailVerificationCode) methods := store.backend.(*mockBackend).MethodsCalled if (err == nil && test.ExpectedErr != "" || err != nil && test.ExpectedErr != err.Error()) || - !collectionEqual(test.MethodsCalled, methods) || test.DestinatinURL != destinationURL { - t.Errorf("Scenario[%d] failed: %s\nexpected err:%v\tactual err:%v\nexpected methods: %s\tactual methods: %s, destination: %s", i, test.Scenario, test.ExpectedErr, err, test.MethodsCalled, methods, destinationURL) + !collectionEqual(test.MethodsCalled, methods) || test.InfoValue != "" && (info == nil || info["key"] != test.InfoValue) { + t.Errorf("Scenario[%d] failed: %s\nexpected err:%v\tactual err:%v\nexpected methods: %s\tactual methods: %s, info: %v", i, test.Scenario, test.ExpectedErr, err, test.MethodsCalled, methods, info) } } } var loginTests = []struct { - Scenario string - Email string - Password string - RememberMe bool - CreateSessionReturn *SessionReturn - LoginReturn *LoginReturn - ErrReturn error - MethodsCalled []string - ExpectedResult *rememberMeSession - ExpectedErr string + Scenario string + Email string + Password string + RememberMe bool + CreateSessionReturn *SessionReturn + LoginAndGetUserReturn *UserReturn + ErrReturn error + MethodsCalled []string + ExpectedResult *rememberMeSession + ExpectedErr string }{ { Scenario: "Invalid email", @@ -653,26 +644,26 @@ var loginTests = []struct { ExpectedErr: passwordValidationMessage, }, { - Scenario: "Can't get login", - Email: "email@example.com", - Password: "validPassword", - LoginReturn: loginErr(), - MethodsCalled: []string{"Login"}, - ExpectedErr: "Invalid username or password", + Scenario: "Can't get login", + Email: "email@example.com", + Password: "validPassword", + LoginAndGetUserReturn: userErr(), + MethodsCalled: []string{"LoginAndGetUser"}, + ExpectedErr: "Invalid username or password", }, { - Scenario: "Got session", - Email: "email@example.com", - Password: "correctPassword", - LoginReturn: loginSuccess(), - CreateSessionReturn: sessionSuccess(futureTime, futureTime), - MethodsCalled: []string{"Login", "CreateSession"}, + Scenario: "Got session", + Email: "email@example.com", + Password: "correctPassword", + LoginAndGetUserReturn: userSuccess(), + CreateSessionReturn: sessionSuccess(futureTime, futureTime), + MethodsCalled: []string{"LoginAndGetUser", "CreateSession"}, }, } func TestAuthLogin(t *testing.T) { for i, test := range loginTests { - backend := &mockBackend{LoginReturn: test.LoginReturn, ErrReturn: test.ErrReturn, CreateSessionReturn: test.CreateSessionReturn} + backend := &mockBackend{LoginAndGetUserReturn: test.LoginAndGetUserReturn, ErrReturn: test.ErrReturn, CreateSessionReturn: test.CreateSessionReturn} store := getAuthStore(nil, nil, nil, false, false, nil, backend) val, err := store.login(nil, &http.Request{}, test.Email, test.Password, test.RememberMe) methods := store.backend.(*mockBackend).MethodsCalled @@ -754,7 +745,7 @@ func TestLoginJson(t *testing.T) { var buf bytes.Buffer buf.WriteString(`{"Email":"test@test.com", "Password":"password", "RememberMe":true}`) r := &http.Request{Body: ioutil.NopCloser(&buf)} - backend := &mockBackend{LoginReturn: loginErr()} + backend := &mockBackend{LoginAndGetUserReturn: userErr()} store := getAuthStore(nil, nil, nil, true, false, nil, backend) _, lErr := store.Login(nil, r) err := lErr.(*AuthError).innerError @@ -787,7 +778,7 @@ func TestGetProfile(t *testing.T) { r, _ := http.NewRequest("PUT", "url", &buf) r.Header.Add("Content-Type", w.FormDataContentType()) profile, err := getProfile(r) - if err != nil || profile == nil || profile.FullName != "name" || profile.Organization != "org" || profile.Password != "pass" { + if err != nil || profile == nil || profile.Password != "pass" || profile.Info == nil || profile.Info["fullName"] != "name" || profile.Info["organization"] != "org" || profile.Info["mailQuota"] != "1" || profile.Info["fileQuota"] != "1" { t.Error("expected correct profile", profile, err) } } diff --git a/backend.go b/backend.go index bb129d2..26ecdfb 100644 --- a/backend.go +++ b/backend.go @@ -1,6 +1,7 @@ package auth import ( + "fmt" "time" "github.com/pkg/errors" @@ -23,19 +24,19 @@ var errUserAlreadyExists = errors.New("DB: User already exists") // Backender interface contains all the methods needed to read and write users, sessions and logins type Backender interface { - // UserBackender. Write out since it contains duplicate BackendCloser - AddUser(email string) (string, error) - GetUser(email string) (*user, error) - UpdateUser(userID, fullname string, company string, pictureURL string) error - CreateSecondaryEmail(userID, secondaryEmail string) error - - // LoginBackender. Write out since it contains duplicate BackendCloser - CreateLogin(userID, email, password, fullName string) (*UserLogin, error) - GetLogin(email string) (*UserLogin, error) - Login(email, password string) (*UserLogin, error) - SetPrimaryEmail(userID, newPrimaryEmail string) error + // duplicates userBackender since we can't have duplicate Close methods + AddUser(email string, info map[string]interface{}) (string, error) + AddUserFull(email, password string, info map[string]interface{}) (*User, error) + GetUser(email string) (*User, error) + UpdateUser(userID, password string, info map[string]interface{}) error + UpdateInfo(userID string, info map[string]interface{}) error UpdatePassword(userID, newPassword string) error + Login(email, password string) error + LoginAndGetUser(email, password string) (*User, error) + AddSecondaryEmail(userID, secondaryEmail string) error + UpdatePrimaryEmail(userID, newPrimaryEmail string) error + sessionBackender } @@ -44,30 +45,28 @@ type backendCloser interface { } // UserBackender interface holds methods for user management -type UserBackender interface { - AddUser(email string) (string, error) - GetUser(email string) (*user, error) - UpdateUser(userID, fullname string, company string, pictureURL string) error - CreateSecondaryEmail(userID, secondaryEmail string) error - backendCloser -} - -type loginBackender interface { - CreateLogin(userID, email, password, fullName string) (*UserLogin, error) - GetLogin(email string) (*UserLogin, error) - Login(email, password string) (*UserLogin, error) - SetPrimaryEmail(userID, newPrimaryEmail string) error +type userBackender interface { + AddUser(email string, info map[string]interface{}) (string, error) + AddUserFull(email, password string, info map[string]interface{}) (*User, error) + GetUser(email string) (*User, error) + UpdateUser(userID, password string, info map[string]interface{}) error + UpdateInfo(userID string, info map[string]interface{}) error UpdatePassword(userID, newPassword string) error + + Login(email, password string) error + LoginAndGetUser(email, password string) (*User, error) + AddSecondaryEmail(userID, secondaryEmail string) error + UpdatePrimaryEmail(userID, newPrimaryEmail string) error backendCloser } type sessionBackender interface { - CreateEmailSession(email, emailVerifyHash, csrfToken, destinationURL string) error + CreateEmailSession(email string, info map[string]interface{}, emailVerifyHash, csrfToken string) error GetEmailSession(verifyHash string) (*emailSession, error) UpdateEmailSession(verifyHash string, userID string) error DeleteEmailSession(verifyHash string) error - CreateSession(userID, email, fullname, sessionHash, csrfToken string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time) (*LoginSession, error) + CreateSession(userID, email string, info map[string]interface{}, sessionHash, csrfToken string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time) (*LoginSession, error) GetSession(sessionHash string) (*LoginSession, error) UpdateSession(sessionHash string, renewTimeUTC, expireTimeUTC time.Time) error DeleteSession(sessionHash string) error @@ -81,38 +80,77 @@ type sessionBackender interface { } type emailSession struct { - UserID string `bson:"userID" json:"userID"` - Email string `bson:"email" json:"email"` - EmailVerifyHash string `bson:"_id" json:"emailVerifyHash"` - CSRFToken string `bson:"csrfToken" json:"csrfToken"` - DestinationURL string `bson:"destinationURL" json:"destinationURL"` + UserID string `bson:"userID" json:"userID"` + Email string `bson:"email" json:"email"` + Info map[string]interface{} `bson:"info" json:"info"` + EmailVerifyHash string `bson:"_id" json:"emailVerifyHash"` + CSRFToken string `bson:"csrfToken" json:"csrfToken"` } type user struct { UserID string - FullName string PrimaryEmail string + PasswordHash string + Info map[string]interface{} LockoutEndTimeUTC *time.Time AccessFailedCount int - Roles []string } -// UserLogin is the struct which holds login information -type UserLogin struct { - UserID string `json:"userID"` - Email string `json:"email"` - FullName string `json:"fullName"` +// User is the struct which holds user information +type User struct { + UserID string `json:"userID"` + Email string `json:"email"` + Info map[string]interface{} `json:"info"` } // LoginSession is the struct which holds session information type LoginSession struct { - UserID string `bson:"userID" json:"userID"` - Email string `bson:"email" json:"email"` - FullName string `bson:"fullName" json:"fullName"` - SessionHash string `bson:"_id" json:"sessionHash"` - CSRFToken string `bson:"csrfToken" json:"csrfToken"` - RenewTimeUTC time.Time `bson:"renewTimeUTC" json:"renewTimeUTC"` - ExpireTimeUTC time.Time `bson:"expireTimeUTC" json:"expireTimeUTC"` + UserID string `bson:"userID" json:"userID"` + Email string `bson:"email" json:"email"` + Info map[string]interface{} `bson:"info" json:"info"` + SessionHash string `bson:"_id" json:"sessionHash"` + CSRFToken string `bson:"csrfToken" json:"csrfToken"` + RenewTimeUTC time.Time `bson:"renewTimeUTC" json:"renewTimeUTC"` + ExpireTimeUTC time.Time `bson:"expireTimeUTC" json:"expireTimeUTC"` +} + +// GetInfo will return the named info as an interface{} +func (l *LoginSession) GetInfo(name string) interface{} { + if l == nil || l.Info == nil { + return nil + } + return l.Info[name] +} + +// GetInfoString will return the named info as a string +func (l *LoginSession) GetInfoString(name string) string { + v := l.GetInfo(name) + if v == nil { + return "" + } + if i, ok := v.(string); ok { + return i + } + return fmt.Sprint(v) +} + +// GetInfoStrings will return the named info as an array of strings +func (l *LoginSession) GetInfoStrings(name string) []string { + if strs, ok := l.GetInfo(name).([]string); ok { + return strs + } + if v, ok := l.GetInfo(name).([]interface{}); ok { + strArr := make([]string, len(v)) + for i, str := range v { + if s, ok := str.(string); ok { + strArr[i] = s + } else { + strArr[i] = fmt.Sprint(str) + } + } + return strArr + } + return nil } type rememberMeSession struct { @@ -169,27 +207,26 @@ func (a *AuthError) Trace() string { } type backend struct { - u UserBackender - l loginBackender + u userBackender s sessionBackender backendCloser } // NewBackend returns a Backender from a UserBackender, LoginBackender and SessionBackender -func NewBackend(u UserBackender, l loginBackender, s sessionBackender) Backender { - return &backend{u: u, l: l, s: s} +func NewBackend(u userBackender, s sessionBackender) Backender { + return &backend{u: u, s: s} } -func (b *backend) GetLogin(email string) (*UserLogin, error) { - return b.l.GetLogin(email) +func (b *backend) Login(email, password string) error { + return b.u.Login(email, password) } -func (b *backend) Login(email, password string) (*UserLogin, error) { - return b.l.Login(email, password) +func (b *backend) LoginAndGetUser(email, password string) (*User, error) { + return b.u.LoginAndGetUser(email, password) } -func (b *backend) CreateSession(userID string, email, fullname, sessionHash, csrfToken string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time) (*LoginSession, error) { - return b.s.CreateSession(userID, email, fullname, sessionHash, csrfToken, sessionRenewTimeUTC, sessionExpireTimeUTC) +func (b *backend) CreateSession(userID, email string, info map[string]interface{}, sessionHash, csrfToken string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time) (*LoginSession, error) { + return b.s.CreateSession(userID, email, info, sessionHash, csrfToken, sessionRenewTimeUTC, sessionExpireTimeUTC) } func (b *backend) GetSession(sessionHash string) (*LoginSession, error) { @@ -212,8 +249,8 @@ func (b *backend) UpdateRememberMe(selector string, renewTimeUTC time.Time) erro return b.s.UpdateRememberMe(selector, renewTimeUTC) } -func (b *backend) CreateEmailSession(email, emailVerifyHash, csrfToken, destinationURL string) error { - return b.s.CreateEmailSession(email, emailVerifyHash, csrfToken, destinationURL) +func (b *backend) CreateEmailSession(email string, info map[string]interface{}, emailVerifyHash, csrfToken string) error { + return b.s.CreateEmailSession(email, info, emailVerifyHash, csrfToken) } func (b *backend) GetEmailSession(emailVerifyHash string) (*emailSession, error) { @@ -228,32 +265,36 @@ func (b *backend) DeleteEmailSession(emailVerifyHash string) error { return b.s.DeleteEmailSession(emailVerifyHash) } -func (b *backend) AddUser(email string) (string, error) { - return b.u.AddUser(email) +func (b *backend) AddUser(email string, info map[string]interface{}) (string, error) { + return b.u.AddUser(email, info) +} + +func (b *backend) AddUserFull(email, password string, info map[string]interface{}) (*User, error) { + return b.u.AddUserFull(email, password, info) } -func (b *backend) GetUser(email string) (*user, error) { +func (b *backend) GetUser(email string) (*User, error) { return b.u.GetUser(email) } -func (b *backend) UpdateUser(userID, fullname string, company string, pictureURL string) error { - return b.u.UpdateUser(userID, fullname, company, pictureURL) +func (b *backend) UpdateUser(userID, password string, info map[string]interface{}) error { + return b.u.UpdateUser(userID, password, info) } -func (b *backend) CreateLogin(userID, email, password, fullName string) (*UserLogin, error) { - return b.l.CreateLogin(userID, email, password, fullName) +func (b *backend) UpdateInfo(userID string, info map[string]interface{}) error { + return b.u.UpdateInfo(userID, info) } -func (b *backend) CreateSecondaryEmail(userID string, secondaryEmail string) error { - return b.u.CreateSecondaryEmail(userID, secondaryEmail) +func (b *backend) AddSecondaryEmail(userID string, secondaryEmail string) error { + return b.u.AddSecondaryEmail(userID, secondaryEmail) } -func (b *backend) SetPrimaryEmail(userID, secondaryEmail string) error { - return b.l.SetPrimaryEmail(userID, secondaryEmail) +func (b *backend) UpdatePrimaryEmail(userID, secondaryEmail string) error { + return b.u.UpdatePrimaryEmail(userID, secondaryEmail) } func (b *backend) UpdatePassword(userID, password string) error { - return b.l.UpdatePassword(userID, password) + return b.u.UpdatePassword(userID, password) } func (b *backend) DeleteSession(sessionHash string) error { @@ -272,8 +313,5 @@ func (b *backend) Close() error { if err := b.s.Close(); err != nil { return err } - if err := b.u.Close(); err != nil { - return err - } - return b.l.Close() + return b.u.Close() } diff --git a/backendDbUser.go b/backendDbUser.go deleted file mode 100644 index 4b3f502..0000000 --- a/backendDbUser.go +++ /dev/null @@ -1,56 +0,0 @@ -package auth - -import ( - "strconv" - - "github.com/EndFirstCorp/onedb" - "github.com/pkg/errors" -) - -type backendDbUser struct { - Db onedb.DBer - - AddUserQuery string - GetUserQuery string - UpdateUserQuery string - CreateSecondaryEmailQuery string -} - -// NewBackendDbUser creates a Postgres-based UserBackender -func NewBackendDbUser(server string, port int, username, password, database string, addUserQuery, getUserQuery, updateUserQuery, createSecondaryEmailQuery string) (UserBackender, error) { - db, err := onedb.NewPgx(server, uint16(port), username, password, database) - if err != nil { - return nil, err - } - return &backendDbUser{Db: db, - GetUserQuery: getUserQuery, - AddUserQuery: addUserQuery, - UpdateUserQuery: updateUserQuery, - CreateSecondaryEmailQuery: createSecondaryEmailQuery}, nil -} - -func (u *backendDbUser) AddUser(email string) (string, error) { - var userID int32 = -1 - return strconv.Itoa(int(userID)), u.Db.QueryValues(onedb.NewSqlQuery(u.AddUserQuery, email), &userID) -} - -func (u *backendDbUser) GetUser(email string) (*user, error) { - r := &user{} - err := u.Db.QueryStructRow(onedb.NewSqlQuery(u.GetUserQuery, email), r) - if err != nil { - return nil, errors.New("Unable to get user: " + err.Error()) - } - return r, err -} - -func (u *backendDbUser) UpdateUser(userID string, fullname string, company string, pictureURL string) error { - return u.Db.Execute(onedb.NewSqlQuery(u.UpdateUserQuery, userID, fullname)) -} - -func (u *backendDbUser) CreateSecondaryEmail(userID, secondaryEmail string) error { - return u.Db.Execute(onedb.NewSqlQuery(u.CreateSecondaryEmailQuery, userID, secondaryEmail)) -} - -func (u *backendDbUser) Close() error { - return nil -} diff --git a/backendDbUser_test.go b/backendDbUser_test.go deleted file mode 100644 index b7614a2..0000000 --- a/backendDbUser_test.go +++ /dev/null @@ -1,12 +0,0 @@ -package auth - -import ( - "testing" -) - -func TestNewBackendDbUser(t *testing.T) { - if testing.Short() { - t.SkipNow() - } - -} diff --git a/backendLDAPLogin.go b/backendLDAPLogin.go deleted file mode 100644 index b6cf79d..0000000 --- a/backendLDAPLogin.go +++ /dev/null @@ -1,92 +0,0 @@ -package auth - -import ( - "fmt" - - "github.com/EndFirstCorp/onedb" - "gopkg.in/ldap.v2" -) - -type backendLDAPLogin struct { - db onedb.DBer - baseDn string - userLoginFilter string -} - -// NewBackendLDAPLogin creates a LoginBackender using OpenLDAP -func NewBackendLDAPLogin(server string, port int, bindDn, password, baseDn, userLoginFilter string) (loginBackender, error) { - db, err := onedb.NewLdap(server, port, bindDn, password) - if err != nil { - return nil, err - } - return &backendLDAPLogin{db, baseDn, userLoginFilter}, nil -} - -type ldapData struct { - UID string - DbUserId string - Cn string -} - -func (l *backendLDAPLogin) Login(email, password string) (*UserLogin, error) { - // check credentials - err := l.db.Execute(ldap.NewSimpleBindRequest(fmt.Sprintf("uid=%s,%s", email, l.baseDn), password, nil)) - if err != nil { - return nil, err - } - return l.GetLogin(email) -} - -func (l *backendLDAPLogin) GetLogin(email string) (*UserLogin, error) { - // get login info - req := ldap.NewSearchRequest(l.baseDn, ldap.ScopeSingleLevel, ldap.NeverDerefAliases, 0, 0, false, fmt.Sprintf(l.userLoginFilter, email), []string{"uid", "dbUserId", "cn"}, nil) - data := &ldapData{} - err := l.db.QueryStructRow(req, data) - if err != nil { - return nil, err - } - return &UserLogin{UserID: data.DbUserId, Email: data.UID, FullName: data.Cn}, nil -} - -/**************** TODO: create different type of user if not using file and mail quotas **********************/ -func (l *backendLDAPLogin) CreateLogin(userID, email, password, fullName string) (*UserLogin, error) { - req := ldap.NewAddRequest("uid=" + email + "," + l.baseDn) - req.Attribute("objectClass", []string{"endfirstAccount"}) - req.Attribute("uid", []string{email}) - req.Attribute("dbUserId", []string{userID}) - req.Attribute("cn", []string{fullName}) - req.Attribute("userPassword", []string{password}) - err := l.db.Execute(req) - return &UserLogin{}, err -} - -/*func (l *backendLDAPLogin) CreateSubscriber(userID int, email, password, fullName, homeDirectory string, uidNumber, gidNumber int, mailQuota, fileQuota string) (*UserLogin, error) { - req := ldap.NewAddRequest("uid=" + email + "," + l.baseDn) - req.Attribute("objectClass", []string{"endfirstAccount", "endfirstSubscriber"}) - req.Attribute("uid", []string{email}) - req.Attribute("dbUserId", []string{strconv.Itoa(userID)}) - req.Attribute("cn", []string{fullName}) - req.Attribute("userPassword", []string{password}) - req.Attribute("uidNumber", []string{strconv.Itoa(uidNumber)}) - req.Attribute("gidNumber", []string{strconv.Itoa(gidNumber)}) - req.Attribute("mailFolder", []string{homeDirectory}) - req.Attribute("mailQuota", []string{mailQuota}) - req.Attribute("fileQuota", []string{fileQuota}) - err := l.db.Execute(req) - return &UserLogin{}, err -}*/ - -func (l *backendLDAPLogin) CreateSecondaryEmail(userID, secondaryEmail string) error { - return nil -} - -func (l *backendLDAPLogin) SetPrimaryEmail(userID, newPrimaryEmail string) error { - return nil -} -func (l *backendLDAPLogin) UpdatePassword(userID, newPassword string) error { - return nil -} - -func (l *backendLDAPLogin) Close() error { - return l.db.Close() -} diff --git a/backendLDAPLogin_test.go b/backendLDAPLogin_test.go deleted file mode 100644 index 84129c0..0000000 --- a/backendLDAPLogin_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package auth - -import ( - "testing" - - "github.com/EndFirstCorp/onedb" - "github.com/pkg/errors" - "gopkg.in/ldap.v2" -) - -func TestNewBackendLDAPLogin(t *testing.T) { - if testing.Short() { - t.SkipNow() - } - - ldapServer := "ldap" - ldapPort := 389 - ldapBindDn := "uid=admin,ou=SystemAccounts,dc=example,dc=com" - ldapPassword := "secret" - ldapBaseDn := "ou=Users,dc=example,dc=com" - ldapUserFilter := "(&(objectClass=account)(uid=%s))" - - l, err := NewBackendLDAPLogin(ldapServer, ldapPort, ldapBindDn, ldapPassword, ldapBaseDn, ldapUserFilter) - if err != nil { - t.Fatal("unable to login", err) - } - - _, err = l.Login("test@test.com", "") - if err == nil { - t.Fatal("Expected no results", err) - } -} - -func TestLdapLogin(t *testing.T) { - // success - data := ldapData{UID: "email", DbUserId: "1234"} - m := onedb.NewMock(nil, nil, data) - l := backendLDAPLogin{db: m, userLoginFilter: "%s"} - login, err := l.Login("email", "password") - if err != nil || login.Email != "email" { - t.Error("expected to find data", login, err) - } - - queries := m.QueriesRun() - if _, ok := queries[0].(*ldap.SimpleBindRequest); !ok { - t.Error("expected ldap bind request first") - } - if _, ok := queries[1].(*ldap.SearchRequest); !ok { - t.Error("expected ldap searc request next") - } - - // error - m = onedb.NewMock(nil, nil, nil) - l = backendLDAPLogin{db: m, userLoginFilter: "%s"} - _, err = l.Login("email", "password") - if err == nil { - t.Error("expected error") - } -} - -// replace with test that does something when code does something -func TestLdapCreateSecondaryEmail(t *testing.T) { - m := onedb.NewMock(nil, nil, nil) - l := backendLDAPLogin{db: m} - l.CreateSecondaryEmail("userID", "secondaryEmail") -} - -func TestLdapSetPrimaryEmail(t *testing.T) { - m := onedb.NewMock(nil, nil, nil) - l := backendLDAPLogin{db: m} - l.SetPrimaryEmail("userID", "newPrimaryEmail") -} - -// replace with test that does something when code does something -func TestLdapUpdatePassword(t *testing.T) { - m := onedb.NewMock(nil, nil, nil) - l := backendLDAPLogin{db: m} - l.UpdatePassword("userID", "newPassword") -} - -func TestLdapClose(t *testing.T) { - m := onedb.NewMock(errors.New("failed"), nil, nil) - l := backendLDAPLogin{db: m} - if err := l.Close(); err == nil { - t.Error("expected close to error out") - } -} diff --git a/backendMemory.go b/backendMemory.go index 076c3c4..e1ff3dd 100644 --- a/backendMemory.go +++ b/backendMemory.go @@ -7,18 +7,10 @@ import ( "time" ) -type userLoginMemory struct { - UserID string - Email string - FullName string - PasswordHash string -} - type backendMemory struct { Backender EmailSessions []*emailSession Users []*user - Logins []*userLoginMemory Sessions []*LoginSession RememberMes []*rememberMeSession LoginProviders []*loginProvider @@ -34,32 +26,29 @@ func NewBackendMemory(c Crypter) Backender { return &backendMemory{c: c, LoginProviders: []*loginProvider{&loginProvider{LoginProviderID: 1, Name: loginProviderDefaultName}}} } -func (m *backendMemory) GetLogin(email string) (*UserLogin, error) { - login := m.getLoginByEmail(email) - if login == nil { - return nil, errLoginNotFound - } - return &UserLogin{login.UserID, login.Email, login.FullName}, nil +func (m *backendMemory) Login(email, password string) error { + _, err := m.LoginAndGetUser(email, password) + return err } -func (m *backendMemory) Login(email, password string) (*UserLogin, error) { - login := m.getLoginByEmail(email) - if login == nil { - return nil, errLoginNotFound +func (m *backendMemory) LoginAndGetUser(email, password string) (*User, error) { + user := m.getUserByEmail(email) + if user == nil { + return nil, errUserNotFound } - if err := m.c.HashEquals(password, login.PasswordHash); err != nil { + if err := m.c.HashEquals(password, user.PasswordHash); err != nil { return nil, err } - return &UserLogin{login.UserID, login.Email, login.FullName}, nil + return &User{user.UserID, user.PrimaryEmail, user.Info}, nil } -func (m *backendMemory) CreateSession(userID, email, fullname, sessionHash, csrfToken string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time) (*LoginSession, error) { +func (m *backendMemory) CreateSession(userID, email string, info map[string]interface{}, sessionHash, csrfToken string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time) (*LoginSession, error) { session := m.getSessionByHash(sessionHash) if session != nil { return nil, errSessionAlreadyExists } - session = &LoginSession{userID, email, fullname, sessionHash, csrfToken, sessionRenewTimeUTC, sessionExpireTimeUTC} + session = &LoginSession{userID, email, info, sessionHash, csrfToken, sessionRenewTimeUTC, sessionExpireTimeUTC} m.Sessions = append(m.Sessions, session) return session, nil } @@ -110,7 +99,7 @@ func (m *backendMemory) UpdateRememberMe(selector string, renewTimeUTC time.Time return nil } -func (m *backendMemory) CreateEmailSession(email, emailVerifyHash, csrfToken, destinationURL string) error { +func (m *backendMemory) CreateEmailSession(email string, info map[string]interface{}, emailVerifyHash, csrfToken string) error { if m.getUserByEmail(email) != nil { return errUserAlreadyExists } @@ -118,7 +107,7 @@ func (m *backendMemory) CreateEmailSession(email, emailVerifyHash, csrfToken, de return errEmailVerifyHashExists } - m.EmailSessions = append(m.EmailSessions, &emailSession{"", email, emailVerifyHash, csrfToken, destinationURL}) + m.EmailSessions = append(m.EmailSessions, &emailSession{"", email, info, emailVerifyHash, csrfToken}) return nil } @@ -147,53 +136,89 @@ func (m *backendMemory) DeleteEmailSession(emailVerifyHash string) error { return nil } -func (m *backendMemory) AddUser(email string) (string, error) { +func (m *backendMemory) AddUser(email string, info map[string]interface{}) (string, error) { u := m.getUserByEmail(email) if u != nil { return "", errUserAlreadyExists } m.LastUserID++ - m.Users = append(m.Users, &user{strconv.Itoa(m.LastUserID), "", email, nil, 0, nil}) + m.Users = append(m.Users, &user{strconv.Itoa(m.LastUserID), email, "", info, nil, 0}) return strconv.Itoa(m.LastUserID), nil } -func (m *backendMemory) GetUser(email string) (*user, error) { +func (m *backendMemory) AddUserFull(email, password string, info map[string]interface{}) (*User, error) { + passwordHash, err := m.c.Hash(password) + if err != nil { + return nil, err + } + u := m.getUserByEmail(email) + if u != nil { + return nil, errUserAlreadyExists + } + m.LastUserID++ + m.Users = append(m.Users, &user{strconv.Itoa(m.LastUserID), email, passwordHash, info, nil, 0}) + return &User{u.UserID, u.PrimaryEmail, u.Info}, nil +} + +func (m *backendMemory) GetUser(email string) (*User, error) { u := m.getUserByEmail(email) if u == nil { return nil, errUserNotFound } - return u, nil + return &User{u.UserID, u.PrimaryEmail, u.Info}, nil +} + +func (m *backendMemory) UpdateUser(userID, password string, info map[string]interface{}) error { + passwordHash, err := m.c.Hash(password) + if err != nil { + return err + } + user := m.getUserByID(userID) + if user == nil { + return errUserNotFound + } + if user.Info == nil { + user.Info = make(map[string]interface{}) + } + for key := range info { + user.Info[key] = info[key] + } + user.PasswordHash = passwordHash + return nil } -func (m *backendMemory) UpdateUser(userID, fullname string, company string, pictureURL string) error { +func (m *backendMemory) UpdateInfo(userID string, info map[string]interface{}) error { user := m.getUserByID(userID) if user == nil { return errUserNotFound } - user.FullName = fullname - // need to be able to create company and set pictureURL + if user.Info == nil { + user.Info = make(map[string]interface{}) + } + for key := range info { + user.Info[key] = info[key] + } return nil } -func (m *backendMemory) CreateLogin(userID, email, password, fullName string) (*UserLogin, error) { +func (m *backendMemory) UpdatePassword(userID, password string) error { passwordHash, err := m.c.Hash(password) if err != nil { - return nil, err + return err } - login := userLoginMemory{userID, email, fullName, passwordHash} - m.Logins = append(m.Logins, &login) - - return &UserLogin{userID, email, fullName}, nil -} -func (m *backendMemory) CreateSecondaryEmail(userID, secondaryEmail string) error { + user := m.getUserByID(userID) + if user == nil { + return errUserNotFound + } + user.PasswordHash = passwordHash return nil } -func (m *backendMemory) SetPrimaryEmail(userID, newPrimaryEmail string) error { +func (m *backendMemory) AddSecondaryEmail(userID, secondaryEmail string) error { return nil } -func (m *backendMemory) UpdatePassword(userID string, newPassword string) error { +func (m *backendMemory) UpdatePrimaryEmail(userID, newPrimaryEmail string) error { return nil } @@ -217,10 +242,6 @@ func (m *backendMemory) ToString() string { for _, user := range m.Users { buf.WriteString(fmt.Sprintln(" ", *user)) } - buf.WriteString("Logins:\n") - for _, login := range m.Logins { - buf.WriteString(fmt.Sprintln(" ", *login)) - } buf.WriteString("Sessions:\n") for _, session := range m.Sessions { buf.WriteString(fmt.Sprintln(" ", *session)) @@ -266,15 +287,6 @@ func (m *backendMemory) removeSession(sessionHash string) { } } -func (m *backendMemory) getLoginByEmail(email string) *userLoginMemory { - for _, login := range m.Logins { - if login.Email == email { - return login - } - } - return nil -} - func (m *backendMemory) getUserByID(userID string) *user { for _, user := range m.Users { if user.UserID == userID { diff --git a/backendMemory_test.go b/backendMemory_test.go index 532f911..46da2e0 100644 --- a/backendMemory_test.go +++ b/backendMemory_test.go @@ -11,36 +11,41 @@ var in1Hour = time.Now().UTC().Add(time.Hour) func TestMemoryLogin(t *testing.T) { // can't get login backend := NewBackendMemory(&hashStore{}).(*backendMemory) - if _, err := backend.Login("email", "password"); err != errLoginNotFound { - t.Error("expected no login since login not added yet", err) + if err := backend.Login("email", "password"); err != errUserNotFound { + t.Error("expected no login since user not added yet", err) } // invalid credentials - expected := &userLoginMemory{Email: "email", UserID: "1", FullName: "name", PasswordHash: "zVNfmBbTwQZwyMsAizV1Guh_j7kcFbyG7-LRJeeJfXc="} - backend.Logins = []*userLoginMemory{expected} - if _, err := backend.Login("email", "wrongPassword"); err != nil && err.Error() != "supplied token and tokenHash do not match" { + expected := &user{PrimaryEmail: "email", UserID: "1", Info: map[string]interface{}{"key": "value"}, PasswordHash: "zVNfmBbTwQZwyMsAizV1Guh_j7kcFbyG7-LRJeeJfXc="} + backend.Users = []*user{expected} + if err := backend.Login("email", "wrongPassword"); err != nil && err.Error() != "supplied token and tokenHash do not match" { t.Error("expected error", err) } // success - expected = &userLoginMemory{Email: "email", UserID: "1", FullName: "name", PasswordHash: "zVNfmBbTwQZwyMsAizV1Guh_j7kcFbyG7-LRJeeJfXc="} // hash of "correctPassword"" - backend.Logins = []*userLoginMemory{expected} - if actual, err := backend.Login("email", "correctPassword"); err != nil || expected == nil || expected.Email != actual.Email || expected.FullName != actual.FullName || expected.UserID != actual.UserID { - t.Error("expected success", expected, actual, err) + expected = &user{PrimaryEmail: "email", UserID: "1", Info: map[string]interface{}{"key": "value"}, PasswordHash: "zVNfmBbTwQZwyMsAizV1Guh_j7kcFbyG7-LRJeeJfXc="} // hash of "correctPassword"" + backend.Users = []*user{expected} + if err := backend.Login("email", "correctPassword"); err != nil { + t.Error("expected success", err) + } + actual := backend.Users[0] + if expected.PrimaryEmail != actual.PrimaryEmail || expected.UserID != actual.UserID || actual.Info == nil || actual.Info["key"] != "value" { + t.Error("expected matching values", expected, actual) } } func TestMemoryCreateSession(t *testing.T) { backend := NewBackendMemory(&hashStore{}).(*backendMemory) - if session, _ := backend.CreateSession("1", "test@test.com", "fullname", "sessionHash", "csrfToken", in5Minutes, in1Hour); session.SessionHash != "sessionHash" || session.Email != "test@test.com" { + if session, _ := backend.CreateSession("1", "test@test.com", map[string]interface{}{"key": "value"}, "sessionHash", "csrfToken", in5Minutes, in1Hour); session.SessionHash != "sessionHash" || session.Email != "test@test.com" || + session.CSRFToken != "csrfToken" || session.Info == nil || session.Info["key"] != "value" || session.UserID != "1" || session.ExpireTimeUTC != in1Hour || session.RenewTimeUTC != in5Minutes { t.Error("expected matching session", session) } // create again, should error - if _, err := backend.CreateSession("1", "test@test.com", "fullname", "sessionHash", "csrfToken", in5Minutes, in1Hour); err == nil { + if _, err := backend.CreateSession("1", "test@test.com", map[string]interface{}{"key": "value"}, "sessionHash", "csrfToken", in5Minutes, in1Hour); err == nil { t.Error("expected error since session exists", err) } // new session ID since it was generated when no cookie was found (e.g. on another computer or browser) - if session, _ := backend.CreateSession("1", "test@test.com", "fullname", "newSessionHash", "csrfToken", in5Minutes, in1Hour); session.SessionHash != "newSessionHash" || len(backend.Sessions) != 2 { + if session, _ := backend.CreateSession("1", "test@test.com", map[string]interface{}{"key": "value"}, "newSessionHash", "csrfToken", in5Minutes, in1Hour); session.SessionHash != "newSessionHash" || len(backend.Sessions) != 2 { t.Error("expected matching session", session) } } @@ -116,11 +121,11 @@ func TestMemoryRenewRememberMe(t *testing.T) { func TestMemoryAddUser(t *testing.T) { backend := NewBackendMemory(&hashStore{}).(*backendMemory) - if userID, err := backend.AddUser("email"); err != nil || len(backend.Users) != 1 || userID != "1" { + if userID, err := backend.AddUser("email", map[string]interface{}{"key": "value"}); err != nil || len(backend.Users) != 1 || userID != "1" { t.Error("expected valid session", err, backend.Users) } - if userID, err := backend.AddUser("email"); err != errUserAlreadyExists || userID != "" { + if userID, err := backend.AddUser("email", map[string]interface{}{"key": "value"}); err != errUserAlreadyExists || userID != "" { t.Error("expected user to already exist", err) } } @@ -140,14 +145,14 @@ func TestMemoryGetEmailSession(t *testing.T) { func TestMemoryUpdateUser(t *testing.T) { backend := NewBackendMemory(&hashStore{}).(*backendMemory) - err := backend.UpdateUser("1", "fullname", "company", "pictureUrl") + err := backend.UpdateUser("1", "password", map[string]interface{}{"key": "value"}) if err != errUserNotFound { t.Error("expected to be unable to update non-existant user") } backend = NewBackendMemory(&hashStore{}).(*backendMemory) backend.Users = append(backend.Users, &user{UserID: "1", PrimaryEmail: "email"}) - err = backend.UpdateUser("1", "fullname", "company", "pictureUrl") + err = backend.UpdateUser("1", "password", map[string]interface{}{"key": "value"}) if err != nil { t.Error("expected success", err) } @@ -155,12 +160,12 @@ func TestMemoryUpdateUser(t *testing.T) { func TestMemoryCreateSecondaryEmail(t *testing.T) { backend := NewBackendMemory(&hashStore{}).(*backendMemory) - backend.CreateSecondaryEmail("userID", "secondaryEmail") + backend.AddSecondaryEmail("userID", "secondaryEmail") } func TestMemorySetPrimaryEmail(t *testing.T) { backend := NewBackendMemory(&hashStore{}).(*backendMemory) - backend.SetPrimaryEmail("userID", "newPrimaryEmail") + backend.UpdatePrimaryEmail("userID", "newPrimaryEmail") } func TestMemoryUpdatePassword(t *testing.T) { @@ -199,13 +204,12 @@ func TestMemoryClose(t *testing.T) { func TestToString(t *testing.T) { backend := NewBackendMemory(&hashStore{}).(*backendMemory) backend.Users = append(backend.Users, &user{}) - backend.Logins = append(backend.Logins, &userLoginMemory{}) backend.Sessions = append(backend.Sessions, &LoginSession{}) backend.RememberMes = append(backend.RememberMes, &rememberMeSession{}) actual := backend.ToString() - expected := "Users:\n { 0 []}\nLogins:\n { }\nSessions:\n { 0001-01-01 00:00:00 +0000 UTC 0001-01-01 00:00:00 +0000 UTC}\nRememberMe:\n { 0001-01-01 00:00:00 +0000 UTC 0001-01-01 00:00:00 +0000 UTC}\n" + expected := "Users:\n { map[] 0}\nSessions:\n { map[] 0001-01-01 00:00:00 +0000 UTC 0001-01-01 00:00:00 +0000 UTC}\nRememberMe:\n { 0001-01-01 00:00:00 +0000 UTC 0001-01-01 00:00:00 +0000 UTC}\n" if actual != expected { - t.Error("expected different value", actual) + t.Error("expected different value", expected, "\n", actual) } } diff --git a/backendMongo.go b/backendMongo.go index 3fefd98..24cb157 100644 --- a/backendMongo.go +++ b/backendMongo.go @@ -14,14 +14,13 @@ type backendMongo struct { } type mongoUser struct { - ID bson.ObjectId `bson:"_id" json:"id"` - PrimaryEmail string `bson:"primaryEmail" json:"primaryEmail"` - SecondaryEmails []email `bson:"secondaryEmails" json:"secondaryEmails"` - FullName string `bson:"fullName" json:"fullName"` - PasswordHash string `bson:"passwordHash" json:"passwordHash"` - LockoutEndTimeUTC *time.Time `bson:"lockoutEndTimeUTC" json:"lockoutEndTimeUTC"` - AccessFailedCount int `bson:"accessFailedCount" json:"accessFailedCount"` - Roles []string `bson:"roles" json:"roles"` + ID bson.ObjectId `bson:"_id" json:"id"` + PrimaryEmail string `bson:"primaryEmail" json:"primaryEmail"` + SecondaryEmails []email `bson:"secondaryEmails" json:"secondaryEmails"` + PasswordHash string `bson:"passwordHash" json:"passwordHash"` + Info map[string]interface{} `bson:"info" json:"info"` + LockoutEndTimeUTC *time.Time `bson:"lockoutEndTimeUTC" json:"lockoutEndTimeUTC"` + AccessFailedCount int `bson:"accessFailedCount" json:"accessFailedCount"` } type email struct { @@ -35,14 +34,28 @@ func NewBackendMongo(m mgo.Sessioner, c Crypter) Backender { return &backendMongo{m, c} } -func (b *backendMongo) AddUser(email string) (string, error) { +func (b *backendMongo) AddUser(email string, info map[string]interface{}) (string, error) { u, err := b.getUser(email) if err == nil { return u.ID.Hex(), errors.New("user already exists") } id := bson.NewObjectId() - return id.Hex(), b.users().Insert(mongoUser{ID: id, PrimaryEmail: email}) + return id.Hex(), b.users().Insert(mongoUser{ID: id, PrimaryEmail: email, Info: info}) +} + +func (b *backendMongo) AddUserFull(email, password string, info map[string]interface{}) (*User, error) { + passwordHash, err := b.c.Hash(password) + if err != nil { + return nil, err + } + _, err = b.getUser(email) + if err == nil { + return nil, errors.New("user already exists") + } + + id := bson.NewObjectId() + return &User{id.Hex(), email, info}, b.users().Insert(mongoUser{ID: id, PrimaryEmail: email, PasswordHash: passwordHash, Info: info}) } func (b *backendMongo) getUser(email string) (*mongoUser, error) { @@ -50,16 +63,41 @@ func (b *backendMongo) getUser(email string) (*mongoUser, error) { return u, b.users().Find(bson.M{"primaryEmail": email}).One(u) } -func (b *backendMongo) GetUser(email string) (*user, error) { +func (b *backendMongo) GetUser(email string) (*User, error) { u, err := b.getUser(email) if err != nil { return nil, err } - return &user{UserID: u.ID.Hex(), FullName: u.FullName, PrimaryEmail: u.PrimaryEmail, AccessFailedCount: u.AccessFailedCount, LockoutEndTimeUTC: u.LockoutEndTimeUTC, Roles: u.Roles}, nil + return &User{u.ID.Hex(), u.PrimaryEmail, u.Info}, nil +} + +func (b *backendMongo) UpdateUser(userID, password string, info map[string]interface{}) error { + passwordHash, err := b.c.Hash(password) + if err != nil { + return err + } + set := bson.M{} + for key := range info { + set["info."+key] = info[key] + } + set["passwordHash"] = passwordHash + return b.users().UpdateId(bson.ObjectIdHex(userID), bson.M{"$set": set}) +} + +func (b *backendMongo) UpdatePassword(userID, password string) error { + passwordHash, err := b.c.Hash(password) + if err != nil { + return err + } + return b.users().UpdateId(bson.ObjectIdHex(userID), bson.M{"$set": bson.M{"passwordHash": passwordHash}}) } -func (b *backendMongo) UpdateUser(userID string, fullname string, company string, pictureURL string) error { - return b.users().UpdateId(bson.ObjectIdHex(userID), bson.M{"$set": bson.M{"fullName": fullname}}) +func (b *backendMongo) UpdateInfo(userID string, info map[string]interface{}) error { + var set bson.M + for key := range info { + set["info."+key] = info[key] + } + return b.users().UpdateId(bson.ObjectIdHex(userID), bson.M{"$set": set}) } func (b *backendMongo) Close() error { @@ -67,7 +105,7 @@ func (b *backendMongo) Close() error { return nil } -func (b *backendMongo) Login(email, password string) (*UserLogin, error) { +func (b *backendMongo) LoginAndGetUser(email, password string) (*User, error) { u, err := b.getUser(email) if err != nil { return nil, err @@ -75,46 +113,29 @@ func (b *backendMongo) Login(email, password string) (*UserLogin, error) { if err := b.c.HashEquals(password, u.PasswordHash); err != nil { return nil, err } - return &UserLogin{UserID: u.ID.Hex(), FullName: u.FullName, Email: u.PrimaryEmail}, nil + return &User{u.ID.Hex(), u.PrimaryEmail, u.Info}, nil } -func (b *backendMongo) GetLogin(email string) (*UserLogin, error) { - u, err := b.getUser(email) - if err != nil { - return nil, err - } - return &UserLogin{UserID: u.ID.Hex(), FullName: u.FullName, Email: u.PrimaryEmail}, nil +func (b *backendMongo) Login(email, password string) error { + _, err := b.LoginAndGetUser(email, password) + return err } -func (b *backendMongo) CreateLogin(userID, email, password, fullName string) (*UserLogin, error) { - passwordHash, err := b.c.Hash(password) - if err != nil { - return nil, err - } - return &UserLogin{UserID: userID, FullName: fullName, Email: email}, - b.users().UpdateId(bson.ObjectIdHex(userID), bson.M{"$set": bson.M{"passwordHash": passwordHash}}) -} - -func (b *backendMongo) CreateSecondaryEmail(userID, secondaryEmail string) error { +func (b *backendMongo) AddSecondaryEmail(userID, secondaryEmail string) error { return nil } -func (b *backendMongo) SetPrimaryEmail(userID, secondaryEmail string) error { + +func (b *backendMongo) UpdatePrimaryEmail(userID, secondaryEmail string) error { return nil } -func (b *backendMongo) UpdatePassword(userID, newPassword string) error { - passwordHash, err := b.c.Hash(newPassword) - if err != nil { - return err - } - return b.users().UpdateId(bson.ObjectIdHex(userID), bson.M{"$set": bson.M{"passwordHash": passwordHash}}) -} -func (b *backendMongo) CreateEmailSession(email, emailVerifyHash, csrfToken, destinationURL string) error { + +func (b *backendMongo) CreateEmailSession(email string, info map[string]interface{}, emailVerifyHash, csrfToken string) error { s := b.emailSessions() c, _ := s.FindId(emailVerifyHash).Count() if c > 0 { return errors.New("invalid emailVerifyHash") } - return s.Insert(&emailSession{"", email, emailVerifyHash, csrfToken, destinationURL}) + return s.Insert(&emailSession{"", email, info, emailVerifyHash, csrfToken}) } func (b *backendMongo) GetEmailSession(verifyHash string) (*emailSession, error) { @@ -128,8 +149,8 @@ func (b *backendMongo) UpdateEmailSession(verifyHash, userID string) error { func (b *backendMongo) DeleteEmailSession(verifyHash string) error { return b.emailSessions().RemoveId(verifyHash) } -func (b *backendMongo) CreateSession(userID, email, fullname, sessionHash, csrfToken string, renewTimeUTC, expireTimeUTC time.Time) (*LoginSession, error) { - s := LoginSession{userID, email, fullname, sessionHash, csrfToken, renewTimeUTC, expireTimeUTC} +func (b *backendMongo) CreateSession(userID, email string, info map[string]interface{}, sessionHash, csrfToken string, renewTimeUTC, expireTimeUTC time.Time) (*LoginSession, error) { + s := LoginSession{userID, email, info, sessionHash, csrfToken, renewTimeUTC, expireTimeUTC} return &s, b.loginSessions().Insert(s) } diff --git a/backendRedisSession.go b/backendRedisSession.go index 4f7b857..9217ac4 100644 --- a/backendRedisSession.go +++ b/backendRedisSession.go @@ -20,8 +20,8 @@ func NewBackendRedisSession(server string, port int, password string, maxIdle, m } // need to first check that this emailVerifyHash isn't being used, otherwise we'll clobber existing -func (r *backendRedisSession) CreateEmailSession(email, emailVerifyHash, csrfToken, destinationURL string) error { - return r.saveEmailSession(&emailSession{"", email, emailVerifyHash, csrfToken, destinationURL}) +func (r *backendRedisSession) CreateEmailSession(email string, info map[string]interface{}, emailVerifyHash, csrfToken string) error { + return r.saveEmailSession(&emailSession{"", email, info, emailVerifyHash, csrfToken}) } func (r *backendRedisSession) GetEmailSession(emailVerifyHash string) (*emailSession, error) { @@ -42,8 +42,8 @@ func (r *backendRedisSession) DeleteEmailSession(emailVerifyHash string) error { return nil } -func (r *backendRedisSession) CreateSession(userID, email, fullname, sessionHash, csrfToken string, renewTimeUTC, expireTimeUTC time.Time) (*LoginSession, error) { - session := LoginSession{userID, email, fullname, sessionHash, csrfToken, renewTimeUTC, expireTimeUTC} +func (r *backendRedisSession) CreateSession(userID, email string, info map[string]interface{}, sessionHash, csrfToken string, renewTimeUTC, expireTimeUTC time.Time) (*LoginSession, error) { + session := LoginSession{userID, email, info, sessionHash, csrfToken, renewTimeUTC, expireTimeUTC} return &session, r.saveSession(&session) } diff --git a/backendRedisSession_test.go b/backendRedisSession_test.go index f49c7e1..e760f57 100644 --- a/backendRedisSession_test.go +++ b/backendRedisSession_test.go @@ -15,13 +15,13 @@ func TestRedisCreateSession(t *testing.T) { // expired session error m := onedb.NewMock(nil, nil, nil) r := backendRedisSession{db: m, prefix: "test"} - _, err := r.CreateSession("1", "test@test.com", "fullname", "hash", "csrfToken", time.Now(), time.Now()) + _, err := r.CreateSession("1", "test@test.com", map[string]interface{}{"info": "values"}, "hash", "csrfToken", time.Now(), time.Now()) if err == nil || len(m.QueriesRun()) != 0 { t.Error("expected error") } // success - session, err := r.CreateSession("1", "test@test.com", "fullname", "hash", "csrfToken", time.Now(), time.Now().AddDate(1, 0, 0)) + session, err := r.CreateSession("1", "test@test.com", map[string]interface{}{"info": "values"}, "hash", "csrfToken", time.Now(), time.Now().AddDate(1, 0, 0)) if q := m.QueriesRun(); err != nil || len(q) != 1 || q[0].(*onedb.RedisCommand).Command != "SETEX" || len(q[0].(*onedb.RedisCommand).Args) != 3 || q[0].(*onedb.RedisCommand).Args[0] != "test/session/hash" { t.Error("expected success", len(q), q[0].(*onedb.RedisCommand)) } diff --git a/backend_test.go b/backend_test.go index 5bc75e2..fd81651 100644 --- a/backend_test.go +++ b/backend_test.go @@ -20,8 +20,8 @@ func TestAuthError(t *testing.T) { } func TestBackendLogin(t *testing.T) { - m := &mockBackend{LoginReturn: loginSuccess()} - b := backend{u: m, l: m, s: m} + m := &mockBackend{} + b := backend{u: m, s: m} b.Login("email", "password") if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "Login" { t.Error("Expected it would call backend", m.MethodsCalled) @@ -30,8 +30,8 @@ func TestBackendLogin(t *testing.T) { func TestBackendCreateSession(t *testing.T) { m := &mockBackend{CreateSessionReturn: sessionSuccess(time.Now(), time.Now())} - b := backend{u: m, l: m, s: m} - b.CreateSession("1", "test@test.com", "fullname", "hash", "csrfToken", time.Now(), time.Now()) + b := backend{u: m, s: m} + b.CreateSession("1", "test@test.com", map[string]interface{}{"info": "values"}, "hash", "csrfToken", time.Now(), time.Now()) if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "CreateSession" { t.Error("Expected it would call backend", m.MethodsCalled) } @@ -39,7 +39,7 @@ func TestBackendCreateSession(t *testing.T) { func TestBackendCreateRemember(t *testing.T) { m := &mockBackend{CreateRememberMeReturn: rememberMe(time.Now(), time.Now())} - b := backend{u: m, l: m, s: m} + b := backend{u: m, s: m} b.CreateRememberMe("1", "test@test.com", "", "", time.Now(), time.Now()) if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "CreateRememberMe" { t.Error("Expected it would call backend", m.MethodsCalled) @@ -48,7 +48,7 @@ func TestBackendCreateRemember(t *testing.T) { func TestBackendGetSession(t *testing.T) { m := &mockBackend{GetSessionReturn: sessionErr()} - b := backend{u: m, l: m, s: m} + b := backend{u: m, s: m} b.GetSession("hash") if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "GetSession" { t.Error("Expected it would call backend", m.MethodsCalled) @@ -57,7 +57,7 @@ func TestBackendGetSession(t *testing.T) { func TestBackendUpdateSession(t *testing.T) { m := &mockBackend{UpdateSessionErr: errors.New("failed")} - b := backend{u: m, l: m, s: m} + b := backend{u: m, s: m} b.UpdateSession("hash", time.Now(), time.Now()) if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "UpdateSession" { t.Error("Expected it would call backend", m.MethodsCalled) @@ -66,7 +66,7 @@ func TestBackendUpdateSession(t *testing.T) { func TestBackendGetRememberMe(t *testing.T) { m := &mockBackend{GetRememberMeReturn: rememberErr()} - b := backend{u: m, l: m, s: m} + b := backend{u: m, s: m} b.GetRememberMe("selector") if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "GetRememberMe" { t.Error("Expected it would call backend", m.MethodsCalled) @@ -75,7 +75,7 @@ func TestBackendGetRememberMe(t *testing.T) { func TestBackendUpdateRememberMe(t *testing.T) { m := &mockBackend{UpdateRememberMeErr: errors.New("failed")} - b := backend{u: m, l: m, s: m} + b := backend{u: m, s: m} b.UpdateRememberMe("selector", time.Now()) if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "UpdateRememberMe" { t.Error("Expected it would call backend", m.MethodsCalled) @@ -84,8 +84,8 @@ func TestBackendUpdateRememberMe(t *testing.T) { func TestBackendAddUser(t *testing.T) { m := &mockBackend{AddUserErr: nil} - b := backend{u: m, l: m, s: m} - b.AddUser("mail") + b := backend{u: m, s: m} + b.AddUser("mail", map[string]interface{}{"info": "value"}) if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "AddUser" { t.Error("Expected it would call backend", m.MethodsCalled) } @@ -93,7 +93,7 @@ func TestBackendAddUser(t *testing.T) { func TestBackendGetEmailSession(t *testing.T) { m := &mockBackend{getEmailSessionReturn: getEmailSessionErr()} - b := backend{u: m, l: m, s: m} + b := backend{u: m, s: m} b.GetEmailSession("hash") if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "GetEmailSession" { t.Error("Expected it would call backend", m.MethodsCalled) @@ -102,43 +102,43 @@ func TestBackendGetEmailSession(t *testing.T) { func TestBackendUpdateUser(t *testing.T) { m := &mockBackend{} - b := backend{u: m, l: m, s: m} - b.UpdateUser("1", "name", "company", "url") + b := backend{u: m, s: m} + b.UpdateUser("1", "password", map[string]interface{}{"fullName": "name", "company": "companyName"}) if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "UpdateUser" { t.Error("Expected it would call backend", m.MethodsCalled) } } func TestBackendSetPrimaryEmail(t *testing.T) { - m := &mockBackend{SetPrimaryEmailErr: errors.New("fail")} - b := backend{u: m, l: m, s: m} - b.SetPrimaryEmail("userID", "newEmail") - if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "SetPrimaryEmail" { + m := &mockBackend{UpdatePrimaryEmailErr: errors.New("fail")} + b := backend{u: m, s: m} + b.UpdatePrimaryEmail("userID", "newEmail") + if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "UpdatePrimaryEmail" { t.Error("Expected it would call backend", m.MethodsCalled) } } func TestBackendUpdatePassword(t *testing.T) { m := &mockBackend{UpdatePasswordErr: errors.New("fail")} - b := backend{u: m, l: m, s: m} + b := backend{u: m, s: m} b.UpdatePassword("userID", "newPassword") if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "UpdatePassword" { t.Error("Expected it would call backend", m.MethodsCalled) } } -func TestBackendCreateSecondaryEmail(t *testing.T) { - m := &mockBackend{CreateSecondaryEmailErr: errors.New("fail")} - b := backend{u: m, l: m, s: m} - b.CreateSecondaryEmail("userID", "secondaryEmail@test.com") - if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "CreateSecondaryEmail" { +func TestBackendAddSecondaryEmail(t *testing.T) { + m := &mockBackend{AddSecondaryEmailErr: errors.New("fail")} + b := backend{u: m, s: m} + b.AddSecondaryEmail("userID", "secondaryEmail@test.com") + if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "AddSecondaryEmail" { t.Error("Expected it would call backend", m.MethodsCalled) } } func TestBackendDeleteSession(t *testing.T) { m := &mockBackend{} - b := backend{u: m, l: m, s: m} + b := backend{u: m, s: m} b.DeleteSession("hash") if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "DeleteSession" { t.Error("Expected it would call backend", m.MethodsCalled) @@ -147,7 +147,7 @@ func TestBackendDeleteSession(t *testing.T) { func TestBackendInvalidateSessions(t *testing.T) { m := &mockBackend{} - b := backend{u: m, l: m, s: m} + b := backend{u: m, s: m} b.InvalidateSessions("email") if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "InvalidateSessions" { t.Error("Expected it would call backend", m.MethodsCalled) @@ -156,7 +156,7 @@ func TestBackendInvalidateSessions(t *testing.T) { func TestBackendDeleteRememberMe(t *testing.T) { m := &mockBackend{} - b := backend{u: m, l: m, s: m} + b := backend{u: m, s: m} b.DeleteRememberMe("selector") if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "DeleteRememberMe" { t.Error("Expected it would call backend", m.MethodsCalled) @@ -166,16 +166,16 @@ func TestBackendDeleteRememberMe(t *testing.T) { func TestBackendClose(t *testing.T) { // all succeed m := &mockBackend{} - b := backend{u: m, l: m, s: m} + b := backend{u: m, s: m} b.Close() - if len(m.MethodsCalled) != 3 || m.MethodsCalled[0] != "Close" || m.MethodsCalled[1] != "Close" || m.MethodsCalled[2] != "Close" { + if len(m.MethodsCalled) != 2 || m.MethodsCalled[0] != "Close" || m.MethodsCalled[1] != "Close" { t.Error("Expected it would call backend", m.MethodsCalled) } // error on session close m = &mockBackend{} e := &mockBackend{ErrReturn: errors.New("failed")} - b = backend{u: m, l: m, s: e} + b = backend{u: m, s: e} b.Close() if len(m.MethodsCalled) != 0 || len(e.MethodsCalled) != 1 || e.MethodsCalled[0] != "Close" { t.Error("Expected fail on session close", m.MethodsCalled) @@ -184,27 +184,18 @@ func TestBackendClose(t *testing.T) { // error on user close m = &mockBackend{} e = &mockBackend{ErrReturn: errors.New("failed")} - b = backend{u: e, l: m, s: m} + b = backend{u: e, s: m} b.Close() if len(e.MethodsCalled) != 1 || len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "Close" || e.MethodsCalled[0] != "Close" { t.Error("Expected fail on user close", m.MethodsCalled) } - - // error on login close - m = &mockBackend{} - e = &mockBackend{ErrReturn: errors.New("failed")} - b = backend{u: m, l: e, s: m} - b.Close() - if len(m.MethodsCalled) != 2 || len(e.MethodsCalled) != 1 || m.MethodsCalled[0] != "Close" || m.MethodsCalled[1] != "Close" || e.MethodsCalled[0] != "Close" { - t.Error("Expected it would call backend", m.MethodsCalled) - } } /***********************************************************************/ -type LoginReturn struct { - Login *UserLogin - Err error +type UserReturn struct { + User *User + Err error } type SessionReturn struct { @@ -217,54 +208,56 @@ type RememberMeReturn struct { Err error } -type GetUserReturn struct { - User *user - Err error -} - -type getEmailSessionReturn struct { +type EmailSessionReturn struct { Session *emailSession Err error } type mockBackend struct { Backender - GetLoginReturn *LoginReturn - LoginReturn *LoginReturn - ExpirationReturn *time.Time - GetSessionReturn *SessionReturn - CreateSessionReturn *SessionReturn - CreateRememberMeReturn *RememberMeReturn - UpdateSessionErr error - AddUserErr error - DeleteEmailSessionErr error - UpdateEmailSessionErr error - GetUserReturn *GetUserReturn - getEmailSessionReturn *getEmailSessionReturn - CreateLoginReturn *LoginReturn - CreateSecondaryEmailErr error - SetPrimaryEmailErr error - UpdatePasswordErr error - GetRememberMeReturn *RememberMeReturn - UpdateRememberMeErr error - ErrReturn error - MethodsCalled []string -} - -func (b *mockBackend) GetLogin(email string) (*UserLogin, error) { + GetLoginReturn *UserReturn + LoginAndGetUserReturn *UserReturn + LoginErr error + ExpirationReturn *time.Time + GetSessionReturn *SessionReturn + CreateSessionReturn *SessionReturn + CreateRememberMeReturn *RememberMeReturn + UpdateSessionErr error + AddUserErr error + DeleteEmailSessionErr error + UpdateEmailSessionErr error + GetUserReturn *UserReturn + getEmailSessionReturn *EmailSessionReturn + AddSecondaryEmailErr error + UpdatePrimaryEmailErr error + UpdateUserErr error + UpdatePasswordErr error + UpdateInfoErr error + GetRememberMeReturn *RememberMeReturn + UpdateRememberMeErr error + ErrReturn error + MethodsCalled []string +} + +func (b *mockBackend) GetLogin(email string) (*User, error) { b.MethodsCalled = append(b.MethodsCalled, "GetLogin") if b.GetLoginReturn == nil { return nil, errors.New("GetLoginReturn not initialized") } - return b.GetLoginReturn.Login, b.GetLoginReturn.Err + return b.GetLoginReturn.User, b.GetLoginReturn.Err } -func (b *mockBackend) Login(email, password string) (*UserLogin, error) { - b.MethodsCalled = append(b.MethodsCalled, "Login") - if b.LoginReturn == nil { - return nil, errors.New("LoginReturn not initialized") +func (b *mockBackend) LoginAndGetUser(email, password string) (*User, error) { + b.MethodsCalled = append(b.MethodsCalled, "LoginAndGetUser") + if b.LoginAndGetUserReturn == nil { + return nil, errors.New("LoginAndGetUserReturn not initialized") } - return b.LoginReturn.Login, b.LoginReturn.Err + return b.LoginAndGetUserReturn.User, b.LoginAndGetUserReturn.Err +} + +func (b *mockBackend) Login(email, password string) error { + b.MethodsCalled = append(b.MethodsCalled, "Login") + return b.LoginErr } func (b *mockBackend) GetSession(sessionHash string) (*LoginSession, error) { @@ -275,7 +268,7 @@ func (b *mockBackend) GetSession(sessionHash string) (*LoginSession, error) { return b.GetSessionReturn.Session, b.GetSessionReturn.Err } -func (b *mockBackend) CreateSession(userID, email, fullname, sessionHash, csrfToken string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time) (*LoginSession, error) { +func (b *mockBackend) CreateSession(userID, email string, info map[string]interface{}, sessionHash, csrfToken string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time) (*LoginSession, error) { b.MethodsCalled = append(b.MethodsCalled, "CreateSession") if b.CreateSessionReturn == nil { return nil, errors.New("CreateSessionReturn not initialized") @@ -306,12 +299,12 @@ func (b *mockBackend) UpdateRememberMe(selector string, renewTimeUTC time.Time) b.MethodsCalled = append(b.MethodsCalled, "UpdateRememberMe") return b.UpdateRememberMeErr } -func (b *mockBackend) AddUser(email string) (string, error) { +func (b *mockBackend) AddUser(email string, info map[string]interface{}) (string, error) { b.MethodsCalled = append(b.MethodsCalled, "AddUser") return "1", b.AddUserErr } -func (b *mockBackend) CreateEmailSession(email, emailVerifyHash, csrfToken, destinationURL string) error { +func (b *mockBackend) CreateEmailSession(email string, info map[string]interface{}, emailVerifyHash, csrfToken string) error { b.MethodsCalled = append(b.MethodsCalled, "CreateEmailSession") return b.ErrReturn } @@ -334,7 +327,7 @@ func (b *mockBackend) DeleteEmailSession(emailVerifyHash string) error { return b.DeleteEmailSessionErr } -func (b *mockBackend) GetUser(email string) (*user, error) { +func (b *mockBackend) GetUser(email string) (*User, error) { b.MethodsCalled = append(b.MethodsCalled, "GetUser") if b.GetUserReturn == nil { return nil, errors.New("GetUserReturn not initialized") @@ -342,33 +335,30 @@ func (b *mockBackend) GetUser(email string) (*user, error) { return b.GetUserReturn.User, b.GetUserReturn.Err } -func (b *mockBackend) UpdateUser(userID, fullname, company, pictureURL string) error { +func (b *mockBackend) UpdateUser(userID, password string, info map[string]interface{}) error { b.MethodsCalled = append(b.MethodsCalled, "UpdateUser") - return b.ErrReturn + return b.UpdateUserErr } -func (b *mockBackend) CreateLogin(userID, email, password, fullName string) (*UserLogin, error) { - b.MethodsCalled = append(b.MethodsCalled, "CreateLogin") - if b.CreateLoginReturn == nil { - return nil, errors.New("CreateLoginReturn not initialized") - } - return b.CreateLoginReturn.Login, b.CreateLoginReturn.Err +func (b *mockBackend) UpdatePassword(userID, password string) error { + b.MethodsCalled = append(b.MethodsCalled, "UpdatePassword") + return b.UpdatePasswordErr } -func (b *mockBackend) CreateSecondaryEmail(userID, secondaryEmail string) error { - b.MethodsCalled = append(b.MethodsCalled, "CreateSecondaryEmail") - return b.CreateSecondaryEmailErr +func (b *mockBackend) UpdateInfo(userID string, info map[string]interface{}) error { + b.MethodsCalled = append(b.MethodsCalled, "UpdateInfo") + return b.UpdateInfoErr } -func (b *mockBackend) SetPrimaryEmail(userID, newPrimaryEmail string) error { - b.MethodsCalled = append(b.MethodsCalled, "SetPrimaryEmail") - return b.SetPrimaryEmailErr +func (b *mockBackend) AddSecondaryEmail(userID, secondaryEmail string) error { + b.MethodsCalled = append(b.MethodsCalled, "AddSecondaryEmail") + return b.AddSecondaryEmailErr } -func (b *mockBackend) UpdatePassword(userID, newPassword string) error { - b.MethodsCalled = append(b.MethodsCalled, "UpdatePassword") - return b.UpdatePasswordErr -} +func (b *mockBackend) UpdatePrimaryEmail(userID, newPrimaryEmail string) error { + b.MethodsCalled = append(b.MethodsCalled, "UpdatePrimaryEmail") + return b.UpdatePrimaryEmailErr +} func (b *mockBackend) DeleteSession(sessionHash string) error { b.MethodsCalled = append(b.MethodsCalled, "DeleteSession") return b.ErrReturn @@ -389,16 +379,16 @@ func (b *mockBackend) Close() error { return b.ErrReturn } -func loginSuccess() *LoginReturn { - return &LoginReturn{&UserLogin{Email: "test@test.com"}, nil} +func userSuccess() *UserReturn { + return &UserReturn{&User{Email: "test@test.com"}, nil} } -func loginErr() *LoginReturn { - return &LoginReturn{nil, errors.New("failed")} +func userErr() *UserReturn { + return &UserReturn{nil, errors.New("failed")} } func sessionSuccess(renewTimeUTC, expireTimeUTC time.Time) *SessionReturn { - return &SessionReturn{&LoginSession{"1", "test@test.com", "fullname", "sessionHash", "csrfToken", renewTimeUTC, expireTimeUTC}, nil} + return &SessionReturn{&LoginSession{"1", "test@test.com", map[string]interface{}{"info": "values"}, "sessionHash", "csrfToken", renewTimeUTC, expireTimeUTC}, nil} } func sessionErr() *SessionReturn { @@ -413,17 +403,9 @@ func rememberErr() *RememberMeReturn { return &RememberMeReturn{&rememberMeSession{}, errors.New("failed")} } -func getEmailSessionSuccess() *getEmailSessionReturn { - return &getEmailSessionReturn{&emailSession{Email: "email@test.com", EmailVerifyHash: "hash", DestinationURL: "destinationURL", CSRFToken: "csrfToken"}, nil} -} -func getEmailSessionErr() *getEmailSessionReturn { - return &getEmailSessionReturn{nil, errors.New("failed")} -} - -func getUserSuccess() *GetUserReturn { - return &GetUserReturn{&user{FullName: "name", PrimaryEmail: "test@test.com"}, nil} +func getEmailSessionSuccess() *EmailSessionReturn { + return &EmailSessionReturn{&emailSession{Email: "email@test.com", EmailVerifyHash: "hash", Info: map[string]interface{}{"key": "value"}, CSRFToken: "csrfToken"}, nil} } - -func getUserErr() *GetUserReturn { - return &GetUserReturn{nil, errors.New("failed")} +func getEmailSessionErr() *EmailSessionReturn { + return &EmailSessionReturn{nil, errors.New("failed")} } diff --git a/nginx/nginxauth.go b/nginx/nginxauth.go index 1397781..08952d2 100644 --- a/nginx/nginxauth.go +++ b/nginx/nginxauth.go @@ -12,24 +12,14 @@ import ( "github.com/EndFirstCorp/auth" "github.com/EndFirstCorp/configReader" + "github.com/EndFirstCorp/onedb/mgo" "github.com/gorilla/handlers" ) type authConf struct { AuthServerListenPort int StoragePrefix string - DbType string - DbServer string - DbPort int - DbUser string - DbDatabase string - DbPassword string - LdapServer string - LdapPort int - LdapBindDn string - LdapPassword string - LdapBaseDn string - LdapUserFilter string + ConnectionURI string GetSessionQuery string RenewSessionQuery string GetRememberMeQuery string @@ -105,17 +95,13 @@ func newNginxAuth(configFle, logfile string) (*nginxauth, error) { if err != nil { log.Fatal(err) } - - s := auth.NewBackendRedisSession(config.RedisServer, config.RedisPort, config.RedisPassword, config.RedisMaxIdle, config.RedisMaxConnections, config.StoragePrefix) - l, err := auth.NewBackendLDAPLogin(config.LdapServer, config.LdapPort, config.LdapBindDn, config.LdapPassword, config.LdapBaseDn, config.LdapUserFilter) - if err != nil { - return nil, err - } - u, err := auth.NewBackendDbUser(config.DbServer, config.DbPort, config.DbUser, config.DbPassword, config.DbDatabase, config.AddUserQuery, config.GetUserQuery, config.UpdateUserQuery, config.CreateSecondaryEmailQuery) + m, err := mgo.Dial(config.ConnectionURI) if err != nil { - return nil, err + log.Fatal(err) } - b := auth.NewBackend(u, l, s) + + s := auth.NewBackendRedisSession(config.RedisServer, config.RedisPort, config.RedisPassword, config.RedisMaxIdle, config.RedisMaxConnections, config.StoragePrefix) + b := auth.NewBackend(auth.NewBackendMongo(m, &auth.CryptoHashStore{}), s) mailer, err := config.NewEmailer() if err != nil { @@ -187,7 +173,7 @@ func authCookie(authStore auth.AuthStorer, w http.ResponseWriter, r *http.Reques return } - user, err := json.Marshal(&auth.UserLogin{Email: session.Email, UserID: session.UserID, FullName: session.FullName}) + user, err := json.Marshal(&auth.User{Email: session.Email, UserID: session.UserID, Info: session.Info}) if err != nil { authErr(w, r, err) return @@ -216,7 +202,7 @@ func authBasic(authStore auth.AuthStorer, w http.ResponseWriter, r *http.Request return } - user, err := json.Marshal(&auth.UserLogin{Email: session.Email, UserID: session.UserID, FullName: session.FullName}) + user, err := json.Marshal(&auth.User{Email: session.Email, UserID: session.UserID, Info: session.Info}) if err != nil { basicErr(w, r, err) return @@ -275,7 +261,7 @@ func runWithProfile(method func(http.ResponseWriter, *http.Request) (*auth.Login return } - user, err := json.Marshal(&auth.UserLogin{Email: s.Email, UserID: s.UserID, FullName: s.FullName}) + user, err := json.Marshal(&auth.User{Email: s.Email, UserID: s.UserID, Info: s.Info}) if err != nil { authErr(w, r, err) return diff --git a/nginx/nginxauth_test.go b/nginx/nginxauth_test.go index d1bb340..b3939fb 100644 --- a/nginx/nginxauth_test.go +++ b/nginx/nginxauth_test.go @@ -54,9 +54,9 @@ func TestAuth(t *testing.T) { } w = httptest.NewRecorder() - storer = &mockAuthStorer{SessionReturn: &auth.LoginSession{UserID: "1", Email: "test@test.com", FullName: "Name"}} + storer = &mockAuthStorer{SessionReturn: &auth.LoginSession{UserID: "1", Email: "test@test.com", Info: map[string]interface{}{"fullName": "Name"}}} authCookie(storer, w, nil) - if w.Header().Get("X-User") != `{"userID":"1","email":"test@test.com","fullName":"Name"}` || storer.LastRun != "GetSession" { + if w.Header().Get("X-User") != `{"userID":"1","email":"test@test.com","info":{"fullName":"Name"}}` || storer.LastRun != "GetSession" { t.Error("expected User header to be set", w.Header().Get("X-User"), storer.LastRun) } } @@ -73,7 +73,7 @@ func TestAuthBasic(t *testing.T) { w = httptest.NewRecorder() storer = &mockAuthStorer{SessionReturn: &auth.LoginSession{UserID: "0", Email: "test@test.com"}} authBasic(storer, w, nil) - if w.Header().Get("X-User") != `{"userID":"0","email":"test@test.com","fullName":""}` || storer.LastRun != "GetBasicAuth" { + if w.Header().Get("X-User") != `{"userID":"0","email":"test@test.com","info":null}` || storer.LastRun != "GetBasicAuth" { t.Error("expected User header to be set", w.Header().Get("X-User"), storer.LastRun) } } @@ -165,10 +165,10 @@ func TestAddUserHeader(t *testing.T) { /*******************************************************/ type mockAuthStorer struct { - SessionReturn *auth.LoginSession - DestinationURLReturn string - ErrReturn error - LastRun string + SessionReturn *auth.LoginSession + InfoReturn map[string]interface{} + ErrReturn error + LastRun string } func (s *mockAuthStorer) GetSession(w http.ResponseWriter, r *http.Request) (*auth.LoginSession, error) { @@ -196,9 +196,9 @@ func (s *mockAuthStorer) CreateProfile(w http.ResponseWriter, r *http.Request) ( s.LastRun = "CreateProfile" return s.SessionReturn, s.ErrReturn } -func (s *mockAuthStorer) VerifyEmail(w http.ResponseWriter, r *http.Request) (string, string, error) { +func (s *mockAuthStorer) VerifyEmail(w http.ResponseWriter, r *http.Request) (string, map[string]interface{}, error) { s.LastRun = "VerifyEmail" - return "csrfToken", s.DestinationURLReturn, s.ErrReturn + return "csrfToken", s.InfoReturn, s.ErrReturn } func (s *mockAuthStorer) CreateSecondaryEmail(w http.ResponseWriter, r *http.Request) error { s.LastRun = "CreateSecondaryEmail"