Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Adapted app.py to work with latest code #40

Merged
merged 33 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
0d0c261
initial service
acatav Aug 30, 2023
359d764
implement upsert
acatav Aug 31, 2023
a05157c
lint
acatav Aug 31, 2023
d453b7a
change upsert to post command
miararoy Sep 4, 2023
97bf7cf
[pyproj] Added pytest-xdist
igiloh-pinecone Sep 7, 2023
fd3a4de
[CI] Added system tests to CI
igiloh-pinecone Sep 7, 2023
68655e0
Merge branch 'int_context_tokens' into ci_system_tests
igiloh-pinecone Sep 7, 2023
bf9cb4e
Merge remote-tracking branch 'origin/initial-chat-service' into server
igiloh-pinecone Sep 7, 2023
cac0961
[app] Adapt FastAPI to work with latest changes
igiloh-pinecone Sep 7, 2023
1c8e1dc
[kb] Set more meaningful defaults
igiloh-pinecone Sep 7, 2023
4f174d6
[chat] Set defaults to max_prompt_tokens and generated_tokens
igiloh-pinecone Sep 7, 2023
a1229d6
Merge remote-tracking branch 'origin/dev' into server
igiloh-pinecone Sep 10, 2023
566e887
[app.py] Slightly improve type static analysis
igiloh-pinecone Sep 10, 2023
0292276
[main] moved service to outside library
igiloh-pinecone Sep 10, 2023
b8a4023
[main] Index name shouldn't be hardcoded
igiloh-pinecone Sep 10, 2023
f88a387
[main] Added proper question_id and session_id
igiloh-pinecone Sep 10, 2023
d0754ff
[main] renamed models to api_models
igiloh-pinecone Sep 10, 2023
7cb59df
[chat] Properly support propogating degug_info to chat response
igiloh-pinecone Sep 10, 2023
ffdeb07
[chat] Added debug_info propogation
igiloh-pinecone Sep 10, 2023
c0a1a83
[app] Added very basic logging
igiloh-pinecone Sep 10, 2023
884e0a6
[app] Fixed basic logging
igiloh-pinecone Sep 11, 2023
fda3ee1
[app] Renamed main -> service
igiloh-pinecone Sep 11, 2023
a727464
[app] Further improve logging
igiloh-pinecone Sep 11, 2023
1755b77
[app] Added health check
igiloh-pinecone Sep 11, 2023
28caab0
minor fixes
igiloh-pinecone Sep 11, 2023
042deda
Merge remote-tracking branch 'origin/dev' into server
igiloh-pinecone Sep 11, 2023
ca9d374
[models] made metadata optional in Document
igiloh-pinecone Sep 11, 2023
198f0ec
Merge branch 'dev' into server
miararoy Sep 12, 2023
e3fe689
fix tests
acatav Sep 12, 2023
9167d7a
import dotenv before importing openai
acatav Sep 12, 2023
0ec0e47
align kb init
acatav Sep 12, 2023
3ab52f6
add verify index connection to KB
acatav Sep 12, 2023
d0d7df3
lint
acatav Sep 12, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions context_engine/chat_engine/chat_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from abc import ABC, abstractmethod
from typing import Iterable, Union, Optional
from typing import Iterable, Union, Optional, cast

from context_engine.chat_engine.models import HistoryPruningMethod
from context_engine.chat_engine.prompt_builder import PromptBuilder
Expand All @@ -9,9 +10,12 @@
from context_engine.knoweldge_base.tokenizer import Tokenizer
from context_engine.llm import BaseLLM
from context_engine.llm.models import ModelParams, SystemMessage
from context_engine.models.api_models import StreamingChatResponse, ChatResponse
from context_engine.models.api_models import (StreamingChatChunk, ChatResponse,
StreamingChatResponse, )
from context_engine.models.data_models import Context, Messages

CE_DEBUG_INFO = os.getenv("CE_DEBUG_INFO", "FALSE").lower() == "true"


DEFAULT_SYSTEM_PROMPT = """"Use the following pieces of context to answer the user question at the next messages. This context retrieved from a knowledge database and you should use only the facts from the context to answer. Always remember to include the source to the documents you used from their 'source' field in the format 'Source: $SOURCE_HERE'.
If you don't know the answer, just say that you don't know, don't try to make up an answer, use the context."
Expand All @@ -25,7 +29,7 @@ def chat(self,
*,
stream: bool = False,
model_params: Optional[ModelParams] = None
) -> Union[ChatResponse, Iterable[StreamingChatResponse]]:
) -> Union[ChatResponse, StreamingChatResponse]:
pass

# TODO: Decide if we want it for first release in the API
Expand All @@ -39,7 +43,7 @@ async def achat(self,
*,
stream: bool = False,
model_params: Optional[ModelParams] = None
) -> Union[ChatResponse, Iterable[StreamingChatResponse]]:
) -> Union[ChatResponse, StreamingChatResponse]:
pass

@abstractmethod
Expand All @@ -55,8 +59,8 @@ def __init__(self,
*,
llm: BaseLLM,
context_engine: ContextEngine,
max_prompt_tokens: int,
max_generated_tokens: int,
max_prompt_tokens: int = 4096,
max_generated_tokens: Optional[int] = None,
max_context_tokens: Optional[int] = None,
query_builder: Optional[QueryGenerator] = None,
system_prompt: Optional[str] = None,
Expand Down Expand Up @@ -95,18 +99,32 @@ def chat(self,
*,
stream: bool = False,
model_params: Optional[ModelParams] = None
) -> Union[ChatResponse, Iterable[StreamingChatResponse]]:
) -> Union[ChatResponse, StreamingChatResponse]:
context = self.get_context(messages)
system_prompt = self.system_prompt_template + f"\nContext: {context.to_text()}"
llm_messages = self._prompt_builder.build(
system_prompt,
messages,
max_tokens=self.max_prompt_tokens
)
return self.llm.chat_completion(llm_messages,
max_tokens=self.max_generated_tokens,
stream=stream,
model_params=model_params)
llm_response = self.llm.chat_completion(llm_messages,
max_tokens=self.max_generated_tokens,
stream=stream,
model_params=model_params)
debug_info = {}
if CE_DEBUG_INFO:
debug_info['context'] = context.dict()
debug_info['context'].update(context.debug_info)

if stream:
return StreamingChatResponse(
chunks=cast(Iterable[StreamingChatChunk], llm_response),
debug_info=debug_info
)
else:
resonse = cast(ChatResponse, llm_response)
resonse.debug_info = debug_info
return resonse

def get_context(self,
messages: Messages,
Expand All @@ -120,7 +138,7 @@ async def achat(self,
*,
stream: bool = False,
model_params: Optional[ModelParams] = None
) -> Union[ChatResponse, Iterable[StreamingChatResponse]]:
) -> Union[ChatResponse, StreamingChatResponse]:
raise NotImplementedError

async def aget_context(self, messages: Messages) -> Context:
Expand Down
6 changes: 6 additions & 0 deletions context_engine/context_engine/context_engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from abc import ABC, abstractmethod
from typing import List, Optional

Expand All @@ -6,6 +7,8 @@
from context_engine.knoweldge_base.base import BaseKnowledgeBase
from context_engine.models.data_models import Context, Query

CE_DEBUG_INFO = os.getenv("CE_DEBUG_INFO", "FALSE").lower() == "true"


class BaseContextEngine(ABC):

Expand Down Expand Up @@ -38,6 +41,9 @@ def query(self, queries: List[Query], max_context_tokens: int, ) -> Context:
queries,
global_metadata_filter=self.global_metadata_filter)
context = self.context_builder.build(query_results, max_context_tokens)

if CE_DEBUG_INFO:
context.debug_info["query_results"] = [qr.dict() for qr in query_results]
return context

async def aquery(self, queries: List[Query], max_context_tokens: int, ) -> Context:
Expand Down
4 changes: 4 additions & 0 deletions context_engine/knoweldge_base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def upsert(self,
def delete(self, document_ids: List[str], namespace: str = "", ) -> None:
pass

@abstractmethod
def verify_connection_health(self) -> None:
pass

@abstractmethod
async def aquery(self,
queries: List[Query],
Expand Down
4 changes: 2 additions & 2 deletions context_engine/knoweldge_base/chunker/token_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
class TokenChunker(Chunker):

def __init__(self,
max_chunk_size: int = 200,
overlap: int = 0, ):
max_chunk_size: int = 256,
overlap: int = 30, ):
if overlap < 0:
cls_name = self.__class__.__name__
raise ValueError(
Expand Down
35 changes: 29 additions & 6 deletions context_engine/knoweldge_base/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,29 @@ def _connect_index(cls,
) from e
return index

def verify_connection_health(self) -> None:
self._verify_not_deleted()

try:
self._index.describe_index_stats() # type: ignore
except Exception as e:
try:
pinecone_whoami()
except Exception:
raise RuntimeError(
"Failed to connect to Pinecone. "
"Please check your credentials and try again"
) from e

if self._index_name not in list_indexes():
raise RuntimeError(
f"index {self._index_name} does not exist anymore"
"and was probably deleted. "
"Please create it first using `create_with_new_index()`"
) from e
raise RuntimeError("Index unexpectedly did not respond. "
"Please try again in few moments") from e

@classmethod
def create_with_new_index(cls,
index_name: str,
Expand Down Expand Up @@ -189,11 +212,11 @@ def index_name(self) -> str:
return self._index_name

def delete_index(self):
self._validate_not_deleted()
self._verify_not_deleted()
delete_index(self._index_name)
self._index = None

def _validate_not_deleted(self):
def _verify_not_deleted(self):
if self._index is None:
raise RuntimeError(
"index was deleted. "
Expand All @@ -218,7 +241,7 @@ def query(self,
def _query_index(self,
query: KBQuery,
global_metadata_filter: Optional[dict]) -> KBQueryResult:
self._validate_not_deleted()
self._verify_not_deleted()

metadata_filter = deepcopy(query.metadata_filter)
if global_metadata_filter is not None:
Expand Down Expand Up @@ -252,7 +275,7 @@ def upsert(self,
documents: List[Document],
namespace: str = "",
batch_size: int = 100):
self._validate_not_deleted()
self._verify_not_deleted()

chunks = self._chunker.chunk_documents(documents)
encoded_chunks = self._encoder.encode_documents(chunks)
Expand Down Expand Up @@ -290,7 +313,7 @@ def upsert_dataframe(self,
df: pd.DataFrame,
namespace: str = "",
batch_size: int = 100):
self._validate_not_deleted()
self._verify_not_deleted()

expected_columns = ["id", "text", "metadata"]
if not all([c in df.columns for c in expected_columns]):
Expand All @@ -305,7 +328,7 @@ def upsert_dataframe(self,
def delete(self,
document_ids: List[str],
namespace: str = "") -> None:
self._validate_not_deleted()
self._verify_not_deleted()
self._index.delete( # type: ignore
filter={"document_id": {"$in": document_ids}},
namespace=namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ class DenseRecordEncoder(RecordEncoder):
DEFAULT_MODEL_NAME = "text-embedding-ada-002"

def __init__(self,
dense_encoder: Optional[BaseDenseEncoder] = None, **kwargs):
super().__init__(**kwargs)
dense_encoder: Optional[BaseDenseEncoder] = None,
*,
batch_size: int = 500,
**kwargs):
super().__init__(batch_size=batch_size, **kwargs)
if dense_encoder is None:
dense_encoder = self.DEFAULT_DENSE_ENCODER(self.DEFAULT_MODEL_NAME)
self._dense_encoder = dense_encoder
Expand Down
6 changes: 3 additions & 3 deletions context_engine/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Union, Iterable, Optional, List

from context_engine.llm.models import Function, ModelParams
from context_engine.models.api_models import ChatResponse, StreamingChatResponse
from context_engine.models.api_models import ChatResponse, StreamingChatChunk
from context_engine.models.data_models import Messages, Query


Expand All @@ -23,7 +23,7 @@ def chat_completion(self,
stream: bool = False,
max_tokens: Optional[int] = None,
model_params: Optional[ModelParams] = None,
) -> Union[ChatResponse, Iterable[StreamingChatResponse]]:
) -> Union[ChatResponse, Iterable[StreamingChatChunk]]:
pass

@abstractmethod
Expand All @@ -44,7 +44,7 @@ async def achat_completion(self,
max_generated_tokens: Optional[int] = None,
model_params: Optional[ModelParams] = None,
) -> Union[ChatResponse,
Iterable[StreamingChatResponse]]:
Iterable[StreamingChatChunk]]:
pass

@abstractmethod
Expand Down
8 changes: 4 additions & 4 deletions context_engine/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from context_engine.llm import BaseLLM
from context_engine.llm.models import Function, ModelParams
from context_engine.models.api_models import ChatResponse, StreamingChatResponse
from context_engine.models.api_models import ChatResponse, StreamingChatChunk
from context_engine.models.data_models import Messages, Query


Expand All @@ -30,7 +30,7 @@ def chat_completion(self,
stream: bool = False,
max_tokens: Optional[int] = None,
model_params: Optional[ModelParams] = None,
) -> Union[ChatResponse, Iterable[StreamingChatResponse]]:
) -> Union[ChatResponse, Iterable[StreamingChatChunk]]:

model_params_dict: Dict[str, Any] = {}
model_params_dict.update(
Expand All @@ -48,7 +48,7 @@ def chat_completion(self,

def streaming_iterator(response):
for chunk in response:
yield StreamingChatResponse(**chunk)
yield StreamingChatChunk(**chunk)

if stream:
return streaming_iterator(response)
Expand Down Expand Up @@ -90,7 +90,7 @@ async def achat_completion(self,
max_generated_tokens: Optional[int] = None,
model_params: Optional[ModelParams] = None
) -> Union[ChatResponse,
Iterable[StreamingChatResponse]]:
Iterable[StreamingChatChunk]]:
raise NotImplementedError()

async def agenerate_queries(self,
Expand Down
9 changes: 7 additions & 2 deletions context_engine/models/api_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence
from typing import Optional, Sequence, Iterable

from pydantic import BaseModel, Field, validator

Expand Down Expand Up @@ -37,9 +37,14 @@ class ChatResponse(BaseModel):
debug_info: dict = Field(default_factory=dict, exclude=True)


class StreamingChatResponse(BaseModel):
class StreamingChatChunk(BaseModel):
id: str
object: str
created: int
model: str
choices: Sequence[_StreamChoice]


class StreamingChatResponse(BaseModel):
chunks: Iterable[StreamingChatChunk]
debug_info: dict = Field(default_factory=dict, exclude=True)
2 changes: 1 addition & 1 deletion context_engine/models/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Document(BaseModel):
id: str
text: str
source: str = ""
metadata: Metadata
metadata: Metadata = Field(default_factory=dict)

@validator('metadata')
def metadata_reseved_fields(cls, v):
Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ openai = "^0.27.5"
tiktoken = "^0.3.3"
pinecone-datasets = "^0.6.1"
pydantic = "^1.10.7"
pinecone-text = "^0.5.3"
pinecone-text = { version = "^0.5.4", extras = ["openai"] }
flake8-pyproject = "^1.2.3"
pandas-stubs = "^2.0.3.230814"
langchain = "^0.0.188"
fastapi = "^0.92.0"
uvicorn = "^0.20.0"
tenacity = "^8.2.1"
sse-starlette = "^1.6.5"


[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
Expand Down
Empty file added service/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions service/api_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Optional, List

from pydantic import BaseModel

from context_engine.models.data_models import Messages, Query, Document


class ChatRequest(BaseModel):
model: str = ""
messages: Messages
stream: bool = False
user: Optional[str] = None


class ContextQueryRequest(BaseModel):
queries: List[Query]
max_tokens: int


class ContextUpsertRequest(BaseModel):
documents: List[Document]
namespace: str = ""
batch_size: int = 100
Loading