Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust code generation for the function pointer type. #327

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions decompiler/backend/cexpressiongenerator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
from ctypes import c_byte, c_int, c_long, c_short, c_ubyte, c_uint, c_ulong, c_ushort
from itertools import chain, repeat
from typing import Union

from decompiler.structures import pseudo as expressions
from decompiler.structures.pseudo import Float, Integer, OperationType, StringSymbol
from decompiler.structures.pseudo import Float, FunctionTypeDef, Integer, OperationType, Pointer, StringSymbol, Type
from decompiler.structures.pseudo import instructions as instructions
from decompiler.structures.pseudo import operations as operations
from decompiler.structures.visitors.interfaces import DataflowObjectVisitorInterface
Expand Down Expand Up @@ -361,3 +360,14 @@ def _format_string_literal(constant: expressions.Constant) -> str:
escaped = string_representation.replace('"', '\\"')
return f'"{escaped}"'
return f"{constant}"

@staticmethod
def format_variables_declaration(var_type: Type, var_names: list[str]) -> str:
""" Return a string representation of variable declarations."""
match var_type:
case Pointer(type=FunctionTypeDef() as fun_type):
parameter_names = ", ".join(str(parameter) for parameter in fun_type.parameters)
declarations_without_return_type = [f"(* {var_name})({parameter_names})" for var_name in var_names]
return f"{fun_type.return_type} {', '.join(declarations_without_return_type)}"
case _:
return f"{var_type} {', '.join(var_names)}"
6 changes: 5 additions & 1 deletion decompiler/backend/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from string import Template
from typing import Iterable, List

from decompiler.backend.cexpressiongenerator import CExpressionGenerator
from decompiler.backend.codevisitor import CodeVisitor
from decompiler.backend.variabledeclarations import GlobalDeclarationGenerator, LocalDeclarationGenerator
from decompiler.task import DecompilerTask
Expand Down Expand Up @@ -37,7 +38,10 @@ def generate_function(self, task: DecompilerTask) -> str:
return self.TEMPLATE.substitute(
return_type=task.function_return_type,
name=task.name,
parameters=", ".join(map(lambda param: f"{param.type} {param.name}", task.function_parameters)),
parameters=", ".join(map(
lambda param: CExpressionGenerator.format_variables_declaration(param.type, [param.name]),
task.function_parameters
)),
local_declarations=LocalDeclarationGenerator.from_task(task) if not task.failed else "",
function_body=CodeVisitor(task).visit(task.syntax_tree.root) if not task.failed else task.failure_message,
)
13 changes: 7 additions & 6 deletions decompiler/backend/variabledeclarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
from typing import Iterable, Iterator, List, Set

from decompiler.backend.cexpressiongenerator import CExpressionGenerator
from decompiler.structures.ast.ast_nodes import ForLoopNode, LoopNode
from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree
from decompiler.structures.pseudo import (
Expand Down Expand Up @@ -61,19 +62,19 @@ def visit_unary_operation(self, unary: UnaryOperation):
else:
self.visit(unary.operand.left)

def generate(self, param_names: list = []) -> Iterator[str]:
def generate(self, param_names: list[str] = []) -> Iterator[str]:
"""Generate a string containing the variable definitions for the visited variables."""
variable_type_mapping = defaultdict(list)
for variable in sorted(self._variables, key=lambda x: str(x)):
if not isinstance(variable, GlobalVariable):
if not isinstance(variable, GlobalVariable) and variable.name not in param_names:
variable_type_mapping[variable.type].append(variable)

for variable_type, variables in sorted(variable_type_mapping.items(), key=lambda x: str(x)):
for chunked_variables in self._chunks(variables, self._vars_per_line):
variable_names = ", ".join([var.name for var in chunked_variables])
if variable_names in param_names:
continue
yield f"{variable_type} {variable_names};"
yield CExpressionGenerator.format_variables_declaration(
variable_type,
[var.name for var in chunked_variables]
) + ";"

@staticmethod
def _chunks(lst: List, n: int) -> Iterator[List]:
Expand Down
14 changes: 14 additions & 0 deletions tests/backend/test_codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from decompiler.structures.ast.ast_nodes import CodeNode, SeqNode, SwitchNode
from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree
from decompiler.structures.logic.logic_condition import LogicCondition
from decompiler.structures.pseudo import FunctionTypeDef
from decompiler.structures.pseudo.expressions import (
Constant,
DataflowObject,
Expand Down Expand Up @@ -75,6 +76,8 @@ def logic_cond(name: str, context) -> LogicCondition:
var_x_u = Variable("x_u", uint32)
var_y_u = Variable("y_u", uint32)
var_p = Variable("p", Pointer(int32))
var_fun_p = Variable("p", Pointer(FunctionTypeDef(0, int32, (int32,))))
var_fun_p0 = Variable("p0", Pointer(FunctionTypeDef(0, int32, (int32,))))

const_0 = Constant(0, int32)
const_1 = Constant(1, int32)
Expand Down Expand Up @@ -155,6 +158,15 @@ def test_empty_function_two_parameters(self):
r"^\s*int +test_function\(\s*int +a\s*,\s*int +b\s*\){\s*}\s*$", self._task(ast, params=[var_a.copy(), var_b.copy()])
)

def test_empty_function_two_function_parameters(self):
root = SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context()))
ast = AbstractSyntaxTree(root, {})
code_node = ast._add_code_node([])
ast._add_edge(root, code_node)
assert self._regex_matches(
r"^\s*int +test_function\(\s*int +\(\*\s*p\)\(int\)\s*,\s*int +\(\*\s*p0\)\(int\)\s*\){\s*}\s*$", self._task(ast, params=[var_fun_p.copy(), var_fun_p0.copy()])
)

def test_function_with_instruction(self):
root = SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context()))
ast = AbstractSyntaxTree(root, {})
Expand Down Expand Up @@ -1069,6 +1081,8 @@ def test_operation(self, op, expected):
(1, [var_x.copy(), var_y.copy(), var_x_f.copy(), var_y_f.copy()], "float x_f;\nfloat y_f;\nint x;\nint y;"),
(2, [var_x.copy(), var_y.copy(), var_x_f.copy(), var_y_f.copy()], "float x_f, y_f;\nint x, y;"),
(1, [var_x.copy(), var_y.copy(), var_p.copy()], "int x;\nint y;\nint * p;"),
(1, [var_x.copy(), var_y.copy(), var_fun_p.copy()], "int x;\nint y;\nint (* p)(int);"),
(2, [var_x.copy(), var_y.copy(), var_fun_p.copy(), var_fun_p0.copy()], "int x, y;\nint (* p)(int), (* p0)(int);"),
],
)
def test_variable_declaration(self, vars_per_line: int, variables: List[Variable], expected: str):
Expand Down