Skip to content

Commit

Permalink
Merge branch 'main' into issue-374-Improve_else-if_chaining
Browse files Browse the repository at this point in the history
  • Loading branch information
ebehner authored Jan 24, 2024
2 parents 3202710 + 1878760 commit f82dce6
Show file tree
Hide file tree
Showing 16 changed files with 826 additions and 147 deletions.
147 changes: 61 additions & 86 deletions decompiler/backend/variabledeclarations.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,54 @@
"""Module containing the visitors used to generate variable declarations."""
from collections import defaultdict
from typing import Iterable, Iterator, List, Set
from typing import Iterable, Iterator, List

from decompiler.backend.cexpressiongenerator import CExpressionGenerator
from decompiler.structures.ast.ast_nodes import ForLoopNode, LoopNode
from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree
from decompiler.structures.pseudo import (
Assignment,
BinaryOperation,
Constant,
DataflowObject,
Expression,
ExternConstant,
ExternFunctionPointer,
GlobalVariable,
Operation,
OperationType,
Pointer,
UnaryOperation,
Variable,
)
from decompiler.structures.pseudo.operations import MemberAccess
from decompiler.structures.visitors.ast_dataflowobjectvisitor import BaseAstDataflowObjectVisitor
from decompiler.task import DecompilerTask
from decompiler.util.insertion_ordered_set import InsertionOrderedSet
from decompiler.util.serialization.bytes_serializer import convert_bytes


class LocalDeclarationGenerator(BaseAstDataflowObjectVisitor):
"""Visits all nodes in the AST and produces the variable declarations."""

def __init__(self, vars_per_line: int = 1):
"""Initialize a new VariableCollector with an empty set of variables."""
self._variables: Set[Variable] = set()
self._vars_per_line: int = vars_per_line

@classmethod
def from_task(cls, task: DecompilerTask):
"""Class method for shorthand usage."""
param_names = list(param.name for param in task.function_parameters)
generator = cls(task.options.getint("code-generator.variable_declarations_per_line", fallback=1))
generator.visit_ast(task.syntax_tree)
return "\n".join(generator.generate(param_names))

def visit_assignment(self, instruction: Assignment):
"""Remember all defined variables."""
self._variables.update(instruction.definitions)

def visit_loop_node(self, node: LoopNode):
"""Visit the given loop node, taking node of the loop declaration."""
if isinstance(node, ForLoopNode) and isinstance(node.declaration, Assignment):
if isinstance(node.declaration.destination, Operation):
self._variables.add(node.declaration.destination[0])
else:
self._variables.add(node.declaration.destination)

def visit_unary_operation(self, unary: UnaryOperation):
"""Visit unary operations to remember all variables those memory location was read."""
if isinstance(unary, MemberAccess):
self._variables.add(unary.struct_variable)
if unary.operation == OperationType.address or unary.operation == OperationType.dereference:
if isinstance(unary.operand, Variable):
self._variables.add(unary.operand)
elif isinstance(unary.operand, BinaryOperation):
if isinstance(unary.operand.left, Variable):
self._variables.add(unary.operand.left)
else:
self.visit(unary.operand.left)

def generate(self, param_names: list[str] = []) -> Iterator[str]:
class LocalDeclarationGenerator:
@staticmethod
def from_task(task: DecompilerTask):
vars_per_line = task.options.getint("code-generator.variable_declarations_per_line", fallback=1)

parameter_names = {p.name for p in task.function_parameters}
variables = InsertionOrderedSet(LocalDeclarationGenerator._get_variables(task.syntax_tree))

return "\n".join(LocalDeclarationGenerator.generate(parameter_names, variables, vars_per_line))

@staticmethod
def _get_variables(ast: AbstractSyntaxTree) -> Iterator[Variable]:
for node in ast.nodes:
for obj in node.get_dataflow_objets(ast.condition_map):
for expression in obj.subexpressions():
if isinstance(expression, Variable):
yield expression

@staticmethod
def generate(parameter_names: Iterable[str], variables: Iterable[Variable], vars_per_line: int) -> Iterator[str]:
"""Generate a string containing the variable definitions for the visited variables."""

variable_type_mapping = defaultdict(list)
for variable in sorted(self._variables, key=lambda x: str(x)):
if not isinstance(variable, GlobalVariable) and variable.name not in param_names:
for variable in sorted(variables, key=lambda x: str(x)):
if not isinstance(variable, GlobalVariable) and variable.name not in parameter_names:
variable_type_mapping[variable.type].append(variable)

for variable_type, variables in sorted(variable_type_mapping.items(), key=lambda x: str(x)):
for chunked_variables in self._chunks(variables, self._vars_per_line):
for chunked_variables in LocalDeclarationGenerator._chunks(variables, vars_per_line):
yield CExpressionGenerator.format_variables_declaration(variable_type, [var.name for var in chunked_variables]) + ";"

@staticmethod
Expand All @@ -83,44 +59,43 @@ def _chunks(lst: List, n: int) -> Iterator[List]:


class GlobalDeclarationGenerator(BaseAstDataflowObjectVisitor):
"""Visits all nodes in the AST and produces the declarations of global variables."""
@staticmethod
def from_asts(asts: Iterable[AbstractSyntaxTree]) -> str:
global_variables, extern_constants = GlobalDeclarationGenerator._get_global_variables_and_constants(asts)
return "\n".join(GlobalDeclarationGenerator.generate(global_variables.__iter__(), extern_constants))

def __init__(self):
"""Generate a new declarator with an empty sets of visited globals."""
self._extern_constants: Set[ExternConstant] = set()
self._global_variables: Set[GlobalVariable] = set()
@staticmethod
def _get_global_variables_and_constants(asts: Iterable[AbstractSyntaxTree]) -> tuple[set[GlobalVariable], set[ExternConstant]]:
global_variables = InsertionOrderedSet()
extern_constants = InsertionOrderedSet()

# if this gets more complex, a visitor pattern should perhaps be used instead
def handle_obj(obj: DataflowObject):
match obj:
case GlobalVariable():
global_variables.add(obj)
if isinstance(obj.initial_value, Expression):
for subexpression in obj.initial_value.subexpressions():
handle_obj(subexpression)

case ExternConstant():
extern_constants.add(obj)

@classmethod
def from_asts(cls, asts: Iterable[AbstractSyntaxTree]) -> str:
"""Class method for shorthand usage."""
generator = cls()
for ast in asts:
generator.visit_ast(ast)
return "\n".join(generator.generate())
for node in ast.nodes:
for obj in node.get_dataflow_objets(ast.condition_map):
for expression in obj.subexpressions():
handle_obj(expression)

def generate(self) -> Iterator[str]:
"""Generate a string containing the variable definitions for the visited variables."""
for variable in self._global_variables:
yield f"extern {variable.type} {variable.name} = {self.get_initial_value(variable)};"
for constant in sorted(self._extern_constants, key=lambda x: x.value):
yield f"extern {constant.type} {constant.value};"
return global_variables, extern_constants

def visit_unary_operation(self, unary: UnaryOperation):
"""Visit an unary operation, visiting variable operands and nested operations along the way."""
if isinstance(unary.operand, UnaryOperation) or isinstance(unary.operand, Variable):
self.visit(unary.operand)

def visit_variable(self, expression: Variable):
"""Visit the given variable, remembering all visited global Variables."""
if isinstance(expression, GlobalVariable):
self._global_variables.add(expression)
if isinstance(expression.initial_value, UnaryOperation):
self.visit(expression.initial_value)

def visit_constant(self, expression: Constant):
"""Visit the given constant, checking if it has been defined externally."""
if isinstance(expression, ExternConstant):
self._extern_constants.add(expression)
@staticmethod
def generate(global_variables: Iterable[GlobalVariable], extern_constants: Iterable[ExternConstant]) -> Iterator[str]:
"""Generate all definitions"""
for variable in global_variables:
yield f"extern {variable.type} {variable.name} = {GlobalDeclarationGenerator.get_initial_value(variable)};"
for constant in sorted(extern_constants, key=lambda x: x.value):
yield f"extern {constant.type} {constant.value};"

@staticmethod
def get_initial_value(variable: GlobalVariable) -> str:
Expand Down
115 changes: 111 additions & 4 deletions decompiler/pipeline/controlflowanalysis/loop_utility_methods.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Optional

from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, CaseNode, CodeNode, ConditionNode, LoopNode, SeqNode, SwitchNode
from typing import Dict, List, Optional

from decompiler.structures.ast.ast_nodes import (
AbstractSyntaxTreeNode,
CaseNode,
CodeNode,
ConditionNode,
LoopNode,
SeqNode,
SwitchNode,
WhileLoopNode,
)
from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree
from decompiler.structures.logic.logic_condition import LogicCondition
from decompiler.structures.pseudo import Assignment, Condition, Variable
from decompiler.structures.pseudo import (
Assignment,
BinaryOperation,
Condition,
Constant,
Expression,
OperationType,
UnaryOperation,
Variable,
)
from decompiler.structures.visitors.assignment_visitor import AssignmentVisitor


Expand Down Expand Up @@ -210,3 +228,92 @@ def _requirement_without_reinitialization(ast: AbstractSyntaxTree, node: Abstrac
return True
elif variable in assignment.requirements:
return True


def _get_equalizable_last_definitions(loop_node: WhileLoopNode, continuation: AstInstruction) -> List[CodeNode]:
"""
Finds equalizable last definitions of the continuation instruction in the code nodes of a while loop containing continue statements.
:param loop_node: While-loop to search in
:param continuation: Instruction defining the for-loops modification
:return: List of equalizable last definitions, Empty list if no continue nodes or no equalizable nodes
:return: None if at least one continue node does not match the requirements
"""
if not (
continue_nodes := [
node for node in loop_node.body.get_descendant_code_nodes_interrupting_ancestor_loop() if node.does_end_with_continue
]
):
return continue_nodes

if not (_is_assignment_with_simple_binary_operation(continuation.instruction)):
return None

equalizable_nodes = []
for code_node in continue_nodes:
if (last_definition_index := _get_last_definition_index_of(code_node, continuation.instruction.destination)) == -1:
return None

last_definition = code_node.instructions[last_definition_index]
if not (isinstance(last_definition.value, Constant) or _is_assignment_with_simple_binary_operation(last_definition)):
return None

_unify_binary_operation_in_assignment(continuation.instruction)
equalizable_nodes.append(last_definition)
return equalizable_nodes


def _is_assignment_with_simple_binary_operation(assignment: Assignment) -> bool:
"""
Checks if an assignment has a simple binary operation as value and the used and defined variable is the same. A simple binary
operation means that it includes a variable and a constant and uses plus or minus as operation type.
"""
return (
isinstance(assignment.value, BinaryOperation)
and assignment.value.operation in {OperationType.plus, OperationType.minus}
and any(isinstance(operand, Constant) or _is_negated_constant_variable(operand, Constant) for operand in assignment.value.operands)
and any(isinstance(operand, Variable) or _is_negated_constant_variable(operand, Variable) for operand in assignment.value.operands)
and assignment.destination == _get_variable_in_binary_operation(assignment.value)
)


def _is_negated_constant_variable(operand: Expression, expression: Constant | Variable) -> bool:
"""Checks if an operand (constant or variable) is negated."""
return isinstance(operand, UnaryOperation) and operand.operation == OperationType.negate and isinstance(operand.operand, expression)


def _get_variable_in_binary_operation(binaryoperation: BinaryOperation) -> Variable:
"""Returns the used variable of a binary operation if available."""
for operand in binaryoperation.operands:
if isinstance(operand, Variable):
return operand
if _is_negated_constant_variable(operand, Variable):
return operand.operand
return None


def _unify_binary_operation_in_assignment(assignment: Assignment):
"""Brings a simple binary operation of an assignment into a unified representation like 'var = -var + const' instead of 'var = const - var'."""
if not assignment.value.operation == OperationType.plus:
assignment.substitute(
assignment.value,
BinaryOperation(OperationType.plus, [assignment.value.left, UnaryOperation(OperationType.negate, [assignment.value.right])]),
)

if any(isinstance(operand, Constant) for operand in assignment.value.left.subexpressions()):
assignment.substitute(assignment.value, BinaryOperation(OperationType.plus, [assignment.value.right, assignment.value.left]))


def _substract_continuation_from_last_definition(last_definition: Assignment, continuation: AstInstruction):
"""
Substracts the value of the continuation instruction from the last definition, which must be a simple binary operation or a constant,
defining the same value as the continuation instruction in the given code node.
:param last_definition: Last definition that is to be changed
:param continuation: Instruction defining the for-loops modification
"""
substracted_binary_operation = BinaryOperation(OperationType.minus, [last_definition.value, continuation.instruction.value.right])
if _is_negated_constant_variable(continuation.instruction.value.left, Variable):
last_definition.substitute(last_definition.value, UnaryOperation(OperationType.negate, [substracted_binary_operation]))
else:
last_definition.substitute(last_definition.value, substracted_binary_operation)
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from decompiler.pipeline.controlflowanalysis.loop_utility_methods import (
AstInstruction,
_find_continuation_instruction,
_get_equalizable_last_definitions,
_get_variable_initialisation,
_initialization_reaches_loop_node,
_is_single_instruction_loop_node,
_single_defininition_reaches_node,
_substract_continuation_from_last_definition,
)
from decompiler.pipeline.stage import PipelineStage
from decompiler.structures.ast.ast_nodes import ConditionNode, DoWhileLoopNode, ForLoopNode, WhileLoopNode
Expand Down Expand Up @@ -76,6 +78,7 @@ def run(self):
-> loop condition complexity < condition complexity
-> possible modification complexity < modification complexity
-> if condition is only a symbol: check condition type for allowed one
-> has a continue statement which must and can be equalized
If 'force_for_loops' is enabled, the complexity options are ignored and every while loop after the
initial transformation will be forced into a for loop with an empty declaration/modification
Expand All @@ -90,9 +93,6 @@ def run(self):
):
continue

if any(node.does_end_with_continue for node in loop_node.body.get_descendant_code_nodes_interrupting_ancestor_loop()):
continue

if not self._force_for_loops and loop_node.condition.get_complexity(self._ast.condition_map) > self._condition_max_complexity:
continue

Expand All @@ -103,6 +103,10 @@ def run(self):
continue
if not self._force_for_loops and continuation.instruction.complexity > self._modification_max_complexity:
continue
if (equalizable_last_definitions := _get_equalizable_last_definitions(loop_node, continuation)) is None:
continue
for last_definition in equalizable_last_definitions:
_substract_continuation_from_last_definition(last_definition, continuation)
self._replace_with_for_loop(loop_node, continuation, variable_init)
break

Expand Down
9 changes: 6 additions & 3 deletions decompiler/pipeline/preprocessing/missing_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from decompiler.structures.pseudo.instructions import Assignment, Instruction, Phi, Relation
from decompiler.structures.pseudo.operations import Call, ListOperation, OperationType, UnaryOperation
from decompiler.task import DecompilerTask
from decompiler.util.insertion_ordered_set import InsertionOrderedSet
from networkx import DiGraph

from .util import _init_basicblocks_of_definition, _init_basicblocks_usages_variable, _init_maps
Expand Down Expand Up @@ -129,7 +130,7 @@ def insert_missing_definitions(self):
- Depending whether the memory-changing instruction changes the aliased-variable we insert the definition as an assignment
(no change) or a relation (change).
"""
undefined_variables: Set[Variable] = self._get_undefined_variables()
undefined_variables: InsertionOrderedSet[Variable] = self._get_undefined_variables()
variable_copies: _VariableCopyPool = _VariableCopyPool(undefined_variables | self._def_map.defined_variables)
variable_copies.sort_copies_of(*undefined_variables)

Expand All @@ -143,7 +144,7 @@ def insert_missing_definitions(self):
self._insert_definition_if_undefined(variable, previous_ssa_labels, undefined_variables)
previous_ssa_labels.add(variable.ssa_label)

def _get_undefined_variables(self) -> Set[Variable]:
def _get_undefined_variables(self) -> InsertionOrderedSet[Variable]:
"""
Compute the set of undefined variables.
Expand All @@ -160,7 +161,9 @@ def _get_undefined_variables(self) -> Set[Variable]:
undefined_variables.add(aliased_variable)
return undefined_variables

def _insert_definition_if_undefined(self, variable: Variable, previous_ssa_labels: Set[int], undefined_variables: Set[Variable]):
def _insert_definition_if_undefined(
self, variable: Variable, previous_ssa_labels: Set[int], undefined_variables: InsertionOrderedSet[Variable]
):
"""Insert definition for the given variable if it is undefined or raises an error when it is a not an aliased variable."""
if variable in undefined_variables:
if not variable.is_aliased:
Expand Down
Loading

0 comments on commit f82dce6

Please sign in to comment.