Skip to content

Commit

Permalink
refactor(api): add services and dao
Browse files Browse the repository at this point in the history
  • Loading branch information
sunshinesmilelk authored and BroKun committed Jul 15, 2024
1 parent cbbe593 commit c067d2d
Show file tree
Hide file tree
Showing 14 changed files with 448 additions and 175 deletions.
Empty file added api/dao/__init__.py
Empty file.
File renamed without changes.
23 changes: 13 additions & 10 deletions api/routers/agent/crud.py → api/dao/agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from datetime import datetime
from typing import Any
from typing import Any, List
from sqlalchemy.orm import Session
from sqlalchemy import func
from models.agent_bot import AgentBotModel, AgentBotORM, AgentBotCreate, AgentBotUpdate
from models.agent_bot import AgentBotORM, AgentBotCreate, AgentBotUpdate
from models.agent_config import AgentConfigCreate, AgentConfigORM, AgentConfigUpdate
from fastapi_pagination.ext.sqlalchemy import paginate
from fastapi_pagination import Page


class AgentBotHelper:
Expand All @@ -19,12 +17,17 @@ def get(session: Session, bot_id: int) -> AgentBotORM | None:
return session.query(AgentBotORM).filter(AgentBotORM.id == bot_id).one_or_none()

@staticmethod
def get_all(session: Session, user_id: int) -> Page[AgentBotModel]:
return paginate(
session.query(AgentBotORM)
.filter(AgentBotORM.created_by == user_id)
.order_by(AgentBotORM.updated_at.desc())
)
def get_all(session: Session) -> List[AgentBotORM]:
return session.query(AgentBotORM).order_by(AgentBotORM.updated_at.desc()).all()

@staticmethod
def get_by_user(session: Session, user_id: int) -> List[AgentBotORM]:
return session.query(AgentBotORM).filter(AgentBotORM.created_by == user_id).order_by(AgentBotORM.updated_at.desc()).all()
# return paginate(
# session.query(AgentBotORM)
# .filter(AgentBotORM.created_by == user_id)
# .order_by(AgentBotORM.updated_at.desc())
# )

@staticmethod
def update(session: Session, operator: int, bot_model: AgentBotUpdate) -> int:
Expand Down
20 changes: 12 additions & 8 deletions api/routers/chat/crud.py → api/dao/chat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List
from datetime import datetime
from sqlalchemy.orm import Session
from models.agent_config import AgentConfigModel
from models.agent_config import AgentConfigModel, AgentConfigORM
from models.chat import ChatModel, ChatORM, MessageModel, MessageModelCreate, MessageORM
from routers.agent.crud import AgentConfigHelper
from dao.agent import AgentConfigHelper

from sqlalchemy import inspect

Expand All @@ -21,12 +21,16 @@ def getattr_from_column_name(instance, name, default=Ellipsis):

class ChatHelper:
@staticmethod
def get_chat_bot_config(session: Session, operator: int, coversation_id: int) -> AgentConfigModel:
chat_orm = session.query(ChatORM).filter(ChatORM.id ==
coversation_id).one_or_none()
chat_model = ChatModel.model_validate(chat_orm)
config_id = chat_model.bot_config_id
return AgentConfigHelper.get(session, config_id)
def get_chat(session: Session, chat_id: int) -> ChatORM:
chat_orm = session.query(ChatORM).filter(
ChatORM.id == chat_id).one_or_none()
return chat_orm

@staticmethod
def get_chat_by_agent_config(session: Session, agent_config_id: int) -> ChatORM:
chat_orm = session.query(ChatORM).filter(
ChatORM.bot_config_id == agent_config_id).one_or_none()
return chat_orm

@staticmethod
def get_or_create_bot_chat(session: Session, operator: int, agent_config_id: int) -> ChatORM:
Expand Down
50 changes: 27 additions & 23 deletions api/routers/plugin/crud.py → api/dao/plugin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from datetime import datetime
from typing import Any
from typing import Any, List
from sqlalchemy.orm import Session
from sqlalchemy import func
from models.plugin_config import PluginConfigCreate, PluginConfigORM, PluginConfigUpdate
from models.plugin import PluginCreate, PluginORM, PluginUpdate, PluginModel
from models.plugin_api import PluginApiCreate, PluginApiORM, PluginApiUpdate, PluginApiModel
from fastapi_pagination.ext.sqlalchemy import paginate
from fastapi_pagination import Page
from models.plugin import PluginCreate, PluginORM, PluginUpdate
from models.plugin_api import PluginApiCreate, PluginApiORM, PluginApiUpdate


class PluginHelper:
Expand All @@ -20,20 +18,21 @@ def get(session: Session, plugin_id: int) -> PluginORM | None:
return session.query(PluginORM).filter(PluginORM.id == plugin_id).one_or_none()

@staticmethod
def get_user_plugin(session: Session, user_id: int) -> Page[PluginModel]:
print('user_id_test', user_id)
return paginate(
session.query(PluginORM)
.filter(PluginORM.created_by == user_id)
.order_by(PluginORM.updated_at.desc())
)
def get_user_plugin(session: Session, user_id: int) -> List[PluginORM]:
return session.query(PluginORM).filter(PluginORM.created_by == user_id).order_by(PluginORM.updated_at.desc()).all()
# return paginate(
# session.query(PluginORM)
# .filter(PluginORM.created_by == user_id)
# .order_by(PluginORM.updated_at.desc())
# )

@staticmethod
def get_all_plugin(session: Session) -> Page[PluginModel]:
return paginate(
session.query(PluginORM)
.order_by(PluginORM.updated_at.desc())
)
def get_all_plugin(session: Session) -> List[PluginORM]:
return session.query(PluginORM).order_by(PluginORM.updated_at.desc()).all()
# return paginate(
# session.query(PluginORM)
# .order_by(PluginORM.updated_at.desc())
# )

@staticmethod
def update(session: Session, operator: int, plugin_model: PluginUpdate) -> int:
Expand Down Expand Up @@ -129,12 +128,17 @@ def get(session: Session, plugin_api_id: int) -> PluginApiORM | None:
return session.query(PluginApiORM).filter(PluginApiORM.id == plugin_api_id).one_or_none()

@staticmethod
def get_all(session: Session, user_id: int) -> Page[PluginApiModel]:
return paginate(
session.query(PluginApiORM)
.filter(PluginApiORM.created_by == user_id)
.order_by(PluginApiORM.updated_at.desc())
)
def get_user_all(session: Session, user_id: int) -> List[PluginApiORM]:
return session.query(PluginApiORM).filter(PluginApiORM.created_by == user_id).order_by(PluginApiORM.updated_at.desc()).all()
# paginate(
# session.query(PluginApiORM)
# .filter(PluginApiORM.created_by == user_id)
# .order_by(PluginApiORM.updated_at.desc())
# )

@staticmethod
def get_all(session: Session) -> List[PluginApiORM]:
return session.query(PluginApiORM).order_by(PluginApiORM.updated_at.desc()).all()

@staticmethod
def update(session: Session, operator: int, plugin_api_model: PluginApiUpdate) -> int:
Expand Down
23 changes: 10 additions & 13 deletions api/routers/account/router.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,29 @@
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.orm import Session
from fastapi import APIRouter, HTTPException

from models.account import AccountModel, AccountCreate
from db import get_db

from .crud import AccountHelper
from services.account import AccountService

router = APIRouter()
account_router = router


@router.post("/", response_model=AccountModel)
async def create_account(model: AccountCreate, session: Session = Depends(get_db), ):
model = await AccountHelper.create(session, model)
return AccountModel.model_validate(model)
async def create_account(model: AccountCreate) -> AccountModel:
account_model = await AccountService.create(model)
return account_model


@router.get("/{user_id}", response_model=AccountModel)
async def get_account_by_id(user_id, db: Session = Depends(get_db)):
model = AccountHelper.get_by_id(db, user_id)
async def get_account_by_id(user_id):
model = AccountService.get_by_id(user_id)
if model is None:
raise HTTPException(404)
return AccountModel.model_validate(model)
return model


@router.get("/email/{email}", response_model=AccountModel)
async def get_account_by_email(email, db: Session = Depends(get_db)):
model = AccountHelper.get_by_email(db, email)
async def get_account_by_email(email):
model = AccountService.get_by_email(email)
if model is None:
raise HTTPException(404)
return AccountModel.model_validate(model)
61 changes: 28 additions & 33 deletions api/routers/agent/router.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,70 @@
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.orm import Session
from fastapi import APIRouter, HTTPException

from fastapi_pagination import Page
from fastapi_pagination import Page, paginate

from models.agent_bot import AgentBotModel, AgentBotCreate, AgentBotUpdate
from models.agent_config import AgentConfigModel, AgentConfigUpdate, AgentConfigCreate

from db import get_db

from .crud import AgentBotHelper, AgentConfigHelper

from services.agent import AgentConfigService, AgentService

router = APIRouter()

agent_router = router


@router.post("/bots", response_model=AgentBotModel)
def create_agent_bot(user_id: int, bot: AgentBotCreate, session: Session = Depends(get_db)):
model = AgentBotHelper.create(session, user_id, bot)
return AgentBotModel.model_validate(model)
def create_agent_bot(user_id: int, bot: AgentBotCreate):
model = AgentService.create(user_id, bot)
return model


@router.get("/bots", response_model=Page[AgentBotModel])
def get_agent_bots(user_id: int, session: Session = Depends(get_db)):
data = AgentBotHelper.get_all(session, user_id)
return data
def get_agent_bots(user_id: int):
data = AgentService.get_by_user(user_id)
res = paginate(data)
return res


@router.get("/bots/{bot_id}", response_model=AgentBotModel)
async def get_agent_bot(bot_id, user_id: int, with_draft=False, session: Session = Depends(get_db)):
model = AgentBotHelper.get(session, bot_id)
async def get_agent_bot(bot_id, user_id: int, with_draft=False):
model = AgentService.get_by_id(bot_id)
if model is None:
raise HTTPException(404)
bot = AgentBotModel.model_validate(model)
if with_draft:
draft = AgentConfigHelper.get_or_create_bot_draft(
session, user_id, bot.id)
bot.draft = draft
return bot
draft = AgentConfigService.get_or_create_bot_draft(user_id, bot_id)
model.draft = draft
return model


@router.get("/bots/{bot_id}/draft", response_model=AgentConfigModel)
async def get_or_create_agent_bot_draft_config(user_id: int, bot_id, session: Session = Depends(get_db)):
model = AgentConfigHelper.get_or_create_bot_draft(session, user_id, bot_id)
async def get_or_create_agent_bot_draft_config(user_id: int, bot_id):
model = AgentConfigService.get_or_create_bot_draft(user_id, bot_id)
if model is None:
raise HTTPException(404)
return AgentConfigModel.model_validate(model)
return model


@router.put("/bots/{bot_id}")
async def update_agent_bot(user_id: int, bot: AgentBotUpdate, db: Session = Depends(get_db)):
success = AgentBotHelper.update(db, user_id, bot)
async def update_agent_bot(user_id: int, bot: AgentBotUpdate):
success = AgentService.update(user_id, bot)
return success


@router.get("/configs/{config_id}", response_model=AgentConfigModel)
async def get_agent_config(config_id, session: Session = Depends(get_db)):
model = AgentConfigHelper.get(session, config_id)
async def get_agent_config(config_id):
model = AgentConfigService.get_by_id(config_id)
if model is None:
raise HTTPException(404)
return AgentConfigModel.model_validate(model)
return model


@router.put("/configs/{bot_id}")
async def update_agent_config(user_id: int, config: AgentConfigUpdate, db: Session = Depends(get_db)):
success = AgentConfigHelper.update(db, user_id, config)
async def update_agent_config(user_id: int, config: AgentConfigUpdate):
success = AgentConfigService.update(user_id, config)
return success


@router.post("/configs", response_model=AgentConfigModel)
async def create_agent_config(user_id: int, config: AgentConfigCreate, session: Session = Depends(get_db)):
model = AgentConfigHelper.create(session, user_id, config)
return AgentConfigModel.model_validate(model)
async def create_agent_config(user_id: int, config: AgentConfigCreate):
model = AgentConfigService.create(user_id, config)
return model
Loading

0 comments on commit c067d2d

Please sign in to comment.