Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
rihi committed Nov 15, 2023
1 parent e85ac6d commit c0e5aa0
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 57 deletions.
71 changes: 25 additions & 46 deletions decompiler/backend/variabledeclarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,76 +3,55 @@
from typing import Iterable, Iterator, List, Set

from decompiler.backend.cexpressiongenerator import CExpressionGenerator
from decompiler.structures.ast.ast_nodes import ForLoopNode, LoopNode
from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree
from decompiler.structures.pseudo import (
Assignment,
BinaryOperation,
Constant,
ExternConstant,
ExternFunctionPointer,
GlobalVariable,
Operation,
OperationType,
Pointer,
UnaryOperation,
Variable,
)
from decompiler.structures.pseudo.operations import MemberAccess
from decompiler.structures.visitors.ast_dataflowobjectvisitor import BaseAstDataflowObjectVisitor
from decompiler.task import DecompilerTask
from decompiler.util.insertion_ordered_set import InsertionOrderedSet
from decompiler.util.serialization.bytes_serializer import convert_bytes


class LocalDeclarationGenerator(BaseAstDataflowObjectVisitor):
class LocalDeclarationGenerator:
"""Visits all nodes in the AST and produces the variable declarations."""

def __init__(self, vars_per_line: int = 1):
"""Initialize a new VariableCollector with an empty set of variables."""
self._variables: Set[Variable] = set()
self._vars_per_line: int = vars_per_line

@classmethod
def from_task(cls, task: DecompilerTask):
@staticmethod
def from_task(task: DecompilerTask):
"""Class method for shorthand usage."""
param_names = list(param.name for param in task.function_parameters)
generator = cls(task.options.getint("code-generator.variable_declarations_per_line", fallback=1))
generator.visit_ast(task.syntax_tree)
return "\n".join(generator.generate(param_names))

def visit_assignment(self, instruction: Assignment):
"""Remember all defined variables."""
self._variables.update(instruction.definitions)

def visit_loop_node(self, node: LoopNode):
"""Visit the given loop node, taking node of the loop declaration."""
if isinstance(node, ForLoopNode) and isinstance(node.declaration, Assignment):
if isinstance(node.declaration.destination, Operation):
self._variables.add(node.declaration.destination[0])
else:
self._variables.add(node.declaration.destination)
vars_per_line = task.options.getint("code-generator.variable_declarations_per_line", fallback=1)

def visit_unary_operation(self, unary: UnaryOperation):
"""Visit unary operations to remember all variables those memory location was read."""
if isinstance(unary, MemberAccess):
self._variables.add(unary.struct_variable)
if unary.operation == OperationType.address or unary.operation == OperationType.dereference:
if isinstance(unary.operand, Variable):
self._variables.add(unary.operand)
elif isinstance(unary.operand, BinaryOperation):
if isinstance(unary.operand.left, Variable):
self._variables.add(unary.operand.left)
else:
self.visit(unary.operand.left)

def generate(self, param_names: list[str] = []) -> Iterator[str]:
parameter_names = {p.name for p in task.function_parameters}
variables = InsertionOrderedSet(LocalDeclarationGenerator._get_variables(task.syntax_tree))

return "\n".join(LocalDeclarationGenerator.generate(parameter_names, variables, vars_per_line))

@staticmethod
def _get_variables(ast: AbstractSyntaxTree) -> Iterator[Variable]:
for node in ast.nodes:
for obj in node.get_dataflow_objets(ast.condition_map):
for expression in obj.subexpressions():
if isinstance(expression, Variable):
yield expression

@staticmethod
def generate(parameter_names: set[str], variables: set[Variable], vars_per_line: int) -> Iterator[str]:
"""Generate a string containing the variable definitions for the visited variables."""

variable_type_mapping = defaultdict(list)
for variable in sorted(self._variables, key=lambda x: str(x)):
if not isinstance(variable, GlobalVariable) and variable.name not in param_names:
for variable in sorted(variables, key=lambda x: str(x)):
if not isinstance(variable, GlobalVariable) and variable.name not in parameter_names:
variable_type_mapping[variable.type].append(variable)

for variable_type, variables in sorted(variable_type_mapping.items(), key=lambda x: str(x)):
for chunked_variables in self._chunks(variables, self._vars_per_line):
for chunked_variables in LocalDeclarationGenerator._chunks(variables, vars_per_line):
yield CExpressionGenerator.format_variables_declaration(variable_type, [var.name for var in chunked_variables]) + ";"

@staticmethod
Expand Down
48 changes: 47 additions & 1 deletion decompiler/structures/ast/ast_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,18 @@
from decompiler.structures.ast.reachability_graph import CaseDependencyGraph, SiblingReachability
from decompiler.structures.graphs.interface import GraphNodeInterface
from decompiler.structures.logic.logic_condition import LogicCondition
from decompiler.structures.pseudo import Assignment, Break, Condition, Constant, Continue, Expression, Instruction, Return, Variable
from decompiler.structures.pseudo import (
Assignment,
Break,
Condition,
Constant,
Continue,
DataflowObject,
Expression,
Instruction,
Return,
Variable,
)

if TYPE_CHECKING:
from decompiler.structures.ast.syntaxgraph import AbstractSyntaxInterface
Expand Down Expand Up @@ -207,6 +218,10 @@ def get_defined_variables(self, condition_map: Optional[Dict[LogicCondition, Con
"""Return all variables that are defined in this node."""
yield from ()

def get_dataflow_objets(self, condition_map: Optional[Dict[LogicCondition, Condition]] = None) -> Iterable[DataflowObject]:
return
yield


class VirtualRootNode(AbstractSyntaxTreeNode):
"""
Expand Down Expand Up @@ -451,6 +466,9 @@ def get_defined_variables(self, condition_map: Optional[Dict[LogicCondition, Con
for instruction in self.instructions:
yield from instruction.definitions

def get_dataflow_objets(self, condition_map: Optional[Dict[LogicCondition, Condition]] = None) -> Iterable[DataflowObject]:
yield from self.instructions


class ConditionNode(AbstractSyntaxTreeNode):
"""
Expand Down Expand Up @@ -597,6 +615,16 @@ def get_required_variables(self, condition_map: Optional[Dict[LogicCondition, Co
continue
yield from condition_map[symbol].requirements

def get_dataflow_objets(self, condition_map: Optional[Dict[LogicCondition, Condition]] = None) -> Iterable[DataflowObject]:
if not condition_map:
return

for symbol in self.condition.get_symbols():
if symbol in condition_map:
yield condition_map[symbol]
else:
logging.warning("LogicCondition not in condition map.")


class ConditionalNode(AbstractSyntaxTreeNode, ABC):
"""Abstract Base class for nodes with one child, i.e. TrueNodes, FalseNodes and CaseNodes."""
Expand Down Expand Up @@ -766,6 +794,16 @@ def get_required_variables(self, condition_map: Optional[Dict[LogicCondition, Co
continue
yield from condition_map[symbol].requirements

def get_dataflow_objets(self, condition_map: Optional[Dict[LogicCondition, Condition]] = None) -> Iterable[DataflowObject]:
if not condition_map:
return

for symbol in self.condition.get_symbols():
if symbol in condition_map:
yield condition_map[symbol]
else:
logging.warning("LogicCondition not in condition map.")


class WhileLoopNode(LoopNode):
"""Class for While Loops."""
Expand Down Expand Up @@ -872,6 +910,11 @@ def get_defined_variables(self, condition_map: Optional[Dict[LogicCondition, Con
yield from self.declaration.definitions
yield from self.modification.definitions

def get_dataflow_objets(self, condition_map: Optional[Dict[LogicCondition, Condition]] = None) -> Iterable[DataflowObject]:
yield self.declaration
yield self.modification
yield from super().get_dataflow_objets(condition_map)


class SwitchNode(AbstractSyntaxTreeNode):
"""
Expand Down Expand Up @@ -992,6 +1035,9 @@ def accept(self, visitor: ASTVisitorInterface[T]) -> T:
def get_required_variables(self, condition_map: Optional[Dict[LogicCondition, Condition]] = None) -> Iterable[Variable]:
yield from self.expression.requirements

def get_dataflow_objets(self, condition_map: Optional[Dict[LogicCondition, Condition]] = None) -> Iterable[DataflowObject]:
yield self.expression


class CaseNode(ConditionalNode):
"""
Expand Down
20 changes: 10 additions & 10 deletions tests/backend/test_codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def test_function_with_while_condition_loop(self):
ast._add_edges_from(((root, child_1), (root, child_2), (child_2, body)))
ast._code_node_reachability_graph.add_reachability(child_1, body)

regex = r"^%void +test_function\(%int +a%,%int +b%\)%{%int%c;%c%=%5%;%while%\(%x%==%5%\)%{%c%=%c%\+%5%;%}%}%$"
regex = r"^%void +test_function\(%int +a%,%int +b%\)%{%int%c;%int%x;%c%=%5%;%while%\(%x%==%5%\)%{%c%=%c%\+%5%;%}%}%$"
assert self._regex_matches(regex.replace("%", "\\s*"), self._task(ast, params=[var_a.copy(), var_b.copy()], return_type=void))

def test_function_with_do_while_condition_loop(self):
Expand All @@ -358,7 +358,7 @@ def test_function_with_do_while_condition_loop(self):
ast._code_node_reachability_graph.add_reachability(child_1, body)

assert self._regex_matches(
r"^%void +test_function\(%int +a%,%int +b%\)%{%int%c;%c%=%5%;%do%{%c%=%c%\+%5%;%}%while%\(%x%==%5%\);%}%$".replace("%", "\\s*"),
r"^%void +test_function\(%int +a%,%int +b%\)%{%int%c;%int%x;%c%=%5%;%do%{%c%=%c%\+%5%;%}%while%\(%x%==%5%\);%}%$".replace("%", "\\s*"),
self._task(ast, params=[var_a.copy(), var_b.copy()], return_type=void),
)

Expand Down Expand Up @@ -408,7 +408,7 @@ def test_function_nested_loop(self):
ast._code_node_reachability_graph.add_reachability(child_1, nested_loop_body)

regex = (
r"^%void +test_function\(%int +a%,%int +b%\)%{%int%c;%c%=%5%;%" r"while%\(%true%\)%{%while%\(%x%!=%5%\)%{%c%=%c%\+%5%;%}%}%}%$"
r"^%void +test_function\(%int +a%,%int +b%\)%{%int%c;%int%x;%c%=%5%;%" r"while%\(%true%\)%{%while%\(%x%!=%5%\)%{%c%=%c%\+%5%;%}%}%}%$"
)
assert self._regex_matches(regex.replace("%", r"\s*"), self._task(ast, params=[var_a.copy(), var_b.copy()], return_type=void))

Expand All @@ -421,7 +421,7 @@ def test_varvisitor_condition_as_var(self):
ast._add_edge(root, condition_node)

assert self._regex_matches(
r"^%bool +test_function\(%\)%{%if%\(%c%\)%{return%c%;%}%}%$".replace("%", "\\s*"), self._task(ast, return_type=bool1)
r"^%bool +test_function\(%\)%{%int%c;%if%\(%c%\)%{return%c%;%}%}%$".replace("%", "\\s*"), self._task(ast, return_type=bool1)
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -1143,17 +1143,17 @@ class TestLocalDeclarationGenerator:
"op, expected",
[
(ListOperation([]), []),
(ListOperation([var_x.copy()]), []),
(UnaryOperation(OperationType.negate, [var_x.copy()]), []),
(BinaryOperation(OperationType.minus, [var_x.copy(), const_3.copy()]), []),
(BinaryOperation(OperationType.minus, [var_x.copy(), var_y.copy()]), []),
(ListOperation([var_x.copy()]), ["int x;"]),
(UnaryOperation(OperationType.negate, [var_x.copy()]), ["int x;"]),
(BinaryOperation(OperationType.minus, [var_x.copy(), const_3.copy()]), ["int x;"]),
(BinaryOperation(OperationType.minus, [var_x.copy(), var_y.copy()]), ["int x;"]),
(Assignment(var_x.copy(), Constant(3)), ["int x;"]),
(Assignment(ListOperation([var_x.copy(), var_y.copy()]), Call(FunctionSymbol("foo", 0), [var_x.copy()])), ["int x;", "int y;"]),
],
)
def test_operation(self, op, expected):
def test_get_variables(self, op, expected):
"""Ensure variables are generated for operations."""
var_visitor = LocalDeclarationGenerator()
LocalDeclarationGenerator._get_variables()
var_visitor.visit_subexpressions(op)
assert list(var_visitor.generate()) == expected

Expand Down

0 comments on commit c0e5aa0

Please sign in to comment.