Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
Merge pull request #40 from pinecone-io/server
Browse files Browse the repository at this point in the history
Adapted app.py to work with latest code
  • Loading branch information
miararoy authored Sep 13, 2023
2 parents 52fd4ee + d0d7df3 commit 257d835
Show file tree
Hide file tree
Showing 17 changed files with 336 additions and 37 deletions.
42 changes: 30 additions & 12 deletions context_engine/chat_engine/chat_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from abc import ABC, abstractmethod
from typing import Iterable, Union, Optional
from typing import Iterable, Union, Optional, cast

from context_engine.chat_engine.models import HistoryPruningMethod
from context_engine.chat_engine.prompt_builder import PromptBuilder
Expand All @@ -9,9 +10,12 @@
from context_engine.knoweldge_base.tokenizer import Tokenizer
from context_engine.llm import BaseLLM
from context_engine.llm.models import ModelParams, SystemMessage
from context_engine.models.api_models import StreamingChatResponse, ChatResponse
from context_engine.models.api_models import (StreamingChatChunk, ChatResponse,
StreamingChatResponse, )
from context_engine.models.data_models import Context, Messages

CE_DEBUG_INFO = os.getenv("CE_DEBUG_INFO", "FALSE").lower() == "true"


DEFAULT_SYSTEM_PROMPT = """"Use the following pieces of context to answer the user question at the next messages. This context retrieved from a knowledge database and you should use only the facts from the context to answer. Always remember to include the source to the documents you used from their 'source' field in the format 'Source: $SOURCE_HERE'.
If you don't know the answer, just say that you don't know, don't try to make up an answer, use the context."
Expand All @@ -25,7 +29,7 @@ def chat(self,
*,
stream: bool = False,
model_params: Optional[ModelParams] = None
) -> Union[ChatResponse, Iterable[StreamingChatResponse]]:
) -> Union[ChatResponse, StreamingChatResponse]:
pass

# TODO: Decide if we want it for first release in the API
Expand All @@ -39,7 +43,7 @@ async def achat(self,
*,
stream: bool = False,
model_params: Optional[ModelParams] = None
) -> Union[ChatResponse, Iterable[StreamingChatResponse]]:
) -> Union[ChatResponse, StreamingChatResponse]:
pass

@abstractmethod
Expand All @@ -55,8 +59,8 @@ def __init__(self,
*,
llm: BaseLLM,
context_engine: ContextEngine,
max_prompt_tokens: int,
max_generated_tokens: int,
max_prompt_tokens: int = 4096,
max_generated_tokens: Optional[int] = None,
max_context_tokens: Optional[int] = None,
query_builder: Optional[QueryGenerator] = None,
system_prompt: Optional[str] = None,
Expand Down Expand Up @@ -95,18 +99,32 @@ def chat(self,
*,
stream: bool = False,
model_params: Optional[ModelParams] = None
) -> Union[ChatResponse, Iterable[StreamingChatResponse]]:
) -> Union[ChatResponse, StreamingChatResponse]:
context = self.get_context(messages)
system_prompt = self.system_prompt_template + f"\nContext: {context.to_text()}"
llm_messages = self._prompt_builder.build(
system_prompt,
messages,
max_tokens=self.max_prompt_tokens
)
return self.llm.chat_completion(llm_messages,
max_tokens=self.max_generated_tokens,
stream=stream,
model_params=model_params)
llm_response = self.llm.chat_completion(llm_messages,
max_tokens=self.max_generated_tokens,
stream=stream,
model_params=model_params)
debug_info = {}
if CE_DEBUG_INFO:
debug_info['context'] = context.dict()
debug_info['context'].update(context.debug_info)

if stream:
return StreamingChatResponse(
chunks=cast(Iterable[StreamingChatChunk], llm_response),
debug_info=debug_info
)
else:
resonse = cast(ChatResponse, llm_response)
resonse.debug_info = debug_info
return resonse

def get_context(self,
messages: Messages,
Expand All @@ -120,7 +138,7 @@ async def achat(self,
*,
stream: bool = False,
model_params: Optional[ModelParams] = None
) -> Union[ChatResponse, Iterable[StreamingChatResponse]]:
) -> Union[ChatResponse, StreamingChatResponse]:
raise NotImplementedError

async def aget_context(self, messages: Messages) -> Context:
Expand Down
6 changes: 6 additions & 0 deletions context_engine/context_engine/context_engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from abc import ABC, abstractmethod
from typing import List, Optional

Expand All @@ -6,6 +7,8 @@
from context_engine.knoweldge_base.base import BaseKnowledgeBase
from context_engine.models.data_models import Context, Query

CE_DEBUG_INFO = os.getenv("CE_DEBUG_INFO", "FALSE").lower() == "true"


class BaseContextEngine(ABC):

Expand Down Expand Up @@ -38,6 +41,9 @@ def query(self, queries: List[Query], max_context_tokens: int, ) -> Context:
queries,
global_metadata_filter=self.global_metadata_filter)
context = self.context_builder.build(query_results, max_context_tokens)

if CE_DEBUG_INFO:
context.debug_info["query_results"] = [qr.dict() for qr in query_results]
return context

async def aquery(self, queries: List[Query], max_context_tokens: int, ) -> Context:
Expand Down
4 changes: 4 additions & 0 deletions context_engine/knoweldge_base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def upsert(self,
def delete(self, document_ids: List[str], namespace: str = "", ) -> None:
pass

@abstractmethod
def verify_connection_health(self) -> None:
pass

@abstractmethod
async def aquery(self,
queries: List[Query],
Expand Down
4 changes: 2 additions & 2 deletions context_engine/knoweldge_base/chunker/token_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
class TokenChunker(Chunker):

def __init__(self,
max_chunk_size: int = 200,
overlap: int = 0, ):
max_chunk_size: int = 256,
overlap: int = 30, ):
if overlap < 0:
cls_name = self.__class__.__name__
raise ValueError(
Expand Down
35 changes: 29 additions & 6 deletions context_engine/knoweldge_base/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,29 @@ def _connect_index(cls,
) from e
return index

def verify_connection_health(self) -> None:
self._verify_not_deleted()

try:
self._index.describe_index_stats() # type: ignore
except Exception as e:
try:
pinecone_whoami()
except Exception:
raise RuntimeError(
"Failed to connect to Pinecone. "
"Please check your credentials and try again"
) from e

if self._index_name not in list_indexes():
raise RuntimeError(
f"index {self._index_name} does not exist anymore"
"and was probably deleted. "
"Please create it first using `create_with_new_index()`"
) from e
raise RuntimeError("Index unexpectedly did not respond. "
"Please try again in few moments") from e

@classmethod
def create_with_new_index(cls,
index_name: str,
Expand Down Expand Up @@ -189,11 +212,11 @@ def index_name(self) -> str:
return self._index_name

def delete_index(self):
self._validate_not_deleted()
self._verify_not_deleted()
delete_index(self._index_name)
self._index = None

def _validate_not_deleted(self):
def _verify_not_deleted(self):
if self._index is None:
raise RuntimeError(
"index was deleted. "
Expand All @@ -218,7 +241,7 @@ def query(self,
def _query_index(self,
query: KBQuery,
global_metadata_filter: Optional[dict]) -> KBQueryResult:
self._validate_not_deleted()
self._verify_not_deleted()

metadata_filter = deepcopy(query.metadata_filter)
if global_metadata_filter is not None:
Expand Down Expand Up @@ -252,7 +275,7 @@ def upsert(self,
documents: List[Document],
namespace: str = "",
batch_size: int = 100):
self._validate_not_deleted()
self._verify_not_deleted()

chunks = self._chunker.chunk_documents(documents)
encoded_chunks = self._encoder.encode_documents(chunks)
Expand Down Expand Up @@ -290,7 +313,7 @@ def upsert_dataframe(self,
df: pd.DataFrame,
namespace: str = "",
batch_size: int = 100):
self._validate_not_deleted()
self._verify_not_deleted()

expected_columns = ["id", "text", "metadata"]
if not all([c in df.columns for c in expected_columns]):
Expand All @@ -305,7 +328,7 @@ def upsert_dataframe(self,
def delete(self,
document_ids: List[str],
namespace: str = "") -> None:
self._validate_not_deleted()
self._verify_not_deleted()
self._index.delete( # type: ignore
filter={"document_id": {"$in": document_ids}},
namespace=namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ class DenseRecordEncoder(RecordEncoder):
DEFAULT_MODEL_NAME = "text-embedding-ada-002"

def __init__(self,
dense_encoder: Optional[BaseDenseEncoder] = None, **kwargs):
super().__init__(**kwargs)
dense_encoder: Optional[BaseDenseEncoder] = None,
*,
batch_size: int = 500,
**kwargs):
super().__init__(batch_size=batch_size, **kwargs)
if dense_encoder is None:
dense_encoder = self.DEFAULT_DENSE_ENCODER(self.DEFAULT_MODEL_NAME)
self._dense_encoder = dense_encoder
Expand Down
6 changes: 3 additions & 3 deletions context_engine/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Union, Iterable, Optional, List

from context_engine.llm.models import Function, ModelParams
from context_engine.models.api_models import ChatResponse, StreamingChatResponse
from context_engine.models.api_models import ChatResponse, StreamingChatChunk
from context_engine.models.data_models import Messages, Query


Expand All @@ -23,7 +23,7 @@ def chat_completion(self,
stream: bool = False,
max_tokens: Optional[int] = None,
model_params: Optional[ModelParams] = None,
) -> Union[ChatResponse, Iterable[StreamingChatResponse]]:
) -> Union[ChatResponse, Iterable[StreamingChatChunk]]:
pass

@abstractmethod
Expand All @@ -44,7 +44,7 @@ async def achat_completion(self,
max_generated_tokens: Optional[int] = None,
model_params: Optional[ModelParams] = None,
) -> Union[ChatResponse,
Iterable[StreamingChatResponse]]:
Iterable[StreamingChatChunk]]:
pass

@abstractmethod
Expand Down
8 changes: 4 additions & 4 deletions context_engine/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from context_engine.llm import BaseLLM
from context_engine.llm.models import Function, ModelParams
from context_engine.models.api_models import ChatResponse, StreamingChatResponse
from context_engine.models.api_models import ChatResponse, StreamingChatChunk
from context_engine.models.data_models import Messages, Query


Expand All @@ -30,7 +30,7 @@ def chat_completion(self,
stream: bool = False,
max_tokens: Optional[int] = None,
model_params: Optional[ModelParams] = None,
) -> Union[ChatResponse, Iterable[StreamingChatResponse]]:
) -> Union[ChatResponse, Iterable[StreamingChatChunk]]:

model_params_dict: Dict[str, Any] = {}
model_params_dict.update(
Expand All @@ -48,7 +48,7 @@ def chat_completion(self,

def streaming_iterator(response):
for chunk in response:
yield StreamingChatResponse(**chunk)
yield StreamingChatChunk(**chunk)

if stream:
return streaming_iterator(response)
Expand Down Expand Up @@ -90,7 +90,7 @@ async def achat_completion(self,
max_generated_tokens: Optional[int] = None,
model_params: Optional[ModelParams] = None
) -> Union[ChatResponse,
Iterable[StreamingChatResponse]]:
Iterable[StreamingChatChunk]]:
raise NotImplementedError()

async def agenerate_queries(self,
Expand Down
9 changes: 7 additions & 2 deletions context_engine/models/api_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence
from typing import Optional, Sequence, Iterable

from pydantic import BaseModel, Field, validator

Expand Down Expand Up @@ -37,9 +37,14 @@ class ChatResponse(BaseModel):
debug_info: dict = Field(default_factory=dict, exclude=True)


class StreamingChatResponse(BaseModel):
class StreamingChatChunk(BaseModel):
id: str
object: str
created: int
model: str
choices: Sequence[_StreamChoice]


class StreamingChatResponse(BaseModel):
chunks: Iterable[StreamingChatChunk]
debug_info: dict = Field(default_factory=dict, exclude=True)
2 changes: 1 addition & 1 deletion context_engine/models/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Document(BaseModel):
id: str
text: str
source: str = ""
metadata: Metadata
metadata: Metadata = Field(default_factory=dict)

@validator('metadata')
def metadata_reseved_fields(cls, v):
Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ openai = "^0.27.5"
tiktoken = "^0.3.3"
pinecone-datasets = "^0.6.1"
pydantic = "^1.10.7"
pinecone-text = "^0.5.3"
pinecone-text = { version = "^0.5.4", extras = ["openai"] }
flake8-pyproject = "^1.2.3"
pandas-stubs = "^2.0.3.230814"
langchain = "^0.0.188"
fastapi = "^0.92.0"
uvicorn = "^0.20.0"
tenacity = "^8.2.1"
sse-starlette = "^1.6.5"


[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
Expand Down
Empty file added service/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions service/api_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Optional, List

from pydantic import BaseModel

from context_engine.models.data_models import Messages, Query, Document


class ChatRequest(BaseModel):
model: str = ""
messages: Messages
stream: bool = False
user: Optional[str] = None


class ContextQueryRequest(BaseModel):
queries: List[Query]
max_tokens: int


class ContextUpsertRequest(BaseModel):
documents: List[Document]
namespace: str = ""
batch_size: int = 100
Loading

0 comments on commit 257d835

Please sign in to comment.