diff --git a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/registry.py b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/registry.py index 08cbf2a..3202f56 100644 --- a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/registry.py +++ b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/registry.py @@ -6,26 +6,23 @@ """ + class AlgorithmRegistry: _algorithms = {} - + @classmethod def register(cls, name=None, input_type=None, output_type=None): def decorator(func): algorithm_name = name or func.__name__ - cls._algorithms[algorithm_name] = { - 'type': func, - 'input_type': input_type, - 'output_type': output_type - } + cls._algorithms[algorithm_name] = {"type": func, "input_type": input_type, "output_type": output_type} return func + return decorator - + @classmethod def get_all(cls): return cls._algorithms - + @classmethod def get_by_category(cls, category): - return {name: data for name, data in cls._algorithms.items() - if data['category'] == category} + return {name: data for name, data in cls._algorithms.items() if data["category"] == category} diff --git a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/run_autometrics.py b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/run_autometrics.py index 0bc927a..27682e5 100644 --- a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/run_autometrics.py +++ b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/run_autometrics.py @@ -7,44 +7,44 @@ import datetime -with open('dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/test_data/data.json') as f: +with open("dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/test_data/data.json") as f: test_data = json.load(f) - test_dialogue = test_data['dialogue'] + test_dialogue = test_data["dialogue"] test_dialogue = Dialogue(dialogue=test_dialogue) - test_graph = test_data['graph'] + test_graph = test_data["graph"] test_graph = Graph(graph_dict=test_graph) def run_all_algorithms(): # Get all registered classes algorithms = AlgorithmRegistry.get_all() - print('Classes to test:', *algorithms.keys(), sep='\n') + print("Classes to test:", *algorithms.keys(), sep="\n") total_metrics = {} for class_ in algorithms: - class_instance = algorithms[class_]['type']() + class_instance = algorithms[class_]["type"]() metrics = {} - if algorithms[class_]['input_type'] is BaseGraph: + if algorithms[class_]["input_type"] is BaseGraph: result = class_instance.invoke(test_graph) metrics = { "all_paths_sampled": all_paths_sampled(test_graph, result[0]), - "all_utterances_present": all_utterances_present(test_graph, result) + "all_utterances_present": all_utterances_present(test_graph, result), } - elif algorithms[class_]['input_type'] is list[Dialogue]: + elif algorithms[class_]["input_type"] is list[Dialogue]: result = class_instance.invoke(test_dialogue) - + total_metrics[class_] = metrics return total_metrics -if __name__=="__main__": +if __name__ == "__main__": with open("dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/results/results.json") as f: old_data = json.load(f) - + new_metrics = {str(datetime.datetime.now()): run_all_algorithms()} - old_data.update(new_metrics) + old_data.update(new_metrics) - with open("dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/results/results.json", 'w') as f: - f.write(json.dumps(old_data, indent=2, ensure_ascii=False)) \ No newline at end of file + with open("dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/results/results.json", "w") as f: + f.write(json.dumps(old_data, indent=2, ensure_ascii=False)) diff --git a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/sample_dialogue.py b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/sample_dialogue.py index 50728d4..6da561c 100644 --- a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/sample_dialogue.py +++ b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/sample_dialogue.py @@ -5,12 +5,9 @@ from chatsky_llm_autoconfig.autometrics.registry import AlgorithmRegistry -@AlgorithmRegistry.register( - input_type=BaseGraph, - output_type=list[Dialogue] - ) +@AlgorithmRegistry.register(input_type=BaseGraph, output_type=list[Dialogue]) class DialogueSampler(DialogueGenerator): - + def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = -1, topic="") -> list[Dialogue]: nx_graph = graph.graph if end_node == -1: