diff --git a/app/graph.py b/app/graph.py index 81e929e..1c6e14f 100644 --- a/app/graph.py +++ b/app/graph.py @@ -3,11 +3,13 @@ from typing import Annotated, Sequence, TypedDict from langchain.agents import create_openai_tools_agent, AgentExecutor -from langchain_core.messages import BaseMessage, HumanMessage +from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage, ChatMessage, \ + ToolMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_openai import ChatOpenAI from langgraph.graph import StateGraph, END from langgraph.pregel import Pregel +from pydantic import BaseModel from app.chains.supervisor import build_supervisor_chain from app.dependencies.openai_chat_model import openai_chat_model @@ -78,4 +80,9 @@ def build_graph() -> Pregel: return workflow.compile() -graph = build_graph() +class Input(BaseModel): + messages: Sequence[HumanMessage | AIMessage | SystemMessage | FunctionMessage | ChatMessage | ToolMessage] + next: str | None + + +graph = build_graph().with_types(input_type=Input)