Skip to content

Commit

Permalink
Recover additional Switch cases (if switch in sequcence of branch) (#396
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ebehner authored Mar 28, 2024
1 parent 43c19d9 commit d9b8337
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
from typing import Optional, Set
from dataclasses import dataclass
from typing import Optional, Set, Tuple

from decompiler.pipeline.controlflowanalysis.restructuring_commons.condition_aware_refinement_commons.missing_case_finder import (
MissingCaseFinder,
)
from decompiler.pipeline.controlflowanalysis.restructuring_options import RestructuringOptions
from decompiler.structures.ast.ast_nodes import ConditionNode, SwitchNode
from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, ConditionalNode, ConditionNode, FalseNode, SeqNode, SwitchNode
from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest
from decompiler.structures.logic.logic_condition import LogicCondition
from decompiler.structures.pseudo import Constant


@dataclass
class CaseCandidateInformation:
case_node: Optional[AbstractSyntaxTreeNode]
case_constants: Set[Constant]
switch_node: SwitchNode
in_sequence: bool


class MissingCaseFinderCondition(MissingCaseFinder):
"""
Class in charge of finding missing case for switch nodes in Condition nodes.
Expand All @@ -23,29 +32,33 @@ def find(cls, asforest: AbstractSyntaxForest, options: RestructuringOptions):
"""Try to find missing cases that are branches of condition nodes."""
missing_case_finder = cls(asforest, options)
for condition_node in asforest.get_condition_nodes_post_order(asforest.current_root):
if new_case_constants := missing_case_finder._can_insert_missing_case_node(condition_node):
if (case_candidate_information := missing_case_finder._can_insert_missing_case_node(condition_node)) is not None:
missing_case_finder._insert_case_node(
condition_node.false_branch_child, new_case_constants, condition_node.true_branch_child
case_candidate_information.case_node, case_candidate_information.case_constants, case_candidate_information.switch_node
)
asforest.replace_condition_node_by_single_branch(condition_node)
if case_candidate_information.in_sequence:
asforest.extract_switch_from_condition_sequence(case_candidate_information.switch_node, condition_node)
else:
asforest.replace_condition_node_by_single_branch(condition_node)

def _can_insert_missing_case_node(self, condition_node: ConditionNode) -> Optional[Set[Constant]]:
def _can_insert_missing_case_node(self, condition_node: ConditionNode) -> Optional[CaseCandidateInformation]:
"""
Check whether one of the branches is a possible case node for the other branch that should be a switch node.
Check whether one of the branches is a possible case node for the other branch that should be a switch node or a sequence node
having a switch-node as first or last child.
If this is the case, return the case-constants for the new case node.
-> We have to make sure that there exists a switch node where we can insert it and that it has the correct condition
-> The case constants can not exist in the switch node where we want to insert the case node.
-> We have to make sure that there exists a switch node where we can insert it and that it has the correct condition.
-> The case constants cannot exist in the switch node where we want to insert the case node.
:param condition_node: The condition node where we want to find a missing case.
:return: Return the set of constant for this switch node if it is a missing case and None otherwise.
"""
if len(condition_node.children) == 1 or not any(isinstance(branch.child, SwitchNode) for branch in condition_node.children):
if len(condition_node.children) == 1 or (switch_candidate := self.get_switch_candidate(condition_node)) is None:
return None
if isinstance(condition_node.false_branch_child, SwitchNode):
switch_node, branch = switch_candidate
if isinstance(branch, FalseNode):
condition_node.switch_branches()

switch_node: SwitchNode = condition_node.true_branch_child
possible_case_node = condition_node.false_branch_child
case_condition = condition_node.false_branch.branch_condition

Expand All @@ -62,7 +75,17 @@ def _can_insert_missing_case_node(self, condition_node: ConditionNode) -> Option

new_case_constants = set(self._get_case_constants_for_condition(case_condition))
if all(case.constant not in new_case_constants for case in switch_node.cases):
return new_case_constants
return CaseCandidateInformation(possible_case_node, new_case_constants, switch_node, isinstance(switch_node.parent, SeqNode))

def get_switch_candidate(self, condition_node: ConditionNode) -> Optional[Tuple[SwitchNode, ConditionalNode]]:
for branch in [b for b in condition_node.children if isinstance(b.child, SwitchNode)]:
return branch.child, branch
for branch in [b for b in condition_node.children if isinstance(b.child, SeqNode)]:
children = branch.child.children
if isinstance(children[0], SwitchNode):
return children[0], branch
if isinstance(children[-1], SwitchNode):
return children[-1], branch

def __reachable_in_switch(self, case_condition: LogicCondition, switch_node: SwitchNode):
if switch_node.reaching_condition.is_true:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5735,3 +5735,87 @@ def test_intersecting_cases(task):
assert isinstance(children[0], CodeNode) and isinstance(children[5], CodeNode)
assert all(isinstance(child, ConditionNode) for child in children[1:4])
assert isinstance(children[4], SwitchNode) and isinstance(children[3].true_branch_child, SwitchNode)


def test_missing_cases_switch_in_sequence(task):
"""
coreutils expr mpn_base_power_of_two_p.
The switch-node does not have a condition-node as parent, instead it is the first child of a sequence whose parent is a condition-node.
"""
var_b = Variable("b", Integer(32, False), None, False, Variable("b", Integer(32, False), 0, False, None))
var_0_16 = Variable("var_0", Integer(64, True), None, True, Variable("rax", Integer(64, True), 16, True, None))
var_0_4 = Variable("var_0", Integer(64, True), None, True, Variable("rax_6", Integer(64, True), 4, True, None))

task.graph.add_nodes_from(
vertices := [
BasicBlock(0, [Branch(Condition(OperationType.greater_us, [var_b, Constant(32, Integer(32, True))], CustomType("bool", 1)))]),
BasicBlock(1, [Branch(Condition(OperationType.equal, [var_b, Constant(128, Integer(32, True))], CustomType("bool", 1)))]),
BasicBlock(
2, [Branch(Condition(OperationType.less_or_equal_us, [var_b, Constant(1, Integer(32, True))], CustomType("bool", 1)))]
),
BasicBlock(3, [Assignment(var_0_16, Constant(7, Integer(64, True)))]),
BasicBlock(4, [Branch(Condition(OperationType.not_equal, [var_b, Constant(256, Integer(32, True))], CustomType("bool", 1)))]),
BasicBlock(5, [Assignment(var_0_16, Constant(0, Integer(64, True)))]),
BasicBlock(6, [Return(ListOperation([Constant(0, Integer(64, True))]))]),
BasicBlock(7, [Return(ListOperation([var_0_16]))]),
BasicBlock(8, [Branch(Condition(OperationType.equal, [var_b, Constant(64, Integer(32, True))], CustomType("bool", 1)))]),
BasicBlock(9, [Assignment(var_0_16, Constant(8, Integer(64, True)))]),
BasicBlock(10, [Assignment(var_0_4, Constant(6, Integer(64, True)))]),
BasicBlock(11, [Assignment(var_0_4, Constant(0, Integer(64, True)))]),
BasicBlock(12, [Return(ListOperation([var_0_4]))]),
]
)
task.graph.add_edges_from(
[
TrueCase(vertices[0], vertices[1]),
FalseCase(vertices[0], vertices[2]),
TrueCase(vertices[1], vertices[3]),
FalseCase(vertices[1], vertices[4]),
TrueCase(vertices[2], vertices[5]),
FalseCase(vertices[2], vertices[6]),
UnconditionalEdge(vertices[3], vertices[7]),
TrueCase(vertices[4], vertices[8]),
FalseCase(vertices[4], vertices[9]),
UnconditionalEdge(vertices[5], vertices[7]),
TrueCase(vertices[8], vertices[10]),
FalseCase(vertices[8], vertices[11]),
UnconditionalEdge(vertices[9], vertices[7]),
UnconditionalEdge(vertices[10], vertices[12]),
UnconditionalEdge(vertices[11], vertices[12]),
]
)

PatternIndependentRestructuring().run(task)

assert isinstance(seq_node := task._ast.root, SeqNode) and len(children := seq_node.children) == 2
assert isinstance(cn := children[0], ConditionNode) and cn.true_branch and cn.false_branch
assert isinstance(children[1], CodeNode) and children[1].instructions == vertices[7].instructions

if cn.condition.is_negation:
cn.switch_branches()
assert cn.condition.is_symbol and task._ast.condition_map[cn.condition] == vertices[0].instructions[0].condition

# True branch
assert isinstance(true_seq := cn.true_branch_child, SeqNode) and len(children := true_seq.children) == 2
assert isinstance(sw := children[0], SwitchNode) and len(sw.cases) == 3 and sw.default is not None
assert (
isinstance(cn_t := children[1], ConditionNode) and cn_t.condition.is_conjunction and len(operands := cn_t.condition.operands) == 2
)
assert (
isinstance(code := cn_t.true_branch_child, CodeNode)
and code.instructions == vertices[12].instructions
and cn_t.false_branch is None
)
assert any((cc_1 := operands[i]).is_symbol for i in [0, 1]) and any((cc_2 := operands[i]).is_negation for i in [0, 1])
assert (cc_2 := ~cc_2).is_symbol
assert [task._ast.condition_map[cc_1], task._ast.condition_map[cc_2]] == [
vertices[4].instructions[0].condition,
vertices[1].instructions[0].condition,
]

# False branch
assert isinstance(false_seq := cn.false_branch_child, SeqNode) and len(children := false_seq.children) == 2
assert isinstance(cn_false := children[0], ConditionNode) and cn_false.false_branch is None and cn_false.condition.is_negation
assert isinstance(code_return := cn_false.true_branch_child, CodeNode) and code_return.instructions == vertices[6].instructions
assert task._ast.condition_map[~cn_false.condition] == vertices[2].instructions[0].condition
assert isinstance(code := children[1], CodeNode) and code.instructions == vertices[5].instructions

0 comments on commit d9b8337

Please sign in to comment.