From 22bd74b3990a68f8b239bf7e36cd3cc0f9801832 Mon Sep 17 00:00:00 2001 From: BuildTools Date: Sat, 31 Aug 2024 15:34:47 -0700 Subject: [PATCH] feat(server): replace Flask with FastAPI and Uvicorn - replace Flask with FastAPI and Uvicorn - fix web page not found error - port is now defaulted to 7001 - bind to localhost (127.0.0.1) instead of 0.0.0.0 - improve performance by using Uvicorn - add OpenAPI docs for endpoints --- requirements.txt | 3 +- src/main.py | 214 ++++++++++++++++++++++++++++++++++------------- 2 files changed, 158 insertions(+), 59 deletions(-) diff --git a/requirements.txt b/requirements.txt index ecd8764..1c3962b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,9 +6,10 @@ sentencepiece~=0.2.0 PyYAML~=6.0.2 pynvml~=11.5.3 PySide6~=6.7.2 -flask~=3.0.3 python-dotenv~=1.0.1 safetensors~=0.4.4 setuptools~=68.2.0 huggingface-hub~=0.24.6 transformers~=4.44.2 +fastapi~=0.112.2 +uvicorn~=0.30.6 diff --git a/src/main.py b/src/main.py index 5c93d15..f137611 100644 --- a/src/main.py +++ b/src/main.py @@ -1,80 +1,178 @@ import os import sys import threading +from enum import Enum +from typing import List, Optional from PySide6.QtCore import QTimer from PySide6.QtWidgets import QApplication +from fastapi import FastAPI, Query +from pydantic import BaseModel, Field +from uvicorn import Config, Server + from AutoGGUF import AutoGGUF -from flask import Flask, Response, jsonify +from Localizations import AUTOGGUF_VERSION -server = Flask(__name__) +app = FastAPI( + title="AutoGGUF", + description="API for AutoGGUF - automatically quant GGUF models", + version=AUTOGGUF_VERSION, + license_info={ + "name": "Apache 2.0", + "url": "https://raw.githubusercontent.com/leafspark/AutoGGUF/main/LICENSE", + }, +) +# Global variable to hold the window reference +window = None -def main() -> None: - @server.route("/v1/models", methods=["GET"]) - def models() -> Response: - if window: - return jsonify({"models": window.get_models_data()}) - return jsonify({"models": []}) - - @server.route("/v1/tasks", methods=["GET"]) - def tasks() -> Response: - if window: - return jsonify({"tasks": window.get_tasks_data()}) - return jsonify({"tasks": []}) - - @server.route("/v1/health", methods=["GET"]) - def ping() -> Response: - return jsonify({"status": "alive"}) - - @server.route("/v1/backends", methods=["GET"]) - def get_backends() -> Response: - backends = [] + +class ModelType(str, Enum): + single = "single" + sharded = "sharded" + + +class Model(BaseModel): + name: str = Field(..., description="Name of the model") + type: str = Field(..., description="Type of the model") + path: str = Field(..., description="Path to the model file") + size: Optional[int] = Field(None, description="Size of the model in bytes") + + class Config: + json_schema_extra = { + "example": { + "name": "Llama-3.1-8B-Instruct.fp16.gguf", + "type": "single", + "path": "Llama-3.1-8B-Instruct.fp16.gguf", + "size": 13000000000, + } + } + + +class Task(BaseModel): + id: str = Field(..., description="Unique identifier for the task") + status: str = Field(..., description="Current status of the task") + progress: float = Field(..., description="Progress of the task as a percentage") + + class Config: + json_json_schema_extra = { + "example": {"id": "task_123", "status": "running", "progress": 75.5} + } + + +class Backend(BaseModel): + name: str = Field(..., description="Name of the backend") + path: str = Field(..., description="Path to the backend executable") + + +class Plugin(BaseModel): + name: str = Field(..., description="Name of the plugin") + version: str = Field(..., description="Version of the plugin") + description: str = Field(..., description="Description of the plugin") + author: str = Field(..., description="Author of the plugin") + + +@app.get("/v1/models", response_model=List[Model], tags=["Models"]) +async def get_models( + type: Optional[ModelType] = Query(None, description="Filter models by type") +) -> List[Model]: + """ + Get a list of all available models. + + - **type**: Optional filter for model type + + Returns a list of Model objects containing name, type, path, and optional size. + """ + if window: + models = window.get_models_data() + if type: + models = [m for m in models if m["type"] == type] + + # Convert to Pydantic models, handling missing 'size' field + return [Model(**m) for m in models] + return [] + + +@app.get("/v1/tasks", response_model=List[Task], tags=["Tasks"]) +async def get_tasks() -> List[Task]: + """ + Get a list of all current tasks. + + Returns a list of Task objects containing id, status, and progress. + """ + if window: + return window.get_tasks_data() + return [] + + +@app.get("/v1/health", tags=["System"]) +async def health_check() -> dict: + """ + Check the health status of the API. + + Returns a simple status message indicating the API is alive. + """ + return {"status": "alive"} + + +@app.get("/v1/backends", response_model=List[Backend], tags=["System"]) +async def get_backends() -> List[Backend]: + """ + Get a list of all available llama.cpp backends. + + Returns a list of Backend objects containing name and path. + """ + backends = [] + if window: for i in range(window.backend_combo.count()): backends.append( - { - "name": window.backend_combo.itemText(i), - "path": window.backend_combo.itemData(i), - } - ) - return jsonify({"backends": backends}) - - @server.route("/v1/plugins", methods=["GET"]) - def get_plugins() -> Response: - if window: - return jsonify( - { - "plugins": [ - { - "name": plugin_data["data"]["name"], - "version": plugin_data["data"]["version"], - "description": plugin_data["data"]["description"], - "author": plugin_data["data"]["author"], - } - for plugin_data in window.plugins.values() - ] - } - ) - return jsonify({"plugins": []}) - - def run_flask() -> None: - if os.environ.get("AUTOGGUF_SERVER", "").lower() == "enabled": - server.run( - host="0.0.0.0", - port=int(os.environ.get("AUTOGGUF_SERVER_PORT", 5000)), - debug=False, - use_reloader=False, + Backend( + name=window.backend_combo.itemText(i), + path=window.backend_combo.itemData(i), + ) ) + return backends - app = QApplication(sys.argv) + +@app.get("/v1/plugins", response_model=List[Plugin], tags=["System"]) +async def get_plugins() -> List[Plugin]: + """ + Get a list of all installed plugins. + + Returns a list of Plugin objects containing name, version, description, and author. + """ + if window: + return [ + Plugin(**plugin_data["data"]) for plugin_data in window.plugins.values() + ] + return [] + + +def run_uvicorn() -> None: + if os.environ.get("AUTOGGUF_SERVER", "").lower() == "enabled": + config = Config( + app=app, + host="127.0.0.1", + port=int(os.environ.get("AUTOGGUF_SERVER_PORT", 7001)), + log_level="info", + ) + server = Server(config) + server.run() + + +def main() -> None: + global window + qt_app = QApplication(sys.argv) window = AutoGGUF(sys.argv) window.show() - # Start Flask in a separate thread after a short delay + + # Start Uvicorn in a separate thread after a short delay timer = QTimer() timer.singleShot( - 100, lambda: threading.Thread(target=run_flask, daemon=True).start() + 100, lambda: threading.Thread(target=run_uvicorn, daemon=True).start() ) - sys.exit(app.exec()) + + sys.exit(qt_app.exec()) if __name__ == "__main__":