Skip to content

Commit

Permalink
some refactroing to fulfill PR remarks
Browse files Browse the repository at this point in the history
  • Loading branch information
ebehner committed Apr 5, 2024
1 parent 493f73c commit 665bb7c
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,59 @@

from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass
from itertools import combinations
from typing import DefaultDict, Dict, Iterator, List, Optional, Set, Tuple, Union
from itertools import combinations, chain
from typing import Dict, Iterator, List, Optional, Set, Tuple

from networkx import DiGraph, has_path, subgraph_view

from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, SeqNode
from decompiler.structures.ast.reachability_graph import SiblingReachability
from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest
from decompiler.structures.logic.logic_condition import LogicCondition
from networkx import DiGraph, has_path, subgraph_view

from decompiler.util.insertion_ordered_set import InsertionOrderedSet


@dataclass
@dataclass(frozen=True, eq=False)
class Formula:
"""Dataclass for logic-formulas."""
"""
Dataclass for logic-formulas.
- setting eq to false implies that two objects are equal and have the same hash iff they are the same object
"""

condition: LogicCondition
ast_node: AbstractSyntaxTreeNode

def __hash__(self) -> int:
"""Formulas should hash the same only if they are the same object."""
return id(self)
def clauses(self) -> List[LogicCondition]:
"""
Returns all clauses of the given formula in cnf-form.
- formula = (a | b) & (c | d) & e, it returns [a | b, c | d, e] --> here each operand is a new logic-condition
- formula = a | b | c, it returns [a | b | c] --> to ensure that we get a new logic-condition we copy it in this case.
"""
return list(self.condition.operands) if self.condition.is_conjunction else [self.condition.copy()]


@dataclass
@dataclass(frozen=True, eq=False)
class Clause:
"""Dataclass for logic-clauses."""
"""
Dataclass for logic-clauses.
- setting eq to false implies that two objects are equal and have the same hash iff they are the same object
"""

condition: LogicCondition
formula: Formula

def __hash__(self) -> int:
"""Clauses should hash the same only if they are the same object."""
return id(self)


@dataclass
@dataclass(frozen=True, eq=True)
class Symbol:
"""Dataclass for logic-symbols."""
"""
Dataclass for logic-symbols.
- setting eq to true implies that two objects are equal and have the same hash iff their attributes are the same
"""

name: str

def __hash__(self):
"""Symbols should hash the same if they have the same name."""
return hash(self.name)


class ConditionCandidates:
"""A graph implementation handling conditions for the condition-based refinement algorithm."""
Expand All @@ -77,16 +82,13 @@ def _initialize_logic_graph(self):
all_symbols = set()
for formula in self._candidates.values():
self._logic_graph.add_node(formula)
formula_clauses: List[LogicCondition] = (
list(formula.condition.operands) if formula.condition.is_conjunction else [formula.condition.copy()]
)
for logic_clause in formula_clauses:
for logic_clause in formula.clauses():
self._logic_graph.add_edge(formula, clause := Clause(logic_clause, formula))
for symbol_name in logic_clause.get_symbols_as_string():
self._logic_graph.add_edge(clause, symbol := Symbol(symbol_name))
self._logic_graph.add_edge(formula, symbol, auxiliary=True)
all_symbols.add(symbol)
self._remove_nodes_from(set(symbol for symbol in all_symbols if self._symbol_only_in_one_clause(symbol)))
self._remove_symbols(set(symbol for symbol in all_symbols if self._symbol_only_in_one_formula(symbol)))

@property
def candidates(self) -> Iterator[AbstractSyntaxTreeNode]:
Expand Down Expand Up @@ -118,66 +120,11 @@ def get_next_subexpression(self) -> Iterator[Tuple[AbstractSyntaxTreeNode, Logic
yield ast_node, LogicCondition.conjunction_of(new_operands)
current_size -= 1

def remove_nodes(self, nodes_to_remove: List[AbstractSyntaxTreeNode]):
def remove_ast_nodes(self, nodes_to_remove: List[AbstractSyntaxTreeNode]):
"""Remove the given nodes from the graph."""
self._remove_nodes_from(set(self._candidates[node] for node in nodes_to_remove))

def _get_clauses(self, node: AbstractSyntaxTreeNode) -> List[LogicCondition]:
"""Return all clauses that are contained in the formula of the given ast-node."""
return [clause.condition for clause in self._formula_graph.successors(self._candidates[node])]

def _remove_nodes_from(self, removable_nodes: Set[Union[Formula, Clause, Symbol]]):
"""
Remove all nodes from the given set of nodes from the logic-graph and also all nodes that have to be removed after removeing these.
"""
while removable_nodes:
node = removable_nodes.pop()
match node:
case Formula():
removable_nodes.update(self._remove_formula(node))
case Clause():
removable_nodes.update(self._remove_clause(node))
case Symbol():
removable_nodes.update(self._remove_symbol(node))

def _remove_formula(self, formula: Formula) -> Iterator[Union[Clause, Symbol]]:
"""Remove the given formula from the graph."""
yield from self._formula_graph.successors(formula)
symbols_of_formula = list(self._auxiliary_graph.successors(formula))
self._logic_graph.remove_node(formula)
for symbol in symbols_of_formula:
yield from self._remove_symbol_from_formula(symbol)
del self._candidates[formula.ast_node]
self._unconsidered_nodes.discard(formula.ast_node)

def _remove_clause(self, clause: Clause) -> Iterator[Union[Formula, Symbol]]:
"""Remove the given clause from the graph."""
if clause.formula in self._logic_graph:
if self._formula_graph.out_degree(clause.formula) == 1:
yield clause.formula
else:
for symbol in (s for s in self._logic_graph.successors(clause) if not has_path(self._formula_graph, clause.formula, s)):
self._logic_graph.remove_edge(clause.formula, symbol)
yield from self._remove_symbol_from_formula(symbol)
self._logic_graph.remove_node(clause)

def _remove_symbol(self, symbol: Symbol) -> Iterator[Clause]:
"""Remove the given symbol from the graph."""
yield from self._formula_graph.predecessors(symbol)
# for formula in self._auxiliary_graph.predecessors(symbol):
# self._symbols_of_formula[formula].remove(symbol)
self._logic_graph.remove_node(symbol)

def _remove_symbol_from_formula(self, symbol: Symbol) -> Iterator[Symbol]:
"""
Update the dictionaries and decides whether we also remove the given symbol, if the symbol is not contained in the given formula.
"""
if self._auxiliary_graph.in_degree(symbol) <= 1:
yield symbol

def _symbol_only_in_one_clause(self, symbol: Symbol):
"""Checks whether the symbol is only contained in one clause"""
return self._auxiliary_graph.in_degree(symbol) == 1
# for node in nodes_to_remove:
self._remove_formulas(set(self._candidates[node] for node in nodes_to_remove))
# self._remove_nodes_from(set(self._candidates[node] for node in nodes_to_remove))

@property
def _auxiliary_graph(self) -> DiGraph:
Expand All @@ -193,6 +140,67 @@ def filter_non_auxiliary_edges(source, sink):

return subgraph_view(self._logic_graph, filter_edge=filter_non_auxiliary_edges)

def _get_clauses(self, node: AbstractSyntaxTreeNode) -> List[LogicCondition]:
"""Return all clauses that are contained in the formula of the given ast-node."""
return [clause.condition for clause in self._formula_graph.successors(self._candidates[node])]

def _symbol_only_in_one_formula(self, symbol: Symbol):
"""Checks whether the symbol is only contained in one formula"""
return self._auxiliary_graph.in_degree(symbol) == 1

def _remove_formulas(self, removing_formulas: Set[Formula]):
"""
Remove all formulas from the logic-graph and all nodes that have to be removed afterward.
1. remove each clause that are contained in one of the formulas
2. remove each formula from the logic-graph
3. remove all symbols that are only contained in these formulas and one other formula,
i.e., these are only contained in one formula after removing this formula.
"""
symbols_of_formulas: Set[Symbol] = set()
for formula in removing_formulas:
self._logic_graph.remove_nodes_from(list(self._formula_graph.successors(formula)))
symbols_of_formulas.add(self._auxiliary_graph.successors(formula))
self._remove_formula_node(formula)
self._remove_symbols(set(symbol for symbol in symbols_of_formulas if self._symbol_only_in_one_formula(symbol)))

def _remove_formula_node(self, formula: Formula):
"""Remove the formula-node from the logic-graph, including updating the candidates and unconsidered-nodes"""
self._logic_graph.remove_node(formula)
del self._candidates[formula.ast_node]
self._unconsidered_nodes.discard(formula.ast_node)

def _remove_symbols(self, removing_symbols: Set[Symbol]):
"""
Remove all symbols from the logic-graph and all nodes that have to be removed afterward.
1. If we do not have to remove any symbol, we do noting
2. Remove each symbols from the logic-graph
3. For each clause that contains at least one of the symbols, we
i. remove the clause
ii. for each symbol of the clause, we check whether the symbol is in no other clause of the formula containing this clause
- True: remove the auxiliary-edge between the formula and the symbol
and add it to the new_single_formula_nodes iff it is one afterward.
- False: noting
iii. If the formula that contains this clause has no children anymore, we remove it.
4. Remove all symbols that are now only contained in one formula.
"""
if not removing_symbols:
return
clauses_containing_any_symbol = set(chain.from_iterable(self._formula_graph.predecessors(symbol) for symbol in removing_symbols))
self._logic_graph.remove_nodes_from(removing_symbols)
new_single_formula_nodes = set()
for clause in clauses_containing_any_symbol:
symbols_of_clause = list(self._logic_graph.successors(clause))
self._logic_graph.remove_node(clause)
for clause_symbol in (s for s in symbols_of_clause if not has_path(self._formula_graph, clause.formula, s)):
self._logic_graph.remove_edge(clause.formula, clause_symbol)
if self._symbol_only_in_one_formula(clause_symbol):
new_single_formula_nodes.add(clause_symbol)
if self._formula_graph.out_degree(clause.formula) == 0:
self._remove_formula_node(clause.formula)
self._remove_symbols(new_single_formula_nodes)

# def get_next_subexpression(self) -> Iterator[Tuple[AbstractSyntaxTreeNode, LogicCondition]]:
# """Get the next subexpression together with the node it comes from and start with the largest possible subexpression!"""
# current_size = self.maximum_subexpression_size
Expand Down Expand Up @@ -307,7 +315,7 @@ def _structure_sequence_node(self, sequence_node: SeqNode) -> Set[SeqNode]:
newly_created_sequence_nodes.add(condition_node.false_branch_child)
sibling_reachability.merge_siblings_to(condition_node, all_cluster_nodes)
sequence_node._sorted_children = sibling_reachability.sorted_nodes()
condition_candidates.remove_nodes(all_cluster_nodes)
condition_candidates.remove_ast_nodes(all_cluster_nodes)

return newly_created_sequence_nodes

Expand Down
2 changes: 1 addition & 1 deletion decompiler/structures/graphs/nxgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __contains__(self, obj: Union[NODE, EDGE]):
"""Check if a node or edge is contained in the graph."""
if isinstance(obj, GraphNodeInterface):
return obj in self._graph
return any(obj in data["data"] for _, _, data in self._graph.edges(data=True))
return any(obj == data["data"] for _, _, data in self._graph.edges(data=True))

def iter_depth_first(self, source: NODE) -> Iterator[NODE]:
"""Iterate all nodes in dfs fashion."""
Expand Down
8 changes: 8 additions & 0 deletions tests/structures/graphs/test_graph_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,11 @@ def test_get_shortest_path(self, nodes):
assert graph.get_shortest_path(nodes[0], nodes[6]) == (nodes[0], nodes[1], nodes[3], nodes[6])
graph.add_edge(BasicEdge(nodes[0], nodes[6]))
assert graph.get_shortest_path(nodes[0], nodes[6]) == (nodes[0], nodes[6])

def test_contains(self):
"""Test the contains method."""
graph, nodes, edges = self.get_easy_graph()
assert nodes[0] in graph
assert not BasicNode(6) in graph
assert edges[2] in graph
assert not BasicEdge(nodes[2], nodes[0]) in graph

0 comments on commit 665bb7c

Please sign in to comment.