Skip to content

Commit

Permalink
✨ Improve type annotations, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
shroominic committed Nov 11, 2023
1 parent 1123176 commit 2844413
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions funcchain/utils/model_defaults.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from dotenv import load_dotenv
from langchain.chat_models import AzureChatOpenAI, ChatAnthropic, ChatGooglePalm, ChatOpenAI, JinaChat
from langchain.chat_models.base import BaseChatModel
Expand All @@ -6,15 +8,15 @@
from funcchain.config import settings


def auto_model(**kwargs) -> BaseChatModel | RunnableWithFallbacks:
def auto_model(**kwargs: Any) -> BaseChatModel | RunnableWithFallbacks:
if settings.AZURE_DEPLOYMENT_NAME_LONG:
return create_long_llm()
return model_from_env(**kwargs)


def model_from_env(
dotenv_path: str = "./.env",
**kwargs,
**kwargs: Any,
) -> BaseChatModel:
"""
Automatically search your env variables for api keys
Expand Down Expand Up @@ -55,7 +57,7 @@ def model_from_env(
def model_from_name(
model_name: str,
/,
**kwargs,
**kwargs: Any,
) -> BaseChatModel:
"""
Input model_name using this schema
Expand Down Expand Up @@ -120,15 +122,15 @@ def create_long_llm() -> RunnableWithFallbacks:
)
print("Model: AZURE")
return AzureChatOpenAI(
deployment_name=settings.AZURE_DEPLOYMENT_NAME,
azure_deployment=settings.AZURE_DEPLOYMENT_NAME,
model=config.pop("model", None) or "gpt-3.5-turbo",
**config, # type: ignore
**config,
).with_fallbacks(
[
AzureChatOpenAI(
model=(model + "-32k") if (model := config.pop("model", None)) else "gpt-3.5-turbo-16k",
deployment_name=settings.AZURE_DEPLOYMENT_NAME_LONG or "gpt-4-32k",
**config, # type: ignore
azure_deployment=settings.AZURE_DEPLOYMENT_NAME_LONG or "gpt-4-32k",
**config,
)
]
)
Expand All @@ -137,12 +139,12 @@ def create_long_llm() -> RunnableWithFallbacks:
print("Model: OPENAI")
return ChatOpenAI(
model=config.pop("model", None) or "gpt-3.5-turbo",
**config, # type: ignore
**config,
).with_fallbacks(
[
ChatOpenAI(
model=(model + "-32k") if (model := config.pop("model", None)) else "gpt-3.5-turbo-16k",
**config, # type: ignore
**config,
),
]
)
Expand Down

0 comments on commit 2844413

Please sign in to comment.