Skip to content

Commit

Permalink
chore(refactor): Use dep injection, refactor error checking, update d…
Browse files Browse the repository at this point in the history
…eps & related config/imports
  • Loading branch information
Avantol13 committed Jan 22, 2024
1 parent 46dfe22 commit 020ba2c
Show file tree
Hide file tree
Showing 11 changed files with 1,830 additions and 1,717 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ TOPICS=default,anothertopic,gen3docs
# when a configuration is not provided. e.g. if you don't provide FOOBAR_SYSTEM_PROMPT then the DEFAULT_SYSTEM_PROMPT
# will be used
DEFAULT_SYSTEM_PROMPT=You are acting as a search assistant for a researcher who will be asking you questions about data available in a particular system. If you believe the question is not relevant to data in the system, do not answer. The researcher is likely trying to find data of interest for a particular reason or with specific criteria. You answer and recommend datasets that may be of interest to that researcher based on the context you're provided. If you are using any particular context to answer, you should cite that and tell the user where they can find more information. The user may not be able to see the documents you are citing, so provide the relevant information in your response. If you don't know the answer, just say that you don't know, don't try to make up an answer. If you don't believe what the user is looking for is available in the system based on the context, say so instead of trying to explain how to go somewhere else.
DEFAULT_RAW_METADATA=model_name:chat-bison,model_temperature:0.3,max_output_tokens:512,num_similar_docs_to_find:7,similarity_score_threshold:0.6
DEFAULT_RAW_METADATA=model_name:chat-bison,embedding_model_name:textembedding-gecko@003,model_temperature:0.3,max_output_tokens:512,num_similar_docs_to_find:7,similarity_score_threshold:0.6
DEFAULT_DESCRIPTION=Ask about available datasets, powered by public dataset metadata like study descriptions
# Additional topic configurations
Expand All @@ -137,7 +137,7 @@ ANOTHERTOPIC_SYSTEM_PROMPT=You answer questions about datasets that are availabl
ANOTHERTOPIC_CHAIN_NAME=TopicChainOpenAiQuestionAnswerRAG
GEN3DOCS_SYSTEM_PROMPT=You will be given relevant context from all the public documentation surrounding an open source software called Gen3. You are acting as an assistant to a new Gen3 developer, who is going to ask a question. Try to answer their question based on the context, but know that some of the context may be out of date. Let the developer know where they can get more information if relevant and cite portions of the context.
GEN3DOCS_RAW_METADATA=model_name:chat-bison,model_temperature:0.5,max_output_tokens:512,num_similar_docs_to_find:7,similarity_score_threshold:0.5
GEN3DOCS_RAW_METADATA=model_name:chat-bison,embedding_model_name:textembedding-gecko-multilingual@001,model_temperature:0.5,max_output_tokens:512,num_similar_docs_to_find:7,similarity_score_threshold:0.5
GEN3DOCS_DESCRIPTION=Ask about Gen3, powered by public markdown files in the UChicago Center for Translational Data Science's GitHub
########## Debugging and Logging Configurations ##########
Expand Down
4 changes: 2 additions & 2 deletions bin/load_into_knowledge_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import os

import click
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.document_loaders.csv_loader import CSVLoader
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain.text_splitter import TokenTextSplitter

from gen3discoveryai import logging
Expand Down
22 changes: 16 additions & 6 deletions gen3discoveryai/auth.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from authutils.token.fastapi import access_token
from fastapi import HTTPException, Request
from fastapi import HTTPException, Request, Depends
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from gen3authz.client.arborist.async_client import ArboristClient
from starlette.status import HTTP_401_UNAUTHORIZED as HTTP_401_UNAUTHENTICATED
from starlette.status import HTTP_403_FORBIDDEN, HTTP_503_SERVICE_UNAVAILABLE
from starlette.status import (
HTTP_401_UNAUTHORIZED as HTTP_401_UNAUTHENTICATED,
HTTP_403_FORBIDDEN,
HTTP_503_SERVICE_UNAVAILABLE,
HTTP_429_TOO_MANY_REQUESTS,
)

from gen3discoveryai import config, logging

Expand Down Expand Up @@ -104,8 +108,9 @@ async def get_user_id(
return token_claims["sub"]


async def has_user_exceeded_limits(
token: HTTPAuthorizationCredentials = None, request: Request = None
async def raise_if_user_exceeded_limits(
token: HTTPAuthorizationCredentials = Depends(get_bearer_token),
request: Request = None,
):
"""
Checks if the user has exceeded certain limits which should prevent them from using the AI.
Expand All @@ -125,7 +130,12 @@ async def has_user_exceeded_limits(
# TODO logic to determine if it's been exceeded
# make sure you try to handle the case where ALLOW_ANONYMOUS_ACCESS is on

return user_limit_exceeded
if user_limit_exceeded:
logging.debug("has_user_exceeded_limits is True")
raise HTTPException(
HTTP_429_TOO_MANY_REQUESTS,
"You've reached a limit for your user. Please try again later.",
)


async def raise_if_global_ai_limit_exceeded():
Expand Down
2 changes: 1 addition & 1 deletion gen3discoveryai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)

OPENAI_API_KEY = config("OPENAI_API_KEY", cast=Secret, default=None)
URL_PREFIX = config("URL_PREFIX", default="/")
URL_PREFIX = config("URL_PREFIX", default=None)

# csv strings for all topic names
#
Expand Down
50 changes: 18 additions & 32 deletions gen3discoveryai/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,35 @@
from importlib.metadata import version
from typing import Any

import openai
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, HTTPException, Request, Depends

from starlette.status import (
HTTP_400_BAD_REQUEST,
HTTP_404_NOT_FOUND,
HTTP_429_TOO_MANY_REQUESTS,
HTTP_503_SERVICE_UNAVAILABLE,
)

from gen3discoveryai import config, logging
from gen3discoveryai.auth import (
authorize_request,
get_user_id,
has_user_exceeded_limits,
raise_if_user_exceeded_limits,
raise_if_global_ai_limit_exceeded,
)
from gen3discoveryai.topic_chains.logging import CustomCallbackHandlerForLogging
from gen3discoveryai.topic_chains.logging import LoggingCallbackHandler

root_router = APIRouter()


@root_router.post("/ask/")
@root_router.post("/ask", include_in_schema=False)
@root_router.post(
"/ask",
include_in_schema=False,
dependencies=[
Depends(raise_if_global_ai_limit_exceeded),
Depends(raise_if_user_exceeded_limits),
],
)
async def ask_route(
request: Request, data: dict, topic: str = "default", conversation_id: str = None
) -> dict:
Expand All @@ -40,8 +46,6 @@ async def ask_route(
existing conversation. Must match a valid conversation ID for this
user AND topic must support conversation-based queries.
"""
await raise_if_global_ai_limit_exceeded()

await authorize_request(
request=request,
authz_access_method="read",
Expand All @@ -65,34 +69,16 @@ async def ask_route(
conversation = None
if conversation_id:
conversation = await _get_conversation_for_user(conversation_id, user_id)
# TODO handle conversation
# TODO (PXP-11239) handle conversation

logging.debug(f"conversation: {conversation}")

if await has_user_exceeded_limits(request=request):
logging.debug("has_user_exceeded_limits is True")
raise HTTPException(
HTTP_429_TOO_MANY_REQUESTS,
"You've reached a limit for your user. Please try again later.",
)

start_time = time.time()
try:
topic_config = config.topics[topic]
raw_response = topic_config["topic_chain"].run(
query=query, callbacks=[CustomCallbackHandlerForLogging()]
query=query, callbacks=[LoggingCallbackHandler()]
)
except openai.RateLimitError as exc:
logging.debug("openai.RateLimitError")
raise HTTPException(
HTTP_429_TOO_MANY_REQUESTS, "Please try again later."
) from exc
except openai.OpenAIError as exc:
logging.debug("openai.OpenAIError")
raise HTTPException(
HTTP_400_BAD_REQUEST,
"Error. You may have too much text in your query.",
) from exc
except Exception as exc:
logging.error(
f"Returning service unavailable. Got unexpected error from chain: {exc}"
Expand All @@ -118,7 +104,7 @@ async def ask_route(
f"user_query={query}, topic={topic}, response={response['response']}, response_time_seconds={end_time - start_time}"
)

# TODO
# TODO (PXP-11239)
if not conversation_id:
conversation_id = await _get_conversation_id()
await _store_conversation(user_id, conversation_id)
Expand Down Expand Up @@ -249,17 +235,17 @@ async def get_status(request: Request) -> dict:


async def _get_conversation_id() -> str:
# TODO
# TODO (PXP-11239)
return str(uuid.uuid4())


async def _store_conversation(user_id, conversation_id) -> None:
# TODO
# TODO (PXP-11239)
logging.debug(f"storing conversation {conversation_id} for {user_id}")


async def _get_conversation_for_user(conversation_id, user_id) -> Any:
# TODO conversation for now
# TODO (PXP-11239) conversation for now
# should actually retrieve something based on conversation_id
# to see what user's conversation ID it is
logging.debug(f"getting conversation {conversation_id} for {user_id}")
Expand Down
2 changes: 1 addition & 1 deletion gen3discoveryai/topic_chains/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from gen3discoveryai import config, logging


class CustomCallbackHandlerForLogging(BaseCallbackHandler):
class LoggingCallbackHandler(BaseCallbackHandler):
"""
Custom Callback handles all possible callbacks from chains and allows for our
own handing. For now, this just ensures the use of our logging library
Expand Down
39 changes: 24 additions & 15 deletions gen3discoveryai/topic_chains/question_answer_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
import chromadb
import langchain
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatVertexAI
from langchain.embeddings import VertexAIEmbeddings
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.vectorstores.chroma import Chroma

Expand Down Expand Up @@ -69,6 +68,16 @@ def __init__(self, topic: str, metadata: Dict[str, Any] = None) -> None:
llm_top_p = get_from_cfg_metadata("top_p", metadata, default=0.95, type_=float)
llm_top_k = get_from_cfg_metadata("top_k", metadata, default=0, type_=int)

embedding_model_name = get_from_cfg_metadata(
"embedding_model_name",
metadata,
# NOTE: using latest here _could_ result in unexpected updates in behavior if
# Google releases a new version. Recommended to use an explicit version by specifying
# embedding_model_name in the configuration
default="textembedding-gecko@latest",
type_=str,
)

system_prompt = metadata.get("system_prompt", "")

self.llm = ChatVertexAI(
Expand Down Expand Up @@ -109,7 +118,7 @@ def __init__(self, topic: str, metadata: Dict[str, Any] = None) -> None:
vectorstore = Chroma(
client=persistent_client,
collection_name=topic,
embedding_function=VertexAIEmbeddings(),
embedding_function=VertexAIEmbeddings(model_name=embedding_model_name),
# We've heard the `cosine` distance function performs better
# https://docs.trychroma.com/usage-guide#changing-the-distance-function
collection_metadata={"hnsw:space": "cosine"},
Expand Down Expand Up @@ -154,19 +163,19 @@ def store_knowledge(
Args:
documents (list[langchain.schema.document.Document]): documents to store in the knowledge store
"""
try:
# get all docs but don't include anything other than ids
docs = self.vectorstore.get(include=[])
if docs["ids"]:
logging.debug(
f"Deleting current knowledge store collection for {self.topic}..."
)
self.vectorstore.delete(ids=docs["ids"])
except Exception as exc:
# try:
# get all docs but don't include anything other than ids
docs = self.vectorstore.get(include=[])
if docs["ids"]:
logging.debug(
"Exception while deleting collection and recreating client, "
"assume the collection just didn't exist and continue. Exc: {exc}"
f"Deleting current knowledge store collection for {self.topic}..."
)
# doesn't exist so just continue adding
self.vectorstore.delete(ids=docs["ids"])
# except Exception as exc:
# logging.debug(
# "Exception while deleting collection and recreating client, "
# "assume the collection just didn't exist and continue. Exc: {exc}"
# )
# # doesn't exist so just continue adding

self.insert_documents_into_vectorstore(documents)
32 changes: 30 additions & 2 deletions gen3discoveryai/topic_chains/question_answer_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@
from typing import Any, Dict

import chromadb
from fastapi import HTTPException
import langchain
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.vectorstores.chroma import Chroma
import openai
from starlette.status import (
HTTP_400_BAD_REQUEST,
HTTP_429_TOO_MANY_REQUESTS,
)


from gen3discoveryai import config, logging
from gen3discoveryai.topic_chains.base import TopicChain
Expand Down Expand Up @@ -164,3 +170,25 @@ def store_knowledge(
# doesn't exist so just continue adding

self.insert_documents_into_vectorstore(documents)

def run(self, query: str, *args, **kwargs):
"""
Run the query on the underlying chain, overriding base to add OpenAI specific
error catching.
Args:
query (str): query to provide to chain
"""
try:
return self.run(query, *args, **kwargs)
except openai.RateLimitError as exc:
logging.debug("openai.RateLimitError")
raise HTTPException(
HTTP_429_TOO_MANY_REQUESTS, "Please try again later."
) from exc
except openai.OpenAIError as exc:
logging.debug("openai.OpenAIError")
raise HTTPException(
HTTP_400_BAD_REQUEST,
"Error. You may have too much text in your query.",
) from exc
Loading

0 comments on commit 020ba2c

Please sign in to comment.