Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improves user verification feature #35

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions backend/aashub/api/handler/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@ func (h *UserHandler) RegisterUser(w http.ResponseWriter, r *http.Request) {

err = h.Repo.RegisterUser(user.Username, user.Email, user.Password)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
switch err {
case repositories.ErrUserRepoEmailUsernameExists:
http.Error(w, "Email or username already exists", http.StatusBadRequest)
default:
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

w.WriteHeader(http.StatusCreated)
Expand Down Expand Up @@ -131,12 +135,14 @@ func (h *UserHandler) LoginUser(w http.ResponseWriter, r *http.Request) {

token, err := h.Repo.LoginUser(identifier, password)
if err != nil {
if err == repositories.ErrUserRepoNotFound {
switch err {
case repositories.ErrUserRepoNotFound:
http.Error(w, "User not found", http.StatusNotFound)
return
case repositories.ErrUserRepoNotVerified:
http.Error(w, "User not verified", http.StatusForbidden)
default:
http.Error(w, err.Error(), http.StatusInternalServerError)
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

// Create a cookie
Expand Down
10 changes: 8 additions & 2 deletions backend/aashub/cmd/aashub/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
docs "github.com/aas-hub-org/aashub/docs"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/joho/godotenv"
swaggerfiles "github.com/swaggo/files"
ginSwagger "github.com/swaggo/gin-swagger"
)
Expand Down Expand Up @@ -46,12 +47,17 @@ func main() {
// Initialize database
database, err := database.NewDB()
if err != nil {
log.Fatalf("Could not connect to the database: %v", err)
log.Printf("Could not connect to the database: %v", err)
}

env_err := godotenv.Load("/workspace/backend/aashub/.env")
if env_err != nil {
log.Printf("Error loading .env file")
}

// Initialize repositories
verificationRepo := &repositories.VerificationRepository{DB: database}
mailVerificationRepo := &repositories.EmailVerificationRepository{VerificationRepository: verificationRepo}
mailVerificationRepo := &repositories.VerificationRepository{DB: database}
userRepo := &repositories.UserRepository{DB: database, VerificationRepository: mailVerificationRepo}

// Initialize handlers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@ func (e *EmailVerificationRepository) Verify(email string, verificationCode stri
}
return "", nil
}

func (e *EmailVerificationRepository) IsVerified(email string) (bool, error) {
return e.VerificationRepository.IsVerified(email)
}
43 changes: 35 additions & 8 deletions backend/aashub/internal/database/repositories/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"database/sql"
"errors"
"log"
"os"
"strings"

auth "github.com/aas-hub-org/aashub/internal/auth"
interfaces "github.com/aas-hub-org/aashub/internal/interfaces"
Expand All @@ -14,6 +16,8 @@ import (
)

var ErrUserRepoNotFound = errors.New("identifier or password wrong")
var ErrUserRepoNotVerified = errors.New("user not verified")
var ErrUserRepoEmailUsernameExists = errors.New("email or username already exists")

type UserRepository struct {
DB *sql.DB
Expand All @@ -39,20 +43,23 @@ func (repo *UserRepository) RegisterUser(username string, email string, password
userid := uuid.New().String()
hashedpassword, err := HashPassword(password)
if err != nil {
log.Fatalf("Error hashing password: %v", err)
log.Printf("Error hashing password: %v", err)
return err
}
_, err = repo.DB.Exec("INSERT INTO Users (id, username, email, password_hash) VALUES (?, ?, ?, ?)", userid, username, email, hashedpassword)

if err != nil {
log.Fatalf("Error inserting user: %v", err)
// Check if error includes "Duplicate"
if strings.Contains(err.Error(), "Duplicate") {
return ErrUserRepoEmailUsernameExists
}
return err
}

_, err = repo.VerificationRepository.CreateVerification(email)

if IsVerificationEnabled() {
_, err = repo.VerificationRepository.CreateVerification(email)
}
if err != nil {
log.Fatalf("Error inserting verification: %v", err)
log.Printf("Error inserting verification: %v", err)
return err
}

Expand All @@ -68,21 +75,41 @@ func (repo *UserRepository) LoginUser(identifier string, password string) (strin
if err != nil {
return "", ErrUserRepoNotFound
}
if IsVerificationEnabled() {
isVerified, err := repo.VerificationRepository.IsVerified(user.Email)
if err != nil {
log.Printf("Error checking verification: %v", err)
return "", err
}
if !isVerified {
return "", ErrUserRepoNotVerified
}
}

if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
return "", ErrUserRepoNotFound
}

secret, fileReadError := utils.ReadFile("/workspace/backend/aashub/privatekey.txt")

if fileReadError != nil {
log.Fatalf("Error reading file: %v", fileReadError)
log.Printf("Error reading file: %v", fileReadError)
return "", fileReadError
}

jwt, err := auth.GenerateJWT(user.ID, secret)
if err != nil {
log.Fatalf("Error generating JWT: %v", err)
log.Printf("Error generating JWT: %v", err)
return "", err
}

return jwt, nil
}

func IsVerificationEnabled() bool {
enabled, found := os.LookupEnv("VERIFICATION_ENABLED")
if !found {
return false
}
return enabled == "true"
}
17 changes: 15 additions & 2 deletions backend/aashub/internal/database/repositories/verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (v *VerificationRepository) Verify(email string, verificationCode string) (
result, select_err := v.DB.Query("SELECT * FROM Verifications WHERE email = ? AND verification_code = ? AND verified = ?", email, verificationCode, false)

if select_err != nil {
log.Fatalf(select_err.Error())
log.Printf(select_err.Error())
return "system", select_err
}

Expand All @@ -61,9 +61,22 @@ func (v *VerificationRepository) Verify(email string, verificationCode string) (

_, err := v.DB.Exec("UPDATE Verifications SET verified = ? WHERE email = ? AND verification_code = ?", true, email, verificationCode)
if err != nil {
log.Fatalf(err.Error())
log.Printf(err.Error())
return "system", err
}

return "", nil
}

func (v *VerificationRepository) IsVerified(email string) (bool, error) {
result, select_err := v.DB.Query("SELECT * FROM Verifications WHERE email = ? && verified = false", email)

if select_err != nil {
log.Printf(select_err.Error())
return false, select_err
}

defer result.Close()

return !result.Next(), nil
}
1 change: 1 addition & 0 deletions backend/aashub/internal/interfaces/verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ package interfaces
type VerificationRepositoryInterface interface {
CreateVerification(email string) (string, error)
Verify(email string, verificationCode string) (string, error)
IsVerified(email string) (bool, error)
}
18 changes: 9 additions & 9 deletions backend/aashub/internal/mail/mail.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func SendEmail(to, subject, body string) error {
// Load .env
env_err := godotenv.Load("/workspace/backend/aashub/.env")
if env_err != nil {
log.Fatalf("Error loading .env file")
log.Printf("Error loading .env file")
}

// Get from .env
Expand All @@ -41,49 +41,49 @@ func SendEmail(to, subject, body string) error {
// Connect to the SMTP Server
conn, err := tls.Dial("tcp", smtpHost+":"+smtpPort, tlsconfig)
if err != nil {
log.Fatalf("Error connecting to SMTP server: %v", err)
log.Printf("Error connecting to SMTP server: %v", err)
return err
}

client, err := smtp.NewClient(conn, smtpHost)
if err != nil {
log.Fatalf("Error creating SMTP client: %v", err)
log.Printf("Error creating SMTP client: %v", err)
return err
}

// Authentication
auth := smtp.PlainAuth("", from, pass, smtpHost)
if err = client.Auth(auth); err != nil {
log.Fatalf("Error authenticating: %v", err)
log.Printf("Error authenticating: %v", err)
return err
}

// To && From
if err = client.Mail(from); err != nil {
log.Fatalf("Error setting sender: %v", err)
log.Printf("Error setting sender: %v", err)
return err
}
if err = client.Rcpt(to); err != nil {
log.Fatalf("Error setting recipient: %v", err)
log.Printf("Error setting recipient: %v", err)
return err
}

// Data
w, err := client.Data()
if err != nil {
log.Fatalf("Error getting SMTP data writer: %v", err)
log.Printf("Error getting SMTP data writer: %v", err)
return err
}

_, err = w.Write(message)
if err != nil {
log.Fatalf("Error writing message: %v", err)
log.Printf("Error writing message: %v", err)
return err
}

err = w.Close()
if err != nil {
log.Fatalf("Error closing SMTP data writer: %v", err)
log.Printf("Error closing SMTP data writer: %v", err)
return err
}

Expand Down
16 changes: 9 additions & 7 deletions backend/aashub/tests/integration/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"os"

"io"
"log"
Expand Down Expand Up @@ -37,15 +38,16 @@ func teardown(database *sql.DB) {

// Execute the query for the test username
if _, err := database.Exec(query, "testuser"); err != nil {
log.Fatalf("Failed to clean up test user: %v", err)
log.Printf("Failed to clean up test user: %v", err)
}
}

func TestRegisterUser(t *testing.T) {
os.Setenv("VERIFICATION_ENABLED", "true")
// Initialize the database connection
database, err := db.NewDB()
if err != nil {
log.Fatalf("Could not connect to the database: %v", err)
log.Printf("Could not connect to the database: %v", err)
}

// Ensure teardown is called no matter what happens in the test
Expand Down Expand Up @@ -121,12 +123,12 @@ func TestRegisterUser(t *testing.T) {

// Assert the status code
if resp.StatusCode != tc.expectedStatus {
t.Errorf("Expected status code %d, got %d", tc.expectedStatus, resp.StatusCode)
t.Fatalf("Expected status code %d, got %d", tc.expectedStatus, resp.StatusCode)
}

// Assert the response body if expected
if tc.expectedBody != "" && !strings.Contains(string(responseBody), tc.expectedBody) {
t.Errorf("Expected response body to contain %q, got %q", tc.expectedBody, string(responseBody))
t.Fatalf("Expected response body to contain %q, got %q", tc.expectedBody, string(responseBody))
}

// Call the verify user function
Expand Down Expand Up @@ -177,10 +179,10 @@ func verifyUser(t *testing.T, tc testCase, ts *httptest.Server, database *sql.DB

// Assert the verification status code and response body
if verifyResp.StatusCode != http.StatusOK {
t.Errorf("Expected verification status code %d, got %d", http.StatusOK, verifyResp.StatusCode)
t.Fatalf("Expected verification status code %d, got %d", http.StatusOK, verifyResp.StatusCode)
}
if !strings.Contains(string(verifyResponseBody), "User verified successfully") {
t.Errorf("Expected verification response body to contain %q, got %q", "User verified successfully", string(verifyResponseBody))
t.Fatalf("Expected verification response body to contain %q, got %q", "User verified successfully", string(verifyResponseBody))
}
}
}
Expand Down Expand Up @@ -227,6 +229,6 @@ func TestLoginUser(t *testing.T) {

// Check the status code is what we expect
if status := rr.Code; status != http.StatusNoContent {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusNoContent)
t.Fatalf("handler returned wrong status code: got %v want %v", status, http.StatusNoContent)
}
}
4 changes: 2 additions & 2 deletions backend/aashub/tests/unit/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestGenerateJWTAndValidate(t *testing.T) {
}

if !isValid {
t.Errorf("The token was expected to be valid")
t.Fatalf("The token was expected to be valid")
}
}

Expand Down Expand Up @@ -56,7 +56,7 @@ func TestGenerateJWTAndValidateWithManipulatedPayload(t *testing.T) {
}

if !isValid {
t.Errorf("The token was expected to be valid")
t.Fatalf("The token was expected to be valid")
}

// Split token at .
Expand Down
12 changes: 8 additions & 4 deletions backend/aashub/tests/unit/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ func (m *MockRepository) Verify(email, code string) (string, error) {
return "", nil
}

func (m *MockRepository) IsVerified(email string) (bool, error) {
return true, nil
}

func TestRegisterUser_Success(t *testing.T) {
mockRepo := &MockRepository{}
handler := api.UserHandler{Repo: mockRepo}
Expand Down Expand Up @@ -129,13 +133,13 @@ func TestVerifyUser_Success(t *testing.T) {

// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
t.Fatalf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
}

// Check the response body is what we expect.
expected := "User verified successfully"
if rr.Body.String() != expected {
t.Errorf("handler returned unexpected body: got %v want %v", rr.Body.String(), expected)
t.Fatalf("handler returned unexpected body: got %v want %v", rr.Body.String(), expected)
}
}

Expand Down Expand Up @@ -168,14 +172,14 @@ func TestVerifyUser_Failure(t *testing.T) {

// Check the status code is what we expect for failure.
if status := rr.Code; status != http.StatusBadRequest {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusBadRequest)
t.Fatalf("handler returned wrong status code: got %v want %v", status, http.StatusBadRequest)
}

// Check the response body is what we expect for failure.
actual := strings.TrimSpace(rr.Body.String())
expected := strings.TrimSpace("Invalid email or code")

if actual != expected {
t.Errorf("handler returned unexpected body: got %v want %v", actual, expected)
t.Fatalf("handler returned unexpected body: got %v want %v", actual, expected)
}
}
Loading