Skip to content

Commit

Permalink
remove assistant_id from run payload (#269)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkorpela authored Apr 4, 2024
1 parent 38f5c3f commit 85a9409
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 43 deletions.
44 changes: 22 additions & 22 deletions backend/app/api/runs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
from typing import Optional, Sequence

import langsmith.client
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
from fastapi import APIRouter, BackgroundTasks, HTTPException
from fastapi.exceptions import RequestValidationError
from langchain.pydantic_v1 import ValidationError
from langchain_core.messages import AnyMessage
Expand All @@ -15,7 +14,7 @@

from app.agent import agent
from app.schema import OpengptsUserId
from app.storage import get_assistant
from app.storage import get_assistant, get_thread
from app.stream import astream_messages, to_sse

router = APIRouter()
Expand All @@ -24,63 +23,64 @@
class CreateRunPayload(BaseModel):
"""Payload for creating a run."""

assistant_id: str
thread_id: str
input: Optional[Sequence[AnyMessage]] = Field(default_factory=list)
config: Optional[RunnableConfig] = None


async def _run_input_and_config(request: Request, opengpts_user_id: OpengptsUserId):
try:
body = await request.json()
except json.JSONDecodeError:
raise RequestValidationError(errors=["Invalid JSON body"])
assistant = await get_assistant(opengpts_user_id, body["assistant_id"])
async def _run_input_and_config(
payload: CreateRunPayload, opengpts_user_id: OpengptsUserId
):
thread = await get_thread(opengpts_user_id, payload.thread_id)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")

assistant = await get_assistant(opengpts_user_id, str(thread["assistant_id"]))
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")

config: RunnableConfig = {
**assistant["config"],
"configurable": {
**assistant["config"]["configurable"],
**(body.get("config", {}).get("configurable") or {}),
**((payload.config or {}).get("configurable") or {}),
"user_id": opengpts_user_id,
"thread_id": body["thread_id"],
"assistant_id": body["assistant_id"],
"thread_id": str(thread["thread_id"]),
"assistant_id": str(assistant["assistant_id"]),
},
}

try:
input_ = (
_unpack_input(agent.get_input_schema(config).validate(body["input"]))
if body["input"] is not None
_unpack_input(agent.get_input_schema(config).validate(payload.input))
if payload.input is not None
else None
)
except ValidationError as e:
raise RequestValidationError(e.errors(), body=body)
raise RequestValidationError(e.errors(), body=payload)

return input_, config


@router.post("")
async def create_run(
payload: CreateRunPayload, # for openapi docs
request: Request,
payload: CreateRunPayload,
opengpts_user_id: OpengptsUserId,
background_tasks: BackgroundTasks,
):
"""Create a run."""
input_, config = await _run_input_and_config(request, opengpts_user_id)
input_, config = await _run_input_and_config(payload, opengpts_user_id)
background_tasks.add_task(agent.ainvoke, input_, config)
return {"status": "ok"} # TODO add a run id


@router.post("/stream")
async def stream_run(
payload: CreateRunPayload, # for openapi docs
request: Request,
payload: CreateRunPayload,
opengpts_user_id: OpengptsUserId,
):
"""Create a run."""
input_, config = await _run_input_and_config(request, opengpts_user_id)
input_, config = await _run_input_and_config(payload, opengpts_user_id)

return EventSourceResponse(to_sse(astream_messages(agent, input_, config)))

Expand Down
4 changes: 2 additions & 2 deletions backend/app/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from datetime import datetime
import pickle
from datetime import datetime
from typing import AsyncIterator, Optional

from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.checkpoint.base import Checkpoint, CheckpointTuple, CheckpointThreadTs
from langgraph.checkpoint.base import Checkpoint, CheckpointThreadTs, CheckpointTuple

from app.lifespan import get_pg_pool

Expand Down
2 changes: 1 addition & 1 deletion backend/app/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import os
from typing import Any, BinaryIO, List, Optional

from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_community.document_loaders.blob_loaders.schema import Blob
from langchain_community.vectorstores.pgvector import PGVector
from langchain_core.runnables import (
Expand All @@ -21,6 +20,7 @@
)
from langchain_core.vectorstores import VectorStore
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter

from app.ingest import ingest_blob
from app.parsing import MIMETYPE_BASED_PARSER
Expand Down
9 changes: 2 additions & 7 deletions frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ function App() {
const { currentChat, assistantConfig, isLoading } = useThreadAndAssistant();

const startTurn = useCallback(
async (
message: MessageWithFiles | null,
thread_id: string,
assistant_id: string,
) => {
async (message: MessageWithFiles | null, thread_id: string) => {
const files = message?.files || [];
if (files.length > 0) {
const formData = files.reduce((formData, file) => {
Expand All @@ -58,7 +54,6 @@ function App() {
},
]
: null,
assistant_id,
thread_id,
);
},
Expand All @@ -69,7 +64,7 @@ function App() {
async (config: ConfigInterface, message: MessageWithFiles) => {
const chat = await createChat(message.message, config.assistant_id);
navigate(`/thread/${chat.thread_id}`);
return startTurn(message, chat.thread_id, chat.assistant_id);
return startTurn(message, chat.thread_id);
},
[createChat, navigate, startTurn],
);
Expand Down
14 changes: 3 additions & 11 deletions frontend/src/hooks/useStreamState.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@ export interface StreamState {

export interface StreamStateProps {
stream: StreamState | null;
startStream: (
input: Message[] | null,
assistant_id: string,
thread_id: string,
) => Promise<void>;
startStream: (input: Message[] | null, thread_id: string) => Promise<void>;
stopStream?: (clear?: boolean) => void;
}

Expand All @@ -24,11 +20,7 @@ export function useStreamState(): StreamStateProps {
const [controller, setController] = useState<AbortController | null>(null);

const startStream = useCallback(
async (
input: Message[] | null,
assistant_id: string,
thread_id: string,
) => {
async (input: Message[] | null, thread_id: string) => {
const controller = new AbortController();
setController(controller);
setCurrent({ status: "inflight", messages: input || [], merge: true });
Expand All @@ -37,7 +29,7 @@ export function useStreamState(): StreamStateProps {
signal: controller.signal,
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ input, assistant_id, thread_id }),
body: JSON.stringify({ input, thread_id }),
openWhenHidden: true,
onmessage(msg) {
if (msg.event === "data") {
Expand Down

0 comments on commit 85a9409

Please sign in to comment.