diff --git a/context_engine/chat_engine/chat_engine.py b/context_engine/chat_engine/chat_engine.py index 565108dc..17d128f0 100644 --- a/context_engine/chat_engine/chat_engine.py +++ b/context_engine/chat_engine/chat_engine.py @@ -1,5 +1,6 @@ +import os from abc import ABC, abstractmethod -from typing import Iterable, Union, Optional +from typing import Iterable, Union, Optional, cast from context_engine.chat_engine.models import HistoryPruningMethod from context_engine.chat_engine.prompt_builder import PromptBuilder @@ -9,9 +10,12 @@ from context_engine.knoweldge_base.tokenizer import Tokenizer from context_engine.llm import BaseLLM from context_engine.llm.models import ModelParams, SystemMessage -from context_engine.models.api_models import StreamingChatResponse, ChatResponse +from context_engine.models.api_models import (StreamingChatChunk, ChatResponse, + StreamingChatResponse, ) from context_engine.models.data_models import Context, Messages +CE_DEBUG_INFO = os.getenv("CE_DEBUG_INFO", "FALSE").lower() == "true" + DEFAULT_SYSTEM_PROMPT = """"Use the following pieces of context to answer the user question at the next messages. This context retrieved from a knowledge database and you should use only the facts from the context to answer. Always remember to include the source to the documents you used from their 'source' field in the format 'Source: $SOURCE_HERE'. If you don't know the answer, just say that you don't know, don't try to make up an answer, use the context." @@ -25,7 +29,7 @@ def chat(self, *, stream: bool = False, model_params: Optional[ModelParams] = None - ) -> Union[ChatResponse, Iterable[StreamingChatResponse]]: + ) -> Union[ChatResponse, StreamingChatResponse]: pass # TODO: Decide if we want it for first release in the API @@ -39,7 +43,7 @@ async def achat(self, *, stream: bool = False, model_params: Optional[ModelParams] = None - ) -> Union[ChatResponse, Iterable[StreamingChatResponse]]: + ) -> Union[ChatResponse, StreamingChatResponse]: pass @abstractmethod @@ -55,8 +59,8 @@ def __init__(self, *, llm: BaseLLM, context_engine: ContextEngine, - max_prompt_tokens: int, - max_generated_tokens: int, + max_prompt_tokens: int = 4096, + max_generated_tokens: Optional[int] = None, max_context_tokens: Optional[int] = None, query_builder: Optional[QueryGenerator] = None, system_prompt: Optional[str] = None, @@ -95,7 +99,7 @@ def chat(self, *, stream: bool = False, model_params: Optional[ModelParams] = None - ) -> Union[ChatResponse, Iterable[StreamingChatResponse]]: + ) -> Union[ChatResponse, StreamingChatResponse]: context = self.get_context(messages) system_prompt = self.system_prompt_template + f"\nContext: {context.to_text()}" llm_messages = self._prompt_builder.build( @@ -103,10 +107,24 @@ def chat(self, messages, max_tokens=self.max_prompt_tokens ) - return self.llm.chat_completion(llm_messages, - max_tokens=self.max_generated_tokens, - stream=stream, - model_params=model_params) + llm_response = self.llm.chat_completion(llm_messages, + max_tokens=self.max_generated_tokens, + stream=stream, + model_params=model_params) + debug_info = {} + if CE_DEBUG_INFO: + debug_info['context'] = context.dict() + debug_info['context'].update(context.debug_info) + + if stream: + return StreamingChatResponse( + chunks=cast(Iterable[StreamingChatChunk], llm_response), + debug_info=debug_info + ) + else: + resonse = cast(ChatResponse, llm_response) + resonse.debug_info = debug_info + return resonse def get_context(self, messages: Messages, @@ -120,7 +138,7 @@ async def achat(self, *, stream: bool = False, model_params: Optional[ModelParams] = None - ) -> Union[ChatResponse, Iterable[StreamingChatResponse]]: + ) -> Union[ChatResponse, StreamingChatResponse]: raise NotImplementedError async def aget_context(self, messages: Messages) -> Context: diff --git a/context_engine/context_engine/context_engine.py b/context_engine/context_engine/context_engine.py index ea39ea95..da269c07 100644 --- a/context_engine/context_engine/context_engine.py +++ b/context_engine/context_engine/context_engine.py @@ -1,3 +1,4 @@ +import os from abc import ABC, abstractmethod from typing import List, Optional @@ -6,6 +7,8 @@ from context_engine.knoweldge_base.base import BaseKnowledgeBase from context_engine.models.data_models import Context, Query +CE_DEBUG_INFO = os.getenv("CE_DEBUG_INFO", "FALSE").lower() == "true" + class BaseContextEngine(ABC): @@ -38,6 +41,9 @@ def query(self, queries: List[Query], max_context_tokens: int, ) -> Context: queries, global_metadata_filter=self.global_metadata_filter) context = self.context_builder.build(query_results, max_context_tokens) + + if CE_DEBUG_INFO: + context.debug_info["query_results"] = [qr.dict() for qr in query_results] return context async def aquery(self, queries: List[Query], max_context_tokens: int, ) -> Context: diff --git a/context_engine/knoweldge_base/base.py b/context_engine/knoweldge_base/base.py index 82d7d6bb..e1a504bf 100644 --- a/context_engine/knoweldge_base/base.py +++ b/context_engine/knoweldge_base/base.py @@ -26,6 +26,10 @@ def upsert(self, def delete(self, document_ids: List[str], namespace: str = "", ) -> None: pass + @abstractmethod + def verify_connection_health(self) -> None: + pass + @abstractmethod async def aquery(self, queries: List[Query], diff --git a/context_engine/knoweldge_base/chunker/token_chunker.py b/context_engine/knoweldge_base/chunker/token_chunker.py index 22de41fc..7a2bf3b2 100644 --- a/context_engine/knoweldge_base/chunker/token_chunker.py +++ b/context_engine/knoweldge_base/chunker/token_chunker.py @@ -9,8 +9,8 @@ class TokenChunker(Chunker): def __init__(self, - max_chunk_size: int = 200, - overlap: int = 0, ): + max_chunk_size: int = 256, + overlap: int = 30, ): if overlap < 0: cls_name = self.__class__.__name__ raise ValueError( diff --git a/context_engine/knoweldge_base/knowledge_base.py b/context_engine/knoweldge_base/knowledge_base.py index 4489f866..e763319b 100644 --- a/context_engine/knoweldge_base/knowledge_base.py +++ b/context_engine/knoweldge_base/knowledge_base.py @@ -87,6 +87,29 @@ def _connect_index(cls, ) from e return index + def verify_connection_health(self) -> None: + self._verify_not_deleted() + + try: + self._index.describe_index_stats() # type: ignore + except Exception as e: + try: + pinecone_whoami() + except Exception: + raise RuntimeError( + "Failed to connect to Pinecone. " + "Please check your credentials and try again" + ) from e + + if self._index_name not in list_indexes(): + raise RuntimeError( + f"index {self._index_name} does not exist anymore" + "and was probably deleted. " + "Please create it first using `create_with_new_index()`" + ) from e + raise RuntimeError("Index unexpectedly did not respond. " + "Please try again in few moments") from e + @classmethod def create_with_new_index(cls, index_name: str, @@ -189,11 +212,11 @@ def index_name(self) -> str: return self._index_name def delete_index(self): - self._validate_not_deleted() + self._verify_not_deleted() delete_index(self._index_name) self._index = None - def _validate_not_deleted(self): + def _verify_not_deleted(self): if self._index is None: raise RuntimeError( "index was deleted. " @@ -218,7 +241,7 @@ def query(self, def _query_index(self, query: KBQuery, global_metadata_filter: Optional[dict]) -> KBQueryResult: - self._validate_not_deleted() + self._verify_not_deleted() metadata_filter = deepcopy(query.metadata_filter) if global_metadata_filter is not None: @@ -252,7 +275,7 @@ def upsert(self, documents: List[Document], namespace: str = "", batch_size: int = 100): - self._validate_not_deleted() + self._verify_not_deleted() chunks = self._chunker.chunk_documents(documents) encoded_chunks = self._encoder.encode_documents(chunks) @@ -290,7 +313,7 @@ def upsert_dataframe(self, df: pd.DataFrame, namespace: str = "", batch_size: int = 100): - self._validate_not_deleted() + self._verify_not_deleted() expected_columns = ["id", "text", "metadata"] if not all([c in df.columns for c in expected_columns]): @@ -305,7 +328,7 @@ def upsert_dataframe(self, def delete(self, document_ids: List[str], namespace: str = "") -> None: - self._validate_not_deleted() + self._verify_not_deleted() self._index.delete( # type: ignore filter={"document_id": {"$in": document_ids}}, namespace=namespace diff --git a/context_engine/knoweldge_base/record_encoder/dense_record_encoder.py b/context_engine/knoweldge_base/record_encoder/dense_record_encoder.py index 4c2bd499..752b8562 100644 --- a/context_engine/knoweldge_base/record_encoder/dense_record_encoder.py +++ b/context_engine/knoweldge_base/record_encoder/dense_record_encoder.py @@ -15,8 +15,11 @@ class DenseRecordEncoder(RecordEncoder): DEFAULT_MODEL_NAME = "text-embedding-ada-002" def __init__(self, - dense_encoder: Optional[BaseDenseEncoder] = None, **kwargs): - super().__init__(**kwargs) + dense_encoder: Optional[BaseDenseEncoder] = None, + *, + batch_size: int = 500, + **kwargs): + super().__init__(batch_size=batch_size, **kwargs) if dense_encoder is None: dense_encoder = self.DEFAULT_DENSE_ENCODER(self.DEFAULT_MODEL_NAME) self._dense_encoder = dense_encoder diff --git a/context_engine/llm/base.py b/context_engine/llm/base.py index bf292f9a..b89dee5a 100644 --- a/context_engine/llm/base.py +++ b/context_engine/llm/base.py @@ -2,7 +2,7 @@ from typing import Union, Iterable, Optional, List from context_engine.llm.models import Function, ModelParams -from context_engine.models.api_models import ChatResponse, StreamingChatResponse +from context_engine.models.api_models import ChatResponse, StreamingChatChunk from context_engine.models.data_models import Messages, Query @@ -23,7 +23,7 @@ def chat_completion(self, stream: bool = False, max_tokens: Optional[int] = None, model_params: Optional[ModelParams] = None, - ) -> Union[ChatResponse, Iterable[StreamingChatResponse]]: + ) -> Union[ChatResponse, Iterable[StreamingChatChunk]]: pass @abstractmethod @@ -44,7 +44,7 @@ async def achat_completion(self, max_generated_tokens: Optional[int] = None, model_params: Optional[ModelParams] = None, ) -> Union[ChatResponse, - Iterable[StreamingChatResponse]]: + Iterable[StreamingChatChunk]]: pass @abstractmethod diff --git a/context_engine/llm/openai.py b/context_engine/llm/openai.py index c4c6567c..5cf99e4a 100644 --- a/context_engine/llm/openai.py +++ b/context_engine/llm/openai.py @@ -4,7 +4,7 @@ import json from context_engine.llm import BaseLLM from context_engine.llm.models import Function, ModelParams -from context_engine.models.api_models import ChatResponse, StreamingChatResponse +from context_engine.models.api_models import ChatResponse, StreamingChatChunk from context_engine.models.data_models import Messages, Query @@ -30,7 +30,7 @@ def chat_completion(self, stream: bool = False, max_tokens: Optional[int] = None, model_params: Optional[ModelParams] = None, - ) -> Union[ChatResponse, Iterable[StreamingChatResponse]]: + ) -> Union[ChatResponse, Iterable[StreamingChatChunk]]: model_params_dict: Dict[str, Any] = {} model_params_dict.update( @@ -48,7 +48,7 @@ def chat_completion(self, def streaming_iterator(response): for chunk in response: - yield StreamingChatResponse(**chunk) + yield StreamingChatChunk(**chunk) if stream: return streaming_iterator(response) @@ -90,7 +90,7 @@ async def achat_completion(self, max_generated_tokens: Optional[int] = None, model_params: Optional[ModelParams] = None ) -> Union[ChatResponse, - Iterable[StreamingChatResponse]]: + Iterable[StreamingChatChunk]]: raise NotImplementedError() async def agenerate_queries(self, diff --git a/context_engine/models/api_models.py b/context_engine/models/api_models.py index 044bee90..cdf55dfa 100644 --- a/context_engine/models/api_models.py +++ b/context_engine/models/api_models.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence +from typing import Optional, Sequence, Iterable from pydantic import BaseModel, Field, validator @@ -37,9 +37,14 @@ class ChatResponse(BaseModel): debug_info: dict = Field(default_factory=dict, exclude=True) -class StreamingChatResponse(BaseModel): +class StreamingChatChunk(BaseModel): id: str object: str created: int model: str choices: Sequence[_StreamChoice] + + +class StreamingChatResponse(BaseModel): + chunks: Iterable[StreamingChatChunk] + debug_info: dict = Field(default_factory=dict, exclude=True) diff --git a/context_engine/models/data_models.py b/context_engine/models/data_models.py index 3c7b7653..b6de2019 100644 --- a/context_engine/models/data_models.py +++ b/context_engine/models/data_models.py @@ -22,7 +22,7 @@ class Document(BaseModel): id: str text: str source: str = "" - metadata: Metadata + metadata: Metadata = Field(default_factory=dict) @validator('metadata') def metadata_reseved_fields(cls, v): diff --git a/pyproject.toml b/pyproject.toml index 8af395c2..eafc5afc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,10 +14,15 @@ openai = "^0.27.5" tiktoken = "^0.3.3" pinecone-datasets = "^0.6.1" pydantic = "^1.10.7" -pinecone-text = "^0.5.3" +pinecone-text = { version = "^0.5.4", extras = ["openai"] } flake8-pyproject = "^1.2.3" pandas-stubs = "^2.0.3.230814" langchain = "^0.0.188" +fastapi = "^0.92.0" +uvicorn = "^0.20.0" +tenacity = "^8.2.1" +sse-starlette = "^1.6.5" + [tool.poetry.group.dev.dependencies] jupyter = "^1.0.0" diff --git a/service/__init__.py b/service/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/service/api_models.py b/service/api_models.py new file mode 100644 index 00000000..ba80d140 --- /dev/null +++ b/service/api_models.py @@ -0,0 +1,23 @@ +from typing import Optional, List + +from pydantic import BaseModel + +from context_engine.models.data_models import Messages, Query, Document + + +class ChatRequest(BaseModel): + model: str = "" + messages: Messages + stream: bool = False + user: Optional[str] = None + + +class ContextQueryRequest(BaseModel): + queries: List[Query] + max_tokens: int + + +class ContextUpsertRequest(BaseModel): + documents: List[Document] + namespace: str = "" + batch_size: int = 100 diff --git a/service/app.py b/service/app.py new file mode 100644 index 00000000..2de020bc --- /dev/null +++ b/service/app.py @@ -0,0 +1,186 @@ +import logging +import os +import sys +import uuid +from dotenv import load_dotenv + +from context_engine.llm import BaseLLM +from context_engine.llm.models import UserMessage +from context_engine.knoweldge_base.tokenizer import OpenAITokenizer, Tokenizer +from context_engine.knoweldge_base import KnowledgeBase +from context_engine.context_engine import ContextEngine +from context_engine.chat_engine import ChatEngine +from starlette.concurrency import run_in_threadpool +from sse_starlette.sse import EventSourceResponse + +from fastapi import FastAPI, HTTPException, Body +import uvicorn +from typing import cast + +from context_engine.models.api_models import StreamingChatResponse, ChatResponse +from context_engine.models.data_models import Context +from service.api_models import \ + ChatRequest, ContextQueryRequest, ContextUpsertRequest + +load_dotenv() # load env vars before import of openai +from context_engine.llm.openai import OpenAILLM # noqa: E402 + + +INDEX_NAME = os.getenv("INDEX_NAME") +app = FastAPI() + +context_engine: ContextEngine +chat_engine: ChatEngine +kb: KnowledgeBase +llm: BaseLLM +logger: logging.Logger + + +@app.post( + "/context/chat/completions", +) +async def chat( + request: ChatRequest = Body(...), +): + try: + session_id = request.user or "None" # noqa: F841 + question_id = str(uuid.uuid4()) + logger.debug(f"Received chat request: {request.messages[-1].content}") + answer = await run_in_threadpool(chat_engine.chat, + messages=request.messages, + stream=request.stream) + + if request.stream: + def stringify_content(response: StreamingChatResponse): + for chunk in response.chunks: + chunk.id = question_id + data = chunk.json() + yield data + + content_stream = stringify_content(cast(StreamingChatResponse, answer)) + return EventSourceResponse(content_stream, media_type='text/event-stream') + + else: + chat_response = cast(ChatResponse, answer) + chat_response.id = question_id + return chat_response + + except Exception as e: + logger.exception(f"Chat with question_id {question_id} failed") + raise HTTPException( + status_code=500, detail=f"Internal Service Error: {str(e)}") + + +@app.get( + "/context/query", +) +async def query( + request: ContextQueryRequest = Body(...), +): + try: + context: Context = await run_in_threadpool( + context_engine.query, + queries=request.queries, + max_context_tokens=request.max_tokens) + + return context.content + + except Exception as e: + logger.exception(e) + raise HTTPException( + status_code=500, detail=f"Internal Service Error: {str(e)}") + + +@app.post( + "/context/upsert", +) +async def upsert( + request: ContextUpsertRequest = Body(...), +): + try: + logger.info(f"Upserting {len(request.documents)} documents") + upsert_results = await run_in_threadpool( + kb.upsert, + documents=request.documents, + namespace=request.namespace, + batch_size=request.batch_size) + + return upsert_results + + except Exception as e: + logger.exception(e) + raise HTTPException( + status_code=500, detail=f"Internal Service Error: {str(e)}") + + +@app.get( + "/health", +) +async def health_check(): + try: + await run_in_threadpool(kb.verify_connection_health) + except Exception as e: + err_msg = f"Failed connecting to Pinecone Index {kb._index_name}" + logger.exception(err_msg) + raise HTTPException( + status_code=500, detail=f"{err_msg}. Error: {str(e)}") from e + + try: + msg = UserMessage(content="This is a health check. Are you alive? Be concise") + await run_in_threadpool(llm.chat_completion, + messages=[msg], + max_tokens=50) + except Exception as e: + err_msg = f"Failed to communicate with {llm.__class__.__name__}" + logger.exception(err_msg) + raise HTTPException( + status_code=500, detail=f"{err_msg}. Error: {str(e)}") from e + + return "All clear!" + + +@app.on_event("startup") +async def startup(): + _init_logging() + _init_engines() + + +def _init_logging(): + global logger + + file_handler = logging.FileHandler( + filename=os.getenv("CE_LOG_FILENAME", "context_engine.log") + ) + stdout_handler = logging.StreamHandler(stream=sys.stdout) + handlers = [file_handler, stdout_handler] + logging.basicConfig( + format='%(asctime)s - %(processName)s - %(name)-10s [%(levelname)-8s]: ' + '%(message)s', + level=os.getenv("CE_LOG_LEVEL", "INFO").upper(), + handlers=handlers, + force=True + ) + logger = logging.getLogger(__name__) + + +def _init_engines(): + global kb, context_engine, chat_engine, llm + Tokenizer.initialize(OpenAITokenizer, model_name='gpt-3.5-turbo-0613') + + if not INDEX_NAME: + raise ValueError("INDEX_NAME environment variable must be set") + + kb = KnowledgeBase(index_name=INDEX_NAME) + context_engine = ContextEngine(knowledge_base=kb) + llm = OpenAILLM(model_name='gpt-3.5-turbo-0613') + + chat_engine = ChatEngine(llm=llm, context_engine=context_engine) + + +def start(): + uvicorn.run("service.app:app", + host="0.0.0.0", port=8000, reload=True) + + +if __name__ == "__main__": + start() diff --git a/tests/system/knowledge_base/test_knowledge_base.py b/tests/system/knowledge_base/test_knowledge_base.py index 0cbd7c4a..ecf66778 100644 --- a/tests/system/knowledge_base/test_knowledge_base.py +++ b/tests/system/knowledge_base/test_knowledge_base.py @@ -112,6 +112,10 @@ def test_create_index(index_full_name, knowledge_base): assert knowledge_base._index.describe_index_stats() +def test_is_verify_connection_health_happy_path(knowledge_base): + knowledge_base.verify_connection_health() + + def test_init_with_context_engine_prefix(index_full_name, chunker, encoder): kb = KnowledgeBase(index_name=index_full_name, encoder=encoder, @@ -262,6 +266,13 @@ def test_delete_index_for_non_existing(knowledge_base): assert "index was deleted." in str(e.value) +def test_verify_connection_health_raise_for_deleted_index(knowledge_base): + with pytest.raises(RuntimeError) as e: + knowledge_base.verify_connection_health() + + assert "index was deleted" in str(e.value) + + def test_create_with_text_in_indexed_field_raise(index_name, chunker, encoder): diff --git a/tests/system/llm/test_openai.py b/tests/system/llm/test_openai.py index b601182e..b2a3d7d9 100644 --- a/tests/system/llm/test_openai.py +++ b/tests/system/llm/test_openai.py @@ -3,7 +3,7 @@ from context_engine.models.data_models import Role, MessageBase # noqa -from context_engine.models.api_models import ChatResponse, StreamingChatResponse # noqa +from context_engine.models.api_models import ChatResponse, StreamingChatChunk # noqa from context_engine.llm.openai import OpenAILLM # noqa from context_engine.llm.models import \ Function, FunctionParameters, FunctionArrayProperty, ModelParams # noqa @@ -143,7 +143,7 @@ def test_chat_streaming(openai_llm, messages): messages_received = [message for message in response] assert len(messages_received) > 0 for message in messages_received: - assert isinstance(message, StreamingChatResponse) + assert isinstance(message, StreamingChatChunk) @staticmethod def test_max_tokens(openai_llm, messages): diff --git a/tests/unit/chat_engine/test_chat_engine.py b/tests/unit/chat_engine/test_chat_engine.py index ffa0393c..1fe6d4d2 100644 --- a/tests/unit/chat_engine/test_chat_engine.py +++ b/tests/unit/chat_engine/test_chat_engine.py @@ -9,7 +9,8 @@ from context_engine.context_engine.models import ContextQueryResult, ContextSnippet from context_engine.llm import BaseLLM from context_engine.llm.models import SystemMessage -from context_engine.models.data_models import MessageBase, Query, Context +from context_engine.models.api_models import ChatResponse, _Choice, TokenCounts +from context_engine.models.data_models import MessageBase, Query, Context, Role from .. import random_words MOCK_SYSTEM_PROMPT = "This is my mock prompt" @@ -70,7 +71,21 @@ def _get_inputs_and_expected(self, expected_prompt = [SystemMessage( content=system_prompt + f"\nContext: {mock_context.to_text()}" )] + messages - mock_chat_response = "Photosynthesis is a process used by plants..." + + mock_chat_response = ChatResponse( + id='chatcmpl-7xuuGZzniUGiqxDSTJnqwb0l1xtfp', + object='chat.completion', + created=1694514456, + model='gpt-3.5-turbo', + choices=[_Choice(index=0, + message=MessageBase( + role=Role.ASSISTANT, + content="Photosynthesis is a process used by plants"), + finish_reason='stop')], + usage=TokenCounts(prompt_tokens=25, + completion_tokens=9, + total_tokens=34), + debug_info={}) # Set the return values of the mocked methods self.mock_query_builder.generate.return_value = mock_queries