Skip to content

Commit

Permalink
Adjust code generation for the function pointer type. (#327)
Browse files Browse the repository at this point in the history
* Create draft PR for #322

* Add special case for function pointer type variable declarations

* Fix missing comma in function pointer variables declaration

* Add test cases

* Improve readability of format_variables_declaration(Type,list[str])

---------

Co-authored-by: rihi <[email protected]>
Co-authored-by: rihi <[email protected]>
  • Loading branch information
3 people authored Sep 7, 2023
1 parent 42a9781 commit 3bbd8be
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 9 deletions.
14 changes: 12 additions & 2 deletions decompiler/backend/cexpressiongenerator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
from ctypes import c_byte, c_int, c_long, c_short, c_ubyte, c_uint, c_ulong, c_ushort
from itertools import chain, repeat
from typing import Union

from decompiler.structures import pseudo as expressions
from decompiler.structures.pseudo import Float, Integer, OperationType, StringSymbol
from decompiler.structures.pseudo import Float, FunctionTypeDef, Integer, OperationType, Pointer, StringSymbol, Type
from decompiler.structures.pseudo import instructions as instructions
from decompiler.structures.pseudo import operations as operations
from decompiler.structures.visitors.interfaces import DataflowObjectVisitorInterface
Expand Down Expand Up @@ -361,3 +360,14 @@ def _format_string_literal(constant: expressions.Constant) -> str:
escaped = string_representation.replace('"', '\\"')
return f'"{escaped}"'
return f"{constant}"

@staticmethod
def format_variables_declaration(var_type: Type, var_names: list[str]) -> str:
""" Return a string representation of variable declarations."""
match var_type:
case Pointer(type=FunctionTypeDef() as fun_type):
parameter_names = ", ".join(str(parameter) for parameter in fun_type.parameters)
declarations_without_return_type = [f"(* {var_name})({parameter_names})" for var_name in var_names]
return f"{fun_type.return_type} {', '.join(declarations_without_return_type)}"
case _:
return f"{var_type} {', '.join(var_names)}"
6 changes: 5 additions & 1 deletion decompiler/backend/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from string import Template
from typing import Iterable, List

from decompiler.backend.cexpressiongenerator import CExpressionGenerator
from decompiler.backend.codevisitor import CodeVisitor
from decompiler.backend.variabledeclarations import GlobalDeclarationGenerator, LocalDeclarationGenerator
from decompiler.task import DecompilerTask
Expand Down Expand Up @@ -37,7 +38,10 @@ def generate_function(self, task: DecompilerTask) -> str:
return self.TEMPLATE.substitute(
return_type=task.function_return_type,
name=task.name,
parameters=", ".join(map(lambda param: f"{param.type} {param.name}", task.function_parameters)),
parameters=", ".join(map(
lambda param: CExpressionGenerator.format_variables_declaration(param.type, [param.name]),
task.function_parameters
)),
local_declarations=LocalDeclarationGenerator.from_task(task) if not task.failed else "",
function_body=CodeVisitor(task).visit(task.syntax_tree.root) if not task.failed else task.failure_message,
)
13 changes: 7 additions & 6 deletions decompiler/backend/variabledeclarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
from typing import Iterable, Iterator, List, Set

from decompiler.backend.cexpressiongenerator import CExpressionGenerator
from decompiler.structures.ast.ast_nodes import ForLoopNode, LoopNode
from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree
from decompiler.structures.pseudo import (
Expand Down Expand Up @@ -61,19 +62,19 @@ def visit_unary_operation(self, unary: UnaryOperation):
else:
self.visit(unary.operand.left)

def generate(self, param_names: list = []) -> Iterator[str]:
def generate(self, param_names: list[str] = []) -> Iterator[str]:
"""Generate a string containing the variable definitions for the visited variables."""
variable_type_mapping = defaultdict(list)
for variable in sorted(self._variables, key=lambda x: str(x)):
if not isinstance(variable, GlobalVariable):
if not isinstance(variable, GlobalVariable) and variable.name not in param_names:
variable_type_mapping[variable.type].append(variable)

for variable_type, variables in sorted(variable_type_mapping.items(), key=lambda x: str(x)):
for chunked_variables in self._chunks(variables, self._vars_per_line):
variable_names = ", ".join([var.name for var in chunked_variables])
if variable_names in param_names:
continue
yield f"{variable_type} {variable_names};"
yield CExpressionGenerator.format_variables_declaration(
variable_type,
[var.name for var in chunked_variables]
) + ";"

@staticmethod
def _chunks(lst: List, n: int) -> Iterator[List]:
Expand Down
14 changes: 14 additions & 0 deletions tests/backend/test_codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from decompiler.structures.ast.ast_nodes import CodeNode, SeqNode, SwitchNode
from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree
from decompiler.structures.logic.logic_condition import LogicCondition
from decompiler.structures.pseudo import FunctionTypeDef
from decompiler.structures.pseudo.expressions import (
Constant,
DataflowObject,
Expand Down Expand Up @@ -75,6 +76,8 @@ def logic_cond(name: str, context) -> LogicCondition:
var_x_u = Variable("x_u", uint32)
var_y_u = Variable("y_u", uint32)
var_p = Variable("p", Pointer(int32))
var_fun_p = Variable("p", Pointer(FunctionTypeDef(0, int32, (int32,))))
var_fun_p0 = Variable("p0", Pointer(FunctionTypeDef(0, int32, (int32,))))

const_0 = Constant(0, int32)
const_1 = Constant(1, int32)
Expand Down Expand Up @@ -155,6 +158,15 @@ def test_empty_function_two_parameters(self):
r"^\s*int +test_function\(\s*int +a\s*,\s*int +b\s*\){\s*}\s*$", self._task(ast, params=[var_a.copy(), var_b.copy()])
)

def test_empty_function_two_function_parameters(self):
root = SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context()))
ast = AbstractSyntaxTree(root, {})
code_node = ast._add_code_node([])
ast._add_edge(root, code_node)
assert self._regex_matches(
r"^\s*int +test_function\(\s*int +\(\*\s*p\)\(int\)\s*,\s*int +\(\*\s*p0\)\(int\)\s*\){\s*}\s*$", self._task(ast, params=[var_fun_p.copy(), var_fun_p0.copy()])
)

def test_function_with_instruction(self):
root = SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context()))
ast = AbstractSyntaxTree(root, {})
Expand Down Expand Up @@ -1069,6 +1081,8 @@ def test_operation(self, op, expected):
(1, [var_x.copy(), var_y.copy(), var_x_f.copy(), var_y_f.copy()], "float x_f;\nfloat y_f;\nint x;\nint y;"),
(2, [var_x.copy(), var_y.copy(), var_x_f.copy(), var_y_f.copy()], "float x_f, y_f;\nint x, y;"),
(1, [var_x.copy(), var_y.copy(), var_p.copy()], "int x;\nint y;\nint * p;"),
(1, [var_x.copy(), var_y.copy(), var_fun_p.copy()], "int x;\nint y;\nint (* p)(int);"),
(2, [var_x.copy(), var_y.copy(), var_fun_p.copy(), var_fun_p0.copy()], "int x, y;\nint (* p)(int), (* p0)(int);"),
],
)
def test_variable_declaration(self, vars_per_line: int, variables: List[Variable], expected: str):
Expand Down

0 comments on commit 3bbd8be

Please sign in to comment.