diff --git a/application/core/settings.py b/application/core/settings.py index 42dea0ff6..d9b68ed75 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -39,6 +39,9 @@ class Settings(BaseSettings): SAGEMAKER_ACCESS_KEY: Optional[str] = None # SageMaker access key SAGEMAKER_SECRET_KEY: Optional[str] = None # SageMaker secret key + # prem ai project id + PREMAI_PROJECT_ID: Optional[str] = None + path = Path(__file__).parent.parent.absolute() settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8") diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py index d0d6ae3f8..b4fdaebf5 100644 --- a/application/llm/llm_creator.py +++ b/application/llm/llm_creator.py @@ -4,6 +4,7 @@ from application.llm.llama_cpp import LlamaCpp from application.llm.anthropic import AnthropicLLM from application.llm.docsgpt_provider import DocsGPTAPILLM +from application.llm.premai import PremAILLM @@ -15,7 +16,8 @@ class LLMCreator: 'huggingface': HuggingFaceLLM, 'llama.cpp': LlamaCpp, 'anthropic': AnthropicLLM, - 'docsgpt': DocsGPTAPILLM + 'docsgpt': DocsGPTAPILLM, + 'premai': PremAILLM, } @classmethod diff --git a/application/llm/premai.py b/application/llm/premai.py new file mode 100644 index 000000000..4bc8a898c --- /dev/null +++ b/application/llm/premai.py @@ -0,0 +1,33 @@ +from application.llm.base import BaseLLM +from application.core.settings import settings + +class PremAILLM(BaseLLM): + + def __init__(self, api_key): + from premai import Prem + + self.client = Prem( + api_key=api_key + ) + self.api_key = api_key + self.project_id = settings.PREMAI_PROJECT_ID + + def gen(self, model, engine, messages, stream=False, **kwargs): + response = self.client.chat.completions.create(model=model, + project_id=self.project_id, + messages=messages, + stream=stream, + **kwargs) + + return response.choices[0].message["content"] + + def gen_stream(self, model, engine, messages, stream=True, **kwargs): + response = self.client.chat.completions.create(model=model, + project_id=self.project_id, + messages=messages, + stream=stream, + **kwargs) + + for line in response: + if line.choices[0].delta["content"] is not None: + yield line.choices[0].delta["content"]