Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve else-if chaining #375

Merged
merged 13 commits into from
Jan 24, 2024
30 changes: 27 additions & 3 deletions decompiler/backend/codevisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, task: DecompilerTask):
self._int_repr_scope: int = task.options.getint("code-generator.int_representation_scope", fallback=256)
self._neg_hex_as_twos_complement: bool = task.options.getboolean("code-generator.negative_hex_as_twos_complement", fallback=True)
self._aggressive_array_detection: bool = task.options.getboolean("code-generator.aggressive_array_detection", fallback=False)
self._preferred_true_branch: str = task.options.getstring("code-generator.preferred_true_branch", fallback="none")
self.task = task

def visit_seq_node(self, node: ast_nodes.SeqNode) -> str:
Expand Down Expand Up @@ -70,10 +71,33 @@ def visit_condition_node(self, node: ast_nodes.ConditionNode) -> str:
true_str = self.visit(node.true_branch_child)
if node.false_branch is None:
return f"if ({self._condition_string(node.condition)}) {{{true_str}}}"

false_str = self.visit(node.false_branch_child)
if isinstance(node.false_branch_child, ast_nodes.ConditionNode):
return f"if ({self._condition_string(node.condition)}){{{true_str}}} else {false_str}"
return f"if ({self._condition_string(node.condition)}){{{true_str}}} else{{{false_str}}}"

condition = node.condition
true_child = node.true_branch_child
false_child = node.false_branch_child

swap_branches = None

# if only one branch is a condition node, we want to decide swapping by which branch is a condition node
if isinstance(false_child, ast_nodes.ConditionNode) != isinstance(true_child, ast_nodes.ConditionNode):
swap_branches = not isinstance(false_child, ast_nodes.ConditionNode)

# if we haven't already decided on swapping (swap_branches is None), decide by length
if swap_branches is None:
length_comparisons = {"none": None, "smallest": len(true_str) > len(false_str), "largest": len(true_str) < len(false_str)}
swap_branches = length_comparisons[self._preferred_true_branch]

if swap_branches:
condition = ~condition
true_str, false_str = false_str, true_str
true_child, false_child = false_child, true_child

if isinstance(false_child, ast_nodes.ConditionNode):
return f"if ({self._condition_string(condition)}) {{{true_str}}} else {false_str}"
else:
return f"if ({self._condition_string(condition)}) {{{true_str}}} else {{{false_str}}}"

def visit_true_node(self, node: ast_nodes.TrueNode) -> str:
"""Generate code for the given TrueNode by evaluating its child (Wrapper)."""
Expand Down
20 changes: 20 additions & 0 deletions decompiler/util/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,26 @@
"is_hidden_from_cli": false,
"argument_name": "--variable-declarations-per-line"
},
{
"dest": "code-generator.preferred_true_branch",
"default": "smallest",
"title": "Preferred type of true branch in if-else",
"type": "string",
"enum": [
"smallest",
"largest",
"none"
],
"enumDescriptions": [
"Swap branches, so that the true branch is the larger one length wise.",
"Swap branches, so that the true branch is the smaller one length wise.",
rihi marked this conversation as resolved.
Show resolved Hide resolved
"Don't swap branches based on length."
],
"description": "Swap branches of if-else structures based on the given criteria",
rihi marked this conversation as resolved.
Show resolved Hide resolved
"is_hidden_from_gui": false,
"is_hidden_from_cli": false,
"argument_name": "--preferred_true_branch"
},
{
"dest": "pattern-independent-restructuring.switch_reconstruction",
"default": true,
Expand Down
122 changes: 119 additions & 3 deletions tests/backend/test_codegenerator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Dict, List
from typing import Dict, List, Optional

import decompiler.structures.pseudo.instructions as instructions
import decompiler.structures.pseudo.operations as operations
Expand Down Expand Up @@ -99,6 +99,7 @@ def _generate_options(
twos_complement: bool = True,
array_detection: bool = False,
var_declarations_per_line: int = 1,
preferred_true_branch: str = "smallest",
):
options = Options()
options.set("code-generator.max_complexity", max_complx)
Expand All @@ -111,15 +112,17 @@ def _generate_options(
options.set("code-generator.negative_hex_as_twos_complement", twos_complement)
options.set("code-generator.aggressive_array_detection", array_detection)
options.set("code-generator.variable_declarations_per_line", var_declarations_per_line)
options.set("code-generator.preferred_true_branch", preferred_true_branch)
return options


class TestCodeGeneration:
@staticmethod
def _task(ast: AbstractSyntaxTree, params: List[DataflowObject] = None, return_type: Type = int32):
def _task(ast: AbstractSyntaxTree, params: List[DataflowObject] = None, return_type: Type = int32, options: Optional[Options] = None):
if not params:
params = []
options = _generate_options(max_complx=100, compounding=False)
if not options:
options = _generate_options(compounding=False)
return DecompilerTask("test_function", None, ast=ast, options=options, function_parameters=params, function_return_type=return_type)

@staticmethod
Expand Down Expand Up @@ -253,6 +256,119 @@ def test_function_with_ifelse(self):
r"^%int +test_function\(%int +a%,%int +b%\)%{%int%c;%if%\(%c%<%5%\)%{%c%=%5%;%return%c%;%}%else%{%return%0%;%}%}%$".replace(
"%", "\\s*"
),
self._task(ast, params=[var_a.copy(), var_b.copy()], options=_generate_options(preferred_true_branch="none")),
)

def test_function_with_ifelseif(self):
context = LogicCondition.generate_new_context()
root = SeqNode(LogicCondition.initialize_true(context))
ast = AbstractSyntaxTree(
root,
{
x1_symbol(context): Condition(OperationType.less, [var_a, const_3]),
x2_symbol(context): Condition(OperationType.less, [var_a, const_5]),
},
)

x2_true_node = ast._add_code_node([instructions.Return([const_1])])
x2_false_node = ast._add_code_node([instructions.Return([const_2])])
x1_true_node = ast._add_code_node([instructions.Return([const_0])])
x1_false_node = ast._add_condition_node_with(
condition=x2_symbol(ast.factory.logic_context), true_branch=x2_true_node, false_branch=x2_false_node
)
condition_node = ast._add_condition_node_with(
condition=x1_symbol(ast.factory.logic_context), true_branch=x1_true_node, false_branch=x1_false_node
)

ast._add_edges_from([(root, condition_node)])

assert self._regex_matches(
r"^%int +test_function\(%int +a%,%int +b%\)%{%if%\(%a%<%3%\)%{%return%0%;%}%else +if%\(%a%<%5%\)%{%return%1%;%}%else%{%return%2%;%}%}%$".replace(
"%", "\\s*"
),
self._task(ast, params=[var_a.copy(), var_b.copy()]),
)

def test_function_with_ifelseif_prioritize_elseif_over_length(self):
context = LogicCondition.generate_new_context()
root = SeqNode(LogicCondition.initialize_true(context))
ast = AbstractSyntaxTree(
root,
{
x1_symbol(context): Condition(OperationType.less, [var_a, const_3]),
x2_symbol(context): Condition(OperationType.less, [var_a, const_5]),
},
)

x2_true_node = ast._add_code_node([instructions.Return([const_1])])
x2_false_node = ast._add_code_node([instructions.Return([const_2])])
x1_true_node = ast._add_code_node([instructions.Return([const_0])])
x1_false_node = ast._add_condition_node_with(
condition=x2_symbol(ast.factory.logic_context), true_branch=x2_true_node, false_branch=x2_false_node
)
condition_node = ast._add_condition_node_with(
condition=x1_symbol(ast.factory.logic_context), true_branch=x1_true_node, false_branch=x1_false_node
)

ast._add_edges_from([(root, condition_node)])

assert self._regex_matches(
r"^%int +test_function\(%int +a%,%int +b%\)%{%if%\(%a%<%3%\)%{%return%0%;%}%else +if%\(%a%<%5%\)%{%return%1%;%}%else%{%return%2%;%}%}%$".replace(
"%", "\\s*"
),
self._task(ast, params=[var_a.copy(), var_b.copy()], options=_generate_options(preferred_true_branch="largest")),
)

def test_function_with_ifelseif_swapped_because_elseif(self):
context = LogicCondition.generate_new_context()
root = SeqNode(LogicCondition.initialize_true(context))
ast = AbstractSyntaxTree(
root,
{
x1_symbol(context): Condition(OperationType.greater_or_equal, [var_a, const_3]),
x2_symbol(context): Condition(OperationType.less, [var_a, const_5]),
},
)

x2_true_node = ast._add_code_node([instructions.Return([const_1])])
x2_false_node = ast._add_code_node([instructions.Return([const_2])])
x1_true_node = ast._add_condition_node_with(
condition=x2_symbol(ast.factory.logic_context), true_branch=x2_true_node, false_branch=x2_false_node
)
x1_false_node = ast._add_code_node([instructions.Comment("Long comment to pad branch length..."), instructions.Return([const_0])])
condition_node = ast._add_condition_node_with(
condition=x1_symbol(ast.factory.logic_context), true_branch=x1_true_node, false_branch=x1_false_node
)

ast._add_edges_from([(root, condition_node)])

assert self._regex_matches(
r"^%int +test_function\(%int +a%,%int +b%\)%{%if%\(%a%<%3%\)%{%\/\*%Long comment to pad branch length...%\*\/%return%0%;%}%else +if%\(%a%<%5%\)%{%return%1%;%}%else%{%return%2%;%}%}%$".replace(
"%", "\\s*"
),
self._task(ast, params=[var_a.copy(), var_b.copy()]),
)

def test_function_with_ifelseif_swapped_because_length(self):
context = LogicCondition.generate_new_context()
root = SeqNode(LogicCondition.initialize_true(context))
ast = AbstractSyntaxTree(
root,
{x1_symbol(context): Condition(OperationType.greater_or_equal, [var_a, const_3])},
)

x1_true_node = ast._add_code_node([instructions.Comment("Long comment to pad branch length..."), instructions.Return([const_1])])
x1_false_node = ast._add_code_node([instructions.Return([const_0])])
condition_node = ast._add_condition_node_with(
condition=x1_symbol(ast.factory.logic_context), true_branch=x1_true_node, false_branch=x1_false_node
)

ast._add_edges_from([(root, condition_node)])

assert self._regex_matches(
r"^%int +test_function\(%int +a%,%int +b%\)%{%if%\(%a%<%3%\)%{%return%0%;%}%else%{%\/\*%Long comment to pad branch length...%\*\/%return%1%;%}%}%$".replace(
"%", "\\s*"
),
self._task(ast, params=[var_a.copy(), var_b.copy()]),
)

rihi marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Loading