From 9c4fed34728b0e20b19a4d337157caa69beca674 Mon Sep 17 00:00:00 2001 From: sunshinesmilelk <1176136681@qq.com> Date: Tue, 2 Jul 2024 18:24:05 +0800 Subject: [PATCH] feat(api): add plugin database & api --- .../versions/ac21c38c5e56_add_plugin_table.py | 78 +++++++++ api/models/__init__.py | 8 +- api/models/plugin.py | 60 +++++++ api/models/plugin_api.py | 84 ++++++++++ api/models/plugin_config.py | 79 +++++++++ api/routers/main.py | 3 +- api/routers/plugin/__init__.py | 0 api/routers/plugin/crud.py | 158 ++++++++++++++++++ api/routers/plugin/router.py | 94 +++++++++++ 9 files changed, 562 insertions(+), 2 deletions(-) create mode 100644 api/alembic/versions/ac21c38c5e56_add_plugin_table.py create mode 100644 api/models/plugin.py create mode 100644 api/models/plugin_api.py create mode 100644 api/models/plugin_config.py create mode 100644 api/routers/plugin/__init__.py create mode 100644 api/routers/plugin/crud.py create mode 100644 api/routers/plugin/router.py diff --git a/api/alembic/versions/ac21c38c5e56_add_plugin_table.py b/api/alembic/versions/ac21c38c5e56_add_plugin_table.py new file mode 100644 index 00000000..cf0df21f --- /dev/null +++ b/api/alembic/versions/ac21c38c5e56_add_plugin_table.py @@ -0,0 +1,78 @@ +"""add plugin table + +Revision ID: ac21c38c5e56 +Revises: 5918599719a1 +Create Date: 2024-07-02 09:56:29.437606 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'ac21c38c5e56' +down_revision: Union[str, None] = '5918599719a1' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('plugin', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('plugin_type', sa.Integer(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('plugin_config', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('avatar', sa.String(length=255), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('plugin_id', sa.Integer(), nullable=True), + sa.Column('is_draft', sa.Boolean(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['plugin_id'], ['plugin.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('plugin_api', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('plugin_config_id', sa.Integer(), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('openapi_desc', sa.Text(), nullable=True), + sa.Column('disabled', sa.Boolean(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['plugin_config_id'], [ + 'plugin_config.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.alter_column('messages', 'message_type', + type_=sa.VARCHAR(length=50), + existing_type=sa.Enum( + 'MARKDOWN', 'TEXT', name='messagetype'), + existing_nullable=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('messages', 'message_type', + existing_type=sa.Enum( + 'MARKDOWN', 'TEXT', name='messagetype'), + type_=sa.VARCHAR(length=50), + existing_nullable=False) + op.drop_table('plugin_api') + op.drop_table('plugin_config') + op.drop_table('plugin') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index b200ab00..3fd5a1bd 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -4,6 +4,9 @@ from .agent_config import AgentConfigORM from .account import AccountORM from .chat import ChatORM, MessageORM +from .plugin import PluginORM +from .plugin_config import PluginConfigORM +from .plugin_api import PluginApiORM __all__ = [ 'Base', @@ -14,5 +17,8 @@ 'AgentBotORM', 'AgentConfigORM', 'ChatORM', - 'MessageORM' + 'MessageORM', + 'PluginORM', + 'PluginConfigORM', + 'PluginApiORM' ] diff --git a/api/models/plugin.py b/api/models/plugin.py new file mode 100644 index 00000000..41be034e --- /dev/null +++ b/api/models/plugin.py @@ -0,0 +1,60 @@ +from typing import Optional +from pydantic import BaseModel +from datetime import datetime + +from sqlalchemy import ( + Column, + Integer, + DateTime, +) + +from db import Base +from models.plugin_config import PluginConfigModel + + +class PluginORM(Base): + ''' + plugin database model + ''' + __tablename__ = "plugin" + id = Column(Integer, primary_key=True) + plugin_type = Column(Integer, nullable=False) # 插件创建类型 + created_by = Column(Integer) + created_at = Column( + DateTime(), nullable=False, default=datetime.now + ) + updated_by = Column(Integer) + updated_at = Column(DateTime(), + nullable=False, onupdate=datetime.now) + # statistic_data = Column(JSON) # 保留字段,用于存储诸如 被多少个bot 引用了 + + +class PluginCreate(BaseModel): + ''' + plugin create + ''' + plugin_type: int + + +class PluginUpdate(BaseModel): + ''' + plugin update + ''' + id: int + plugin_type: int + + +class PluginModel(PluginCreate): + ''' + plugin + ''' + id: int + plugin_type: int + created_by: int + created_at: datetime + updated_by: int + updated_at: datetime + draft: Optional[PluginConfigModel] = None + + class Config: + from_attributes = True diff --git a/api/models/plugin_api.py b/api/models/plugin_api.py new file mode 100644 index 00000000..223579f3 --- /dev/null +++ b/api/models/plugin_api.py @@ -0,0 +1,84 @@ +from typing import Optional +from pydantic import BaseModel +from datetime import datetime + +from sqlalchemy import ( + Column, + ForeignKey, + Integer, + DateTime, + Text, + Boolean +) + +from db import Base + + +class PluginApiORM(Base): + ''' + plugin api model + ''' + __tablename__ = "plugin_api" + id = Column(Integer, primary_key=True) + plugin_config_id = Column(Integer, ForeignKey( + "plugin_config.id"), nullable=True) + description = Column(Text) # API 描述 + openapi_desc = Column(Text) # openapi 先调研 + # request_params = Column(JSON) # 入参 + # response_params = Column(JSON) # 出参 + # debug_example = Column(JSON) # 调试示例 + # debug_example_status = Column(Integer) # 调试示例状态 + # debug_status = Column(Integer) # 调试状态 + disabled = Column(Boolean, nullable=False, default=True) + # online_status = Column(Integer) # 服务状态 + # path = Column(String(255)) #路径 + created_by = Column(Integer) + created_at = Column( + DateTime(), nullable=False, default=datetime.now, + ) + updated_by = Column(Integer) + updated_at = Column(DateTime(), + nullable=False, onupdate=datetime.now) + # statistic_data = Column(JSON) # 保留字段,用于存储诸如 被多少个bot 引用了 + + +class PluginApiCreate(BaseModel): + ''' + plugin api create + ''' + plugin_config_id: int + description: str + openapi_desc: str + disabled: bool + created_by: int + created_at: datetime + updated_by: int + updated_at: datetime + + +class PluginApiUpdate(BaseModel): + ''' + plugin api update + ''' + id: int + description: str + openapi_desc: str + disabled: bool + + +class PluginApiModel(PluginApiCreate): + ''' + plugin api + ''' + id: int + plugin_config_id: int + description: str + openapi_desc: str + disabled: bool + created_by: int + created_at: datetime + updated_by: int + updated_at: datetime + + class Config: + from_attributes = True diff --git a/api/models/plugin_config.py b/api/models/plugin_config.py new file mode 100644 index 00000000..b035a0ff --- /dev/null +++ b/api/models/plugin_config.py @@ -0,0 +1,79 @@ +from typing import Optional +from datetime import datetime +from pydantic import BaseModel +from sqlalchemy import ( + Boolean, + Column, + Integer, + DateTime, + ForeignKey, + String, + Text, +) + +from db import Base + + +class PluginConfigORM(Base): + ''' + plugin config database model + ''' + __tablename__ = "plugin_config" + id = Column(Integer, primary_key=True) + name = Column(String(255), nullable=False) # 插件名字 + avatar = Column(String(255)) # 插件图标 + description = Column(Text) # 插件描述 + # meta_info = Column(JSON) # 插件元信息,包含 api url,service_token,oauth_info 等 + plugin_id = Column(Integer, ForeignKey("plugin.id"), nullable=True) + is_draft = Column(Boolean, nullable=False) # 是否为草稿 + # openapi_desc = Column(Text) # openapi 先调研 + # plugin_desc = Column(Text) + created_by = Column(Integer) + created_at = Column( + DateTime(), nullable=False, default=datetime.now, + ) + updated_by = Column(Integer) + updated_at = Column(DateTime(), + nullable=False, onupdate=datetime.now) + + +class PluginConfigCreate(BaseModel): + ''' + plugin config create + ''' + plugin_id: int + name: str + avatar: str + description: str + + is_draft: Optional[bool] + + +class PluginConfigUpdate(BaseModel): + ''' + plugin config update + ''' + id: int + is_draft: Optional[bool] + name: str + avatar: str + description: str + + +class PluginConfigModel(PluginConfigCreate): + ''' + plugin config + ''' + id: int + name: str + avatar: str + description: str + + plugin_id: int + created_by: int + created_at: datetime + updated_by: int + updated_at: datetime + + class Config: + from_attributes = True diff --git a/api/routers/main.py b/api/routers/main.py index 58f798ea..6d651ef8 100644 --- a/api/routers/main.py +++ b/api/routers/main.py @@ -8,13 +8,14 @@ from .agent.router import agent_router from .account.router import account_router from .chat.router import chat_router +from .plugin.router import plugin_router api_router = APIRouter() api_router.include_router(agent_router, prefix="/agent", tags=["agent"]) api_router.include_router(account_router, prefix="/accounts", tags=["account"]) api_router.include_router(chat_router, prefix="/chats", tags=["chat"]) - +api_router.include_router(plugin_router, prefix="/plugins", tags=["plugin"]) COUNTER = 0 diff --git a/api/routers/plugin/__init__.py b/api/routers/plugin/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/routers/plugin/crud.py b/api/routers/plugin/crud.py new file mode 100644 index 00000000..7413ce1a --- /dev/null +++ b/api/routers/plugin/crud.py @@ -0,0 +1,158 @@ +from datetime import datetime +from typing import Any +from sqlalchemy.orm import Session +from sqlalchemy import func +from models.plugin_config import PluginConfigCreate, PluginConfigORM, PluginConfigUpdate +from models.plugin import PluginCreate, PluginORM, PluginUpdate, PluginModel +from models.plugin_api import PluginApiCreate, PluginApiORM, PluginApiUpdate, PluginApiModel +from fastapi_pagination.ext.sqlalchemy import paginate +from fastapi_pagination import Page + + +class PluginHelper: + + @staticmethod + def count(session: Session) -> int: + return session.query(func.count(PluginORM.id)).scalar() + + @staticmethod + def get(session: Session, plugin_id: int) -> PluginORM | None: + return session.query(PluginORM).filter(PluginORM.id == plugin_id).one_or_none() + + @staticmethod + def get_all(session: Session, user_id: int) -> Page[PluginModel]: + return paginate( + session.query(PluginORM) + .filter(PluginORM.created_by == user_id) + .order_by(PluginORM.updated_at.desc()) + ) + + @staticmethod + def update(session: Session, operator: int, plugin_model: PluginUpdate) -> int: + plugin_id = plugin_model.id + now = datetime.now() + update_model: Any = { + **plugin_model.model_dump(), + "updated_by": operator, + "updated_at": now, + } + result = session.query(PluginORM).filter( + PluginORM.id == plugin_id).update(update_model) + session.commit() + return result + + @staticmethod + def create(session: Session, operator: int, plugin_model: PluginCreate) -> PluginORM: + now = datetime.now() + model = PluginORM(**{ + **plugin_model.model_dump(), + "created_by": operator, + "created_at": now, + "updated_by": operator, + "updated_at": now + }) + session.add(model) + session.commit() + session.refresh(model) + return model + + +class PluginConfigHelper: + @staticmethod + def get(session: Session, config_id: int) -> PluginConfigORM | None: + return session.query(PluginConfigORM).filter(PluginConfigORM.id == config_id).one_or_none() + + @staticmethod + def get_plugin_draft(session: Session, plugin_id: int) -> PluginConfigORM | None: + return session.query(PluginConfigORM).filter( + PluginConfigORM.plugin_id == plugin_id, + PluginConfigORM.is_draft == True + ).one_or_none() + + @staticmethod + def get_or_create_plugin_draft(session: Session, operator: int, plugin_id: int,) -> PluginConfigORM: + exist = PluginConfigHelper.get_plugin_draft(session, plugin_id) + if exist is None: + return PluginConfigHelper.create(session, operator, PluginConfigCreate(plugin_id=plugin_id, name='', avatar='', description='', is_draft=True)) + else: + return exist + + @staticmethod + def create(session: Session, operator: int, config_model: PluginConfigCreate) -> PluginConfigORM: + now = datetime.now() + dict = { + "is_draft": False, + **config_model.model_dump(), + "created_by": operator, + "created_at": now, + "updated_by": operator, + "updated_at": now, + } + model = PluginConfigORM(**dict) + session.add(model) + session.commit() + session.refresh(model) + return model + + @staticmethod + def update(session: Session, operator: int, config_model: PluginConfigUpdate) -> int: + config_id = config_model.id + now = datetime.now() + update_model: dict[Any, Any] = { + **config_model.model_dump(), + "updated_by": operator, + "updated_at": now, + } + update_model.pop('id') + result = session.query(PluginConfigORM).filter( + PluginConfigORM.id == config_id).update(update_model) + session.commit() + return result + + +class PluginAPIHelper: + + @staticmethod + def count(session: Session) -> int: + return session.query(func.count(PluginApiORM.id)).scalar() + + @staticmethod + def get(session: Session, plugin_api_id: int) -> PluginApiORM | None: + return session.query(PluginApiORM).filter(PluginApiORM.id == plugin_api_id).one_or_none() + + @staticmethod + def get_all(session: Session, user_id: int) -> Page[PluginApiModel]: + return paginate( + session.query(PluginApiORM) + .filter(PluginApiORM.created_by == user_id) + .order_by(PluginApiORM.updated_at.desc()) + ) + + @staticmethod + def update(session: Session, operator: int, plugin_api_model: PluginApiUpdate) -> int: + plugin_api_id = plugin_api_model.id + now = datetime.now() + update_model: Any = { + **plugin_api_model.model_dump(), + "updated_by": operator, + "updated_at": now, + } + result = session.query(PluginApiORM).filter( + PluginApiORM.id == plugin_api_id).update(update_model) + session.commit() + return result + + @staticmethod + def create(session: Session, operator: int, plugin_api_model: PluginApiCreate) -> PluginApiORM: + now = datetime.now() + model = PluginApiORM(**{ + **plugin_api_model.model_dump(), + "created_by": operator, + "created_at": now, + "updated_by": operator, + "updated_at": now + }) + session.add(model) + session.commit() + session.refresh(model) + return model diff --git a/api/routers/plugin/router.py b/api/routers/plugin/router.py new file mode 100644 index 00000000..556e9249 --- /dev/null +++ b/api/routers/plugin/router.py @@ -0,0 +1,94 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session + +from fastapi_pagination import Page +from models.plugin import PluginModel, PluginCreate, PluginUpdate +from models.plugin_api import PluginApiCreate, PluginApiModel, PluginApiUpdate +from models.plugin_config import PluginConfigModel, PluginConfigCreate, PluginConfigUpdate + +from db import get_db +from .crud import PluginHelper, PluginConfigHelper, PluginAPIHelper + +router = APIRouter() + +plugin_router = router + + +@router.post("/", response_model=PluginModel) +def create_plugin(user_id: int, plugin: PluginCreate, session: Session = Depends(get_db)): + model = PluginHelper.create(session, user_id, plugin) + return PluginModel.model_validate(model) + + +@router.get("/", response_model=Page[PluginModel]) +def get_plugins(user_id: int, session: Session = Depends(get_db)): + data = PluginHelper.get_all(session, user_id) + return data + + +@router.get("/{plugin_id}", response_model=PluginModel) +async def get_plugin(plugin_id, user_id: int, with_draft=False, session: Session = Depends(get_db)): + model = PluginHelper.get(session, plugin_id) + if model is None: + raise HTTPException(404) + plugin_model = PluginModel.model_validate(model) + if with_draft: + draft = PluginConfigHelper.get_or_create_plugin_draft( + session, user_id, plugin_model.id) + plugin_model.draft = draft + return plugin_model + + +@router.get("/{plugin_id}/draft", response_model=PluginConfigModel) +async def get_or_create_plugin_draft_config(user_id: int, plugin_id, session: Session = Depends(get_db)): + model = PluginConfigHelper.get_or_create_plugin_draft( + session, user_id, plugin_id) + if model is None: + raise HTTPException(404) + return PluginConfigModel.model_validate(model) + + +@router.put("/{plugin_id}") +async def update_plugin(user_id: int, plugin: PluginUpdate, db: Session = Depends(get_db)): + success = PluginHelper.update(db, user_id, plugin) + return success + + +@router.get("/configs/{config_id}", response_model=PluginConfigModel) +async def get_plugin_config(config_id, session: Session = Depends(get_db)): + model = PluginConfigHelper.get(session, config_id) + if model is None: + raise HTTPException(404) + return PluginConfigModel.model_validate(model) + + +@router.put("/configs/{config_id}") +async def update_plugin_config(user_id: int, config: PluginConfigUpdate, db: Session = Depends(get_db)): + success = PluginConfigHelper.update(db, user_id, config) + return success + + +@router.post("/configs", response_model=PluginConfigModel) +async def create_plugin_config(user_id: int, config: PluginConfigCreate, session: Session = Depends(get_db)): + model = PluginConfigHelper.create(session, user_id, config) + return PluginConfigModel.model_validate(model) + + +@router.get("/api/{api_id}", response_model=PluginApiModel) +async def get_plugin_api(api_id, session: Session = Depends(get_db)): + model = PluginAPIHelper.get(session, api_id) + if model is None: + raise HTTPException(404) + return PluginApiModel.model_validate(model) + + +@router.put("/api/{api_id}") +async def update_plugin_api(user_id: int, api: PluginApiUpdate, db: Session = Depends(get_db)): + success = PluginAPIHelper.update(db, user_id, api) + return success + + +@router.post("/api", response_model=PluginApiModel) +async def create_plugin_api(user_id: int, api: PluginApiCreate, session: Session = Depends(get_db)): + model = PluginAPIHelper.create(session, user_id, api) + return PluginApiModel.model_validate(model)