Skip to content

Commit

Permalink
add openai support
Browse files Browse the repository at this point in the history
  • Loading branch information
ricopinazo committed May 20, 2024
1 parent 38043be commit 99ad149
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 11 deletions.
10 changes: 9 additions & 1 deletion src/comai/chain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
import logging
import re
import os
import textwrap
import uuid
from dataclasses import dataclass
Expand All @@ -19,6 +20,7 @@
)
from langchain_core.runnables.utils import ConfigurableFieldSpec
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI

from comai.context import Context
from comai.history import load_history
Expand Down Expand Up @@ -87,7 +89,13 @@ def attatch_history(

def create_chain_stream(settings: Settings, context: Context):
prompt = create_prompt(context)
model = ChatOllama(model=settings.model, temperature=0)
if settings.provider == "ollama":
model = ChatOllama(model=settings.model, temperature=0)
elif settings.provider == "openai":
default_key = os.environ.get("OPENAI_API_KEY")
comai_key = os.environ.get("COMAI_OPENAI_API_KEY")
api_key = comai_key if comai_key is not None else default_key
model = ChatOpenAI(model=settings.model, temperature=0, api_key=api_key)
base_chain = prompt | model

if context.session_id is not None:
Expand Down
34 changes: 24 additions & 10 deletions src/comai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from comai import __version__
from comai.chain import StreamStart, Token, FinalCommand, query_command
from comai.ollama import get_ollama_model_names
from comai.openai import get_openai_model_names
from comai.prompt import prompt_bool, prompt_options
from comai.settings import (
InvalidSettingsFileException,
Expand Down Expand Up @@ -51,19 +52,32 @@ def show_settings_callback(value: bool):
def settings_callback(value: bool):
if value:
settings = load_settings()
ollama_models = get_ollama_model_names()
if settings.model in ollama_models:

provider = prompt_options(
"Which provider do you want to use?",
["ollama", "openai"],
settings.provider,
)
if provider == "ollama":
models = get_ollama_model_names()
elif provider == "openai":
models = get_openai_model_names()
else:
raise Exception(f"Got unknown provider option: {provider}")

if settings.model in models:
default_model = settings.model
elif "llama3" in ollama_models:
elif "llama3" in models:
default_model = "llama3"
elif len(ollama_models) > 0:
default_model = ollama_models[0]
elif "gpt-3.5-turbo" in models:
default_model = "gpt-3.5-turbo"
elif len(models) > 0:
default_model = models[0]
else:
default_model = "llama3"
model = prompt_options(
"Which model do you want to use?", ollama_models, default_model
)
settings.provider = "ollama"
raise Exception("No models available for the selected provider")
model = prompt_options("Which model do you want to use?", models, default_model)

settings.provider = provider
settings.model = model
write_settings(settings)
raise typer.Exit()
Expand Down
12 changes: 12 additions & 0 deletions src/comai/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from openai import OpenAI
import os


def get_openai_model_names() -> list[str]:
# FIXME: this piece of code is duplicated
default_key = os.environ.get("OPENAI_API_KEY")
comai_key = os.environ.get("COMAI_OPENAI_API_KEY")
api_key = comai_key if comai_key is not None else default_key

client = OpenAI(api_key=api_key)
return [model.id for model in client.models.list() if model.id.startswith("gpt-")]

0 comments on commit 99ad149

Please sign in to comment.