Skip to content

Commit

Permalink
Merge branch 'main' into improve_insert_missing_cases
Browse files Browse the repository at this point in the history
  • Loading branch information
steffenenders authored Mar 28, 2024
2 parents 9df83d0 + 43c19d9 commit a954aeb
Show file tree
Hide file tree
Showing 26 changed files with 547 additions and 587 deletions.
69 changes: 66 additions & 3 deletions decompiler/backend/cexpressiongenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,37 @@
from itertools import chain, repeat

from decompiler.structures import pseudo as expressions
from decompiler.structures.pseudo import Float, FunctionTypeDef, Integer, OperationType, Pointer, StringSymbol, Type
from decompiler.structures.pseudo import (
ArrayType,
CustomType,
Float,
FunctionTypeDef,
GlobalVariable,
Integer,
OperationType,
Pointer,
Type,
)
from decompiler.structures.pseudo import instructions as instructions
from decompiler.structures.pseudo import operations as operations
from decompiler.structures.pseudo.operations import MemberAccess
from decompiler.structures.visitors.interfaces import DataflowObjectVisitorInterface
from decompiler.util.integer_util import normalize_int

MAX_GLOBAL_INIT_LENGTH = 128


def inline_global_variable(var) -> bool:
if not var.is_constant:
return False
match var.type:
case ArrayType():
if var.type.type in [Integer.char(), CustomType.wchar16(), CustomType.wchar32()]:
return True
case _:
return False
return False


class CExpressionGenerator(DataflowObjectVisitorInterface):
"""Generate C code for Expressions.
Expand Down Expand Up @@ -145,17 +169,52 @@ def visit_unknown_expression(self, expr: expressions.UnknownExpression) -> str:

def visit_constant(self, expr: expressions.Constant) -> str:
"""Return constant in a format that will be parsed correctly by a compiler."""
if isinstance(expr, expressions.NotUseableConstant):
return expr.value
if isinstance(expr, expressions.Symbol):
return expr.name
if isinstance(expr.type, Integer):
value = self._get_integer_literal_value(expr)
return self._format_integer_literal(expr.type, value)
if isinstance(expr, StringSymbol):
return expr.name
if isinstance(expr.type, Pointer):
match (expr.value):
case (
str()
): # Technically every string will be lifted as an ConstantArray. Will still leave this, if someone creates a string as a char*
string = expr.value if len(expr.value) <= MAX_GLOBAL_INIT_LENGTH else expr.value[:MAX_GLOBAL_INIT_LENGTH] + "..."
match expr.type.type:
case CustomType(text="wchar16") | CustomType(text="wchar32"):
return f'L"{string}"'
case _:
return f'"{string}"'
case bytes():
val = "".join("\\x{:02x}".format(x) for x in expr.value)
return f'"{val}"' if len(val) <= MAX_GLOBAL_INIT_LENGTH else f'"{val[:MAX_GLOBAL_INIT_LENGTH]}..."'

return self._format_string_literal(expr)

def visit_constant_composition(self, expr: expressions.ConstantComposition):
"""Visit a Constant Array."""
match expr.type.type:
case CustomType(text="wchar16") | CustomType(text="wchar32"):
val = "".join([x.value for x in expr.value])
return f'L"{val}"' if len(val) <= MAX_GLOBAL_INIT_LENGTH else f'L"{val[:MAX_GLOBAL_INIT_LENGTH]}..."'
case Integer(8):
val = "".join([x.value for x in expr.value][:MAX_GLOBAL_INIT_LENGTH])
return f'"{val}"' if len(val) <= MAX_GLOBAL_INIT_LENGTH else f'"{val[:MAX_GLOBAL_INIT_LENGTH]}..."'
case _:
return f'{", ".join([self.visit(x) for x in expr.value])}' # Todo: Should we print every member? Could get pretty big

def visit_variable(self, expr: expressions.Variable) -> str:
"""Return a string representation of the variable."""
return f"{expr.name}" if (label := expr.ssa_label) is None else f"{expr.name}_{label}"

def visit_global_variable(self, expr: expressions.GlobalVariable):
"""Inline a global variable if its initial value is constant and not of void type"""
if inline_global_variable(expr):
return self.visit(expr.initial_value)
return expr.name

def visit_register_pair(self, expr: expressions.Variable) -> str:
"""Return a string representation of the register pair and log."""
logging.error(f"generated code for register pair {expr}")
Expand All @@ -168,6 +227,8 @@ def visit_list_operation(self, op: operations.ListOperation) -> str:
def visit_unary_operation(self, op: operations.UnaryOperation) -> str:
"""Return a string representation of the given unary operation (e.g. !a or &a)."""
operand = self._visit_bracketed(op.operand) if self._has_lower_precedence(op.operand, op) else self.visit(op.operand)
if op.operation == OperationType.address and isinstance(op.operand, GlobalVariable) and isinstance(op.operand.type, ArrayType):
return operand
if isinstance(op, MemberAccess):
operator_str = "->" if isinstance(op.struct_variable.type, Pointer) else self.C_SYNTAX[op.operation]
return f"{operand}{operator_str}{op.member_name}"
Expand Down Expand Up @@ -353,5 +414,7 @@ def format_variables_declaration(var_type: Type, var_names: list[str]) -> str:
parameter_names = ", ".join(str(parameter) for parameter in fun_type.parameters)
declarations_without_return_type = [f"(* {var_name})({parameter_names})" for var_name in var_names]
return f"{fun_type.return_type} {', '.join(declarations_without_return_type)}"
case ArrayType():
return f"{var_type.type}* {', '.join(var_names)}"
case _:
return f"{var_type} {', '.join(var_names)}"
94 changes: 34 additions & 60 deletions decompiler/backend/variabledeclarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,13 @@
from collections import defaultdict
from typing import Iterable, Iterator, List

from decompiler.backend.cexpressiongenerator import CExpressionGenerator
from decompiler.backend.cexpressiongenerator import CExpressionGenerator, inline_global_variable
from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree
from decompiler.structures.pseudo import (
DataflowObject,
Expression,
ExternConstant,
ExternFunctionPointer,
GlobalVariable,
Operation,
Pointer,
Variable,
)
from decompiler.structures.pseudo import GlobalVariable, Integer, Variable
from decompiler.structures.pseudo.typing import ArrayType, CustomType, Pointer
from decompiler.structures.visitors.ast_dataflowobjectvisitor import BaseAstDataflowObjectVisitor
from decompiler.task import DecompilerTask
from decompiler.util.insertion_ordered_set import InsertionOrderedSet
from decompiler.util.serialization.bytes_serializer import convert_bytes


class LocalDeclarationGenerator:
Expand Down Expand Up @@ -60,57 +51,40 @@ def _chunks(lst: List, n: int) -> Iterator[List]:


class GlobalDeclarationGenerator(BaseAstDataflowObjectVisitor):
@staticmethod
def from_asts(asts: Iterable[AbstractSyntaxTree]) -> str:
global_variables, extern_constants = GlobalDeclarationGenerator._get_global_variables_and_constants(asts)
return "\n".join(GlobalDeclarationGenerator.generate(global_variables.__iter__(), extern_constants))
def __init__(self) -> None:
self._global_vars = InsertionOrderedSet()
super().__init__()

@staticmethod
def _get_global_variables_and_constants(asts: Iterable[AbstractSyntaxTree]) -> tuple[set[GlobalVariable], set[ExternConstant]]:
global_variables = InsertionOrderedSet()
extern_constants = InsertionOrderedSet()

# if this gets more complex, a visitor pattern should perhaps be used instead
def handle_obj(obj: DataflowObject):
match obj:
case GlobalVariable():
global_variables.add(obj)
if isinstance(obj.initial_value, Expression):
for subexpression in obj.initial_value.subexpressions():
handle_obj(subexpression)

case ExternConstant():
extern_constants.add(obj)

for ast in asts:
for node in ast.nodes:
for obj in node.get_dataflow_objets(ast.condition_map):
for expression in obj.subexpressions():
handle_obj(expression)

return global_variables, extern_constants

@staticmethod
def generate(global_variables: Iterable[GlobalVariable], extern_constants: Iterable[ExternConstant]) -> Iterator[str]:
def _generate_definitions(global_variables: set[GlobalVariable]) -> Iterator[str]:
"""Generate all definitions"""
for variable in global_variables:
yield f"extern {variable.type} {variable.name} = {GlobalDeclarationGenerator.get_initial_value(variable)};"
for constant in sorted(extern_constants, key=lambda x: x.value):
yield f"extern {constant.type} {constant.value};"
base = f"extern {'const ' if variable.is_constant else ''}"
match variable.type:
case ArrayType():
br, bl = "", ""
if not variable.type.type in [Integer.char(), CustomType.wchar16(), CustomType.wchar32()]:
br, bl = "{", "}"
yield f"{base}{variable.type.type} {variable.name}[{hex(variable.type.elements)}] = {br}{CExpressionGenerator().visit(variable.initial_value)}{bl};"
case _:
yield f"{base}{variable.type} {variable.name} = {CExpressionGenerator().visit(variable.initial_value)};"

@staticmethod
def get_initial_value(variable: GlobalVariable) -> str:
"""Get a string representation of the initial value of the given variable."""
if isinstance(variable.initial_value, GlobalVariable):
return variable.initial_value.name
elif isinstance(variable.initial_value, ExternFunctionPointer):
return str(variable.initial_value.value)
if isinstance(variable.initial_value, bytes):
return str(convert_bytes(variable.initial_value, variable.type))
if isinstance(operation := variable.initial_value, Operation):
for requirement in operation.requirements:
if isinstance(requirement, GlobalVariable):
requirement.unsubscript()
if isinstance(variable.type, Pointer) and isinstance(variable.initial_value, int):
return hex(variable.initial_value)
return str(variable.initial_value)
def from_asts(asts: Iterable[AbstractSyntaxTree]) -> str:
"""Generate"""
globals = InsertionOrderedSet()
for ast in asts:
globals |= GlobalDeclarationGenerator().visit_ast(ast)
return "\n".join(GlobalDeclarationGenerator._generate_definitions(globals))

def visit_ast(self, ast: AbstractSyntaxTree) -> InsertionOrderedSet:
"""Visit ast and return all collected global variables"""
super().visit_ast(ast)
return self._global_vars

def visit_global_variable(self, expr: GlobalVariable):
"""Visit global variables. Only collect ones which will not be inlined by CExprGenerator. Strip SSA label to remove duplicates"""
if not inline_global_variable(expr):
self._global_vars.add(expr.copy(ssa_label=0, ssa_name=None))
if not expr.is_constant or expr.type == Pointer(CustomType.void()):
self._global_vars.add(expr.copy(ssa_label=0, ssa_name=None))
55 changes: 33 additions & 22 deletions decompiler/frontend/binaryninja/handlers/constants.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
"""Module implementing the ConstantHandler for the binaryninja frontend."""

import math
from typing import Union

from binaryninja import BinaryView, DataVariable, SectionSemantics, SymbolType, Type, mediumlevelil
from binaryninja import DataVariable, SymbolType, Type, mediumlevelil
from decompiler.frontend.lifter import Handler
from decompiler.structures.pseudo import Constant, GlobalVariable, Integer, NotUseableConstant, Pointer, StringSymbol
from decompiler.structures.pseudo import (
Constant,
CustomType,
GlobalVariable,
Integer,
NotUseableConstant,
OperationType,
Pointer,
Symbol,
UnaryOperation,
)

BYTE_SIZE = 8

Expand Down Expand Up @@ -36,36 +47,36 @@ def lift_integer_literal(value: int, **kwargs) -> Constant:
return Constant(value, vartype=Integer.int32_t())

def lift_constant_data(self, pointer: mediumlevelil.MediumLevelILConstData, **kwargs) -> Constant:
"""Lift const data as a non mute able constant string"""
return StringSymbol(str(pointer), pointer.address)
"""Lift data as a non mute able constant string (register string)"""
return NotUseableConstant(str(pointer))

def lift_constant_pointer(self, pointer: mediumlevelil.MediumLevelILConstPtr, **kwargs):
def lift_constant_pointer(self, pointer: mediumlevelil.MediumLevelILConstPtr, **kwargs) -> Union[GlobalVariable, Symbol]:
"""Lift the given constant pointer, e.g. &0x80000."""
view = pointer.function.view

if variable := view.get_data_var_at(pointer.constant):
return self._lifter.lift(variable, view=view, parent=pointer)
res = self._lifter.lift(variable, view=view, parent=pointer)

if (symbol := view.get_symbol_at(pointer.constant)) and symbol.type != SymbolType.DataSymbol:
elif (symbol := view.get_symbol_at(pointer.constant)) and symbol.type != SymbolType.DataSymbol:
return self._lifter.lift(symbol)

if function := view.get_function_at(pointer.constant):
elif function := view.get_function_at(pointer.constant):
return self._lifter.lift(function.symbol)

variable = DataVariable(view, pointer.constant, Type.void(), False)
global_variable = self._lifter.lift(variable, view=view, parent=pointer)
else:
res = self._lifter.lift(DataVariable(view, pointer.constant, Type.void(), False), view=view, parent=pointer)

return self._replace_global_variable_with_value(global_variable, variable, view)
if isinstance(res, Constant): # BNinja Error case handling
return res

def _replace_global_variable_with_value(self, globalVariable: GlobalVariable, variable: DataVariable, view: BinaryView) -> StringSymbol:
"""Replace global variable with it's value, if it's a char/wchar16/wchar32* and in a read only section"""
if not self._in_read_only_section(variable.address, view) or str(globalVariable.type) == "void *":
return globalVariable
return StringSymbol(globalVariable.initial_value, variable.address, vartype=Pointer(Integer.char(), view.address_size * BYTE_SIZE))
if isinstance(res.type, Pointer) and res.type.type == CustomType.void():
return res

def _in_read_only_section(self, addr: int, view: BinaryView) -> bool:
"""Returns True if address is contained in a read only section, False otherwise"""
for _, section in view.sections.items():
if addr >= section.start and addr <= section.end and section.semantics == SectionSemantics.ReadOnlyDataSectionSemantics:
return True
return False
if isinstance(pointer, mediumlevelil.MediumLevelILImport): # Temp fix for '&'
return res

return UnaryOperation(
OperationType.address,
[res],
vartype=res.type,
)
Loading

0 comments on commit a954aeb

Please sign in to comment.