Skip to content

Commit

Permalink
rate limit (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
swuecho authored Apr 16, 2023
1 parent ee6704b commit c887753
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 27 deletions.
1 change: 0 additions & 1 deletion api/auth_user_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ func (h *AuthUserHandler) Register(router *mux.Router) {
router.HandleFunc("/users/{id}", h.UpdateUser).Methods(http.MethodPut)
router.HandleFunc("/signup", h.SignUp).Methods(http.MethodPost)
router.HandleFunc("/login", h.Login).Methods(http.MethodPost)
router.HandleFunc("/verify", h.verify).Methods(http.MethodPost)
router.HandleFunc("/config", h.configHandler).Methods(http.MethodPost)
// rate limit handler
router.HandleFunc("/admin/rate_limit", h.UpdateRateLimit).Methods(http.MethodPost)
Expand Down
8 changes: 7 additions & 1 deletion api/middleware_authenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,14 @@ const (
)

func IsAuthorizedMiddleware(handler http.Handler) http.Handler {
noAuthPaths := map[string]bool{
"/": true,
"/login": true,
"/signup": true,
}
jwtSigningKey := []byte(jwtSecretAndAud.Secret)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" || r.URL.Path == "/verify" || r.URL.Path == "/login" || r.URL.Path == "/signup" || strings.HasPrefix(r.URL.Path, "/static") {
if _, ok := noAuthPaths[r.URL.Path]; ok || strings.HasPrefix(r.URL.Path, "/static") {
handler.ServeHTTP(w, r)
return
}
Expand Down Expand Up @@ -113,3 +118,4 @@ func IsAuthorizedMiddleware(handler http.Handler) http.Handler {
}
})
}

18 changes: 6 additions & 12 deletions api/middleware_rateLimit.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package main

import (
"database/sql"
"errors"
"net/http"
"strconv"

"github.com/rotisserie/eris"
"github.com/swuecho/chat_backend/sqlc_queries"
Expand All @@ -14,13 +11,15 @@ import (
func RateLimitByUserID(q *sqlc_queries.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

// Get the user ID from the request, e.g. from a JWT token.
if r.URL.Path == "/chat" || r.URL.Path == "/chat_stream" {
ctx := r.Context()
userIDStr := ctx.Value(userContextKey).(string)
userIDInt, err := strconv.Atoi(userIDStr)
userIDInt, err := getUserID(ctx)
// role := ctx.Value(roleContextKey).(string)

if err != nil {
http.Error(w, "Error: '"+userIDStr+"' is not a valid user ID. Please enter a valid user ID.", http.StatusBadRequest)
RespondWithError(w, http.StatusUnauthorized, "no user", err)
return
}
messageCount, err := q.GetChatMessagesCount(r.Context(), int32(userIDInt))
Expand All @@ -30,12 +29,7 @@ func RateLimitByUserID(q *sqlc_queries.Queries) func(http.Handler) http.Handler
}
maxRate, err := q.GetRateLimit(r.Context(), int32(userIDInt))
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
maxRate = int32(appConfig.OPENAI.RATELIMIT)
} else {
http.Error(w, "Could not get rate limit.", http.StatusInternalServerError)
return
}
maxRate = int32(appConfig.OPENAI.RATELIMIT)
}

if messageCount >= int64(maxRate) {
Expand Down
3 changes: 2 additions & 1 deletion api/sqlc/queries/auth_user_management.sql
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
-- name: GetRateLimit :one
-- GetRateLimit retrieves the rate limit for a user from the auth_user_management table.
-- If no rate limit is set for the user, it returns the default rate limit of 100.
SELECT COALESCE(rate_limit, 100) AS rate_limit
SELECT rate_limit AS rate_limit
FROM auth_user_management
WHERE user_id = $1;

2 changes: 1 addition & 1 deletion api/sqlc_queries/auth_user_management.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 0 additions & 11 deletions web/src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,6 @@ export function fetchChatAPIProcess<T>(
})
}

export async function fetchVerify(token: string) {
try {
const response = await request.post('/verify', { token })
return response.data
}
catch (error) {
console.error(error)
throw error
}
}

export async function fetchLogin(email: string, password: string) {
try {
const response = await request.post('/login', { email, password })
Expand Down

0 comments on commit c887753

Please sign in to comment.