Skip to content

Commit

Permalink
feat: change to use dataclass for type safety
Browse files Browse the repository at this point in the history
  • Loading branch information
doomspec committed Aug 29, 2024
1 parent ae742e0 commit 4fd9e6b
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 28 deletions.
8 changes: 4 additions & 4 deletions mllm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,16 +236,16 @@ def complete(self, model=None, cache=False, expensive=False, parse=None, retry=T
"""
if options is None:
options = {}
options = {**default_options, **options}
options = {**default_options.get_dict(), **options}

contains_image = self.contains_image()
if model is None:
if not expensive:
model = default_models["normal"]
model = default_models.normal
else:
model = default_models["expensive"]
model = default_models.expensive
if contains_image:
model = default_models["vision"]
model = default_models.vision

if get_llm_provider(model)[1] in ["openai"]:
if parse == "dict":
Expand Down
58 changes: 41 additions & 17 deletions mllm/config.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,46 @@
from dataclasses import dataclass

default_models = {
"normal": "gpt-3.5-turbo",
"expensive": "gpt-4-turbo-preview",
"vision": "gpt-4-vision-preview",
"embedding": "text-embedding-3-large"
}

@dataclass
class DefaultModels:
normal: str = "gpt-4o-mini"
expensive: str = "gpt-4-turbo-preview"
vision: str = "gpt-4o"
embedding: str = "text-embedding-3-large"

default_options = {
"temperature": 0.2,
"max_tokens": None,
"frequency_penalty": None,
"timeout": 600,
"seed": None,
}
def __getitem__(self, item):
return getattr(self, item)

def __setitem__(self, key, value):
setattr(self, key, value)

class CacheOptions:
def __init__(self):
self.max_embedding_vecs = 1000
self.max_chat = 1000
def get_dict(self):
return {k: v for k, v in self.__dict__.items() if v is not None}


@dataclass
class DefaultOptions:
temperature: float = 0.2
max_tokens: int = None
frequency_penalty: float = None
timeout: int = 600
seed: int = None

def __getitem__(self, item):
return getattr(self, item)

def __setitem__(self, key, value):
setattr(self, key, value)

def get_dict(self):
return {k: v for k, v in self.__dict__.items() if v is not None}


default_models = DefaultModels()

default_options = DefaultOptions()

from mllm.utils.parser import parse_options

if __name__ == '__main__':
print(default_options.get_dict())
19 changes: 12 additions & 7 deletions mllm/utils/parser.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import ast
import json
from dataclasses import dataclass

"""
# Parse
"""

parse_options = {
"cheap_model": "gpt-4o-mini",
"correct_json_by_model": False,
}

@dataclass
class ParseOptions:
cheap_model: str = "gpt-4o-mini"
correct_json_by_model: bool = False


parse_options = ParseOptions()


class Parse:
Expand Down Expand Up @@ -50,7 +55,7 @@ def dict(src: str):
except:
pass

if parse_options["correct_json_by_model"]:
if parse_options.correct_json_by_model:
try:
res = parse_json_by_cheap_model(json_src)
return res
Expand Down Expand Up @@ -87,7 +92,7 @@ def quotes(src: str):
break
else:
raise ValueError("No closing quotes")
for end in range(start+1, len(split_lines)):
for end in range(start + 1, len(split_lines)):
if split_lines[end].startswith("```"):
break
else:
Expand Down Expand Up @@ -125,5 +130,5 @@ def parse_json_by_cheap_model(json_src):
You should directly output the corrected JSON dict with a minimal modification.
"""
chat = Chat(prompt)
res = chat.complete(parse="dict", model=parse_options["cheap_model"])
res = chat.complete(parse="dict", model=parse_options.cheap_model)
return res
4 changes: 4 additions & 0 deletions test/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ def test_dict_gen():
res = chat.complete(parse="dict", cache=False)
assert res == {"a": 1, "b": 2}


def test_model_correct():
# this import is necessary to test whether it's imported into mllm.config
from mllm.config import parse_options

raw_json = """
{
no_quote : "
Expand Down

0 comments on commit 4fd9e6b

Please sign in to comment.