Skip to content

Commit

Permalink
start with issue
Browse files Browse the repository at this point in the history
  • Loading branch information
ebehner committed Apr 19, 2024
1 parent be57b90 commit 13b8dfa
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _group_by_reaching_conditions(self, nodes: Tuple[AbstractSyntaxTreeNode]) ->
:param nodes: The AST nodes that we want to group.
:return: A dictionary that assigns to a reaching condition the list of AST code nodes with this reaching condition,
if it are at least two with the same.
if there are at least two with the same.
"""
initial_groups: Dict[LogicCondition, List[AbstractSyntaxTreeNode]] = dict()
for node in nodes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def _update_reaching_condition_for_case_node_children(self, switch_node: SwitchN
case_node.reaching_condition.is_disjunction_of_literals
), f"The condition of a case node should be a disjunction, but it is {case_node.reaching_condition}!"

if isinstance(cond_node := case_node.child, ConditionNode) and cond_node.false_branch is None:
if (cond_node := case_node.child).is_single_branch:
self._update_condition_for(cond_node, case_node)

case_node.child.reaching_condition = case_node.child.reaching_condition.substitute_by_true(case_node.reaching_condition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from itertools import chain, combinations
from typing import Dict, Iterator, List, Optional, Set, Tuple

from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, SeqNode
from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, SeqNode, ConditionNode
from decompiler.structures.ast.reachability_graph import SiblingReachability
from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest
from decompiler.structures.logic.logic_condition import LogicCondition
Expand All @@ -23,17 +23,32 @@ class Formula:
- setting eq to false implies that two objects are equal and have the same hash iff they are the same object
"""

condition: LogicCondition
ast_node: AbstractSyntaxTreeNode

def clauses(self) -> List[LogicCondition]:
@property
def is_if_else_formula(self) -> bool:
"""Check whether condition of formula belongs to an if-else condition."""
return self.ast_node.reaching_condition.is_true and not self.ast_node.is_single_branch

@property
def condition(self) -> LogicCondition:
if self.ast_node.reaching_condition.is_true:
assert isinstance(self.ast_node, ConditionNode), "The ast-node must be a condition node if the RC is true"
return self.ast_node.condition
return self.ast_node.reaching_condition

def clauses(self) -> List[Clause]:
"""
Returns all clauses of the given formula in cnf-form.
- formula = (a | b) & (c | d) & e, it returns [a | b, c | d, e] --> here each operand is a new logic-condition
- formula = a | b | c, it returns [a | b | c] --> to ensure that we get a new logic-condition we copy it in this case.
"""
return list(self.condition.operands) if self.condition.is_conjunction else [self.condition.copy()]
if self.is_if_else_formula:
return [ClauseFormula(self.ast_node.condition.copy(), self)]
else:
clauses = list(self.condition.operands) if self.condition.is_conjunction else [self.condition.copy()]
return [Clause(c, self) for c in clauses]


@dataclass(frozen=True, eq=False)
Expand All @@ -47,6 +62,14 @@ class Clause:
formula: Formula


@dataclass(frozen=True, eq=False)
class ClauseFormula(Clause):
"""
Dataclass for logic-formula that can not be split into clauses for the grouping.
- setting eq to false implies that two objects are equal and have the same hash iff they are the same object
"""


@dataclass(frozen=True, eq=True)
class Symbol:
"""
Expand All @@ -70,7 +93,7 @@ def __init__(self, candidates: List[AbstractSyntaxTreeNode]) -> None:
- unconsidered_nodes: a set of all nodes that we still have to consider for grouping into conditions.
- logic_graph: representation of all logic-formulas relevant
"""
self._candidates: Dict[AbstractSyntaxTreeNode, Formula] = {c: Formula(c.reaching_condition, c) for c in candidates}
self._candidates: Dict[AbstractSyntaxTreeNode, Formula] = {c: Formula(c) for c in candidates}
self._unconsidered_nodes: InsertionOrderedSet[AbstractSyntaxTreeNode] = InsertionOrderedSet()
self._logic_graph: DiGraph = DiGraph()
self._initialize_logic_graph()
Expand All @@ -89,9 +112,9 @@ def _initialize_logic_graph(self) -> None:
all_symbols = set()
for formula in self._candidates.values():
self._logic_graph.add_node(formula)
for logic_clause in formula.clauses():
self._logic_graph.add_edge(formula, clause := Clause(logic_clause, formula))
for symbol_name in logic_clause.get_symbols_as_string():
for clause in formula.clauses():
self._logic_graph.add_edge(formula, clause)
for symbol_name in clause.condition.get_symbols_as_string():
self._logic_graph.add_edge(clause, symbol := Symbol(symbol_name))
self._logic_graph.add_edge(formula, symbol, auxiliary=True)
all_symbols.add(symbol)
Expand All @@ -102,6 +125,10 @@ def candidates(self) -> Iterator[AbstractSyntaxTreeNode]:
"""Iterates over all candidates considered for grouping into conditions."""
yield from self._candidates

def get_condition(self, ast_node: AbstractSyntaxTreeNode) -> Tuple[LogicCondition, bool]:
"""Return the condition that is relevant for grouping into branches."""
return self._candidates[ast_node].condition, self._candidates[ast_node].is_if_else_formula

def maximum_subexpression_size(self) -> int:
"""Returns the maximum possible subexpression that is relevant to consider for clustering into conditions."""
if len(self._candidates) < 2:
Expand Down Expand Up @@ -130,6 +157,18 @@ def remove_ast_nodes(self, nodes_to_remove: List[AbstractSyntaxTreeNode]) -> Non
"""Remove formulas associated with the given nodes from the graph."""
self._remove_formulas(set(self._candidates[node] for node in nodes_to_remove))

def add_ast_node(self, condition_node: ConditionNode):
"""Add new node to the logic-graph"""
formula = Formula(condition_node)
self._candidates[condition_node] = formula
self._unconsidered_nodes.add(condition_node)
self._logic_graph.add_node(formula)
for clause in formula.clauses():
self._logic_graph.add_edge(formula, clause)
for symbol_name in clause.condition.get_symbols_as_string():
self._logic_graph.add_edge(clause, symbol := Symbol(symbol_name))
self._logic_graph.add_edge(formula, symbol, auxiliary=True)

@property
def _auxiliary_graph(self) -> DiGraph:
"""Return a read-only view of the logic-graph containing only the auxiliary-edges, i.e., the edges between formulas and symbols."""
Expand Down Expand Up @@ -231,6 +270,7 @@ def __init__(self, asforest: AbstractSyntaxForest):
"""Init an instance of the condition-based refinement."""
self.asforest: AbstractSyntaxForest = asforest
self.root: AbstractSyntaxTreeNode = asforest.current_root
self._condition_candidates: Optional[ConditionCandidates] = None

@classmethod
def refine(cls, asforest: AbstractSyntaxForest) -> None:
Expand Down Expand Up @@ -297,28 +337,41 @@ def _structure_sequence_node(self, sequence_node: SeqNode) -> Set[SeqNode]:
"""
newly_created_sequence_nodes: Set[SeqNode] = set()
sibling_reachability: SiblingReachability = self.asforest.get_sibling_reachability_of_children_of(sequence_node)
condition_candidates = ConditionCandidates([child for child in sequence_node.children if not child.reaching_condition.is_true])
for child, subexpression in condition_candidates.get_next_subexpression():
true_cluster, false_cluster = self._cluster_by_condition(subexpression, child, condition_candidates)
self._condition_candidates = ConditionCandidates(
[child for child in sequence_node.children if not child.reaching_condition.is_true or isinstance(child, ConditionNode)]
)
for child, subexpression in self._condition_candidates.get_next_subexpression():
true_cluster, false_cluster, existing_if_else_conditions = self._cluster_by_condition(subexpression, child)
all_cluster_nodes = true_cluster + false_cluster

if len(all_cluster_nodes) < 2:
continue
if self._can_place_condition_node_with_branches(all_cluster_nodes, sibling_reachability):
for existing_if_else_cond in existing_if_else_conditions:
if existing_if_else_cond in true_cluster:
true_cluster.remove(existing_if_else_cond)
true_cluster.append(existing_if_else_cond.true_branch_child)
false_cluster.append(existing_if_else_cond.false_branch_child)
else:
false_cluster.remove(existing_if_else_cond)
true_cluster.append(existing_if_else_cond.false_branch_child)
false_cluster.append(existing_if_else_cond.true_branch_child)
self.asforest.transform_branch_to_reaching_conditions(existing_if_else_cond)
condition_node = self.asforest.create_condition_node_with(subexpression, true_cluster, false_cluster)
if len(true_cluster) > 1:
newly_created_sequence_nodes.add(condition_node.true_branch_child)
if len(false_cluster) > 1:
newly_created_sequence_nodes.add(condition_node.false_branch_child)
sibling_reachability.merge_siblings_to(condition_node, all_cluster_nodes)
sequence_node._sorted_children = sibling_reachability.sorted_nodes()
condition_candidates.remove_ast_nodes(all_cluster_nodes)
self._condition_candidates.add_ast_node(condition_node)
self._condition_candidates.remove_ast_nodes(all_cluster_nodes)

return newly_created_sequence_nodes

def _cluster_by_condition(
self, sub_expression: LogicCondition, node_with_subexpression: AbstractSyntaxTreeNode, condition_candidates: ConditionCandidates
) -> Tuple[List[AbstractSyntaxTreeNode], List[AbstractSyntaxTreeNode]]:
self, sub_expression: LogicCondition, node_with_subexpression: AbstractSyntaxTreeNode
) -> Tuple[List[AbstractSyntaxTreeNode], List[AbstractSyntaxTreeNode], List[ConditionNode]]:
"""
Cluster the nodes in sequence_nodes according to the input condition.
Expand All @@ -331,17 +384,30 @@ def _cluster_by_condition(
true_children = []
false_children = []
symbols_of_condition = set(sub_expression.get_symbols_as_string())
negated_condition = None
for ast_node in condition_candidates.candidates:
if symbols_of_condition - condition_candidates.get_symbol_names_of(ast_node):
negated_condition: Optional[LogicCondition] = None
existing_if_else_condition: List[ConditionNode] = []
for ast_node in self._condition_candidates.candidates:
if symbols_of_condition - self._condition_candidates.get_symbol_names_of(ast_node):
continue
if ast_node == node_with_subexpression or self._is_subexpression_of_cnf_formula(sub_expression, ast_node.reaching_condition):
condition, is_if_else_node = self._condition_candidates.get_condition(ast_node)
if (
ast_node == node_with_subexpression
or (not is_if_else_node and self._is_subexpression_of_cnf_formula(sub_expression, condition))
or (is_if_else_node and sub_expression.is_equivalent_to(condition))
):
true_children.append(ast_node)
if is_if_else_node:
existing_if_else_condition.append(ast_node)
else:
negated_condition = self._get_negated_condition_of(sub_expression, negated_condition)
if self._is_subexpression_of_cnf_formula(negated_condition, ast_node.reaching_condition):
if (not is_if_else_node and self._is_subexpression_of_cnf_formula(negated_condition, condition)) or (
is_if_else_node and negated_condition.is_equivalent_to(condition)
):
false_children.append(ast_node)
return true_children, false_children
if is_if_else_node:
existing_if_else_condition.append(ast_node)

return true_children, false_children, existing_if_else_condition

@staticmethod
def _get_negated_condition_of(condition: LogicCondition, negated_condition: Optional[LogicCondition]) -> LogicCondition:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ def can_be_applied(loop_node: AbstractSyntaxTreeNode):
return (
loop_node.is_endless_loop
and isinstance(body := loop_node.body, SeqNode)
and isinstance(condition_node := body.children[-1], ConditionNode)
and len(condition_node.children) == 1
and body.children[-1].is_single_branch
and not any(child._has_descendant_code_node_breaking_ancestor_loop() for child in body.children[:-1])
)

Expand Down
11 changes: 10 additions & 1 deletion decompiler/structures/ast/ast_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ def is_code_node_ending_with_return(self) -> bool:
"""Checks whether the node is a CodeNode and ends with a return."""
return isinstance(self, CodeNode) and self.does_end_with_return

@property
def is_single_branch(self) -> bool:
"""Check whether the node is a condition node with one branch."""
return isinstance(self, ConditionNode) and len(self.children) == 1

def get_end_nodes(self) -> Iterable[Union[CodeNode, SwitchNode, LoopNode, ConditionNode]]:
"""Yields all nodes where the subtree can terminate."""
for child in self.children:
Expand Down Expand Up @@ -595,8 +600,12 @@ def clean(self) -> None:
"""Standardizing a Condition node is to remove empty True/False Branches and to make sure that the true branch always exists."""
for dead_child in (child for child in self.children if child.child is None):
self._ast.remove_subtree(dead_child)
if len(self.children) == 1 and self.true_branch is None:
if (len(self.children) == 1 and self.true_branch is None) or self.condition.is_false:
self.switch_branches()
if self.condition.is_true:
if self.false_branch is not None:
self._ast.remove_subtree(self.false_branch)
self._ast.replace_condition_node_by_single_branch(self)
super().clean()

def replace_variable(self, replacee: Variable, replacement: Variable) -> None:
Expand Down
30 changes: 28 additions & 2 deletions decompiler/structures/ast/syntaxforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def extract_branch_from_condition_node(
"""
Extract the given Branch from the condition node.
-> Afterwards, the Branch must always be executed after the condition node.
-> Afterward, the Branch must always be executed after the condition node.
"""
assert isinstance(cond_node, ConditionNode) and branch in cond_node.children, f"{branch} must be a child of {cond_node}."
new_seq = self._add_sequence_node_before(cond_node)
Expand Down Expand Up @@ -311,6 +311,9 @@ def create_condition_node_with(
condition_node = self._add_condition_node_with(condition, true_branch, false_branch)
self._add_edge(parent, condition_node)

for branch in true_cases + false_cases:
branch.clean()

return condition_node

def __create_branch_for(self, branch_nodes: List[AbstractSyntaxTreeNode], condition: LogicCondition):
Expand All @@ -323,11 +326,34 @@ def __create_branch_for(self, branch_nodes: List[AbstractSyntaxTreeNode], condit
else:
branch = self.add_seq_node_with_reaching_condition_before(branch_nodes, self.condition_handler.get_true_value())
for node in branch_nodes:
node.reaching_condition.substitute_by_true(condition)
if node.reaching_condition.is_true:
assert isinstance(node, ConditionNode), "The node must be a condition node if its RC is true"
node.condition.substitute_by_true(condition)
else:
node.reaching_condition.substitute_by_true(condition)

self._remove_edge(branch.parent, branch)
return branch

def transform_branch_to_reaching_conditions(self, condition_node: ConditionNode):
"""Transform a branch into a sequence-node having the branch-children as children with the according reaching-condition."""
condition_node.clean()
parent = condition_node.parent
new_seq_node = self._add_sequence_node_before(condition_node)

self._add_edge(new_seq_node, condition_node.true_branch_child)
condition_node.true_branch_child.reaching_condition = condition_node.condition
nodes = [condition_node.true_branch_child]
if condition_node.false_branch:
self._add_edge(new_seq_node, condition_node.false_branch_child)
condition_node.false_branch_child.reaching_condition = ~condition_node.condition
nodes.append(condition_node.false_branch_child)
self._remove_nodes_from([condition_node, condition_node.true_branch, condition_node.false_branch])

new_seq_node._sorted_children = tuple(nodes)
new_seq_node.clean()
parent.clean()

def create_switch_node_with(self, expression: Expression, cases: List[Tuple[CaseNode, AbstractSyntaxTreeNode]]) -> SwitchNode:
"""Create a switch node with the given expression and the given list of case nodes."""
assert (parent := self.have_same_parent([case[1] for case in cases])) is not None, "All case nodes must have the same parent."
Expand Down

0 comments on commit 13b8dfa

Please sign in to comment.