Skip to content

Commit

Permalink
Feat/chat poc (#91)
Browse files Browse the repository at this point in the history
* ✨ Feat(chat): poc of chat using langstream + sse

* ♻️  Refactoring(llm): move around class/defs/files

* ✨ Feat(llm): select llm backend via envs

* ✨ Feat(chat): integrate with rag capabilities

* ✨ Feat(chat): select rag document from chat

* 🐛 Bug(parrot): make parrot work with rag stream

* ♻️  Refactoring(sse): cleanup and fix sse format

* ♻️  Refactoring(openai): use template

* ♻️  Refactoring: misc cleanup
  • Loading branch information
MasterKenth authored May 23, 2024
1 parent 0b2ea96 commit e2c1929
Show file tree
Hide file tree
Showing 29 changed files with 665 additions and 87 deletions.
15 changes: 15 additions & 0 deletions fai-rag-app/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,18 @@ FILE_UPLOAD_PATH=uploads
# Refs: https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
#
TOKENIZERS_PARALLELISM=false

# LLM (Large Language Model) configuration.
# Select backend for LLM. Options: 'parrot', 'openai'.
# Defaults to 'parrot' which just echoes back the questions given.
LLM_BACKEND=parrot

# API key needed for operations using OpenAI API
# See https://platform.openai.com/ for more info.
OPENAI_API_KEY=my-openai-key

# Model name to use for chat prompt
CHAT_MODEL=gpt-4o

# Model name to use for RAG scoring prompt
SCORING_MODEL=gpt-3.5-turbo
80 changes: 75 additions & 5 deletions fai-rag-app/fai-backend/fai_backend/chat/template.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,74 @@
import dataclasses
import os
from fai_backend.chat.prompt import UserChatPrompt, SystemChatPrompt
from typing import List, Callable, Any

from fai_backend.chat.prompt import UserChatPrompt, SystemChatPrompt, MessageChatPrompt


@dataclasses.dataclass
class PromptTemplate:
name: str
messages: List[MessageChatPrompt]

# Takes the input T from a Stream[T,U] and returns a dict of replacement variables
input_map_fn: Callable[[Any], dict[str, str]]

settings: dict[str, Any]


chatPromptTemplate = PromptTemplate(
name="ChatStream",
messages=[
SystemChatPrompt(
"You are a helpful AI assistant that helps people with answering questions about planning "
"permission.<br> If you can't find the answer in the search result below, just say (in Swedish) "
"\"Tyvärr kan jag inte svara på det.\" Don't try to make up an answer.<br> If the "
"question is not related to the context, politely respond that you are tuned to only "
"answer questions that are related to the context.<br> The questions are going to be "
"asked in Swedish. Your response must always be in Swedish."
),
UserChatPrompt("{query}"),
UserChatPrompt("Here are the results of the search:\n\n {results}"),
],
input_map_fn=lambda vector_results: {
"query": list(vector_results)[0]['query'],
"results": ' | '.join([doc for doc, _ in list(vector_results)[0]['results']])
},
settings={
"model": os.environ.get("CHAT_MODEL", "gpt-4o"),
"temperature": 0
}
)

scoringPromptTemplate = PromptTemplate(
name="ScoringStream",
messages=[
SystemChatPrompt(
"You are a scoring systems that classifies documents from 0-100 based on how well they answer a query."),
UserChatPrompt("Query: {query}\n\nDocument: {document}"),
],
input_map_fn=lambda input: {**(input)},
settings={
"model": os.environ.get("SCORING_MODEL", "gpt-3.5-turbo"),
"temperature": 0,
"functions": [
{
"name": "score_document",
"description": "Scores the previous document according to the user query\n\n Parameters\n ----------\n score\n A number from 0-100 scoring how well does the document matches the query. The higher the score, the better match for the query\n ",
"parameters": {
"type": "object",
"properties": {
"score": {
"type": "number",
}
},
"required": ["score"],
}
}
],
"function_call": {"name": "score_document"},
}
)

CHAT_PROMPT_TEMPLATE_ARGS = {
"name": "ChatStream",
Expand All @@ -20,20 +89,21 @@
"results": ' | '.join([doc for doc, _ in list(input)[0]['results']])
},
"settings": {
"model": os.environ.get("GPT_4_MODEL_NAME", "gpt-4"),
"model": os.environ.get("CHAT_MODEL", "gpt-4o"),
"temperature": 0
},
}

SCORING_PROMPT_TEMPLATE_ARGS = {
"name": "ScoringStream",
"messages": [
SystemChatPrompt("You are a scoring systems that classifies documents from 0-100 based on how well they answer a query."),
SystemChatPrompt(
"You are a scoring systems that classifies documents from 0-100 based on how well they answer a query."),
UserChatPrompt("Query: {query}\n\nDocument: {document}"),
],
"input_map_fn": lambda input: {**(input)},
"settings": {
"model": "gpt-3.5-turbo",
"model": os.environ.get("SCORING_MODEL", "gpt-3.5-turbo"),
"temperature": 0,
"functions": [
{
Expand All @@ -52,4 +122,4 @@
],
"function_call": {"name": "score_document"},
},
}
}
1 change: 1 addition & 0 deletions fai-rag-app/fai-backend/fai_backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Settings(BaseSettings, extra=Extra.ignore):
LOG_LEVEL: str = 'INFO'
DEFAULT_LANGUAGE: str = 'en'
FILE_UPLOAD_PATH: str = 'uploads'
LLM_BACKEND: Literal['parrot', 'openai'] = 'parrot'

class Config:
env_file = '.env'
Expand Down
2 changes: 2 additions & 0 deletions fai-rag-app/fai-backend/fai_backend/documents/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ def list_view(
{
'file_name': document.file_name,
'file_size': document.file_size.human_readable(),
'collection': document.collection,
'mime_type': document.mime_type,
'upload_date': document.upload_date.date(),
}
for document in documents
],
columns=[
{'key': 'file_name', 'label': _('file_name', 'File name')},
{'key': 'collection', 'label': _('collection', 'Collection')},
{'key': 'file_size', 'label': _('file_size', 'File size')},
{'key': 'mime_type', 'label': _('mime_type', 'Mime type')},
{'key': 'upload_date', 'label': _('upload_date', 'Upload date')},
Expand Down
1 change: 1 addition & 0 deletions fai-rag-app/fai-backend/fai_backend/files/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class FileInfo(BaseModel):
file_name: str
file_size: ByteSize
path: str
collection: str
mime_type: str
last_modified: datetime
upload_date: datetime
Expand Down
17 changes: 10 additions & 7 deletions fai-rag-app/fai-backend/fai_backend/files/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_file_infos(self, directory_path, upload_date: datetime) -> list[FileInfo
file_name=file_name,
file_size=ByteSize(stat.st_size),
path=file_path,
collection=file_path.split('/')[-2], # TODO: niceify
mime_type=mime_type or 'application/octet-stream',
last_modified=datetime.fromtimestamp(stat.st_mtime),
upload_date=upload_date,
Expand All @@ -55,19 +56,21 @@ def get_file_infos(self, directory_path, upload_date: datetime) -> list[FileInfo
return file_infos

def list_files(self, project_id: str) -> list[FileInfo]:
project_directories = [d for d in os.listdir(self.upload_dir) if d.startswith(f'{PROJECT_PATH_PREFIX}_{project_id}_')]
project_directories = [d for d in os.listdir(self.upload_dir) if
d.startswith(f'{PROJECT_PATH_PREFIX}_{project_id}_')]
if not project_directories:
return []

latest_directory = sorted(project_directories, key=lambda x: (x.split('_')[2], x.split('_')[3]), reverse=True)[
0]
latest_directory_path = os.path.join(self.upload_dir, latest_directory)
upload_date = datetime.fromtimestamp(os.path.getctime(latest_directory_path))
full_paths = [os.path.join(self.upload_dir, path) for path in project_directories]

all_files = [file for path in full_paths for file in
self.get_file_infos(path, datetime.fromtimestamp(os.path.getctime(path)))]

return self.get_file_infos(latest_directory_path, upload_date)
return all_files

def get_latest_upload_path(self, project_id: str) -> str | None:
project_directories = [d for d in os.listdir(self.upload_dir) if d.startswith(f'{PROJECT_PATH_PREFIX}_{project_id}_')]
project_directories = [d for d in os.listdir(self.upload_dir) if
d.startswith(f'{PROJECT_PATH_PREFIX}_{project_id}_')]
if not project_directories:
return None

Expand Down
15 changes: 14 additions & 1 deletion fai-rag-app/fai-backend/fai_backend/framework/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
'Radio',
'ChatBubble',
'FileInput',
'SSEChat',
# then `AnyComponent` itself
'AnyUI',
)
Expand Down Expand Up @@ -257,9 +258,21 @@ class ChatBubble(UIComponent):
footer: 'list[AnyUI] | None' = Field(None, serialization_alias='slot.footer')


class SSEDocument(BaseModel):
id: str
name: str


class SSEChat(UIComponent):
type: Literal['SSEChat'] = 'SSEChat'
documents: list[SSEDocument]
endpoint: str


AnyUI = Annotated[
(Div | Form | InputField | Button | FireEvent | Heading |
AppShell | AppDrawer | AppContent | AppFooter | PageHeader |
PageContent | Menu | Link | Textarea | Text | Table | Pagination | Select | Radio | ChatBubble | FileInput),
PageContent | Menu | Link | Textarea | Text | Table | Pagination | Select | Radio |
ChatBubble | FileInput | SSEChat),
Field(discriminator='type')
]
Empty file.
32 changes: 32 additions & 0 deletions fai-rag-app/fai-backend/fai_backend/llm/impl/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Iterable, Any

from langstream import Stream
from langstream.contrib import OpenAIChatStream, OpenAIChatDelta, OpenAIChatMessage

from fai_backend.chat.stream import create_chat_prompt
from fai_backend.chat.template import PromptTemplate
from fai_backend.llm.protocol import ILLMStreamProtocol
from fai_backend.llm.models import LLMDataPacket


class OpenAILLM(ILLMStreamProtocol):

def __init__(self, template: PromptTemplate):
self.template = template

async def create(self) -> Stream[str, LLMDataPacket]:
def messages(in_data: Any) -> Iterable[OpenAIChatMessage]:
prompt = create_chat_prompt({
"name": self.template.name,
"messages": self.template.messages,
"settings": self.template.settings,
})
prompt.format_prompt(self.template.input_map_fn(in_data))
return prompt.to_messages()

return OpenAIChatStream[str, OpenAIChatDelta](
"RecipeStream",
messages,
model="gpt-4",
temperature=0,
).map(lambda delta: LLMDataPacket(content=delta.content, user_friendly=True))
40 changes: 40 additions & 0 deletions fai-rag-app/fai-backend/fai_backend/llm/impl/parrot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import asyncio
from random import uniform
from typing import Any

from langstream import Stream

from fai_backend.llm.protocol import ILLMStreamProtocol
from fai_backend.llm.models import LLMDataPacket


class ParrotLLM(ILLMStreamProtocol):
"""
Parrot (mock) LLM protocol reference implementation.
Parrot will respond with the same message as its input, with a random delay between tokens (words).
"""

def __init__(self, min_delay: float = 0.1, max_delay: float = 1.0):
self.min_delay = min_delay
self.max_delay = max_delay

async def to_generator(self, input_message: str | Any):
if not isinstance(input_message, str):
if isinstance(input_message, list) and "query" in input_message[0]:
input_message = input_message[0]["query"]
else:
yield "squawk?"
return

import re
parts = re.findall(r'\S+\s*', input_message)
for part in parts:
yield part
await asyncio.sleep(uniform(self.min_delay, self.max_delay))

async def create(self) -> Stream[str, LLMDataPacket]:
return Stream[str, str](
"ParrotStream",
self.to_generator
).map(lambda delta: LLMDataPacket(content=delta, user_friendly=True))
26 changes: 26 additions & 0 deletions fai-rag-app/fai-backend/fai_backend/llm/impl/rag_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from langstream import Stream

from fai_backend.llm.protocol import ILLMStreamProtocol
from fai_backend.llm.models import LLMDataPacket
from fai_backend.llm.service import create_rag_stream


class RAGWrapper(ILLMStreamProtocol):
"""
Wraps an underlying Stream with RAG capabilities.
The underlying stream will be supplied with document extracts in plaintext
from the given collection along with the original question.
"""

def __init__(self, input_query: str, base_llm: ILLMStreamProtocol, rag_collection_name: str):
self.input_query = input_query
self.rag_collection_name = rag_collection_name
self.base_llm = base_llm

async def create(self) -> Stream[str, LLMDataPacket]:
rag_stream = await create_rag_stream(self.input_query, self.rag_collection_name)
base_stream = await self.base_llm.create()

return (rag_stream
.and_then(base_stream))
16 changes: 16 additions & 0 deletions fai-rag-app/fai-backend/fai_backend/llm/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import dataclasses
from datetime import datetime

from pydantic import BaseModel


class LLMMessage(BaseModel):
date: datetime
source: str | None = None
content: str | None = None


@dataclasses.dataclass
class LLMDataPacket:
content: str
user_friendly: bool = False
14 changes: 14 additions & 0 deletions fai-rag-app/fai-backend/fai_backend/llm/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Protocol

from langstream import Stream

from fai_backend.llm.models import LLMDataPacket


class ILLMStreamProtocol(Protocol):
async def create(self) -> Stream[str, LLMDataPacket]:
"""
Create a Stream that takes a str (generally a question) and returns
a stream of tokens (strings) of the response given by the LLM.
"""
...
Empty file.
Empty file.
11 changes: 11 additions & 0 deletions fai-rag-app/fai-backend/fai_backend/llm/serializer/impl/base64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import base64

from pydantic import BaseModel

from fai_backend.llm.serializer.protocol import ISerializer


class Base64Serializer(ISerializer):
def serialize(self, input_data: BaseModel) -> str:
output_data: str = input_data.model_dump_json(exclude_none=True)
return base64.b64encode(output_data.encode("utf-8")).decode("utf-8")
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydantic import BaseModel

from fai_backend.llm.serializer.protocol import ISerializer


class JSONSerializer(ISerializer):
def serialize(self, input_data: BaseModel) -> str:
return input_data.model_dump_json(exclude_none=True)
11 changes: 11 additions & 0 deletions fai-rag-app/fai-backend/fai_backend/llm/serializer/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Protocol

from pydantic import BaseModel


class ISerializer(Protocol):
def serialize(self, input_data: BaseModel) -> str:
"""
"""
...
Loading

0 comments on commit e2c1929

Please sign in to comment.