Skip to content

Commit

Permalink
SwitchExtractor Improvement (#424)
Browse files Browse the repository at this point in the history
ebehner authored Jun 24, 2024
1 parent df6e516 commit 68d1d13
Showing 4 changed files with 176 additions and 60 deletions.
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@ def find(cls, asforest: AbstractSyntaxForest, options: RestructuringOptions):
case_candidate_information.case_node, case_candidate_information.case_constants, case_candidate_information.switch_node
)
if case_candidate_information.in_sequence:
asforest.extract_switch_from_condition_sequence(case_candidate_information.switch_node, condition_node)
asforest.extract_switch_from_sequence(case_candidate_information.switch_node)
else:
asforest.replace_condition_node_by_single_branch(condition_node)

Original file line number Diff line number Diff line change
@@ -4,54 +4,46 @@
BaseClassConditionAwareRefinement,
)
from decompiler.pipeline.controlflowanalysis.restructuring_options import RestructuringOptions
from decompiler.structures.ast.ast_nodes import ConditionNode, FalseNode, SeqNode, TrueNode
from decompiler.structures.ast.ast_nodes import ConditionNode, FalseNode, SeqNode, SwitchNode, TrueNode
from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest
from decompiler.structures.logic.logic_condition import LogicCondition


class SwitchExtractor(BaseClassConditionAwareRefinement):
"""Extract switch nodes from condition nodes if the condition node is irrelevant for the switch node."""

def __init__(self, asforest: AbstractSyntaxForest, options: RestructuringOptions):
"""
self.current_cond_node: The condition node which we consider to extract switch nodes.
"""
super().__init__(asforest, options)
self._current_cond_node: Optional[ConditionNode] = None

@classmethod
def extract(cls, asforest: AbstractSyntaxForest, options: RestructuringOptions):
"""
Extract switch nodes from condition nodes, i.e., if a switch node is a branch of a condition node whose condition is redundant for
the switch node, we extract it from the condition node.
"""
"""Extract switch nodes from condition nodes, or sequence-nodes with a non-trivial reaching-condition."""
switch_extractor = cls(asforest, options)
for condition_node in asforest.get_condition_nodes_post_order(asforest.current_root):
switch_extractor._current_cond_node = condition_node
switch_extractor._extract_switches_from_condition()
for switch_node in list(asforest.get_switch_nodes_post_order(asforest.current_root)):
while switch_extractor._successfully_extracts_switch_nodes(switch_node):
pass

def _extract_switches_from_condition(self) -> None:
"""Extract switch nodes in the true and false branch of the given condition node."""
if self._current_cond_node.false_branch:
self._try_to_extract_switch_from_branch(self._current_cond_node.false_branch)
if self._current_cond_node.true_branch:
self._try_to_extract_switch_from_branch(self._current_cond_node.true_branch)
if self._current_cond_node in self.asforest:
self._current_cond_node.clean()

def _try_to_extract_switch_from_branch(self, branch: Union[TrueNode, FalseNode]) -> None:
def _successfully_extracts_switch_nodes(self, switch_node: SwitchNode) -> bool:
"""
1. If the given branch of the condition node is a switch node,
then extract it if the reaching condition is redundant for the switch node.
2. If the given branch of the condition node is a sequence node whose first or last node is a switch node,
then extract it if the reaching condition is redundant for the switch node.
We extract the given switch-node, if possible, and return whether it was successfully extracted.
1. If the switch node has a sequence node as parent and is its first or last child
i) Sequence node has a non-trivial reaching-condition
--> extract the switch from the sequence node if the reaching-condition is redundant for the switch
ii) Sequence node has a trivial reaching-condition, and its parent is a branch of a condition node
--> extract the switch from the condition-node if the branch-condition is redundant for the switch
2. If the switch node has a branch of a condition-node as parent
--> extract the switch from the condition node if the branch-condition is redundant for the switch
"""
branch_condition = branch.branch_condition
if self._condition_is_redundant_for_switch_node(branch.child, branch_condition):
self._extract_switch_node_from_branch(branch)
elif isinstance(sequence_node := branch.child, SeqNode):
for switch_node in [sequence_node.children[0], sequence_node.children[-1]]:
if self._condition_is_redundant_for_switch_node(switch_node, branch_condition):
self.asforest.extract_switch_from_condition_sequence(switch_node, self._current_cond_node)
switch_parent = switch_node.parent
if isinstance(switch_parent, SeqNode):
if not switch_parent.reaching_condition.is_true:
return self._successfully_extract_switch_from_first_or_last_child_of(switch_parent, switch_parent.reaching_condition)
elif isinstance(branch := switch_parent.parent, TrueNode | FalseNode):
return self._successfully_extract_switch_from_first_or_last_child_of(switch_parent, branch.branch_condition)
elif isinstance(switch_parent, TrueNode | FalseNode) and self._condition_is_redundant_for_switch_node(
switch_node, switch_parent.branch_condition
):
self._extract_switch_node_from_branch(switch_parent)
return True
return False

def _extract_switch_node_from_branch(self, branch: Union[TrueNode, FalseNode]) -> None:
"""
@@ -64,7 +56,20 @@ def _extract_switch_node_from_branch(self, branch: Union[TrueNode, FalseNode]) -
:param branch: The branch from which we extract the switch node.
:return: If we introduce a new sequence node, then return this node, otherwise return None.
"""
if len(self._current_cond_node.children) != 2:
self.asforest.replace_condition_node_by_single_branch(self._current_cond_node)
assert isinstance(condition_node := branch.parent, ConditionNode), "The parent of a true/false-branch must be a condition node!"
if len(condition_node.children) != 2:
self.asforest.replace_condition_node_by_single_branch(condition_node)
else:
self.asforest.extract_branch_from_condition_node(self._current_cond_node, branch, False)
self.asforest.extract_branch_from_condition_node(condition_node, branch, False)

def _successfully_extract_switch_from_first_or_last_child_of(self, sequence_node: SeqNode, condition: LogicCondition) -> bool:
"""
Check whether the first or last child of the sequence node is a switch-node for which the given condition is redundant.
If this is the case, extract the switch-node from the sequence.
"""
for switch_node in [sequence_node.children[0], sequence_node.children[-1]]:
if self._condition_is_redundant_for_switch_node(switch_node, condition):
assert isinstance(switch_node, SwitchNode), f"The node {switch_node} must be a switch-node!"
self.asforest.extract_switch_from_sequence(switch_node)
return True
return False
50 changes: 33 additions & 17 deletions decompiler/structures/ast/syntaxforest.py
Original file line number Diff line number Diff line change
@@ -241,26 +241,42 @@ def extract_branch_from_condition_node(
if new_seq.parent is not None:
new_seq.parent.clean()

def extract_switch_from_condition_sequence(self, switch_node: SwitchNode, condition_node: ConditionNode):
"""Extract the given switch-node, that is the first or last child of a seq-node Branch from the condition node"""
seq_node_branch = switch_node.parent
seq_node_branch_children = seq_node_branch.children
assert seq_node_branch.parent in condition_node.children, f"{seq_node_branch} must be a branch of {condition_node}"
new_seq_node = self._add_sequence_node_before(condition_node)
self._remove_edge(seq_node_branch, switch_node)
self._add_edge(new_seq_node, switch_node)
if switch_node is seq_node_branch_children[0]:
new_seq_node._sorted_children = (new_seq_node, condition_node)
seq_node_branch._sorted_children = seq_node_branch_children[1:]
elif switch_node is seq_node_branch_children[-1]:
new_seq_node._sorted_children = (condition_node, new_seq_node)
seq_node_branch._sorted_children = seq_node_branch_children[:-1]

seq_node_branch.clean()
condition_node.clean()
def extract_switch_from_sequence(self, switch_node: SwitchNode):
"""
Extract the given switch-node, that is the first or last child of a seq-node Branch from the condition node
or sequence node with a non-trivial reaching-condition.
"""
switch_parent = switch_node.parent
assert isinstance(switch_parent, SeqNode), f"The parent of the switch-node {switch_node} must be a sequence node!"
if isinstance(switch_parent, SeqNode) and not switch_parent.reaching_condition.is_true:
new_seq_node = self._extract_switch_from_subtree(switch_parent, switch_node)
elif isinstance(condition_node := switch_parent.parent.parent, ConditionNode):
new_seq_node = self._extract_switch_from_subtree(condition_node, switch_node)
condition_node.clean()
else:
raise ValueError(
f"The parent of the switch node {switch_node} must either have a non-trivial reaching-condition or is a branch of a condition-node!"
)

if new_seq_node.parent is not None:
new_seq_node.parent.clean()

def _extract_switch_from_subtree(self, subtree_head: AbstractSyntaxTreeNode, switch_node: SwitchNode):
switch_parent = switch_node.parent
switch_parent_children = switch_node.children
new_seq_node = self._add_sequence_node_before(subtree_head)
self._remove_edge(switch_parent, switch_node)
self._add_edge(new_seq_node, switch_node)
if switch_node is switch_parent_children[0]:
new_seq_node._sorted_children = (new_seq_node, subtree_head)
switch_parent._sorted_children = switch_parent_children[1:]
elif switch_node is switch_parent_children[-1]:
new_seq_node._sorted_children = (subtree_head, new_seq_node)
switch_parent._sorted_children = switch_parent_children[:-1]

switch_parent.clean()
return new_seq_node

def extract_all_breaks_from_condition_node(self, cond_node: ConditionNode):
"""Remove all break instructions at the end of the condition node and extracts them, i.e., add a break after the condition."""
for node in cond_node.get_end_nodes():
Original file line number Diff line number Diff line change
@@ -8,15 +8,18 @@
from decompiler.pipeline.controlflowanalysis.restructuring_commons.condition_aware_refinement_commons.missing_case_finder_intersecting_constants import (
MissingCaseFinderIntersectingConstants,
)
from decompiler.pipeline.controlflowanalysis.restructuring_commons.condition_aware_refinement_commons.switch_extractor import (
SwitchExtractor,
)
from decompiler.pipeline.controlflowanalysis.restructuring_options import LoopBreakOptions, RestructuringOptions
from decompiler.structures.ast.ast_nodes import ConditionNode, SeqNode, SwitchNode
from decompiler.structures.ast.ast_nodes import CodeNode, ConditionNode, SeqNode, SwitchNode
from decompiler.structures.ast.condition_symbol import ConditionHandler
from decompiler.structures.ast.reachability_graph import SiblingReachabilityGraph
from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest
from decompiler.structures.graphs.cfg import BasicBlock, ControlFlowGraph, FalseCase, TrueCase, UnconditionalEdge
from decompiler.structures.pseudo.expressions import Constant, Variable
from decompiler.structures.pseudo.expressions import Constant, ImportedFunctionSymbol, Variable
from decompiler.structures.pseudo.instructions import Assignment, Branch, Return
from decompiler.structures.pseudo.operations import BinaryOperation, Condition, OperationType
from decompiler.structures.pseudo.operations import BinaryOperation, Call, Condition, ListOperation, OperationType
from decompiler.structures.pseudo.typing import CustomType, Integer
from decompiler.task import DecompilerTask

@@ -183,3 +186,95 @@ def test_insert_intersecting_cases_anywhere(task):

assert isinstance(ast.current_root, SeqNode) and len(ast.current_root.children) == 1
assert isinstance(switch := ast.current_root.children[0], SwitchNode) and switch.cases == (case2, case1)


def test_switch_extractor_sequence(task):
"""Test, switch gets extracted from sequence nodes with Reaching Condition."""
condition_handler = ConditionHandler()
# cond_1_symbol = condition_handler.add_condition(Condition(OperationType.equal, [var_c, const[1]]))
cond_2_symbol = condition_handler.add_condition(Condition(OperationType.not_equal, [var_c, const[1]]))

ast = AbstractSyntaxForest(condition_handler=condition_handler)
root = ast.factory.create_seq_node(reaching_condition=cond_2_symbol)
code_node = ast.factory.create_code_node(
[Assignment(ListOperation([]), Call(ImportedFunctionSymbol("scanf", 0x42), [Constant(0x804B01F), var_c]))]
)
switch = ast.factory.create_switch_node(var_c)
case1 = ast.factory.create_case_node(var_c, const[2], break_case=True)
case2 = ast.factory.create_case_node(var_c, const[3], break_case=True)
case_content = [
ast.factory.create_code_node([Assignment(var_b, BinaryOperation(OperationType.plus, [var_b, const[i + 1]]))]) for i in range(2)
]
ast._add_nodes_from(case_content + [root, code_node, switch, case1, case2])
ast._add_edges_from(
[
(root, code_node),
(root, switch),
(switch, case1),
(switch, case2),
(case1, case_content[0]),
(case2, case_content[1]),
]
)
ast._code_node_reachability_graph.add_reachability_from(
[(code_node, case_content[0]), (code_node, case_content[1]), (case_content[0], case_content[1])]
)
root.sort_children()
switch.sort_cases()
ast.set_current_root(root)

SwitchExtractor.extract(ast, RestructuringOptions(True, True, 2, LoopBreakOptions.structural_variable))
assert isinstance(ast.current_root, SeqNode) and ast.current_root.reaching_condition.is_true and len(ast.current_root.children) == 2
assert ast.current_root.children[0].reaching_condition == cond_2_symbol
assert isinstance(switch := ast.current_root.children[1], SwitchNode) and switch.cases == (case1, case2)


def test_switch_extractor_sequence_no_extraction(task):
"""Test, switch gets extracted from sequence nodes with Reaching Condition."""
condition_handler = ConditionHandler()
# cond_1_symbol = condition_handler.add_condition(Condition(OperationType.equal, [var_c, const[1]]))
cond_1_symbol = condition_handler.add_condition(Condition(OperationType.not_equal, [var_b, const[1]]))
cond_2_symbol = condition_handler.add_condition(Condition(OperationType.not_equal, [var_c, const[1]]))

ast = AbstractSyntaxForest(condition_handler=condition_handler)
root = ast.factory.create_condition_node(cond_2_symbol)
true_node = ast.factory.create_true_node()
seq_node = ast.factory.create_seq_node(reaching_condition=cond_1_symbol)
code_node = ast.factory.create_code_node(
[Assignment(ListOperation([]), Call(ImportedFunctionSymbol("scanf", 0x42), [Constant(0x804B01F), var_c]))]
)
switch = ast.factory.create_switch_node(var_c)
case1 = ast.factory.create_case_node(var_c, const[2], break_case=True)
case2 = ast.factory.create_case_node(var_c, const[3], break_case=True)
case_content = [
ast.factory.create_code_node([Assignment(var_b, BinaryOperation(OperationType.plus, [var_b, const[i + 1]]))]) for i in range(2)
]
ast._add_nodes_from(case_content + [root, true_node, seq_node, code_node, switch, case1, case2])
ast._add_edges_from(
[
(root, true_node),
(true_node, seq_node),
(seq_node, code_node),
(seq_node, switch),
(switch, case1),
(switch, case2),
(case1, case_content[0]),
(case2, case_content[1]),
]
)
ast._code_node_reachability_graph.add_reachability_from(
[(code_node, case_content[0]), (code_node, case_content[1]), (case_content[0], case_content[1])]
)
seq_node.sort_children()
switch.sort_cases()
ast.set_current_root(root)

SwitchExtractor.extract(ast, RestructuringOptions(True, True, 2, LoopBreakOptions.structural_variable))
assert isinstance(cond := ast.current_root, ConditionNode) and cond.false_branch is None
assert (
isinstance(seq_node := cond.true_branch_child, SeqNode)
and seq_node.reaching_condition == cond_1_symbol
and len(seq_node.children) == 2
)
assert isinstance(seq_node.children[0], CodeNode)
assert isinstance(switch := seq_node.children[1], SwitchNode) and switch.cases == (case1, case2)

0 comments on commit 68d1d13

Please sign in to comment.