From c8877535762baa74926b7f4e21d758eb6e82e507 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Sun, 16 Apr 2023 14:32:45 +0800 Subject: [PATCH] rate limit (#118) --- api/auth_user_handler.go | 1 - api/middleware_authenticate.go | 8 +++++++- api/middleware_rateLimit.go | 18 ++++++------------ api/sqlc/queries/auth_user_management.sql | 3 ++- api/sqlc_queries/auth_user_management.sql.go | 2 +- web/src/api/index.ts | 11 ----------- 6 files changed, 16 insertions(+), 27 deletions(-) diff --git a/api/auth_user_handler.go b/api/auth_user_handler.go index fe85c67c..73ad52f0 100644 --- a/api/auth_user_handler.go +++ b/api/auth_user_handler.go @@ -30,7 +30,6 @@ func (h *AuthUserHandler) Register(router *mux.Router) { router.HandleFunc("/users/{id}", h.UpdateUser).Methods(http.MethodPut) router.HandleFunc("/signup", h.SignUp).Methods(http.MethodPost) router.HandleFunc("/login", h.Login).Methods(http.MethodPost) - router.HandleFunc("/verify", h.verify).Methods(http.MethodPost) router.HandleFunc("/config", h.configHandler).Methods(http.MethodPost) // rate limit handler router.HandleFunc("/admin/rate_limit", h.UpdateRateLimit).Methods(http.MethodPost) diff --git a/api/middleware_authenticate.go b/api/middleware_authenticate.go index 5bfc364b..ba6831c2 100644 --- a/api/middleware_authenticate.go +++ b/api/middleware_authenticate.go @@ -48,9 +48,14 @@ const ( ) func IsAuthorizedMiddleware(handler http.Handler) http.Handler { + noAuthPaths := map[string]bool{ + "/": true, + "/login": true, + "/signup": true, + } jwtSigningKey := []byte(jwtSecretAndAud.Secret) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" || r.URL.Path == "/verify" || r.URL.Path == "/login" || r.URL.Path == "/signup" || strings.HasPrefix(r.URL.Path, "/static") { + if _, ok := noAuthPaths[r.URL.Path]; ok || strings.HasPrefix(r.URL.Path, "/static") { handler.ServeHTTP(w, r) return } @@ -113,3 +118,4 @@ func IsAuthorizedMiddleware(handler http.Handler) http.Handler { } }) } + diff --git a/api/middleware_rateLimit.go b/api/middleware_rateLimit.go index 6c8489c9..ba78c88c 100644 --- a/api/middleware_rateLimit.go +++ b/api/middleware_rateLimit.go @@ -1,10 +1,7 @@ package main import ( - "database/sql" - "errors" "net/http" - "strconv" "github.com/rotisserie/eris" "github.com/swuecho/chat_backend/sqlc_queries" @@ -14,13 +11,15 @@ import ( func RateLimitByUserID(q *sqlc_queries.Queries) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get the user ID from the request, e.g. from a JWT token. if r.URL.Path == "/chat" || r.URL.Path == "/chat_stream" { ctx := r.Context() - userIDStr := ctx.Value(userContextKey).(string) - userIDInt, err := strconv.Atoi(userIDStr) + userIDInt, err := getUserID(ctx) + // role := ctx.Value(roleContextKey).(string) + if err != nil { - http.Error(w, "Error: '"+userIDStr+"' is not a valid user ID. Please enter a valid user ID.", http.StatusBadRequest) + RespondWithError(w, http.StatusUnauthorized, "no user", err) return } messageCount, err := q.GetChatMessagesCount(r.Context(), int32(userIDInt)) @@ -30,12 +29,7 @@ func RateLimitByUserID(q *sqlc_queries.Queries) func(http.Handler) http.Handler } maxRate, err := q.GetRateLimit(r.Context(), int32(userIDInt)) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - maxRate = int32(appConfig.OPENAI.RATELIMIT) - } else { - http.Error(w, "Could not get rate limit.", http.StatusInternalServerError) - return - } + maxRate = int32(appConfig.OPENAI.RATELIMIT) } if messageCount >= int64(maxRate) { diff --git a/api/sqlc/queries/auth_user_management.sql b/api/sqlc/queries/auth_user_management.sql index d48ad175..2ad0847f 100644 --- a/api/sqlc/queries/auth_user_management.sql +++ b/api/sqlc/queries/auth_user_management.sql @@ -1,6 +1,7 @@ -- name: GetRateLimit :one -- GetRateLimit retrieves the rate limit for a user from the auth_user_management table. -- If no rate limit is set for the user, it returns the default rate limit of 100. -SELECT COALESCE(rate_limit, 100) AS rate_limit +SELECT rate_limit AS rate_limit FROM auth_user_management WHERE user_id = $1; + diff --git a/api/sqlc_queries/auth_user_management.sql.go b/api/sqlc_queries/auth_user_management.sql.go index 871bfc3a..51b84198 100644 --- a/api/sqlc_queries/auth_user_management.sql.go +++ b/api/sqlc_queries/auth_user_management.sql.go @@ -10,7 +10,7 @@ import ( ) const getRateLimit = `-- name: GetRateLimit :one -SELECT COALESCE(rate_limit, 100) AS rate_limit +SELECT rate_limit AS rate_limit FROM auth_user_management WHERE user_id = $1 ` diff --git a/web/src/api/index.ts b/web/src/api/index.ts index 827e3575..ae8f78c8 100644 --- a/web/src/api/index.ts +++ b/web/src/api/index.ts @@ -77,17 +77,6 @@ export function fetchChatAPIProcess( }) } -export async function fetchVerify(token: string) { - try { - const response = await request.post('/verify', { token }) - return response.data - } - catch (error) { - console.error(error) - throw error - } -} - export async function fetchLogin(email: string, password: string) { try { const response = await request.post('/login', { email, password })