Skip to content

Commit

Permalink
Refine Initial Switch Construction (#114)
Browse files Browse the repository at this point in the history
Co-authored-by: jnhols <[email protected]>
Co-authored-by: jnhols <[email protected]>
Co-authored-by: Eva-Maria Behner <[email protected]>
Co-authored-by: Steffen Enders <[email protected]>
  • Loading branch information
5 people authored Jun 20, 2023
1 parent 7b3264b commit d3d8001
Show file tree
Hide file tree
Showing 23 changed files with 990 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
SwitchExtractor,
)
from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest
from decompiler.structures.logic.logic_condition import LogicCondition


class ConditionAwareRefinement(BaseClassConditionAwareRefinement):
Expand All @@ -28,10 +27,6 @@ class ConditionAwareRefinement(BaseClassConditionAwareRefinement):
MissingCaseFinder.find_in_sequence,
]

def __init__(self, asforest: AbstractSyntaxForest):
self.asforest = asforest
super().__init__(asforest.condition_handler)

@classmethod
def refine(cls, asforest: AbstractSyntaxForest):
condition_aware_refinement = cls(asforest)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,10 @@

from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, CaseNode, SwitchNode
from decompiler.structures.ast.condition_symbol import ConditionHandler
from decompiler.structures.ast.switch_node_handler import ExpressionUsages
from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest
from decompiler.structures.logic.logic_condition import LogicCondition, PseudoLogicCondition
from decompiler.structures.pseudo import Condition, Constant, Expression, OperationType, Variable


@dataclass(frozen=True)
class ExpressionUsages:
"""Dataclass that maintain for a condition the used SSA-variables."""

expression: Expression
ssa_usages: Tuple[Optional[Variable]]
from decompiler.structures.pseudo import Condition, Constant, Expression, OperationType


@dataclass
Expand Down Expand Up @@ -48,12 +42,13 @@ def __hash__(self) -> int:
class BaseClassConditionAwareRefinement:
"""Base Class in charge of logic and condition related things we need during the condition aware refinement."""

def __init__(self, condition_handler: ConditionHandler):
self.condition_handler: ConditionHandler = condition_handler
def __init__(self, asforest: AbstractSyntaxForest):
self.asforest: AbstractSyntaxForest = asforest
self.condition_handler: ConditionHandler = asforest.condition_handler

def _get_constant_equality_check_expressions_and_conditions(
self, condition: LogicCondition
) -> Iterator[Tuple[Expression, LogicCondition]]:
) -> Iterator[Tuple[ExpressionUsages, LogicCondition]]:
"""
Check whether the given condition is a simple comparison of an expression with one or more constants + perhaps a conjunction
with another condition.
Expand All @@ -65,11 +60,11 @@ def _get_constant_equality_check_expressions_and_conditions(
if condition.is_conjunction:
for disjunction in condition.operands:
if expression := self._get_const_eq_check_expression_of_disjunction(disjunction):
yield (expression, disjunction)
yield expression, disjunction
elif expression := self._get_const_eq_check_expression_of_disjunction(condition):
yield (expression, condition)
yield expression, condition

def _get_const_eq_check_expression_of_disjunction(self, condition: LogicCondition) -> Optional[Expression]:
def _get_const_eq_check_expression_of_disjunction(self, condition: LogicCondition) -> Optional[ExpressionUsages]:
"""
Check whether the given condition is a composition of comparisons of the same expression with constants.
Expand All @@ -89,47 +84,21 @@ def _get_const_eq_check_expression_of_disjunction(self, condition: LogicConditio
compared_expressions = [self._get_expression_compared_with_constant(literal) for literal in operands]
if len(set(compared_expressions)) != 1 or compared_expressions[0] is None:
return None
used_variables = tuple(var.ssa_name for var in compared_expressions[0].requirements)
return (
compared_expressions[0]
if all(used_variables == tuple(var.ssa_name for var in expression.requirements) for expression in compared_expressions[1:])
else None
)

def _get_expression_compared_with_constant(self, reaching_condition: LogicCondition) -> Optional[Expression]:
return compared_expressions[0]

def _get_expression_compared_with_constant(self, reaching_condition: LogicCondition) -> Optional[ExpressionUsages]:
"""
Check whether the given reaching condition, which is a literal, i.e., a z3-symbol or its negation is of the form `expr == const`.
If this is the case, then we return the expression `expr`.
"""
condition = self._get_literal_condition(reaching_condition)
if condition is not None and condition.operation == OperationType.equal:
return self._get_expression_compared_with_constant_in(condition)
return None

def _get_literal_condition(self, condition: LogicCondition) -> Optional[Condition]:
"""Check whether the given condition is a literal. If this is the case then it returns the condition that belongs to the literal."""
if condition.is_symbol:
return self.condition_handler.get_condition_of(condition)
if condition.is_negation and (neg_cond := ~condition).is_symbol:
return self.condition_handler.get_condition_of(neg_cond).negate()
return None

@staticmethod
def _get_expression_compared_with_constant_in(condition: Condition) -> Optional[Expression]:
"""
Check whether the given condition, of type Condition, compares a constant with an expression
return self.asforest.switch_node_handler.get_potential_switch_expression(reaching_condition)

- If this is the case, the function returns the expression
- Otherwise, it returns None.
def _get_constant_compared_with_expression(self, reaching_condition: LogicCondition) -> Optional[Constant]:
"""
non_constants = [operand for operand in condition.operands if not isinstance(operand, Constant)]
return non_constants[0] if len(non_constants) == 1 else None

@staticmethod
def _get_constant_compared_in_condition(condition: Condition) -> Optional[Constant]:
"""Return the constant of a Condition, i.e., for `expr == const` it returns `const`."""
constant_operands = [operand for operand in condition.operands if isinstance(operand, Constant)]
return constant_operands[0] if len(constant_operands) == 1 else None
Check whether the given reaching condition, which is a literal, i.e., a z3-symbol or its negation is of the form `expr == const`.
If this is the case, then we return the constant `const`.
"""
return self.asforest.switch_node_handler.get_potential_switch_constant(reaching_condition)

def _convert_to_z3_condition(self, condition: LogicCondition) -> PseudoLogicCondition:
return PseudoLogicCondition.initialize_from_formula(condition, self.condition_handler.get_z3_condition_map())
Expand All @@ -145,7 +114,7 @@ def _condition_is_redundant_for_switch_node(self, switch_node: AbstractSyntaxTre
"""
1. Check whether the given node is a switch node.
2. If this is the case then we check whether condition is always fulfilled when one of the switch cases is fulfilled
and return the switch node. Otherwise we return None.
and return the switch node. Otherwise, we return None.
- If the switch node has a default case, then we can not add any more cases.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from decompiler.pipeline.controlflowanalysis.restructuring_commons.condition_aware_refinement_commons.base_class_car import (
BaseClassConditionAwareRefinement,
CaseNodeCandidate,
ExpressionUsages,
)
from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, CaseNode, CodeNode, ConditionNode, SeqNode, SwitchNode
from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, CaseNode, CodeNode, ConditionNode, SeqNode, SwitchNode, TrueNode
from decompiler.structures.ast.reachability_graph import CaseDependencyGraph, LinearOrderDependency, SiblingReachability
from decompiler.structures.ast.switch_node_handler import ExpressionUsages
from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest
from decompiler.structures.logic.logic_condition import LogicCondition
from decompiler.structures.pseudo import Condition, Constant, Expression
from decompiler.structures.pseudo import Break, Constant, Expression


@dataclass
Expand All @@ -24,26 +24,117 @@ class SwitchNodeCandidate:
def construct_switch_cases(self) -> Iterator[Tuple[CaseNode, AbstractSyntaxTreeNode]]:
"""Construct Switch-case for itself."""
for case_candidate in self.cases:
yield (case_candidate.construct_case_node(self.expression), case_candidate.node)
yield case_candidate.construct_case_node(self.expression), case_candidate.node


class InitialSwitchNodeConstructor(BaseClassConditionAwareRefinement):
"""Class that constructs switch nodes."""

def __init__(self, asforest: AbstractSyntaxForest):
"""
self.asforest: The asforst where we try to construct switch nodes
"""
self.asforest = asforest
super().__init__(asforest.condition_handler)

@classmethod
def construct(cls, asforest: AbstractSyntaxForest):
"""Constructs initial switch nodes if possible."""
initial_switch_constructor = cls(asforest)
for cond_node in asforest.get_condition_nodes_post_order(asforest.current_root):
initial_switch_constructor._extract_case_nodes_from_nested_condition(cond_node)
for seq_node in asforest.get_sequence_nodes_post_order(asforest.current_root):
initial_switch_constructor._try_to_construct_initial_switch_node_for(seq_node)

def _extract_case_nodes_from_nested_condition(self, cond_node: ConditionNode) -> None:
"""
Extract CaseNodeCandidates from nested if-conditions.
- Nested if-conditions can belong to a switch, i.e., Condition node whose condition is a '==' or '!=' comparison of a variable v and
a constant, i.e., v == 2 or v != 2
- The branch with the '!=' condition is
(i) either a Condition node whose condition is a '==' or '!=' comparison of the same variable v and a different constant, or a
Code node whose reaching condition is of this form, i.e., v == 1 or v != 1
(ii) a sequence node whose first and last node is a condition node or code node with the properties described in (i)
- We extract the conditions into a sequence, such that _try_to_construct_initial_switch_node_for can reconstruct the switch.
"""
if cond_node.false_branch is None:
return
if first_case_candidate_expression := self._get_possible_case_candidate_for_condition_node(cond_node):
if second_case_candidate := self._second_case_candidate_exists_in_branch(
cond_node.false_branch_child, first_case_candidate_expression
):
self._extract_conditions_to_obtain_switch(cond_node, second_case_candidate)

def _get_possible_case_candidate_for_condition_node(self, cond_node: ConditionNode) -> Optional[ExpressionUsages]:
"""
Check whether one branch condition is a possible switch case
- Make sure, that the possible switch case is always the true-branch
- If we find a candidate, return a CaseNodeCandidate containing the branch and the switch expression, else return None.
"""
possible_expressions: List[Tuple[ExpressionUsages, LogicCondition]] = list(
self._get_constant_equality_check_expressions_and_conditions(cond_node.condition)
)
if not possible_expressions and cond_node.false_branch_child:
if possible_expressions := list(self._get_constant_equality_check_expressions_and_conditions(~cond_node.condition)):
cond_node.switch_branches()

if len(possible_expressions) == 1:
return possible_expressions[0][0]

def _second_case_candidate_exists_in_branch(
self, ast_node: AbstractSyntaxTreeNode, first_case_expression: ExpressionUsages
) -> Optional[AbstractSyntaxTreeNode]:
"""
Check whether a possible case candidate whose expression is equal to first_case_expression, is contained in the given ast_node.
- The case candidate can either be:
- the ast-node itself if the reaching condition matches a case-condition
- the true or false branch if the ast_node is a condition node where the condition or negation matches a case-condition
- the first or last child, if the node is a Sequence node and it has one of the above conditions.
"""
candidates = [ast_node]
if isinstance(ast_node, SeqNode):
candidates += [ast_node.children[0], ast_node.children[-1]]
for node in candidates:
second_case_candidate = self._find_second_case_candidate_in(node)
if second_case_candidate is not None and second_case_candidate[0] == first_case_expression:
return second_case_candidate[1]

def _find_second_case_candidate_in(self, ast_node: AbstractSyntaxTreeNode) -> Optional[Tuple[ExpressionUsages, AbstractSyntaxTreeNode]]:
"""Check whether the ast-node fulfills the properties of the second-case node to extract from nested conditions."""
if isinstance(ast_node, ConditionNode):
return self._get_possible_case_candidate_for_condition_node(ast_node), ast_node.true_branch_child
if case_candidate := self._get_possible_case_candidate_for(ast_node):
return case_candidate.expression, ast_node

def _extract_conditions_to_obtain_switch(self, cond_node: ConditionNode, second_case_node: AbstractSyntaxTreeNode) -> None:
"""
First of all, we extract both branches of the condition node and handle the reaching conditions.
If a branch contains a sequence node, we propagate the reaching condition to its children. This ensures that
the sequence node can be cleaned and the possible case candidates are all children of the same sequence node.
"""
first_case_node = cond_node.true_branch_child
first_case_node.reaching_condition &= cond_node.condition

common_condition = LogicCondition.conjunction_of(self.__parent_conditions(second_case_node, cond_node))
second_case_node.reaching_condition &= common_condition

default_case_node = None

if isinstance(second_case_node.parent, TrueNode):
inner_condition_node = second_case_node.parent.parent
assert isinstance(inner_condition_node, ConditionNode), "parent of True Branch must be a condition node."
second_case_node.reaching_condition &= inner_condition_node.condition
if default_case_node := inner_condition_node.false_branch_child:
default_case_node.reaching_condition &= LogicCondition.conjunction_of(
(common_condition, ~inner_condition_node.condition, ~cond_node.condition)
)

cond_node.reaching_condition = self.condition_handler.get_true_value()
self.asforest.extract_branch_from_condition_node(cond_node, cond_node.true_branch, update_reachability=False)
new_seq_node = cond_node.parent
if default_case_node:
self.asforest._remove_edge(default_case_node.parent, default_case_node)
self.asforest._add_edge(new_seq_node, default_case_node)
self.asforest._remove_edge(second_case_node.parent, second_case_node)
self.asforest._add_edge(new_seq_node, second_case_node)
self.asforest.clean_up(new_seq_node)

def _try_to_construct_initial_switch_node_for(self, seq_node: SeqNode) -> None:
"""
Construct a switch node whose cases are children of the current sequence node.
Expand Down Expand Up @@ -96,14 +187,13 @@ def _get_possible_case_candidate_for(self, ast_node: AbstractSyntaxTreeNode) ->
- Otherwise, the function returns None.
- Note: Cases can not end with a loop-break statement
"""
possible_expressions: List[Tuple[Expression, LogicCondition]] = list()
possible_conditions: List[Tuple[ExpressionUsages, LogicCondition]] = list()
if (possible_case_condition := ast_node.get_possible_case_candidate_condition()) is not None:
possible_expressions = list(self._get_constant_equality_check_expressions_and_conditions(possible_case_condition))
possible_conditions = list(self._get_constant_equality_check_expressions_and_conditions(possible_case_condition))

if len(possible_expressions) == 1:
expression, condition = possible_expressions[0]
used_variables = tuple(var.ssa_name for var in expression.requirements)
return CaseNodeCandidate(ast_node, ExpressionUsages(expression, used_variables), possible_expressions[0][1])
if len(possible_conditions) == 1:
expression_usage, condition = possible_conditions[0]
return CaseNodeCandidate(ast_node, expression_usage, condition)

return None

Expand Down Expand Up @@ -168,6 +258,8 @@ def _add_constants_to_cases(self, switch_node: SwitchNode, case_dependency_graph
new_start_node = self._add_constants_for_linear_order_starting_at(
starting_case, linear_ordering_starting_at, linear_order_dependency_graph, considered_conditions
)
if starting_case in cross_nodes and starting_case != new_start_node:
cross_nodes = [new_start_node if id(n) == id(starting_case) else n for n in cross_nodes]
conditions_considered_at[new_start_node] = considered_conditions
self._get_linear_order_for(cross_nodes, linear_ordering_starting_at, linear_order_dependency_graph)
else:
Expand Down Expand Up @@ -237,8 +329,7 @@ def _add_constants_to_cases_for(
self._update_reaching_condition_of(case_node, considered_conditions)

if case_node.reaching_condition.is_literal:
condition: Condition = self._get_literal_condition(case_node.reaching_condition)
case_node.constant = self._get_constant_compared_in_condition(condition)
case_node.constant = self._get_constant_compared_with_expression(case_node.reaching_condition)
considered_conditions.add(case_node.reaching_condition)
elif case_node.reaching_condition.is_false:
case_node.constant = Constant("add_to_previous_case")
Expand Down Expand Up @@ -320,8 +411,8 @@ def prepend_empty_cases_to_case_with_or_condition(self, case: CaseNode) -> List[
"""
condition_for_constant: Dict[Constant, LogicCondition] = dict()
for literal in case.reaching_condition.operands:
if condition := self._get_literal_condition(literal):
condition_for_constant[self._get_constant_compared_in_condition(condition)] = literal
if constant := self._get_constant_compared_with_expression(literal):
condition_for_constant[constant] = literal
else:
raise ValueError(
f"The case node should have a reaching-condition that is a disjunction of literals, but it has the clause {literal}."
Expand Down Expand Up @@ -471,3 +562,9 @@ def _clean_up_reaching_conditions(self, switch_node: SwitchNode) -> None:
continue
elif not case_node.reaching_condition.is_true:
raise ValueError(f"{case_node} should have a literal as reaching condition, but RC = {case_node.reaching_condition}.")

def __parent_conditions(self, second_case_node: AbstractSyntaxTreeNode, cond_node: ConditionNode):
yield self.condition_handler.get_true_value()
current_node = second_case_node
while (current_node := current_node.parent) != cond_node:
yield current_node.reaching_condition
Loading

0 comments on commit d3d8001

Please sign in to comment.