diff --git a/decompiler/backend/codevisitor.py b/decompiler/backend/codevisitor.py index ad064012b..27f3523c2 100644 --- a/decompiler/backend/codevisitor.py +++ b/decompiler/backend/codevisitor.py @@ -39,6 +39,7 @@ def __init__(self, task: DecompilerTask): self._int_repr_scope: int = task.options.getint("code-generator.int_representation_scope", fallback=256) self._neg_hex_as_twos_complement: bool = task.options.getboolean("code-generator.negative_hex_as_twos_complement", fallback=True) self._aggressive_array_detection: bool = task.options.getboolean("code-generator.aggressive_array_detection", fallback=False) + self._simplify_branches: bool = task.options.getboolean("code-generator.simplify_branches", fallback=True) self._preferred_true_branch: str = task.options.getstring("code-generator.preferred_true_branch", fallback="none") self.task = task @@ -69,10 +70,17 @@ def visit_loop_node(self, node: ast_nodes.LoopNode) -> str: def visit_condition_node(self, node: ast_nodes.ConditionNode) -> str: """Generate code for a conditional.""" true_str = self.visit(node.true_branch_child) + if self._simplify_branches and node.condition.is_true: + return true_str if node.false_branch is None: + if self._simplify_branches and (node.condition.is_false or not true_str): + return "" return f"if ({self._condition_string(node.condition)}) {{{true_str}}}" - false_str = self.visit(node.false_branch_child) + if self._simplify_branches and node.condition.is_false: + return false_str + if self._simplify_branches and not false_str: + return f"if ({self._condition_string(node.condition)}) {{{true_str}}}" condition = node.condition true_child = node.true_branch_child diff --git a/decompiler/util/default.json b/decompiler/util/default.json index e2eedadac..bf07da39b 100644 --- a/decompiler/util/default.json +++ b/decompiler/util/default.json @@ -418,6 +418,16 @@ "is_hidden_from_cli": false, "argument_name": "--variable-declarations-per-line" }, + { + "dest": "code-generator.simplify_branches", + "default": true, + "title": "Simplify branches with true or false conditions", + "type": "boolean", + "description": "Removes branches in the output that wont be reached because of a 'true' or 'false' condition", + "is_hidden_from_gui": false, + "is_hidden_from_cli": false, + "argument_name": "--simplify-branches" + }, { "dest": "code-generator.preferred_true_branch", "default": "smallest", diff --git a/tests/backend/test_codegenerator.py b/tests/backend/test_codegenerator.py index 021160101..496897c6b 100644 --- a/tests/backend/test_codegenerator.py +++ b/tests/backend/test_codegenerator.py @@ -62,6 +62,11 @@ def true_condition(context=None): return LogicCondition.initialize_true(context) +def false_condition(context=None): + context = LogicCondition.generate_new_context() if context is None else context + return LogicCondition.initialize_false(context) + + def logic_cond(name: str, context) -> LogicCondition: return LogicCondition.initialize_symbol(name, context) @@ -99,6 +104,7 @@ def _generate_options( twos_complement: bool = True, array_detection: bool = False, var_declarations_per_line: int = 1, + simplify_branches: bool = True, preferred_true_branch: str = "smallest", ): options = Options() @@ -112,6 +118,7 @@ def _generate_options( options.set("code-generator.negative_hex_as_twos_complement", twos_complement) options.set("code-generator.aggressive_array_detection", array_detection) options.set("code-generator.variable_declarations_per_line", var_declarations_per_line) + options.set("code-generator.simplify_branches", simplify_branches) options.set("code-generator.preferred_true_branch", preferred_true_branch) return options @@ -236,6 +243,106 @@ def test_function_with_true_condition(self): ast._add_edges_from(((root, condition_node), (seq_node, code_node))) assert self._regex_matches( r"^%int +test_function\(%int +a%,%int +b%\)%{%int%c;%if%\(%true%\)%{%c%=%5%;%return%c%;%}%}%$".replace("%", "\\s*"), + self._task(ast, params=[var_a.copy(), var_b.copy()], options=_generate_options(simplify_branches=False)), + ) + + def test_function_with_simplified_true_condition(self): + """ + if(true){ + c = 5 + return c + } + """ + context = LogicCondition.generate_new_context() + root = SeqNode(LogicCondition.initialize_true(context)) + ast = AbstractSyntaxTree(root, {x1_symbol(context): Condition(OperationType.less, [var_c.copy(), const_5.copy()])}) + seq_node = ast.factory.create_seq_node() + ast._add_node(seq_node) + code_node = ast._add_code_node([instructions.Assignment(var_c.copy(), const_5.copy()), instructions.Return([var_c.copy()])]) + condition_node = ast._add_condition_node_with(condition=true_condition(ast.factory.logic_context), true_branch=seq_node) + ast._add_edges_from(((root, condition_node), (seq_node, code_node))) + assert self._regex_matches( + r"^%int +test_function\(%int +a%,%int +b%\)%{%int%c;%c%=%5%;%return%c%;%}%$".replace("%", "\\s*"), + self._task(ast, params=[var_a.copy(), var_b.copy()]), + ) + + def test_function_with_simplified_false_condition(self): + """ + if(false){ + c = 5 + return c + } else { + return 0 + } + """ + context = LogicCondition.generate_new_context() + root = SeqNode(LogicCondition.initialize_true(context)) + ast = AbstractSyntaxTree(root, {x1_symbol(context): Condition(OperationType.less, [var_c.copy(), const_5.copy()])}) + true_seq_node = ast.factory.create_seq_node() + ast._add_node(true_seq_node) + true_code_node = ast._add_code_node([instructions.Assignment(var_c.copy(), const_5.copy()), instructions.Return([var_c.copy()])]) + false_seq_node = ast.factory.create_seq_node() + ast._add_node(false_seq_node) + false_code_node = ast._add_code_node([instructions.Return([const_0.copy()])]) + condition_node = ast._add_condition_node_with( + condition=false_condition(ast.factory.logic_context), true_branch=true_seq_node, false_branch=false_seq_node + ) + ast._add_edges_from(((root, condition_node), (true_seq_node, true_code_node), (false_seq_node, false_code_node))) + assert self._regex_matches( + r"^%int +test_function\(%int +a%,%int +b%\)%{%int%c;%return%0%;%}%$".replace("%", "\\s*"), + self._task(ast, params=[var_a.copy(), var_b.copy()]), + ) + + def test_function_with_simplified_false_condition_in_true_branch(self): + """ + if(a == 5){ + if(false){ + c = 5 + return c + } + } + """ + context = LogicCondition.generate_new_context() + root = SeqNode(LogicCondition.initialize_true(context)) + ast = AbstractSyntaxTree(root, {x1_symbol(context): Condition(OperationType.less, [var_c.copy(), const_5.copy()])}) + seq_node = ast.factory.create_seq_node() + ast._add_node(seq_node) + code_node = ast._add_code_node([instructions.Assignment(var_c.copy(), const_5.copy()), instructions.Return([var_c.copy()])]) + false_condition_node = ast._add_condition_node_with(condition=false_condition(ast.factory.logic_context), true_branch=seq_node) + condition_node = ast._add_condition_node_with(condition=x1_symbol(ast.factory.logic_context), true_branch=false_condition_node) + ast._add_edges_from(((root, condition_node), (seq_node, code_node))) + assert self._regex_matches( + r"^%int +test_function\(%int +a%,%int +b%\)%{%int%c;%}%$".replace("%", "\\s*"), + self._task(ast, params=[var_a.copy(), var_b.copy()]), + ) + + def test_function_with_simplified_false_condition_in_false_branch(self): + """ + if(a == 5){ + return 0 + } else { + if(false){ + c = 5 + return c + } + } + """ + context = LogicCondition.generate_new_context() + root = SeqNode(LogicCondition.initialize_true(context)) + ast = AbstractSyntaxTree(root, {x1_symbol(context): Condition(OperationType.less, [var_c.copy(), const_5.copy()])}) + seq_node = ast.factory.create_seq_node() + ast._add_node(seq_node) + false_condition_code_node = ast._add_code_node( + [instructions.Assignment(var_c.copy(), const_5.copy()), instructions.Return([var_c.copy()])] + ) + false_condition_node = ast._add_condition_node_with(condition=false_condition(ast.factory.logic_context), true_branch=seq_node) + code_node = ast._add_code_node([instructions.Return([const_0.copy()])]) + condition_node = ast._add_condition_node_with( + condition=x1_symbol(ast.factory.logic_context), true_branch=code_node, false_branch=false_condition_node + ) + ast._add_edges_from(((root, condition_node), (seq_node, false_condition_code_node))) + assert self._regex_matches( + r"^%int +test_function\(%int +a%,%int +b%\)%{%int%c;%if%\(%c%<%5%\)%{%return%0%;%}%}%$".replace("%", "\\s*"), self._task(ast, params=[var_a.copy(), var_b.copy()]), )