Skip to content

Commit

Permalink
Rework variables declaration generation
Browse files Browse the repository at this point in the history
  • Loading branch information
rihi committed Nov 15, 2023
1 parent e85ac6d commit 3c68a43
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 126 deletions.
152 changes: 58 additions & 94 deletions decompiler/backend/variabledeclarations.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,45 @@
"""Module containing the visitors used to generate variable declarations."""
from collections import defaultdict
from typing import Iterable, Iterator, List, Set
from typing import Iterable, Iterator, List

from decompiler.backend.cexpressiongenerator import CExpressionGenerator
from decompiler.structures.ast.ast_nodes import ForLoopNode, LoopNode
from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree
from decompiler.structures.pseudo import (
Assignment,
BinaryOperation,
Constant,
ExternConstant,
ExternFunctionPointer,
GlobalVariable,
Operation,
OperationType,
Pointer,
UnaryOperation,
Variable,
)
from decompiler.structures.pseudo.operations import MemberAccess
from decompiler.structures.pseudo import DataflowObject, ExternConstant, ExternFunctionPointer, GlobalVariable, Operation, Pointer, Variable
from decompiler.structures.visitors.ast_dataflowobjectvisitor import BaseAstDataflowObjectVisitor
from decompiler.task import DecompilerTask
from decompiler.util.insertion_ordered_set import InsertionOrderedSet
from decompiler.util.serialization.bytes_serializer import convert_bytes


class LocalDeclarationGenerator(BaseAstDataflowObjectVisitor):
"""Visits all nodes in the AST and produces the variable declarations."""

def __init__(self, vars_per_line: int = 1):
"""Initialize a new VariableCollector with an empty set of variables."""
self._variables: Set[Variable] = set()
self._vars_per_line: int = vars_per_line

@classmethod
def from_task(cls, task: DecompilerTask):
"""Class method for shorthand usage."""
param_names = list(param.name for param in task.function_parameters)
generator = cls(task.options.getint("code-generator.variable_declarations_per_line", fallback=1))
generator.visit_ast(task.syntax_tree)
return "\n".join(generator.generate(param_names))

def visit_assignment(self, instruction: Assignment):
"""Remember all defined variables."""
self._variables.update(instruction.definitions)

def visit_loop_node(self, node: LoopNode):
"""Visit the given loop node, taking node of the loop declaration."""
if isinstance(node, ForLoopNode) and isinstance(node.declaration, Assignment):
if isinstance(node.declaration.destination, Operation):
self._variables.add(node.declaration.destination[0])
else:
self._variables.add(node.declaration.destination)

def visit_unary_operation(self, unary: UnaryOperation):
"""Visit unary operations to remember all variables those memory location was read."""
if isinstance(unary, MemberAccess):
self._variables.add(unary.struct_variable)
if unary.operation == OperationType.address or unary.operation == OperationType.dereference:
if isinstance(unary.operand, Variable):
self._variables.add(unary.operand)
elif isinstance(unary.operand, BinaryOperation):
if isinstance(unary.operand.left, Variable):
self._variables.add(unary.operand.left)
else:
self.visit(unary.operand.left)

def generate(self, param_names: list[str] = []) -> Iterator[str]:
class LocalDeclarationGenerator:
@staticmethod
def from_task(task: DecompilerTask):
vars_per_line = task.options.getint("code-generator.variable_declarations_per_line", fallback=1)

parameter_names = {p.name for p in task.function_parameters}
variables = InsertionOrderedSet(LocalDeclarationGenerator._get_variables(task.syntax_tree))

return "\n".join(LocalDeclarationGenerator.generate(parameter_names, variables, vars_per_line))

@staticmethod
def _get_variables(ast: AbstractSyntaxTree) -> Iterator[Variable]:
for node in ast.nodes:
for obj in node.get_dataflow_objets(ast.condition_map):
for expression in obj.subexpressions():
if isinstance(expression, Variable):
yield expression

@staticmethod
def generate(parameter_names: Iterable[str], variables: Iterable[Variable], vars_per_line: int) -> Iterator[str]:
"""Generate a string containing the variable definitions for the visited variables."""

variable_type_mapping = defaultdict(list)
for variable in sorted(self._variables, key=lambda x: str(x)):
if not isinstance(variable, GlobalVariable) and variable.name not in param_names:
for variable in sorted(variables, key=lambda x: str(x)):
if not isinstance(variable, GlobalVariable) and variable.name not in parameter_names:
variable_type_mapping[variable.type].append(variable)

for variable_type, variables in sorted(variable_type_mapping.items(), key=lambda x: str(x)):
for chunked_variables in self._chunks(variables, self._vars_per_line):
for chunked_variables in LocalDeclarationGenerator._chunks(variables, vars_per_line):
yield CExpressionGenerator.format_variables_declaration(variable_type, [var.name for var in chunked_variables]) + ";"

@staticmethod
Expand All @@ -83,44 +50,41 @@ def _chunks(lst: List, n: int) -> Iterator[List]:


class GlobalDeclarationGenerator(BaseAstDataflowObjectVisitor):
"""Visits all nodes in the AST and produces the declarations of global variables."""
@staticmethod
def from_asts(asts: Iterable[AbstractSyntaxTree]) -> str:
global_variables, extern_constants = GlobalDeclarationGenerator._get_global_variables_and_constants(asts)
return "\n".join(GlobalDeclarationGenerator.generate(global_variables.__iter__(), extern_constants))

@staticmethod
def _get_global_variables_and_constants(asts: Iterable[AbstractSyntaxTree]) -> tuple[set[GlobalVariable], set[ExternConstant]]:
global_variables = InsertionOrderedSet()
extern_constants = InsertionOrderedSet()

def handle_obj(obj: DataflowObject):
match obj:
case GlobalVariable():
global_variables.add(obj)
if isinstance(obj.initial_value, GlobalVariable):
handle_obj(obj.initial_value)

def __init__(self):
"""Generate a new declarator with an empty sets of visited globals."""
self._extern_constants: Set[ExternConstant] = set()
self._global_variables: Set[GlobalVariable] = set()
case ExternConstant():
extern_constants.add(obj)

@classmethod
def from_asts(cls, asts: Iterable[AbstractSyntaxTree]) -> str:
"""Class method for shorthand usage."""
generator = cls()
for ast in asts:
generator.visit_ast(ast)
return "\n".join(generator.generate())
for node in ast.nodes:
for obj in node.get_dataflow_objets(ast.condition_map):
for expression in obj.subexpressions():
handle_obj(expression)

def generate(self) -> Iterator[str]:
"""Generate a string containing the variable definitions for the visited variables."""
for variable in self._global_variables:
yield f"extern {variable.type} {variable.name} = {self.get_initial_value(variable)};"
for constant in sorted(self._extern_constants, key=lambda x: x.value):
yield f"extern {constant.type} {constant.value};"
return global_variables, extern_constants

def visit_unary_operation(self, unary: UnaryOperation):
"""Visit an unary operation, visiting variable operands and nested operations along the way."""
if isinstance(unary.operand, UnaryOperation) or isinstance(unary.operand, Variable):
self.visit(unary.operand)

def visit_variable(self, expression: Variable):
"""Visit the given variable, remembering all visited global Variables."""
if isinstance(expression, GlobalVariable):
self._global_variables.add(expression)
if isinstance(expression.initial_value, UnaryOperation):
self.visit(expression.initial_value)

def visit_constant(self, expression: Constant):
"""Visit the given constant, checking if it has been defined externally."""
if isinstance(expression, ExternConstant):
self._extern_constants.add(expression)
@staticmethod
def generate(global_variables: Iterable[GlobalVariable], extern_constants: Iterable[ExternConstant]) -> Iterator[str]:
"""Generate all definitions"""
for variable in global_variables:
yield f"extern {variable.type} {variable.name} = {GlobalDeclarationGenerator.get_initial_value(variable)};"
for constant in sorted(extern_constants, key=lambda x: x.value):
yield f"extern {constant.type} {constant.value};"

@staticmethod
def get_initial_value(variable: GlobalVariable) -> str:
Expand Down
48 changes: 47 additions & 1 deletion decompiler/structures/ast/ast_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,18 @@
from decompiler.structures.ast.reachability_graph import CaseDependencyGraph, SiblingReachability
from decompiler.structures.graphs.interface import GraphNodeInterface
from decompiler.structures.logic.logic_condition import LogicCondition
from decompiler.structures.pseudo import Assignment, Break, Condition, Constant, Continue, Expression, Instruction, Return, Variable
from decompiler.structures.pseudo import (
Assignment,
Break,
Condition,
Constant,
Continue,
DataflowObject,
Expression,
Instruction,
Return,
Variable,
)

if TYPE_CHECKING:
from decompiler.structures.ast.syntaxgraph import AbstractSyntaxInterface
Expand Down Expand Up @@ -207,6 +218,10 @@ def get_defined_variables(self, condition_map: Optional[Dict[LogicCondition, Con
"""Return all variables that are defined in this node."""
yield from ()

def get_dataflow_objets(self, condition_map: Optional[Dict[LogicCondition, Condition]] = None) -> Iterable[DataflowObject]:
return
yield


class VirtualRootNode(AbstractSyntaxTreeNode):
"""
Expand Down Expand Up @@ -451,6 +466,9 @@ def get_defined_variables(self, condition_map: Optional[Dict[LogicCondition, Con
for instruction in self.instructions:
yield from instruction.definitions

def get_dataflow_objets(self, condition_map: Optional[Dict[LogicCondition, Condition]] = None) -> Iterable[DataflowObject]:
yield from self.instructions


class ConditionNode(AbstractSyntaxTreeNode):
"""
Expand Down Expand Up @@ -597,6 +615,16 @@ def get_required_variables(self, condition_map: Optional[Dict[LogicCondition, Co
continue
yield from condition_map[symbol].requirements

def get_dataflow_objets(self, condition_map: Optional[Dict[LogicCondition, Condition]] = None) -> Iterable[DataflowObject]:
if not condition_map:
return

for symbol in self.condition.get_symbols():
if symbol in condition_map:
yield condition_map[symbol]
else:
logging.warning("LogicCondition not in condition map.")


class ConditionalNode(AbstractSyntaxTreeNode, ABC):
"""Abstract Base class for nodes with one child, i.e. TrueNodes, FalseNodes and CaseNodes."""
Expand Down Expand Up @@ -766,6 +794,16 @@ def get_required_variables(self, condition_map: Optional[Dict[LogicCondition, Co
continue
yield from condition_map[symbol].requirements

def get_dataflow_objets(self, condition_map: Optional[Dict[LogicCondition, Condition]] = None) -> Iterable[DataflowObject]:
if not condition_map:
return

for symbol in self.condition.get_symbols():
if symbol in condition_map:
yield condition_map[symbol]
else:
logging.warning("LogicCondition not in condition map.")


class WhileLoopNode(LoopNode):
"""Class for While Loops."""
Expand Down Expand Up @@ -872,6 +910,11 @@ def get_defined_variables(self, condition_map: Optional[Dict[LogicCondition, Con
yield from self.declaration.definitions
yield from self.modification.definitions

def get_dataflow_objets(self, condition_map: Optional[Dict[LogicCondition, Condition]] = None) -> Iterable[DataflowObject]:
yield self.declaration
yield self.modification
yield from super().get_dataflow_objets(condition_map)


class SwitchNode(AbstractSyntaxTreeNode):
"""
Expand Down Expand Up @@ -992,6 +1035,9 @@ def accept(self, visitor: ASTVisitorInterface[T]) -> T:
def get_required_variables(self, condition_map: Optional[Dict[LogicCondition, Condition]] = None) -> Iterable[Variable]:
yield from self.expression.requirements

def get_dataflow_objets(self, condition_map: Optional[Dict[LogicCondition, Condition]] = None) -> Iterable[DataflowObject]:
yield self.expression


class CaseNode(ConditionalNode):
"""
Expand Down
57 changes: 27 additions & 30 deletions tests/backend/test_codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def test_function_with_while_condition_loop(self):
ast._add_edges_from(((root, child_1), (root, child_2), (child_2, body)))
ast._code_node_reachability_graph.add_reachability(child_1, body)

regex = r"^%void +test_function\(%int +a%,%int +b%\)%{%int%c;%c%=%5%;%while%\(%x%==%5%\)%{%c%=%c%\+%5%;%}%}%$"
regex = r"^%void +test_function\(%int +a%,%int +b%\)%{%int%c;%int%x;%c%=%5%;%while%\(%x%==%5%\)%{%c%=%c%\+%5%;%}%}%$"
assert self._regex_matches(regex.replace("%", "\\s*"), self._task(ast, params=[var_a.copy(), var_b.copy()], return_type=void))

def test_function_with_do_while_condition_loop(self):
Expand All @@ -358,7 +358,7 @@ def test_function_with_do_while_condition_loop(self):
ast._code_node_reachability_graph.add_reachability(child_1, body)

assert self._regex_matches(
r"^%void +test_function\(%int +a%,%int +b%\)%{%int%c;%c%=%5%;%do%{%c%=%c%\+%5%;%}%while%\(%x%==%5%\);%}%$".replace("%", "\\s*"),
r"^%void +test_function\(%int +a%,%int +b%\)%{%int%c;%int%x;%c%=%5%;%do%{%c%=%c%\+%5%;%}%while%\(%x%==%5%\);%}%$".replace("%", "\\s*"),
self._task(ast, params=[var_a.copy(), var_b.copy()], return_type=void),
)

Expand Down Expand Up @@ -408,7 +408,7 @@ def test_function_nested_loop(self):
ast._code_node_reachability_graph.add_reachability(child_1, nested_loop_body)

regex = (
r"^%void +test_function\(%int +a%,%int +b%\)%{%int%c;%c%=%5%;%" r"while%\(%true%\)%{%while%\(%x%!=%5%\)%{%c%=%c%\+%5%;%}%}%}%$"
r"^%void +test_function\(%int +a%,%int +b%\)%{%int%c;%int%x;%c%=%5%;%" r"while%\(%true%\)%{%while%\(%x%!=%5%\)%{%c%=%c%\+%5%;%}%}%}%$"
)
assert self._regex_matches(regex.replace("%", r"\s*"), self._task(ast, params=[var_a.copy(), var_b.copy()], return_type=void))

Expand All @@ -421,7 +421,7 @@ def test_varvisitor_condition_as_var(self):
ast._add_edge(root, condition_node)

assert self._regex_matches(
r"^%bool +test_function\(%\)%{%if%\(%c%\)%{return%c%;%}%}%$".replace("%", "\\s*"), self._task(ast, return_type=bool1)
r"^%bool +test_function\(%\)%{%int%c;%if%\(%c%\)%{return%c%;%}%}%$".replace("%", "\\s*"), self._task(ast, return_type=bool1)
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -1140,25 +1140,7 @@ def test_escaped_string_constant(self, expr, expected):

class TestLocalDeclarationGenerator:
@pytest.mark.parametrize(
"op, expected",
[
(ListOperation([]), []),
(ListOperation([var_x.copy()]), []),
(UnaryOperation(OperationType.negate, [var_x.copy()]), []),
(BinaryOperation(OperationType.minus, [var_x.copy(), const_3.copy()]), []),
(BinaryOperation(OperationType.minus, [var_x.copy(), var_y.copy()]), []),
(Assignment(var_x.copy(), Constant(3)), ["int x;"]),
(Assignment(ListOperation([var_x.copy(), var_y.copy()]), Call(FunctionSymbol("foo", 0), [var_x.copy()])), ["int x;", "int y;"]),
],
)
def test_operation(self, op, expected):
"""Ensure variables are generated for operations."""
var_visitor = LocalDeclarationGenerator()
var_visitor.visit_subexpressions(op)
assert list(var_visitor.generate()) == expected

@pytest.mark.parametrize(
"vars_per_line, variables, expected",
["vars_per_line", "variables", "expected"],
[
(1, [var_x.copy(), var_y.copy()], "int x;\nint y;"),
(2, [var_x.copy(), var_y.copy()], "int x, y;"),
Expand All @@ -1176,7 +1158,7 @@ def test_variable_declaration(self, vars_per_line: int, variables: List[Variable
options = _generate_options(var_declarations_per_line=vars_per_line)
ast = AbstractSyntaxTree(
CodeNode(
ListOperation([Assignment(var, const_1.copy()) for var in variables]),
[Assignment(var, const_1.copy()) for var in variables],
LogicCondition.initialize_true(LogicCondition.generate_new_context()),
),
{},
Expand All @@ -1195,16 +1177,31 @@ class TestGlobalVisitor:
)
def test_operation(self, op):
"""Ensure that GlobalVariable and ExternConstant are generated for global printing"""
global_visitor = GlobalDeclarationGenerator()
global_visitor.visit_subexpressions(op)
assert len(list(global_visitor.generate())) != 0
ast = AbstractSyntaxTree(
CodeNode(
[Assignment(var_a, op)],
LogicCondition.initialize_true(LogicCondition.generate_new_context()),
),
{}
)

assert len(GlobalDeclarationGenerator.from_asts([ast])) != 0

def test_nested_global_variable(self):
"""Ensure that GlobalVariableVisitor can visit global variables nested within a global variable"""

var1 = ExternFunctionPointer("ExternFunction")
var2 = GlobalVariable("var_glob1", initial_value=var1)
var3 = GlobalVariable("var_glob2", initial_value=var2)
var4 = GlobalVariable("var_glob3", initial_value=var3)
global_visitor = GlobalDeclarationGenerator()
global_visitor.visit_subexpressions(ListOperation([var2, var3, var4]))
assert len(global_visitor._global_variables) == 3

ast = AbstractSyntaxTree(
CodeNode(
[Assignment(var_a, var4)],
LogicCondition.initialize_true(LogicCondition.generate_new_context()),
),
{}
)

global_variables, _ = GlobalDeclarationGenerator._get_global_variables_and_constants([ast])
assert len(global_variables) == 3
Loading

0 comments on commit 3c68a43

Please sign in to comment.