Skip to content

Commit

Permalink
some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ebehner committed Mar 26, 2024
1 parent 80812c3 commit 129b514
Showing 1 changed file with 72 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass
from itertools import combinations
from typing import DefaultDict, Dict, Iterator, List, Optional, Set, Tuple, Union

Expand All @@ -14,10 +14,7 @@
from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, SeqNode
from decompiler.structures.ast.reachability_graph import SiblingReachability
from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest
from decompiler.structures.graphs.interface import GraphNodeInterface, GraphEdgeInterface
from decompiler.structures.graphs.nxgraph import NetworkXGraph
from decompiler.structures.logic.logic_condition import LogicCondition
from decompiler.util.insertion_ordered_set import InsertionOrderedSet


@dataclass
Expand Down Expand Up @@ -63,9 +60,8 @@ def __init__(self, candidates: List[AbstractSyntaxTreeNode]) -> None:
self._logic_graph: DiGraph = DiGraph()
self._formulas_containing_symbol: DefaultDict[Symbol, Set[Formula]] = defaultdict(set)
self._symbols_of_formula: DefaultDict[Formula, Set[Symbol]] = defaultdict(set)
self._removable_nodes: Set[Union[Formula, Clause, Symbol]] = set()
self._initialize_logic_graph_and_dictionaries()
self._clean_up()
self._remove_nodes_from(set(symbol for symbol, formulas in self._formulas_containing_symbol.items() if len(formulas) == 1))

def _initialize_logic_graph_and_dictionaries(self):
for formula in self._candidates.values():
Expand All @@ -77,18 +73,11 @@ def _initialize_logic_graph_and_dictionaries(self):
self._logic_graph.add_edge(clause, symbol := Symbol(symbol_name))
self._formulas_containing_symbol[symbol].add(formula)
self._symbols_of_formula[formula].add(symbol)
self._removable_nodes = set(symbol for symbol, formulas in self._formulas_containing_symbol.items() if len(formulas) == 1)

@property
def candidates(self) -> Iterator[AbstractSyntaxTreeNode]:
yield from self._candidates

def get_clauses(self, node: AbstractSyntaxTreeNode) -> List[LogicCondition]:
return [clause.condition for clause in self._logic_graph.successors(self._candidates[node])]

def get_symbols(self, node: AbstractSyntaxTreeNode) -> Set[str]:
return {symbol.name for symbol in self._symbols_of_formula[self._candidates[node]]}

@property
def maximum_subexpression_size(self) -> int:
if len(self._candidates) < 2:
Expand All @@ -97,96 +86,90 @@ def maximum_subexpression_size(self) -> int:
all_sizes.remove(max(all_sizes))
return max(all_sizes)

def get_symbols_of(self, node: AbstractSyntaxTreeNode) -> Set[str]:
return {symbol.name for symbol in self._symbols_of_formula[self._candidates[node]]}

def get_next_subexpression(self) -> Iterator[Tuple[AbstractSyntaxTreeNode, LogicCondition]]:
"""Consider Candidates in sequence-node order and start with the largest possible subexpression."""
all_candidates = list(self._candidates)
for ast_node in all_candidates:
for ast_node in list(self._candidates):
if ast_node not in self._candidates:
continue
if (max_expr_size := self.maximum_subexpression_size) == 0:
break
clauses = self.get_clauses(ast_node)
clauses = self._get_clauses(ast_node)
current_size = min(len(clauses), max_expr_size)
while current_size > 0 and ast_node in self._candidates:
if current_size == 1:
for operand in clauses:
yield ast_node, operand
else:
for new_operands in combinations(clauses, current_size):
yield ast_node, LogicCondition.conjunction_of(new_operands)
for new_operands in combinations(clauses, current_size):
yield ast_node, LogicCondition.conjunction_of(new_operands)
current_size -= 1

def remove_nodes(self, nodes_to_remove: List[AbstractSyntaxTreeNode]):
"""Remove the given nodes from the graph."""
for node in nodes_to_remove:
self._remove_formula(self._candidates[node])
self._clean_up()
self._remove_nodes_from(set(self._candidates[node] for node in nodes_to_remove))

def _clean_up(self):
while self._removable_nodes:
node = self._removable_nodes.pop()
def _get_clauses(self, node: AbstractSyntaxTreeNode) -> List[LogicCondition]:
return [clause.condition for clause in self._logic_graph.successors(self._candidates[node])]

def _remove_nodes_from(self, removable_nodes: Set[Union[Formula, Clause, Symbol]]):
while removable_nodes:
node = removable_nodes.pop()
match node:
case Formula():
self._remove_formula(node)
removable_nodes.update(self._remove_formula(node))
case Clause():
self._remove_clause(node)
removable_nodes.update(self._remove_clause(node))
case Symbol():
self._remove_symbol(node)
removable_nodes.update(self._remove_symbol(node))

def _remove_formula(self, formula: Formula):
def _remove_formula(self, formula: Formula) -> Iterator[Union[Clause, Symbol]]:
"""Remove the given formula from the graph."""
self._removable_nodes.update(self._logic_graph.successors(formula))
yield from self._logic_graph.successors(formula)
self._logic_graph.remove_node(formula)
for symbol in self._symbols_of_formula[formula]:
self._remove_formula_from_formula_containing_symbol(formula, symbol)
yield from self._remove_symbol_from_formula(formula, symbol)
del self._candidates[formula.ast_node]

def _remove_clause(self, clause: Clause):
def _remove_clause(self, clause: Clause) -> Iterator[Union[Formula, Symbol]]:
"""Remove the given clause from the graph."""
if clause.formula in self._logic_graph:
if self._logic_graph.out_degree(clause.formula) == 1:
self._removable_nodes.add(clause.formula)
yield clause.formula
else:
for symbol in (s for s in self._logic_graph.successors(clause) if not has_path(self._logic_graph, clause.formula, s)):
self._symbols_of_formula[clause.formula].remove(symbol)
self._remove_formula_from_formula_containing_symbol(clause.formula, symbol)
yield from self._remove_symbol_from_formula(clause.formula, symbol)
self._logic_graph.remove_node(clause)

def _remove_symbol(self, symbol: Symbol):
def _remove_symbol(self, symbol: Symbol) -> Iterator[Clause]:
"""Remove the given symbol from the graph."""
self._removable_nodes.update(self._logic_graph.predecessors(symbol))
yield from self._logic_graph.predecessors(symbol)
for formula in self._formulas_containing_symbol[symbol]:
self._symbols_of_formula[formula].remove(symbol)
self._logic_graph.remove_node(symbol)

def _remove_formula_from_formula_containing_symbol(self, formula: Formula, symbol: Symbol):
def _remove_symbol_from_formula(self, formula: Formula, symbol: Symbol) -> Iterator[Symbol]:
"""
Update the dictionaries and decides whether we also remove the given symbol, if the symbol is not contained in the given formula.
"""
self._formulas_containing_symbol[symbol].remove(formula)
if len(self._formulas_containing_symbol[symbol]) <= 1:
self._removable_nodes.add(symbol)


#
# # def get_next_subexpression(self) -> Iterator[Tuple[AbstractSyntaxTreeNode, LogicCondition]]:
# # """Get the next subexpression together with the node it comes from and start with the largest possible subexpression!"""
# # TODO: only compute "useful" subexpressions!
# # while (current_size := self.maximum_subexpression_size) > 0:
# # children_to_consider = [c for c, p in self._candidates.items() if p.number_of_interesting_operands >= current_size]
# # for child in children_to_consider:
# # if child not in self._candidates:
# # continue
# # if current_size > self.maximum_subexpression_size:
# # break
# # if current_size == 1:
# # for operand in self._candidates[child].operands:
# # yield child, operand
# # if child not in self._candidates or current_size > self.maximum_subexpression_size:
# # break
# # else:
# # for new_operands in combinations(self._candidates[child].operands, current_size):
# # yield child, LogicCondition.conjunction_of(new_operands)
# # if child not in self._candidates or current_size > self._max_subexpression_size:
# # break
# # self._max_subexpression_size = current_size - 1
yield symbol

# def get_next_subexpression(self) -> Iterator[Tuple[AbstractSyntaxTreeNode, LogicCondition]]:
# """Get the next subexpression together with the node it comes from and start with the largest possible subexpression!"""
# current_size = self.maximum_subexpression_size
# while current_size > 0:
# for ast_node in [c for c, p in self._candidates.items() if self._logic_graph.out_degree(p) >= current_size]:
# if ast_node not in self._candidates:
# continue
# if current_size > self.maximum_subexpression_size:
# break
# clauses = self._get_clauses(ast_node)
# for new_operands in combinations(clauses, current_size):
# yield ast_node, LogicCondition.conjunction_of(new_operands)
# if ast_node not in self._candidates or current_size > self.maximum_subexpression_size:
# break
# current_size = min(self.maximum_subexpression_size, current_size - 1)


class ConditionBasedRefinement:
Expand Down Expand Up @@ -295,7 +278,7 @@ def _cluster_by_condition(
:param sub_expression: The condition for which we check whether it or its negation is a subexpression of the list of input nodes.
:param node_with_subexpression: The node of which the given sub_expression is a sub-expression
:param condition_candidates: TODO The children of the sequence node we want to cluster and that have a reaching condition.
:param condition_candidates: class-object handling all condition candidates.
:return: A 2-tuple, where the first list is the set of nodes that have condition as subexpression, the second list is the set of
nodes that have the negated condition as subexpression.
"""
Expand All @@ -304,7 +287,7 @@ def _cluster_by_condition(
symbols_of_condition = set(sub_expression.get_symbols_as_string())
negated_condition = None
for ast_node in condition_candidates.candidates:
if symbols_of_condition - condition_candidates.get_symbols(ast_node):
if symbols_of_condition - condition_candidates.get_symbols_of(ast_node):
continue
if ast_node == node_with_subexpression or self._is_subexpression_of_cnf_formula(sub_expression, ast_node.reaching_condition):
true_children.append(ast_node)
Expand All @@ -321,39 +304,39 @@ def _get_negated_condition_of(condition: LogicCondition, negated_condition: Opti
return ~condition
return negated_condition

def _is_subexpression_of_cnf_formula(self, term: LogicCondition, expression: LogicCondition) -> bool:
def _is_subexpression_of_cnf_formula(self, sub_expression: LogicCondition, condition: LogicCondition) -> bool:
"""
Check whether the input term is a conjunction of a subset of clauses of a CNF expression.
:param term: assumed to be CNF. May contain more than one clause.
:param expression: no assumptions made.
:param sub_expression: assumed to be CNF. May contain more than one clause.
:param condition: no assumptions made.
Examples:
term = a∨b, expression = (a∨b)∧c, returns True
term = a∨b, expression = a∨b∨c, returns False; expression's CNF is (a∨b∨c).
term = (a∨b)∧c, expression = (a∨b)∧(b∨d)∧c, returns True
term = ¬(a∨b), expression = ¬a∧¬b∧¬c, returns False; term is not CNF and will not match (although this case should not occur).
sub_expression = a∨b, condition = (a∨b)∧c, returns True
sub_expression = a∨b, condition = a∨b∨c, returns False; expression's CNF is (a∨b∨c).
sub_expression = (a∨b)∧c, condition = (a∨b)∧(b∨d)∧c, returns True
sub_expression = ¬(a∨b), condition = ¬a∧¬b∧¬c, returns False; sub_expression is not CNF and will not match (although this case should not occur).
"""
if (is_subexpression := self._preliminary_subexpression_checks(term, expression)) is not None:
if (is_subexpression := self._preliminary_subexpression_checks(sub_expression, condition)) is not None:
return is_subexpression

expression_operands = expression.operands
term_operands = term.operands
numb_of_arg_expr = len(expression_operands) if expression.is_conjunction else 1
numb_of_arg_term = len(term_operands) if term.is_conjunction else 1
condition_operands = condition.operands
sub_expression_operands = sub_expression.operands
numb_of_arg_condition = len(condition_operands) if condition.is_conjunction else 1
numb_of_arg_sub_expression = len(sub_expression_operands) if sub_expression.is_conjunction else 1

if numb_of_arg_expr <= numb_of_arg_term:
if numb_of_arg_condition <= numb_of_arg_sub_expression:
return False

subexpressions = [term] if numb_of_arg_term == 1 else term_operands
# Not sure whether we not want first the expression and then the term, since we do the same when inserting the condition-node.
# However, we could compare which operands are removed, and then decide whether this is something we want.
updated_expression_operands = (expression & term).operands
if len(updated_expression_operands) > len(expression_operands) or sum(len(op) for op in updated_expression_operands) > sum(
len(op) for op in expression_operands
):
clauses_of_sub_expression = [sub_expression] if numb_of_arg_sub_expression == 1 else sub_expression_operands
updated_expression_operands = (condition & sub_expression).operands
if self._first_expression_is_complexer_than_second(updated_expression_operands, condition_operands):
return False
if len(updated_expression_operands) < len(expression_operands):
if len(updated_expression_operands) < len(condition_operands):
return True
return all(self._is_contained_in_logic_conditions(sub_expr, updated_expression_operands) for sub_expr in subexpressions)
return all(self._is_contained_in_logic_conditions(sub_expr, updated_expression_operands) for sub_expr in clauses_of_sub_expression)

def _first_expression_is_complexer_than_second(self, expression_1: List[LogicCondition], expression_2: List[LogicCondition]):
"""Check whether the clauses belonging to the first-expression are more complex than the clauses of the second expression."""
return len(expression_1) > len(expression_2) or sum(len(op) for op in expression_1) > sum(len(op) for op in expression_2)

@staticmethod
def _preliminary_subexpression_checks(term: LogicCondition, expression: LogicCondition) -> Optional[bool]:
Expand Down

0 comments on commit 129b514

Please sign in to comment.