Skip to content

Commit

Permalink
format the rest of the files
Browse files Browse the repository at this point in the history
  • Loading branch information
ricopinazo committed May 6, 2024
1 parent f3b7892 commit 1b58761
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 30 deletions.
2 changes: 2 additions & 0 deletions src/comai/animations.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ def query_animation() -> Generator[None, None, None]:
# for chunk in command_chunks:
# print(f"[{COMMAND_COLOR}]{chunk}", end="", flush=True)


def start_printing_command():
print(f"[{ANSWER_PROMPT_COLOR}]{ANSWER_PROMPT}", end="", flush=True)


def print_command_token(chunk: str):
print(f"[{COMMAND_COLOR}]{chunk}", end="", flush=True)

Expand Down
2 changes: 1 addition & 1 deletion src/comai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main_normal_flow(instructions: List[str]):
output = query_command(input_text, settings, context)
with query_animation():
stream_start = next(output)
assert(type(stream_start) == StreamStart)
assert type(stream_start) == StreamStart

start_printing_command()
for chunk in output:
Expand Down
58 changes: 30 additions & 28 deletions src/comai/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path

from langchain_community.chat_message_histories import SQLChatMessageHistory

# from langchain.memory import ChatMessageHistory
from langchain_core.chat_history import InMemoryChatMessageHistory

Expand All @@ -15,7 +16,8 @@
except Exception:
pass

def load_history(session_id: str)-> SQLChatMessageHistory | InMemoryChatMessageHistory:

def load_history(session_id: str) -> SQLChatMessageHistory | InMemoryChatMessageHistory:
if history_path:
return SQLChatMessageHistory(
session_id=session_id, connection_string=f"sqlite:///{history_path}"
Expand All @@ -30,36 +32,36 @@ def load_history(session_id: str)-> SQLChatMessageHistory | InMemoryChatMessageH
# self.messages = messages
# self.filepath = filepath

# @classmethod
# def load_from_file(cls, filepath: os.PathLike) -> History:
# messages = []
# try:
# with open(filepath, "br") as history_file:
# messages = pickle.load(history_file)
# except Exception:
# pass
# return History(messages, filepath)
# @classmethod
# def load_from_file(cls, filepath: os.PathLike) -> History:
# messages = []
# try:
# with open(filepath, "br") as history_file:
# messages = pickle.load(history_file)
# except Exception:
# pass
# return History(messages, filepath)

# @classmethod
# def create_local(cls) -> History:
# return History([], None)
# @classmethod
# def create_local(cls) -> History:
# return History([], None)

# def append_user_message(self, request: str) -> None:
# user_message = {"role": "user", "content": request}
# self.messages += [user_message]
# def append_user_message(self, request: str) -> None:
# user_message = {"role": "user", "content": request}
# self.messages += [user_message]

# def append_assistant_message(self, command: str) -> None:
# content = f"""COMMAND {command} END"""
# assistant_message = {"role": "user", "content": content}
# self.messages += [assistant_message]
# def append_assistant_message(self, command: str) -> None:
# content = f"""COMMAND {command} END"""
# assistant_message = {"role": "user", "content": content}
# self.messages += [assistant_message]

# def get_messages(self) -> List:
# return self.messages
# def get_messages(self) -> List:
# return self.messages

# def checkpoint(self) -> None:
# if self.filepath:
# with open(self.filepath, "bw") as history_file:
# pickle.dump(self.messages, history_file)
# def checkpoint(self) -> None:
# if self.filepath:
# with open(self.filepath, "bw") as history_file:
# pickle.dump(self.messages, history_file)

# def copy(self) -> History:
# return History(copy(self.messages), self.filepath)
# def copy(self) -> History:
# return History(copy(self.messages), self.filepath)
7 changes: 6 additions & 1 deletion src/comai/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@
config_dir = typer.get_app_dir(APP_NAME, force_posix=True)
settings_path = os.path.join(config_dir, "settings.json")


class Settings(BaseModel):
provider: Literal["ollama", "openai"]
model: str = "llama3" # TODO: improve this, should be typed per provider, although possible models can be queried at runtime
model: str = (
"llama3" # TODO: improve this, should be typed per provider, although possible models can be queried at runtime
)


DEFAULT_SETTINGS: Settings = Settings(provider="ollama", model="llama3")


def load_settings() -> Settings:
try:
return Settings.parse_file(settings_path)
Expand Down

0 comments on commit 1b58761

Please sign in to comment.