From 4fd9e6bea60b77d66c18f8b09151490475151bc4 Mon Sep 17 00:00:00 2001 From: Zijian Zhang Date: Thu, 29 Aug 2024 15:07:22 -0400 Subject: [PATCH] feat: change to use dataclass for type safety --- mllm/chat.py | 8 +++--- mllm/config.py | 58 +++++++++++++++++++++++++++++++------------- mllm/utils/parser.py | 19 +++++++++------ test/test_parse.py | 4 +++ 4 files changed, 61 insertions(+), 28 deletions(-) diff --git a/mllm/chat.py b/mllm/chat.py index d4fa96e..92b2309 100644 --- a/mllm/chat.py +++ b/mllm/chat.py @@ -236,16 +236,16 @@ def complete(self, model=None, cache=False, expensive=False, parse=None, retry=T """ if options is None: options = {} - options = {**default_options, **options} + options = {**default_options.get_dict(), **options} contains_image = self.contains_image() if model is None: if not expensive: - model = default_models["normal"] + model = default_models.normal else: - model = default_models["expensive"] + model = default_models.expensive if contains_image: - model = default_models["vision"] + model = default_models.vision if get_llm_provider(model)[1] in ["openai"]: if parse == "dict": diff --git a/mllm/config.py b/mllm/config.py index ec2e2ec..ece2fef 100644 --- a/mllm/config.py +++ b/mllm/config.py @@ -1,22 +1,46 @@ +from dataclasses import dataclass -default_models = { - "normal": "gpt-3.5-turbo", - "expensive": "gpt-4-turbo-preview", - "vision": "gpt-4-vision-preview", - "embedding": "text-embedding-3-large" -} +@dataclass +class DefaultModels: + normal: str = "gpt-4o-mini" + expensive: str = "gpt-4-turbo-preview" + vision: str = "gpt-4o" + embedding: str = "text-embedding-3-large" -default_options = { - "temperature": 0.2, - "max_tokens": None, - "frequency_penalty": None, - "timeout": 600, - "seed": None, -} + def __getitem__(self, item): + return getattr(self, item) + def __setitem__(self, key, value): + setattr(self, key, value) -class CacheOptions: - def __init__(self): - self.max_embedding_vecs = 1000 - self.max_chat = 1000 \ No newline at end of file + def get_dict(self): + return {k: v for k, v in self.__dict__.items() if v is not None} + + +@dataclass +class DefaultOptions: + temperature: float = 0.2 + max_tokens: int = None + frequency_penalty: float = None + timeout: int = 600 + seed: int = None + + def __getitem__(self, item): + return getattr(self, item) + + def __setitem__(self, key, value): + setattr(self, key, value) + + def get_dict(self): + return {k: v for k, v in self.__dict__.items() if v is not None} + + +default_models = DefaultModels() + +default_options = DefaultOptions() + +from mllm.utils.parser import parse_options + +if __name__ == '__main__': + print(default_options.get_dict()) diff --git a/mllm/utils/parser.py b/mllm/utils/parser.py index 136bb2d..98d4f91 100644 --- a/mllm/utils/parser.py +++ b/mllm/utils/parser.py @@ -1,14 +1,19 @@ import ast import json +from dataclasses import dataclass """ # Parse """ -parse_options = { - "cheap_model": "gpt-4o-mini", - "correct_json_by_model": False, -} + +@dataclass +class ParseOptions: + cheap_model: str = "gpt-4o-mini" + correct_json_by_model: bool = False + + +parse_options = ParseOptions() class Parse: @@ -50,7 +55,7 @@ def dict(src: str): except: pass - if parse_options["correct_json_by_model"]: + if parse_options.correct_json_by_model: try: res = parse_json_by_cheap_model(json_src) return res @@ -87,7 +92,7 @@ def quotes(src: str): break else: raise ValueError("No closing quotes") - for end in range(start+1, len(split_lines)): + for end in range(start + 1, len(split_lines)): if split_lines[end].startswith("```"): break else: @@ -125,5 +130,5 @@ def parse_json_by_cheap_model(json_src): You should directly output the corrected JSON dict with a minimal modification. """ chat = Chat(prompt) - res = chat.complete(parse="dict", model=parse_options["cheap_model"]) + res = chat.complete(parse="dict", model=parse_options.cheap_model) return res diff --git a/test/test_parse.py b/test/test_parse.py index e8f8b5c..dd9e97f 100644 --- a/test/test_parse.py +++ b/test/test_parse.py @@ -37,7 +37,11 @@ def test_dict_gen(): res = chat.complete(parse="dict", cache=False) assert res == {"a": 1, "b": 2} + def test_model_correct(): + # this import is necessary to test whether it's imported into mllm.config + from mllm.config import parse_options + raw_json = """ { no_quote : "