Skip to content

Commit

Permalink
sagemaker + llm creator class
Browse files Browse the repository at this point in the history
  • Loading branch information
dartpain committed Sep 29, 2023
1 parent c1c54f4 commit e4be38b
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 25 deletions.
26 changes: 5 additions & 21 deletions application/api/answer/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


from application.core.settings import settings
from application.llm.openai import OpenAILLM, AzureOpenAILLM
from application.llm.llm_creator import LLMCreator
from application.vectorstore.faiss import FaissStore
from application.error import bad_request

Expand Down Expand Up @@ -128,16 +128,8 @@ def is_azure_configured():


def complete_stream(question, docsearch, chat_history, api_key, conversation_id):
if is_azure_configured():
llm = AzureOpenAILLM(
openai_api_key=api_key,
openai_api_base=settings.OPENAI_API_BASE,
openai_api_version=settings.OPENAI_API_VERSION,
deployment_name=settings.AZURE_DEPLOYMENT_NAME,
)
else:
logger.debug("plain OpenAI")
llm = OpenAILLM(api_key=api_key)
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)


docs = docsearch.search(question, k=2)
# join all page_content together with a newline
Expand Down Expand Up @@ -270,16 +262,8 @@ def api_answer():
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
docsearch = FaissStore(vectorstore, embeddings_key)

if is_azure_configured():
llm = AzureOpenAILLM(
openai_api_key=api_key,
openai_api_base=settings.OPENAI_API_BASE,
openai_api_version=settings.OPENAI_API_VERSION,
deployment_name=settings.AZURE_DEPLOYMENT_NAME,
)
else:
logger.debug("plain OpenAI")
llm = OpenAILLM(api_key=api_key)

llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)



Expand Down
2 changes: 1 addition & 1 deletion application/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class Settings(BaseSettings):
LLM_NAME: str = "openai_chat"
LLM_NAME: str = "openai"
EMBEDDINGS_NAME: str = "openai_text-embedding-ada-002"
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
Expand Down
20 changes: 20 additions & 0 deletions application/llm/llm_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from application.llm.openai import OpenAILLM, AzureOpenAILLM
from application.llm.sagemaker import SagemakerAPILLM
from application.llm.huggingface import HuggingFaceLLM



class LLMCreator:
llms = {
'openai': OpenAILLM,
'azure_openai': AzureOpenAILLM,
'sagemaker': SagemakerAPILLM,
'huggingface': HuggingFaceLLM
}

@classmethod
def create_llm(cls, type, *args, **kwargs):
llm_class = cls.llms.get(type.lower())
if not llm_class:
raise ValueError(f"No LLM class found for type {type}")
return llm_class(*args, **kwargs)
7 changes: 4 additions & 3 deletions application/llm/openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from application.llm.base import BaseLLM
from application.core.settings import settings

class OpenAILLM(BaseLLM):

Expand Down Expand Up @@ -44,9 +45,9 @@ class AzureOpenAILLM(OpenAILLM):

def __init__(self, openai_api_key, openai_api_base, openai_api_version, deployment_name):
super().__init__(openai_api_key)
self.api_base = openai_api_base
self.api_version = openai_api_version
self.deployment_name = deployment_name
self.api_base = settings.OPENAI_API_BASE,
self.api_version = settings.OPENAI_API_VERSION,
self.deployment_name = settings.AZURE_DEPLOYMENT_NAME,

def _get_openai(self):
openai = super()._get_openai()
Expand Down
27 changes: 27 additions & 0 deletions application/llm/sagemaker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from application.llm.base import BaseLLM
from application.core.settings import settings
import requests
import json

class SagemakerAPILLM(BaseLLM):

def __init__(self, *args, **kwargs):
self.url = settings.SAGEMAKER_API_URL

def gen(self, model, engine, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

response = requests.post(
url=self.url,
headers={
"Content-Type": "application/json; charset=utf-8",
},
data=json.dumps({"input": prompt})
)

return response.json()['answer']

def gen_stream(self, model, engine, messages, stream=True, **kwargs):
raise NotImplementedError("Sagemaker does not support streaming")

0 comments on commit e4be38b

Please sign in to comment.