From 2ba873c1c2cdd13e9726611deba718b19a15c840 Mon Sep 17 00:00:00 2001 From: dreamhunter2333 Date: Mon, 4 Mar 2024 22:45:03 +0800 Subject: [PATCH] feat: add api --- .flake8 | 5 ++ README.md | 1 - config.py | 12 +-- docker-compose.yaml | 1 - main.py | 2 + models.py | 11 ++- router/chatgpt_router.py | 9 +-- router/divination/__init__.py | 7 ++ router/divination/birthday.py | 5 +- router/divination_router.py | 138 ++++++++++++++++++++++++++++++++++ router/file_logger.py | 19 ----- router/user.py | 5 +- vercel.json | 22 ------ 13 files changed, 176 insertions(+), 61 deletions(-) create mode 100644 .flake8 create mode 100644 router/divination_router.py delete mode 100644 router/file_logger.py delete mode 100644 vercel.json diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..07c2714 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +max-line-length = 120 +per-file-ignores = + # imported but unused + __init__.py: F401 diff --git a/README.md b/README.md index 66353d1..2e9d616 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,6 @@ services: # - model=gpt-3.5-turbo # optional # - rate_limit=10/minute # optional # - user_rate_limit=600/hour # optional - # - log_dir=/logs/ # optional - github_client_id=xxx - github_client_secret=xxx - jwt_secret=secret diff --git a/config.py b/config.py index bc5353a..f87991d 100644 --- a/config.py +++ b/config.py @@ -1,24 +1,23 @@ import logging -from typing import Optional -from pydantic import BaseSettings +from pydantic import BaseSettings, Field logging.basicConfig( format="%(asctime)s: %(levelname)s: %(name)s: %(message)s", level=logging.INFO ) +_logger = logging.getLogger(__name__) class Settings(BaseSettings): - api_key: str = "sk-xxx" + api_key: str = Field(default="sk-xxx", exclude=True) api_base: str = "https://api.openai.com/v1" model: str = "gpt-3.5-turbo" rate_limit: str = "60/hour" user_rate_limit: str = "600/hour" - log_dir: str = "" github_client_id: str = "" - github_client_secret: str = "" - jwt_secret: str = "secret" + github_client_secret: str = Field(default="", exclude=True) + jwt_secret: str = Field(default="secret", exclude=True) ad_client: str = "" ad_slot: str = "" @@ -27,3 +26,4 @@ class Config: settings = Settings() +_logger.info(f"settings: {settings.json()}") diff --git a/docker-compose.yaml b/docker-compose.yaml index c2db598..9250879 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -11,7 +11,6 @@ services: # - model=gpt-3.5-turbo # optional # - rate_limit=10/minute # optional # - user_rate_limit=600/hour # optional - # - log_dir=/logs/ # optional - github_client_id=xxx - github_client_secret=xxx - jwt_secret=secret diff --git a/main.py b/main.py index f2d485e..67634e7 100644 --- a/main.py +++ b/main.py @@ -14,6 +14,7 @@ from router.limiter import limiter, get_real_ipaddr from router.date_router import router as date_router from router.chatgpt_router import router as chatgpt_router +from router.divination_router import router as divination_router from router.user_router import router as user_router @@ -38,6 +39,7 @@ app.include_router(date_router) app.include_router(chatgpt_router) +app.include_router(divination_router) app.include_router(user_router) if os.path.exists("dist"): diff --git a/models.py b/models.py index 94cb1a8..9fbef89 100644 --- a/models.py +++ b/models.py @@ -1,5 +1,5 @@ from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field class SettingsInfo(BaseModel): @@ -46,3 +46,12 @@ class DivinationBody(BaseModel): new_name: Optional[NewName] = None plum_flower: Optional[PlumFlower] = None fate: Optional[Fate] = None + + +class BirthdayBody(BaseModel): + birthday: str = Field(example="2000-08-17 00:00:00") + + +class CommonResponse(BaseModel): + content: str + request_id: str diff --git a/router/chatgpt_router.py b/router/chatgpt_router.py index 1efdd18..965cf43 100644 --- a/router/chatgpt_router.py +++ b/router/chatgpt_router.py @@ -4,7 +4,6 @@ import openai import logging -from datetime import datetime from fastapi import Depends, HTTPException, Request, status @@ -15,7 +14,6 @@ from router.user import get_user from .limiter import get_real_ipaddr, limiter from .divination import DivinationFactory -from .file_logger import file_logger openai.api_key = settings.api_key openai.api_base = settings.api_base @@ -26,9 +24,6 @@ "幫助", "現在", "開始", "开始", "start", "restart", "重新开始", "重新開始", "遵守", "遵循", "遵从", "遵從" ] -_logger.info( - f"Loaded divination types: {list(DivinationFactory.divination_map.keys())}" -) @limiter.limit(settings.rate_limit) @@ -68,13 +63,13 @@ async def divination( ) if any(w in divination_body.prompt.lower() for w in STOP_WORDS): raise HTTPException( - status_code=403, + status_code=status.HTTP_403_FORBIDDEN, detail="Prompt contains stop words" ) divination_obj = DivinationFactory.get(divination_body.prompt_type) if not divination_obj: raise HTTPException( - status_code=400, + status_code=status.HTTP_400_BAD_REQUEST, detail=f"No prompt type {divination_body.prompt_type} not supported" ) prompt, system_prompt = divination_obj.build_prompt(divination_body) diff --git a/router/divination/__init__.py b/router/divination/__init__.py index ab61e01..9ca4d4a 100644 --- a/router/divination/__init__.py +++ b/router/divination/__init__.py @@ -7,3 +7,10 @@ from . import plum_flower from . import fate from .base import DivinationFactory + +import logging + +_logger = logging.getLogger("divination factory") +_logger.info( + f"Loaded divination types: {list(DivinationFactory.divination_map.keys())}" +) diff --git a/router/divination/birthday.py b/router/divination/birthday.py index 29c93f6..d38244b 100644 --- a/router/divination/birthday.py +++ b/router/divination/birthday.py @@ -13,8 +13,11 @@ class BirthdayFactory(DivinationFactory): divination_type = "birthday" def build_prompt(self, divination_body: DivinationBody) -> tuple[str, str]: + return self.internal_build_prompt(divination_body.birthday) + + def internal_build_prompt(self, birthday: str) -> tuple[str, str]: birthday = datetime.datetime.strptime( - divination_body.birthday, '%Y-%m-%d %H:%M:%S' + birthday, '%Y-%m-%d %H:%M:%S' ) prompt = f"我的生日是{birthday.year}年{birthday.month}月{birthday.day}日{birthday.hour}时{birthday.minute}分{birthday.second}秒" return prompt, BIRTHDAY_PROMPT diff --git a/router/divination_router.py b/router/divination_router.py new file mode 100644 index 0000000..a310001 --- /dev/null +++ b/router/divination_router.py @@ -0,0 +1,138 @@ +import json +import uuid +import openai +import logging + +from datetime import datetime +from fastapi import Depends, HTTPException, Request +from fastapi.responses import StreamingResponse +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials + +from config import settings +from fastapi import APIRouter + +from models import BirthdayBody, CommonResponse +from .limiter import get_real_ipaddr +from .divination.birthday import BirthdayFactory + +openai.api_key = settings.api_key +openai.api_base = settings.api_base +router = APIRouter() +security = HTTPBearer() +_logger = logging.getLogger(__name__) + + +def get_token( + credentials: HTTPAuthorizationCredentials = Depends(security) +): + return credentials.credentials + + +@router.post("/api/streaming_divination/birthday", tags=["divination"]) +async def birthday_divination_streaming( + request: Request, + birthday_body: BirthdayBody, + token: str = Depends(get_token) +) -> StreamingResponse: + _logger.info( + f"Request from {get_real_ipaddr(request)}, birthday_body={birthday_body}" + ) + prompt, system_prompt = BirthdayFactory().internal_build_prompt( + birthday_body.birthday + ) + return common_openai_streaming_call(token, prompt, system_prompt) + + +@router.post("/api/divination/birthday", tags=["divination"]) +async def birthday_divination( + request: Request, + birthday_body: BirthdayBody, + token: str = Depends(get_token) +) -> CommonResponse: + _logger.info( + f"Request from {get_real_ipaddr(request)}, birthday_body={birthday_body}" + ) + prompt, system_prompt = BirthdayFactory().internal_build_prompt( + birthday_body.birthday + ) + return common_openai_call(request, token, prompt, system_prompt) + + +def common_openai_streaming_call( + token: str, + prompt: str, + system_prompt: str +) -> StreamingResponse: + def get_openai_generator(): + try: + openai_stream = openai.ChatCompletion.create( + api_key=token, + model=settings.model, + max_tokens=1000, + temperature=0.9, + top_p=1, + stream=True, + messages=[ + { + "role": "system", + "content": system_prompt + }, + {"role": "user", "content": prompt} + ] + ) + except openai.error.OpenAIError as e: + raise HTTPException( + status_code=500, + detail=f"OpenAI error: {e}" + ) + for event in openai_stream: + if "content" in event["choices"][0].delta: + current_response = event["choices"][0].delta.content + yield current_response + + return StreamingResponse(get_openai_generator(), media_type='text/event-stream') + + +def common_openai_call( + request: Request, + token: str, + prompt: str, + system_prompt: str +) -> CommonResponse: + start_time = datetime.now() + request_id = uuid.uuid4() + + try: + response = openai.ChatCompletion.create( + api_key=token, + model=settings.model, + max_tokens=1000, + temperature=0.9, + top_p=1, + messages=[ + { + "role": "system", + "content": system_prompt + }, + {"role": "user", "content": prompt} + ] + ) + except openai.error.OpenAIError as e: + raise HTTPException( + status_code=500, + detail=f"OpenAI error: {e}" + ) + + res = response['choices'][0]['message']['content'] + latency = datetime.now() - start_time + _logger.info( + f"Request {request_id}:" + f"Request from {get_real_ipaddr(request)}, " + f"latency_seconds={latency.total_seconds()}, " + f"prompt={prompt}," + f"res={json.dumps(res, ensure_ascii=False)}" + ) + return CommonResponse( + content=res, + request_id=request_id.hex + ) diff --git a/router/file_logger.py b/router/file_logger.py deleted file mode 100644 index eb517a9..0000000 --- a/router/file_logger.py +++ /dev/null @@ -1,19 +0,0 @@ -import os -import logging -from logging.handlers import RotatingFileHandler - -from config import settings - -file_logger = logging.getLogger(__name__) -file_logger.setLevel(logging.INFO) - -if settings.log_dir: - file_handler = RotatingFileHandler( - os.path.join(settings.log_dir, "divination.log"), - maxBytes=1024*1024*1024 - ) - file_handler.setLevel(logging.INFO) - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s') - file_handler.setFormatter(formatter) - file_logger.addHandler(file_handler) diff --git a/router/user.py b/router/user.py index 3f063b1..2427f96 100644 --- a/router/user.py +++ b/router/user.py @@ -3,7 +3,7 @@ from typing import Optional import jwt -from fastapi import Depends, status, Request +from fastapi import Depends, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from config import settings @@ -29,6 +29,5 @@ def get_user( raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired") return jwt_payload - except Exception as e: - _logger.exception(e) + except Exception: return diff --git a/vercel.json b/vercel.json deleted file mode 100644 index 6eaaa01..0000000 --- a/vercel.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "builds": [ - { - "src": "frontend/package.json", - "use": "@vercel/static-build" - }, - { - "src": "main.py", - "use": "@vercel/python" - } - ], - "routes": [ - { - "src": "/api/(.*)", - "dest": "main.py" - }, - { - "src": "/(.*)", - "dest": "frontend/$1" - } - ] -}