forked from renqabs/A2OA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
68 lines (54 loc) · 1.92 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import json
import os
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from adapters.ai_pro_adapter import AIProAdapter
from models import models_list
LOG_LEVEL = os.getenv("LOG_LEVEL", "info")
PORT = int(os.getenv("PORT", 8000))
PROXY = os.getenv("PROXY")
PASSWORD = os.getenv("PASSWORD", "ninomae")
API_PROXY = os.getenv("API_PROXY")
adapter = AIProAdapter(password=PASSWORD, proxy=PROXY, api_proxy=API_PROXY)
print('adapter: ' + str(adapter))
app = FastAPI()
print('app: ' + str(app))
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins
allow_credentials=True,
allow_methods=["*"], # Allow all methods, including OPTIONS
allow_headers=["*"],
)
@app.api_route(
"/v1/chat/completions",
methods=["POST", "OPTIONS"],
)
@app.api_route(
"/hf/v1/chat/completions",
methods=["POST", "OPTIONS"],
)
async def chat(request: Request):
openai_params = await request.json()
if openai_params.get("stream", False):
async def generate():
async for response in adapter.chat(request):
if response == "[DONE]":
yield "data: [DONE]"
break
yield f"data: {json.dumps(response)}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
else:
response = adapter.chat(request)
openai_response = await response.__anext__()
return JSONResponse(content=openai_response)
@app.get("/v1/models")
@app.get("/hf/v1/models")
async def models(request: Request):
# return a dict with key "object" and "data", "object" value is "list", "data" values is models list
return JSONResponse(content={"object": "list", "data": models_list})
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=PORT, log_level=LOG_LEVEL)