From 99ad1490cbe3f5c53c2359cc10c3fa266431c98b Mon Sep 17 00:00:00 2001 From: Pedro Rico Pinazo Date: Mon, 20 May 2024 22:24:54 +0100 Subject: [PATCH] add openai support --- src/comai/chain.py | 10 +++++++++- src/comai/cli.py | 34 ++++++++++++++++++++++++---------- src/comai/openai.py | 12 ++++++++++++ 3 files changed, 45 insertions(+), 11 deletions(-) create mode 100644 src/comai/openai.py diff --git a/src/comai/chain.py b/src/comai/chain.py index baeda70..1f64c4f 100644 --- a/src/comai/chain.py +++ b/src/comai/chain.py @@ -1,6 +1,7 @@ import itertools import logging import re +import os import textwrap import uuid from dataclasses import dataclass @@ -19,6 +20,7 @@ ) from langchain_core.runnables.utils import ConfigurableFieldSpec from langchain_core.tools import tool +from langchain_openai import ChatOpenAI from comai.context import Context from comai.history import load_history @@ -87,7 +89,13 @@ def attatch_history( def create_chain_stream(settings: Settings, context: Context): prompt = create_prompt(context) - model = ChatOllama(model=settings.model, temperature=0) + if settings.provider == "ollama": + model = ChatOllama(model=settings.model, temperature=0) + elif settings.provider == "openai": + default_key = os.environ.get("OPENAI_API_KEY") + comai_key = os.environ.get("COMAI_OPENAI_API_KEY") + api_key = comai_key if comai_key is not None else default_key + model = ChatOpenAI(model=settings.model, temperature=0, api_key=api_key) base_chain = prompt | model if context.session_id is not None: diff --git a/src/comai/cli.py b/src/comai/cli.py index a494a18..8415c70 100755 --- a/src/comai/cli.py +++ b/src/comai/cli.py @@ -12,6 +12,7 @@ from comai import __version__ from comai.chain import StreamStart, Token, FinalCommand, query_command from comai.ollama import get_ollama_model_names +from comai.openai import get_openai_model_names from comai.prompt import prompt_bool, prompt_options from comai.settings import ( InvalidSettingsFileException, @@ -51,19 +52,32 @@ def show_settings_callback(value: bool): def settings_callback(value: bool): if value: settings = load_settings() - ollama_models = get_ollama_model_names() - if settings.model in ollama_models: + + provider = prompt_options( + "Which provider do you want to use?", + ["ollama", "openai"], + settings.provider, + ) + if provider == "ollama": + models = get_ollama_model_names() + elif provider == "openai": + models = get_openai_model_names() + else: + raise Exception(f"Got unknown provider option: {provider}") + + if settings.model in models: default_model = settings.model - elif "llama3" in ollama_models: + elif "llama3" in models: default_model = "llama3" - elif len(ollama_models) > 0: - default_model = ollama_models[0] + elif "gpt-3.5-turbo" in models: + default_model = "gpt-3.5-turbo" + elif len(models) > 0: + default_model = models[0] else: - default_model = "llama3" - model = prompt_options( - "Which model do you want to use?", ollama_models, default_model - ) - settings.provider = "ollama" + raise Exception("No models available for the selected provider") + model = prompt_options("Which model do you want to use?", models, default_model) + + settings.provider = provider settings.model = model write_settings(settings) raise typer.Exit() diff --git a/src/comai/openai.py b/src/comai/openai.py new file mode 100644 index 0000000..64e053d --- /dev/null +++ b/src/comai/openai.py @@ -0,0 +1,12 @@ +from openai import OpenAI +import os + + +def get_openai_model_names() -> list[str]: + # FIXME: this piece of code is duplicated + default_key = os.environ.get("OPENAI_API_KEY") + comai_key = os.environ.get("COMAI_OPENAI_API_KEY") + api_key = comai_key if comai_key is not None else default_key + + client = OpenAI(api_key=api_key) + return [model.id for model in client.models.list() if model.id.startswith("gpt-")]