Skip to content

Commit

Permalink
🔍 Improve LLM type detection
Browse files Browse the repository at this point in the history
  • Loading branch information
shroominic committed Nov 10, 2023
1 parent d454b04 commit 2d5e5c3
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions funcchain/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,22 @@ def count_tokens(text: str, model: str = "gpt-4") -> int:
return len(encoding_for_model(model).encode(text))


def gather_llm_type(llm: BaseLanguageModel | Runnable, func_check: bool = False) -> str:
def gather_llm_type(llm: BaseLanguageModel | Runnable, func_check: bool = True) -> str:
if isinstance(llm, RunnableWithFallbacks):
llm = llm.runnable
if not isinstance(llm, BaseChatModel):
return "base_model"
if not isinstance(llm, ChatOpenAI):
return "chat_model"
if llm.model_name == "gpt-4-vision-preview":
return "vision_model"
if llm.model_name == "gpt-4":
return "function_model"
try:
if func_check:
llm.predict_messages(
[
SystemMessage(content="You are an AI assistant."),
SystemMessage(content="This is a test message to see if the model can run functions."),
HumanMessage(content="Hello!"),
],
functions=[
Expand All @@ -96,21 +100,27 @@ def gather_llm_type(llm: BaseLanguageModel | Runnable, func_check: bool = False)
],
)
except Exception:
return "openai_model"
return "chat_model"
else:
return "function_model"


FUNCTION_MODEL: bool | None = None


def is_function_model(llm: BaseLanguageModel | RunnableWithFallbacks) -> bool:
global FUNCTION_MODEL
if FUNCTION_MODEL is None:
FUNCTION_MODEL = gather_llm_type(llm, True) == "function_model"
FUNCTION_MODEL = gather_llm_type(llm) == "function_model"
return FUNCTION_MODEL


VISION_MODEL: bool | None = None
def is_vision_model(llm: BaseLanguageModel | RunnableWithFallbacks) -> bool:
global VISION_MODEL
if VISION_MODEL is None:
VISION_MODEL = gather_llm_type(llm) == "vision_model"
return VISION_MODEL


def _remove_a_key(d, remove_key) -> None:
"""Remove a key from a dictionary recursively"""
if isinstance(d, dict):
Expand Down

0 comments on commit 2d5e5c3

Please sign in to comment.