diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_based_refinement.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_based_refinement.py index 5aa76ac3b..0cc0b9aaf 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_based_refinement.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_based_refinement.py @@ -14,25 +14,97 @@ from decompiler.structures.logic.logic_condition import LogicCondition +@dataclass +class CandidateProperties: + operands: List[LogicCondition] + symbols: Set[str] + + @classmethod + def initialize(cls, node: AbstractSyntaxTreeNode) -> CandidateProperties: + operands = list(node.reaching_condition.operands) if node.reaching_condition.is_conjunction else [node.reaching_condition.copy()] + symbols = set(node.reaching_condition.get_symbols_as_string()) + return CandidateProperties(operands, symbols) + + @property + def number_of_interesting_operands(self) -> int: + return len(self.operands) + + class ConditionCandidates: def __init__(self, candidates: List[AbstractSyntaxTreeNode]): - self._candidates: Dict[AbstractSyntaxTreeNode, List[LogicCondition]] = { - c: list(c.reaching_condition.operands) if c.reaching_condition.is_conjunction else [c.reaching_condition] for c in candidates - } - self._max_subexpression_size: Optional[int] = None + self._candidates: Dict[AbstractSyntaxTreeNode, CandidateProperties] = {c: CandidateProperties.initialize(c) for c in candidates} + self._max_subexpression_size: int = max( + (candidate_property.number_of_interesting_operands for candidate_property in self._candidates.values()), default=0 + ) + + def __iter__(self): + yield from self._candidates.items() + @property def maximum_subexpression_size(self) -> int: - max = 0 - second_max = 0 - for candidate, operands in self._candidates.values(): - if len(operands) >= max: - second_max = max - max = len(operands) - elif len(operands) > second_max: - second_max = len(operands) - self._max_subexpression_size = min(second_max, self._max_subexpression_size) + if len(self._candidates) < 2: + self._max_subexpression_size = 0 + else: + all_sizes = [candidate_property.number_of_interesting_operands for candidate_property in self._candidates.values()] + all_sizes.remove(max(all_sizes)) + self._max_subexpression_size = min(max(all_sizes), self._max_subexpression_size) return self._max_subexpression_size + def get_next_subexpression(self): + while (current_size := self.maximum_subexpression_size) > 0: + childrens_to_consider = [c for c, p in self._candidates.items() if p.number_of_interesting_operands >= current_size] + for child in childrens_to_consider: + if child not in self._candidates: + continue + if current_size > self.maximum_subexpression_size: + break + if current_size == 1: + for operand in self._candidates[child].operands: + yield child, operand + if child not in self._candidates or current_size > self.maximum_subexpression_size: + break + else: + for new_operands in combinations(self._candidates[child].operands, current_size): + yield child, LogicCondition.conjunction_of(new_operands) + if child not in self._candidates or current_size > self._max_subexpression_size: + break + self._max_subexpression_size -= 1 + + # def _get_logical_and_subexpressions_of(self, condition: LogicCondition) -> Iterator[LogicCondition]: + # """ + # Get logical and-subexpressions of the input condition. + # + # We get the following expressions + # - If the condition is a Symbol or a Not, the whole condition + # - If the condition is an And, every possible combination of its And-arguments + # - If the condition is an Or, either the condition if all arguments are Symbols or Not or nothing otherwise. + # """ + # if condition.is_true: + # yield from () + # elif condition.is_symbol or condition.is_negation or condition.is_disjunction: + # yield condition.copy() + # elif condition.is_conjunction: + # for sub_expression in self._all_subsets(condition.operands): + # if len(sub_expression) == 1: + # yield sub_expression[0] + # else: + # yield LogicCondition.conjunction_of(sub_expression) + # else: + # raise ValueError(f"Received a condition which is not a Symbol, Or, Not, or And: {condition}") + # + # @staticmethod + # def _all_subsets(arguments: List[LogicCondition]) -> Iterator[Tuple[LogicCondition]]: + # """ + # Given a set of elements, in our case z3-expressions, it returns an iterator that contains each combination of the input arguments + # as a tuple. + # + # (1,2,3) --> Iterator[(1,2,3) (1,2) (1,3) (1,) (2,) (3,)] + # """ + # return (arg for size in range(len(arguments), 0, -1) for arg in combinations(arguments, size)) + def remove(self, nodes_to_remove: List[AbstractSyntaxTreeNode]): + for node in nodes_to_remove: + del self._candidates[node] + class ConditionBasedRefinement: """ @@ -63,7 +135,7 @@ def _condition_based_refinement(self) -> None: 1. Find nodes with complementary reaching conditions. 2. Find nodes that have some factors in common. """ - assert isinstance(self.root, SeqNode), f"The root note {self.root} should be a sequence node!" + assert isinstance(self.root, SeqNode), f"The root node {self.root} should be a sequence node!" self._refine_code_nodes_with_complementary_conditions() newly_created_sequence_nodes: Set[SeqNode] = {self.root} @@ -111,53 +183,42 @@ def _structure_sequence_node(self, sequence_node: SeqNode) -> Set[SeqNode]: :param sequence_node: The sequence nodes whose children we want to structure. :return: The set of sequence nodes we add during structuring the given sequence node. """ - visited = set() + # visited = set() newly_created_sequence_nodes: Set[SeqNode] = set() sibling_reachability: SiblingReachability = self.asforest.get_sibling_reachability_of_children_of(sequence_node) - all_children = list() subexpression_of_node = dict() - interesting_children = [child for child in sequence_node.children if not child.reaching_condition.is_true] - if len(interesting_children) < 2: - return newly_created_sequence_nodes - for child in interesting_children: - all_children.append(repr(child)) + condition_candidates = ConditionCandidates([child for child in sequence_node.children if not child.reaching_condition.is_true]) + for child, subexpression in condition_candidates.get_next_subexpression(): # TODO Also stop if it is the last child with a reaching condition to consider! - if child in visited: - continue - subexpression_of_node[child] = 0 # TODO: only compute "useful" subexpressions! - for subexpression in self._get_logical_and_subexpressions_of(child.reaching_condition): - subexpression_of_node[child] += 1 - true_cluster, false_cluster = self._cluster_by_condition(subexpression, sequence_node) - all_cluster_nodes = true_cluster + false_cluster + # for subexpression in self._get_logical_and_subexpressions_of(child.reaching_condition): + true_cluster, false_cluster = self._cluster_by_condition(subexpression, condition_candidates) + all_cluster_nodes = true_cluster + false_cluster - if len(all_cluster_nodes) < 2: - continue - if self._can_place_condition_node_with_branches(all_cluster_nodes, sibling_reachability): - condition_node = self.asforest.create_condition_node_with(subexpression, true_cluster, false_cluster) - if len(true_cluster) > 1: - newly_created_sequence_nodes.add(condition_node.true_branch_child) - if len(false_cluster) > 1: - newly_created_sequence_nodes.add(condition_node.false_branch_child) - sibling_reachability.merge_siblings_to(condition_node, all_cluster_nodes) - visited.update(all_cluster_nodes) - sequence_node._sorted_children = sibling_reachability.sorted_nodes() - break - if subexpression_of_node: - # print("[" + ",\n".join(child for child in all_children) + "]") - for child, numb in subexpression_of_node.items(): - print(f"consider node {child} with {numb} subexpressions:") + if len(all_cluster_nodes) < 2: + continue + if self._can_place_condition_node_with_branches(all_cluster_nodes, sibling_reachability): + condition_node = self.asforest.create_condition_node_with(subexpression, true_cluster, false_cluster) + if len(true_cluster) > 1: + newly_created_sequence_nodes.add(condition_node.true_branch_child) + if len(false_cluster) > 1: + newly_created_sequence_nodes.add(condition_node.false_branch_child) + sibling_reachability.merge_siblings_to(condition_node, all_cluster_nodes) + sequence_node._sorted_children = sibling_reachability.sorted_nodes() + # TODO remove nodes from condition candidates! + condition_candidates.remove(all_cluster_nodes) + # break return newly_created_sequence_nodes def _cluster_by_condition( - self, condition: LogicCondition, sequence_node: SeqNode + self, condition: LogicCondition, condition_candidates: ConditionCandidates ) -> Tuple[List[AbstractSyntaxTreeNode], List[AbstractSyntaxTreeNode]]: """ Cluster the nodes in sequence_nodes according to the input condition. :param condition: The condition for which we check whether it or its negation is a subexpression of the list of input nodes. - :param sequence_node: The sequence node we want to cluster. + :param condition_candidates: TODO The sequence node we want to cluster. :return: A 2-tuple, where the first list is the set of nodes that have condition as subexpression, the second list is the set of nodes that have the negated condition as subexpression. """ @@ -166,8 +227,8 @@ def _cluster_by_condition( symbols_of_condition = set(condition.get_symbols_as_string()) negated_condition = None - for node in sequence_node.children: - if symbols_of_condition - set(node.reaching_condition.get_symbols_as_string()): + for node, properties in condition_candidates: + if symbols_of_condition - properties.symbols: continue # TODO: we should not check this for the node we currently consider! if self._is_subexpression_of_cnf_formula(condition, node.reaching_condition): @@ -239,38 +300,6 @@ def _is_contained_in_logic_conditions(sub_expression: LogicCondition, logic_cond """Check whether the given sub_expression is contained in the list of logic conditions""" return any(sub_expression.does_imply(condition) for condition in logic_conditions) - def _get_logical_and_subexpressions_of(self, condition: LogicCondition) -> Iterator[LogicCondition]: - """ - Get logical and-subexpressions of the input condition. - - We get the following expressions - - If the condition is a Symbol or a Not, the whole condition - - If the condition is an And, every possible combination of its And-arguments - - If the condition is an Or, either the condition if all arguments are Symbols or Not or nothing otherwise. - """ - if condition.is_true: - yield from () - elif condition.is_symbol or condition.is_negation or condition.is_disjunction: - yield condition.copy() - elif condition.is_conjunction: - for sub_expression in self._all_subsets(condition.operands): - if len(sub_expression) == 1: - yield sub_expression[0] - else: - yield LogicCondition.conjunction_of(sub_expression) - else: - raise ValueError(f"Received a condition which is not a Symbol, Or, Not, or And: {condition}") - - @staticmethod - def _all_subsets(arguments: List[LogicCondition]) -> Iterator[Tuple[LogicCondition]]: - """ - Given a set of elements, in our case z3-expressions, it returns an iterator that contains each combination of the input arguments - as a tuple. - - (1,2,3) --> Iterator[(1,2,3) (1,2) (1,3) (1,) (2,) (3,)] - """ - return (arg for size in range(len(arguments), 0, -1) for arg in combinations(arguments, size)) - @staticmethod def _can_place_condition_node_with_branches(branches: List[AbstractSyntaxTreeNode], sibling_reachability: SiblingReachability) -> bool: """ diff --git a/decompiler/structures/logic/z3_implementations.py b/decompiler/structures/logic/z3_implementations.py index 704127a46..793f1e955 100644 --- a/decompiler/structures/logic/z3_implementations.py +++ b/decompiler/structures/logic/z3_implementations.py @@ -179,9 +179,10 @@ def simplify_z3_condition(self, z3_condition: BoolRef, resolve_negations: bool = """ if self._resolve_negations and resolve_negations: z3_condition = self._resolve_negation(z3_condition) - if self._too_large_to_fully_simplify(z3_condition): - return simplify(Repeat(Tactic("ctx-simplify", ctx=z3_condition.ctx))(z3_condition).as_expr()) - return Repeat(Tactic("ctx-solver-simplify", ctx=z3_condition.ctx))(z3_condition).as_expr() + z3_condition = simplify(Repeat(Tactic("ctx-simplify", ctx=z3_condition.ctx))(z3_condition).as_expr()) + if not self._too_large_to_fully_simplify(z3_condition): + z3_condition = simplify(Repeat(Tactic("ctx-solver-simplify", ctx=z3_condition.ctx))(z3_condition).as_expr()) + return z3_condition @staticmethod def get_symbols(condition: BoolRef) -> Iterator[BoolRef]: