diff --git a/config.py b/config.py deleted file mode 100644 index cb15742..0000000 --- a/config.py +++ /dev/null @@ -1,29 +0,0 @@ -import logging - -from pydantic import BaseSettings, Field - -logging.basicConfig( - format="%(asctime)s: %(levelname)s: %(name)s: %(message)s", - level=logging.INFO -) -_logger = logging.getLogger(__name__) - - -class Settings(BaseSettings): - api_key: str = Field(default="sk-xxx", exclude=True) - api_base: str = "https://api.openai.com/v1" - model: str = "gpt-3.5-turbo" - rate_limit: str = "60/hour" - user_rate_limit: str = "600/hour" - github_client_id: str = "" - github_client_secret: str = Field(default="", exclude=True) - jwt_secret: str = Field(default="secret", exclude=True) - ad_client: str = "" - ad_slot: str = "" - - class Config: - env_file = ".env" - - -settings = Settings() -_logger.info(f"settings: {settings.json(indent=2)}") diff --git a/frontend/src/views/Index.vue b/frontend/src/views/Index.vue index 12cc9a4..1d970e9 100644 --- a/frontend/src/views/Index.vue +++ b/frontend/src/views/Index.vue @@ -65,8 +65,8 @@ const onSubmit = async () => { async onopen(response) { if (response.ok && response.headers.get('content-type') === EventStreamContentType) { return; - } else if (response.status >= 400 && response.status < 500 && response.status !== 429) { - throw new Error(`占卜失败: ${response.status}`); + } else if (response.status >= 400 && response.status < 500) { + throw new Error(`${response.status} ${await response.text()}`); } }, onmessage(msg) { diff --git a/main.py b/main.py index 044580d..304a643 100644 --- a/main.py +++ b/main.py @@ -1,64 +1,17 @@ -import os import logging import uvicorn -from fastapi import FastAPI, Request, status -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import PlainTextResponse, FileResponse -from fastapi.staticfiles import StaticFiles +from src.app import app +from src.config import settings -from slowapi import _rate_limit_exceeded_handler -from slowapi.errors import RateLimitExceeded -from slowapi.middleware import SlowAPIMiddleware - -from src.limiter import limiter, get_real_ipaddr -from src.chatgpt_router import router as chatgpt_router -from src.divination_router import router as divination_router -from src.user_router import router as user_router - - -_logger = logging.getLogger(__name__) - -app = FastAPI(title="Chatgpt Divination API") -app.state.limiter = limiter -app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) -app.add_middleware(SlowAPIMiddleware) - -app.add_middleware( - CORSMiddleware, - allow_origins=[ - "http://localhost:5173", - "http://localhost", - "http://127.0.0.1" - ], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], +logging.basicConfig( + format="%(asctime)s: %(levelname)s: %(name)s: %(message)s", + level=logging.INFO ) +_logger = logging.getLogger(__name__) -app.include_router(chatgpt_router) -app.include_router(divination_router) -app.include_router(user_router) - -if os.path.exists("dist"): - @app.get("/") - @app.get("/login/{path}") - async def read_index(request: Request): - _logger.info(f"Request from {get_real_ipaddr(request)}") - return FileResponse( - "dist/index.html", - headers={"Cache-Control": "no-cache"} - ) - - app.mount("/", StaticFiles(directory="dist"), name="static") - +_logger.info(f"settings: {settings.model_dump_json(indent=2)}") -@app.exception_handler(Exception) -async def exception_handler(request: Request, exc: Exception): - return PlainTextResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=f"Internal Server Error: {exc}", - ) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/requirements.txt b/requirements.txt index 343b546..e1324e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -fastapi==0.95.0 -python-dotenv==1.0.0 -uvicorn==0.21.1 -openai==0.27.4 -python-multipart==0.0.6 -slowapi==0.1.8 -pyjwt==2.7.0 +fastapi==0.110.2 +pydantic-settings==2.2.1 +pydantic==2.7.1 +uvicorn==0.29.0 +openai==1.23.6 +pyjwt==2.8.0 +requests==2.31.0 diff --git a/src/app.py b/src/app.py new file mode 100644 index 0000000..b169031 --- /dev/null +++ b/src/app.py @@ -0,0 +1,52 @@ +import os +import logging + +from fastapi import FastAPI, Request, status +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import PlainTextResponse, FileResponse +from fastapi.staticfiles import StaticFiles + +from src.limiter import get_real_ipaddr +from src.chatgpt_router import router as chatgpt_router +from src.divination_router import router as divination_router +from src.user_router import router as user_router + + +_logger = logging.getLogger(__name__) + +app = FastAPI(title="Chatgpt Divination API") +app.add_middleware( + CORSMiddleware, + allow_origins=[ + "http://localhost:5173", + "http://localhost", + "http://127.0.0.1" + ], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.include_router(chatgpt_router) +app.include_router(divination_router) +app.include_router(user_router) + +if os.path.exists("dist"): + @app.get("/") + @app.get("/login/{path}") + async def read_index(request: Request): + _logger.info(f"Request from {get_real_ipaddr(request)}") + return FileResponse( + "dist/index.html", + headers={"Cache-Control": "no-cache"} + ) + + app.mount("/", StaticFiles(directory="dist"), name="static") + + +@app.exception_handler(Exception) +async def exception_handler(request: Request, exc: Exception): + return PlainTextResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=f"Internal Server Error: {exc}", + ) diff --git a/src/chatgpt_router.py b/src/chatgpt_router.py index e438794..30e64b0 100644 --- a/src/chatgpt_router.py +++ b/src/chatgpt_router.py @@ -1,22 +1,22 @@ import json from typing import Optional from fastapi.responses import StreamingResponse -import openai +from openai import OpenAI + import logging from fastapi import Depends, HTTPException, Request, status -from config import settings +from src.config import settings from fastapi import APIRouter -from models import DivinationBody, User +from src.models import DivinationBody, User from src.user import get_user -from .limiter import get_real_ipaddr, limiter -from .divination import DivinationFactory +from src.limiter import get_real_ipaddr, check_rate_limit +from src.divination import DivinationFactory -openai.api_key = settings.api_key -openai.api_base = settings.api_base +client = OpenAI(api_key=settings.api_key, base_url=settings.api_base) router = APIRouter() _logger = logging.getLogger(__name__) STOP_WORDS = [ @@ -26,25 +26,6 @@ ] -@limiter.limit(settings.rate_limit) -def limit_when_not_login(request: Request): - """ - Limit when not login - """ - - -def limit_when_login(request: Request, user: User): - """ - Limit when login - """ - @limiter.limit(settings.user_rate_limit, key_func=lambda: (user.user_name, user.login_type)) - def limit(request: Request): - """ - Limit when login - """ - limit(request) - - @router.post("/api/divination") async def divination( request: Request, @@ -52,14 +33,21 @@ async def divination( user: Optional[User] = Depends(get_user) ): + real_ip = get_real_ipaddr(request) # rate limit when not login if not user: - limit_when_not_login(request) + max_reqs, time_window_seconds = settings.rate_limit + check_rate_limit(real_ip, time_window_seconds, max_reqs) else: - limit_when_login(request, user) + max_reqs, time_window_seconds = settings.user_rate_limit + check_rate_limit( + f"{user.login_type}:{user.user_name}", time_window_seconds, max_reqs + ) _logger.info( - f"Request from {get_real_ipaddr(request)}, user={user.json(ensure_ascii=False) if user else None} body={divination_body.json(ensure_ascii=False)}" + f"Request from {real_ip}, " + f"user={user.model_dump_json(context=dict(ensure_ascii=False)) if user else None}, " + f"body={divination_body.model_dump_json(context=dict(ensure_ascii=False))}" ) if any(w in divination_body.prompt.lower() for w in STOP_WORDS): raise HTTPException( @@ -75,7 +63,11 @@ async def divination( prompt, system_prompt = divination_obj.build_prompt(divination_body) def get_openai_generator(): - openai_stream = openai.ChatCompletion.create( + for i in range(100): + contet = f'{i}' + yield f"data: {json.dumps(contet)}\n\n" + return + openai_stream = client.chat.completions.create( model=settings.model, max_tokens=1000, temperature=0.9, @@ -90,8 +82,8 @@ def get_openai_generator(): ] ) for event in openai_stream: - if "content" in event["choices"][0].delta: - current_response = event["choices"][0].delta.content + if event.choices and event.choices[0].delta and event.choices[0].delta.content: + current_response = event.choices[0].delta.content yield f"data: {json.dumps(current_response)}\n\n" return StreamingResponse(get_openai_generator(), media_type='text/event-stream') diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..dd071d7 --- /dev/null +++ b/src/config.py @@ -0,0 +1,37 @@ +import logging +from typing import Tuple + +from pydantic import Field +from pydantic_settings import BaseSettings + +_logger = logging.getLogger(__name__) + + +class Settings(BaseSettings): + api_key: str = Field(default="sk-xxx", exclude=True) + api_base: str = "https://api.openai.com/v1" + model: str = "gpt-3.5-turbo" + # rate limit xxx request per xx seconds + rate_limit: Tuple[int, int] = (60, 60 * 60) + user_rate_limit: Tuple[int, int] = (600, 60 * 60) + github_client_id: str = "" + github_client_secret: str = Field(default="", exclude=True) + jwt_secret: str = Field(default="secret", exclude=True) + ad_client: str = "" + ad_slot: str = "" + + def get_human_rate_limit(self) -> str: + max_reqs, time_window_seconds = self.rate_limit + # convert to human readable format + return f"{max_reqs}req/{time_window_seconds}seconds" + + def get_human_user_rate_limit(self) -> str: + max_reqs, time_window_seconds = self.user_rate_limit + # convert to human readable format + return f"{max_reqs}req/{time_window_seconds}seconds" + + class Config: + env_file = ".env" + + +settings = Settings() diff --git a/src/divination/base.py b/src/divination/base.py index 51d038e..ca9536c 100644 --- a/src/divination/base.py +++ b/src/divination/base.py @@ -1,5 +1,5 @@ -from models import DivinationBody +from src.models import DivinationBody from typing import Optional diff --git a/src/divination/birthday.py b/src/divination/birthday.py index d38244b..768828a 100644 --- a/src/divination/birthday.py +++ b/src/divination/birthday.py @@ -1,5 +1,5 @@ import datetime -from models import DivinationBody +from src.models import DivinationBody from .base import DivinationFactory BIRTHDAY_PROMPT = "我请求你担任中国传统的生辰八字算命的角色。" \ diff --git a/src/divination/dream.py b/src/divination/dream.py index c3233f3..6895709 100644 --- a/src/divination/dream.py +++ b/src/divination/dream.py @@ -1,5 +1,5 @@ from fastapi import HTTPException -from models import DivinationBody +from src.models import DivinationBody from .base import DivinationFactory DREAM_PROMPT = "我请求你担任中国传统的周公解梦师的角色。" \ diff --git a/src/divination/fate.py b/src/divination/fate.py index 2d567d1..1b77372 100644 --- a/src/divination/fate.py +++ b/src/divination/fate.py @@ -1,5 +1,5 @@ from fastapi import HTTPException -from models import DivinationBody +from src.models import DivinationBody from .base import DivinationFactory SYS_PROMPT = "你是一个姻缘助手,我给你发两个人的名字,用逗号隔开,"\ diff --git a/src/divination/name.py b/src/divination/name.py index ee560dc..1e2925a 100644 --- a/src/divination/name.py +++ b/src/divination/name.py @@ -1,5 +1,5 @@ from fastapi import HTTPException -from models import DivinationBody +from src.models import DivinationBody from .base import DivinationFactory NAME_PROMPT = "我请求你担任中国传统的姓名五格算命师的角色。" \ diff --git a/src/divination/new_name.py b/src/divination/new_name.py index 92c6c8c..03e37d2 100644 --- a/src/divination/new_name.py +++ b/src/divination/new_name.py @@ -1,6 +1,6 @@ import datetime from fastapi import HTTPException -from models import DivinationBody +from src.models import DivinationBody from .base import DivinationFactory NEW_NAME_PROMPT = ( diff --git a/src/divination/plum_flower.py b/src/divination/plum_flower.py index b7d0254..1ae2f26 100644 --- a/src/divination/plum_flower.py +++ b/src/divination/plum_flower.py @@ -1,5 +1,5 @@ from fastapi import HTTPException -from models import DivinationBody +from src.models import DivinationBody from .base import DivinationFactory SYS_PROMPT = "我请求你担任中国传统的梅花易数占卜师的角色。" \ diff --git a/src/divination/tarot.py b/src/divination/tarot.py index bca2c2c..2ecfeb0 100644 --- a/src/divination/tarot.py +++ b/src/divination/tarot.py @@ -1,5 +1,5 @@ from fastapi import HTTPException -from models import DivinationBody +from src.models import DivinationBody from .base import DivinationFactory TAROT_PROMPT = "我请求你担任塔罗占卜师的角色。 " \ diff --git a/src/divination_router.py b/src/divination_router.py index a310001..76c58ba 100644 --- a/src/divination_router.py +++ b/src/divination_router.py @@ -1,6 +1,8 @@ import json import uuid import openai +from openai import OpenAI + import logging from datetime import datetime @@ -8,15 +10,14 @@ from fastapi.responses import StreamingResponse from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials -from config import settings +from src.config import settings from fastapi import APIRouter -from models import BirthdayBody, CommonResponse -from .limiter import get_real_ipaddr +from src.models import BirthdayBody, CommonResponse +from src.limiter import get_real_ipaddr from .divination.birthday import BirthdayFactory -openai.api_key = settings.api_key -openai.api_base = settings.api_base +client = OpenAI(api_key="", base_url=settings.api_base) router = APIRouter() security = HTTPBearer() _logger = logging.getLogger(__name__) @@ -65,7 +66,7 @@ def common_openai_streaming_call( ) -> StreamingResponse: def get_openai_generator(): try: - openai_stream = openai.ChatCompletion.create( + openai_stream = client.chat.completions.create( api_key=token, model=settings.model, max_tokens=1000, @@ -80,14 +81,14 @@ def get_openai_generator(): {"role": "user", "content": prompt} ] ) - except openai.error.OpenAIError as e: + except openai.OpenAIError as e: raise HTTPException( status_code=500, detail=f"OpenAI error: {e}" ) for event in openai_stream: - if "content" in event["choices"][0].delta: - current_response = event["choices"][0].delta.content + if "content" in event.choices[0].delta: + current_response = event.choices[0].delta.content yield current_response return StreamingResponse(get_openai_generator(), media_type='text/event-stream') @@ -103,7 +104,7 @@ def common_openai_call( request_id = uuid.uuid4() try: - response = openai.ChatCompletion.create( + response = client.chat.completions.create( api_key=token, model=settings.model, max_tokens=1000, @@ -117,13 +118,13 @@ def common_openai_call( {"role": "user", "content": prompt} ] ) - except openai.error.OpenAIError as e: + except openai.OpenAIError as e: raise HTTPException( status_code=500, detail=f"OpenAI error: {e}" ) - res = response['choices'][0]['message']['content'] + res = response.choices[0].message.content latency = datetime.now() - start_time _logger.info( f"Request {request_id}:" diff --git a/src/limiter.py b/src/limiter.py index 0776569..c8243ca 100644 --- a/src/limiter.py +++ b/src/limiter.py @@ -1,6 +1,11 @@ -from fastapi import Request +import time +import logging -from slowapi import Limiter +from collections import defaultdict +from fastapi import HTTPException, Request + +_logger = logging.getLogger(__name__) +request_limit_map = defaultdict(list) def get_real_ipaddr(request: Request) -> str: @@ -13,4 +18,24 @@ def get_real_ipaddr(request: Request) -> str: return request.client.host -limiter = Limiter(key_func=get_real_ipaddr) +def check_rate_limit(key: str, time_window_seconds: int, max_requests: int) -> None: + cur_timestamp = int(time.time()) + try: + # remove expired records + while request_limit_map[key] and request_limit_map[key][0] < (cur_timestamp - time_window_seconds): + request_limit_map[key].pop(0) + # add current timestamp + request_limit_map[key].append(cur_timestamp) + req_count = len(request_limit_map[key]) + if req_count >= max_requests: + raise HTTPException( + status_code=429, detail="Rate limit exceeded" + ) + return + except Exception as e: + if isinstance(e, HTTPException): + raise e + _logger.error(f"Rate limit failed: {e}") + raise HTTPException( + status_code=400, detail="Rate limit failed" + ) diff --git a/models.py b/src/models.py similarity index 100% rename from models.py rename to src/models.py diff --git a/src/user.py b/src/user.py index 2427f96..5745642 100644 --- a/src/user.py +++ b/src/user.py @@ -6,8 +6,8 @@ from fastapi import Depends, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials -from config import settings -from models import User +from src.config import settings +from src.models import User from fastapi import HTTPException _logger = logging.getLogger(__name__) diff --git a/src/user_router.py b/src/user_router.py index 7222d51..4bdb8d5 100644 --- a/src/user_router.py +++ b/src/user_router.py @@ -1,13 +1,14 @@ -from typing import Optional import jwt +import requests import datetime import logging +from typing import Optional + from fastapi import APIRouter, Depends, HTTPException, status -import requests -from config import settings -from models import OauthBody, SettingsInfo, User +from src.config import settings +from src.models import OauthBody, SettingsInfo, User from src.user import get_user router = APIRouter() @@ -23,19 +24,19 @@ @router.get("/api/v1/settings", tags=["User"]) -def info(user: Optional[User] = Depends(get_user)): +async def info(user: Optional[User] = Depends(get_user)): return SettingsInfo( login_type=user.login_type if user else "", user_name=user.user_name if user else "", ad_client=settings.ad_client, ad_slot=settings.ad_slot, - rate_limit=settings.rate_limit, - user_rate_limit=settings.user_rate_limit + rate_limit=settings.get_human_rate_limit(), + user_rate_limit=settings.get_human_user_rate_limit(), ) @router.get("/api/v1/login", tags=["User"]) -def login(login_type: str, redirect_url: str): +async def login(login_type: str, redirect_url: str): if login_type == "github": return f"{GITHUB_URL}&redirect_uri={redirect_url}" raise HTTPException( @@ -45,7 +46,7 @@ def login(login_type: str, redirect_url: str): @router.post("/api/v1/oauth", tags=["User"]) -def oauth(oauth_body: OauthBody): +async def oauth(oauth_body: OauthBody): if oauth_body.login_type == "github" and oauth_body.code: access_token = requests.post( f"{GITHUB_TOEKN_URL}&code={oauth_body.code}",