From bfa9bb540edeed5c76e543af98ce6c60c8f0b090 Mon Sep 17 00:00:00 2001 From: "longbin.lailb" Date: Tue, 19 Nov 2024 02:07:49 +0800 Subject: [PATCH] add paper_inspector_test --- python/graphy/graph/nodes/base_node.py | 4 +- .../graphy/graph/nodes/paper_reading_nodes.py | 46 +++----- .../tests/workflow/paper_inspector_test.py | 105 +++++++++++++++++- python/graphy/utils/arxiv_fetcher.py | 2 - python/graphy/utils/scholar_fetcher.py | 5 - python/graphy/workflow/executor.py | 2 +- 6 files changed, 122 insertions(+), 42 deletions(-) diff --git a/python/graphy/graph/nodes/base_node.py b/python/graphy/graph/nodes/base_node.py index 57ce9acdd..0c95815fe 100644 --- a/python/graphy/graph/nodes/base_node.py +++ b/python/graphy/graph/nodes/base_node.py @@ -33,11 +33,11 @@ def __init__(self, name: str, node_type: NodeType = NodeType.BASE): def pre_execute(self, state: Dict[str, Any] = None): """define pre-execution logic""" - logger.info(f"Executing node: {self.get_node_key()}") + logger.info(f"Executing node: {self.name}") def post_execute(self, output: Dict[str, Any] = None): """define post-execution logic""" - logger.info(f"Complete executing node: {self.get_node_key()}") + logger.info(f"Complete executing node: {self.name}") def execute( self, state: Dict[str, Any], input: DataGenerator = None diff --git a/python/graphy/graph/nodes/paper_reading_nodes.py b/python/graphy/graph/nodes/paper_reading_nodes.py index f944f7b89..67bae1bd4 100644 --- a/python/graphy/graph/nodes/paper_reading_nodes.py +++ b/python/graphy/graph/nodes/paper_reading_nodes.py @@ -378,69 +378,59 @@ def run_through( DataGenerator: Outputs generated by the workflow nodes. """ - current_node_name = self.graph.get_first_node_name() - next_nodes = [current_node_name] + next_nodes = [self.graph.get_first_node()] while next_nodes: - current_node_name = next_nodes.pop() # Run in DFS order - if current_node_name in skipped_nodes: - self.progress[current_node_name].complete() + current_node = next_nodes.pop() # Run in DFS order + if current_node.name in skipped_nodes: + self.progress[current_node.name].complete() self.progress["total"].complete() continue - curr_node = self.graph.get_node(current_node_name) - if not curr_node: - logger.error(f"Node '{current_node_name}' not found in the graph.") - continue - last_output = None try: - curr_node.pre_execute(state) + current_node.pre_execute(state) # Execute the current node - output_generator = curr_node.execute(state) + output_generator = current_node.execute(state) for output in output_generator: last_output = output except Exception as e: - logger.error(f"Error executing node '{current_node_name}': {e}") + logger.error(f"Error executing node '{current_node.name}': {e}") if continue_on_error: continue else: - raise ValueError(f"Error executing node '{current_node_name}': {e}") + raise ValueError(f"Error executing node '{current_node.name}': {e}") finally: # Complete progress tracking - self.progress[current_node_name].complete() + self.progress[current_node.name].complete() self.progress["total"].complete() # Persist the output and queries if applicable if last_output and is_persist and self.persist_store: self.persist_store.save_state( - data_id, current_node_name, last_output + data_id, current_node.name, last_output ) - if curr_node.get_query(): - input_query = f""" - **************QUERY***************: - {curr_node.get_query()} - **************MEMORY**************: - {curr_node.get_memory()} - """ + if current_node.get_query(): + input_query = f"**************QUERY***************: \n {current_node.get_query()} \ + **************MEMORY**************: \n {current_node.get_memory()}" self.persist_store.save_query( - data_id, current_node_name, input_query + data_id, current_node.name, input_query ) # Cache the output if last_output: node_caches: dict = state.get(WF_STATE_CACHE_KEY, {}) node_cache: NodeCache = node_caches.setdefault( - current_node_name, NodeCache(current_node_name) + current_node.name, NodeCache(current_node.name) ) if node_cache: node_cache.add_chat_cache("", last_output) - curr_node.post_execute(last_output) + current_node.post_execute(last_output) # Add adjacent nodes to the processing queue - for next_node in reversed(self.graph.get_adjacent_nodes(current_node_name)): - next_nodes.append(next_node.name) + for next_node in reversed(self.graph.get_adjacent_nodes(current_node.name)): + next_nodes.append(next_node) @profiler.profile def execute( diff --git a/python/graphy/tests/workflow/paper_inspector_test.py b/python/graphy/tests/workflow/paper_inspector_test.py index 2369d35da..b53e2ac80 100644 --- a/python/graphy/tests/workflow/paper_inspector_test.py +++ b/python/graphy/tests/workflow/paper_inspector_test.py @@ -6,14 +6,25 @@ from unittest.mock import MagicMock, create_autospec from graph import BaseGraph, BaseEdge from graph.nodes import PaperInspector, BaseNode -from graph.nodes.paper_reading_nodes import ProgressInfo, ExtractNode +from graph.nodes.paper_reading_nodes import ( + ProgressInfo, + ExtractNode, + create_inspector_graph, +) from graph.types import DataGenerator -from models import LLM -from db import PersistentStore -from config import WF_STATE_CACHE_KEY, WF_STATE_MEMORY_KEY, WF_STATE_EXTRACTOR_KEY +from models import LLM, set_llm_model, DEFAULT_LLM_MODEL_CONFIG, DefaultEmbedding +from db import PersistentStore, JsonFileStore +from config import ( + WF_STATE_CACHE_KEY, + WF_STATE_MEMORY_KEY, + WF_STATE_EXTRACTOR_KEY, + WF_OUTPUT_DIR, +) from langchain_core.embeddings import Embeddings +import os + @pytest.fixture def mock_graph(): @@ -87,3 +98,89 @@ def test_execute(mock_paper_inspector): assert len(outputs) == 1 assert outputs[0] == {"result": "node_output"} + + +@pytest.mark.skip(reason="The LLM model must be set to run this.") +def test_inspector_execute(): + llm_model = set_llm_model(DEFAULT_LLM_MODEL_CONFIG) + embeddings_model = DefaultEmbedding() + + workflow = { + "nodes": [ + {"name": "Paper"}, + { + "name": "Contribution", + "query": "**Question**:\nList all contributions of the paper...", + "extract_from": ["1"], + "output_schema": { + "type": "array", + "description": "A list of contributions.", + "item": [ + { + "name": "original", + "type": "string", + "description": "The original contribution sentences.", + }, + { + "name": "summary", + "type": "string", + "description": "The summary of the contribution.", + }, + ], + }, + }, + { + "name": "Challenge", + "query": "**Question**:\nPlease summarize some challenges in this paper...", + "extract_from": [], + "output_schema": { + "type": "array", + "description": "A list of challenges...", + "item": [ + { + "name": "name", + "type": "string", + "description": "The summarized name of the challenge.", + }, + { + "name": "description", + "type": "string", + "description": "The description of the challenge.", + }, + { + "name": "solution", + "type": "string", + "description": "The solution of the challenge.", + }, + ], + }, + }, + ], + "edges": [ + {"source": "Paper", "target": "Contribution"}, + {"source": "Contribution", "target": "Challenge"}, + ], + } + + graph = create_inspector_graph( + workflow, llm_model, llm_model, embeddings_model.chroma_embedding_model() + ) + + persist_store = JsonFileStore(WF_OUTPUT_DIR) + + inspector = PaperInspector( + "PaperInspector", + llm_model, + embeddings_model.chroma_embedding_model(), + graph, + persist_store, + ) + + state = {} + inputs = [ + {"paper_file_path": "inputs/samples/graphrag.pdf"}, + {"paper_file_path": "inputs/samples/huge-sigmod21.pdf"}, + ] + + for output in inspector.execute(state, iter(inputs)): + print(output) diff --git a/python/graphy/utils/arxiv_fetcher.py b/python/graphy/utils/arxiv_fetcher.py index 5f52830b8..6d4d91800 100644 --- a/python/graphy/utils/arxiv_fetcher.py +++ b/python/graphy/utils/arxiv_fetcher.py @@ -239,8 +239,6 @@ def fetch_papers_concurrently( # Append file names to the list filenames.append(file) - print(filenames) - for file_name in filenames: download_foler = os.path.join(f"{WF_DOWNLOADS_DIR}", file_name.split(".")[0]) fetcher = ArxivFetcher(download_folder=download_foler) diff --git a/python/graphy/utils/scholar_fetcher.py b/python/graphy/utils/scholar_fetcher.py index c2982fdd1..a0451850e 100644 --- a/python/graphy/utils/scholar_fetcher.py +++ b/python/graphy/utils/scholar_fetcher.py @@ -49,7 +49,6 @@ def fetch_paper(self, name: str, mode="vague"): logger.error(f"Error searching google scholar: {e}") paper_info = None - print(paper_info) # logger.debug(paper_info) if paper_info is None: return None, None @@ -58,7 +57,3 @@ def fetch_paper(self, name: str, mode="vague"): paper_bib = paper_info.get("bib", None) return fetch_result, paper_bib - - -if __name__ == "__main__": - pass diff --git a/python/graphy/workflow/executor.py b/python/graphy/workflow/executor.py index cd6424aeb..c4d02bd5d 100644 --- a/python/graphy/workflow/executor.py +++ b/python/graphy/workflow/executor.py @@ -77,7 +77,7 @@ def execute(self, initial_inputs: List[DataType]) -> Dict[str, Any]: state = self.workflow.state # Add all initial inputs to the task queue with the first node - first_node = self.workflow.graph.get_node_names()[0] + first_node = self.workflow.graph.get_first_node_name() for input_data in initial_inputs: self.task_queue.put((input_data, first_node))