diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_based_refinement.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_based_refinement.py index 18a9a84a6..9730456e7 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_based_refinement.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_based_refinement.py @@ -5,7 +5,7 @@ from __future__ import annotations from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from itertools import combinations from typing import DefaultDict, Dict, Iterator, List, Optional, Set, Tuple, Union @@ -14,10 +14,7 @@ from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, SeqNode from decompiler.structures.ast.reachability_graph import SiblingReachability from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest -from decompiler.structures.graphs.interface import GraphNodeInterface, GraphEdgeInterface -from decompiler.structures.graphs.nxgraph import NetworkXGraph from decompiler.structures.logic.logic_condition import LogicCondition -from decompiler.util.insertion_ordered_set import InsertionOrderedSet @dataclass @@ -63,9 +60,8 @@ def __init__(self, candidates: List[AbstractSyntaxTreeNode]) -> None: self._logic_graph: DiGraph = DiGraph() self._formulas_containing_symbol: DefaultDict[Symbol, Set[Formula]] = defaultdict(set) self._symbols_of_formula: DefaultDict[Formula, Set[Symbol]] = defaultdict(set) - self._removable_nodes: Set[Union[Formula, Clause, Symbol]] = set() self._initialize_logic_graph_and_dictionaries() - self._clean_up() + self._remove_nodes_from(set(symbol for symbol, formulas in self._formulas_containing_symbol.items() if len(formulas) == 1)) def _initialize_logic_graph_and_dictionaries(self): for formula in self._candidates.values(): @@ -77,18 +73,11 @@ def _initialize_logic_graph_and_dictionaries(self): self._logic_graph.add_edge(clause, symbol := Symbol(symbol_name)) self._formulas_containing_symbol[symbol].add(formula) self._symbols_of_formula[formula].add(symbol) - self._removable_nodes = set(symbol for symbol, formulas in self._formulas_containing_symbol.items() if len(formulas) == 1) @property def candidates(self) -> Iterator[AbstractSyntaxTreeNode]: yield from self._candidates - def get_clauses(self, node: AbstractSyntaxTreeNode) -> List[LogicCondition]: - return [clause.condition for clause in self._logic_graph.successors(self._candidates[node])] - - def get_symbols(self, node: AbstractSyntaxTreeNode) -> Set[str]: - return {symbol.name for symbol in self._symbols_of_formula[self._candidates[node]]} - @property def maximum_subexpression_size(self) -> int: if len(self._candidates) < 2: @@ -97,96 +86,90 @@ def maximum_subexpression_size(self) -> int: all_sizes.remove(max(all_sizes)) return max(all_sizes) + def get_symbols_of(self, node: AbstractSyntaxTreeNode) -> Set[str]: + return {symbol.name for symbol in self._symbols_of_formula[self._candidates[node]]} + def get_next_subexpression(self) -> Iterator[Tuple[AbstractSyntaxTreeNode, LogicCondition]]: """Consider Candidates in sequence-node order and start with the largest possible subexpression.""" - all_candidates = list(self._candidates) - for ast_node in all_candidates: + for ast_node in list(self._candidates): if ast_node not in self._candidates: continue if (max_expr_size := self.maximum_subexpression_size) == 0: break - clauses = self.get_clauses(ast_node) + clauses = self._get_clauses(ast_node) current_size = min(len(clauses), max_expr_size) while current_size > 0 and ast_node in self._candidates: - if current_size == 1: - for operand in clauses: - yield ast_node, operand - else: - for new_operands in combinations(clauses, current_size): - yield ast_node, LogicCondition.conjunction_of(new_operands) + for new_operands in combinations(clauses, current_size): + yield ast_node, LogicCondition.conjunction_of(new_operands) current_size -= 1 def remove_nodes(self, nodes_to_remove: List[AbstractSyntaxTreeNode]): """Remove the given nodes from the graph.""" - for node in nodes_to_remove: - self._remove_formula(self._candidates[node]) - self._clean_up() + self._remove_nodes_from(set(self._candidates[node] for node in nodes_to_remove)) - def _clean_up(self): - while self._removable_nodes: - node = self._removable_nodes.pop() + def _get_clauses(self, node: AbstractSyntaxTreeNode) -> List[LogicCondition]: + return [clause.condition for clause in self._logic_graph.successors(self._candidates[node])] + + def _remove_nodes_from(self, removable_nodes: Set[Union[Formula, Clause, Symbol]]): + while removable_nodes: + node = removable_nodes.pop() match node: case Formula(): - self._remove_formula(node) + removable_nodes.update(self._remove_formula(node)) case Clause(): - self._remove_clause(node) + removable_nodes.update(self._remove_clause(node)) case Symbol(): - self._remove_symbol(node) + removable_nodes.update(self._remove_symbol(node)) - def _remove_formula(self, formula: Formula): + def _remove_formula(self, formula: Formula) -> Iterator[Union[Clause, Symbol]]: """Remove the given formula from the graph.""" - self._removable_nodes.update(self._logic_graph.successors(formula)) + yield from self._logic_graph.successors(formula) self._logic_graph.remove_node(formula) for symbol in self._symbols_of_formula[formula]: - self._remove_formula_from_formula_containing_symbol(formula, symbol) + yield from self._remove_symbol_from_formula(formula, symbol) del self._candidates[formula.ast_node] - def _remove_clause(self, clause: Clause): + def _remove_clause(self, clause: Clause) -> Iterator[Union[Formula, Symbol]]: """Remove the given clause from the graph.""" if clause.formula in self._logic_graph: if self._logic_graph.out_degree(clause.formula) == 1: - self._removable_nodes.add(clause.formula) + yield clause.formula else: for symbol in (s for s in self._logic_graph.successors(clause) if not has_path(self._logic_graph, clause.formula, s)): self._symbols_of_formula[clause.formula].remove(symbol) - self._remove_formula_from_formula_containing_symbol(clause.formula, symbol) + yield from self._remove_symbol_from_formula(clause.formula, symbol) self._logic_graph.remove_node(clause) - def _remove_symbol(self, symbol: Symbol): + def _remove_symbol(self, symbol: Symbol) -> Iterator[Clause]: """Remove the given symbol from the graph.""" - self._removable_nodes.update(self._logic_graph.predecessors(symbol)) + yield from self._logic_graph.predecessors(symbol) for formula in self._formulas_containing_symbol[symbol]: self._symbols_of_formula[formula].remove(symbol) self._logic_graph.remove_node(symbol) - def _remove_formula_from_formula_containing_symbol(self, formula: Formula, symbol: Symbol): + def _remove_symbol_from_formula(self, formula: Formula, symbol: Symbol) -> Iterator[Symbol]: + """ + Update the dictionaries and decides whether we also remove the given symbol, if the symbol is not contained in the given formula. + """ self._formulas_containing_symbol[symbol].remove(formula) if len(self._formulas_containing_symbol[symbol]) <= 1: - self._removable_nodes.add(symbol) - - -# -# # def get_next_subexpression(self) -> Iterator[Tuple[AbstractSyntaxTreeNode, LogicCondition]]: -# # """Get the next subexpression together with the node it comes from and start with the largest possible subexpression!""" -# # TODO: only compute "useful" subexpressions! -# # while (current_size := self.maximum_subexpression_size) > 0: -# # children_to_consider = [c for c, p in self._candidates.items() if p.number_of_interesting_operands >= current_size] -# # for child in children_to_consider: -# # if child not in self._candidates: -# # continue -# # if current_size > self.maximum_subexpression_size: -# # break -# # if current_size == 1: -# # for operand in self._candidates[child].operands: -# # yield child, operand -# # if child not in self._candidates or current_size > self.maximum_subexpression_size: -# # break -# # else: -# # for new_operands in combinations(self._candidates[child].operands, current_size): -# # yield child, LogicCondition.conjunction_of(new_operands) -# # if child not in self._candidates or current_size > self._max_subexpression_size: -# # break -# # self._max_subexpression_size = current_size - 1 + yield symbol + + # def get_next_subexpression(self) -> Iterator[Tuple[AbstractSyntaxTreeNode, LogicCondition]]: + # """Get the next subexpression together with the node it comes from and start with the largest possible subexpression!""" + # current_size = self.maximum_subexpression_size + # while current_size > 0: + # for ast_node in [c for c, p in self._candidates.items() if self._logic_graph.out_degree(p) >= current_size]: + # if ast_node not in self._candidates: + # continue + # if current_size > self.maximum_subexpression_size: + # break + # clauses = self._get_clauses(ast_node) + # for new_operands in combinations(clauses, current_size): + # yield ast_node, LogicCondition.conjunction_of(new_operands) + # if ast_node not in self._candidates or current_size > self.maximum_subexpression_size: + # break + # current_size = min(self.maximum_subexpression_size, current_size - 1) class ConditionBasedRefinement: @@ -295,7 +278,7 @@ def _cluster_by_condition( :param sub_expression: The condition for which we check whether it or its negation is a subexpression of the list of input nodes. :param node_with_subexpression: The node of which the given sub_expression is a sub-expression - :param condition_candidates: TODO The children of the sequence node we want to cluster and that have a reaching condition. + :param condition_candidates: class-object handling all condition candidates. :return: A 2-tuple, where the first list is the set of nodes that have condition as subexpression, the second list is the set of nodes that have the negated condition as subexpression. """ @@ -304,7 +287,7 @@ def _cluster_by_condition( symbols_of_condition = set(sub_expression.get_symbols_as_string()) negated_condition = None for ast_node in condition_candidates.candidates: - if symbols_of_condition - condition_candidates.get_symbols(ast_node): + if symbols_of_condition - condition_candidates.get_symbols_of(ast_node): continue if ast_node == node_with_subexpression or self._is_subexpression_of_cnf_formula(sub_expression, ast_node.reaching_condition): true_children.append(ast_node) @@ -321,39 +304,39 @@ def _get_negated_condition_of(condition: LogicCondition, negated_condition: Opti return ~condition return negated_condition - def _is_subexpression_of_cnf_formula(self, term: LogicCondition, expression: LogicCondition) -> bool: + def _is_subexpression_of_cnf_formula(self, sub_expression: LogicCondition, condition: LogicCondition) -> bool: """ Check whether the input term is a conjunction of a subset of clauses of a CNF expression. - :param term: assumed to be CNF. May contain more than one clause. - :param expression: no assumptions made. + :param sub_expression: assumed to be CNF. May contain more than one clause. + :param condition: no assumptions made. Examples: - term = a∨b, expression = (a∨b)∧c, returns True - term = a∨b, expression = a∨b∨c, returns False; expression's CNF is (a∨b∨c). - term = (a∨b)∧c, expression = (a∨b)∧(b∨d)∧c, returns True - term = ¬(a∨b), expression = ¬a∧¬b∧¬c, returns False; term is not CNF and will not match (although this case should not occur). + sub_expression = a∨b, condition = (a∨b)∧c, returns True + sub_expression = a∨b, condition = a∨b∨c, returns False; expression's CNF is (a∨b∨c). + sub_expression = (a∨b)∧c, condition = (a∨b)∧(b∨d)∧c, returns True + sub_expression = ¬(a∨b), condition = ¬a∧¬b∧¬c, returns False; sub_expression is not CNF and will not match (although this case should not occur). """ - if (is_subexpression := self._preliminary_subexpression_checks(term, expression)) is not None: + if (is_subexpression := self._preliminary_subexpression_checks(sub_expression, condition)) is not None: return is_subexpression - expression_operands = expression.operands - term_operands = term.operands - numb_of_arg_expr = len(expression_operands) if expression.is_conjunction else 1 - numb_of_arg_term = len(term_operands) if term.is_conjunction else 1 + condition_operands = condition.operands + sub_expression_operands = sub_expression.operands + numb_of_arg_condition = len(condition_operands) if condition.is_conjunction else 1 + numb_of_arg_sub_expression = len(sub_expression_operands) if sub_expression.is_conjunction else 1 - if numb_of_arg_expr <= numb_of_arg_term: + if numb_of_arg_condition <= numb_of_arg_sub_expression: return False - subexpressions = [term] if numb_of_arg_term == 1 else term_operands - # Not sure whether we not want first the expression and then the term, since we do the same when inserting the condition-node. - # However, we could compare which operands are removed, and then decide whether this is something we want. - updated_expression_operands = (expression & term).operands - if len(updated_expression_operands) > len(expression_operands) or sum(len(op) for op in updated_expression_operands) > sum( - len(op) for op in expression_operands - ): + clauses_of_sub_expression = [sub_expression] if numb_of_arg_sub_expression == 1 else sub_expression_operands + updated_expression_operands = (condition & sub_expression).operands + if self._first_expression_is_complexer_than_second(updated_expression_operands, condition_operands): return False - if len(updated_expression_operands) < len(expression_operands): + if len(updated_expression_operands) < len(condition_operands): return True - return all(self._is_contained_in_logic_conditions(sub_expr, updated_expression_operands) for sub_expr in subexpressions) + return all(self._is_contained_in_logic_conditions(sub_expr, updated_expression_operands) for sub_expr in clauses_of_sub_expression) + + def _first_expression_is_complexer_than_second(self, expression_1: List[LogicCondition], expression_2: List[LogicCondition]): + """Check whether the clauses belonging to the first-expression are more complex than the clauses of the second expression.""" + return len(expression_1) > len(expression_2) or sum(len(op) for op in expression_1) > sum(len(op) for op in expression_2) @staticmethod def _preliminary_subexpression_checks(term: LogicCondition, expression: LogicCondition) -> Optional[bool]: