From 083bba9d9c183dd3a4b7dfb5566f48e4592cffef Mon Sep 17 00:00:00 2001 From: Eva-Maria Behner Date: Wed, 25 Oct 2023 10:17:08 +0200 Subject: [PATCH] apply black --- decompiler/backend/cexpressiongenerator.py | 2 +- decompiler/backend/codegenerator.py | 7 +- decompiler/backend/variabledeclarations.py | 7 +- decompiler/frontend/binaryninja/frontend.py | 11 +- .../binaryninja/handlers/assignments.py | 9 +- .../frontend/binaryninja/handlers/calls.py | 13 +- .../binaryninja/handlers/constants.py | 6 +- .../frontend/binaryninja/handlers/globals.py | 115 ++++---- .../frontend/binaryninja/handlers/symbols.py | 23 +- decompiler/frontend/binaryninja/parser.py | 2 +- decompiler/frontend/binaryninja/tagging.py | 6 +- .../commons/expressionpropagationcommons.py | 4 +- .../constant_folding.py | 19 +- .../rules/collapse_add_neg.py | 16 +- .../rules/collapse_constants.py | 5 +- .../rules/collapse_nested_constants.py | 13 +- .../rules/positive_constants.py | 19 +- .../rules/simplify_redundant_reference.py | 3 +- .../rules/simplify_trivial_arithmetic.py | 3 +- .../rules/sub_to_add.py | 9 +- .../expression_simplification/stages.py | 24 +- .../loop_name_generator.py | 2 +- .../loop_utility_methods.py | 3 +- .../readability_based_refinement.py | 67 +++-- .../variable_name_generation.py | 36 +-- .../dataflowanalysis/expressionpropagation.py | 20 +- .../expressionpropagationfunctioncall.py | 26 +- decompiler/pipeline/default.py | 8 +- .../bitfieldcomparisonunrolling.py | 1 + .../preprocessing/remove_stack_canary.py | 2 +- .../switch_variable_detection.py | 2 +- .../pipeline/ssa/sreedhar_out_of_ssa.py | 1 - decompiler/structures/logic/custom_logic.py | 6 +- decompiler/structures/pseudo/expressions.py | 3 +- decompiler/structures/pseudo/operations.py | 4 +- decompiler/structures/pseudo/typing.py | 2 +- .../structures/visitors/substitute_visitor.py | 5 +- decompiler/task.py | 4 +- decompiler/util/integer_util.py | 2 +- tests/backend/test_codegenerator.py | 10 +- tests/frontend/test_parser.py | 2 + .../rules/test_positive_constants.py | 1 - .../test_constant_folding.py | 20 +- .../expression_simplification/test_stage.py | 57 +--- .../test_loop_name_generator.py | 26 +- .../test_readability_based_refinement.py | 101 ++++--- .../test_variable_name_generation.py | 66 +++-- .../test_array_access_detection.py | 15 +- .../test_expression_propagation.py | 30 +- .../test_expression_propagation_mem.py | 274 +++++++++--------- .../test_insert_missing_definition.py | 6 +- .../preprocessing/test_remove_stack_canary.py | 14 +- .../structures/logic/test_logic_condition.py | 4 +- .../logic/test_z3_logic_converter.py | 7 +- tests/structures/pseudo/test_expressions.py | 2 +- tests/structures/pseudo/test_typing.py | 3 + tests/structures/test_maps.py | 7 +- .../visitors/test_substitute_visitor.py | 144 +++------ tests/test_sample_binaries.py | 26 +- 59 files changed, 644 insertions(+), 681 deletions(-) diff --git a/decompiler/backend/cexpressiongenerator.py b/decompiler/backend/cexpressiongenerator.py index 6ff59150a..8ce03e778 100644 --- a/decompiler/backend/cexpressiongenerator.py +++ b/decompiler/backend/cexpressiongenerator.py @@ -348,7 +348,7 @@ def _format_string_literal(constant: expressions.Constant) -> str: @staticmethod def format_variables_declaration(var_type: Type, var_names: list[str]) -> str: - """ Return a string representation of variable declarations.""" + """Return a string representation of variable declarations.""" match var_type: case Pointer(type=FunctionTypeDef() as fun_type): parameter_names = ", ".join(str(parameter) for parameter in fun_type.parameters) diff --git a/decompiler/backend/codegenerator.py b/decompiler/backend/codegenerator.py index 677d2466b..827fdedfe 100644 --- a/decompiler/backend/codegenerator.py +++ b/decompiler/backend/codegenerator.py @@ -39,10 +39,9 @@ def generate_function(self, task: DecompilerTask) -> str: return self.TEMPLATE.substitute( return_type=task.function_return_type, name=task.name, - parameters=", ".join(map( - lambda param: CExpressionGenerator.format_variables_declaration(param.type, [param.name]), - task.function_parameters - )), + parameters=", ".join( + map(lambda param: CExpressionGenerator.format_variables_declaration(param.type, [param.name]), task.function_parameters) + ), local_declarations=LocalDeclarationGenerator.from_task(task) if not task.failed else "", function_body=CodeVisitor(task).visit(task.syntax_tree.root) if not task.failed else task.failure_message, ) diff --git a/decompiler/backend/variabledeclarations.py b/decompiler/backend/variabledeclarations.py index f79ae589c..901ba99db 100644 --- a/decompiler/backend/variabledeclarations.py +++ b/decompiler/backend/variabledeclarations.py @@ -73,10 +73,7 @@ def generate(self, param_names: list[str] = []) -> Iterator[str]: variable_type_mapping[variable.type].append(variable) for variable_type, variables in sorted(variable_type_mapping.items(), key=lambda x: str(x)): for chunked_variables in self._chunks(variables, self._vars_per_line): - yield CExpressionGenerator.format_variables_declaration( - variable_type, - [var.name for var in chunked_variables] - ) + ";" + yield CExpressionGenerator.format_variables_declaration(variable_type, [var.name for var in chunked_variables]) + ";" @staticmethod def _chunks(lst: List, n: int) -> Iterator[List]: @@ -134,7 +131,7 @@ def get_initial_value(variable: GlobalVariable) -> str: return str(variable.initial_value.value) if isinstance(variable.initial_value, bytes): return str(convert_bytes(variable.initial_value, variable.type)) - if isinstance(operation:=variable.initial_value, Operation): + if isinstance(operation := variable.initial_value, Operation): for requirement in operation.requirements: if isinstance(requirement, GlobalVariable): requirement.unsubscript() diff --git a/decompiler/frontend/binaryninja/frontend.py b/decompiler/frontend/binaryninja/frontend.py index bdd1ad2d1..8c5fe8dc8 100644 --- a/decompiler/frontend/binaryninja/frontend.py +++ b/decompiler/frontend/binaryninja/frontend.py @@ -130,13 +130,16 @@ def create_task(self, function_identifier: Union[str, Function], options: Option try: cfg, complex_types = self._extract_cfg(function.function, options) task = DecompilerTask( - function.name, cfg, function_return_type=function.return_type, function_parameters=function.params, - options=options, complex_types=complex_types + function.name, + cfg, + function_return_type=function.return_type, + function_parameters=function.params, + options=options, + complex_types=complex_types, ) except Exception as e: task = DecompilerTask( - function.name, None, function_return_type=function.return_type, function_parameters=function.params, - options=options + function.name, None, function_return_type=function.return_type, function_parameters=function.params, options=options ) task.fail(origin="CFG creation") logging.error(f"Failed to decompile {task.name}, error during CFG creation: {e}") diff --git a/decompiler/frontend/binaryninja/handlers/assignments.py b/decompiler/frontend/binaryninja/handlers/assignments.py index b34efc670..25b494b01 100644 --- a/decompiler/frontend/binaryninja/handlers/assignments.py +++ b/decompiler/frontend/binaryninja/handlers/assignments.py @@ -68,7 +68,7 @@ def lift_set_field(self, assignment: mediumlevelil.MediumLevelILSetVarField, is_ # case 1 (struct), avoid set field of named integers: dest_type = self._lifter.lift(assignment.dest.type) if isinstance(assignment.dest.type, binaryninja.NamedTypeReferenceType) and not ( - isinstance(dest_type, Pointer) and isinstance(dest_type.type, Integer) + isinstance(dest_type, Pointer) and isinstance(dest_type.type, Integer) ): struct_variable = self._lifter.lift(assignment.dest, is_aliased=True, parent=assignment) destination = MemberAccess( @@ -105,11 +105,8 @@ def lift_get_field(self, instruction: mediumlevelil.MediumLevelILVarField, is_al if instruction.offset: return UnaryOperation( OperationType.cast, - [BinaryOperation( - OperationType.right_shift_us, - [source, Constant(instruction.offset, Integer.int32_t())] - )], - cast_type + [BinaryOperation(OperationType.right_shift_us, [source, Constant(instruction.offset, Integer.int32_t())])], + cast_type, ) return UnaryOperation(OperationType.cast, [source], vartype=cast_type, contraction=True) diff --git a/decompiler/frontend/binaryninja/handlers/calls.py b/decompiler/frontend/binaryninja/handlers/calls.py index e84ca9e2e..18af28b8b 100644 --- a/decompiler/frontend/binaryninja/handlers/calls.py +++ b/decompiler/frontend/binaryninja/handlers/calls.py @@ -53,7 +53,7 @@ def lift_syscall(self, call: mediumlevelil.MediumLevelILSyscall, ssa: bool = Fal [self._lifter.lift(parameter, parent=call) for parameter in call.params], vartype=dest.type.copy(), writes_memory=call.output_dest_memory if ssa else None, - meta_data={"param_names": self._lift_syscall_parameter_names(call)} + meta_data={"param_names": self._lift_syscall_parameter_names(call)}, ), ) @@ -74,12 +74,15 @@ def lift_intrinsic(self, call: mediumlevelil.MediumLevelILIntrinsic, ssa: bool = @staticmethod def _lift_call_parameter_names(instruction: mediumlevelil.MediumLevelILCall) -> List[str]: """Lift parameter names of call by iterating over the function parameters where the call is pointing to (if available)""" - if instruction.dest.expr_type is None or not isinstance(instruction.dest.expr_type, PointerType) or \ - not isinstance(instruction.dest.expr_type.target, FunctionType): + if ( + instruction.dest.expr_type is None + or not isinstance(instruction.dest.expr_type, PointerType) + or not isinstance(instruction.dest.expr_type.target, FunctionType) + ): return [] return [param.name for param in instruction.dest.expr_type.target.parameters] - @staticmethod + @staticmethod def _lift_syscall_parameter_names(instruction: mediumlevelil.MediumLevelILSyscall) -> List[str]: """Lift syscall identifier (e.G. sys_open) from a syscall instruction""" - return [str(instruction).split("syscall(")[1].split(' ')[0]] + return [str(instruction).split("syscall(")[1].split(" ")[0]] diff --git a/decompiler/frontend/binaryninja/handlers/constants.py b/decompiler/frontend/binaryninja/handlers/constants.py index 20d38f2dd..295b89ff4 100644 --- a/decompiler/frontend/binaryninja/handlers/constants.py +++ b/decompiler/frontend/binaryninja/handlers/constants.py @@ -7,6 +7,7 @@ BYTE_SIZE = 8 + class ConstantHandler(Handler): def register(self): """Register the handler at its parent lifter.""" @@ -24,7 +25,7 @@ def register(self): def lift_constant(self, constant: mediumlevelil.MediumLevelILConst, **kwargs) -> Constant: """Lift the given constant value.""" - if(constant.constant in [math.inf, -math.inf, math.nan]): + if constant.constant in [math.inf, -math.inf, math.nan]: return NotUseableConstant(str(constant.constant)) return Constant(constant.constant, vartype=self._lifter.lift(constant.expr_type)) @@ -33,7 +34,6 @@ def lift_integer_literal(value: int, **kwargs) -> Constant: """Lift the given literal, which is most likely an artefact from shift operations and the like.""" return Constant(value, vartype=Integer.int32_t()) - def lift_constant_data(self, pointer: mediumlevelil.MediumLevelILConstData, **kwargs) -> Constant: """Lift const data as a non mute able constant string""" return StringSymbol(str(pointer), pointer.address) @@ -67,4 +67,4 @@ def _in_read_only_section(self, addr: int, view: BinaryView) -> bool: for _, section in view.sections.items(): if addr >= section.start and addr <= section.end and section.semantics == SectionSemantics.ReadOnlyDataSectionSemantics: return True - return False \ No newline at end of file + return False diff --git a/decompiler/frontend/binaryninja/handlers/globals.py b/decompiler/frontend/binaryninja/handlers/globals.py index bc66170a3..ee2b5f958 100644 --- a/decompiler/frontend/binaryninja/handlers/globals.py +++ b/decompiler/frontend/binaryninja/handlers/globals.py @@ -31,6 +31,7 @@ MAX_GLOBAL_STRINGBYTES_LENGTH = 129 + class GlobalHandler(Handler): """Handle for global variables.""" @@ -39,7 +40,7 @@ class GlobalHandler(Handler): def __init__(self, lifter): super().__init__(lifter) - self._lift_datavariable_by_type : dict[Type, Callable] = { + self._lift_datavariable_by_type: dict[Type, Callable] = { CharType: self._lift_basic_type, IntegerType: self._lift_basic_type, FloatType: self._lift_basic_type, @@ -47,92 +48,106 @@ def __init__(self, lifter): VoidType: self._lift_void_type, ArrayType: self._lift_constant_type, PointerType: self._lift_pointer_type, - NamedTypeReferenceType : self._lift_named_type_ref, # Lift DataVariable with type NamedTypeRef + NamedTypeReferenceType: self._lift_named_type_ref, # Lift DataVariable with type NamedTypeRef } def register(self): """Register the handler at its parent lifter.""" self._lifter.HANDLERS.update({DataVariable: self.lift_global_variable}) - def lift_global_variable(self, variable: DataVariable, view: BinaryView, - parent: Optional[MediumLevelILInstruction] = None, caller_addr: int = None, **kwargs + def lift_global_variable( + self, variable: DataVariable, view: BinaryView, parent: Optional[MediumLevelILInstruction] = None, caller_addr: int = None, **kwargs ) -> Union[Symbol, UnaryOperation, GlobalVariable, StringSymbol]: """Lift global variables via datavariable type. Check bninja error case + recursive datavariable first""" if not self._addr_in_section(view, variable.address): - return Constant(variable.address, vartype=Integer(view.address_size*BYTE_SIZE, False)) + return Constant(variable.address, vartype=Integer(view.address_size * BYTE_SIZE, False)) - if caller_addr == variable.address: - return self._lifter.lift(variable.symbol) if variable.symbol else \ - Symbol(GLOBAL_VARIABLE_PREFIX + f"{variable.address:x}", variable.address, vartype=Integer.uint32_t()) + if caller_addr == variable.address: + return ( + self._lifter.lift(variable.symbol) + if variable.symbol + else Symbol(GLOBAL_VARIABLE_PREFIX + f"{variable.address:x}", variable.address, vartype=Integer.uint32_t()) + ) return self._lift_datavariable_by_type[type(variable.type)](variable, view, parent) - - def _lift_constant_type(self, variable: DataVariable, view: BinaryView, parent: Optional[MediumLevelILInstruction] = None) -> StringSymbol: - """Lift constant data type (bninja only uses strings) into code""" # jump table ist auch constant + def _lift_constant_type( + self, variable: DataVariable, view: BinaryView, parent: Optional[MediumLevelILInstruction] = None + ) -> StringSymbol: + """Lift constant data type (bninja only uses strings) into code""" # jump table ist auch constant if str(variable).find("char const") != -1: - string = str(variable.value)[2:-1].rstrip('\\x00') # we want to keep escaped control chars (\n), therefore we take the raw string representation of bytes and purge b"" + string = str(variable.value)[2:-1].rstrip( + "\\x00" + ) # we want to keep escaped control chars (\n), therefore we take the raw string representation of bytes and purge b"" return StringSymbol(f'"{string}"', variable.address, vartype=Pointer(Integer.char(), view.address_size * BYTE_SIZE)) - return StringSymbol(f"&{variable.name}" if variable.name else GLOBAL_VARIABLE_PREFIX + f"{variable.address:x}", variable.address) # Else - + return StringSymbol( + f"&{variable.name}" if variable.name else GLOBAL_VARIABLE_PREFIX + f"{variable.address:x}", variable.address + ) # Else def _lift_pointer_type(self, variable: DataVariable, view: BinaryView, parent: Optional[MediumLevelILInstruction] = None): """Lift pointer as: - 1. Function pointer: If bninja already knows it's a function pointer. - 2. Type pointer: As normal type pointer (there _should_ be a datavariable at the pointers dest.) - 3. Void pointer: Try to extract a datavariable (recover type of void* directly), string (char*) or raw bytes (void*) at the given address + 1. Function pointer: If bninja already knows it's a function pointer. + 2. Type pointer: As normal type pointer (there _should_ be a datavariable at the pointers dest.) + 3. Void pointer: Try to extract a datavariable (recover type of void* directly), string (char*) or raw bytes (void*) at the given address """ if isinstance(variable.type.target, FunctionType): return ImportedFunctionSymbol(variable.name, variable.address, vartype=Pointer(Integer.char(), view.address_size * BYTE_SIZE)) if isinstance(variable.type.target, VoidType): init_value, type = self._get_unknown_value(variable.value, view, variable.address) - if not isinstance(type, PointerType): # Fix type to be a pointer (happens when a datavariable is at the dest.) + if not isinstance(type, PointerType): # Fix type to be a pointer (happens when a datavariable is at the dest.) type = Type.pointer(view.arch, type) else: - init_value, type = self._lifter.lift(view.get_data_var_at(variable.value), view=view, caller_addr=variable.address), variable.type + init_value, type = ( + self._lifter.lift(view.get_data_var_at(variable.value), view=view, caller_addr=variable.address), + variable.type, + ) return UnaryOperation( OperationType.address, - [ - GlobalVariable( + [ + GlobalVariable( name=self._lifter.lift(variable.symbol).name if variable.symbol else GLOBAL_VARIABLE_PREFIX + f"{variable.address:x}", vartype=self._lifter.lift(type), ssa_label=parent.ssa_memory_version if parent else 0, - initial_value=init_value - ) - ], + initial_value=init_value, + ) + ], ) - - def _lift_basic_type(self, variable: DataVariable, view: BinaryView, parent: Optional[MediumLevelILInstruction] = None) -> UnaryOperation: + def _lift_basic_type( + self, variable: DataVariable, view: BinaryView, parent: Optional[MediumLevelILInstruction] = None + ) -> UnaryOperation: """Lift basic known type""" return UnaryOperation( OperationType.address, - [ - GlobalVariable( + [ + GlobalVariable( name=self._lifter.lift(variable.symbol).name if variable.symbol else GLOBAL_VARIABLE_PREFIX + f"{variable.address:x}", vartype=self._lifter.lift(variable.type), ssa_label=parent.ssa_memory_version if parent else 0, - initial_value=Constant(variable.value) + initial_value=Constant(variable.value), ) ], ) - - def _lift_void_type(self, variable: DataVariable, view: BinaryView, parent: Optional[MediumLevelILInstruction] = None) -> GlobalVariable: + def _lift_void_type( + self, variable: DataVariable, view: BinaryView, parent: Optional[MediumLevelILInstruction] = None + ) -> GlobalVariable: "Lift unknown type, by checking the value at the given address. Will always be lifted as a pointer. Try to extract datavariable, string or bytes as value" value, type = self._get_unknown_value(variable.address, view, variable.address) return GlobalVariable( - name=self._lifter.lift(variable.symbol).name if variable.symbol else GLOBAL_VARIABLE_PREFIX + f"{variable.address:x}", - vartype=self._lifter.lift(type), - ssa_label=parent.ssa_memory_version if parent else 0, - initial_value=value - ) - - - def _lift_named_type_ref(self, variable: DataVariable, view: BinaryView, parent: Optional[MediumLevelILInstruction] = None) -> GlobalVariable: - """Lift a named custom type (Enum, Structs)""" - return Constant("Unknown value", self._lifter.lift(variable.type)) # BNinja error, need to check with the issue to get the correct value + name=self._lifter.lift(variable.symbol).name if variable.symbol else GLOBAL_VARIABLE_PREFIX + f"{variable.address:x}", + vartype=self._lifter.lift(type), + ssa_label=parent.ssa_memory_version if parent else 0, + initial_value=value, + ) + def _lift_named_type_ref( + self, variable: DataVariable, view: BinaryView, parent: Optional[MediumLevelILInstruction] = None + ) -> GlobalVariable: + """Lift a named custom type (Enum, Structs)""" + return Constant( + "Unknown value", self._lifter.lift(variable.type) + ) # BNinja error, need to check with the issue to get the correct value def _get_unknown_value(self, addr: int, view: BinaryView, caller_addr: int = 0): """Return symbol, datavariable, address, string or raw bytes at given address.""" @@ -148,7 +163,6 @@ def _get_unknown_value(self, addr: int, view: BinaryView, caller_addr: int = 0): if len(data) > MAX_GLOBAL_STRINGBYTES_LENGTH: data = data[:MAX_GLOBAL_STRINGBYTES_LENGTH] + '..."' return data, type - def _get_raw_bytes(self, addr: int, view: BinaryView) -> str: """Returns raw bytes as hex string after a given address to the next data structure or section""" @@ -157,10 +171,9 @@ def _get_raw_bytes(self, addr: int, view: BinaryView) -> str: else: data = view.read(addr, view.get_sections_at(addr)[0].end) - string = ''.join("\\x{:02x}".format(x) for x in data) + string = "".join("\\x{:02x}".format(x) for x in data) return f'"{string}"' - def _get_different_string_types_at(self, addr: int, view: BinaryView) -> Tuple[Optional[str], Type]: """Extract string with char/wchar16/wchar32 type if there is one""" types: list[Type] = [Type.char(), Type.wide_char(2), Type.wide_char(4)] @@ -170,13 +183,12 @@ def _get_different_string_types_at(self, addr: int, view: BinaryView) -> Tuple[O break return string, type - def _get_string_at(self, view: BinaryView, addr: int, width: int) -> Optional[str]: """Read string with specified width from location. Explanation for the magic parsing: - - we read 1, 2 or 4 long integers which should be interpreted as a byte in ASCII range (while Loop; can't use chr() for checking) - - afterwards we convert bytes array manually to a string by removing the "bytearray(...)" parts from the string - - this string now consists of readable chars (A, b), escaped hex values (\\x17) and control chars (\n, \t) - - we consider a it a string, if it only consists of readable chars + control chars + - we read 1, 2 or 4 long integers which should be interpreted as a byte in ASCII range (while Loop; can't use chr() for checking) + - afterwards we convert bytes array manually to a string by removing the "bytearray(...)" parts from the string + - this string now consists of readable chars (A, b), escaped hex values (\\x17) and control chars (\n, \t) + - we consider a it a string, if it only consists of readable chars + control chars """ raw_bytes = bytearray() match width: @@ -198,11 +210,10 @@ def _get_string_at(self, view: BinaryView, addr: int, width: int) -> Optional[st raw_bytes.append(byte) string = str(raw_bytes)[12:-2] - if len(string) < 2 or string.find("\\x") != -1: # escaped + if len(string) < 2 or string.find("\\x") != -1: # escaped return None - - return identifier + f'"{string}"' + return identifier + f'"{string}"' def _addr_in_section(self, view: BinaryView, addr: int) -> bool: """Returns True if address is contained in a section, False otherwise""" diff --git a/decompiler/frontend/binaryninja/handlers/symbols.py b/decompiler/frontend/binaryninja/handlers/symbols.py index fc301bc89..dfe78b576 100644 --- a/decompiler/frontend/binaryninja/handlers/symbols.py +++ b/decompiler/frontend/binaryninja/handlers/symbols.py @@ -11,6 +11,7 @@ MAX_SYMBOL_NAME_LENGTH = 64 GLOBAL_VARIABLE_PREFIX = "data_" + class SymbolHandler(Handler): """Handler for phi instructions emitted by binaryninja.""" @@ -35,7 +36,11 @@ def register(self): } ) - def lift_symbol(self, symbol: CoreSymbol, **kwargs,) -> Union[GlobalVariable, Constant]: + def lift_symbol( + self, + symbol: CoreSymbol, + **kwargs, + ) -> Union[GlobalVariable, Constant]: """Lift the given symbol from binaryninja MLIL.""" if not (symbol_type := self.SYMBOL_MAP.get(symbol.type, None)): warning(f"[Lifter] Can not handle symbols of type {symbol.type}, falling back to constant lifting.") @@ -44,11 +49,13 @@ def lift_symbol(self, symbol: CoreSymbol, **kwargs,) -> Union[GlobalVariable, Co def _purge_symbol_name(self, name: str, addr: int) -> str: """Purge invalid chars from symbol names or lift as data_addr if name is too long""" - if name[:2] == "??" or len(name) > MAX_SYMBOL_NAME_LENGTH: # strip useless PDB debug names which start with `??` + if name[:2] == "??" or len(name) > MAX_SYMBOL_NAME_LENGTH: # strip useless PDB debug names which start with `??` return GLOBAL_VARIABLE_PREFIX + f"{hex(addr)}" - return name.translate({ - ord(' '): '_', - ord("'"): "", - ord('.'): "_", - ord('`'): "", - }) + return name.translate( + { + ord(" "): "_", + ord("'"): "", + ord("."): "_", + ord("`"): "", + } + ) diff --git a/decompiler/frontend/binaryninja/parser.py b/decompiler/frontend/binaryninja/parser.py index 9c24b97dc..119804475 100644 --- a/decompiler/frontend/binaryninja/parser.py +++ b/decompiler/frontend/binaryninja/parser.py @@ -60,7 +60,7 @@ def complex_types(self) -> ComplexTypeMap: def _recover_switch_edge_cases(self, edge: BasicBlockEdge, lookup_table: dict): """ - If edge.target.source_block.start address is not in lookup table, + If edge.target.source_block.start address is not in lookup table, try to recover matching address by inspecting addresses used in edge.target. Return matched case list for edge.target. """ diff --git a/decompiler/frontend/binaryninja/tagging.py b/decompiler/frontend/binaryninja/tagging.py index c30fb3f93..50c2e261a 100644 --- a/decompiler/frontend/binaryninja/tagging.py +++ b/decompiler/frontend/binaryninja/tagging.py @@ -7,6 +7,7 @@ class CompilerIdiomsTagging: """Generates binary view tags for the matched compiler idioms.""" + TAG_SYMBOL = "⚙" TAG_PREFIX = "compiler_idiom: " @@ -33,8 +34,9 @@ def run(self): for match in matches: for address in match.addresses: - self._set_tag(self._bv, tag_name=f"{self.TAG_PREFIX}{match.operation}", address=address, - text=f"{match.operand},{match.constant}") + self._set_tag( + self._bv, tag_name=f"{self.TAG_PREFIX}{match.operation}", address=address, text=f"{match.operand},{match.constant}" + ) @staticmethod def _set_tag(binary_view: BinaryView, tag_name: str, address: int, text: str): diff --git a/decompiler/pipeline/commons/expressionpropagationcommons.py b/decompiler/pipeline/commons/expressionpropagationcommons.py index 7dd06f06f..750ec3af0 100644 --- a/decompiler/pipeline/commons/expressionpropagationcommons.py +++ b/decompiler/pipeline/commons/expressionpropagationcommons.py @@ -172,7 +172,7 @@ def _is_aliased_postponed_for_propagation(self, target: Instruction, definition: We can propagate this in case the variable is used once (in the example used twice). This way we revert insertion of redundant missing definition. If possible, such propagation is done after everything else is propagated. """ - if self._is_aliased_variable(aliased:=definition.destination): + if self._is_aliased_variable(aliased := definition.destination): if self._is_aliased_redefinition(aliased, target): self._postponed_aliased.add(aliased) return True @@ -319,7 +319,6 @@ def _has_any_of_dangerous_uses_between_definition_and_target( raise RuntimeError(f"Same definition {definition} in multiple blocks") definition_block, definition_index = list(definition_block_info)[0] for target_block, target_index in self._blocks_map[str(target)]: - if target_block == definition_block: # then one of dangerous uses should lie between definition and target in the same block for use in dangerous_uses: @@ -494,7 +493,6 @@ def _is_variable_contraction(expression: Expression) -> bool: and expression.operand.complexity == 1 ) - def _is_aliased_redefinition(self, aliased_variable: Variable, instruction: Instruction): """ Given aliased variable check if the instruction is re-definition: diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py index f770ff6b1..07827f126 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py @@ -22,9 +22,7 @@ def constant_fold(operation: OperationType, constants: list[Constant]) -> Consta def _constant_fold_arithmetic_binary( - constants: list[Constant], - fun: Callable[[int, int], int], - norm_sign: Optional[bool] = None + constants: list[Constant], fun: Callable[[int, int], int], norm_sign: Optional[bool] = None ) -> Constant: """ Fold an arithmetic binary operation with constants as operands. @@ -53,10 +51,7 @@ def _constant_fold_arithmetic_binary( left_value = normalize_int(left_value, left.type.size, norm_sign) right_value = normalize_int(right_value, right.type.size, norm_sign) - return Constant( - normalize_int(fun(left_value, right_value), left.type.size, left.type.signed), - left.type - ) + return Constant(normalize_int(fun(left_value, right_value), left.type.size, left.type.signed), left.type) def _constant_fold_arithmetic_unary(constants: list[Constant], fun: Callable[[int], int]) -> Constant: @@ -94,14 +89,8 @@ def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], in left, right = constants - shifted_value = fun( - normalize_int(left.value, left.type.size, left.type.signed and signed), - right.value - ) - return Constant( - normalize_int(shifted_value, left.type.size, left.type.signed), - left.type - ) + shifted_value = fun(normalize_int(left.value, left.type.size, left.type.signed and signed), right.value) + return Constant(normalize_int(shifted_value, left.type.size, left.type.signed), left.type) _OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], Constant]] = { diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py index c978991fc..70cc75408 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py @@ -20,11 +20,13 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: if not isinstance(right, UnaryOperation) or right.operation != OperationType.negate: return [] - return [( - operation, - BinaryOperation( - OperationType.minus if operation.operation == OperationType.plus else OperationType.plus, - [operation.left, right.operand], - operation.type + return [ + ( + operation, + BinaryOperation( + OperationType.minus if operation.operation == OperationType.plus else OperationType.plus, + [operation.left, right.operand], + operation.type, + ), ) - )] + ] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py index 03db3c00d..295a0bb03 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py @@ -14,7 +14,4 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: if operation.operation not in FOLDABLE_OPERATIONS: return [] - return [( - operation, - constant_fold(operation.operation, operation.operands) - )] + return [(operation, constant_fold(operation.operation, operation.operands))] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py index 2f9346a72..cf36ce84b 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py @@ -8,12 +8,14 @@ _COLLAPSIBLE_OPERATIONS = COMMUTATIVE_OPERATIONS & FOLDABLE_OPERATIONS + class CollapseNestedConstants(SimplificationRule): """ This rule walks the dafaflow tree and collects and folds constants in commutative operations. The first constant of the tree is replaced with the folded result and all remaining constants are replaced with the identity. This stage exploits associativity and is the only stage doing so. Therefore, it cannot be replaced by a combination of `TermOrder` and `CollapseConstants`. """ + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: if operation.operation not in _COLLAPSIBLE_OPERATIONS: return [] @@ -26,17 +28,10 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: first, *rest = constants - folded_constant = reduce( - lambda c0, c1: constant_fold(operation.operation, [c0, c1]), - rest, - first - ) + folded_constant = reduce(lambda c0, c1: constant_fold(operation.operation, [c0, c1]), rest, first) identity_constant = _identity_constant(operation.operation, operation.type) - return [ - (first, folded_constant), - *((constant, identity_constant) for constant in rest) - ] + return [(first, folded_constant), *((constant, identity_constant) for constant in rest)] def _collect_constants(operation: Operation) -> Iterator[Constant]: diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/positive_constants.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/positive_constants.py index 42da06986..96263ec04 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/positive_constants.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/positive_constants.py @@ -30,14 +30,13 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: if signed_normalized_constant >= 0: return [] - neg_constant = Constant( - normalize_int(-signed_normalized_constant, constant_type.size, constant_type.signed), - constant_type - ) - return [( - operation, - BinaryOperation( - OperationType.plus if operation.operation == OperationType.minus else OperationType.minus, - [operation.left, neg_constant] + neg_constant = Constant(normalize_int(-signed_normalized_constant, constant_type.size, constant_type.signed), constant_type) + return [ + ( + operation, + BinaryOperation( + OperationType.plus if operation.operation == OperationType.minus else OperationType.minus, + [operation.left, neg_constant], + ), ) - )] + ] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_redundant_reference.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_redundant_reference.py index 9337be30b..83a9d0f23 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_redundant_reference.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_redundant_reference.py @@ -12,8 +12,7 @@ class SimplifyRedundantReference(SimplificationRule): def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: match operation: case UnaryOperation( - operation=OperationType.dereference, - operand=UnaryOperation(operation=OperationType.address, operand=inner_operand) + operation=OperationType.dereference, operand=UnaryOperation(operation=OperationType.address, operand=inner_operand) ): return [(operation, inner_operand)] case _: diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_arithmetic.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_arithmetic.py index 07ebac6c2..7e2d30014 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_arithmetic.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_arithmetic.py @@ -31,8 +31,7 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: case BinaryOperation(operation=OperationType.multiply | OperationType.multiply_us, right=Constant(value=0)): return [(operation, Constant(0, operation.type))] case BinaryOperation( - operation=OperationType.multiply | OperationType.multiply_us | OperationType.divide, - right=Constant(value=-1) + operation=OperationType.multiply | OperationType.multiply_us | OperationType.divide, right=Constant(value=-1) ): return [(operation, UnaryOperation(OperationType.negate, [operation.left]))] case _: diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/sub_to_add.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/sub_to_add.py index 9943d365d..10aadb52d 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/sub_to_add.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/sub_to_add.py @@ -17,11 +17,4 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: neg_op = UnaryOperation(OperationType.negate, [operation.right]) - return [( - operation, - BinaryOperation( - OperationType.plus, - [operation.left, neg_op], - operation.type - ) - )] + return [(operation, BinaryOperation(OperationType.plus, [operation.left, neg_op], operation.type))] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/stages.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/stages.py index cc05a59ba..bdad4333f 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/stages.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/stages.py @@ -26,7 +26,6 @@ class _ExpressionSimplificationBase(PipelineStage, ABC): - def run(self, task: DecompilerTask): max_iterations = task.options.getint("expression-simplification.max_iterations") self._simplify_instructions(self._get_instructions(task), max_iterations) @@ -37,11 +36,7 @@ def _get_instructions(self, task: DecompilerTask) -> list[Instruction]: @classmethod def _simplify_instructions(cls, instructions: list[Instruction], max_iterations: int): - rule_sets = [ - ("pre-rules", _pre_rules), - ("rules", _rules), - ("post-rules", _post_rules) - ] + rule_sets = [("pre-rules", _pre_rules), ("rules", _rules), ("post-rules", _post_rules)] for rule_name, rule_set in rule_sets: # max_iterations is counted per rule_set iteration_count = cls._simplify_instructions_with_rule_set(instructions, rule_set, max_iterations) @@ -52,10 +47,7 @@ def _simplify_instructions(cls, instructions: list[Instruction], max_iterations: @classmethod def _simplify_instructions_with_rule_set( - cls, - instructions: list[Instruction], - rule_set: list[SimplificationRule], - max_iterations: int + cls, instructions: list[Instruction], rule_set: list[SimplificationRule], max_iterations: int ) -> int: iteration_count = 0 @@ -76,12 +68,7 @@ def _simplify_instructions_with_rule_set( return iteration_count @classmethod - def _simplify_instruction_with_rule( - cls, - instruction: Instruction, - rule: SimplificationRule, - max_iterations: int - ) -> int: + def _simplify_instruction_with_rule(cls, instruction: Instruction, rule: SimplificationRule, max_iterations: int) -> int: iteration_count = 0 for expression in instruction.subexpressions(): while True: @@ -163,7 +150,4 @@ def _get_instructions(self, task: DecompilerTask) -> list[Instruction]: CollapseConstants(), CollapseNestedConstants(), ] -_post_rules: list[SimplificationRule] = [ - CollapseAddNeg(), - PositiveConstants() -] +_post_rules: list[SimplificationRule] = [CollapseAddNeg(), PositiveConstants()] diff --git a/decompiler/pipeline/controlflowanalysis/loop_name_generator.py b/decompiler/pipeline/controlflowanalysis/loop_name_generator.py index 169a111d3..0d9590d6f 100644 --- a/decompiler/pipeline/controlflowanalysis/loop_name_generator.py +++ b/decompiler/pipeline/controlflowanalysis/loop_name_generator.py @@ -118,6 +118,6 @@ def run(self, task: DecompilerTask): if rename_while_loops: WhileLoopVariableRenamer(task._ast).rename() - + if for_loop_names: ForLoopVariableRenamer(task._ast, for_loop_names).rename() diff --git a/decompiler/pipeline/controlflowanalysis/loop_utility_methods.py b/decompiler/pipeline/controlflowanalysis/loop_utility_methods.py index 6a2f7a4bc..f9390fecf 100644 --- a/decompiler/pipeline/controlflowanalysis/loop_utility_methods.py +++ b/decompiler/pipeline/controlflowanalysis/loop_utility_methods.py @@ -16,6 +16,7 @@ class AstInstruction: position: int node: CodeNode + def _is_single_instruction_loop_node(loop_node: LoopNode) -> bool: """ Check if the loop body contains only one instruction. @@ -208,4 +209,4 @@ def _requirement_without_reinitialization(ast: AbstractSyntaxTree, node: Abstrac elif variable in assignment.definitions and variable in assignment.requirements: return True elif variable in assignment.requirements: - return True \ No newline at end of file + return True diff --git a/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py b/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py index 7a334fd44..e0d609e42 100644 --- a/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py +++ b/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py @@ -26,11 +26,11 @@ def _get_potential_guarded_do_while_loops(ast: AbstractSyntaxTree) -> tuple(Unio def remove_guarded_do_while(ast: AbstractSyntaxTree): - """ Removes a if statement which guards a do-while loop/while loop when: - -> there is nothing in between the if-node and the do-while-node/while-node - -> the if-node has only one branch (true branch) - -> the condition of the branch is the same as the condition of the do-while-node - Replacement is a WhileLoop, otherwise the control flow would not be correct + """Removes a if statement which guards a do-while loop/while loop when: + -> there is nothing in between the if-node and the do-while-node/while-node + -> the if-node has only one branch (true branch) + -> the condition of the branch is the same as the condition of the do-while-node + Replacement is a WhileLoop, otherwise the control flow would not be correct """ for do_while_node, condition_node in _get_potential_guarded_do_while_loops(ast): if condition_node.false_branch: @@ -43,14 +43,14 @@ def remove_guarded_do_while(ast: AbstractSyntaxTree): class WhileLoopReplacer: """Convert WhileLoopNodes to ForLoopNodes depending on the configuration. - -> keep_empty_for_loops will keep empty for-loops in the code - -> force_for_loops will transform every while-loop into a for-loop, worst case with empty declaration/modification statement - -> forbidden_condition_types_in_simple_for_loops will not transform trivial for-loop candidates (with only one condition) into for-loops - if the operator matches one of the forbidden operator list - -> max_condition_complexity_for_loop_recovery will transform for-loop candidates only into for-loops if the condition complexity is - less/equal then the threshold - -> max_modification_complexity_for_loop_recovery will transform for-loop candidates only into for-loops if the modification complexity is - less/equal then the threshold + -> keep_empty_for_loops will keep empty for-loops in the code + -> force_for_loops will transform every while-loop into a for-loop, worst case with empty declaration/modification statement + -> forbidden_condition_types_in_simple_for_loops will not transform trivial for-loop candidates (with only one condition) into for-loops + if the operator matches one of the forbidden operator list + -> max_condition_complexity_for_loop_recovery will transform for-loop candidates only into for-loops if the condition complexity is + less/equal then the threshold + -> max_modification_complexity_for_loop_recovery will transform for-loop candidates only into for-loops if the modification complexity is + less/equal then the threshold """ def __init__(self, ast: AbstractSyntaxTree, options: Options): @@ -58,25 +58,34 @@ def __init__(self, ast: AbstractSyntaxTree, options: Options): self._keep_empty_for_loops = options.getboolean("readability-based-refinement.keep_empty_for_loops", fallback=False) self._hide_non_init_decl = options.getboolean("readability-based-refinement.hide_non_initializing_declaration", fallback=False) self._force_for_loops = options.getboolean("readability-based-refinement.force_for_loops", fallback=False) - self._forbidden_condition_types = options.getlist("readability-based-refinement.forbidden_condition_types_in_simple_for_loops", fallback=[]) - self._condition_max_complexity = options.getint("readability-based-refinement.max_condition_complexity_for_loop_recovery", fallback=100) - self._modification_max_complexity = options.getint("readability-based-refinement.max_modification_complexity_for_loop_recovery", fallback=100) + self._forbidden_condition_types = options.getlist( + "readability-based-refinement.forbidden_condition_types_in_simple_for_loops", fallback=[] + ) + self._condition_max_complexity = options.getint( + "readability-based-refinement.max_condition_complexity_for_loop_recovery", fallback=100 + ) + self._modification_max_complexity = options.getint( + "readability-based-refinement.max_modification_complexity_for_loop_recovery", fallback=100 + ) def run(self): """For each WhileLoop in AST check the following conditions: -> any variable in loop condition has a valid continuation instruction in loop body -> variable is initialized - -> loop condition complexity < condition complexity + -> loop condition complexity < condition complexity -> possible modification complexity < modification complexity -> if condition is only a symbol: check condition type for allowed one - - If 'force_for_loops' is enabled, the complexity options are ignored and every while loop after the - initial transformation will be forced into a for loop with an empty declaration/modification + + If 'force_for_loops' is enabled, the complexity options are ignored and every while loop after the + initial transformation will be forced into a for loop with an empty declaration/modification """ for loop_node in list(self._ast.get_while_loop_nodes_topological_order()): - if loop_node.is_endless_loop or (not self._keep_empty_for_loops and _is_single_instruction_loop_node(loop_node)) \ - or self._invalid_simple_for_loop_condition_type(loop_node.condition): + if ( + loop_node.is_endless_loop + or (not self._keep_empty_for_loops and _is_single_instruction_loop_node(loop_node)) + or self._invalid_simple_for_loop_condition_type(loop_node.condition) + ): continue if any(node.does_end_with_continue for node in loop_node.body.get_descendant_code_nodes_interrupting_ancestor_loop()): @@ -100,11 +109,11 @@ def run(self): self._ast.substitute_loop_node( loop_node, ForLoopNode( - declaration=None, - condition=loop_node.condition, - modification=None, - reaching_condition=loop_node.reaching_condition, - ) + declaration=None, + condition=loop_node.condition, + modification=None, + reaching_condition=loop_node.reaching_condition, + ), ) def _replace_with_for_loop(self, loop_node: WhileLoopNode, continuation: AstInstruction, init: AstInstruction): @@ -139,9 +148,9 @@ def _replace_with_for_loop(self, loop_node: WhileLoopNode, continuation: AstInst ) continuation.node.instructions.remove(continuation.instruction) self._ast.clean_up() - + def _invalid_simple_for_loop_condition_type(self, logic_condition) -> bool: - """ Checks if a logic condition is only a symbol, if true checks condition type of symbol for forbidden ones""" + """Checks if a logic condition is only a symbol, if true checks condition type of symbol for forbidden ones""" if not logic_condition.is_symbol or not self._forbidden_condition_types: return False diff --git a/decompiler/pipeline/controlflowanalysis/variable_name_generation.py b/decompiler/pipeline/controlflowanalysis/variable_name_generation.py index 97f838cb7..8627c93da 100644 --- a/decompiler/pipeline/controlflowanalysis/variable_name_generation.py +++ b/decompiler/pipeline/controlflowanalysis/variable_name_generation.py @@ -18,6 +18,7 @@ ==> Therefore we always collect EVERY variable used + check with a method (already_renamed) if we already renamed it to our new naming scheme """ + def _get_var_counter(var_name: str) -> Optional[str]: """Return the counter of a given variable name, if any is present.""" if counter := re.match(r".*?([0-9]+)$", var_name): @@ -64,6 +65,7 @@ def visit_variable(self, expression: Variable): class NamingConvention(str, Enum): """Enum for the currently available naming conventions.""" + default = "default" system_hungarian = "system_hungarian" @@ -76,18 +78,20 @@ def __init__(self, task: DecompilerTask) -> None: collector = VariableCollector(task._ast.condition_map) collector.visit_ast(task._ast) self._params: List[Variable] = task._function_parameters - self._loop_vars : List[Variable] = collector.get_loop_variables() + self._loop_vars: List[Variable] = collector.get_loop_variables() self._variables: List[Variable] = list(filter(self._filter_variables, collector.get_variables())) - def _filter_variables(self, item: Variable) -> bool: """Return False if variable is either a: - - parameter - - renamed loop variable - - GlobalVariable + - parameter + - renamed loop variable + - GlobalVariable """ - return not item in self._params and not (item in self._loop_vars and item.name.find("var_") == -1) and not isinstance(item, GlobalVariable) - + return ( + not item in self._params + and not (item in self._loop_vars and item.name.find("var_") == -1) + and not isinstance(item, GlobalVariable) + ) @abstractmethod def renameVariableNames(self): @@ -103,10 +107,7 @@ class HungarianScheme(RenamingScheme): Integer: {8: "ch", 16: "s", 32: "i", 64: "l", 128: "i128"}, } - custom_var_names = { - "tmp_": "Tmp", - "loop_break": "LoopBreak" - } + custom_var_names = {"tmp_": "Tmp", "loop_break": "LoopBreak"} def __init__(self, task: DecompilerTask) -> None: super().__init__(task) @@ -115,7 +116,6 @@ def __init__(self, task: DecompilerTask) -> None: self._pointer_base: bool = task.options.getboolean(f"{self._name}.pointer_base", fallback=True) self._type_separator: str = task.options.getstring(f"{self._name}.type_separator", fallback="") self._counter_separator: str = task.options.getstring(f"{self._name}.counter_separator", fallback="") - def renameVariableNames(self): """Rename all collected variables to the hungarian notation.""" @@ -124,13 +124,11 @@ def renameVariableNames(self): continue counter = _get_var_counter(var.name) var._name = self._hungarian_notation(var, counter if counter else "") - def _hungarian_notation(self, var: Variable, counter: int) -> str: """Return hungarian notation to a given variable.""" return f"{self._hungarian_prefix(var.type)}{self._type_separator}{self.custom_var_names.get(var._name.rstrip(counter), self._var_name)}{self._counter_separator}{counter}" - def _hungarian_prefix(self, var_type: Type) -> str: """Return hungarian prefix to a given variable type.""" if isinstance(var_type, Pointer): @@ -150,36 +148,34 @@ def _hungarian_prefix(self, var_type: Type) -> str: return f"{sign}{prefix}" return "" - - def alread_renamed(self, name) -> bool: + def alread_renamed(self, name) -> bool: """Return true if variable with custom name was already renamed, false otherwise""" renamed_keys_words = [key for key in self.custom_var_names.values()] + ["unk", self._var_name] return any(keyword in name for keyword in renamed_keys_words) + class DefaultScheme(RenamingScheme): """Class which renames variables into the default scheme.""" def __init__(self, task: DecompilerTask) -> None: super().__init__(task) - def renameVariableNames(self): # Maybe make the suboptions more generic, so that the default scheme can also be changed by some parameters? pass class VariableNameGeneration(PipelineStage): - """ + """ Pipelinestage in charge of renaming variables to a configured format. Currently only the 'default' or 'hungarian' system are supported. """ - name : str = "variable-name-generation" + name: str = "variable-name-generation" def __init__(self): self._notation: str = None - def run(self, task: DecompilerTask): """Rename variable names to the given scheme.""" self._notation = task.options.getstring(f"{self.name}.notation", fallback="default") diff --git a/decompiler/pipeline/dataflowanalysis/expressionpropagation.py b/decompiler/pipeline/dataflowanalysis/expressionpropagation.py index 07691b354..3e7df7e9c 100644 --- a/decompiler/pipeline/dataflowanalysis/expressionpropagation.py +++ b/decompiler/pipeline/dataflowanalysis/expressionpropagation.py @@ -26,14 +26,14 @@ def _definition_can_be_propagated_into_target(self, definition: Assignment, targ :return: true if propagation is allowed false otherwise """ return isinstance(definition, Assignment) and not ( - self._is_phi(definition) - or self._is_call_assignment(definition) - or self._defines_unknown_expression(definition) - or self._contains_aliased_variables(definition) - or self._is_address_assignment(definition) - or self._contains_global_variable(definition) - or self._operation_is_propagated_in_phi(target, definition) - or self._resulting_instruction_is_too_long(target, definition) - or self._is_invalid_propagation_into_address_operation(target, definition) - or self._is_dereference_assignment(definition) + self._is_phi(definition) + or self._is_call_assignment(definition) + or self._defines_unknown_expression(definition) + or self._contains_aliased_variables(definition) + or self._is_address_assignment(definition) + or self._contains_global_variable(definition) + or self._operation_is_propagated_in_phi(target, definition) + or self._resulting_instruction_is_too_long(target, definition) + or self._is_invalid_propagation_into_address_operation(target, definition) + or self._is_dereference_assignment(definition) ) diff --git a/decompiler/pipeline/dataflowanalysis/expressionpropagationfunctioncall.py b/decompiler/pipeline/dataflowanalysis/expressionpropagationfunctioncall.py index 4efd7cb3c..99817ee90 100644 --- a/decompiler/pipeline/dataflowanalysis/expressionpropagationfunctioncall.py +++ b/decompiler/pipeline/dataflowanalysis/expressionpropagationfunctioncall.py @@ -85,19 +85,19 @@ def _definition_can_be_propagated_into_target(self, definition: Assignment, targ self._is_call_assignment(definition) and self._is_call_value_used_exactly_once(definition) and not ( - self._is_phi(definition) - or self._defines_unknown_expression(definition) - or self._contains_aliased_variables(definition) - or self._is_address_assignment(definition) - or self._contains_global_variable(definition) - or self._operation_is_propagated_in_phi(target, definition) - or self._resulting_instruction_is_too_long(target, definition) - or self._is_invalid_propagation_into_address_operation(target, definition) - or self._is_dereference_assignment(definition) - or self._definition_value_could_be_modified_via_memory_access_between_definition_and_target(definition, target) - or self._pointer_value_used_in_definition_could_be_modified_via_memory_access_between_definition_and_target( - definition, target - ) + self._is_phi(definition) + or self._defines_unknown_expression(definition) + or self._contains_aliased_variables(definition) + or self._is_address_assignment(definition) + or self._contains_global_variable(definition) + or self._operation_is_propagated_in_phi(target, definition) + or self._resulting_instruction_is_too_long(target, definition) + or self._is_invalid_propagation_into_address_operation(target, definition) + or self._is_dereference_assignment(definition) + or self._definition_value_could_be_modified_via_memory_access_between_definition_and_target(definition, target) + or self._pointer_value_used_in_definition_could_be_modified_via_memory_access_between_definition_and_target( + definition, target + ) ) ) diff --git a/decompiler/pipeline/default.py b/decompiler/pipeline/default.py index c33a6c9ff..7627b3f3c 100644 --- a/decompiler/pipeline/default.py +++ b/decompiler/pipeline/default.py @@ -44,10 +44,4 @@ EdgePruner, ] -AST_STAGES = [ - ReadabilityBasedRefinement, - ExpressionSimplificationAst, - InstructionLengthHandler, - VariableNameGeneration, - LoopNameGenerator -] +AST_STAGES = [ReadabilityBasedRefinement, ExpressionSimplificationAst, InstructionLengthHandler, VariableNameGeneration, LoopNameGenerator] diff --git a/decompiler/pipeline/expressions/bitfieldcomparisonunrolling.py b/decompiler/pipeline/expressions/bitfieldcomparisonunrolling.py index 51fb37814..bb7326d45 100644 --- a/decompiler/pipeline/expressions/bitfieldcomparisonunrolling.py +++ b/decompiler/pipeline/expressions/bitfieldcomparisonunrolling.py @@ -18,6 +18,7 @@ class FoldedCase: """ Class for storing information of folded case. """ + basic_block: BasicBlock switch_variable: Expression case_values: List[int] diff --git a/decompiler/pipeline/preprocessing/remove_stack_canary.py b/decompiler/pipeline/preprocessing/remove_stack_canary.py index 14b2f2ed2..0a37977f2 100644 --- a/decompiler/pipeline/preprocessing/remove_stack_canary.py +++ b/decompiler/pipeline/preprocessing/remove_stack_canary.py @@ -21,7 +21,7 @@ def run(self, task: DecompilerTask): if task.options.getboolean(f"{self.name}.remove_canary", fallback=False) and task.name != self.STACK_FAIL_STR: self._cfg = task.graph if len(self._cfg) == 1: - return # do not remove the only node + return # do not remove the only node for fail_node in list(self._contains_stack_check_fail()): self._patch_canary(fail_node) diff --git a/decompiler/pipeline/preprocessing/switch_variable_detection.py b/decompiler/pipeline/preprocessing/switch_variable_detection.py index e459c4b14..c1ac78992 100644 --- a/decompiler/pipeline/preprocessing/switch_variable_detection.py +++ b/decompiler/pipeline/preprocessing/switch_variable_detection.py @@ -123,7 +123,7 @@ def _is_used_in_condition_assignment(self, value: Variable): for usage in self._use_map.get(value): if isinstance(usage, Assignment) and isinstance(usage.value, Condition) and usage.requirements == [value]: return True - return False + return False def _is_used_in_branch(self, value: Variable): """ diff --git a/decompiler/pipeline/ssa/sreedhar_out_of_ssa.py b/decompiler/pipeline/ssa/sreedhar_out_of_ssa.py index bca8a8b21..d0b0d719f 100644 --- a/decompiler/pipeline/ssa/sreedhar_out_of_ssa.py +++ b/decompiler/pipeline/ssa/sreedhar_out_of_ssa.py @@ -136,7 +136,6 @@ def _insert_copy(self, x, instr): self._stmt_block_map[instr] = current_block else: - xnew = expressions.Var(x.name, x.type) self._copy_counter += 1 xnew_copy = expressions.Assignment(x, xnew) diff --git a/decompiler/structures/logic/custom_logic.py b/decompiler/structures/logic/custom_logic.py index 26e5716a2..8e766358f 100644 --- a/decompiler/structures/logic/custom_logic.py +++ b/decompiler/structures/logic/custom_logic.py @@ -307,7 +307,9 @@ def remove_redundancy(self, condition_handler: ConditionHandler) -> LOGICCLASS: return self @classmethod - def get_logic_condition(cls, real_condition: PseudoCustomLogicCondition, condition_handler: ConditionHandler) -> Optional[CustomLogicCondition]: + def get_logic_condition( + cls, real_condition: PseudoCustomLogicCondition, condition_handler: ConditionHandler + ) -> Optional[CustomLogicCondition]: """Generate a symbol condition given the real-condition together with the condition handler.""" context = real_condition.context non_logic_operands = { @@ -318,7 +320,7 @@ def get_logic_condition(cls, real_condition: PseudoCustomLogicCondition, conditi replacement_dict = dict() expression_of_variables: Dict[Variable, pseudo.Expression] = dict() for symbol in condition_handler: - replacement_dict[condition_handler.get_z3_condition_of(symbol)._condition] =symbol._condition + replacement_dict[condition_handler.get_z3_condition_of(symbol)._condition] = symbol._condition for operand in [op for op in condition_handler.get_condition_of(symbol) if not isinstance(op, pseudo.Constant)]: expression_of_variables[context.variable(real_condition._variable_name_for(operand))] = operand diff --git a/decompiler/structures/pseudo/expressions.py b/decompiler/structures/pseudo/expressions.py index 5c59afaf2..0b849628c 100644 --- a/decompiler/structures/pseudo/expressions.py +++ b/decompiler/structures/pseudo/expressions.py @@ -257,12 +257,13 @@ def copy(self) -> ExternFunctionPointer: class NotUseableConstant(Constant): """Represents a non useable constant like 'inf' or 'NaN' as a string""" + def __init__(self, value: str, tags: Optional[Tuple[Tag, ...]] = None): super().__init__(value, CustomType("double", 0), tags=tags) def __str__(self) -> str: """Return a string because NotUseableConstant are string only""" - return self.value + return self.value def __repr__(self): """Return the non usable constant.""" diff --git a/decompiler/structures/pseudo/operations.py b/decompiler/structures/pseudo/operations.py index c21002068..7f5b6a37a 100644 --- a/decompiler/structures/pseudo/operations.py +++ b/decompiler/structures/pseudo/operations.py @@ -160,7 +160,7 @@ class OperationType(Enum): OperationType.logical_or, OperationType.logical_and, OperationType.equal, - OperationType.not_equal + OperationType.not_equal, } NON_COMPOUNDABLE_OPERATIONS = { @@ -172,7 +172,7 @@ class OperationType(Enum): OperationType.logical_or, OperationType.logical_and, OperationType.equal, - OperationType.not_equal + OperationType.not_equal, } diff --git a/decompiler/structures/pseudo/typing.py b/decompiler/structures/pseudo/typing.py index 8f367da62..d679023c9 100644 --- a/decompiler/structures/pseudo/typing.py +++ b/decompiler/structures/pseudo/typing.py @@ -191,7 +191,7 @@ def __init__(self, text: str, size: int): @classmethod def bool(cls) -> CustomType: """Return a boolean type representing either TRUE or FALSE.""" - return cls("bool", 8) # BN bool has size 8 + return cls("bool", 8) # BN bool has size 8 @classmethod def void(cls) -> CustomType: diff --git a/decompiler/structures/visitors/substitute_visitor.py b/decompiler/structures/visitors/substitute_visitor.py index b4d7646e2..4e3928dc4 100644 --- a/decompiler/structures/visitors/substitute_visitor.py +++ b/decompiler/structures/visitors/substitute_visitor.py @@ -152,10 +152,7 @@ def visit_binary_operation(self, op: BinaryOperation) -> Optional[DataflowObject def visit_call(self, op: Call) -> Optional[DataflowObject]: if (function_replacement := op.function.accept(self)) is not None: - op._function = _assert_type( - function_replacement, - Union[FunctionSymbol, ImportedFunctionSymbol, IntrinsicSymbol, Variable] - ) + op._function = _assert_type(function_replacement, Union[FunctionSymbol, ImportedFunctionSymbol, IntrinsicSymbol, Variable]) return self._visit_operation(op) diff --git a/decompiler/task.py b/decompiler/task.py index d476e71ec..f46a28777 100644 --- a/decompiler/task.py +++ b/decompiler/task.py @@ -20,7 +20,7 @@ def __init__( options: Optional[Options] = None, function_return_type: Type = Integer(32), function_parameters: Optional[List[Variable]] = None, - complex_types: Optional[ComplexTypeMap] = None + complex_types: Optional[ComplexTypeMap] = None, ): """ Init a new decompiler task. @@ -109,4 +109,4 @@ def code(self) -> str: @code.setter def code(self, value): """Setter function for C-Code representation of the Task""" - self._code = value \ No newline at end of file + self._code = value diff --git a/decompiler/util/integer_util.py b/decompiler/util/integer_util.py index 1e96f62bf..3c641abad 100644 --- a/decompiler/util/integer_util.py +++ b/decompiler/util/integer_util.py @@ -16,4 +16,4 @@ def normalize_int(v: int, size: int, signed: bool) -> int: if signed and value & (1 << (size - 1)): return value - (1 << size) else: - return value \ No newline at end of file + return value diff --git a/tests/backend/test_codegenerator.py b/tests/backend/test_codegenerator.py index 47726e32d..caaade2c9 100644 --- a/tests/backend/test_codegenerator.py +++ b/tests/backend/test_codegenerator.py @@ -167,7 +167,8 @@ def test_empty_function_two_function_parameters(self): code_node = ast._add_code_node([]) ast._add_edge(root, code_node) assert self._regex_matches( - r"^\s*int +test_function\(\s*int +\(\*\s*p\)\(int\)\s*,\s*int +\(\*\s*p0\)\(int\)\s*\){\s*}\s*$", self._task(ast, params=[var_fun_p.copy(), var_fun_p0.copy()]) + r"^\s*int +test_function\(\s*int +\(\*\s*p\)\(int\)\s*,\s*int +\(\*\s*p0\)\(int\)\s*\){\s*}\s*$", + self._task(ast, params=[var_fun_p.copy(), var_fun_p0.copy()]), ) def test_function_with_instruction(self): @@ -465,7 +466,6 @@ def test_branch_condition(self, context, condition: LogicCondition, condition_ma regex = r"^%int +test_function\(\)%{(?s).*if%\(%COND_STR%\)%{%return%0%;%}%}%$" assert self._regex_matches(regex.replace("COND_STR", expected).replace("%", "\\s*"), self._task(ast)) - def test_loop_declaration_ListOp(self): """ a = 5; @@ -483,7 +483,11 @@ def test_loop_declaration_ListOp(self): Assignment(Variable("a"), Constant(5)), ] ) - loop_node = ast.factory.create_for_loop_node(Assignment(ListOperation([Variable("b")]), Call(ImportedFunctionSymbol("foo", 0), [])), logic_cond("x1", context), Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))) + loop_node = ast.factory.create_for_loop_node( + Assignment(ListOperation([Variable("b")]), Call(ImportedFunctionSymbol("foo", 0), [])), + logic_cond("x1", context), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ) loop_node_body = ast._add_code_node( [ Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("1")])), diff --git a/tests/frontend/test_parser.py b/tests/frontend/test_parser.py index 678e4f326..12b7656dc 100644 --- a/tests/frontend/test_parser.py +++ b/tests/frontend/test_parser.py @@ -161,6 +161,7 @@ def __init__(self, address: int): self.dest.constant = address self.dest.function = MockFunction([]) # need .function.view to lift + class MockTailcall(Mock): """Mock object representing a constant jump.""" @@ -315,6 +316,7 @@ def test_convert_indirect_edge_to_unconditional_no_valid_edge(parser): assert not isinstance(cfg_edge, UnconditionalEdge) assert len(list(cfg.instructions)) == 1 + def test_tailcall_address_recovery(parser): """ Address of edge.target.source_block.start is not in lookup table. diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_positive_constants.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_positive_constants.py index 8142753f9..e8e2be53e 100644 --- a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_positive_constants.py +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_positive_constants.py @@ -19,7 +19,6 @@ ), (BinaryOperation(OperationType.plus, [var_x_i, (Constant(3, Integer.int32_t()))]), []), (BinaryOperation(OperationType.minus, [var_x_i, (Constant(3, Integer.int32_t()))]), []), - ( BinaryOperation(OperationType.minus, [var_x_u, (Constant(4294967293, Integer.uint32_t()))]), [BinaryOperation(OperationType.plus, [var_x_u, Constant(3, Integer.uint32_t())])], diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py b/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py index 24f88ed93..40efe858b 100644 --- a/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py +++ b/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py @@ -21,10 +21,7 @@ def _c_float(value: float) -> Constant: return Constant(value, Float.float()) -@pytest.mark.parametrize( - ["operation"], - [(operation,) for operation in OperationType if operation not in FOLDABLE_OPERATIONS] -) +@pytest.mark.parametrize(["operation"], [(operation,) for operation in OperationType if operation not in FOLDABLE_OPERATIONS]) def test_constant_fold_invalid_operations(operation: OperationType): with pytest.raises(ValueError): constant_fold(operation, []) @@ -40,7 +37,6 @@ def test_constant_fold_invalid_operations(operation: OperationType): (OperationType.plus, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), (OperationType.plus, [_c_i32(3)], None, pytest.raises(ValueError)), (OperationType.plus, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.minus, [_c_i32(3), _c_i32(4)], _c_i32(-1), nullcontext()), (OperationType.minus, [_c_i32(-2147483648), _c_i32(1)], _c_i32(2147483647), nullcontext()), (OperationType.minus, [_c_u32(3), _c_u32(4)], _c_u32(4294967295), nullcontext()), @@ -48,7 +44,6 @@ def test_constant_fold_invalid_operations(operation: OperationType): (OperationType.minus, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), (OperationType.minus, [_c_i32(3)], None, pytest.raises(ValueError)), (OperationType.minus, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.multiply, [_c_i32(3), _c_i32(4)], _c_i32(12), nullcontext()), (OperationType.multiply, [_c_i32(-1073741824), _c_i32(2)], _c_i32(-2147483648), nullcontext()), (OperationType.multiply, [_c_u32(3221225472), _c_u32(2)], _c_u32(2147483648), nullcontext()), @@ -56,7 +51,6 @@ def test_constant_fold_invalid_operations(operation: OperationType): (OperationType.multiply, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), (OperationType.multiply, [_c_i32(3)], None, pytest.raises(ValueError)), (OperationType.multiply, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.multiply_us, [_c_i32(3), _c_i32(4)], _c_i32(12), nullcontext()), (OperationType.multiply_us, [_c_i32(-1073741824), _c_i32(2)], _c_i32(-2147483648), nullcontext()), (OperationType.multiply_us, [_c_u32(3221225472), _c_u32(2)], _c_u32(2147483648), nullcontext()), @@ -64,44 +58,37 @@ def test_constant_fold_invalid_operations(operation: OperationType): (OperationType.multiply_us, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), (OperationType.multiply_us, [_c_i32(3)], None, pytest.raises(ValueError)), (OperationType.multiply_us, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.divide, [_c_i32(12), _c_i32(4)], _c_i32(3), nullcontext()), (OperationType.divide, [_c_i32(-2147483648), _c_i32(2)], _c_i32(-1073741824), nullcontext()), (OperationType.divide, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), (OperationType.divide, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), (OperationType.divide, [_c_i32(3)], None, pytest.raises(ValueError)), (OperationType.divide, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.divide_us, [_c_i32(12), _c_i32(4)], _c_i32(3), nullcontext()), (OperationType.divide_us, [_c_i32(-2147483648), _c_i32(2)], _c_i32(1073741824), nullcontext()), (OperationType.divide_us, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), (OperationType.divide_us, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), (OperationType.divide_us, [_c_i32(3)], None, pytest.raises(ValueError)), (OperationType.divide_us, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.negate, [_c_i32(3)], _c_i32(-3), nullcontext()), (OperationType.negate, [_c_i32(-2147483648)], _c_i32(-2147483648), nullcontext()), (OperationType.negate, [], None, pytest.raises(ValueError)), (OperationType.negate, [_c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.left_shift, [_c_i32(3), _c_i32(4)], _c_i32(48), nullcontext()), (OperationType.left_shift, [_c_i32(1073741824), _c_i32(1)], _c_i32(-2147483648), nullcontext()), (OperationType.left_shift, [_c_u32(1073741824), _c_u32(1)], _c_u32(2147483648), nullcontext()), (OperationType.left_shift, [_c_i32(3)], None, pytest.raises(ValueError)), (OperationType.left_shift, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.right_shift, [_c_i32(32), _c_i32(4)], _c_i32(2), nullcontext()), (OperationType.right_shift, [_c_i32(-2147483648), _c_i32(1)], _c_i32(-1073741824), nullcontext()), (OperationType.right_shift, [_c_u32(2147483648), _c_u32(1)], _c_u32(1073741824), nullcontext()), (OperationType.right_shift, [_c_i32(3)], None, pytest.raises(ValueError)), (OperationType.right_shift, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.right_shift_us, [_c_i32(32), _c_i32(4)], _c_i32(2), nullcontext()), (OperationType.right_shift_us, [_c_i32(-2147483648), _c_i32(1)], _c_i32(1073741824), nullcontext()), (OperationType.right_shift_us, [_c_u32(2147483648), _c_u32(1)], _c_u32(1073741824), nullcontext()), (OperationType.right_shift_us, [_c_i32(3)], None, pytest.raises(ValueError)), (OperationType.right_shift_us, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.bitwise_or, [_c_i32(85), _c_i32(34)], _c_i32(119), nullcontext()), (OperationType.bitwise_or, [_c_i32(-2147483648), _c_i32(1)], _c_i32(-2147483647), nullcontext()), (OperationType.bitwise_or, [_c_u32(2147483648), _c_u32(1)], _c_u32(2147483649), nullcontext()), @@ -109,7 +96,6 @@ def test_constant_fold_invalid_operations(operation: OperationType): (OperationType.bitwise_or, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), (OperationType.bitwise_or, [_c_i32(3)], None, pytest.raises(ValueError)), (OperationType.bitwise_or, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.bitwise_and, [_c_i32(85), _c_i32(51)], _c_i32(17), nullcontext()), (OperationType.bitwise_and, [_c_i32(-2147483647), _c_i32(3)], _c_i32(1), nullcontext()), (OperationType.bitwise_and, [_c_u32(2147483649), _c_u32(3)], _c_u32(1), nullcontext()), @@ -117,7 +103,6 @@ def test_constant_fold_invalid_operations(operation: OperationType): (OperationType.bitwise_and, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), (OperationType.bitwise_and, [_c_i32(3)], None, pytest.raises(ValueError)), (OperationType.bitwise_and, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.bitwise_xor, [_c_i32(85), _c_i32(51)], _c_i32(102), nullcontext()), (OperationType.bitwise_xor, [_c_i32(-2147483647), _c_i32(-2147483646)], _c_i32(3), nullcontext()), (OperationType.bitwise_xor, [_c_u32(2147483649), _c_u32(2147483650)], _c_u32(3), nullcontext()), @@ -125,13 +110,12 @@ def test_constant_fold_invalid_operations(operation: OperationType): (OperationType.bitwise_xor, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), (OperationType.bitwise_xor, [_c_i32(3)], None, pytest.raises(ValueError)), (OperationType.bitwise_xor, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.bitwise_not, [_c_i32(6)], _c_i32(-7), nullcontext()), (OperationType.bitwise_not, [_c_i32(-2147483648)], _c_i32(2147483647), nullcontext()), (OperationType.bitwise_not, [_c_u32(2147483648)], _c_u32(2147483647), nullcontext()), (OperationType.bitwise_not, [], None, pytest.raises(ValueError)), (OperationType.bitwise_not, [_c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - ] + ], ) def test_constant_fold(operation: OperationType, constants: list[Constant], result: Constant, context): with context: diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/test_stage.py b/tests/pipeline/controlflowanalysis/expression_simplification/test_stage.py index a9b74833c..449f1428a 100644 --- a/tests/pipeline/controlflowanalysis/expression_simplification/test_stage.py +++ b/tests/pipeline/controlflowanalysis/expression_simplification/test_stage.py @@ -22,7 +22,7 @@ class _RedundantChanges(SimplificationRule): def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: return [(operation, operation)] - + class _NoChanges(SimplificationRule): def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: @@ -48,62 +48,29 @@ def _v_i32(name: str) -> Variable: @pytest.mark.parametrize( ["rule_set", "instruction", "expected_result"], [ - ( - [TermOrder()], - Assignment(_v_i32("a"), _add(_c_i32(1), _v_i32("b"))), - Assignment(_v_i32("a"), _add(_v_i32("b"), _c_i32(1))) - ), - ( - [CollapseConstants()], - Assignment(_v_i32("a"), _sub(_c_i32(10), _add(_c_i32(3), _c_i32(2)))), - Assignment(_v_i32("a"), _c_i32(5)) - ), + ([TermOrder()], Assignment(_v_i32("a"), _add(_c_i32(1), _v_i32("b"))), Assignment(_v_i32("a"), _add(_v_i32("b"), _c_i32(1)))), + ([CollapseConstants()], Assignment(_v_i32("a"), _sub(_c_i32(10), _add(_c_i32(3), _c_i32(2)))), Assignment(_v_i32("a"), _c_i32(5))), ( [SubToAdd(), SimplifyTrivialArithmetic(), CollapseConstants(), CollapseNestedConstants()], Assignment(_v_i32("a"), _sub(_add(_v_i32("a"), _c_i32(5)), _c_i32(5))), - Assignment(_v_i32("a"), _v_i32("a")) + Assignment(_v_i32("a"), _v_i32("a")), ), - ] + ], ) -def test_simplify_instructions_with_rule_set( - rule_set: list[SimplificationRule], - instruction: Instruction, - expected_result: Instruction -): - _ExpressionSimplificationBase._simplify_instructions_with_rule_set( - [instruction], - rule_set, - 100 - ) +def test_simplify_instructions_with_rule_set(rule_set: list[SimplificationRule], instruction: Instruction, expected_result: Instruction): + _ExpressionSimplificationBase._simplify_instructions_with_rule_set([instruction], rule_set, 100) assert instruction == expected_result @pytest.mark.parametrize( ["rule_set", "instruction", "max_iterations", "expect_exceed_max_iterations"], [ - ( - [_RedundantChanges()], - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Constant(1), Variable("b")])), - 10, - True - ), - ( - [_NoChanges()], - Assignment(_v_i32("a"), _v_i32("b")), - 0, - False - ) - ] + ([_RedundantChanges()], Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Constant(1), Variable("b")])), 10, True), + ([_NoChanges()], Assignment(_v_i32("a"), _v_i32("b")), 0, False), + ], ) def test_simplify_instructions_with_rule_set_max_iterations( - rule_set: list[SimplificationRule], - instruction: Instruction, - max_iterations: int, - expect_exceed_max_iterations: bool + rule_set: list[SimplificationRule], instruction: Instruction, max_iterations: int, expect_exceed_max_iterations: bool ): - iterations = _ExpressionSimplificationBase._simplify_instructions_with_rule_set( - [instruction], - rule_set, - max_iterations - ) + iterations = _ExpressionSimplificationBase._simplify_instructions_with_rule_set([instruction], rule_set, max_iterations) assert (iterations > max_iterations) == expect_exceed_max_iterations diff --git a/tests/pipeline/controlflowanalysis/test_loop_name_generator.py b/tests/pipeline/controlflowanalysis/test_loop_name_generator.py index 8a1a1aeec..45d13033c 100644 --- a/tests/pipeline/controlflowanalysis/test_loop_name_generator.py +++ b/tests/pipeline/controlflowanalysis/test_loop_name_generator.py @@ -25,9 +25,11 @@ # Test For/WhileLoop Renamer + def logic_cond(name: str, context) -> LogicCondition: return LogicCondition.initialize_symbol(name, context) + @pytest.fixture def ast_call_for_loop() -> AbstractSyntaxTree: """ @@ -46,7 +48,11 @@ def ast_call_for_loop() -> AbstractSyntaxTree: Assignment(Variable("a"), Constant(5)), ] ) - loop_node = ast.factory.create_for_loop_node(Assignment(ListOperation([Variable("b")]), Call(ImportedFunctionSymbol("foo", 0), [])), logic_cond("x1", context), Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))) + loop_node = ast.factory.create_for_loop_node( + Assignment(ListOperation([Variable("b")]), Call(ImportedFunctionSymbol("foo", 0), [])), + logic_cond("x1", context), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ) loop_node_body = ast._add_code_node( [ Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("1")])), @@ -65,12 +71,12 @@ def test_declaration_listop(ast_call_for_loop): for node in ast_call_for_loop: if isinstance(node, ForLoopNode): assert node.declaration.destination.operands[0].name == "i" - + def test_for_loop_variable_generation(): renamer = ForLoopVariableRenamer( AbstractSyntaxTree(SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context())), {}), - ["i", "j", "k", "l", "m", "n"] + ["i", "j", "k", "l", "m", "n"], ) assert [renamer._get_variable_name() for _ in range(14)] == [ "i", @@ -96,11 +102,20 @@ def test_while_loop_variable_generation(): ) assert [renamer._get_variable_name() for _ in range(5)] == ["counter", "counter1", "counter2", "counter3", "counter4"] + # Test Readabilitybasedrefinement + LoopNameGenerator together -def _generate_options(empty_loops: bool = False, hide_decl: bool = False, rename_for: bool = True, rename_while: bool = True, \ - max_condition: int = 100, max_modification: int = 100, force_for_loops: bool = False, blacklist : List[str] = []) -> Options: +def _generate_options( + empty_loops: bool = False, + hide_decl: bool = False, + rename_for: bool = True, + rename_while: bool = True, + max_condition: int = 100, + max_modification: int = 100, + force_for_loops: bool = False, + blacklist: List[str] = [], +) -> Options: options = Options() options.set("readability-based-refinement.keep_empty_for_loops", empty_loops) options.set("readability-based-refinement.hide_non_initializing_declaration", hide_decl) @@ -957,7 +972,6 @@ def run_rbr(ast: AbstractSyntaxTree, options: Options = _generate_options()): ReadabilityBasedRefinement().run(task) LoopNameGenerator().run(task) - def test_no_replacement(self, ast_while_true): self.run_rbr(ast_while_true) assert all(not isinstance(node, ForLoopNode) for node in ast_while_true.topological_order()) diff --git a/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py b/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py index 15a602a1b..3616f3dde 100644 --- a/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py +++ b/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py @@ -32,8 +32,14 @@ def logic_cond(name: str, context) -> LogicCondition: return LogicCondition.initialize_symbol(name, context) -def _generate_options(empty_loops: bool = False, hide_decl: bool = False, max_condition: int = 100, max_modification: int = 100, \ - force_for_loops: bool = False, blacklist : List[str] = []) -> Options: +def _generate_options( + empty_loops: bool = False, + hide_decl: bool = False, + max_condition: int = 100, + max_modification: int = 100, + force_for_loops: bool = False, + blacklist: List[str] = [], +) -> Options: options = Options() options.set("readability-based-refinement.keep_empty_for_loops", empty_loops) options.set("readability-based-refinement.hide_non_initializing_declaration", hide_decl) @@ -75,14 +81,19 @@ def ast_innerWhile_simple_condition_complexity() -> AbstractSyntaxTree: outer_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) outer_while_body = ast.factory.create_seq_node() - outer_while_init = ast._add_code_node([Assignment(Variable("b"), Constant(0)), Assignment(Variable("c"), Constant(0)) - , Assignment(Variable("d"), Constant(0))]) + outer_while_init = ast._add_code_node( + [Assignment(Variable("b"), Constant(0)), Assignment(Variable("c"), Constant(0)), Assignment(Variable("d"), Constant(0))] + ) outer_while_exit = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) inner_while = ast.factory.create_while_loop_node(logic_cond("x2", context) & logic_cond("x3", context) & logic_cond("x4", context)) - inner_while_body = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - Assignment(Variable("c"), BinaryOperation(OperationType.plus, [Variable("c"), Constant(1)])), - Assignment(Variable("d"), BinaryOperation(OperationType.plus, [Variable("d"), Constant(1)]))]) + inner_while_body = ast._add_code_node( + [ + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + Assignment(Variable("c"), BinaryOperation(OperationType.plus, [Variable("c"), Constant(1)])), + Assignment(Variable("d"), BinaryOperation(OperationType.plus, [Variable("d"), Constant(1)])), + ] + ) ast._add_nodes_from((outer_while, outer_while_body, inner_while)) ast._add_edges_from( @@ -99,7 +110,7 @@ def ast_innerWhile_simple_condition_complexity() -> AbstractSyntaxTree: return ast -def generate_ast_with_modification_complexity(complexity : int) -> AbstractSyntaxTree: +def generate_ast_with_modification_complexity(complexity: int) -> AbstractSyntaxTree: """ a = 0; while (a < 10) { @@ -121,7 +132,7 @@ def generate_ast_with_modification_complexity(complexity : int) -> AbstractSynta return ast -def generate_ast_with_condition_type(op : OperationType) -> AbstractSyntaxTree: +def generate_ast_with_condition_type(op: OperationType) -> AbstractSyntaxTree: """ a = 0; while (a 10) { @@ -161,9 +172,17 @@ def ast_guarded_do_while_if() -> AbstractSyntaxTree: ast._add_node(cond_node) ast._add_node(true_branch) ast._add_node(do_while_loop) - ast._add_edges_from([(root, init_code_node), (root, cond_node), (cond_node, true_branch), (true_branch, do_while_loop), (do_while_loop, do_while_loop_body)]) + ast._add_edges_from( + [ + (root, init_code_node), + (root, cond_node), + (cond_node, true_branch), + (true_branch, do_while_loop), + (do_while_loop, do_while_loop_body), + ] + ) return ast - + @pytest.fixture def ast_guarded_do_while_else() -> AbstractSyntaxTree: @@ -188,7 +207,15 @@ def ast_guarded_do_while_else() -> AbstractSyntaxTree: ast._add_node(cond_node) ast._add_node(false_branch) ast._add_node(do_while_loop) - ast._add_edges_from([(root, init_code_node), (root, cond_node), (cond_node, false_branch), (false_branch, do_while_loop), (do_while_loop, do_while_loop_body)]) + ast._add_edges_from( + [ + (root, init_code_node), + (root, cond_node), + (cond_node, false_branch), + (false_branch, do_while_loop), + (do_while_loop, do_while_loop_body), + ] + ) return ast @@ -248,7 +275,8 @@ def ast_while_in_else() -> AbstractSyntaxTree: class TestForLoopRecovery: - """ Test options for for-loop recovery """ + """Test options for for-loop recovery""" + @staticmethod def run_rbr(ast: AbstractSyntaxTree, options: Options = _generate_options()): ReadabilityBasedRefinement().run(DecompilerTask("func", cfg=None, ast=ast, options=options)) @@ -262,7 +290,6 @@ def test_max_condition_complexity(self, ast_innerWhile_simple_condition_complexi else: assert isinstance(loop_node, WhileLoopNode) - @pytest.mark.parametrize("modification_nesting", [1, 2]) def test_max_modification_complexity(self, modification_nesting): ast = generate_ast_with_modification_complexity(modification_nesting) @@ -276,11 +303,10 @@ def test_max_modification_complexity(self, modification_nesting): assert isinstance(loop_node, WhileLoopNode) for condition_variable in loop_node.get_required_variables(ast.condition_map): instruction = _find_continuation_instruction(ast, loop_node, condition_variable) - assert instruction is not None + assert instruction is not None assert instruction.instruction.complexity > max_modi_complexity - - @pytest.mark.parametrize("operation", [OperationType.equal, OperationType.not_equal ,OperationType.less_or_equal, OperationType.less]) + @pytest.mark.parametrize("operation", [OperationType.equal, OperationType.not_equal, OperationType.less_or_equal, OperationType.less]) def test_for_loop_recovery_blacklist(self, operation): ast = generate_ast_with_condition_type(operation) forbidden_conditon_types = ["not_equal", "equal"] @@ -811,7 +837,6 @@ def test_separated_by_loop_node_4(self, ast_while_in_else): assert _initialization_reaches_loop_node(init_code_node, inner_while) is False - def test_skip_for_loop_recovery_if_continue_in_while(self): """ a = 0 @@ -828,15 +853,12 @@ def test_skip_for_loop_recovery_if_continue_in_while(self): root := SeqNode(true_value), condition_map={ logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)]), - logic_cond("x2", context): Condition(OperationType.equal, [Variable("a"), Constant(2)]) - } + logic_cond("x2", context): Condition(OperationType.equal, [Variable("a"), Constant(2)]), + }, ) true_branch = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(2)])), - Continue() - ] + [Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(2)])), Continue()] ) if_condition = ast._add_condition_node_with(logic_cond("x2", context), true_branch) @@ -844,7 +866,9 @@ def test_skip_for_loop_recovery_if_continue_in_while(self): while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) while_loop_body = ast.factory.create_seq_node() - while_loop_iteration = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) + while_loop_iteration = ast._add_code_node( + [Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))] + ) ast._add_node(while_loop) ast._add_node(while_loop_body) @@ -854,7 +878,7 @@ def test_skip_for_loop_recovery_if_continue_in_while(self): (root, while_loop), (while_loop, while_loop_body), (while_loop_body, if_condition), - (while_loop_body, while_loop_iteration) + (while_loop_body, while_loop_iteration), ] ) @@ -881,28 +905,31 @@ def test_skip_for_loop_recovery_if_continue_in_nested_while(self): condition_map={ logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(5)]), logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(10)]), - logic_cond("x3", context): Condition(OperationType.less, [Variable("b"), Constant(0)]) - } + logic_cond("x3", context): Condition(OperationType.less, [Variable("b"), Constant(0)]), + }, ) true_branch = ast._add_code_node( - [ - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(2)])), - Continue() - ] + [Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(2)])), Continue()] ) if_condition = ast._add_condition_node_with(logic_cond("x3", context), true_branch) while_loop_outer = ast.factory.create_while_loop_node(logic_cond("x1", context)) while_loop_body_outer = ast.factory.create_seq_node() - while_loop_iteration_outer_1 = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")]))]) - while_loop_iteration_outer_2 = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) + while_loop_iteration_outer_1 = ast._add_code_node( + [Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")]))] + ) + while_loop_iteration_outer_2 = ast._add_code_node( + [Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))] + ) ast._add_node(while_loop_outer) ast._add_node(while_loop_body_outer) while_loop_inner = ast.factory.create_while_loop_node(logic_cond("x2", context)) while_loop_body_inner = ast.factory.create_seq_node() - while_loop_iteration_inner = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))]) + while_loop_iteration_inner = ast._add_code_node( + [Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))] + ) ast._add_node(while_loop_inner) ast._add_node(while_loop_body_inner) @@ -915,10 +942,10 @@ def test_skip_for_loop_recovery_if_continue_in_nested_while(self): (while_loop_body_outer, while_loop_iteration_outer_2), (while_loop_inner, while_loop_body_inner), (while_loop_body_inner, if_condition), - (while_loop_body_inner, while_loop_iteration_inner) + (while_loop_body_inner, while_loop_iteration_inner), ] ) WhileLoopReplacer(ast, _generate_options()).run() loop_nodes = list(ast.get_loop_nodes_post_order()) - assert not isinstance(loop_nodes[0], ForLoopNode) and isinstance(loop_nodes[1], ForLoopNode) \ No newline at end of file + assert not isinstance(loop_nodes[0], ForLoopNode) and isinstance(loop_nodes[1], ForLoopNode) diff --git a/tests/pipeline/controlflowanalysis/test_variable_name_generation.py b/tests/pipeline/controlflowanalysis/test_variable_name_generation.py index 1f0242d55..166a026eb 100644 --- a/tests/pipeline/controlflowanalysis/test_variable_name_generation.py +++ b/tests/pipeline/controlflowanalysis/test_variable_name_generation.py @@ -31,10 +31,46 @@ VOID = CustomType.void() ALL_TYPES = [I8, I16, I32, I64, I128, UI8, UI16, UI32, UI64, UI128, HALF, FLOAT, DOUBLE, LONG_DOUBLE, QUADRUPLE, OCTUPLE, BOOL, VOID] -EXPECTED_BASE_NAMES = ["chVar0", "sVar1", "iVar2", "lVar3", "i128Var4", "uchVar5", "usVar6", "uiVar7", "ulVar8", "ui128Var9", "hVar10", - "fVar11", "dVar12", "ldVar13", "qVar14", "oVar15", "bVar16", "vVar17"] -EXPECTED_POINTER_NAMES = ["chpVar0", "spVar1", "ipVar2", "lpVar3", "i128pVar4", "uchpVar5", "uspVar6", "uipVar7", "ulpVar8", "ui128pVar9", - "hpVar10", "fpVar11", "dpVar12", "ldpVar13", "qpVar14", "opVar15", "bpVar16", "vpVar17"] +EXPECTED_BASE_NAMES = [ + "chVar0", + "sVar1", + "iVar2", + "lVar3", + "i128Var4", + "uchVar5", + "usVar6", + "uiVar7", + "ulVar8", + "ui128Var9", + "hVar10", + "fVar11", + "dVar12", + "ldVar13", + "qVar14", + "oVar15", + "bVar16", + "vVar17", +] +EXPECTED_POINTER_NAMES = [ + "chpVar0", + "spVar1", + "ipVar2", + "lpVar3", + "i128pVar4", + "uchpVar5", + "uspVar6", + "uipVar7", + "ulpVar8", + "ui128pVar9", + "hpVar10", + "fpVar11", + "dpVar12", + "ldpVar13", + "qpVar14", + "opVar15", + "bpVar16", + "vpVar17", +] def _generate_options(notation: str = "system_hungarian", pointer_base: bool = True, type_sep: str = "", counter_sep: str = "") -> Options: @@ -65,13 +101,8 @@ def test_default_notation_1(): @pytest.mark.parametrize( "variable, name", - [ - (Variable("var_" + str(i), typ), EXPECTED_BASE_NAMES[i]) for i, typ in enumerate(ALL_TYPES) - ] + - [ - (Variable("var_" + str(i), Pointer(typ)), EXPECTED_POINTER_NAMES[i]) for i, typ in enumerate(ALL_TYPES) - ] - , + [(Variable("var_" + str(i), typ), EXPECTED_BASE_NAMES[i]) for i, typ in enumerate(ALL_TYPES)] + + [(Variable("var_" + str(i), Pointer(typ)), EXPECTED_POINTER_NAMES[i]) for i, typ in enumerate(ALL_TYPES)], ) def test_hungarian_notation(variable, name): true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) @@ -114,12 +145,13 @@ def test_same_variable(): true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) var1 = Variable("tmp_42", Float(64)) var2 = Variable("var_0", Integer(104, True)) - ast = AbstractSyntaxTree(CodeNode([ - Assignment(var1, Constant(0)), - Assignment(var1, Constant(0)), - Assignment(var2, Constant(0)), - Assignment(var2, Constant(0))], true_value), {}) + ast = AbstractSyntaxTree( + CodeNode( + [Assignment(var1, Constant(0)), Assignment(var1, Constant(0)), Assignment(var2, Constant(0)), Assignment(var2, Constant(0))], + true_value, + ), + {}, + ) _run_vng(ast, _generate_options()) assert var1._name == "dTmp42" assert var2._name == "unkVar0" - \ No newline at end of file diff --git a/tests/pipeline/dataflowanalysis/test_array_access_detection.py b/tests/pipeline/dataflowanalysis/test_array_access_detection.py index 838e20f76..538f24242 100644 --- a/tests/pipeline/dataflowanalysis/test_array_access_detection.py +++ b/tests/pipeline/dataflowanalysis/test_array_access_detection.py @@ -458,6 +458,7 @@ def test10(): run_array_access_detection(input_cfg) assert graphs_equal(input_cfg, output_cfg) + def test11(): """Test array-access-detection when array type is bool -> RuntimeError: Unexpected size 1 @@ -2042,6 +2043,7 @@ def graphs_test10(): ) return cfg, out_cfg + def graphs_test11(): bl = CustomType.bool() cfg = ControlFlowGraph() @@ -2055,10 +2057,7 @@ def graphs_test11(): [ BinaryOperation( OperationType.plus, - [ - base := Variable("arg1", Pointer(bl, 32), 0, False), - index := Variable("var_11", Integer.int64_t()) - ], + [base := Variable("arg1", Pointer(bl, 32), 0, False), index := Variable("var_11", Integer.int64_t())], ), ], ), @@ -2078,13 +2077,11 @@ def graphs_test11(): [ BinaryOperation( OperationType.plus, - [ - Variable("arg1", Pointer(bl, 32), 0, False), - Variable("var_11", Integer.int64_t()) - ], + [Variable("arg1", Pointer(bl, 32), 0, False), Variable("var_11", Integer.int64_t())], ), ], - array_info=ArrayInfo(base, index, True)), + array_info=ArrayInfo(base, index, True), + ), Constant(10), ) ], diff --git a/tests/pipeline/dataflowanalysis/test_expression_propagation.py b/tests/pipeline/dataflowanalysis/test_expression_propagation.py index 8188033d1..f8111d876 100644 --- a/tests/pipeline/dataflowanalysis/test_expression_propagation.py +++ b/tests/pipeline/dataflowanalysis/test_expression_propagation.py @@ -966,21 +966,21 @@ def test_correct_propagation_relation(): def test_contraction_copy(): """ - Original copy error with the following steps: - - two different variables have the same contraction as a value - ==> will be propagated to both, but both have the same value (same variables in memory) - - - later in exp. prop. mem. a part of the value (a third variable; pointer) for the first variable will be propagated by the value of the pointer - ==> because both have the same value and only a subexpr is being propagated, the value will change in both variables - ==> error on the propagation of the second value, because the definition of that variable is not correct anymore - - Variables affected in 'tr' 'main' for easy reconstruction: - - rdi_16#28, rdi_17#31 have the same value after expr. prop. - - rax_55#44 is the part which will be propagated by expr. prop. mem. in rdi_16#28 first and yields the side effect at rdi_17#31 - ==> rdi_17#31 will yield the error after trying to get a definition - - Test will only check if the value (memory) of the propagated variables is not the same. - (Variables itself are no contractions, but still works) + Original copy error with the following steps: + - two different variables have the same contraction as a value + ==> will be propagated to both, but both have the same value (same variables in memory) + + - later in exp. prop. mem. a part of the value (a third variable; pointer) for the first variable will be propagated by the value of the pointer + ==> because both have the same value and only a subexpr is being propagated, the value will change in both variables + ==> error on the propagation of the second value, because the definition of that variable is not correct anymore + + Variables affected in 'tr' 'main' for easy reconstruction: + - rdi_16#28, rdi_17#31 have the same value after expr. prop. + - rax_55#44 is the part which will be propagated by expr. prop. mem. in rdi_16#28 first and yields the side effect at rdi_17#31 + ==> rdi_17#31 will yield the error after trying to get a definition + + Test will only check if the value (memory) of the propagated variables is not the same. + (Variables itself are no contractions, but still works) """ variable0 = Variable("var0", Integer(64)) variable1 = Variable("var1", Integer(64)) diff --git a/tests/pipeline/dataflowanalysis/test_expression_propagation_mem.py b/tests/pipeline/dataflowanalysis/test_expression_propagation_mem.py index 03fc141a7..f3cb69d44 100644 --- a/tests/pipeline/dataflowanalysis/test_expression_propagation_mem.py +++ b/tests/pipeline/dataflowanalysis/test_expression_propagation_mem.py @@ -25,45 +25,45 @@ def test_postponed_aliased_propagation_handles_aliases_correctly(): """ - +--------------------------------+ - | 0. | - | var_18#1 = var_18#0 | - | func() | - | var_18#2 = var_18#1 | - | var_28#1 = &(var_18#2) | - | scanf(var_28#1) | - | var_18#3 -> var_18#2 | - | eax#1 = var_18#3 | - | var_14#4 = eax#1 |<--------var_14 is now an alias of var_18 - | func() | - | var_18#4 = var_18#3 | - | var_10#1 = &(var_18#4) | - | *(var_10#1) = 0x7 |<--------var_18 is changed via deref, so does var_14, since they are aliases - | var_18#5 -> var_18#4 | - | var_14#5 = var_14#4 |<--------do not propagate old value of var_14 here, cause of change above - | eax_2#3 = var_18#5 | - | return (&(var_14#5)) + eax_2#3 | - +--------------------------------+ - + +--------------------------------+ + | 0. | + | var_18#1 = var_18#0 | + | func() | + | var_18#2 = var_18#1 | + | var_28#1 = &(var_18#2) | + | scanf(var_28#1) | + | var_18#3 -> var_18#2 | + | eax#1 = var_18#3 | + | var_14#4 = eax#1 |<--------var_14 is now an alias of var_18 + | func() | + | var_18#4 = var_18#3 | + | var_10#1 = &(var_18#4) | + | *(var_10#1) = 0x7 |<--------var_18 is changed via deref, so does var_14, since they are aliases + | var_18#5 -> var_18#4 | + | var_14#5 = var_14#4 |<--------do not propagate old value of var_14 here, cause of change above + | eax_2#3 = var_18#5 | + | return (&(var_14#5)) + eax_2#3 | + +--------------------------------+ + + +---------------------------------+ + | 0. | + | var_18#1 = var_18#0 | + | func() | + | var_18#2 = var_18#0 | + | var_28#1 = &(var_18#2) | + | scanf(&(var_18#2)) | + | var_18#3 -> var_18#2 | + | eax#1 = var_18#3 | + | var_14#4 = var_18#3 | + | func() | + | var_18#4 = var_18#3 | + | var_10#1 = &(var_18#4) | + | *(var_10#1) = 0x7 | + | var_18#5 -> var_18#4 | + | var_14#5 = var_14#4 |<--------this instruction should not be changed after epm + | eax_2#3 = var_18#5 | + | return (&(var_14#5)) + var_18#5 | +---------------------------------+ - | 0. | - | var_18#1 = var_18#0 | - | func() | - | var_18#2 = var_18#0 | - | var_28#1 = &(var_18#2) | - | scanf(&(var_18#2)) | - | var_18#3 -> var_18#2 | - | eax#1 = var_18#3 | - | var_14#4 = var_18#3 | - | func() | - | var_18#4 = var_18#3 | - | var_10#1 = &(var_18#4) | - | *(var_10#1) = 0x7 | - | var_18#5 -> var_18#4 | - | var_14#5 = var_14#4 |<--------this instruction should not be changed after epm - | eax_2#3 = var_18#5 | - | return (&(var_14#5)) + var_18#5 | -+---------------------------------+ """ input_cfg, output_cfg = graphs_with_aliases() _run_expression_propagation(input_cfg) @@ -71,7 +71,6 @@ def test_postponed_aliased_propagation_handles_aliases_correctly(): def graphs_with_aliases(): - var_18 = vars("var_18", 6, aliased=True) var_14 = vars("var_14", 6, aliased=True) var_28 = vars("var_28", 2, type=Pointer(int32)) @@ -82,23 +81,24 @@ def graphs_with_aliases(): in_n0 = BasicBlock( 0, - [_assign(var_18[1], var_18[0]), - _call("func", [], []), - _assign(var_18[2], var_18[1]), - _assign(var_28[1], _addr(var_18[2])), - _call("scanf", [], [var_28[1]]), - Relation(var_18[3], var_18[2]), - _assign(eax[1], var_18[3]), - _assign(var_14[4], eax[1]), - _call("func", [], []), - _assign(var_18[4], var_18[3]), - _assign(var_10[1], _addr(var_18[4])), - _assign(_deref(var_10[1]), c[7]), - Relation(var_18[5], var_18[4]), - _assign(var_14[5], var_14[4]), - _assign(eax_2[3], var_18[5]), - _ret(_add(_addr(var_14[5]), eax_2[3])) - ] + [ + _assign(var_18[1], var_18[0]), + _call("func", [], []), + _assign(var_18[2], var_18[1]), + _assign(var_28[1], _addr(var_18[2])), + _call("scanf", [], [var_28[1]]), + Relation(var_18[3], var_18[2]), + _assign(eax[1], var_18[3]), + _assign(var_14[4], eax[1]), + _call("func", [], []), + _assign(var_18[4], var_18[3]), + _assign(var_10[1], _addr(var_18[4])), + _assign(_deref(var_10[1]), c[7]), + Relation(var_18[5], var_18[4]), + _assign(var_14[5], var_14[4]), + _assign(eax_2[3], var_18[5]), + _ret(_add(_addr(var_14[5]), eax_2[3])), + ], ) in_cfg = ControlFlowGraph() in_cfg.add_node(in_n0) @@ -106,23 +106,24 @@ def graphs_with_aliases(): out_cfg.add_node( BasicBlock( 0, - [_assign(var_18[1], var_18[0]), - _call("func", [], []), - _assign(var_18[2], var_18[0]), - _assign(var_28[1], _addr(var_18[2])), - _call("scanf", [], [_addr(var_18[2])]), - Relation(var_18[3], var_18[2]), - _assign(eax[1], var_18[3]), - _assign(var_14[4], var_18[3]), - _call("func", [], []), - _assign(var_18[4], var_18[3]), - _assign(var_10[1], _addr(var_18[4])), - _assign(_deref(var_10[1]), c[7]), - Relation(var_18[5], var_18[4]), - _assign(var_14[5], var_14[4]), - _assign(eax_2[3], var_18[5]), - _ret(_add(_addr(var_14[5]), var_18[5])) - ] + [ + _assign(var_18[1], var_18[0]), + _call("func", [], []), + _assign(var_18[2], var_18[0]), + _assign(var_28[1], _addr(var_18[2])), + _call("scanf", [], [_addr(var_18[2])]), + Relation(var_18[3], var_18[2]), + _assign(eax[1], var_18[3]), + _assign(var_14[4], var_18[3]), + _call("func", [], []), + _assign(var_18[4], var_18[3]), + _assign(var_10[1], _addr(var_18[4])), + _assign(_deref(var_10[1]), c[7]), + Relation(var_18[5], var_18[4]), + _assign(var_14[5], var_14[4]), + _assign(eax_2[3], var_18[5]), + _ret(_add(_addr(var_14[5]), var_18[5])), + ], ) ) return in_cfg, out_cfg @@ -130,40 +131,40 @@ def graphs_with_aliases(): def test_address_propagation_does_not_break_relations_between_aliased_versions(): """ - +------------------+ - | 0. | - | x#0 = 0x0 | - | y#0 = 0x0 | <--- DO NOT propagate - | ptr_x#1 = &(x#0) | <--- can propagate - | ptr_y#1 = &(y#0) | <--- can propagate - | func(ptr_x#1) | - | y#1 = y#0 | <--- propagation will cause connection loss between lhs and rhs variable - | x#1 -> x#0 | - | func(ptr_y#1) | - | y#2 -> y#1 | - | x#2 = x#1 | - | x#3 = x#2 | <--- can propagate (aliased) definition x#2=x#1 here, as x#2 is not used anywhere else - | y#3 = y#2 | - | return x#3 + y#3 | - +------------------+ - - After: - +------------------+ - | 0. | - | x#0 = 0x0 | - | y#0 = 0x0 | - | ptr_x#1 = &(x#0) | - | ptr_y#1 = &(y#0) | - | func(&(x#0)) | - | y#1 = y#0 | - | x#1 -> x#0 | - | func(&(y#0)) | - | y#2 -> y#1 | - | x#2 = x#1 | - | x#3 = x#1 | - | y#3 = y#2 | - | return x#1 + y#2 | - +------------------+ + +------------------+ + | 0. | + | x#0 = 0x0 | + | y#0 = 0x0 | <--- DO NOT propagate + | ptr_x#1 = &(x#0) | <--- can propagate + | ptr_y#1 = &(y#0) | <--- can propagate + | func(ptr_x#1) | + | y#1 = y#0 | <--- propagation will cause connection loss between lhs and rhs variable + | x#1 -> x#0 | + | func(ptr_y#1) | + | y#2 -> y#1 | + | x#2 = x#1 | + | x#3 = x#2 | <--- can propagate (aliased) definition x#2=x#1 here, as x#2 is not used anywhere else + | y#3 = y#2 | + | return x#3 + y#3 | + +------------------+ + + After: + +------------------+ + | 0. | + | x#0 = 0x0 | + | y#0 = 0x0 | + | ptr_x#1 = &(x#0) | + | ptr_y#1 = &(y#0) | + | func(&(x#0)) | + | y#1 = y#0 | + | x#1 -> x#0 | + | func(&(y#0)) | + | y#2 -> y#1 | + | x#2 = x#1 | + | x#3 = x#1 | + | y#3 = y#2 | + | return x#1 + y#2 | + +------------------+ """ input_cfg, output_cfg = graphs_with_address_propagation_does_not_break_relations_between_aliased_versions() _run_expression_propagation(input_cfg) @@ -179,20 +180,21 @@ def graphs_with_address_propagation_does_not_break_relations_between_aliased_ver in_n0 = BasicBlock( 0, - [_assign(x[0], c[0]), - _assign(y[0], c[0]), - _assign(ptr_x[1], _addr(x[0])), - _assign(ptr_y[1], _addr(y[0])), - _call("func", [], [ptr_x[1]]), - _assign(y[1], y[0]), - Relation(x[1], x[0]), - _call("func", [], [ptr_y[1]]), - Relation(y[2], y[1]), - _assign(x[2], x[1]), - _assign(x[3], x[2]), - _assign(y[3], y[2]), - _ret(_add(x[3], y[3])), - ] + [ + _assign(x[0], c[0]), + _assign(y[0], c[0]), + _assign(ptr_x[1], _addr(x[0])), + _assign(ptr_y[1], _addr(y[0])), + _call("func", [], [ptr_x[1]]), + _assign(y[1], y[0]), + Relation(x[1], x[0]), + _call("func", [], [ptr_y[1]]), + Relation(y[2], y[1]), + _assign(x[2], x[1]), + _assign(x[3], x[2]), + _assign(y[3], y[2]), + _ret(_add(x[3], y[3])), + ], ) in_cfg = ControlFlowGraph() in_cfg.add_node(in_n0) @@ -200,24 +202,26 @@ def graphs_with_address_propagation_does_not_break_relations_between_aliased_ver out_cfg.add_node( BasicBlock( 0, - [_assign(x[0], c[0]), - _assign(y[0], c[0]), - _assign(ptr_x[1], _addr(x[0])), - _assign(ptr_y[1], _addr(y[0])), - _call("func", [], [_addr(x[0])]), - _assign(y[1], y[0]), - Relation(x[1], x[0]), - _call("func", [], [_addr(y[0])]), - Relation(y[2], y[1]), - _assign(x[2], x[1]), - _assign(x[3], x[1]), - _assign(y[3], y[2]), - _ret(_add(x[1], y[2])), - ], + [ + _assign(x[0], c[0]), + _assign(y[0], c[0]), + _assign(ptr_x[1], _addr(x[0])), + _assign(ptr_y[1], _addr(y[0])), + _call("func", [], [_addr(x[0])]), + _assign(y[1], y[0]), + Relation(x[1], x[0]), + _call("func", [], [_addr(y[0])]), + Relation(y[2], y[1]), + _assign(x[2], x[1]), + _assign(x[3], x[1]), + _assign(y[3], y[2]), + _ret(_add(x[1], y[2])), + ], ) ) return in_cfg, out_cfg + def test_assignments_with_dereference_subexpressions_on_rhs_are_propagated_when_no_modification_between_def_and_use(): """ +-------------------------------------+ diff --git a/tests/pipeline/preprocessing/test_insert_missing_definition.py b/tests/pipeline/preprocessing/test_insert_missing_definition.py index f674a8963..6fcdcac02 100644 --- a/tests/pipeline/preprocessing/test_insert_missing_definition.py +++ b/tests/pipeline/preprocessing/test_insert_missing_definition.py @@ -413,7 +413,6 @@ def construct_graph_aliased(number: int) -> (List[Instruction], List[Variable], return list_instructions, aliased_variables, task if number == 8: - list_instructions[23].value._writes_memory = 7 nodes[2].instructions = [i.copy() for i in list_instructions[15:17]] + [ list_instructions[23].copy(), @@ -1089,7 +1088,9 @@ def test_relation_and_assignment_insertion_after_memory_changing_instructions(): b = [Variable("b", Integer.int32_t(), i, is_aliased=True) for i in range(10)] instruction_0 = Assignment(a[1], Constant(0x1)) instruction_1 = Assignment(b[2], UnaryOperation(OperationType.address, [a[1]], writes_memory=2)) - instruction_2 = Assignment(ListOperation([]), Call(function_symbol("scanf"), [UnaryOperation(OperationType.address, [a[2]])], writes_memory=3)) + instruction_2 = Assignment( + ListOperation([]), Call(function_symbol("scanf"), [UnaryOperation(OperationType.address, [a[2]])], writes_memory=3) + ) cfg = ControlFlowGraph() cfg.add_node(BasicBlock(0, [instruction_0, instruction_1, instruction_2])) task = DecompilerTask("test", cfg) @@ -1106,4 +1107,3 @@ def test_relation_and_assignment_insertion_after_memory_changing_instructions(): # test last 2 inserted definitions separately # since I am not sure if the definitions insertion order is deterministic assert {Relation(a[3], a[2]), Assignment(b[3], b[2])} == set(task.graph.nodes[0].instructions[-2:]) - diff --git a/tests/pipeline/preprocessing/test_remove_stack_canary.py b/tests/pipeline/preprocessing/test_remove_stack_canary.py index 953f93ba5..1931c21a8 100644 --- a/tests/pipeline/preprocessing/test_remove_stack_canary.py +++ b/tests/pipeline/preprocessing/test_remove_stack_canary.py @@ -52,6 +52,7 @@ def test_trivial_no_change(): assert isinstance(cfg.get_edge(n1, n2), TrueCase) assert isinstance(cfg.get_edge(n1, n3), FalseCase) + def test_no_change_to_single_block_function(): """ +--------------------+ @@ -68,6 +69,7 @@ def test_no_change_to_single_block_function(): _run_remove_stack_canary(cfg) assert set(cfg) == {b} + def test_one_branch_to_stack_fail(): """ Check if one Branch to stack fail gets removed. Block 3 will be removed. @@ -261,7 +263,7 @@ def test_multiple_returns_one_stackcheck(): def test_one_branch_single_empty_block_between_stack_fail(): """ - Check if one Branch to stack fail gets removed. + Check if one Branch to stack fail gets removed. One empty block between __stack_chk_fail, should be removed as well. +--------------------+ | 0. | @@ -305,10 +307,9 @@ def test_one_branch_single_empty_block_between_stack_fail(): assert isinstance(cfg.get_edge(n1, n2), UnconditionalEdge) - def test_single_branch_multiple_empty_blocks_between_stack_fail(): """ - Check if one Branch to stack fail gets removed. + Check if one Branch to stack fail gets removed. Multiple empty blocks in the __stack_chk_fail branch should all be removed. +--------------------+ | 0. | @@ -351,14 +352,16 @@ def test_single_branch_multiple_empty_blocks_between_stack_fail(): n5 := BasicBlock(5, instructions=[Assignment(ListOperation([]), Call(ImportedFunctionSymbol("__stack_chk_fail", 0), []))]), ] ) - cfg.add_edges_from([UnconditionalEdge(n0, n1), TrueCase(n1, n2), FalseCase(n1, n3), UnconditionalEdge(n3, n4), \ - UnconditionalEdge(n4, n5)]) + cfg.add_edges_from( + [UnconditionalEdge(n0, n1), TrueCase(n1, n2), FalseCase(n1, n3), UnconditionalEdge(n3, n4), UnconditionalEdge(n4, n5)] + ) _run_remove_stack_canary(cfg) assert set(cfg) == {n0, n1, n2} assert n1.instructions == [] assert isinstance(cfg.get_edge(n0, n1), UnconditionalEdge) assert isinstance(cfg.get_edge(n1, n2), UnconditionalEdge) + def test_one_branch_single_non_empty_block_between_stack_fail(): """ Check if cfg error will be detected. @@ -407,6 +410,7 @@ def test_one_branch_single_non_empty_block_between_stack_fail(): error = True assert error is True + def test_multiple_returns_multiple_empty_blocks_one_stackcheck(): """ Test with multiple returns that each share a branch to __stack_chk_fail (does this even happen?). diff --git a/tests/structures/logic/test_logic_condition.py b/tests/structures/logic/test_logic_condition.py index 46abcf133..9a5997b2a 100644 --- a/tests/structures/logic/test_logic_condition.py +++ b/tests/structures/logic/test_logic_condition.py @@ -352,9 +352,7 @@ def test_substitute_by_true_basics(self, term, condition, result): ), ( logic_x[1].copy() | logic_x[2].copy(), - (logic_x[4].copy() | logic_x[5].copy()) - & logic_x[6].copy() - & logic_x[7].copy(), + (logic_x[4].copy() | logic_x[5].copy()) & logic_x[6].copy() & logic_x[7].copy(), ), ( (logic_x[1].copy() | logic_x[2].copy() | logic_x[3].copy()) diff --git a/tests/structures/logic/test_z3_logic_converter.py b/tests/structures/logic/test_z3_logic_converter.py index bee8fe803..eb573cd22 100644 --- a/tests/structures/logic/test_z3_logic_converter.py +++ b/tests/structures/logic/test_z3_logic_converter.py @@ -19,7 +19,8 @@ def _get_condition_branch(second_operand): return Branch( - Condition(OperationType.not_equal, + Condition( + OperationType.not_equal, [ Constant(42, Integer.int32_t()), second_operand, @@ -27,6 +28,7 @@ def _get_condition_branch(second_operand): ) ) + def _generate_instr_bool_as_numbers(op: OperationType) -> Branch: return Branch( Condition( @@ -65,6 +67,7 @@ def test_instruction_conv(instr): # Assert z3 compatible logic_converter.check(condition) + def test_logic_converter_z3(): logic_converter: BaseConverter = Z3Converter() instr1 = _get_condition_branch(Constant(inf, Float.double())) @@ -76,4 +79,4 @@ def test_logic_converter_z3(): # Covered by 'Dead Path Elimination' yields ValueError (will be skipped for z3 stuff) with pytest.raises(ValueError): - logic_converter.convert(instr2, define_expr=True) \ No newline at end of file + logic_converter.convert(instr2, define_expr=True) diff --git a/tests/structures/pseudo/test_expressions.py b/tests/structures/pseudo/test_expressions.py index 96389c850..b3cf2d9e7 100644 --- a/tests/structures/pseudo/test_expressions.py +++ b/tests/structures/pseudo/test_expressions.py @@ -223,7 +223,7 @@ def test_copy(self): original = NotUseableConstant(str(inf)) copy = original.copy() assert id(original) != id(copy) and original == copy - + class TestExternConstant: def test_copy(self): diff --git a/tests/structures/pseudo/test_typing.py b/tests/structures/pseudo/test_typing.py index d8801c774..61eaef274 100644 --- a/tests/structures/pseudo/test_typing.py +++ b/tests/structures/pseudo/test_typing.py @@ -4,6 +4,7 @@ SIZEOF_BOOL = 8 + def test_representation(): """Test the text representation of various types.""" # Integer tests @@ -64,10 +65,12 @@ def test_is_bool(): assert not CustomType.void().is_boolean assert not Integer.int32_t().is_boolean + def test_bool_size(): """Test if bool has the correct size""" assert CustomType.bool().size == SIZEOF_BOOL + def test_type_parser(): """Test the type parser to support basic type guessing.""" parser = TypeParser() diff --git a/tests/structures/test_maps.py b/tests/structures/test_maps.py index dc4131d79..fd7342cdc 100644 --- a/tests/structures/test_maps.py +++ b/tests/structures/test_maps.py @@ -180,12 +180,13 @@ def test_used_variables(): Variable("w", Integer.int32_t(), 1), } + def test_use_map_remove_use(): instruction_list, use_map = define_use_map() assert use_map.get(Variable("v", Integer.int32_t(), 3)) == {instruction_list[1], instruction_list[2]} - #remove existing use + # remove existing use use_map.remove_use(Variable("v", Integer.int32_t(), 3), instruction_list[1]) - #remove non-existing use + # remove non-existing use use_map.remove_use(Variable("v", Integer.int32_t(), 3), instruction_list[4]) assert use_map._map == { Variable("u", Integer.int32_t()): {instruction_list[0], instruction_list[2]}, @@ -195,5 +196,3 @@ def test_use_map_remove_use(): Variable("v", Integer.int32_t(), 2): {instruction_list[4]}, Variable("w", Integer.int32_t(), 1): {instruction_list[5]}, } - - diff --git a/tests/structures/visitors/test_substitute_visitor.py b/tests/structures/visitors/test_substitute_visitor.py index 89a874081..3f0d76ab5 100644 --- a/tests/structures/visitors/test_substitute_visitor.py +++ b/tests/structures/visitors/test_substitute_visitor.py @@ -30,160 +30,112 @@ @pytest.mark.parametrize( ["initial_obj", "expected_result", "visitor"], [ - ( - o := Variable("v", _i32, 0), - r := Variable("x", _i32, 1), - SubstituteVisitor.identity(o, r) - ), - ( - o := Variable("v", _i32, 0), - r := Variable("x", _i32, 1), - SubstituteVisitor.equality(o, r) - ), - ( - o := Variable("v", _i32, 0), - o, - SubstituteVisitor.identity(Variable("v", _i32, 0), Variable("x", _i32, 1)) - ), - ( - o := Variable("v", _i32, 0), - r := Variable("x", _i32, 1), - SubstituteVisitor.equality(Variable("v", _i32, 0), r) - ), - ( - Assignment(a := Variable("a"), b := Variable("b")), - Assignment(a, c := Variable("c")), - SubstituteVisitor.identity(b, c) - ), - ( - Assignment(a := Variable("a"), b := Variable("b")), - Assignment(c := Variable("c"), b), - SubstituteVisitor.identity(a, c) - ), + (o := Variable("v", _i32, 0), r := Variable("x", _i32, 1), SubstituteVisitor.identity(o, r)), + (o := Variable("v", _i32, 0), r := Variable("x", _i32, 1), SubstituteVisitor.equality(o, r)), + (o := Variable("v", _i32, 0), o, SubstituteVisitor.identity(Variable("v", _i32, 0), Variable("x", _i32, 1))), + (o := Variable("v", _i32, 0), r := Variable("x", _i32, 1), SubstituteVisitor.equality(Variable("v", _i32, 0), r)), + (Assignment(a := Variable("a"), b := Variable("b")), Assignment(a, c := Variable("c")), SubstituteVisitor.identity(b, c)), + (Assignment(a := Variable("a"), b := Variable("b")), Assignment(c := Variable("c"), b), SubstituteVisitor.identity(a, c)), ( UnaryOperation(OperationType.dereference, [a := Variable("a")]), UnaryOperation(OperationType.dereference, [b := Variable("b")]), - SubstituteVisitor.identity(a, b) + SubstituteVisitor.identity(a, b), ), ( UnaryOperation( OperationType.dereference, [BinaryOperation(OperationType.plus, [a := Variable("a", _p_i32), Constant(4, _i32)])], - array_info=ArrayInfo(a, 1) + array_info=ArrayInfo(a, 1), ), UnaryOperation( OperationType.dereference, [BinaryOperation(OperationType.plus, [b := Variable("b", _p_i32), Constant(4, _i32)])], - array_info=ArrayInfo(b, 1) + array_info=ArrayInfo(b, 1), ), - SubstituteVisitor.identity(a, b) + SubstituteVisitor.identity(a, b), ), ( UnaryOperation( OperationType.dereference, - [BinaryOperation( - OperationType.plus, - [ - a := Variable("a", _p_i32), - BinaryOperation(OperationType.multiply, [b := Variable("b", _i32), Constant(4, _i32)]) - ] - )], - array_info=ArrayInfo(a, b) + [ + BinaryOperation( + OperationType.plus, + [ + a := Variable("a", _p_i32), + BinaryOperation(OperationType.multiply, [b := Variable("b", _i32), Constant(4, _i32)]), + ], + ) + ], + array_info=ArrayInfo(a, b), ), UnaryOperation( OperationType.dereference, - [BinaryOperation( - OperationType.plus, - [ - a := Variable("a", _p_i32), - BinaryOperation(OperationType.multiply, [c := Variable("c", _i32), Constant(4, _i32)]) - ] - )], - array_info=ArrayInfo(a, c) + [ + BinaryOperation( + OperationType.plus, + [ + a := Variable("a", _p_i32), + BinaryOperation(OperationType.multiply, [c := Variable("c", _i32), Constant(4, _i32)]), + ], + ) + ], + array_info=ArrayInfo(a, c), ), - SubstituteVisitor.identity(b, c) + SubstituteVisitor.identity(b, c), ), ( BinaryOperation(OperationType.multiply, [a := Variable("a"), b := Variable("b")]), BinaryOperation(OperationType.multiply, [a, c := Variable("c")]), - SubstituteVisitor.identity(b, c) - ), - ( - RegisterPair(a := Variable("a"), b := Variable("b")), - RegisterPair(a, c := Variable("c")), - SubstituteVisitor.identity(b, c) - ), - ( - Call(f := Variable("f"), [a := Variable("a")]), - Call(f, [b := Variable("b")]), - SubstituteVisitor.identity(a, b) - ), - ( - Call(f := Variable("f"), [a := Variable("a")]), - Call(g := Variable("g"), [a]), - SubstituteVisitor.identity(f, g) + SubstituteVisitor.identity(b, c), ), + (RegisterPair(a := Variable("a"), b := Variable("b")), RegisterPair(a, c := Variable("c")), SubstituteVisitor.identity(b, c)), + (Call(f := Variable("f"), [a := Variable("a")]), Call(f, [b := Variable("b")]), SubstituteVisitor.identity(a, b)), + (Call(f := Variable("f"), [a := Variable("a")]), Call(g := Variable("g"), [a]), SubstituteVisitor.identity(f, g)), ( Phi( a3 := Variable("a", _i32, 3), - [ - a2 := Variable("a", _i32, 2), - a1 := Variable("a", _i32, 1) - ], + [a2 := Variable("a", _i32, 2), a1 := Variable("a", _i32, 1)], { BasicBlock(2): a2, BasicBlock(1): a1, - } + }, ), Phi( a3, - [ - a2, - a0 := Variable("a", _i32, 0) - ], + [a2, a0 := Variable("a", _i32, 0)], { BasicBlock(2): a2, BasicBlock(1): a0, - } + }, ), - SubstituteVisitor.identity(a1, a0) + SubstituteVisitor.identity(a1, a0), ), ( Phi( a3 := Variable("a", _i32, 3), - [ - a2 := Variable("a", _i32, 2), - a1 := Variable("a", _i32, 1) - ], + [a2 := Variable("a", _i32, 2), a1 := Variable("a", _i32, 1)], { BasicBlock(2): a2, BasicBlock(1): a1, - } + }, ), Phi( a4 := Variable("a", _i32, 4), - [ - a2, - a1 - ], + [a2, a1], { BasicBlock(2): a2, BasicBlock(1): a1, - } + }, ), - SubstituteVisitor.identity(a3, a4) + SubstituteVisitor.identity(a3, a4), ), ( Branch(a := Condition(OperationType.equal, [])), Branch(b := Condition(OperationType.not_equal, [])), - SubstituteVisitor.identity(a, b) - ), - ( - Return([a := Variable("a")]), - Return([b := Variable("b")]), - SubstituteVisitor.identity(a, b) + SubstituteVisitor.identity(a, b), ), - ] + (Return([a := Variable("a")]), Return([b := Variable("b")]), SubstituteVisitor.identity(a, b)), + ], ) def test_substitute(initial_obj: DataflowObject, expected_result: DataflowObject, visitor: SubstituteVisitor): result = initial_obj.accept(visitor) diff --git a/tests/test_sample_binaries.py b/tests/test_sample_binaries.py index d013b2e18..3b6c84e7b 100644 --- a/tests/test_sample_binaries.py +++ b/tests/test_sample_binaries.py @@ -7,7 +7,7 @@ def test_sample(test_cases): """Test the decompiler with the given test case.""" sample, function_name = test_cases - output = subprocess.run(("python", "decompile.py", sample, function_name), check=True, capture_output=True).stdout.decode('utf-8') + output = subprocess.run(("python", "decompile.py", sample, function_name), check=True, capture_output=True).stdout.decode("utf-8") assert "Failed to decompile due to error during " not in output @@ -87,7 +87,7 @@ def test_global_ptr(): # Assert global pointer correct assert output.count("c = 0x0") == 1 assert output.count("d = 0x0") == 1 - # Assert call correct + # Assert call correct len(re.findall("var_[0-9]+= d", output)) == 1 len(re.findall("var_[0-9]+= c", output)) == 1 len(re.findall("_add(var_[0-9]+, var_[0-9]+)", output)) == 1 @@ -162,7 +162,7 @@ def test_global_indirect_ptr2(): output = str(subprocess.run(args1, check=True, capture_output=True).stdout) # Assert global variables correct - assert output.count("p = 0xffffffbe") == 2 # should be one, still one lifter issue + assert output.count("p = 0xffffffbe") == 2 # should be one, still one lifter issue assert output.count("o = &(p)") == 1 assert output.count("n = &(o)") == 1 assert output.count("m = &(n)") == 1 @@ -215,13 +215,13 @@ def test_global_import_address_symbol(): assert output.count("g_38 = 0x7cb0be9") == 1 # test types and initial values (dec or hex) are correct in declarations - assert re.search(r'unsigned short\s*g_22\s*=\s*54249', output) or re.search(r'unsigned short\s*g_22\s*=\s*0xd3e9', output) - assert re.search(r'unsigned char\s*g_26\s*=\s*157', output) or re.search(r'unsigned char\s*g_26\s*=\s*0x9d', output) - assert re.search(r'unsigned int\s*g_29\s*=\s*65537', output) or re.search(r'unsigned int\s*g_29\s*=\s*0x10001', output) - assert re.search(r'unsigned char\s*g_30\s*=\s*236', output) or re.search(r'unsigned char\s*g_30\s*=\s*0xec', output) - assert re.search(r'unsigned int\s*g_32\s*=\s*1578356047', output) or re.search(r'unsigned int\s*g_32\s*=\s*0x5e13cd4f', output) - assert re.search(r'unsigned char\s*g_35\s*=\s*255', output) or re.search(r'unsigned char\s*g_35\s*=\s*0xff', output) - assert re.search(r'unsigned int\s*g_38\s*=\s*130747369', output) or re.search(r'unsigned int\s*g_38\s*=\s*0x7cb0be9', output) + assert re.search(r"unsigned short\s*g_22\s*=\s*54249", output) or re.search(r"unsigned short\s*g_22\s*=\s*0xd3e9", output) + assert re.search(r"unsigned char\s*g_26\s*=\s*157", output) or re.search(r"unsigned char\s*g_26\s*=\s*0x9d", output) + assert re.search(r"unsigned int\s*g_29\s*=\s*65537", output) or re.search(r"unsigned int\s*g_29\s*=\s*0x10001", output) + assert re.search(r"unsigned char\s*g_30\s*=\s*236", output) or re.search(r"unsigned char\s*g_30\s*=\s*0xec", output) + assert re.search(r"unsigned int\s*g_32\s*=\s*1578356047", output) or re.search(r"unsigned int\s*g_32\s*=\s*0x5e13cd4f", output) + assert re.search(r"unsigned char\s*g_35\s*=\s*255", output) or re.search(r"unsigned char\s*g_35\s*=\s*0xff", output) + assert re.search(r"unsigned int\s*g_38\s*=\s*130747369", output) or re.search(r"unsigned int\s*g_38\s*=\s*0x7cb0be9", output) def test_string_with_pointer_compare(): @@ -230,7 +230,7 @@ def test_string_with_pointer_compare(): args1 = base_args + ["global_string_compare"] output = str(subprocess.run(args1, check=True, capture_output=True).stdout) - assert output.count("Hello Decompiler") == 1 # it's enough to test if the output has the string. Would crash if not possible in if + assert output.count("Hello Decompiler") == 1 # it's enough to test if the output has the string. Would crash if not possible in if def test_w_char(): @@ -249,7 +249,7 @@ def test_string_length(): args1 = base_args + ["global_string_length"] output = str(subprocess.run(args1, check=True, capture_output=True).stdout) - assert output.count('...') == 2 + assert output.count("...") == 2 def test_tailcall_display(): @@ -282,4 +282,4 @@ def test_iat_entries_are_decompiled_correctly(): args = ["python", "decompile.py", "tests/samples/others/test.exe", "0x401865"] subprocess.run(args, check=True) output = str(subprocess.run(args, check=True, capture_output=True).stdout) - assert re.search(r'=\s*GetModuleHandleW\((0x0|/\* lpModuleName \*/ 0x0\))', output) \ No newline at end of file + assert re.search(r"=\s*GetModuleHandleW\((0x0|/\* lpModuleName \*/ 0x0\))", output)