Skip to content

Commit

Permalink
Implemented basic query stream
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninoLorenzo committed Jun 21, 2024
1 parent f94ef78 commit bc4d7d7
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 9 deletions.
50 changes: 43 additions & 7 deletions src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,41 @@
- /collections/list: Returns available Collections.
- /collections/new: Creates a new Collection.
"""
import argparse
from argparse import ArgumentParser

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse


from src import upload_knowledge
from src.agent import Agent
from src.agent.llm import LLM
from src.agent.knowledge import Store
from src.agent.tools import TOOLS

# setup
# Agent Setup
model = 'llama3'
tools = '\n'.join([tool.get_documentation() for tool in TOOLS])
store = Store()
upload_knowledge('../data/json', store)
agent = Agent(model=model, tools_docs=tools, knowledge_base=store)
# store = Store()
# upload_knowledge('../data/json', store)
# agent = Agent(model=model, tools_docs=tools)# , knowledge_base=store)
llm = LLM(model)

# API Setup
origins = [
'http://localhost:3000' # default frontend port
]

app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)


# --- SESSION RELATED
Expand All @@ -58,15 +78,15 @@ def get_session(sid: int):
def new_session(name: str):
"""
Creates a new session.
Returns True (Success) or False (Failure).
Returns True (Success) or False (Failure). (or redirects to load the new session)
"""


@app.get('/session/{sid}/rename/')
def rename_session(sid: int, new_name: str):
"""
Rename a session.
Returns True (Success) or False (Failure).
Returns True (Success) or False (Failure). (or redirects to load the renamed session)
"""


Expand All @@ -82,17 +102,29 @@ def save_session(sid: int):
def delete_session(sid: int):
"""
Delete a session.
Returns True (Success) or False (Failure).
Returns True (Success) or False (Failure). (or redirects to load the updated responses)
"""


# --- AGENT RELATED

def query_generator(sid: int, q: str):
# testing with llm only
stream = llm.query(messages=[
{'role': 'system', 'content': 'You are an assistant'},
{'role': 'user', 'content': q}
])
for chunk in stream:
yield chunk['message']['content']


@app.get('/session/{sid}/query/')
def query(sid: int, q: str):
"""
Makes a query to the Agent.
Returns the stream for the response.
"""
return StreamingResponse(query_generator(sid, q))


# --- PLAN RELATED
Expand Down Expand Up @@ -145,3 +177,7 @@ def create_collection(title: str, base_path: str, topics: list):
]
"""


if __name__ == "__main__":
# get api settings ...
pass
25 changes: 23 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import requests

from src import upload_knowledge
from src.agent import Agent
from src.agent.knowledge import Store
from src.agent.tools import TOOLS


# Enter: new 1
# Enter: rename plan_no_rag
# Enter: save 1
Expand All @@ -18,7 +21,7 @@ def cli_test():
# upload_knowledge('../data/json', vector_db)

# =================================================================
agent = Agent(model=ollama_model, tools_docs=tools_documentation)#, knowledge_base=vector_db)
agent = Agent(model=ollama_model, tools_docs=tools_documentation) # , knowledge_base=vector_db)
current_session = 0
while True:
user_input = input("Enter: ")
Expand Down Expand Up @@ -49,5 +52,23 @@ def cli_test():
print()


def api_test():
s = requests.Session()
url = 'http://127.0.0.1:8000/session/0/query'
params = {'q': 'tell me how to make a search engine'}

with s.get(url, params=params, headers=None, stream=True) as resp:
print('Assistant: ')
text = ''
for chunk in resp.iter_content():
if chunk:
text += chunk.decode()
print(chunk.decode(), end='')
if len(text) % 200 == 0:
print()
print()


if __name__ == "__main__":
cli_test()
# cli_test()
api_test()

0 comments on commit bc4d7d7

Please sign in to comment.