From c407fae61ed148bfef0ee6659c5e7ef581eb5ce9 Mon Sep 17 00:00:00 2001 From: Shroominic Date: Thu, 28 Dec 2023 15:23:12 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AC=20fix=20vision=20model=20type=20ga?= =?UTF-8?q?ther?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/funcchain/utils/helpers.py | 91 ++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 42 deletions(-) diff --git a/src/funcchain/utils/helpers.py b/src/funcchain/utils/helpers.py index b03ca1b..6603fb9 100644 --- a/src/funcchain/utils/helpers.py +++ b/src/funcchain/utils/helpers.py @@ -1,10 +1,9 @@ from typing import Any, NoReturn, Type from docstring_parser import parse -from langchain.chat_models import ChatOpenAI -from langchain_core.language_models import BaseChatModel, BaseLanguageModel +from langchain.chat_models import ChatOpenAI, ChatOllama +from langchain_core.language_models import BaseChatModel from langchain_core.messages import HumanMessage, SystemMessage -from langchain_core.runnables import Runnable, RunnableWithFallbacks from pydantic import BaseModel from tiktoken import encoding_for_model @@ -17,7 +16,7 @@ def count_tokens(text: str, model: str = "gpt-4") -> int: return len(encoding_for_model(model).encode(text)) -verified_function_models = [ +verified_openai_function_models = [ "gpt-4", "gpt-4-0613", "gpt-4-1106-preview", @@ -30,59 +29,67 @@ def count_tokens(text: str, model: str = "gpt-4") -> int: "gpt-3.5-turbo-16k-0613", ] -verified_vision_models = [ +verified_openai_vision_models = [ "gpt-4-vision-preview", ] +verified_ollama_vision_models = [ + "llava", + "bakllava", +] + -def gather_llm_type(llm: BaseLanguageModel | Runnable, func_check: bool = True) -> str: - if isinstance(llm, RunnableWithFallbacks): - llm = llm.runnable +def gather_llm_type(llm: BaseChatModel, func_check: bool = True) -> str: if not isinstance(llm, BaseChatModel): return "base_model" - if not isinstance(llm, ChatOpenAI): - return "chat_model" - if llm.model_name in verified_vision_models: - return "vision_model" - if llm.model_name in verified_function_models: - return "function_model" - try: - if func_check: - llm.predict_messages( - [ - SystemMessage( - content="This is a test message to see if the model can run functions." - ), - HumanMessage(content="Hello!"), - ], - functions=[ - { - "name": "print", - "description": "show the input", - "parameters": { - "properties": { - "__arg1": {"title": "__arg1", "type": "string"}, + if isinstance(llm, ChatOpenAI): + if llm.model_name in verified_openai_vision_models: + return "vision_model" + if llm.model_name in verified_openai_function_models: + return "function_model" + try: + if func_check: + llm.predict_messages( + [ + SystemMessage( + content="This is a test message to see if the model can run functions." + ), + HumanMessage(content="Hello!"), + ], + functions=[ + { + "name": "print", + "description": "show the input", + "parameters": { + "properties": { + "__arg1": {"title": "__arg1", "type": "string"}, + }, + "required": ["__arg1"], + "type": "object", }, - "required": ["__arg1"], - "type": "object", - }, - } - ], - ) - except Exception: - return "chat_model" - else: - return "function_model" + } + ], + ) + except Exception: + return "chat_model" + else: + return "function_model" + elif isinstance(llm, ChatOllama): + for model in verified_ollama_vision_models: + if llm.model in model: + return "vision_model" + + return "chat_model" def is_function_model( - llm: BaseLanguageModel | RunnableWithFallbacks, + llm: BaseChatModel, ) -> bool: return gather_llm_type(llm) == "function_model" def is_vision_model( - llm: BaseLanguageModel | RunnableWithFallbacks, + llm: BaseChatModel, ) -> bool: return gather_llm_type(llm) == "vision_model"