From 8607165bf335ac37b822c2706f8439bcdd53f187 Mon Sep 17 00:00:00 2001 From: "longbin.lailb" Date: Mon, 25 Nov 2024 19:56:39 +0800 Subject: [PATCH] refactor tests --- python/graphy/extractor/paper_extractor.py | 2 +- python/graphy/graph/nodes/paper_reading_nodes.py | 5 +++++ python/graphy/memory/memory_block.py | 2 +- python/graphy/tests/workflow/paper_inspector_test.py | 6 ++++++ python/graphy/utils/arxiv_fetcher.py | 4 ++-- 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/python/graphy/extractor/paper_extractor.py b/python/graphy/extractor/paper_extractor.py index f1d7cc83b..403c126aa 100644 --- a/python/graphy/extractor/paper_extractor.py +++ b/python/graphy/extractor/paper_extractor.py @@ -1168,7 +1168,7 @@ def _get_section_title_info(self): cur_page_num = int(item["page_num"]) if cur_page_num >= last_page_num and int(cur_sec) <= last_sec: - logger.warn(f"SECTION NUMBER CONFLICTS") + logger.warning(f"SECTION NUMBER CONFLICTS") continue section_order.append({"sec_name": "", "section_id": cur_sec}) diff --git a/python/graphy/graph/nodes/paper_reading_nodes.py b/python/graphy/graph/nodes/paper_reading_nodes.py index a55ec0e91..dbae71c0b 100644 --- a/python/graphy/graph/nodes/paper_reading_nodes.py +++ b/python/graphy/graph/nodes/paper_reading_nodes.py @@ -107,6 +107,11 @@ def __str__(self): def __repr__(self) -> str: return f"ProgressInfo [ Number: {self.number}, Completed: {self.completed} ]" + def __eq__(self, other): + if isinstance(other, ProgressInfo): + return self.completed == other.completed and self.number == other.number + return False + class ExtractNode(BaseChainNode): def __init__( diff --git a/python/graphy/memory/memory_block.py b/python/graphy/memory/memory_block.py index 6fc30fe83..97fd86226 100644 --- a/python/graphy/memory/memory_block.py +++ b/python/graphy/memory/memory_block.py @@ -216,7 +216,7 @@ def __init__(self) -> None: def add_block(self, block: MemoryBlock): block_id = block.get_block_id() if block_id in self.blocks: - logger.warn(f"Block with id {block_id} already exists in the manager.") + logger.warning(f"Block with id {block_id} already exists in the manager.") else: self.blocks[block_id] = block return block_id diff --git a/python/graphy/tests/workflow/paper_inspector_test.py b/python/graphy/tests/workflow/paper_inspector_test.py index c0f3dad5a..19097b590 100644 --- a/python/graphy/tests/workflow/paper_inspector_test.py +++ b/python/graphy/tests/workflow/paper_inspector_test.py @@ -196,5 +196,11 @@ def test_inspector_execute(): assert persist_store.get_state(data_id, "Paper") assert persist_store.get_state(data_id, "Contribution") assert persist_store.get_state(data_id, "Challenge") + assert persist_store.get_state(data_id, "_DONE") + + assert inspector.progress["Paper"] == ProgressInfo(completed=1, number=1) + assert inspector.progress["Contribution"] == ProgressInfo(completed=1, number=1) + assert inspector.progress["Challenge"] == ProgressInfo(completed=1, number=1) + assert inspector.progress["total"] == ProgressInfo(completed=3, number=3) temp_dir.cleanup() diff --git a/python/graphy/utils/arxiv_fetcher.py b/python/graphy/utils/arxiv_fetcher.py index 9b66ac1cd..03d36006d 100644 --- a/python/graphy/utils/arxiv_fetcher.py +++ b/python/graphy/utils/arxiv_fetcher.py @@ -130,7 +130,7 @@ def find_paper_from_arxiv(self, name, max_results): if highest_similarity > 0.9 or found_result: break - logger.warn(f"Not Found: {query}") + logger.warning(f"Not Found: {query}") if highest_similarity > 0.9: break @@ -164,7 +164,7 @@ def download_paper(self, name: str, max_results): return download_list else: - logger.warn(f"Failed to fetch paper with arxiv: {name}") + logger.warning(f"Failed to fetch paper with arxiv: {name}") return [] def fetch_paper(self, name: str, max_results):