Skip to content

Commit

Permalink
fix type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
ricopinazo committed May 21, 2024
1 parent caa8577 commit 64b69a2
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 42 deletions.
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ dependencies = [
"langchain==0.1.17",
"langchain-openai==0.1.6",
"ollama==0.1.9",
"types-requests==2.31.0.20240406",
]

[project.urls]
Expand All @@ -35,7 +34,13 @@ issues = "https://github.com/ricopinazo/comai/issues"
comai = "comai.cli:app"

[project.optional-dependencies]
test = ["pytest", "hatchling"]
test = [
"pytest",
"hatchling",
"types-requests==2.31.0.20240406",
"mypy==1.9.0",
"black==24.4.0",
]

[tool.hatch.version]
path = "src/comai/__init__.py"
Expand Down
3 changes: 1 addition & 2 deletions src/comai/animations.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def print_command_prompt(command: str):
message = [
("class:mark", ANSWER_PROMPT),
]
# return "ls"
return prompt(message, default="%s" % command, style=style)
return prompt(message, default="%s" % command, style=style) # type: ignore


def hide_cursor() -> None:
Expand Down
19 changes: 16 additions & 3 deletions src/comai/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import uuid
from dataclasses import dataclass
from typing import Iterable, Iterator, List
from pydantic import BaseModel, SecretStr

from langchain.globals import set_debug, set_verbose
from langchain_community.chat_message_histories import SQLChatMessageHistory
Expand Down Expand Up @@ -87,15 +88,27 @@ def attatch_history(
)


class OpenaiSecrets(BaseModel):
default_key: SecretStr | None
comai_key: SecretStr | None


def extract_secret(key: str | None):
if key is None:
return None
else:
return SecretStr(key)


def create_chain_stream(settings: Settings, context: Context):
prompt = create_prompt(context)
if settings.provider == "ollama":
model = ChatOllama(model=settings.model, temperature=0)
elif settings.provider == "openai":
default_key = os.environ.get("OPENAI_API_KEY")
comai_key = os.environ.get("COMAI_OPENAI_API_KEY")
default_key = extract_secret(os.environ.get("OPENAI_API_KEY"))
comai_key = extract_secret(os.environ.get("COMAI_OPENAI_API_KEY"))
api_key = comai_key if comai_key is not None else default_key
model = ChatOpenAI(model=settings.model, temperature=0, api_key=api_key)
model = ChatOpenAI(model=settings.model, temperature=0, api_key=api_key) # type: ignore
base_chain = prompt | model

if context.session_id is not None:
Expand Down
5 changes: 3 additions & 2 deletions src/comai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Settings,
)
from comai.context import get_context
from comai.menu import get_option_from_menu, MenuOption
from comai.animations import (
print_command_token,
query_animation,
Expand All @@ -31,7 +30,9 @@
print_command_prompt,
)

typer.core.rich = None # this is to avoid using right for the help panel
# this is to avoid using right for the help panel
typer.core.rich = None # type: ignore

app = typer.Typer(pretty_exceptions_enable=False, add_completion=False)


Expand Down
32 changes: 0 additions & 32 deletions src/comai/menu.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/comai/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def prompt_str(question: str, default: str) -> str:
("class:mark", "? "),
("class:question", f"{question.strip()} "),
]
return prompt(message, default="%s" % default, style=style)
return prompt(message, default="%s" % default, style=style) # type: ignore


def prompt_options(question: str, options: list[str], default: str) -> str:
Expand Down

0 comments on commit 64b69a2

Please sign in to comment.