Skip to content

Commit

Permalink
updated the way we pass facts to reason again to be same as reasoning…
Browse files Browse the repository at this point in the history
… for the first time
  • Loading branch information
dyumanaditya committed Feb 27, 2025
1 parent 956ef89 commit 49aa83a
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 23 deletions.
31 changes: 11 additions & 20 deletions pyreason/pyreason.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def add_annotation_function(function: Callable) -> None:
__annotation_functions.append(function)


def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bound_threshold: float = -1, queries: List[Query] = None, again: bool = False, facts: List[Fact] = None):
def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bound_threshold: float = -1, queries: List[Query] = None, again: bool = False):
"""Function to start the main reasoning process. Graph and rules must already be loaded.
:param timesteps: Max number of timesteps to run. -1 specifies run till convergence. If reasoning again, this is the number of timesteps to reason for extra (no zero timestep), defaults to -1
Expand Down Expand Up @@ -653,10 +653,10 @@ def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bou
else:
if settings.memory_profile:
start_mem = mp.memory_usage(max_usage=True)
mem_usage, interp = mp.memory_usage((_reason_again, [timesteps, convergence_threshold, convergence_bound_threshold, facts]), max_usage=True, retval=True)
mem_usage, interp = mp.memory_usage((_reason_again, [timesteps, convergence_threshold, convergence_bound_threshold]), max_usage=True, retval=True)
print(f"\nProgram used {mem_usage-start_mem} MB of memory")
else:
interp = _reason_again(timesteps, convergence_threshold, convergence_bound_threshold, facts)
interp = _reason_again(timesteps, convergence_threshold, convergence_bound_threshold)

return interp

Expand Down Expand Up @@ -742,34 +742,25 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queri
# Run Program and get final interpretation
interpretation = __program.reason(timesteps, convergence_threshold, convergence_bound_threshold, settings.verbose)

# Clear facts after reasoning, so that reasoning again is possible with any added facts
__node_facts = None
__edge_facts = None

return interpretation


def _reason_again(timesteps, convergence_threshold, convergence_bound_threshold, facts):
def _reason_again(timesteps, convergence_threshold, convergence_bound_threshold):
# Globals
global __graph, __rules, __node_facts, __edge_facts, __ipl, __specific_node_labels, __specific_edge_labels, __graphml_parser
global settings, __timestamp, __program

assert __program is not None, 'To run `reason_again` you need to have reasoned once before'

# Parse new facts and Extend current set of facts with the new facts supplied
# Extend facts
all_node_facts = numba.typed.List.empty_list(fact_node.fact_type)
all_edge_facts = numba.typed.List.empty_list(fact_edge.fact_type)
fact_cnt = 1
for fact in facts:
if fact.type == 'node':
print(fact.name)
if fact.name is None:
fact.name = f'fact_{len(__node_facts)+len(__edge_facts)+fact_cnt}'
f = fact_node.Fact(fact.name, fact.component, fact.pred, fact.bound, fact.start_time, fact.end_time, fact.static)
all_node_facts.append(f)
fact_cnt += 1
else:
if fact.name is None:
fact.name = f'fact_{len(__node_facts)+len(__edge_facts)+fact_cnt}'
f = fact_edge.Fact(fact.name, fact.component, fact.pred, fact.bound, fact.start_time, fact.end_time, fact.static)
all_edge_facts.append(f)
fact_cnt += 1
all_node_facts.extend(numba.typed.List(__node_facts))
all_edge_facts.extend(numba.typed.List(__edge_facts))

# Run Program and get final interpretation
interpretation = __program.reason_again(timesteps, convergence_threshold, convergence_bound_threshold, all_node_facts, all_edge_facts, settings.verbose)
Expand Down
18 changes: 18 additions & 0 deletions pyreason/scripts/interpretation/interpretation.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
timestep_loop = True
facts_to_be_applied_node_new = numba.typed.List.empty_list(facts_to_be_applied_node_type)
facts_to_be_applied_edge_new = numba.typed.List.empty_list(facts_to_be_applied_edge_type)
facts_to_be_applied_node_trace_new = numba.typed.List.empty_list(numba.types.string)
facts_to_be_applied_edge_trace_new = numba.typed.List.empty_list(numba.types.string)
rules_to_remove_idx = set()
rules_to_remove_idx.add(-1)
while timestep_loop:
Expand Down Expand Up @@ -260,6 +262,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
# Start by applying facts
# Nodes
facts_to_be_applied_node_new.clear()
facts_to_be_applied_node_trace_new.clear()
nodes_set = set(nodes)
for i in range(len(facts_to_be_applied_node)):
if facts_to_be_applied_node[i][0] == t:
Expand Down Expand Up @@ -317,17 +320,25 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi

if static:
facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, l, bnd, static, graph_attribute))
if atom_trace:
facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i])

# If time doesn't match, fact to be applied later
else:
facts_to_be_applied_node_new.append(facts_to_be_applied_node[i])
if atom_trace:
facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i])

# Update list of facts with ones that have not been applied yet (delete applied facts)
facts_to_be_applied_node[:] = facts_to_be_applied_node_new.copy()
if atom_trace:
facts_to_be_applied_node_trace[:] = facts_to_be_applied_node_trace_new.copy()
facts_to_be_applied_node_new.clear()
facts_to_be_applied_node_trace_new.clear()

# Edges
facts_to_be_applied_edge_new.clear()
facts_to_be_applied_edge_trace_new.clear()
edges_set = set(edges)
for i in range(len(facts_to_be_applied_edge)):
if facts_to_be_applied_edge[i][0]==t:
Expand Down Expand Up @@ -383,14 +394,21 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi

if static:
facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, l, bnd, static, graph_attribute))
if atom_trace:
facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i])

# Time doesn't match, fact to be applied later
else:
facts_to_be_applied_edge_new.append(facts_to_be_applied_edge[i])
if atom_trace:
facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i])

# Update list of facts with ones that have not been applied yet (delete applied facts)
facts_to_be_applied_edge[:] = facts_to_be_applied_edge_new.copy()
if atom_trace:
facts_to_be_applied_edge_trace[:] = facts_to_be_applied_edge_trace_new.copy()
facts_to_be_applied_edge_new.clear()
facts_to_be_applied_edge_trace_new.clear()

in_loop = True
while in_loop:
Expand Down
18 changes: 18 additions & 0 deletions pyreason/scripts/interpretation/interpretation_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
timestep_loop = True
facts_to_be_applied_node_new = numba.typed.List.empty_list(facts_to_be_applied_node_type)
facts_to_be_applied_edge_new = numba.typed.List.empty_list(facts_to_be_applied_edge_type)
facts_to_be_applied_node_trace_new = numba.typed.List.empty_list(numba.types.string)
facts_to_be_applied_edge_trace_new = numba.typed.List.empty_list(numba.types.string)
rules_to_remove_idx = set()
rules_to_remove_idx.add(-1)
while timestep_loop:
Expand Down Expand Up @@ -260,6 +262,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
# Start by applying facts
# Nodes
facts_to_be_applied_node_new.clear()
facts_to_be_applied_node_trace_new.clear()
nodes_set = set(nodes)
for i in range(len(facts_to_be_applied_node)):
if facts_to_be_applied_node[i][0] == t:
Expand Down Expand Up @@ -317,17 +320,25 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi

if static:
facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, l, bnd, static, graph_attribute))
if atom_trace:
facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i])

# If time doesn't match, fact to be applied later
else:
facts_to_be_applied_node_new.append(facts_to_be_applied_node[i])
if atom_trace:
facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i])

# Update list of facts with ones that have not been applied yet (delete applied facts)
facts_to_be_applied_node[:] = facts_to_be_applied_node_new.copy()
if atom_trace:
facts_to_be_applied_node_trace[:] = facts_to_be_applied_node_trace_new.copy()
facts_to_be_applied_node_new.clear()
facts_to_be_applied_node_trace_new.clear()

# Edges
facts_to_be_applied_edge_new.clear()
facts_to_be_applied_edge_trace_new.clear()
edges_set = set(edges)
for i in range(len(facts_to_be_applied_edge)):
if facts_to_be_applied_edge[i][0]==t:
Expand Down Expand Up @@ -383,14 +394,21 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi

if static:
facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, l, bnd, static, graph_attribute))
if atom_trace:
facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i])

# Time doesn't match, fact to be applied later
else:
facts_to_be_applied_edge_new.append(facts_to_be_applied_edge[i])
if atom_trace:
facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i])

# Update list of facts with ones that have not been applied yet (delete applied facts)
facts_to_be_applied_edge[:] = facts_to_be_applied_edge_new.copy()
if atom_trace:
facts_to_be_applied_edge_trace[:] = facts_to_be_applied_edge_trace_new.copy()
facts_to_be_applied_edge_new.clear()
facts_to_be_applied_edge_trace_new.clear()

in_loop = True
while in_loop:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setup(
name='pyreason',
version='3.0.2',
version='3.0.3',
author='Dyuman Aditya',
author_email='[email protected]',
description='An explainable inference software supporting annotated, real valued, graph based and temporal logic',
Expand Down
4 changes: 2 additions & 2 deletions tests/test_reason_again.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def test_reason_again():

# Now reason again
new_fact = pr.Fact('popular(Mary)', 'popular_fact2', 2, 4)
interpretation = pr.reason(timesteps=3, again=True, facts=[new_fact])
pr.save_rule_trace(interpretation)
pr.add_fact(new_fact)
interpretation = pr.reason(timesteps=3, again=True)

# Display the changes in the interpretation for each timestep
dataframes = pr.filter_and_sort_nodes(interpretation, ['popular'])
Expand Down

0 comments on commit 49aa83a

Please sign in to comment.