diff --git a/decompiler/backend/cexpressiongenerator.py b/decompiler/backend/cexpressiongenerator.py index f310fe341..09bcfac65 100644 --- a/decompiler/backend/cexpressiongenerator.py +++ b/decompiler/backend/cexpressiongenerator.py @@ -1,12 +1,12 @@ import logging from ctypes import c_byte, c_int, c_long, c_short, c_ubyte, c_uint, c_ulong, c_ushort from itertools import chain, repeat -from typing import Union from decompiler.structures import pseudo as expressions -from decompiler.structures.pseudo import Float, Integer, OperationType, StringSymbol +from decompiler.structures.pseudo import Float, FunctionTypeDef, Integer, OperationType, Pointer, StringSymbol, Type from decompiler.structures.pseudo import instructions as instructions from decompiler.structures.pseudo import operations as operations +from decompiler.structures.pseudo.operations import MemberAccess from decompiler.structures.visitors.interfaces import DataflowObjectVisitorInterface @@ -65,6 +65,7 @@ class CExpressionGenerator(DataflowObjectVisitorInterface): OperationType.greater_or_equal_us: ">=", OperationType.dereference: "*", OperationType.address: "&", + OperationType.member_access: ".", # Handled in code # OperationType.cast: "cast", # OperationType.pointer: "point", @@ -146,7 +147,7 @@ class CExpressionGenerator(DataflowObjectVisitorInterface): # OperationType.low: "low", OperationType.ternary: 30, OperationType.call: 150, - OperationType.field: 150, + OperationType.member_access: 150, OperationType.list_op: 10, # TODO: Figure out what these are / how to handle this # OperationType.adc: "adc", @@ -180,6 +181,9 @@ def visit_list_operation(self, op: operations.ListOperation) -> str: def visit_unary_operation(self, op: operations.UnaryOperation) -> str: """Return a string representation of the given unary operation (e.g. !a or &a).""" + if isinstance(op, MemberAccess): + operator_str = "->" if isinstance(op.struct_variable.type, Pointer) else self.C_SYNTAX[op.operation] + return f"{self.visit(op.struct_variable)}{operator_str}{op.member_name}" operand = self._visit_bracketed(op.operand) if self._has_lower_precedence(op.operand, op) else self.visit(op.operand) if op.operation == OperationType.cast and op.contraction: return f"({int(op.type.size / 8)}: ){operand}" @@ -361,3 +365,14 @@ def _format_string_literal(constant: expressions.Constant) -> str: escaped = string_representation.replace('"', '\\"') return f'"{escaped}"' return f"{constant}" + + @staticmethod + def format_variables_declaration(var_type: Type, var_names: list[str]) -> str: + """ Return a string representation of variable declarations.""" + match var_type: + case Pointer(type=FunctionTypeDef() as fun_type): + parameter_names = ", ".join(str(parameter) for parameter in fun_type.parameters) + declarations_without_return_type = [f"(* {var_name})({parameter_names})" for var_name in var_names] + return f"{fun_type.return_type} {', '.join(declarations_without_return_type)}" + case _: + return f"{var_type} {', '.join(var_names)}" diff --git a/decompiler/backend/codegenerator.py b/decompiler/backend/codegenerator.py index 4ce11655a..677d2466b 100644 --- a/decompiler/backend/codegenerator.py +++ b/decompiler/backend/codegenerator.py @@ -2,6 +2,7 @@ from string import Template from typing import Iterable, List +from decompiler.backend.cexpressiongenerator import CExpressionGenerator from decompiler.backend.codevisitor import CodeVisitor from decompiler.backend.variabledeclarations import GlobalDeclarationGenerator, LocalDeclarationGenerator from decompiler.task import DecompilerTask @@ -29,6 +30,7 @@ def generate(self, tasks: Iterable[DecompilerTask], run_cleanup: bool = True): for task in tasks: if run_cleanup and not task.failed: task.syntax_tree.clean_up() + string_blocks.append(task.complex_types.declarations()) string_blocks.append(self.generate_function(task)) return "\n\n".join(string_blocks) @@ -37,7 +39,10 @@ def generate_function(self, task: DecompilerTask) -> str: return self.TEMPLATE.substitute( return_type=task.function_return_type, name=task.name, - parameters=", ".join(map(lambda param: f"{param.type} {param.name}", task.function_parameters)), + parameters=", ".join(map( + lambda param: CExpressionGenerator.format_variables_declaration(param.type, [param.name]), + task.function_parameters + )), local_declarations=LocalDeclarationGenerator.from_task(task) if not task.failed else "", function_body=CodeVisitor(task).visit(task.syntax_tree.root) if not task.failed else task.failure_message, ) diff --git a/decompiler/backend/variabledeclarations.py b/decompiler/backend/variabledeclarations.py index 56b2ae2c8..f79ae589c 100644 --- a/decompiler/backend/variabledeclarations.py +++ b/decompiler/backend/variabledeclarations.py @@ -2,6 +2,7 @@ from collections import defaultdict from typing import Iterable, Iterator, List, Set +from decompiler.backend.cexpressiongenerator import CExpressionGenerator from decompiler.structures.ast.ast_nodes import ForLoopNode, LoopNode from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree from decompiler.structures.pseudo import ( @@ -17,6 +18,7 @@ UnaryOperation, Variable, ) +from decompiler.structures.pseudo.operations import MemberAccess from decompiler.structures.visitors.ast_dataflowobjectvisitor import BaseAstDataflowObjectVisitor from decompiler.task import DecompilerTask from decompiler.util.serialization.bytes_serializer import convert_bytes @@ -52,6 +54,8 @@ def visit_loop_node(self, node: LoopNode): def visit_unary_operation(self, unary: UnaryOperation): """Visit unary operations to remember all variables those memory location was read.""" + if isinstance(unary, MemberAccess): + self._variables.add(unary.struct_variable) if unary.operation == OperationType.address or unary.operation == OperationType.dereference: if isinstance(unary.operand, Variable): self._variables.add(unary.operand) @@ -61,19 +65,18 @@ def visit_unary_operation(self, unary: UnaryOperation): else: self.visit(unary.operand.left) - def generate(self, param_names: list = []) -> Iterator[str]: + def generate(self, param_names: list[str] = []) -> Iterator[str]: """Generate a string containing the variable definitions for the visited variables.""" variable_type_mapping = defaultdict(list) for variable in sorted(self._variables, key=lambda x: str(x)): - if not isinstance(variable, GlobalVariable): + if not isinstance(variable, GlobalVariable) and variable.name not in param_names: variable_type_mapping[variable.type].append(variable) - for variable_type, variables in sorted(variable_type_mapping.items(), key=lambda x: str(x)): for chunked_variables in self._chunks(variables, self._vars_per_line): - variable_names = ", ".join([var.name for var in chunked_variables]) - if variable_names in param_names: - continue - yield f"{variable_type} {variable_names};" + yield CExpressionGenerator.format_variables_declaration( + variable_type, + [var.name for var in chunked_variables] + ) + ";" @staticmethod def _chunks(lst: List, n: int) -> Iterator[List]: diff --git a/decompiler/frontend/binaryninja/frontend.py b/decompiler/frontend/binaryninja/frontend.py index ed3ba266f..bdd1ad2d1 100644 --- a/decompiler/frontend/binaryninja/frontend.py +++ b/decompiler/frontend/binaryninja/frontend.py @@ -2,11 +2,12 @@ from __future__ import annotations import logging -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union -from binaryninja import BinaryView, BinaryViewType, Function +from binaryninja import BinaryView, Function, load from binaryninja.types import SymbolType from decompiler.structures.graphs.cfg import ControlFlowGraph +from decompiler.structures.pseudo.complextypes import ComplexTypeMap from decompiler.structures.pseudo.expressions import Variable from decompiler.structures.pseudo.typing import Type from decompiler.task import DecompilerTask @@ -112,7 +113,7 @@ def __init__(self, bv: BinaryView): def from_path(cls, path: str, options: Options): """Create a frontend object by invoking binaryninja on the given sample.""" file_options = {"analysis.limits.maxFunctionSize": options.getint("binaryninja.max_function_size")} - if (bv := BinaryViewType.get_view_of_file_with_options(path, options=file_options)) is not None: + if (bv := load(path, options=file_options)) is not None: return cls(bv) raise RuntimeError("Failed to create binary view") @@ -127,10 +128,10 @@ def create_task(self, function_identifier: Union[str, Function], options: Option tagging = CompilerIdiomsTagging(self._bv, function.function.start, options) tagging.run() try: - cfg = self._extract_cfg(function.function, options) + cfg, complex_types = self._extract_cfg(function.function, options) task = DecompilerTask( function.name, cfg, function_return_type=function.return_type, function_parameters=function.params, - options=options + options=options, complex_types=complex_types ) except Exception as e: task = DecompilerTask( @@ -154,9 +155,9 @@ def get_all_function_names(self): functions.append(function.name) return functions - def _extract_cfg(self, function: Function, options: Options) -> ControlFlowGraph: + def _extract_cfg(self, function: Function, options: Options) -> Tuple[ControlFlowGraph, ComplexTypeMap]: """Extract a control flow graph utilizing the parser and fixing it afterwards.""" report_threshold = options.getint("lifter.report_threshold", fallback=3) no_masks = options.getboolean("lifter.no_bit_masks", fallback=True) - parser = BinaryninjaParser(BinaryninjaLifter(no_masks), report_threshold) - return parser.parse(function) + parser = BinaryninjaParser(BinaryninjaLifter(no_masks, bv=function.view), report_threshold) + return parser.parse(function), parser.complex_types diff --git a/decompiler/frontend/binaryninja/handlers/assignments.py b/decompiler/frontend/binaryninja/handlers/assignments.py index 8d9e5ab7d..85d81038c 100644 --- a/decompiler/frontend/binaryninja/handlers/assignments.py +++ b/decompiler/frontend/binaryninja/handlers/assignments.py @@ -1,13 +1,15 @@ """Module implementing the AssignmentHandler for binaryninja.""" +import logging from functools import partial -from typing import Union +import binaryninja from binaryninja import mediumlevelil from decompiler.frontend.lifter import Handler from decompiler.structures.pseudo import ( Assignment, BinaryOperation, Constant, + Expression, GlobalVariable, Integer, Operation, @@ -16,6 +18,8 @@ RegisterPair, UnaryOperation, ) +from decompiler.structures.pseudo.complextypes import Struct, Union +from decompiler.structures.pseudo.operations import MemberAccess class AssignmentHandler(Handler): @@ -38,8 +42,8 @@ def register(self): mediumlevelil.MediumLevelILVarAliasedField: partial(self.lift_get_field, is_aliased=True), mediumlevelil.MediumLevelILStore: self.lift_store, mediumlevelil.MediumLevelILStoreSsa: self.lift_store, - mediumlevelil.MediumLevelILStoreStruct: self._lift_store_struct, - mediumlevelil.MediumLevelILStoreStructSsa: self._lift_store_struct, + mediumlevelil.MediumLevelILStoreStruct: self.lift_store_struct, + mediumlevelil.MediumLevelILStoreStructSsa: self.lift_store_struct, mediumlevelil.MediumLevelILLowPart: self._lift_mask_high, } ) @@ -54,16 +58,31 @@ def lift_assignment(self, assignment: mediumlevelil.MediumLevelILSetVar, is_alia def lift_set_field(self, assignment: mediumlevelil.MediumLevelILSetVarField, is_aliased=False, **kwargs) -> Assignment: """ Lift an instruction writing to a subset of the given value. - - In case of lower register (offset 0) lift as contraction - e.g. eax.al = .... <=> (char)eax .... - - In case higher registers use masking - e.g. eax.ah = x <=> eax = (eax & 0xffff00ff) + (x << 2) + case 1: writing into struct member: book.title = value + lift as struct_member(book, title, writes_memory) = value + case 2: writing into lower register part (offset 0): eax.al = value + lift as contraction (char) eax = value + case 3: writing into higher register part: eax.ah = value + lift using bit masking eax = (eax & 0xffff00ff) + (value << 2) """ - if assignment.offset == 0 and self._lifter.is_omitting_masks: + # case 1 (struct), avoid set field of named integers: + dest_type = self._lifter.lift(assignment.dest.type) + if isinstance(assignment.dest.type, binaryninja.NamedTypeReferenceType) and not ( + isinstance(dest_type, Pointer) and isinstance(dest_type.type, Integer) + ): + struct_variable = self._lifter.lift(assignment.dest, is_aliased=True, parent=assignment) + destination = MemberAccess( + offset=assignment.offset, + member_name=struct_variable.type.get_member_by_offset(assignment.offset).name, + operands=[struct_variable], + writes_memory=assignment.ssa_memory_version, + ) + value = self._lifter.lift(assignment.src) + # case 2 (contraction): + elif assignment.offset == 0 and self._lifter.is_omitting_masks: destination = self._lift_contraction(assignment, is_aliased=is_aliased, parent=assignment) value = self._lifter.lift(assignment.src) + # case 3 (bit masking): else: destination = self._lifter.lift(assignment.dest, is_aliased=is_aliased, parent=assignment) value = self._lift_masked_operand(assignment) @@ -72,9 +91,16 @@ def lift_set_field(self, assignment: mediumlevelil.MediumLevelILSetVarField, is_ def lift_get_field(self, instruction: mediumlevelil.MediumLevelILVarField, is_aliased=False, **kwargs) -> Operation: """ Lift an instruction accessing a field from the outside. - e.g. x = eax.ah <=> x = eax & 0x0000ff00 + + case 1: struct member read access e.g. (x = )book.title + lift as (x = ) struct_member(book, title) + case 2: accessing register portion e.g. (x = )eax.ah + lift as (x = ) eax & 0x0000ff00 + (x = ) <- for the sake of example, only rhs expression is lifted here. """ source = self._lifter.lift(instruction.src, is_aliased=is_aliased, parent=instruction) + if isinstance(source.type, Struct) or isinstance(source.type, Union): + return self._get_field_as_member_access(instruction, source, **kwargs) cast_type = source.type.resize(instruction.size * self.BYTE_SIZE) if instruction.offset: return BinaryOperation( @@ -84,6 +110,22 @@ def lift_get_field(self, instruction: mediumlevelil.MediumLevelILVarField, is_al ) return UnaryOperation(OperationType.cast, [source], vartype=cast_type, contraction=True) + def _get_field_as_member_access(self, instruction: mediumlevelil.MediumLevelILVarField, source: Expression, **kwargs) -> MemberAccess: + """Lift MLIL var_field as struct or union member read access.""" + if isinstance(source.type, Struct): + member_name = source.type.get_member_by_offset(instruction.offset).name + elif parent := kwargs.get("parent", None): + parent_type = self._lifter.lift(parent.dest.type) + member_name = source.type.get_member_by_type(parent_type).name + else: + logging.warning(f"Cannot get member name for instruction {instruction}") + member_name = f"field_{hex(instruction.offset)}" + return MemberAccess( + offset=instruction.offset, + member_name=member_name, + operands=[source], + ) + def lift_store(self, assignment: mediumlevelil.MediumLevelILStoreSsa, **kwargs) -> Assignment: """Lift a store operation to pseudo (e.g. [ebp+4] = eax, or [global_var_label] = 25).""" return Assignment( @@ -91,7 +133,7 @@ def lift_store(self, assignment: mediumlevelil.MediumLevelILStoreSsa, **kwargs) self._lifter.lift(assignment.src), ) - def _lift_store_destination(self, store_assignment: mediumlevelil.MediumLevelILStoreSsa) -> Union[UnaryOperation, GlobalVariable]: + def _lift_store_destination(self, store_assignment: mediumlevelil.MediumLevelILStoreSsa) -> UnaryOperation | GlobalVariable: """ Lift destination operand of store operation which is used for modelling both assignments of dereferences and global variables. """ @@ -167,24 +209,16 @@ def lift_split_assignment(self, assignment: mediumlevelil.MediumLevelILSetVarSpl self._lifter.lift(assignment.src, parent=assignment), ) - def _lift_store_struct(self, instruction: mediumlevelil.MediumLevelILStoreStruct, **kwargs) -> Assignment: + def lift_store_struct(self, instruction: mediumlevelil.MediumLevelILStoreStruct, **kwargs) -> Assignment: """Lift a MLIL_STORE_STRUCT_SSA instruction to pseudo (e.g. object->field = x).""" vartype = self._lifter.lift(instruction.dest.expr_type) - return Assignment( - UnaryOperation( - OperationType.dereference, - [ - BinaryOperation( - OperationType.plus, - [ - UnaryOperation(OperationType.cast, [self._lifter.lift(instruction.dest)], vartype=Pointer(Integer.char())), - Constant(instruction.offset), - ], - vartype=vartype, - ), - ], - vartype=Pointer(vartype), - writes_memory=instruction.dest_memory - ), - self._lifter.lift(instruction.src), + struct_variable = self._lifter.lift(instruction.dest, is_aliased=True, parent=instruction) + struct_member_access = MemberAccess( + member_name=vartype.type.members.get(instruction.offset), + offset=instruction.offset, + operands=[struct_variable], + vartype=vartype, + writes_memory=instruction.dest_memory, ) + src = self._lifter.lift(instruction.src) + return Assignment(struct_member_access, src) diff --git a/decompiler/frontend/binaryninja/handlers/calls.py b/decompiler/frontend/binaryninja/handlers/calls.py index d3467d5a7..e84ca9e2e 100644 --- a/decompiler/frontend/binaryninja/handlers/calls.py +++ b/decompiler/frontend/binaryninja/handlers/calls.py @@ -2,7 +2,7 @@ from functools import partial from typing import List -from binaryninja import MediumLevelILInstruction, Tailcall, mediumlevelil +from binaryninja import FunctionType, PointerType, Tailcall, mediumlevelil from decompiler.frontend.lifter import Handler from decompiler.structures.pseudo import Assignment, Call, ImportedFunctionSymbol, IntrinsicSymbol, ListOperation @@ -73,11 +73,11 @@ def lift_intrinsic(self, call: mediumlevelil.MediumLevelILIntrinsic, ssa: bool = @staticmethod def _lift_call_parameter_names(instruction: mediumlevelil.MediumLevelILCall) -> List[str]: - """Lift parameter names of call from type string of instruction.dest.expr_type""" - if instruction.dest.expr_type is None: + """Lift parameter names of call by iterating over the function parameters where the call is pointing to (if available)""" + if instruction.dest.expr_type is None or not isinstance(instruction.dest.expr_type, PointerType) or \ + not isinstance(instruction.dest.expr_type.target, FunctionType): return [] - clean_type_string_of_parameters = instruction.dest.expr_type.get_string_after_name().strip("()") - return [type_parameter.rsplit(" ", 1)[-1] for type_parameter in clean_type_string_of_parameters.split(",")] + return [param.name for param in instruction.dest.expr_type.target.parameters] @staticmethod def _lift_syscall_parameter_names(instruction: mediumlevelil.MediumLevelILSyscall) -> List[str]: diff --git a/decompiler/frontend/binaryninja/handlers/types.py b/decompiler/frontend/binaryninja/handlers/types.py index 2dd211a8a..353a5922a 100644 --- a/decompiler/frontend/binaryninja/handlers/types.py +++ b/decompiler/frontend/binaryninja/handlers/types.py @@ -1,15 +1,20 @@ import logging +from abc import abstractmethod +from typing import Optional, Union +from binaryninja import BinaryView, StructureVariant from binaryninja.types import ( ArrayType, BoolType, CharType, + EnumerationMember, EnumerationType, FloatType, FunctionType, IntegerType, NamedTypeReferenceType, PointerType, + StructureMember, StructureType, Type, VoidType, @@ -17,6 +22,8 @@ ) from decompiler.frontend.lifter import Handler from decompiler.structures.pseudo import CustomType, Float, FunctionTypeDef, Integer, Pointer, UnknownType, Variable +from decompiler.structures.pseudo.complextypes import ComplexTypeMember, ComplexTypeName, Enum, Struct +from decompiler.structures.pseudo.complextypes import Union as Union_ class TypeHandler(Handler): @@ -33,10 +40,12 @@ def register(self): VoidType: self.lift_void, CharType: self.lift_integer, WideCharType: self.lift_custom, - NamedTypeReferenceType: self.lift_custom, - StructureType: self.lift_custom, + NamedTypeReferenceType: self.lift_named_type_reference_type, + StructureType: self.lift_struct, + StructureMember: self.lift_struct_member, FunctionType: self.lift_function_type, - EnumerationType: self.lift_custom, + EnumerationType: self.lift_enum, + EnumerationMember: self.lift_enum_member, type(None): self.lift_none, } ) @@ -47,9 +56,80 @@ def lift_none(self, _: None, **kwargs): def lift_custom(self, custom: Type, **kwargs) -> CustomType: """Lift custom types such as structs as a custom type.""" - logging.debug(f"[TypeHandler] lifting custom type: {custom}") return CustomType(str(custom), custom.width * self.BYTE_SIZE) + def lift_named_type_reference_type(self, custom: NamedTypeReferenceType, **kwargs) -> Union[Type, CustomType]: + """Lift a special type that binary ninja uses a references on complex types like structs, unions, etc. as well + as user-defined types. Examples: + typedef PVOID HANDLE; # NamedTypeReferenceType pointing to void pointer + struct IO_FILE; #NamedTypeReferenceType pointing to a structure with that name + + Binary Ninja expressions do not get complex type in that case, but the NamedTypeReferenceType on that type. + We try to retrieve the original complex type from binary view using this placeholder type, and lift it correspondingly. + """ + view: BinaryView = self._lifter.bv + if defined_type := view.get_type_by_name(custom.name): # actually should always be the case + return self._lifter.lift(defined_type, name=str(custom.name)) + logging.warning(f"NamedTypeReferenceType {custom} was not found in binary view types.") + return CustomType(str(custom), custom.width * self.BYTE_SIZE) + + def lift_enum(self, binja_enum: EnumerationType, name: str = None, **kwargs) -> Enum: + """Lift enum type.""" + enum_name = name if name else self._get_data_type_name(binja_enum, keyword="enum") + enum = Enum(binja_enum.width * self.BYTE_SIZE, enum_name, {}) + for member in binja_enum.members: + enum.add_member(self._lifter.lift(member)) + self._lifter.complex_types.add(enum) + return enum + + def lift_enum_member(self, enum_member: EnumerationMember, **kwargs) -> ComplexTypeMember: + """Lift enum member type.""" + return ComplexTypeMember(size=0, name=enum_member.name, offset=-1, type=Integer(32), value=int(enum_member.value)) + + def lift_struct(self, struct: StructureType, name: str = None, **kwargs) -> Union[Struct, ComplexTypeName]: + """Lift struct or union type.""" + if struct.type == StructureVariant.StructStructureType: + type_name = name if name else self._get_data_type_name(struct, keyword="struct") + lifted_struct = Struct(struct.width * self.BYTE_SIZE, type_name, {}) + elif struct.type == StructureVariant.UnionStructureType: + type_name = name if name else self._get_data_type_name(struct, keyword="union") + lifted_struct = Union_(struct.width * self.BYTE_SIZE, type_name, []) + else: + raise RuntimeError(f"Unknown struct type {struct.type.name}") + for member in struct.members: + lifted_struct.add_member(self.lift_struct_member(member, type_name)) + self._lifter.complex_types.add(lifted_struct) + return lifted_struct + + @abstractmethod + def _get_data_type_name(self, complex_type: Union[StructureType, EnumerationType], keyword: str) -> str: + """Parse out the name of complex type.""" + string = complex_type.get_string() + if keyword in string: + return complex_type.get_string().split(keyword)[1] + return string + + def lift_struct_member(self, member: StructureMember, parent_struct_name: str = None) -> ComplexTypeMember: + """Lift struct or union member.""" + # handle the case when struct member is a pointer on the same struct + if structPtr := self._get_member_pointer_on_the_parent_struct(member, parent_struct_name): + return structPtr + else: + # if member is an embedded struct/union, the name is already available + member_type = self._lifter.lift(member.type, name=member.name) + return ComplexTypeMember(0, name=member.name, offset=member.offset, type=member_type) + + @abstractmethod + def _get_member_pointer_on_the_parent_struct(self, member: StructureMember, parent_struct_name: str) -> ComplexTypeMember: + """Constructs struct or union member which is a pointer on parent struct or union type.""" + if ( + isinstance(member.type, PointerType) + and (isinstance(member.type.target, StructureType) or isinstance(member.type.target, NamedTypeReferenceType)) + and str(member.type.target.name) == parent_struct_name + ): + member_type = Pointer(ComplexTypeName(0, parent_struct_name)) + return ComplexTypeMember(0, name=member.name, offset=member.offset, type=member_type) + def lift_void(self, _, **kwargs) -> CustomType: """Lift the void-type (should only be used as function return type).""" return CustomType.void() diff --git a/decompiler/frontend/binaryninja/handlers/unary.py b/decompiler/frontend/binaryninja/handlers/unary.py index 7470fb200..180aecfd0 100644 --- a/decompiler/frontend/binaryninja/handlers/unary.py +++ b/decompiler/frontend/binaryninja/handlers/unary.py @@ -1,4 +1,5 @@ """Module implementing the UnaryOperationHandler.""" +import logging from functools import partial from typing import Union @@ -14,6 +15,8 @@ Pointer, UnaryOperation, ) +from decompiler.structures.pseudo.complextypes import Struct +from decompiler.structures.pseudo.operations import MemberAccess class UnaryOperationHandler(Handler): @@ -52,7 +55,7 @@ def lift_dereference_or_global_variable( self, operation: Union[mediumlevelil.MediumLevelILLoad, mediumlevelil.MediumLevelILLoadSsa], **kwargs ) -> Union[GlobalVariable, UnaryOperation]: """Lift load operation which is used both to model dereference operation and global variable read.""" - load_operand : UnaryOperation = self._lifter.lift(operation.src, parent=operation) + load_operand: UnaryOperation = self._lifter.lift(operation.src, parent=operation) if load_operand and isinstance(global_variable := load_operand, GlobalVariable): global_variable.ssa_label = operation.ssa_memory_version return global_variable @@ -91,22 +94,12 @@ def _lift_zx_operation(self, instruction: MediumLevelILInstruction, **kwargs) -> ) return self.lift_cast(instruction, **kwargs) - def _lift_load_struct(self, instruction: mediumlevelil.MediumLevelILLoadStruct, **kwargs) -> UnaryOperation: - """Lift a MLIL_LOAD_STRUCT_SSA instruction.""" - return UnaryOperation( - OperationType.dereference, - [ - BinaryOperation( - OperationType.plus, - [ - UnaryOperation(OperationType.cast, [self._lifter.lift(instruction.src)], vartype=Pointer(Integer.char())), - Constant(instruction.offset), - ], - vartype=self._lifter.lift(instruction.src.expr_type), - ), - ], - vartype=Pointer(self._lifter.lift(instruction.src.expr_type)), - ) + def _lift_load_struct(self, instruction: mediumlevelil.MediumLevelILLoadStruct, **kwargs) -> MemberAccess: + """Lift a MLIL_LOAD_STRUCT_SSA (struct member access e.g. var#n->x) instruction.""" + struct_variable = self._lifter.lift(instruction.src) + struct_ptr: Pointer = self._lifter.lift(instruction.src.expr_type) + struct_member = struct_ptr.type.get_member_by_offset(instruction.offset) + return MemberAccess(vartype=struct_ptr, operands=[struct_variable], offset=struct_member.offset, member_name=struct_member.name) def _lift_ftrunc(self, instruction: mediumlevelil.MediumLevelILFtrunc, **kwargs) -> UnaryOperation: """Lift a MLIL_FTRUNC operation.""" diff --git a/decompiler/frontend/binaryninja/lifter.py b/decompiler/frontend/binaryninja/lifter.py index e502d1b6b..e42761763 100644 --- a/decompiler/frontend/binaryninja/lifter.py +++ b/decompiler/frontend/binaryninja/lifter.py @@ -2,18 +2,21 @@ from logging import warning from typing import Optional, Tuple, Union -from binaryninja import MediumLevelILInstruction, Type +from binaryninja import BinaryView, MediumLevelILInstruction, Type from decompiler.frontend.lifter import ObserverLifter from decompiler.structures.pseudo import DataflowObject, Tag, UnknownExpression, UnknownType +from ...structures.pseudo.complextypes import ComplexTypeMap from .handlers import HANDLERS class BinaryninjaLifter(ObserverLifter): """Lifter converting Binaryninja.mediumlevelil expressions to pseudo expressions.""" - def __init__(self, no_bit_masks: bool = True): + def __init__(self, no_bit_masks: bool = True, bv: BinaryView = None): self.no_bit_masks = no_bit_masks + self.bv: BinaryView = bv + self.complex_types: ComplexTypeMap = ComplexTypeMap() for handler in HANDLERS: handler(self).register() diff --git a/decompiler/frontend/binaryninja/parser.py b/decompiler/frontend/binaryninja/parser.py index 59f5afde6..43ccc684d 100644 --- a/decompiler/frontend/binaryninja/parser.py +++ b/decompiler/frontend/binaryninja/parser.py @@ -1,20 +1,23 @@ """Implements the parser for the binaryninja frontend.""" from logging import info, warning -from typing import Dict, Iterator, List +from typing import Dict, Iterator, List, Tuple from binaryninja import ( + BasicBlockEdge, BranchType, Function, MediumLevelILBasicBlock, MediumLevelILConstPtr, MediumLevelILInstruction, MediumLevelILJumpTo, + MediumLevelILTailcallSsa, RegisterValueType, ) from decompiler.frontend.lifter import Lifter from decompiler.frontend.parser import Parser from decompiler.structures.graphs.cfg import BasicBlock, ControlFlowGraph, FalseCase, IndirectEdge, SwitchCase, TrueCase, UnconditionalEdge from decompiler.structures.pseudo import Constant, Instruction +from decompiler.structures.pseudo.complextypes import ComplexTypeMap class BinaryninjaParser(Parser): @@ -33,6 +36,7 @@ def __init__(self, lifter: Lifter, report_threshold: int = 3): self._lifter = lifter self._unlifted_instructions: List[MediumLevelILInstruction] = [] self._report_threshold = int(report_threshold) + self._complex_types = None def parse(self, function: Function) -> ControlFlowGraph: """Generate a cfg from the given function.""" @@ -43,9 +47,32 @@ def parse(self, function: Function) -> ControlFlowGraph: cfg.add_node(index_to_BasicBlock[basic_block.index]) for basic_block in function.medium_level_il.ssa_form: self._add_basic_block_edges(cfg, index_to_BasicBlock, basic_block) + self._complex_types = self._lifter.complex_types self._report_lifter_errors() return cfg + @property + def complex_types(self) -> ComplexTypeMap: + """Return complex type map for the given function.""" + return self._complex_types + + def _recover_switch_edge_cases(self, edge: BasicBlockEdge, lookup_table: dict): + """ + If edge.target.source_block.start address is not in lookup table, + try to recover matching address by inspecting addresses used in edge.target. + Return matched case list for edge.target. + """ + possible_matches = set() + for instruction in edge.target: + match instruction: + # tail calls destroy edge address mapping + case MediumLevelILTailcallSsa(dest=MediumLevelILConstPtr()): + possible_matches.add(instruction.dest.constant) + # we have found exactly one address and that address is used in lookup table. + if len(possible_matches) == 1 and possible_matches & set(lookup_table): + return lookup_table[possible_matches.pop()] # return cases for matched address + raise KeyError("Can not recover address used in lookup table") + def _add_basic_block_edges(self, cfg: ControlFlowGraph, vertices: dict, basic_block: MediumLevelILBasicBlock) -> None: """Add all outgoing edges of the given basic block to the given cfg.""" if self._can_convert_single_outedge_to_unconditional(basic_block): @@ -56,11 +83,15 @@ def _add_basic_block_edges(self, cfg: ControlFlowGraph, vertices: dict, basic_bl # check if the block ends with a switch statement elif lookup_table := self._get_lookup_table(basic_block): for edge in basic_block.outgoing_edges: + if edge.target.source_block.start not in lookup_table: + case_list = self._recover_switch_edge_cases(edge, lookup_table) + else: + case_list = lookup_table[edge.target.source_block.start] cfg.add_edge( SwitchCase( vertices[edge.source.index], vertices[edge.target.index], - lookup_table[edge.target.source_block.start], + case_list, ) ) else: diff --git a/decompiler/frontend/lifter.py b/decompiler/frontend/lifter.py index b4d4a3259..c58e76382 100644 --- a/decompiler/frontend/lifter.py +++ b/decompiler/frontend/lifter.py @@ -8,6 +8,9 @@ class Lifter(ABC): """Represents a basic lifter emmiting decompiler IR.""" + def __init__(self): + self.complex_types = None + @abstractmethod def lift(self, expression, **kwargs) -> Expression: """Lift the given expression to pseudo IR.""" diff --git a/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py b/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py index e8b82bb10..048d4c098 100644 --- a/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py +++ b/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py @@ -287,6 +287,9 @@ def run(self): or self._invalid_simple_for_loop_condition_type(loop_node.condition): continue + if any(node.does_end_with_continue for node in loop_node.body.get_descendant_code_nodes_interrupting_ancestor_loop()): + continue + if not self._force_for_loops and loop_node.condition.get_complexity(self._ast.condition_map) > self._condition_max_complexity: continue diff --git a/decompiler/pipeline/default.py b/decompiler/pipeline/default.py index 79c70cf42..7266e0596 100644 --- a/decompiler/pipeline/default.py +++ b/decompiler/pipeline/default.py @@ -20,12 +20,13 @@ RedundantCastsElimination, TypePropagation, ) -from decompiler.pipeline.expressions import DeadComponentPruner, EdgePruner, GraphExpressionFolding +from decompiler.pipeline.expressions import BitFieldComparisonUnrolling, DeadComponentPruner, EdgePruner, GraphExpressionFolding CFG_STAGES = [ GraphExpressionFolding, DeadComponentPruner, ExpressionPropagation, + BitFieldComparisonUnrolling, TypePropagation, DeadPathElimination, DeadLoopElimination, diff --git a/decompiler/pipeline/expressions/__init__.py b/decompiler/pipeline/expressions/__init__.py index c5ec8a836..482025f6a 100644 --- a/decompiler/pipeline/expressions/__init__.py +++ b/decompiler/pipeline/expressions/__init__.py @@ -1,3 +1,4 @@ +from .bitfieldcomparisonunrolling import BitFieldComparisonUnrolling from .deadcomponentpruner import DeadComponentPruner from .edgepruner import EdgePruner from .expressionfolding import GraphExpressionFolding diff --git a/decompiler/pipeline/expressions/bitfieldcomparisonunrolling.py b/decompiler/pipeline/expressions/bitfieldcomparisonunrolling.py new file mode 100644 index 000000000..51fb37814 --- /dev/null +++ b/decompiler/pipeline/expressions/bitfieldcomparisonunrolling.py @@ -0,0 +1,169 @@ +from dataclasses import dataclass +from logging import debug, warning +from typing import List, Optional, Tuple, Union + +from decompiler.pipeline.stage import PipelineStage +from decompiler.structures.graphs.basicblock import BasicBlock +from decompiler.structures.graphs.branches import ConditionalEdge, FalseCase, TrueCase, UnconditionalEdge +from decompiler.structures.graphs.cfg import ControlFlowGraph +from decompiler.structures.pseudo import Constant, Expression +from decompiler.structures.pseudo.expressions import Variable +from decompiler.structures.pseudo.instructions import Branch +from decompiler.structures.pseudo.operations import BinaryOperation, Condition, OperationType, UnaryOperation +from decompiler.task import DecompilerTask + + +@dataclass +class FoldedCase: + """ + Class for storing information of folded case. + """ + basic_block: BasicBlock + switch_variable: Expression + case_values: List[int] + edge_type_to_case_node: type[FalseCase] | type[TrueCase] + + def get_case_node_and_other_node(self, cfg: ControlFlowGraph) -> Tuple[BasicBlock, BasicBlock]: + """ + Return the case node and the other node based on which branch condition corresponds to the case node. + """ + out_edges = cfg.get_out_edges(self.basic_block) + assert len(out_edges) == 2, "expext two out edges (TrueCase/FalseCase)" + if isinstance(out_edges[0], self.edge_type_to_case_node): + return out_edges[0].sink, out_edges[1].sink + elif isinstance(out_edges[1], self.edge_type_to_case_node): + return out_edges[1].sink, out_edges[0].sink + raise ValueError("Outedges do not match type") + + +class BitFieldComparisonUnrolling(PipelineStage): + """ + Transform bit-field compiler optimization to readable comparison: + + var = 1 << amount; + if ((var & 0b11010) != 0) { ... } + + // becomes: + + if ( amount == 1 || amount == 3 || amount == 4 ) { ... } + + This can subsequently be used to reconstruct switch-case statements. + + This stage requires expression-propagation PipelineStage, such that bit-shift + gets forwarded into Branch.condition: + + if ( (1 << amount) & bit_mask) == 0) ) { ... } + """ + + name = "bit-field-comparison-unrolling" + dependencies = ["expression-propagation"] + + def run(self, task: DecompilerTask): + """Run the pipeline stage: Check all viable Branch-instructions.""" + folded_cases: List[FoldedCase] = [] + for block in task.graph: + if (folded_case := self._get_folded_case(block)) is not None: + folded_cases.append(folded_case) + for folded_case in folded_cases: + self._modify_cfg(task.graph, folded_case) + + def _modify_cfg(self, cfg: ControlFlowGraph, folded_case: FoldedCase): + """ + Create nested if blocks for each case in unfolded values. + Note: with the Branch condition encountered so far (== 0x0), the node of the collected cases is adjacent to the FalseCase edge. + However, negated conditions may exist. In this case, pass condition type (flag) and swap successor nodes accordingly. + """ + debug("modifying cfg") + case_node, other_node = folded_case.get_case_node_and_other_node(cfg) + # remove condition from block + folded_case.basic_block.remove_instruction(folded_case.basic_block[-1]) + cfg.remove_edges_from(cfg.get_out_edges(folded_case.basic_block)) + # create condition chain + nested_if_blocks = [ + self._create_condition_block(cfg, folded_case.switch_variable, case_value) for case_value in folded_case.case_values + ] + for pred, succ in zip(nested_if_blocks, nested_if_blocks[1:]): + cfg.add_edge(TrueCase(pred, case_node)) + cfg.add_edge(FalseCase(pred, succ)) + # add edges for last and first block + cfg.add_edge(TrueCase(nested_if_blocks[-1], case_node)) + cfg.add_edge(FalseCase(nested_if_blocks[-1], other_node)) + cfg.add_edge(UnconditionalEdge(folded_case.basic_block, nested_if_blocks[0])) + + def _create_condition_block(self, cfg: ControlFlowGraph, switch_var: Expression, case_value: int) -> BasicBlock: + """Create conditional block in CFG, e.g., `if (var == 0x42)`.""" + const = Constant(value=case_value, vartype=switch_var.type) + return cfg.create_block([Branch(condition=Condition(OperationType.equal, [switch_var, const]))]) + + def _get_folded_case(self, block: BasicBlock) -> Optional[FoldedCase]: + """Unfold Branch condition (checking bit field) into switch variable and list of case values.""" + if not len(block): + return None + if not isinstance(branch_instruction := block[-1], Branch): + return None + match branch_instruction.condition: + case Condition(OperationType.equal, subexpr, Constant(value=0x0)): + edge_type_to_case_node = FalseCase + case Condition(OperationType.not_equal, subexpr, Constant(value=0x0)): + edge_type_to_case_node = TrueCase + case Condition(OperationType.equal, Constant(value=0x0), subexpr): + edge_type_to_case_node = FalseCase + case Condition(OperationType.not_equal, Constant(value=0x0), subexpr): + edge_type_to_case_node = TrueCase + case _: + return None + if (matched_expression := self._get_switch_var_and_bitfield(subexpr)) is not None: + switch_var, bit_field = matched_expression + cleaned_var = self._clean_variable(switch_var) + case_values = self._get_values(bit_field) + if cleaned_var and case_values: + return FoldedCase( + basic_block=block, switch_variable=cleaned_var, case_values=case_values, edge_type_to_case_node=edge_type_to_case_node + ) + return None + + def _get_switch_var_and_bitfield(self, subexpr: Expression) -> Optional[Tuple[Expression, Constant]]: + """ + Match expression of folded switch case: + a) ((1 << (cast)var) & 0xffffffff) & bit_field_constant + b) (0x1 << ((1: ) ecx#1)) & bit_field_constant + Return the Variable (or Expression) that is switched on, and bit field Constant. + """ + match subexpr: + case BinaryOperation( + OperationType.bitwise_and, + BinaryOperation( + OperationType.bitwise_and, BinaryOperation(OperationType.left_shift, Constant(value=1), switch_var), Constant() + ), + Constant() as bit_field, + ) if bit_field.value != 0xFFFFFFFF: + return switch_var, bit_field + case BinaryOperation( + OperationType.bitwise_and, + BinaryOperation(OperationType.left_shift, Constant(value=1), switch_var), + Constant() as bit_field, + ) if bit_field.value != 0xFFFFFFFF: + return switch_var, bit_field + case _: + debug(f"no match for {subexpr}") + return None + + def _get_values(self, const: Constant) -> List[int]: + """Return positions of set bits from integer Constant.""" + bitmask = const.value + values = [] + if not isinstance(bitmask, int): + warning("not an integer") + return [] + for pos, bit in enumerate(bin(bitmask)[:1:-1]): + if bit == "1": + values.append(pos) + return values + + def _clean_variable(self, expr: Expression) -> Optional[Variable]: + """Remove cast from Variable.""" + if isinstance(expr, Variable): + return expr + if isinstance(expr, UnaryOperation) and expr.operation == OperationType.cast: + if len(expr.requirements) == 1: + return expr.requirements[0] diff --git a/decompiler/pipeline/preprocessing/remove_stack_canary.py b/decompiler/pipeline/preprocessing/remove_stack_canary.py index fe053ef60..14b2f2ed2 100644 --- a/decompiler/pipeline/preprocessing/remove_stack_canary.py +++ b/decompiler/pipeline/preprocessing/remove_stack_canary.py @@ -20,6 +20,8 @@ class RemoveStackCanary(PipelineStage): def run(self, task: DecompilerTask): if task.options.getboolean(f"{self.name}.remove_canary", fallback=False) and task.name != self.STACK_FAIL_STR: self._cfg = task.graph + if len(self._cfg) == 1: + return # do not remove the only node for fail_node in list(self._contains_stack_check_fail()): self._patch_canary(fail_node) diff --git a/decompiler/pipeline/preprocessing/switch_variable_detection.py b/decompiler/pipeline/preprocessing/switch_variable_detection.py index ebb32e02d..e459c4b14 100644 --- a/decompiler/pipeline/preprocessing/switch_variable_detection.py +++ b/decompiler/pipeline/preprocessing/switch_variable_detection.py @@ -69,21 +69,21 @@ class BackwardSliceSwitchVariableDetection(PipelineStage): name = "backward-slice-switch-variable-detection" def __init__(self): - self._use_map = None - self._def_map = None - self._dereferences_used_in_branches = None + self._def_map: DefMap + self._use_map: UseMap + self._dereferences_used_in_branches: set def run(self, task: DecompilerTask): """ + Replace switch variable containing offset calculations with a "cleaner" predecessor. + Jump table offset calculations become then the dead code and will be removed during the dead code elimination stage. - iterate through the basic blocks - on switch block found: - if switch block has only one conditional block predecessor: - track the variable in indirect jump backwards till its first related use in the switch basic block - find the variable common between the first use instruction and condition in conditional predecessor - and substitute the jump variable with the common variable; - jump table offset calculations become then the dead code and will be removed during - the dead code elimination stage - + - track the variable in indirect jump backwards until it matches a replacement criterion: + a) defined in copy assignment Var1 = Var2 + b) is used in an Assignment with RHS being Condition solely requiring `variable` + c) is used in Branch with single requirement + d) if any predecessors of `variable` are used as dereferences in branches Overcomes issues with dummy heuristic. """ self._init_map(task.graph) @@ -116,13 +116,28 @@ def find_switch_expression(self, switch_instruction: Instruction): return variable raise ValueError("No switch variable candidate found.") - def _is_bounds_checked(self, value: Variable) -> bool: - """Check if the given variable is a direct copy of another one. It that is the case, return the copied variable.""" + def _is_used_in_condition_assignment(self, value: Variable): + """ + Check if `value` is used in an Assignment with RHS being Condition solely requiring `value` + """ for usage in self._use_map.get(value): if isinstance(usage, Assignment) and isinstance(usage.value, Condition) and usage.requirements == [value]: return True + return False + + def _is_used_in_branch(self, value: Variable): + """ + Check if `value` is used in Branch solely requiring `value` + """ + for usage in self._use_map.get(value): if isinstance(usage, Branch) and usage.requirements == [value]: return True + return False + + def _is_predecessor_dereferenced_in_branch(self, value: Variable) -> bool: + """ + Check if any predecessors of `value` are used as dereferences in branches. + """ if definition := self._def_map.get(value): return ( any(exp in self._dereferences_used_in_branches for exp in definition.value) @@ -130,6 +145,27 @@ def _is_bounds_checked(self, value: Variable) -> bool: ) return False + def _is_copy_assigned(self, value: Variable) -> bool: + """ + Check if variable is defined in copy assignment of the form Var1 = Var2. + """ + if definition := self._def_map.get(value): + return isinstance(definition.value, Variable) + return False + + def _is_bounds_checked(self, value: Variable) -> bool: + """ + Check if variable can be used in switch expression. + """ + return any( + [ + self._is_copy_assigned(value), + self._is_used_in_condition_assignment(value), + self._is_used_in_branch(value), + self._is_predecessor_dereferenced_in_branch(value), + ] + ) + def _backwardslice(self, value: Variable): """Do a breadth-first search on variable predecessors.""" visited = set() diff --git a/decompiler/structures/pseudo/__init__.py b/decompiler/structures/pseudo/__init__.py index 314a94cf4..0df23d585 100644 --- a/decompiler/structures/pseudo/__init__.py +++ b/decompiler/structures/pseudo/__init__.py @@ -1,3 +1,4 @@ +from .complextypes import ComplexType, ComplexTypeMember, ComplexTypeName, Enum, Struct, Union from .delogic_logic import DelogicConverter from .expressions import ( Constant, diff --git a/decompiler/structures/pseudo/complextypes.py b/decompiler/structures/pseudo/complextypes.py new file mode 100644 index 000000000..b32528b4a --- /dev/null +++ b/decompiler/structures/pseudo/complextypes.py @@ -0,0 +1,140 @@ +import copy +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional + +from decompiler.structures.pseudo.typing import Type + + +class ComplexTypeSpecifier(Enum): + STRUCT = "struct" + UNION = "union" + ENUM = "enum" + CLASS = "class" + + +@dataclass(frozen=True, order=True) +class ComplexType(Type): + size = 0 + name: str + + def __str__(self): + return self.name + + def copy(self, **kwargs) -> Type: + return copy.deepcopy(self) + + def declaration(self) -> str: + raise NotImplementedError + + +@dataclass(frozen=True, order=True) +class ComplexTypeMember(ComplexType): + """Class representing a member of a struct type. + @param name: name of the struct member + @param offset: offset of the member within the struct + @param type: datatype of the member + @param value: initial value of the member, enums only + """ + + name: str + offset: int + type: Type + value: Optional[int] = None + + def __str__(self) -> str: + return f"{self.name}" + + def declaration(self) -> str: + """Return declaration field for the complex type member.""" + if isinstance(self.type, Union): + return self.type.declaration() + return f"{self.type.__str__()} {self.name}" + + +@dataclass(frozen=True, order=True) +class Struct(ComplexType): + """Class representing a struct type.""" + + members: Dict[int, ComplexTypeMember] = field(compare=False) + type_specifier: ComplexTypeSpecifier = ComplexTypeSpecifier.STRUCT + + def add_member(self, member: ComplexTypeMember): + self.members[member.offset] = member + + def get_member_by_offset(self, offset: int) -> ComplexTypeMember: + return self.members.get(offset) + + def declaration(self) -> str: + members = ";\n\t".join(self.members[k].declaration() for k in sorted(self.members.keys())) + ";" + return f"{self.type_specifier.value} {self.name} {{\n\t{members}\n}}" + + +@dataclass(frozen=True, order=True) +class Union(ComplexType): + members: List[ComplexTypeMember] = field(compare=False) + type_specifier = ComplexTypeSpecifier.UNION + + def add_member(self, member: ComplexTypeMember): + self.members.append(member) + + def declaration(self) -> str: + members = ";\n\t".join(x.declaration() for x in self.members) + ";" + return f"{self.type_specifier.value} {self.name} {{\n\t{members}\n}}" + + def get_member_by_type(self, _type: Type) -> ComplexTypeMember: + """Retrieve member of union by its type.""" + for member in self.members: + if member.type == _type: + return member + + +@dataclass(frozen=True, order=True) +class Enum(ComplexType): + members: Dict[int, ComplexTypeMember] = field(compare=False) + type_specifier = ComplexTypeSpecifier.ENUM + + def add_member(self, member: ComplexTypeMember): + self.members[member.value] = member + + def get_name_by_value(self, value: int) -> str: + return self.members.get(value).name + + def declaration(self) -> str: + members = ",\n\t".join(f"{x.name} = {x.value}" for x in self.members.values()) + return f"{self.type_specifier.value} {self.name} {{\n\t{members}\n}}" + + +@dataclass(frozen=True, order=True) +class ComplexTypeName(Type): + """Class that store a name of a complex type. Used to prevent recursions when constructing + struct(...) members of the same complex type""" + + name: str + + def __str__(self) -> str: + return self.name + + +class ComplexTypeMap: + """A class in charge of storing complex custom/user defined types by their string representation""" + + def __init__(self): + self._name_to_type_map: Dict[ComplexTypeName, ComplexType] = {} + + def retrieve_by_name(self, typename: ComplexTypeName) -> ComplexType: + """Get complex type by name; used to avoid recursion.""" + return self._name_to_type_map.get(typename, None) + + def add(self, complex_type: ComplexType): + """Add complex type to the mapping.""" + self._name_to_type_map[ComplexTypeName(0, complex_type.name)] = complex_type + + def pretty_print(self): + for t in self._name_to_type_map.values(): + logging.error(t.declaration()) + + def declarations(self) -> str: + """Returns declarations of all complex types used in decompiled function.""" + return ";\n".join(t.declaration() for t in self._name_to_type_map.values()) + ";" if self._name_to_type_map else "" diff --git a/decompiler/structures/pseudo/expressions.py b/decompiler/structures/pseudo/expressions.py index a7c9f91ee..5c59afaf2 100644 --- a/decompiler/structures/pseudo/expressions.py +++ b/decompiler/structures/pseudo/expressions.py @@ -32,6 +32,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Generic, Iterator, List, Optional, Tuple, TypeVar, Union +from .complextypes import Enum from .typing import CustomType, Type, UnknownType T = TypeVar("T") @@ -186,7 +187,11 @@ def __repr__(self) -> str: return f"{value} type: {self.type}" def __str__(self) -> str: - """Return a hex-based string representation for integers, strings are printed with double quotation marks""" + """Return a hex-based string representation for integers, strings are printed with double quotation marks. + Constants of type Enum are represented as strings (corresponding enumerator identifiers). + """ + if isinstance(self._type, Enum): + return self._type.get_name_by_value(self.value) if self._type.is_boolean: return "true" if self.value else "false" if isinstance(self.value, str): diff --git a/decompiler/structures/pseudo/operations.py b/decompiler/structures/pseudo/operations.py index aeaa1b514..c21002068 100644 --- a/decompiler/structures/pseudo/operations.py +++ b/decompiler/structures/pseudo/operations.py @@ -11,7 +11,7 @@ from decompiler.util.insertion_ordered_set import InsertionOrderedSet from .expressions import Constant, Expression, FunctionSymbol, ImportedFunctionSymbol, IntrinsicSymbol, Symbol, Tag, Variable -from .typing import CustomType, Type, UnknownType +from .typing import CustomType, Pointer, Type, UnknownType T = TypeVar("T") @@ -73,6 +73,7 @@ class OperationType(Enum): field = auto() list_op = auto() adc = auto() + member_access = auto() # For pretty-printing and debug @@ -124,9 +125,9 @@ class OperationType(Enum): OperationType.low: "low", OperationType.ternary: "?", OperationType.call: "func", - OperationType.field: "->", OperationType.list_op: "list", OperationType.adc: "adc", + OperationType.member_access: ".", } UNSIGNED_OPERATIONS = { @@ -385,9 +386,59 @@ def accept(self, visitor: DataflowObjectVisitorInterface[T]) -> T: return visitor.visit_unary_operation(self) +class MemberAccess(UnaryOperation): + def __init__( + self, + offset: int, + member_name: str, + operands: List[Expression], + vartype: Type = UnknownType(), + writes_memory: Optional[int] = None, + ): + super().__init__(OperationType.member_access, operands, vartype, writes_memory=writes_memory) + self.member_offset = offset + self.member_name = member_name + + def __str__(self): + # use -> when accessing member via a pointer to a struct: ptrBook->title + # use . when accessing struct member directly: book.title + if isinstance(self.struct_variable.type, Pointer): + return f"{self.struct_variable}->{self.member_name}" + return f"{self.struct_variable}.{self.member_name}" + + @property + def struct_variable(self) -> Expression: + """Variable of complex type, which member is being accessed here.""" + return self.operand + + def substitute(self, replacee: Expression, replacement: Expression) -> None: + if isinstance(replacee, Variable) and replacee == self.struct_variable and isinstance(replacement, Variable): + self.operands[:] = [replacement] + + def copy(self) -> MemberAccess: + """Copy the current UnaryOperation, copying all operands and the type.""" + return MemberAccess( + self.member_offset, + self.member_name, + [operand.copy() for operand in self._operands], + self._type.copy(), + writes_memory=self.writes_memory, + ) + + def is_read_access(self) -> bool: + """Read-only member access.""" + return self.writes_memory is None + + def is_write_access(self) -> bool: + """Member is being accessed for writing.""" + return self.writes_memory is not None + + class BinaryOperation(Operation): """Class representing operations with two operands.""" + __match_args__ = ("operation", "left", "right") + def __str__(self) -> str: """Return a string representation with infix notation.""" str_left = f"({self.left})" if isinstance(self.left, Operation) else f"{self.left}" diff --git a/decompiler/task.py b/decompiler/task.py index 1128cc6ae..38f149a7b 100644 --- a/decompiler/task.py +++ b/decompiler/task.py @@ -1,8 +1,9 @@ """Module describing tasks to be handled by the decompiler pipleline.""" -from typing import List, Optional +from typing import Dict, List, Optional from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree from decompiler.structures.graphs.cfg import ControlFlowGraph +from decompiler.structures.pseudo.complextypes import ComplexTypeMap from decompiler.structures.pseudo.expressions import Variable from decompiler.structures.pseudo.typing import Integer, Type from decompiler.util.options import Options @@ -19,6 +20,7 @@ def __init__( options: Optional[Options] = None, function_return_type: Type = Integer(32), function_parameters: Optional[List[Variable]] = None, + complex_types: Optional[ComplexTypeMap] = None ): """ Init a new decompiler task. @@ -36,6 +38,7 @@ def __init__( self._options: Options = options if options else Options.load_default_options() self._failed = False self._failure_origin = None + self._complex_types = complex_types if complex_types else ComplexTypeMap() @property def name(self) -> str: @@ -91,3 +94,8 @@ def failure_message(self) -> str: if self._failure_origin: msg += f" due to error during {self._failure_origin}." return msg + + @property + def complex_types(self) -> ComplexTypeMap: + """Return complex types present in the function (structs, unions, enums, etc.).""" + return self._complex_types diff --git a/decompiler/util/default.json b/decompiler/util/default.json index 7ce41db93..5ae9ba0b5 100644 --- a/decompiler/util/default.json +++ b/decompiler/util/default.json @@ -602,6 +602,7 @@ "dest": "pipeline.cfg_stages", "default": [ "expression-propagation", + "bit-field-comparison-unrolling", "type-propagation", "dead-path-elimination", "dead-loop-elimination", diff --git a/tests/backend/test_codegenerator.py b/tests/backend/test_codegenerator.py index 4f90e61f5..47726e32d 100644 --- a/tests/backend/test_codegenerator.py +++ b/tests/backend/test_codegenerator.py @@ -11,6 +11,7 @@ from decompiler.structures.ast.ast_nodes import CodeNode, SeqNode, SwitchNode from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree from decompiler.structures.logic.logic_condition import LogicCondition +from decompiler.structures.pseudo import FunctionTypeDef from decompiler.structures.pseudo.expressions import ( Constant, DataflowObject, @@ -28,6 +29,7 @@ Call, Condition, ListOperation, + MemberAccess, OperationType, UnaryOperation, ) @@ -59,6 +61,7 @@ def true_condition(context=None): context = LogicCondition.generate_new_context() if context is None else context return LogicCondition.initialize_true(context) + def logic_cond(name: str, context) -> LogicCondition: return LogicCondition.initialize_symbol(name, context) @@ -75,6 +78,8 @@ def logic_cond(name: str, context) -> LogicCondition: var_x_u = Variable("x_u", uint32) var_y_u = Variable("y_u", uint32) var_p = Variable("p", Pointer(int32)) +var_fun_p = Variable("p", Pointer(FunctionTypeDef(0, int32, (int32,)))) +var_fun_p0 = Variable("p0", Pointer(FunctionTypeDef(0, int32, (int32,)))) const_0 = Constant(0, int32) const_1 = Constant(1, int32) @@ -127,6 +132,7 @@ def test_init(self): def test_function_with_comment(self): root = SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context())) + ast = AbstractSyntaxTree(root, {}) code_node = ast._add_code_node([Comment("test_comment", comment_style="debug")]) ast._add_edge(root, code_node) @@ -155,6 +161,15 @@ def test_empty_function_two_parameters(self): r"^\s*int +test_function\(\s*int +a\s*,\s*int +b\s*\){\s*}\s*$", self._task(ast, params=[var_a.copy(), var_b.copy()]) ) + def test_empty_function_two_function_parameters(self): + root = SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context())) + ast = AbstractSyntaxTree(root, {}) + code_node = ast._add_code_node([]) + ast._add_edge(root, code_node) + assert self._regex_matches( + r"^\s*int +test_function\(\s*int +\(\*\s*p\)\(int\)\s*,\s*int +\(\*\s*p0\)\(int\)\s*\){\s*}\s*$", self._task(ast, params=[var_fun_p.copy(), var_fun_p0.copy()]) + ) + def test_function_with_instruction(self): root = SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context())) ast = AbstractSyntaxTree(root, {}) @@ -450,7 +465,7 @@ def test_branch_condition(self, context, condition: LogicCondition, condition_ma regex = r"^%int +test_function\(\)%{(?s).*if%\(%COND_STR%\)%{%return%0%;%}%}%$" assert self._regex_matches(regex.replace("COND_STR", expected).replace("%", "\\s*"), self._task(ast)) - + def test_loop_declaration_ListOp(self): """ a = 5; @@ -479,7 +494,7 @@ def test_loop_declaration_ListOp(self): ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) root._sorted_children = (code_node, loop_node) source_code = CodeGenerator().generate([self._task(ast)]).replace("\n", "") - assert source_code.find("for (b = foo();") != -1 + assert source_code.find("for (b = foo();") != -1 class TestExpression: @@ -691,6 +706,85 @@ def test_array_element_access_default(self, operation, result): def test_array_element_access_aggressive(self, operation, result): assert self._visit_code(operation, _generate_options(array_detection=True)) == result + @pytest.mark.parametrize( + "operation, result", + [ + ( + MemberAccess(operands=[Variable("a", Integer.int32_t())], member_name="x", offset=0, vartype=Integer.int32_t()), + "a.x", + ), + ( + MemberAccess( + operands=[ + MemberAccess(operands=[Variable("a", Integer.int32_t())], member_name="x", offset=0, vartype=Integer.int32_t()) + ], + member_name="z", + offset=0, + vartype=Integer.int32_t(), + ), + "a.x.z", + ), + ( + MemberAccess(operands=[Variable("ptr", Pointer(Integer.int32_t()))], member_name="x", offset=0, vartype=Integer.int32_t()), + "ptr->x", + ), + ( + MemberAccess( + operands=[ + MemberAccess( + operands=[Variable("ptr", Pointer(Integer.int32_t()))], member_name="x", offset=0, vartype=Integer.int32_t() + ) + ], + member_name="z", + offset=0, + vartype=Pointer(Integer.int32_t()), + ), + "ptr->x.z", + ), + ( + MemberAccess( + operands=[ + MemberAccess( + operands=[Variable("ptr", Pointer(Integer.int32_t()))], + member_name="x", + offset=0, + vartype=Pointer(Integer.int32_t()), + ) + ], + member_name="z", + offset=0, + vartype=Pointer(Pointer(Integer.int32_t())), + ), + "ptr->x->z", + ), + ( + MemberAccess( + operands=[ + MemberAccess( + operands=[ + MemberAccess( + operands=[Variable("ptr", Pointer(Integer.int32_t()))], + member_name="x", + offset=0, + vartype=Pointer(Integer.int32_t()), + ) + ], + member_name="z", + offset=0, + vartype=Pointer(Pointer(Integer.int32_t())), + ) + ], + member_name="w", + offset=8, + vartype=Pointer(Pointer(Pointer(Integer.int32_t()))), + ), + "ptr->x->z->w", + ), + ], + ) + def test_member_access(self, operation, result): + assert self._visit_code(operation) == result + @pytest.mark.parametrize( "expr, result", [ @@ -1069,6 +1163,8 @@ def test_operation(self, op, expected): (1, [var_x.copy(), var_y.copy(), var_x_f.copy(), var_y_f.copy()], "float x_f;\nfloat y_f;\nint x;\nint y;"), (2, [var_x.copy(), var_y.copy(), var_x_f.copy(), var_y_f.copy()], "float x_f, y_f;\nint x, y;"), (1, [var_x.copy(), var_y.copy(), var_p.copy()], "int x;\nint y;\nint * p;"), + (1, [var_x.copy(), var_y.copy(), var_fun_p.copy()], "int x;\nint y;\nint (* p)(int);"), + (2, [var_x.copy(), var_y.copy(), var_fun_p.copy(), var_fun_p0.copy()], "int x, y;\nint (* p)(int), (* p0)(int);"), ], ) def test_variable_declaration(self, vars_per_line: int, variables: List[Variable], expected: str): diff --git a/tests/frontend/test_parser.py b/tests/frontend/test_parser.py index 11d4a4278..678e4f326 100644 --- a/tests/frontend/test_parser.py +++ b/tests/frontend/test_parser.py @@ -4,21 +4,24 @@ import pytest from binaryninja import ( + BasicBlockEdge, BranchType, Function, MediumLevelILBasicBlock, MediumLevelILConstPtr, MediumLevelILInstruction, MediumLevelILJumpTo, + MediumLevelILTailcallSsa, PossibleValueSet, RegisterValueType, Variable, ) from decompiler.frontend.binaryninja.lifter import BinaryninjaLifter from decompiler.frontend.binaryninja.parser import BinaryninjaParser -from decompiler.structures.graphs.branches import UnconditionalEdge +from decompiler.structures.graphs.branches import SwitchCase, UnconditionalEdge from decompiler.structures.graphs.cfg import BasicBlockEdgeCondition from decompiler.structures.pseudo.expressions import Constant +from decompiler.util.decoration import DecoratedCFG class MockEdge: @@ -158,6 +161,24 @@ def __init__(self, address: int): self.dest.constant = address self.dest.function = MockFunction([]) # need .function.view to lift +class MockTailcall(Mock): + """Mock object representing a constant jump.""" + + def __init__(self, address: int): + """Create new MediumLevelILJumpTo object""" + super().__init__(spec=MediumLevelILTailcallSsa) + self.ssa_memory_version = 0 + self.function = None # prevents lifting of tags + self.dest = Mock(spec=MediumLevelILConstPtr) + self.dest.constant = address + self.params = [] + self.output = [] + self.dest.function = MockFunction([]) # need .function.view to lift + + def _get_child_mock(self, **kw: Any) -> NonCallableMock: + """Return Mock as child mock.""" + return Mock(params=[])._get_child_mock(**kw) + @pytest.fixture def parser(): @@ -293,3 +314,35 @@ def test_convert_indirect_edge_to_unconditional_no_valid_edge(parser): assert (cfg_edge.source.address, cfg_edge.sink.address) == (0, 42) assert not isinstance(cfg_edge, UnconditionalEdge) assert len(list(cfg.instructions)) == 1 + +def test_tailcall_address_recovery(parser): + """ + Address of edge.target.source_block.start is not in lookup table. + """ + jmp_instr = MockSwitch({"a": 42}) + function = MockFunction( + [ + MockBlock(0, [MockEdge(0, 0, BranchType.IndirectBranch)], instructions=[jmp_instr]), + MockBlock(1, []), + ] + ) + with pytest.raises(KeyError): + cfg = parser.parse(function) + + # extract address from tailcall in successor + tailcall = MockTailcall(address=42) + broken_edge = Mock() + broken_edge.type = BranchType.IndirectBranch + + function = MockFunction( + [ + switch_block := MockBlock(0, [broken_edge], instructions=[jmp_instr]), + tailcall_block := MockBlock(1, [], instructions=[tailcall]), + ] + ) + broken_edge.source = switch_block + broken_edge.target = tailcall_block + broken_edge.target.source_block.start = 0 + cfg = parser.parse(function) + v0, v1 = cfg.nodes + assert isinstance(cfg.get_edge(v0, v1), SwitchCase) diff --git a/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py b/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py index d01311c6c..eecd157a3 100644 --- a/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py +++ b/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py @@ -4,6 +4,7 @@ from decompiler.pipeline.controlflowanalysis.readability_based_refinement import ( ForLoopVariableRenamer, ReadabilityBasedRefinement, + WhileLoopReplacer, WhileLoopVariableRenamer, _find_continuation_instruction, _has_deep_requirement, @@ -19,6 +20,7 @@ Call, Condition, Constant, + Continue, ImportedFunctionSymbol, ListOperation, OperationType, @@ -2077,3 +2079,114 @@ def test_declaration_listop(self, ast_call_for_loop): for node in ast_call_for_loop: if isinstance(node, ForLoopNode): assert node.declaration.destination.operands[0].name == "i" + + def test_skip_for_loop_recovery_if_continue_in_while(self): + """ + a = 0 + while(a < 10) { + if(a == 2) { + a = a + 2 + continue + } + a = a + 1 + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)]), + logic_cond("x2", context): Condition(OperationType.equal, [Variable("a"), Constant(2)]) + } + ) + + true_branch = ast._add_code_node( + [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(2)])), + Continue() + ] + ) + if_condition = ast._add_condition_node_with(logic_cond("x2", context), true_branch) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body = ast.factory.create_seq_node() + while_loop_iteration = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) + ast._add_node(while_loop) + ast._add_node(while_loop_body) + + ast._add_edges_from( + [ + (root, init_code_node), + (root, while_loop), + (while_loop, while_loop_body), + (while_loop_body, if_condition), + (while_loop_body, while_loop_iteration) + ] + ) + + WhileLoopReplacer(ast, _generate_options()).run() + assert not any(isinstance(loop_node, ForLoopNode) for loop_node in list(ast.get_loop_nodes_post_order())) + + def test_skip_for_loop_recovery_if_continue_in_nested_while(self): + """ + while(a < 5) { + a = a + b + while(b < 10) { + if(b < 0) { + b = b + 2 + continue + } + b = b + 1 + } + a = a + 1 + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(5)]), + logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(10)]), + logic_cond("x3", context): Condition(OperationType.less, [Variable("b"), Constant(0)]) + } + ) + + true_branch = ast._add_code_node( + [ + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(2)])), + Continue() + ] + ) + if_condition = ast._add_condition_node_with(logic_cond("x3", context), true_branch) + + while_loop_outer = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body_outer = ast.factory.create_seq_node() + while_loop_iteration_outer_1 = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")]))]) + while_loop_iteration_outer_2 = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) + ast._add_node(while_loop_outer) + ast._add_node(while_loop_body_outer) + + while_loop_inner = ast.factory.create_while_loop_node(logic_cond("x2", context)) + while_loop_body_inner = ast.factory.create_seq_node() + while_loop_iteration_inner = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))]) + ast._add_node(while_loop_inner) + ast._add_node(while_loop_body_inner) + + ast._add_edges_from( + [ + (root, while_loop_outer), + (while_loop_outer, while_loop_body_outer), + (while_loop_body_outer, while_loop_inner), + (while_loop_body_outer, while_loop_iteration_outer_1), + (while_loop_body_outer, while_loop_iteration_outer_2), + (while_loop_inner, while_loop_body_inner), + (while_loop_body_inner, if_condition), + (while_loop_body_inner, while_loop_iteration_inner) + ] + ) + + WhileLoopReplacer(ast, _generate_options()).run() + loop_nodes = list(ast.get_loop_nodes_post_order()) + assert not isinstance(loop_nodes[0], ForLoopNode) and isinstance(loop_nodes[1], ForLoopNode) \ No newline at end of file diff --git a/tests/pipeline/expressions/__init__.py b/tests/pipeline/expressions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/pipeline/expressions/test_bitfieldcomparisonunrolling.py b/tests/pipeline/expressions/test_bitfieldcomparisonunrolling.py new file mode 100644 index 000000000..cdc61605a --- /dev/null +++ b/tests/pipeline/expressions/test_bitfieldcomparisonunrolling.py @@ -0,0 +1,92 @@ +from decompiler.pipeline.expressions.bitfieldcomparisonunrolling import BitFieldComparisonUnrolling +from decompiler.structures.graphs.basicblock import BasicBlock +from decompiler.structures.graphs.branches import FalseCase, TrueCase, UnconditionalEdge +from decompiler.structures.graphs.cfg import ControlFlowGraph +from decompiler.structures.pseudo.expressions import Constant, Variable +from decompiler.structures.pseudo.instructions import Branch, Comment, Return +from decompiler.structures.pseudo.operations import BinaryOperation, Condition, OperationType + + +class MockTask: + def __init__(self, cfg: ControlFlowGraph): + self.graph = cfg + + +def get_tf_successors(cfg: ControlFlowGraph, block: BasicBlock): + match cfg.get_out_edges(block): + case (TrueCase() as true_edge, FalseCase() as false_edge): + pass + case (FalseCase() as false_edge, TrueCase() as true_edge): + pass + case _: + raise ValueError("Block does not have outgoing T/F edges.") + return true_edge.sink, false_edge.sink + + +def test_unrolling_with_bitmask(): + """ + +-------------------+ +------------------------------------------------+ + | 2. | | 0. | + | /* other block */ | | if((((0x1 << var) & 0xffffffff) & 0x7) == 0x0) | + | return 0x1 | <-- | | + +-------------------+ +------------------------------------------------+ + | + | + v + +------------------------------------------------+ + | 1. | + | /* case block */ | + | return 0x0 | + +------------------------------------------------+ + """ + cfg = ControlFlowGraph() + switch_var = Variable("var") + bit_field = Constant(0b111) + branch_subexpr = BinaryOperation( + OperationType.bitwise_and, + [ + BinaryOperation( + OperationType.bitwise_and, + [BinaryOperation(OperationType.left_shift, [Constant(value=1), switch_var]), Constant(0xFFFFFFFF)], + ), + bit_field, + ], + ) + branch = Branch(condition=Condition(OperationType.equal, [branch_subexpr, Constant(0x0)])) + cfg.add_nodes_from( + [ + block := BasicBlock( + 0, + [branch], + ), + case_block := BasicBlock(1, [Comment("case block"), Return([Constant(0)])]), + other_block := BasicBlock(2, [Comment("other block"), Return([Constant(1)])]), + ] + ) + cfg.add_edges_from([TrueCase(block, other_block), FalseCase(block, case_block)]) + task = MockTask(cfg) + BitFieldComparisonUnrolling().run(task) + assert len(block) == 0, "removing of branch instruction failed" + block_out_edges = cfg.get_out_edges(block) + assert len(block_out_edges) == 1 + assert isinstance(block_out_edges[0], UnconditionalEdge) + successors = cfg.get_successors(block) + assert len(successors) == 1 + s1 = successors[0] + target, s2 = get_tf_successors(cfg, s1) + assert target == case_block + target, s3 = get_tf_successors(cfg, s2) + assert target == case_block + target, other = get_tf_successors(cfg, s3) + assert target == case_block + assert other == other_block + assert str(s1.instructions[-1].condition) == "var == 0x0" + assert str(s2.instructions[-1].condition) == "var == 0x1" + assert str(s3.instructions[-1].condition) == "var == 0x2" + assert isinstance(cfg.get_edge(block, s1), UnconditionalEdge) + assert isinstance(cfg.get_edge(s1, s2), FalseCase) + assert isinstance(cfg.get_edge(s2, s3), FalseCase) + assert isinstance(cfg.get_edge(s3, other_block), FalseCase) + assert isinstance(cfg.get_edge(s1, case_block), TrueCase) + assert isinstance(cfg.get_edge(s2, case_block), TrueCase) + assert isinstance(cfg.get_edge(s3, case_block), TrueCase) diff --git a/tests/pipeline/preprocessing/test_remove_stack_canary.py b/tests/pipeline/preprocessing/test_remove_stack_canary.py index 2f9711122..953f93ba5 100644 --- a/tests/pipeline/preprocessing/test_remove_stack_canary.py +++ b/tests/pipeline/preprocessing/test_remove_stack_canary.py @@ -52,6 +52,21 @@ def test_trivial_no_change(): assert isinstance(cfg.get_edge(n1, n2), TrueCase) assert isinstance(cfg.get_edge(n1, n3), FalseCase) +def test_no_change_to_single_block_function(): + """ + +--------------------+ + | 0. | + | __stack_chk_fail() | + +--------------------+ + """ + cfg = ControlFlowGraph() + cfg.add_nodes_from( + [ + b := BasicBlock(0, instructions=[Assignment(ListOperation([]), Call(ImportedFunctionSymbol("__stack_chk_fail", 0), []))]), + ] + ) + _run_remove_stack_canary(cfg) + assert set(cfg) == {b} def test_one_branch_to_stack_fail(): """ @@ -476,4 +491,4 @@ def test_multiple_returns_multiple_empty_blocks_one_stackcheck(): assert isinstance(cfg.get_edge(n2, n1), TrueCase) assert isinstance(cfg.get_edge(n2, n4), FalseCase) assert isinstance(cfg.get_edge(n3, n10), UnconditionalEdge) - assert isinstance(cfg.get_edge(n4, n5), UnconditionalEdge) \ No newline at end of file + assert isinstance(cfg.get_edge(n4, n5), UnconditionalEdge) diff --git a/tests/pipeline/preprocessing/test_switch_variable_detection.py b/tests/pipeline/preprocessing/test_switch_variable_detection.py index 4bf51db6f..afbcf2d4d 100644 --- a/tests/pipeline/preprocessing/test_switch_variable_detection.py +++ b/tests/pipeline/preprocessing/test_switch_variable_detection.py @@ -156,7 +156,6 @@ def test_switch_variable_in_condition_assignment(self): Check whether we track the switch expression correctly even if it was used in a dedicated condition statement." This test is based on the output of gcc 9.2.1 on ubuntu switch sample test_switch test8. - +----------+ +------------------------------+ | | | 0. | | 2. | | | @@ -170,10 +169,10 @@ def test_switch_variable_in_condition_assignment(self): | | 1. | | | | | y#0 = x | | 4. | | | y#1 = 0xfffff + (y#0 << 0x2) | | bar(0x2) | - | | jmp y#1 | ..> | | + | | jmp y#1 | --> | | | +------------------------------+ +----------+ - | : | - | : | + | | | + | | | | v | | +------------------------------+ | | | 3. | | @@ -188,31 +187,34 @@ def test_switch_variable_in_condition_assignment(self): +------------------------------+ """ cfg = ControlFlowGraph() + y0 = Variable("y", ssa_label=0) + y1 = Variable("y", ssa_label=1) + x = Variable("x") cfg.add_nodes_from( [ start := BasicBlock( 0, instructions=[ - Assignment(Variable("cond:0", ssa_label=0), Condition(OperationType.less_us, [Variable("x"), Constant(8)])), + Assignment(Variable("cond:0", ssa_label=0), Condition(OperationType.less_us, [x, Constant(8)])), Branch(Condition(OperationType.not_equal, [Variable("cond:0", ssa_label=0), Constant(0)])), ], ), switch_block := BasicBlock( 1, instructions=[ - Assignment(Variable("y", ssa_label=0), Variable("x")), + Assignment(y0, x), Assignment( - Variable("y", ssa_label=1), + y1, BinaryOperation( OperationType.plus, - [Constant(0xFFFFF), BinaryOperation(OperationType.left_shift, [Variable("y", ssa_label=0), Constant(2)])], + [Constant(0xFFFFF), BinaryOperation(OperationType.left_shift, [y0, Constant(2)])], ), ), - switch := IndirectBranch(Variable("y", ssa_label=1)), + switch := IndirectBranch(y1), ], ), default := BasicBlock(2, instructions=[Assignment(ListOperation([]), Call(function_symbol("foo"), [Constant(0)]))]), - end := BasicBlock(-1, instructions=[Return([Variable("x")])]), + end := BasicBlock(-1, instructions=[Return([x])]), case_1 := BasicBlock( 3, instructions=[ @@ -240,7 +242,7 @@ def test_switch_variable_in_condition_assignment(self): ) svd = SwitchVariableDetection() svd.run(MockTask(cfg)) - assert svd.find_switch_expression(switch) == Variable("x") + assert svd.find_switch_expression(switch) == y0 a0 = Variable("a", Integer.int32_t(), 0) @@ -423,3 +425,147 @@ def test_constant_pointer(): task = MockTask(cfg) SwitchVariableDetection().run(task) assert vertices[2].instructions[-1] == IndirectBranch(rax) + + +def test_first_simple_assignment(): + """ + +--------------+ +--------------------------++--------------+ + | 1. | | 0. || 7. | + | return 0x4 | | x = arg0 || rbx#4 = 0xb | + | | <-- | if((*(0x423658)) u> 0x4) || | ----------------------+ + +--------------+ +--------------------------++--------------+ | + | ^ | + | | | + v | | + +--------------+ +------------------------------------------+ +--------------+ | + | | | 2. | | | | + | 5. | | rax#1 = x | | 6. | | + | rbx#2 = 0x14 | | rax#2 = (*(0xffffff42)) + (rax#1 << 0x3) | | rbx#3 = 0x17 | | + | | | rax#3 = rax#2 + (*(0xffffff42)) | | | | + | | <-- | jmp rax#3 | --> | | | + +--------------+ +------------------------------------------+ +--------------+ | + | | | | | + | | | | | + | v v | | + | +--------------------------++--------------+ | | + | | 3. || 4. | | | + | | rbx#0 = 0x2c || rbx#1 = 0x28 | | | + | +--------------------------++--------------+ | | + | | | | | + | | | | | + | v v | | + | +------------------------------------------+ | | + +----------------> | 8. | <-----+ | + | rbx#5 = ϕ(rbx#0,rbx#1,rbx#2,rbx#3,rbx#4) | | + | return rbx#5 | | + | | <---------------------+ + +------------------------------------------+ + """ + cfg = ControlFlowGraph() + cont_pointer = UnaryOperation( + OperationType.dereference, + [Constant(4339288, Pointer(Integer(32, False), 64))], + Integer(32, False), + None, + False, + ) + rax1 = Variable("rax", Integer(64, False), 1, False, None) + rax2 = Variable("rax", Integer(64, False), 2, False, None) + rax3 = Variable("rax", Integer(64, False), 3, False, None) + rbx = [Variable("rbx", Integer(64, False), i, False, None) for i in range(6)] + def3 = Assignment( + rax3, + BinaryOperation( + OperationType.plus, + [ + rax2, + UnaryOperation(OperationType.dereference, [Constant(JT_OFFSET)], Integer(64, True)), + ], + Integer(64, True), + ), + ) + def2 = Assignment( + rax2, + BinaryOperation( + OperationType.plus, + [ + UnaryOperation(OperationType.dereference, [Constant(JT_OFFSET)], Integer(64, True)), + BinaryOperation( + OperationType.left_shift, + [rax1, Constant(3, Integer(8, True))], + Integer(64, False), + ), + ], + Integer(64, True), + ), + ) + def1 = Assignment(rax1, Variable("x")) + def0 = Assignment(Variable("x"), Variable("arg0")) + + cfg.add_nodes_from( + vertices := [ + BasicBlock( + 0, + [ + def0, + Branch(Condition(OperationType.greater_us, [cont_pointer, Constant(4, Integer(32, True))], CustomType("bool", 1))), + ], + ), + BasicBlock(1, [Return([Constant(4, Integer(64, True))])]), + BasicBlock( + 2, + [ + def1, + def2, + def3, + switch := IndirectBranch(rax3), + ], + ), + BasicBlock(3, [Assignment(rbx[0], Constant(44, Integer(64, True)))]), + BasicBlock(4, [Assignment(rbx[1], Constant(40, Integer(64, True)))]), + BasicBlock( + 5, + [Assignment(rbx[2], Constant(20, Integer(64, True)))], + ), + BasicBlock( + 6, + [Assignment(rbx[3], Constant(23, Integer(64, True)))], + ), + BasicBlock( + 7, + [Assignment(rbx[4], Constant(11, Integer(64, True)))], + ), + BasicBlock( + 8, + [ + Phi( + rbx[5], + rbx[0:5], + {}, + ), + Return([rbx[5]]), + ], + ), + ] + ) + cfg.add_edges_from( + [ + TrueCase(vertices[0], vertices[1]), + FalseCase(vertices[0], vertices[2]), + SwitchCase(vertices[2], vertices[3], [Constant(3)]), + SwitchCase(vertices[2], vertices[4], [Constant(4)]), + SwitchCase(vertices[2], vertices[5], [Constant(0)]), + SwitchCase(vertices[2], vertices[6], [Constant(1)]), + SwitchCase(vertices[2], vertices[7], [Constant(2)]), + UnconditionalEdge(vertices[3], vertices[8]), + UnconditionalEdge(vertices[4], vertices[8]), + UnconditionalEdge(vertices[5], vertices[8]), + UnconditionalEdge(vertices[6], vertices[8]), + UnconditionalEdge(vertices[7], vertices[8]), + ] + ) + task = MockTask(cfg) + svd = SwitchVariableDetection() + svd.run(task) + assert vertices[2].instructions[-1] == IndirectBranch(rax1) + assert svd.find_switch_expression(switch) == rax1 diff --git a/tests/structures/pseudo/test_complextypes.py b/tests/structures/pseudo/test_complextypes.py new file mode 100644 index 000000000..3bad97d60 --- /dev/null +++ b/tests/structures/pseudo/test_complextypes.py @@ -0,0 +1,207 @@ +import pytest +from decompiler.structures.pseudo import Float, Integer, Pointer +from decompiler.structures.pseudo.complextypes import ( + ComplexTypeMap, + ComplexTypeMember, + ComplexTypeName, + ComplexTypeSpecifier, + Enum, + Struct, + Union, +) + + +class TestStruct: + def test_declaration(self, book: Struct, record_id: Union): + assert book.declaration() == "struct Book {\n\tchar * title;\n\tint num_pages;\n\tchar * author;\n}" + # nest complex type + book.add_member( + m := ComplexTypeMember(size=64, name="id", offset=12, type=record_id), + ) + result = f"struct Book {{\n\tchar * title;\n\tint num_pages;\n\tchar * author;\n\t{m.declaration()};\n}}" + assert book.declaration() == result + + def test_str(self, book: Struct): + assert str(book) == "Book" + + def test_copy(self, book: Struct): + new_book: Struct = book.copy() + assert id(new_book) != id(book) + assert new_book.size == book.size + assert new_book.type_specifier == book.type_specifier == ComplexTypeSpecifier.STRUCT + assert id(new_book.members) != id(book.members) + assert new_book.get_member_by_offset(0) == book.get_member_by_offset(0) + assert id(new_book.get_member_by_offset(0)) != id(book.get_member_by_offset(0)) + assert len(new_book.members) == len(book.members) + + def test_add_members(self, book, title, num_pages, author): + empty_book = Struct(name="Book", members={}, size=96) + empty_book.add_member(title) + empty_book.add_member(author) + empty_book.add_member(num_pages) + assert empty_book == book + + def test_get_member_by_offset(self, book, title, num_pages, author): + assert book.get_member_by_offset(0) == title + assert book.get_member_by_offset(4) == num_pages + assert book.get_member_by_offset(8) == author + + +@pytest.fixture +def book() -> Struct: + return Struct( + name="Book", + members={ + 0: ComplexTypeMember(size=32, name="title", offset=0, type=Pointer(Integer.char())), + 4: ComplexTypeMember(size=32, name="num_pages", offset=4, type=Integer.int32_t()), + 8: ComplexTypeMember(size=32, name="author", offset=8, type=Pointer(Integer.char())), + }, + size=96, + ) + + +@pytest.fixture +def title() -> ComplexTypeMember: + return ComplexTypeMember(size=32, name="title", offset=0, type=Pointer(Integer.char())) + + +@pytest.fixture +def num_pages() -> ComplexTypeMember: + return ComplexTypeMember(size=32, name="num_pages", offset=4, type=Integer.int32_t()) + + +@pytest.fixture +def author() -> ComplexTypeMember: + return ComplexTypeMember(size=32, name="author", offset=8, type=Pointer(Integer.char())) + + +class TestUnion: + def test_declaration(self, record_id): + assert record_id.declaration() == "union RecordID {\n\tfloat float_id;\n\tint int_id;\n\tdouble double_id;\n}" + + def test_str(self, record_id): + assert str(record_id) == "RecordID" + + def test_copy(self, record_id): + new_record_id: Union = record_id.copy() + assert new_record_id == record_id + assert id(new_record_id) != id(record_id) + assert id(new_record_id.members) != id(record_id.members) + assert new_record_id.get_member_by_type(Float.float()) == record_id.get_member_by_type(Float.float()) + assert id(new_record_id.get_member_by_type(Float.float())) != id(record_id.get_member_by_type(Float.float())) + + def test_add_members(self, empty_record_id, record_id, float_id, int_id, double_id): + empty_record_id.add_member(float_id) + empty_record_id.add_member(int_id) + empty_record_id.add_member(double_id) + assert empty_record_id == record_id + + def test_get_member_by_type(self, record_id, float_id, int_id, double_id): + assert record_id.get_member_by_type(Float.float()) == float_id + assert record_id.get_member_by_type(Integer.int32_t()) == int_id + assert record_id.get_member_by_type(Float.double()) == double_id + + +@pytest.fixture +def record_id() -> Union: + return Union( + name="RecordID", + size=64, + members=[ + ComplexTypeMember(size=32, name="float_id", offset=0, type=Float.float()), + ComplexTypeMember(size=32, name="int_id", offset=0, type=Integer.int32_t()), + ComplexTypeMember(size=Float.double().size, name="double_id", offset=0, type=Float.double()), + ], + ) + + +@pytest.fixture +def empty_record_id() -> Union: + return Union(name="RecordID", size=64, members=[]) + + +@pytest.fixture +def float_id() -> ComplexTypeMember: + return ComplexTypeMember(size=32, name="float_id", offset=0, type=Float.float()) + + +@pytest.fixture +def int_id() -> ComplexTypeMember: + return ComplexTypeMember(size=32, name="int_id", offset=0, type=Integer.int32_t()) + + +@pytest.fixture +def double_id() -> ComplexTypeMember: + return ComplexTypeMember(size=Float.double().size, name="double_id", offset=0, type=Float.double()) + + +class TestEnum: + def test_declaration(self, color): + assert color.declaration() == "enum Color {\n\tred = 0,\n\tgreen = 1,\n\tblue = 2\n}" + + def test_str(self, color): + assert str(color) == "Color" + + def test_copy(self, color): + new_color = color.copy() + assert new_color == color + assert id(new_color) != color + + def test_add_members(self, empty_color, color, red, green, blue): + empty_color.add_member(red) + empty_color.add_member(green) + empty_color.add_member(blue) + assert empty_color == color + + +@pytest.fixture +def color(): + return Enum( + 0, + "Color", + { + 0: ComplexTypeMember(0, "red", value=0, offset=0, type=Integer.int32_t()), + 1: ComplexTypeMember(0, "green", value=1, offset=0, type=Integer.int32_t()), + 2: ComplexTypeMember(0, "blue", value=2, offset=0, type=Integer.int32_t()), + }, + ) + + +@pytest.fixture +def empty_color(): + return Enum(0, "Color", {}) + + +@pytest.fixture +def red(): + return ComplexTypeMember(0, "red", value=0, offset=0, type=Integer.int32_t()) + + +@pytest.fixture +def green(): + return ComplexTypeMember(0, "green", value=1, offset=0, type=Integer.int32_t()) + + +@pytest.fixture +def blue(): + return ComplexTypeMember(0, "blue", value=2, offset=0, type=Integer.int32_t()) + + +class TestComplexTypeMap: + def test_declarations(self, complex_types: ComplexTypeMap, book: Struct, color: Enum, record_id: Union): + assert complex_types.declarations() == f"{book.declaration()};\n{color.declaration()};\n{record_id.declaration()};" + complex_types.add(book) + assert complex_types.declarations() == f"{book.declaration()};\n{color.declaration()};\n{record_id.declaration()};" + + def test_retrieve_by_name(self, complex_types: ComplexTypeMap, book: Struct, color: Enum, record_id: Union): + assert complex_types.retrieve_by_name(ComplexTypeName(0, "Book")) == book + assert complex_types.retrieve_by_name(ComplexTypeName(0, "RecordID")) == record_id + assert complex_types.retrieve_by_name(ComplexTypeName(0, "Color")) == color + + @pytest.fixture + def complex_types(self, book: Struct, color: Enum, record_id: Union): + complex_types = ComplexTypeMap() + complex_types.add(book) + complex_types.add(color) + complex_types.add(record_id) + return complex_types diff --git a/tests/structures/pseudo/test_operations.py b/tests/structures/pseudo/test_operations.py index c3074cc8f..b3d28ab03 100644 --- a/tests/structures/pseudo/test_operations.py +++ b/tests/structures/pseudo/test_operations.py @@ -9,15 +9,17 @@ Call, Condition, ListOperation, + MemberAccess, OperationType, TernaryExpression, UnaryOperation, ) -from decompiler.structures.pseudo.typing import Integer +from decompiler.structures.pseudo.typing import Integer, Pointer a = Variable("a", Integer.int32_t(), 0) b = Variable("b", Integer.int32_t(), 1) c = Variable("c", Integer.int32_t(), 2) +ptr = Variable("ptr", Pointer(Integer.int32_t()), 0) neg = OperationType.negate add = OperationType.plus @@ -75,6 +77,9 @@ def test_substitute(): op = BinaryOperation(add, [BinaryOperation(add, [a, a]), b]) op.substitute(a, BinaryOperation(add, [b, c])) assert str(op) == "((b#1 + c#2) + (b#1 + c#2)) + b#1" + op = MemberAccess(operands=[a], member_name="x", offset=0, vartype=Integer.int32_t()) + op.substitute(a, b) + assert str(op) == "b#1.x" def test_substitute_loop(): @@ -99,6 +104,7 @@ def test_complexity(): UnaryOperation(OperationType.cast, [UnaryOperation(OperationType.negate, [UnaryOperation(OperationType.cast, [a])])]).complexity == 1 ) + assert MemberAccess(operands=[a], member_name="x", offset=0, vartype=Integer.int32_t()).complexity == 1 def test_requirements(): @@ -111,6 +117,7 @@ def test_requirements(): assert set(ListOperation([a, BinaryOperation(add, [a, b])]).requirements) == {a, b} assert set(ListOperation([a, BinaryOperation(add, [b, c])]).requirements) == {a, b, c} assert ListOperation([Constant(2), a]).requirements == [a] + assert MemberAccess(operands=[a], member_name="x", offset=0, vartype=Integer.int32_t()).requirements == [a] def test_repr(): @@ -172,6 +179,59 @@ def test_str(): assert str(BinaryOperation(div, [a, b])) == "a#0 / b#1" assert str(BinaryOperation(udiv, [a, b])) == "a#0 u/ b#1" assert str(ListOperation([a, b])) == "a#0,b#1" + assert str(MemberAccess(operands=[a], member_name="x", offset=0, vartype=Integer.int32_t())) == "a#0.x" + assert str(MemberAccess(operands=[ptr], member_name="x", offset=0, vartype=Integer.int32_t())) == "ptr#0->x" + assert ( + str( + MemberAccess( + operands=[MemberAccess(operands=[a], member_name="x", offset=0, vartype=Integer.int32_t())], + member_name="z", + offset=0, + vartype=Integer.int32_t(), + ) + ) + == "a#0.x.z" + ) + assert ( + str( + MemberAccess( + operands=[MemberAccess(operands=[ptr], member_name="x", offset=0, vartype=Integer.int32_t())], + member_name="z", + offset=0, + vartype=Pointer(Integer.int32_t()), + ) + ) + == "ptr#0->x.z" + ) + assert ( + str( + MemberAccess( + operands=[MemberAccess(operands=[ptr], member_name="x", offset=0, vartype=Pointer(Integer.int32_t()))], + member_name="z", + offset=0, + vartype=Pointer(Pointer(Integer.int32_t())), + ) + ) + == "ptr#0->x->z" + ) + assert ( + str( + MemberAccess( + operands=[ + MemberAccess( + operands=[MemberAccess(operands=[ptr], member_name="x", offset=0, vartype=Pointer(Integer.int32_t()))], + member_name="z", + offset=0, + vartype=Pointer(Pointer(Integer.int32_t())), + ) + ], + member_name="w", + offset=8, + vartype=Pointer(Pointer(Pointer(Integer.int32_t()))), + ) + ) + == "ptr#0->x->z->w" + ) def test_iter(): @@ -250,6 +310,23 @@ def test_copy(): ) original.array_info.index = Variable("y") assert copy != original + original = MemberAccess(operands=[a], member_name="x", offset=0, vartype=Integer.int32_t(), writes_memory=1) + copy = original.copy() + assert copy == original + assert id(copy) != original + assert copy == MemberAccess(operands=[a], member_name="x", offset=0, vartype=Integer.int32_t(), writes_memory=1) + + +def test_member_access_properties(): + member_access = MemberAccess(operands=[a], member_name="x", offset=4, vartype=Integer.int32_t(), writes_memory=1) + assert member_access.member_name == "x" + assert member_access.member_offset == 4 + assert member_access.struct_variable == a + assert member_access.is_write_access() + assert not member_access.is_read_access() + member_access = MemberAccess(operands=[a], member_name="x", offset=4, vartype=Integer.int32_t(), writes_memory=None) + assert not member_access.is_write_access() + assert member_access.is_read_access() func = FunctionSymbol("func", 0x42) diff --git a/tests/test_sample_binaries.py b/tests/test_sample_binaries.py index 8537c0660..d013b2e18 100644 --- a/tests/test_sample_binaries.py +++ b/tests/test_sample_binaries.py @@ -260,6 +260,17 @@ def test_tailcall_display(): assert output.count("return fseeko(") == 1 +def test_member_access_is_in_decompiled_code(): + """Test that arg1#0->_IO_read_ptr, arg1#0->_IO_write_base and arg1#0->_IO_save_base + are displayed as member accesses in the decompiled code.""" + args = ["python", "decompile.py", "tests/coreutils/binaries/sha224sum", "rpl_fseeko"] + output = str(subprocess.run(args, check=True, capture_output=True).stdout) + + assert "->_IO_read_ptr" in output + assert "->_IO_save" in output + assert "->_IO_write_base" in output + + def test_issue_70(): """Test Issue #70.""" args = ["python", "decompile.py", "tests/samples/others/issue-70.bin", "main"]