Skip to content

Commit

Permalink
refactor check persit store
Browse files Browse the repository at this point in the history
  • Loading branch information
longbinlai committed Nov 19, 2024
1 parent 7a43799 commit 96ddd9b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/graphy/db/base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_total_states(self, name: str) -> List[str]:
json_items = [
os.path.splitext(item)[0] # Get the file name without extension
for item in items
if item.endswith(".json") and not item.startswith(".")
if item.endswith(".json") and not item.startswith("_")
]
return json_items
except Exception as e:
Expand Down
29 changes: 22 additions & 7 deletions python/graphy/graph/nodes/paper_reading_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,6 @@ def run_through(
state,
parent_id=None,
continue_on_error: bool = True,
is_persist: bool = True,
skipped_nodes: List[str] = [],
):
"""
Expand All @@ -381,6 +380,7 @@ def run_through(

first_node = self.graph.get_first_node()
next_nodes = [first_node]
is_persist = True

while next_nodes:
current_node = next_nodes.pop() # Run in DFS order
Expand All @@ -393,10 +393,19 @@ def run_through(
last_output = None
try:
current_node.pre_execute(state)
# Execute the current node
output_generator = current_node.execute(state)
for output in output_generator:
last_output = output
persist_results = self.persist_store.get_state(
data_id, current_node.name
)
if persist_results:
logger.info(f"Found persisted data for node '{current_node.name}'")
last_output = persist_results
is_persist = False
else:
# Execute the current node
output_generator = current_node.execute(state)
for output in output_generator:
last_output = output
is_persist = True
except Exception as e:
logger.error(f"Error executing node '{current_node.name}': {e}")
if continue_on_error:
Expand Down Expand Up @@ -428,7 +437,7 @@ def run_through(
edges.append(f"{parent_id}|{curr_id}")
else:
edges = [f"{parent_id}|{curr_id}"]
self.persist_store.save_state(data_id, "_Edges", edges)
self.persist_store.save_state(data_id, "_Edges", edges)

# Cache the output
if last_output:
Expand Down Expand Up @@ -478,7 +487,7 @@ def execute(
data_id = process_id(base_name)
pdf_extractor.set_img_path(f"{WF_IMAGE_DIR}/{data_id}")
first_node_name = self.graph.get_first_node_name()
if data_id in state["processed_data"]:
if self.persist_store.get_state(data_id, "_DONE"):
# This means that the data has already processed
logger.info(f"Input with ID '{data_id}' already processed.")
yield self.persist_store.get_state(data_id, first_node_name)
Expand All @@ -505,6 +514,12 @@ def execute(

self.progress["total"].add(ProgressInfo(self.graph.nodes_count(), 0))
self.run_through(data_id, state[data_id], parent_id)
# Mark the data as DONE
if (
len(self.persist_store.get_total_states(data_id))
== self.graph.nodes_count()
):
self.persist_store.save_state(data_id, "_DONE", {"done": True})

yield state[data_id][WF_STATE_CACHE_KEY][first_node_name].get_response()

Expand Down
1 change: 0 additions & 1 deletion python/graphy/workflow/survey_paper_reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
graph,
persist_store,
)
self.state["processed_data"] = set(persist_store.get_total_data())

def _create_graph(
self, workflow_dict, llm_model, parser_model, embeddings_model, persist_store
Expand Down

0 comments on commit 96ddd9b

Please sign in to comment.