From aba3568e673e13163f12ca53b9e64c246a960dba Mon Sep 17 00:00:00 2001 From: Aaron Zielstorff Date: Sun, 1 Sep 2024 19:24:43 +0000 Subject: [PATCH] Improves user verification feature --- backend/aashub/api/handler/users.go | 18 +++++--- backend/aashub/cmd/aashub/main.go | 10 ++++- .../database/repositories/mailverification.go | 4 ++ .../internal/database/repositories/users.go | 43 +++++++++++++++---- .../database/repositories/verification.go | 17 +++++++- .../internal/interfaces/verification.go | 1 + backend/aashub/internal/mail/mail.go | 18 ++++---- backend/aashub/tests/integration/user_test.go | 16 ++++--- backend/aashub/tests/unit/jwt_test.go | 4 +- backend/aashub/tests/unit/user_test.go | 12 ++++-- 10 files changed, 103 insertions(+), 40 deletions(-) diff --git a/backend/aashub/api/handler/users.go b/backend/aashub/api/handler/users.go index fa7ea80..2fb3cd9 100644 --- a/backend/aashub/api/handler/users.go +++ b/backend/aashub/api/handler/users.go @@ -51,8 +51,12 @@ func (h *UserHandler) RegisterUser(w http.ResponseWriter, r *http.Request) { err = h.Repo.RegisterUser(user.Username, user.Email, user.Password) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + switch err { + case repositories.ErrUserRepoEmailUsernameExists: + http.Error(w, "Email or username already exists", http.StatusBadRequest) + default: + http.Error(w, err.Error(), http.StatusInternalServerError) + } } w.WriteHeader(http.StatusCreated) @@ -131,12 +135,14 @@ func (h *UserHandler) LoginUser(w http.ResponseWriter, r *http.Request) { token, err := h.Repo.LoginUser(identifier, password) if err != nil { - if err == repositories.ErrUserRepoNotFound { + switch err { + case repositories.ErrUserRepoNotFound: http.Error(w, "User not found", http.StatusNotFound) - return + case repositories.ErrUserRepoNotVerified: + http.Error(w, "User not verified", http.StatusForbidden) + default: + http.Error(w, err.Error(), http.StatusInternalServerError) } - http.Error(w, err.Error(), http.StatusInternalServerError) - return } // Create a cookie diff --git a/backend/aashub/cmd/aashub/main.go b/backend/aashub/cmd/aashub/main.go index cb990a6..412806a 100644 --- a/backend/aashub/cmd/aashub/main.go +++ b/backend/aashub/cmd/aashub/main.go @@ -12,6 +12,7 @@ import ( docs "github.com/aas-hub-org/aashub/docs" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" + "github.com/joho/godotenv" swaggerfiles "github.com/swaggo/files" ginSwagger "github.com/swaggo/gin-swagger" ) @@ -46,12 +47,17 @@ func main() { // Initialize database database, err := database.NewDB() if err != nil { - log.Fatalf("Could not connect to the database: %v", err) + log.Printf("Could not connect to the database: %v", err) + } + + env_err := godotenv.Load("/workspace/backend/aashub/.env") + if env_err != nil { + log.Printf("Error loading .env file") } // Initialize repositories verificationRepo := &repositories.VerificationRepository{DB: database} - mailVerificationRepo := &repositories.EmailVerificationRepository{VerificationRepository: verificationRepo} + mailVerificationRepo := &repositories.VerificationRepository{DB: database} userRepo := &repositories.UserRepository{DB: database, VerificationRepository: mailVerificationRepo} // Initialize handlers diff --git a/backend/aashub/internal/database/repositories/mailverification.go b/backend/aashub/internal/database/repositories/mailverification.go index af49d79..c575e24 100644 --- a/backend/aashub/internal/database/repositories/mailverification.go +++ b/backend/aashub/internal/database/repositories/mailverification.go @@ -38,3 +38,7 @@ func (e *EmailVerificationRepository) Verify(email string, verificationCode stri } return "", nil } + +func (e *EmailVerificationRepository) IsVerified(email string) (bool, error) { + return e.VerificationRepository.IsVerified(email) +} diff --git a/backend/aashub/internal/database/repositories/users.go b/backend/aashub/internal/database/repositories/users.go index f660fea..780c72c 100644 --- a/backend/aashub/internal/database/repositories/users.go +++ b/backend/aashub/internal/database/repositories/users.go @@ -4,6 +4,8 @@ import ( "database/sql" "errors" "log" + "os" + "strings" auth "github.com/aas-hub-org/aashub/internal/auth" interfaces "github.com/aas-hub-org/aashub/internal/interfaces" @@ -14,6 +16,8 @@ import ( ) var ErrUserRepoNotFound = errors.New("identifier or password wrong") +var ErrUserRepoNotVerified = errors.New("user not verified") +var ErrUserRepoEmailUsernameExists = errors.New("email or username already exists") type UserRepository struct { DB *sql.DB @@ -39,20 +43,23 @@ func (repo *UserRepository) RegisterUser(username string, email string, password userid := uuid.New().String() hashedpassword, err := HashPassword(password) if err != nil { - log.Fatalf("Error hashing password: %v", err) + log.Printf("Error hashing password: %v", err) return err } _, err = repo.DB.Exec("INSERT INTO Users (id, username, email, password_hash) VALUES (?, ?, ?, ?)", userid, username, email, hashedpassword) if err != nil { - log.Fatalf("Error inserting user: %v", err) + // Check if error includes "Duplicate" + if strings.Contains(err.Error(), "Duplicate") { + return ErrUserRepoEmailUsernameExists + } return err } - - _, err = repo.VerificationRepository.CreateVerification(email) - + if IsVerificationEnabled() { + _, err = repo.VerificationRepository.CreateVerification(email) + } if err != nil { - log.Fatalf("Error inserting verification: %v", err) + log.Printf("Error inserting verification: %v", err) return err } @@ -68,21 +75,41 @@ func (repo *UserRepository) LoginUser(identifier string, password string) (strin if err != nil { return "", ErrUserRepoNotFound } + if IsVerificationEnabled() { + isVerified, err := repo.VerificationRepository.IsVerified(user.Email) + if err != nil { + log.Printf("Error checking verification: %v", err) + return "", err + } + if !isVerified { + return "", ErrUserRepoNotVerified + } + } + if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil { return "", ErrUserRepoNotFound } + secret, fileReadError := utils.ReadFile("/workspace/backend/aashub/privatekey.txt") if fileReadError != nil { - log.Fatalf("Error reading file: %v", fileReadError) + log.Printf("Error reading file: %v", fileReadError) return "", fileReadError } jwt, err := auth.GenerateJWT(user.ID, secret) if err != nil { - log.Fatalf("Error generating JWT: %v", err) + log.Printf("Error generating JWT: %v", err) return "", err } return jwt, nil } + +func IsVerificationEnabled() bool { + enabled, found := os.LookupEnv("VERIFICATION_ENABLED") + if !found { + return false + } + return enabled == "true" +} diff --git a/backend/aashub/internal/database/repositories/verification.go b/backend/aashub/internal/database/repositories/verification.go index e8b6e90..4e64112 100644 --- a/backend/aashub/internal/database/repositories/verification.go +++ b/backend/aashub/internal/database/repositories/verification.go @@ -49,7 +49,7 @@ func (v *VerificationRepository) Verify(email string, verificationCode string) ( result, select_err := v.DB.Query("SELECT * FROM Verifications WHERE email = ? AND verification_code = ? AND verified = ?", email, verificationCode, false) if select_err != nil { - log.Fatalf(select_err.Error()) + log.Printf(select_err.Error()) return "system", select_err } @@ -61,9 +61,22 @@ func (v *VerificationRepository) Verify(email string, verificationCode string) ( _, err := v.DB.Exec("UPDATE Verifications SET verified = ? WHERE email = ? AND verification_code = ?", true, email, verificationCode) if err != nil { - log.Fatalf(err.Error()) + log.Printf(err.Error()) return "system", err } return "", nil } + +func (v *VerificationRepository) IsVerified(email string) (bool, error) { + result, select_err := v.DB.Query("SELECT * FROM Verifications WHERE email = ? && verified = false", email) + + if select_err != nil { + log.Printf(select_err.Error()) + return false, select_err + } + + defer result.Close() + + return !result.Next(), nil +} diff --git a/backend/aashub/internal/interfaces/verification.go b/backend/aashub/internal/interfaces/verification.go index 5c8ab99..ea06269 100644 --- a/backend/aashub/internal/interfaces/verification.go +++ b/backend/aashub/internal/interfaces/verification.go @@ -3,4 +3,5 @@ package interfaces type VerificationRepositoryInterface interface { CreateVerification(email string) (string, error) Verify(email string, verificationCode string) (string, error) + IsVerified(email string) (bool, error) } diff --git a/backend/aashub/internal/mail/mail.go b/backend/aashub/internal/mail/mail.go index 6fbbeb6..0e8bf0a 100644 --- a/backend/aashub/internal/mail/mail.go +++ b/backend/aashub/internal/mail/mail.go @@ -14,7 +14,7 @@ func SendEmail(to, subject, body string) error { // Load .env env_err := godotenv.Load("/workspace/backend/aashub/.env") if env_err != nil { - log.Fatalf("Error loading .env file") + log.Printf("Error loading .env file") } // Get from .env @@ -41,49 +41,49 @@ func SendEmail(to, subject, body string) error { // Connect to the SMTP Server conn, err := tls.Dial("tcp", smtpHost+":"+smtpPort, tlsconfig) if err != nil { - log.Fatalf("Error connecting to SMTP server: %v", err) + log.Printf("Error connecting to SMTP server: %v", err) return err } client, err := smtp.NewClient(conn, smtpHost) if err != nil { - log.Fatalf("Error creating SMTP client: %v", err) + log.Printf("Error creating SMTP client: %v", err) return err } // Authentication auth := smtp.PlainAuth("", from, pass, smtpHost) if err = client.Auth(auth); err != nil { - log.Fatalf("Error authenticating: %v", err) + log.Printf("Error authenticating: %v", err) return err } // To && From if err = client.Mail(from); err != nil { - log.Fatalf("Error setting sender: %v", err) + log.Printf("Error setting sender: %v", err) return err } if err = client.Rcpt(to); err != nil { - log.Fatalf("Error setting recipient: %v", err) + log.Printf("Error setting recipient: %v", err) return err } // Data w, err := client.Data() if err != nil { - log.Fatalf("Error getting SMTP data writer: %v", err) + log.Printf("Error getting SMTP data writer: %v", err) return err } _, err = w.Write(message) if err != nil { - log.Fatalf("Error writing message: %v", err) + log.Printf("Error writing message: %v", err) return err } err = w.Close() if err != nil { - log.Fatalf("Error closing SMTP data writer: %v", err) + log.Printf("Error closing SMTP data writer: %v", err) return err } diff --git a/backend/aashub/tests/integration/user_test.go b/backend/aashub/tests/integration/user_test.go index 31c3816..0e52610 100644 --- a/backend/aashub/tests/integration/user_test.go +++ b/backend/aashub/tests/integration/user_test.go @@ -9,6 +9,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "os" "io" "log" @@ -37,15 +38,16 @@ func teardown(database *sql.DB) { // Execute the query for the test username if _, err := database.Exec(query, "testuser"); err != nil { - log.Fatalf("Failed to clean up test user: %v", err) + log.Printf("Failed to clean up test user: %v", err) } } func TestRegisterUser(t *testing.T) { + os.Setenv("VERIFICATION_ENABLED", "true") // Initialize the database connection database, err := db.NewDB() if err != nil { - log.Fatalf("Could not connect to the database: %v", err) + log.Printf("Could not connect to the database: %v", err) } // Ensure teardown is called no matter what happens in the test @@ -121,12 +123,12 @@ func TestRegisterUser(t *testing.T) { // Assert the status code if resp.StatusCode != tc.expectedStatus { - t.Errorf("Expected status code %d, got %d", tc.expectedStatus, resp.StatusCode) + t.Fatalf("Expected status code %d, got %d", tc.expectedStatus, resp.StatusCode) } // Assert the response body if expected if tc.expectedBody != "" && !strings.Contains(string(responseBody), tc.expectedBody) { - t.Errorf("Expected response body to contain %q, got %q", tc.expectedBody, string(responseBody)) + t.Fatalf("Expected response body to contain %q, got %q", tc.expectedBody, string(responseBody)) } // Call the verify user function @@ -177,10 +179,10 @@ func verifyUser(t *testing.T, tc testCase, ts *httptest.Server, database *sql.DB // Assert the verification status code and response body if verifyResp.StatusCode != http.StatusOK { - t.Errorf("Expected verification status code %d, got %d", http.StatusOK, verifyResp.StatusCode) + t.Fatalf("Expected verification status code %d, got %d", http.StatusOK, verifyResp.StatusCode) } if !strings.Contains(string(verifyResponseBody), "User verified successfully") { - t.Errorf("Expected verification response body to contain %q, got %q", "User verified successfully", string(verifyResponseBody)) + t.Fatalf("Expected verification response body to contain %q, got %q", "User verified successfully", string(verifyResponseBody)) } } } @@ -227,6 +229,6 @@ func TestLoginUser(t *testing.T) { // Check the status code is what we expect if status := rr.Code; status != http.StatusNoContent { - t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusNoContent) + t.Fatalf("handler returned wrong status code: got %v want %v", status, http.StatusNoContent) } } diff --git a/backend/aashub/tests/unit/jwt_test.go b/backend/aashub/tests/unit/jwt_test.go index aa74386..697e17a 100644 --- a/backend/aashub/tests/unit/jwt_test.go +++ b/backend/aashub/tests/unit/jwt_test.go @@ -28,7 +28,7 @@ func TestGenerateJWTAndValidate(t *testing.T) { } if !isValid { - t.Errorf("The token was expected to be valid") + t.Fatalf("The token was expected to be valid") } } @@ -56,7 +56,7 @@ func TestGenerateJWTAndValidateWithManipulatedPayload(t *testing.T) { } if !isValid { - t.Errorf("The token was expected to be valid") + t.Fatalf("The token was expected to be valid") } // Split token at . diff --git a/backend/aashub/tests/unit/user_test.go b/backend/aashub/tests/unit/user_test.go index 8938857..ad2972a 100644 --- a/backend/aashub/tests/unit/user_test.go +++ b/backend/aashub/tests/unit/user_test.go @@ -45,6 +45,10 @@ func (m *MockRepository) Verify(email, code string) (string, error) { return "", nil } +func (m *MockRepository) IsVerified(email string) (bool, error) { + return true, nil +} + func TestRegisterUser_Success(t *testing.T) { mockRepo := &MockRepository{} handler := api.UserHandler{Repo: mockRepo} @@ -129,13 +133,13 @@ func TestVerifyUser_Success(t *testing.T) { // Check the status code is what we expect. if status := rr.Code; status != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) + t.Fatalf("handler returned wrong status code: got %v want %v", status, http.StatusOK) } // Check the response body is what we expect. expected := "User verified successfully" if rr.Body.String() != expected { - t.Errorf("handler returned unexpected body: got %v want %v", rr.Body.String(), expected) + t.Fatalf("handler returned unexpected body: got %v want %v", rr.Body.String(), expected) } } @@ -168,7 +172,7 @@ func TestVerifyUser_Failure(t *testing.T) { // Check the status code is what we expect for failure. if status := rr.Code; status != http.StatusBadRequest { - t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusBadRequest) + t.Fatalf("handler returned wrong status code: got %v want %v", status, http.StatusBadRequest) } // Check the response body is what we expect for failure. @@ -176,6 +180,6 @@ func TestVerifyUser_Failure(t *testing.T) { expected := strings.TrimSpace("Invalid email or code") if actual != expected { - t.Errorf("handler returned unexpected body: got %v want %v", actual, expected) + t.Fatalf("handler returned unexpected body: got %v want %v", actual, expected) } }