Skip to content

Commit

Permalink
add paper_inspector_test
Browse files Browse the repository at this point in the history
  • Loading branch information
longbinlai committed Nov 18, 2024
1 parent 7805b5a commit bfa9bb5
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 42 deletions.
4 changes: 2 additions & 2 deletions python/graphy/graph/nodes/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def __init__(self, name: str, node_type: NodeType = NodeType.BASE):

def pre_execute(self, state: Dict[str, Any] = None):
"""define pre-execution logic"""
logger.info(f"Executing node: {self.get_node_key()}")
logger.info(f"Executing node: {self.name}")

def post_execute(self, output: Dict[str, Any] = None):
"""define post-execution logic"""
logger.info(f"Complete executing node: {self.get_node_key()}")
logger.info(f"Complete executing node: {self.name}")

def execute(
self, state: Dict[str, Any], input: DataGenerator = None
Expand Down
46 changes: 18 additions & 28 deletions python/graphy/graph/nodes/paper_reading_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,69 +378,59 @@ def run_through(
DataGenerator: Outputs generated by the workflow nodes.
"""

current_node_name = self.graph.get_first_node_name()
next_nodes = [current_node_name]
next_nodes = [self.graph.get_first_node()]

while next_nodes:
current_node_name = next_nodes.pop() # Run in DFS order
if current_node_name in skipped_nodes:
self.progress[current_node_name].complete()
current_node = next_nodes.pop() # Run in DFS order
if current_node.name in skipped_nodes:
self.progress[current_node.name].complete()
self.progress["total"].complete()
continue

curr_node = self.graph.get_node(current_node_name)
if not curr_node:
logger.error(f"Node '{current_node_name}' not found in the graph.")
continue

last_output = None
try:
curr_node.pre_execute(state)
current_node.pre_execute(state)
# Execute the current node
output_generator = curr_node.execute(state)
output_generator = current_node.execute(state)
for output in output_generator:
last_output = output
except Exception as e:
logger.error(f"Error executing node '{current_node_name}': {e}")
logger.error(f"Error executing node '{current_node.name}': {e}")
if continue_on_error:
continue
else:
raise ValueError(f"Error executing node '{current_node_name}': {e}")
raise ValueError(f"Error executing node '{current_node.name}': {e}")
finally:
# Complete progress tracking
self.progress[current_node_name].complete()
self.progress[current_node.name].complete()
self.progress["total"].complete()

# Persist the output and queries if applicable
if last_output and is_persist and self.persist_store:
self.persist_store.save_state(
data_id, current_node_name, last_output
data_id, current_node.name, last_output
)
if curr_node.get_query():
input_query = f"""
**************QUERY***************:
{curr_node.get_query()}
**************MEMORY**************:
{curr_node.get_memory()}
"""
if current_node.get_query():
input_query = f"**************QUERY***************: \n {current_node.get_query()} \
**************MEMORY**************: \n {current_node.get_memory()}"
self.persist_store.save_query(
data_id, current_node_name, input_query
data_id, current_node.name, input_query
)

# Cache the output
if last_output:
node_caches: dict = state.get(WF_STATE_CACHE_KEY, {})
node_cache: NodeCache = node_caches.setdefault(
current_node_name, NodeCache(current_node_name)
current_node.name, NodeCache(current_node.name)
)
if node_cache:
node_cache.add_chat_cache("", last_output)

curr_node.post_execute(last_output)
current_node.post_execute(last_output)

# Add adjacent nodes to the processing queue
for next_node in reversed(self.graph.get_adjacent_nodes(current_node_name)):
next_nodes.append(next_node.name)
for next_node in reversed(self.graph.get_adjacent_nodes(current_node.name)):
next_nodes.append(next_node)

@profiler.profile
def execute(
Expand Down
105 changes: 101 additions & 4 deletions python/graphy/tests/workflow/paper_inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,25 @@
from unittest.mock import MagicMock, create_autospec
from graph import BaseGraph, BaseEdge
from graph.nodes import PaperInspector, BaseNode
from graph.nodes.paper_reading_nodes import ProgressInfo, ExtractNode
from graph.nodes.paper_reading_nodes import (
ProgressInfo,
ExtractNode,
create_inspector_graph,
)
from graph.types import DataGenerator
from models import LLM
from db import PersistentStore
from config import WF_STATE_CACHE_KEY, WF_STATE_MEMORY_KEY, WF_STATE_EXTRACTOR_KEY
from models import LLM, set_llm_model, DEFAULT_LLM_MODEL_CONFIG, DefaultEmbedding
from db import PersistentStore, JsonFileStore
from config import (
WF_STATE_CACHE_KEY,
WF_STATE_MEMORY_KEY,
WF_STATE_EXTRACTOR_KEY,
WF_OUTPUT_DIR,
)

from langchain_core.embeddings import Embeddings

import os


@pytest.fixture
def mock_graph():
Expand Down Expand Up @@ -87,3 +98,89 @@ def test_execute(mock_paper_inspector):

assert len(outputs) == 1
assert outputs[0] == {"result": "node_output"}


@pytest.mark.skip(reason="The LLM model must be set to run this.")
def test_inspector_execute():
llm_model = set_llm_model(DEFAULT_LLM_MODEL_CONFIG)
embeddings_model = DefaultEmbedding()

workflow = {
"nodes": [
{"name": "Paper"},
{
"name": "Contribution",
"query": "**Question**:\nList all contributions of the paper...",
"extract_from": ["1"],
"output_schema": {
"type": "array",
"description": "A list of contributions.",
"item": [
{
"name": "original",
"type": "string",
"description": "The original contribution sentences.",
},
{
"name": "summary",
"type": "string",
"description": "The summary of the contribution.",
},
],
},
},
{
"name": "Challenge",
"query": "**Question**:\nPlease summarize some challenges in this paper...",
"extract_from": [],
"output_schema": {
"type": "array",
"description": "A list of challenges...",
"item": [
{
"name": "name",
"type": "string",
"description": "The summarized name of the challenge.",
},
{
"name": "description",
"type": "string",
"description": "The description of the challenge.",
},
{
"name": "solution",
"type": "string",
"description": "The solution of the challenge.",
},
],
},
},
],
"edges": [
{"source": "Paper", "target": "Contribution"},
{"source": "Contribution", "target": "Challenge"},
],
}

graph = create_inspector_graph(
workflow, llm_model, llm_model, embeddings_model.chroma_embedding_model()
)

persist_store = JsonFileStore(WF_OUTPUT_DIR)

inspector = PaperInspector(
"PaperInspector",
llm_model,
embeddings_model.chroma_embedding_model(),
graph,
persist_store,
)

state = {}
inputs = [
{"paper_file_path": "inputs/samples/graphrag.pdf"},
{"paper_file_path": "inputs/samples/huge-sigmod21.pdf"},
]

for output in inspector.execute(state, iter(inputs)):
print(output)
2 changes: 0 additions & 2 deletions python/graphy/utils/arxiv_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,6 @@ def fetch_papers_concurrently(
# Append file names to the list
filenames.append(file)

print(filenames)

for file_name in filenames:
download_foler = os.path.join(f"{WF_DOWNLOADS_DIR}", file_name.split(".")[0])
fetcher = ArxivFetcher(download_folder=download_foler)
Expand Down
5 changes: 0 additions & 5 deletions python/graphy/utils/scholar_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def fetch_paper(self, name: str, mode="vague"):
logger.error(f"Error searching google scholar: {e}")
paper_info = None

print(paper_info)
# logger.debug(paper_info)
if paper_info is None:
return None, None
Expand All @@ -58,7 +57,3 @@ def fetch_paper(self, name: str, mode="vague"):
paper_bib = paper_info.get("bib", None)

return fetch_result, paper_bib


if __name__ == "__main__":
pass
2 changes: 1 addition & 1 deletion python/graphy/workflow/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def execute(self, initial_inputs: List[DataType]) -> Dict[str, Any]:
state = self.workflow.state

# Add all initial inputs to the task queue with the first node
first_node = self.workflow.graph.get_node_names()[0]
first_node = self.workflow.graph.get_first_node_name()
for input_data in initial_inputs:
self.task_queue.put((input_data, first_node))

Expand Down

0 comments on commit bfa9bb5

Please sign in to comment.