Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBioWaste905 committed Nov 5, 2024
1 parent 07b083e commit ad3a5e1
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,23 @@
"""


class AlgorithmRegistry:
_algorithms = {}

@classmethod
def register(cls, name=None, input_type=None, output_type=None):
def decorator(func):
algorithm_name = name or func.__name__
cls._algorithms[algorithm_name] = {
'type': func,
'input_type': input_type,
'output_type': output_type
}
cls._algorithms[algorithm_name] = {"type": func, "input_type": input_type, "output_type": output_type}
return func

return decorator

@classmethod
def get_all(cls):
return cls._algorithms

@classmethod
def get_by_category(cls, category):
return {name: data for name, data in cls._algorithms.items()
if data['category'] == category}
return {name: data for name, data in cls._algorithms.items() if data["category"] == category}
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,44 @@
import datetime


with open('dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/test_data/data.json') as f:
with open("dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/test_data/data.json") as f:
test_data = json.load(f)
test_dialogue = test_data['dialogue']
test_dialogue = test_data["dialogue"]
test_dialogue = Dialogue(dialogue=test_dialogue)
test_graph = test_data['graph']
test_graph = test_data["graph"]
test_graph = Graph(graph_dict=test_graph)


def run_all_algorithms():
# Get all registered classes
algorithms = AlgorithmRegistry.get_all()
print('Classes to test:', *algorithms.keys(), sep='\n')
print("Classes to test:", *algorithms.keys(), sep="\n")
total_metrics = {}
for class_ in algorithms:
class_instance = algorithms[class_]['type']()
class_instance = algorithms[class_]["type"]()
metrics = {}
if algorithms[class_]['input_type'] is BaseGraph:
if algorithms[class_]["input_type"] is BaseGraph:
result = class_instance.invoke(test_graph)
metrics = {
"all_paths_sampled": all_paths_sampled(test_graph, result[0]),
"all_utterances_present": all_utterances_present(test_graph, result)
"all_utterances_present": all_utterances_present(test_graph, result),
}

elif algorithms[class_]['input_type'] is list[Dialogue]:
elif algorithms[class_]["input_type"] is list[Dialogue]:
result = class_instance.invoke(test_dialogue)

total_metrics[class_] = metrics

return total_metrics


if __name__=="__main__":
if __name__ == "__main__":
with open("dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/results/results.json") as f:
old_data = json.load(f)

new_metrics = {str(datetime.datetime.now()): run_all_algorithms()}

old_data.update(new_metrics)
old_data.update(new_metrics)

with open("dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/results/results.json", 'w') as f:
f.write(json.dumps(old_data, indent=2, ensure_ascii=False))
with open("dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/results/results.json", "w") as f:
f.write(json.dumps(old_data, indent=2, ensure_ascii=False))
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
from chatsky_llm_autoconfig.autometrics.registry import AlgorithmRegistry


@AlgorithmRegistry.register(
input_type=BaseGraph,
output_type=list[Dialogue]
)
@AlgorithmRegistry.register(input_type=BaseGraph, output_type=list[Dialogue])
class DialogueSampler(DialogueGenerator):

def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = -1, topic="") -> list[Dialogue]:
nx_graph = graph.graph
if end_node == -1:
Expand Down

0 comments on commit ad3a5e1

Please sign in to comment.