diff --git a/decompiler/backend/cexpressiongenerator.py b/decompiler/backend/cexpressiongenerator.py index 63e6df7fa..d39b7e061 100644 --- a/decompiler/backend/cexpressiongenerator.py +++ b/decompiler/backend/cexpressiongenerator.py @@ -239,6 +239,9 @@ def visit_constant_composition(self, expr: expressions.ConstantComposition): case CustomType(text="wchar16") | CustomType(text="wchar32"): val = "".join([x.value for x in expr.value]) return f'L"{val}"' if len(val) <= MAX_GLOBAL_INIT_LENGTH else f'L"{val[:MAX_GLOBAL_INIT_LENGTH]}..."' + case Integer(size=8, signed=False): + val = "".join([f"\\x{x.value:02X}" for x in expr.value][:MAX_GLOBAL_INIT_LENGTH]) + return f'"{val}"' if len(val) <= MAX_GLOBAL_INIT_LENGTH else f'"{val[:MAX_GLOBAL_INIT_LENGTH]}..."' case Integer(8): val = "".join([x.value for x in expr.value][:MAX_GLOBAL_INIT_LENGTH]) return f'"{val}"' if len(val) <= MAX_GLOBAL_INIT_LENGTH else f'"{val[:MAX_GLOBAL_INIT_LENGTH]}..."' diff --git a/decompiler/backend/variabledeclarations.py b/decompiler/backend/variabledeclarations.py index 14735d0b8..efd0cc708 100644 --- a/decompiler/backend/variabledeclarations.py +++ b/decompiler/backend/variabledeclarations.py @@ -70,7 +70,7 @@ def _generate_definitions(global_variables: set[GlobalVariable]) -> Iterator[str match variable.type: case ArrayType(): br, bl = "", "" - if not variable.type.type in [Integer.char(), CustomType.wchar16(), CustomType.wchar32()]: + if not variable.type.type in [Integer.char(), Integer.uint8_t(), CustomType.wchar16(), CustomType.wchar32()]: br, bl = "{", "}" yield f"{base}{variable.type.type} {variable.name}[{hex(variable.type.elements)}] = {br}{CExpressionGenerator().visit(variable.initial_value)}{bl};" case Struct(): diff --git a/decompiler/frontend/binaryninja/handlers/constants.py b/decompiler/frontend/binaryninja/handlers/constants.py index d351a2a2e..e7e5dffdd 100644 --- a/decompiler/frontend/binaryninja/handlers/constants.py +++ b/decompiler/frontend/binaryninja/handlers/constants.py @@ -72,9 +72,6 @@ def lift_constant_pointer(self, pointer: mediumlevelil.MediumLevelILConstPtr, ** if isinstance(res, Constant): # BNinja Error case handling return res - if isinstance(res.type, Pointer) and res.type.type == CustomType.void(): - return res - if isinstance(pointer, mediumlevelil.MediumLevelILImport): # Temp fix for '&' return res diff --git a/decompiler/frontend/binaryninja/handlers/globals.py b/decompiler/frontend/binaryninja/handlers/globals.py index beaa439ad..beb526725 100644 --- a/decompiler/frontend/binaryninja/handlers/globals.py +++ b/decompiler/frontend/binaryninja/handlers/globals.py @@ -308,7 +308,9 @@ def _get_unknown_value(self, variable: DataVariable): type = PseudoArrayType(self._lifter.lift(data[1]), len(data[0])) data = ConstantComposition([Constant(x, type.type) for x in data[0]], type) else: - data, type = get_raw_bytes(variable.address, self._view), Pointer(CustomType.void(), self._view.address_size * BYTE_SIZE) + rbytes = get_raw_bytes(variable.address, self._view) + type = PseudoArrayType(Integer.uint8_t(), len(rbytes)) + data = ConstantComposition([Constant(b, type.type) for b in rbytes], type) return data, type def _get_unknown_pointer_value(self, variable: DataVariable, callers: list[int] = None): diff --git a/decompiler/pipeline/commons/expressionpropagationcommons.py b/decompiler/pipeline/commons/expressionpropagationcommons.py index 36c57c14b..c9e95f4ff 100644 --- a/decompiler/pipeline/commons/expressionpropagationcommons.py +++ b/decompiler/pipeline/commons/expressionpropagationcommons.py @@ -225,8 +225,9 @@ def _is_address_into_dereference(self, definition: Assignment, target: Instructi if self._is_address(definition.value): for subexpr in target: for sub in self._find_subexpressions(subexpr): - if self._is_dereference(sub) and sub.operand == definition.destination: + if self._is_dereference(sub) and sub.operand in definition.definitions: return True + return False def _contains_aliased_variables(self, definition: Assignment) -> bool: """ @@ -326,14 +327,13 @@ def _has_any_of_dangerous_uses_between_definition_and_target( def _get_dangerous_uses_of_variable_address(self, var: Variable) -> Set[Instruction]: """ Dangerous use of & of x is func(&x) cause it can potentially modify x. - *(&x) could also do the job but I consider it to be too exotic so that we could get such instruction from Binary Ninja - If it happens we can handle it later. + Another case is an Assignment where the left side is *(&). :param var: aliased variable :return: set of function call assignments that take &var as parameter """ dangerous_uses = set() for use in self._use_map.get(var): - if not self._is_call_assignment(use): + if not self._is_call_assignment(use) and not (isinstance(use, Assignment) and self._is_dereference(use.destination)): continue for subexpr in self._find_subexpressions(use): if self._is_address(subexpr): diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py index 31656b667..88212abde 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py @@ -65,7 +65,12 @@ def constant_fold(operation: OperationType, constants: list[Constant], result_ty ) -def _constant_fold_arithmetic_binary(constants: list[Constant], fun: Callable[[int, int], int], norm_sign: Optional[bool] = None) -> int: +def _constant_fold_arithmetic_binary( + constants: list[Constant], + fun: Callable[[int, int], int], + norm_sign: Optional[bool] = None, + allow_mismatched_sizes: bool = False, +) -> int: """ Fold an arithmetic binary operation with constants as operands. @@ -84,7 +89,7 @@ def _constant_fold_arithmetic_binary(constants: list[Constant], fun: Callable[[i if len(constants) != 2: raise IncompatibleOperandCount(f"Expected exactly 2 constants to fold, got {len(constants)}.") - if not all(constant.type.size == constants[0].type.size for constant in constants): + if not allow_mismatched_sizes and not all(constant.type.size == constants[0].type.size for constant in constants): raise UnsupportedMismatchedSizes(f"Can not fold constants with different sizes: {[constant.type for constant in constants]}") left, right = constants @@ -137,6 +142,10 @@ def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], in return fun(normalize_int(left.value, left.type.size, norm_signed), right.value) +def remainder(n, d): + return (-1 if n < 0 else 1) * (n % d) + + _OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], int]] = { OperationType.minus: partial(_constant_fold_arithmetic_binary, fun=operator.sub), OperationType.plus: partial(_constant_fold_arithmetic_binary, fun=operator.add), @@ -144,6 +153,8 @@ def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], in OperationType.multiply_us: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=False), OperationType.divide: partial(_constant_fold_arithmetic_binary, fun=operator.floordiv, norm_sign=True), OperationType.divide_us: partial(_constant_fold_arithmetic_binary, fun=operator.floordiv, norm_sign=False), + OperationType.modulo: partial(_constant_fold_arithmetic_binary, fun=remainder, norm_sign=True, allow_mismatched_sizes=True), + OperationType.modulo_us: partial(_constant_fold_arithmetic_binary, fun=operator.mod, norm_sign=False, allow_mismatched_sizes=True), OperationType.negate: partial(_constant_fold_arithmetic_unary, fun=operator.neg), OperationType.left_shift: partial(_constant_fold_shift, fun=operator.lshift, signed=True), OperationType.right_shift: partial(_constant_fold_shift, fun=operator.rshift, signed=True), diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/acyclic_restructuring.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/acyclic_restructuring.py index b251a3f3a..38f3e6063 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/acyclic_restructuring.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/acyclic_restructuring.py @@ -94,7 +94,10 @@ def _construct_refined_ast(self, seq_node_root: SeqNode) -> AbstractSyntaxTreeNo ConditionBasedRefinement.refine(self.asforest) acyclic_processor.preprocess_condition_aware_refinement() if self.options.reconstruct_switch: - ConditionAwareRefinement.refine(self.asforest, self.options) + updated_switch_nodes = ConditionAwareRefinement.refine(self.asforest, self.options) + for switch_node in updated_switch_nodes: + for sequence_case in (c for c in switch_node.cases if isinstance(c.child, SeqNode)): + ConditionBasedRefinement.refine(self.asforest, sequence_case.child) acyclic_processor.postprocess_condition_refinement() root = self.asforest.current_root self.asforest.remove_current_root() diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/ast_processor.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/ast_processor.py index 8edf915e6..11b96e35b 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/ast_processor.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/ast_processor.py @@ -232,7 +232,7 @@ def _group_by_reaching_conditions(self, nodes: Tuple[AbstractSyntaxTreeNode]) -> :param nodes: The AST nodes that we want to group. :return: A dictionary that assigns to a reaching condition the list of AST code nodes with this reaching condition, - if it are at least two with the same. + if there are at least two with the same. """ initial_groups: Dict[LogicCondition, List[AbstractSyntaxTreeNode]] = dict() for node in nodes: diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement.py index ce57fbd2a..df17d9a54 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement.py @@ -2,6 +2,8 @@ Module for Condition Aware Refinement """ +from typing import Set + from decompiler.pipeline.controlflowanalysis.restructuring_commons.condition_aware_refinement_commons.base_class_car import ( BaseClassConditionAwareRefinement, ) @@ -21,6 +23,7 @@ SwitchExtractor, ) from decompiler.pipeline.controlflowanalysis.restructuring_options import RestructuringOptions +from decompiler.structures.ast.ast_nodes import SwitchNode from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest @@ -35,13 +38,14 @@ class ConditionAwareRefinement(BaseClassConditionAwareRefinement): ] @classmethod - def refine(cls, asforest: AbstractSyntaxForest, options: RestructuringOptions): + def refine(cls, asforest: AbstractSyntaxForest, options: RestructuringOptions) -> Set[SwitchNode]: condition_aware_refinement = cls(asforest, options) for stage in condition_aware_refinement.REFINEMENT_PIPELINE: asforest.clean_up(asforest.current_root) - stage(asforest, options) + condition_aware_refinement.updated_switch_nodes.update(stage(asforest, options)) condition_aware_refinement._remove_redundant_reaching_condition_from_switch_nodes() asforest.clean_up(asforest.current_root) + return set(switch for switch in condition_aware_refinement.updated_switch_nodes if switch in asforest) def _remove_redundant_reaching_condition_from_switch_nodes(self): """Remove the reaching condition from all switch nodes if it is redundant.""" diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/base_class_car.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/base_class_car.py index ff93d7af2..477377360 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/base_class_car.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/base_class_car.py @@ -1,10 +1,9 @@ from dataclasses import dataclass -from typing import Iterator, Optional, Tuple +from typing import Iterator, Optional, Set, Tuple from decompiler.pipeline.controlflowanalysis.restructuring_options import LoopBreakOptions, RestructuringOptions from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, CaseNode, FalseNode, SwitchNode, TrueNode -from decompiler.structures.ast.condition_symbol import ConditionHandler -from decompiler.structures.ast.switch_node_handler import ExpressionUsages +from decompiler.structures.ast.condition_symbol import ConditionHandler, ExpressionUsages from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest from decompiler.structures.logic.logic_condition import LogicCondition, PseudoLogicCondition from decompiler.structures.pseudo import Condition, Constant, Expression, OperationType @@ -63,6 +62,7 @@ def __init__(self, asforest: AbstractSyntaxForest, options: RestructuringOptions self.asforest: AbstractSyntaxForest = asforest self.condition_handler: ConditionHandler = asforest.condition_handler self.options: RestructuringOptions = options + self.updated_switch_nodes: Set[SwitchNode] = set() def _get_constant_equality_check_expressions_and_conditions( self, condition: LogicCondition @@ -109,14 +109,14 @@ def _get_expression_compared_with_constant(self, reaching_condition: LogicCondit Check whether the given reaching condition, which is a literal, i.e., a z3-symbol or its negation is of the form `expr == const`. If this is the case, then we return the expression `expr`. """ - return self.asforest.switch_node_handler.get_potential_switch_expression(reaching_condition) + return self.asforest.condition_handler.get_potential_switch_expression_of(reaching_condition) def _get_constant_compared_with_expression(self, reaching_condition: LogicCondition) -> Optional[Constant]: """ Check whether the given reaching condition, which is a literal, i.e., a z3-symbol or its negation is of the form `expr == const`. If this is the case, then we return the constant `const`. """ - return self.asforest.switch_node_handler.get_potential_switch_constant(reaching_condition) + return self.asforest.condition_handler.get_potential_switch_constant_of(reaching_condition) def _convert_to_z3_condition(self, condition: LogicCondition) -> PseudoLogicCondition: return PseudoLogicCondition.initialize_from_formula(condition, self.condition_handler.get_z3_condition_map()) diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py index 1950c543b..fbda16638 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py @@ -11,8 +11,8 @@ ) from decompiler.pipeline.controlflowanalysis.restructuring_options import RestructuringOptions from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, CaseNode, CodeNode, ConditionNode, SeqNode, SwitchNode, TrueNode +from decompiler.structures.ast.condition_symbol import ExpressionUsages from decompiler.structures.ast.reachability_graph import CaseDependencyGraph, LinearOrderDependency, SiblingReachability -from decompiler.structures.ast.switch_node_handler import ExpressionUsages from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest from decompiler.structures.logic.logic_condition import LogicCondition from decompiler.structures.pseudo import Constant, Expression @@ -90,8 +90,8 @@ def _clean_up_reachability(self): """ for candidate_1, candidate_2 in permutations(self.switch_candidate.cases, 2): if self.sibling_reachability.reaches(candidate_1.node, candidate_2.node) and not ( - set(self.asforest.switch_node_handler.get_constants_for(candidate_1.condition)) - & set(self.asforest.switch_node_handler.get_constants_for(candidate_2.condition)) + set(self.asforest.condition_handler.get_constants_of(candidate_1.condition)) + & set(self.asforest.condition_handler.get_constants_of(candidate_2.condition)) ): self.asforest._code_node_reachability_graph.remove_reachability_between([candidate_1.node, candidate_2.node]) self.sibling_reachability.remove_reachability_between([candidate_1.node, candidate_2.node]) @@ -214,13 +214,14 @@ class InitialSwitchNodeConstructor(BaseClassConditionAwareRefinement): """Class that constructs switch nodes.""" @classmethod - def construct(cls, asforest: AbstractSyntaxForest, options: RestructuringOptions): + def construct(cls, asforest: AbstractSyntaxForest, options: RestructuringOptions) -> Set[SwitchNode]: """Constructs initial switch nodes if possible.""" initial_switch_constructor = cls(asforest, options) for cond_node in asforest.get_condition_nodes_post_order(asforest.current_root): initial_switch_constructor._extract_case_nodes_from_nested_condition(cond_node) for seq_node in asforest.get_sequence_nodes_post_order(asforest.current_root): initial_switch_constructor._try_to_construct_initial_switch_node_for(seq_node) + return initial_switch_constructor.updated_switch_nodes def _extract_case_nodes_from_nested_condition(self, cond_node: ConditionNode) -> None: """ @@ -336,6 +337,7 @@ def _try_to_construct_initial_switch_node_for(self, seq_node: SeqNode) -> None: sibling_reachability = self.asforest.get_sibling_reachability_of_children_of(seq_node) switch_cases = list(possible_switch_node.construct_switch_cases()) switch_node = self.asforest.create_switch_node_with(possible_switch_node.expression, switch_cases) + self.updated_switch_nodes.add(switch_node) case_dependency = CaseDependencyGraph.construct_case_dependency_for(self.asforest.children(switch_node), sibling_reachability) self._update_reaching_condition_for_case_node_children(switch_node) self._add_constants_to_cases(switch_node, case_dependency) @@ -393,7 +395,7 @@ def _update_reaching_condition_for_case_node_children(self, switch_node: SwitchN case_node.reaching_condition.is_disjunction_of_literals ), f"The condition of a case node should be a disjunction, but it is {case_node.reaching_condition}!" - if isinstance(cond_node := case_node.child, ConditionNode) and cond_node.false_branch is None: + if (cond_node := case_node.child).is_single_branch: self._update_condition_for(cond_node, case_node) case_node.child.reaching_condition = case_node.child.reaching_condition.substitute_by_true(case_node.reaching_condition) @@ -519,7 +521,7 @@ def _add_constants_to_cases_for( case_node.constant = Constant("add_to_previous_case") else: considered_conditions.update( - (c, l) for l, c in self.asforest.switch_node_handler.get_literal_and_constant_for(case_node.reaching_condition) + (c, l) for l, c in self.asforest.condition_handler.get_literal_and_constant_of(case_node.reaching_condition) ) def _update_reaching_condition_of(self, case_node: CaseNode, considered_conditions: Dict[Constant, LogicCondition]) -> None: @@ -535,8 +537,7 @@ def _update_reaching_condition_of(self, case_node: CaseNode, considered_conditio :param considered_conditions: The conditions (literals) that are already fulfilled when we reach the given case node. """ constant_of_case_node_literal = { - const: literal - for literal, const in self.asforest.switch_node_handler.get_literal_and_constant_for(case_node.reaching_condition) + const: literal for literal, const in self.asforest.condition_handler.get_literal_and_constant_of(case_node.reaching_condition) } exception_condition: LogicCondition = self.condition_handler.get_true_value() @@ -576,7 +577,7 @@ def prepend_empty_cases_to_case_with_or_condition(self, case: CaseNode) -> List[ the list of new case nodes. """ condition_for_constant: Dict[Constant, LogicCondition] = dict() - for l, c in self.asforest.switch_node_handler.get_literal_and_constant_for(case.reaching_condition): + for l, c in self.asforest.condition_handler.get_literal_and_constant_of(case.reaching_condition): if c is None: raise ValueError( f"The case node should have a reaching-condition that is a disjunction of literals, but it has the clause {l}." diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder.py index b9c50a190..faf30682a 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder.py @@ -33,6 +33,7 @@ def _insert_case_node(self, new_case_node: AbstractSyntaxTreeNode, case_constant if default_case := switch_node.default: new_children.append(default_case) switch_node._sorted_cases = tuple(new_children) + self.updated_switch_nodes.add(switch_node) def _new_case_nodes_for( self, new_case_node: AbstractSyntaxTreeNode, switch_node: SwitchNode, sorted_case_constants: List[Constant] diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_condition.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_condition.py index 0b50b9e82..1dd3794d8 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_condition.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_condition.py @@ -28,7 +28,7 @@ class MissingCaseFinderCondition(MissingCaseFinder): """ @classmethod - def find(cls, asforest: AbstractSyntaxForest, options: RestructuringOptions): + def find(cls, asforest: AbstractSyntaxForest, options: RestructuringOptions) -> Set[SwitchNode]: """Try to find missing cases that are branches of condition nodes.""" missing_case_finder = cls(asforest, options) for condition_node in asforest.get_condition_nodes_post_order(asforest.current_root): @@ -37,9 +37,10 @@ def find(cls, asforest: AbstractSyntaxForest, options: RestructuringOptions): case_candidate_information.case_node, case_candidate_information.case_constants, case_candidate_information.switch_node ) if case_candidate_information.in_sequence: - asforest.extract_switch_from_condition_sequence(case_candidate_information.switch_node, condition_node) + asforest.extract_switch_from_sequence(case_candidate_information.switch_node) else: asforest.replace_condition_node_by_single_branch(condition_node) + return missing_case_finder.updated_switch_nodes def _can_insert_missing_case_node(self, condition_node: ConditionNode) -> Optional[CaseCandidateInformation]: """ diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_intersecting_constants.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_intersecting_constants.py index e180c6249..6e4f4b16d 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_intersecting_constants.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_intersecting_constants.py @@ -47,7 +47,7 @@ def insert(self, possible_case: CaseNodeCandidate): first fallthrough-cases. - If the possible-case node is reached by the switch-node, then the content must be after any other code. Thus, it must contain all constants from a block of fallthrough-cases. But here, it can contain more. - - If neither one reaches the other, then it can be insert anywhere, at long as it can be archived by only + - If neither one reaches the other, then it can be inserted anywhere, as long as it can be archived by only resorting fallthrough-cases all leading to the same code-execution. """ cases_of_switch_node = {case.constant for case in self._switch_node.children} @@ -70,6 +70,7 @@ def insert(self, possible_case: CaseNodeCandidate): return self._sibling_reachability_graph.update_when_inserting_new_case_node(compare_node, self._switch_node) + self.updated_switch_nodes.add(self._switch_node) compare_node.clean() def _add_case_before(self, intersecting_linear_case: Tuple[CaseNode], possible_case_properties: IntersectingCaseNodeProperties) -> bool: diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_sequence.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_sequence.py index 16d4398dc..d4b35ceda 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_sequence.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_sequence.py @@ -12,8 +12,8 @@ ) from decompiler.pipeline.controlflowanalysis.restructuring_options import RestructuringOptions from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, ConditionNode, FalseNode, SeqNode, SwitchNode, TrueNode +from decompiler.structures.ast.condition_symbol import ExpressionUsages from decompiler.structures.ast.reachability_graph import SiblingReachabilityGraph -from decompiler.structures.ast.switch_node_handler import ExpressionUsages from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest from decompiler.structures.logic.logic_condition import LogicCondition, PseudoLogicCondition from decompiler.structures.pseudo import Condition, Constant, OperationType @@ -38,7 +38,7 @@ def __init__(self, asforest: AbstractSyntaxForest, options: RestructuringOptions self._switch_node_of_expression: Dict[ExpressionUsages, SwitchNode] = dict() @classmethod - def find(cls, asforest: AbstractSyntaxForest, options: RestructuringOptions): + def find(cls, asforest: AbstractSyntaxForest, options: RestructuringOptions) -> Set[SwitchNode]: """ Try to find missing cases that are children of sequence nodes. @@ -58,6 +58,7 @@ def find(cls, asforest: AbstractSyntaxForest, options: RestructuringOptions): if seq_node in asforest: seq_node.clean() + return missing_case_finder.updated_switch_nodes def _initialize_switch_node_of_expression_dictionary(self) -> None: """ diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/switch_extractor.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/switch_extractor.py index ba5479a5b..e19d176de 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/switch_extractor.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/switch_extractor.py @@ -1,57 +1,50 @@ -from typing import Optional, Union +from typing import Optional, Set, Union from decompiler.pipeline.controlflowanalysis.restructuring_commons.condition_aware_refinement_commons.base_class_car import ( BaseClassConditionAwareRefinement, ) from decompiler.pipeline.controlflowanalysis.restructuring_options import RestructuringOptions -from decompiler.structures.ast.ast_nodes import ConditionNode, FalseNode, SeqNode, TrueNode +from decompiler.structures.ast.ast_nodes import ConditionNode, FalseNode, SeqNode, SwitchNode, TrueNode from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest +from decompiler.structures.logic.logic_condition import LogicCondition class SwitchExtractor(BaseClassConditionAwareRefinement): """Extract switch nodes from condition nodes if the condition node is irrelevant for the switch node.""" - def __init__(self, asforest: AbstractSyntaxForest, options: RestructuringOptions): - """ - self.current_cond_node: The condition node which we consider to extract switch nodes. - """ - super().__init__(asforest, options) - self._current_cond_node: Optional[ConditionNode] = None - @classmethod def extract(cls, asforest: AbstractSyntaxForest, options: RestructuringOptions): - """ - Extract switch nodes from condition nodes, i.e., if a switch node is a branch of a condition node whose condition is redundant for - the switch node, we extract it from the condition node. - """ + """Extract switch nodes from condition nodes, or sequence-nodes with a non-trivial reaching-condition.""" switch_extractor = cls(asforest, options) - for condition_node in asforest.get_condition_nodes_post_order(asforest.current_root): - switch_extractor._current_cond_node = condition_node - switch_extractor._extract_switches_from_condition() + for switch_node in list(asforest.get_switch_nodes_post_order(asforest.current_root)): + while switch_extractor._successfully_extracts_switch_nodes(switch_node): + pass + return switch_extractor.updated_switch_nodes - def _extract_switches_from_condition(self) -> None: - """Extract switch nodes in the true and false branch of the given condition node.""" - if self._current_cond_node.false_branch: - self._try_to_extract_switch_from_branch(self._current_cond_node.false_branch) - if self._current_cond_node.true_branch: - self._try_to_extract_switch_from_branch(self._current_cond_node.true_branch) - if self._current_cond_node in self.asforest: - self._current_cond_node.clean() - - def _try_to_extract_switch_from_branch(self, branch: Union[TrueNode, FalseNode]) -> None: + def _successfully_extracts_switch_nodes(self, switch_node: SwitchNode) -> bool: """ - 1. If the given branch of the condition node is a switch node, - then extract it if the reaching condition is redundant for the switch node. - 2. If the given branch of the condition node is a sequence node whose first or last node is a switch node, - then extract it if the reaching condition is redundant for the switch node. + We extract the given switch-node, if possible, and return whether it was successfully extracted. + + 1. If the switch node has a sequence node as parent and is its first or last child + i) Sequence node has a non-trivial reaching-condition + --> extract the switch from the sequence node if the reaching-condition is redundant for the switch + ii) Sequence node has a trivial reaching-condition, and its parent is a branch of a condition node + --> extract the switch from the condition-node if the branch-condition is redundant for the switch + 2. If the switch node has a branch of a condition-node as parent + --> extract the switch from the condition node if the branch-condition is redundant for the switch """ - branch_condition = branch.branch_condition - if self._condition_is_redundant_for_switch_node(branch.child, branch_condition): - self._extract_switch_node_from_branch(branch) - elif isinstance(sequence_node := branch.child, SeqNode): - for switch_node in [sequence_node.children[0], sequence_node.children[-1]]: - if self._condition_is_redundant_for_switch_node(switch_node, branch_condition): - self.asforest.extract_switch_from_condition_sequence(switch_node, self._current_cond_node) + switch_parent = switch_node.parent + if isinstance(switch_parent, SeqNode): + if not switch_parent.reaching_condition.is_true: + return self._successfully_extract_switch_from_first_or_last_child_of(switch_parent, switch_parent.reaching_condition) + elif isinstance(branch := switch_parent.parent, TrueNode | FalseNode): + return self._successfully_extract_switch_from_first_or_last_child_of(switch_parent, branch.branch_condition) + elif isinstance(switch_parent, TrueNode | FalseNode) and self._condition_is_redundant_for_switch_node( + switch_node, switch_parent.branch_condition + ): + self._extract_switch_node_from_branch(switch_parent) + return True + return False def _extract_switch_node_from_branch(self, branch: Union[TrueNode, FalseNode]) -> None: """ @@ -64,7 +57,20 @@ def _extract_switch_node_from_branch(self, branch: Union[TrueNode, FalseNode]) - :param branch: The branch from which we extract the switch node. :return: If we introduce a new sequence node, then return this node, otherwise return None. """ - if len(self._current_cond_node.children) != 2: - self.asforest.replace_condition_node_by_single_branch(self._current_cond_node) + assert isinstance(condition_node := branch.parent, ConditionNode), "The parent of a true/false-branch must be a condition node!" + if len(condition_node.children) != 2: + self.asforest.replace_condition_node_by_single_branch(condition_node) else: - self.asforest.extract_branch_from_condition_node(self._current_cond_node, branch, False) + self.asforest.extract_branch_from_condition_node(condition_node, branch, False) + + def _successfully_extract_switch_from_first_or_last_child_of(self, sequence_node: SeqNode, condition: LogicCondition) -> bool: + """ + Check whether the first or last child of the sequence node is a switch-node for which the given condition is redundant. + If this is the case, extract the switch-node from the sequence. + """ + for switch_node in [sequence_node.children[0], sequence_node.children[-1]]: + if self._condition_is_redundant_for_switch_node(switch_node, condition): + assert isinstance(switch_node, SwitchNode), f"The node {switch_node} must be a switch-node!" + self.asforest.extract_switch_from_sequence(switch_node) + return True + return False diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_based_refinement.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_based_refinement.py index 97a58d87b..4a519f088 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_based_refinement.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_based_refinement.py @@ -6,9 +6,9 @@ from dataclasses import dataclass from itertools import chain, combinations -from typing import Dict, Iterator, List, Optional, Set, Tuple +from typing import Dict, Iterator, List, Literal, Optional, Set, Tuple -from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, SeqNode +from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, ConditionNode, SeqNode from decompiler.structures.ast.reachability_graph import SiblingReachability from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest from decompiler.structures.logic.logic_condition import LogicCondition @@ -23,17 +23,41 @@ class Formula: - setting eq to false implies that two objects are equal and have the same hash iff they are the same object """ - condition: LogicCondition ast_node: AbstractSyntaxTreeNode - def clauses(self) -> List[LogicCondition]: + @property + def is_if_else_formula(self) -> bool: + """ + Check whether the condition of the formula belongs to an if-else condition. + The condition-node can only be grouped if it has not reaching-condition. + """ + return self.ast_node.reaching_condition.is_true and not self.ast_node.is_single_branch + + @property + def condition(self) -> LogicCondition: + """ + Return the condition of the formula. + + If the AST-node of the formula as a reaching-condition that is not true, we return the reaching-condition. + Otherwise, the AS-node must be a condition node and we return its condition. + """ + if self.ast_node.reaching_condition.is_true: + assert isinstance(self.ast_node, ConditionNode), "The ast-node must be a condition node if the RC is true" + return self.ast_node.condition + return self.ast_node.reaching_condition + + def clauses(self) -> List[Clause]: """ Returns all clauses of the given formula in cnf-form. - formula = (a | b) & (c | d) & e, it returns [a | b, c | d, e] --> here each operand is a new logic-condition - formula = a | b | c, it returns [a | b | c] --> to ensure that we get a new logic-condition we copy it in this case. """ - return list(self.condition.operands) if self.condition.is_conjunction else [self.condition.copy()] + if self.is_if_else_formula: + return [ClauseFormula(self.ast_node.condition.copy(), self)] + else: + clauses = list(self.condition.operands) if self.condition.is_conjunction else [self.condition.copy()] + return [Clause(c, self) for c in clauses] @dataclass(frozen=True, eq=False) @@ -47,6 +71,14 @@ class Clause: formula: Formula +@dataclass(frozen=True, eq=False) +class ClauseFormula(Clause): + """ + Dataclass for logic-formula that can not be split into clauses for the grouping. + - setting eq to false implies that two objects are equal and have the same hash iff they are the same object + """ + + @dataclass(frozen=True, eq=True) class Symbol: """ @@ -60,17 +92,24 @@ class Symbol: class ConditionCandidates: """A graph implementation handling conditions for the condition-based refinement algorithm.""" - def __init__(self, candidates: List[AbstractSyntaxTreeNode]) -> None: + def __init__(self, sequence_node: SeqNode, sibling_reachability: SiblingReachability) -> None: """ Init for the condition-candidates. - param candidates:: list of all AST-nodes that we want to cluster into conditions. + param sequence-node:: The sequence node whose children we want to cluster + param sibling_reachability: The sibling-reachability of the given sequence-node - candidates: maps all relevant ast-nodes to their formula (reaching condition) - unconsidered_nodes: a set of all nodes that we still have to consider for grouping into conditions. - logic_graph: representation of all logic-formulas relevant """ - self._candidates: Dict[AbstractSyntaxTreeNode, Formula] = {c: Formula(c.reaching_condition, c) for c in candidates} + self.sequence_node: SeqNode = sequence_node + self.sibling_reachability: SiblingReachability = sibling_reachability + self._candidates: Dict[AbstractSyntaxTreeNode, Formula] = { + child: Formula(child) + for child in sequence_node.children + if not child.reaching_condition.is_true or isinstance(child, ConditionNode) + } self._unconsidered_nodes: InsertionOrderedSet[AbstractSyntaxTreeNode] = InsertionOrderedSet() self._logic_graph: DiGraph = DiGraph() self._initialize_logic_graph() @@ -89,9 +128,9 @@ def _initialize_logic_graph(self) -> None: all_symbols = set() for formula in self._candidates.values(): self._logic_graph.add_node(formula) - for logic_clause in formula.clauses(): - self._logic_graph.add_edge(formula, clause := Clause(logic_clause, formula)) - for symbol_name in logic_clause.get_symbols_as_string(): + for clause in formula.clauses(): + self._logic_graph.add_edge(formula, clause) + for symbol_name in clause.condition.get_symbols_as_string(): self._logic_graph.add_edge(clause, symbol := Symbol(symbol_name)) self._logic_graph.add_edge(formula, symbol, auxiliary=True) all_symbols.add(symbol) @@ -102,6 +141,10 @@ def candidates(self) -> Iterator[AbstractSyntaxTreeNode]: """Iterates over all candidates considered for grouping into conditions.""" yield from self._candidates + def get_condition(self, ast_node: AbstractSyntaxTreeNode) -> Tuple[LogicCondition, bool]: + """Return the condition that is relevant for grouping into branches.""" + return self._candidates[ast_node].condition, self._candidates[ast_node].is_if_else_formula + def maximum_subexpression_size(self) -> int: """Returns the maximum possible subexpression that is relevant to consider for clustering into conditions.""" if len(self._candidates) < 2: @@ -112,6 +155,8 @@ def maximum_subexpression_size(self) -> int: def get_symbol_names_of(self, node: AbstractSyntaxTreeNode) -> Set[str]: """Return all symbols that are used in the formula of the given ast-node.""" + if node not in self._candidates: + return set() return {symbol.name for symbol in self._auxiliary_graph.successors(self._candidates[node])} def get_next_subexpression(self) -> Iterator[Tuple[AbstractSyntaxTreeNode, LogicCondition]]: @@ -124,12 +169,57 @@ def get_next_subexpression(self) -> Iterator[Tuple[AbstractSyntaxTreeNode, Logic while current_size > 0 and ast_node in self._candidates: for new_operands in combinations(clauses, current_size): yield ast_node, LogicCondition.conjunction_of(new_operands) + if ast_node not in self._candidates: + break current_size -= 1 - def remove_ast_nodes(self, nodes_to_remove: List[AbstractSyntaxTreeNode]) -> None: + def update_properties_for_integrating_second_node_into_first(self, cond_node: ConditionNode, merged_node: AbstractSyntaxTreeNode): + """ + Update the condition-candidate properties when the merged-node is integrated into the given condition-node. + + - Update the sibling-reachability + - Add the condition-node if it is new and update it otherwise + - remove the merged-node from the candidate list. + """ + self.sibling_reachability.merge_siblings_to(cond_node, [merged_node]) + if cond_node not in self._candidates: + self._add_ast_node(cond_node) + else: + self._update_ast_node(cond_node) + self._remove_ast_nodes([merged_node]) + + def _remove_ast_nodes(self, nodes_to_remove: List[AbstractSyntaxTreeNode]) -> None: """Remove formulas associated with the given nodes from the graph.""" self._remove_formulas(set(self._candidates[node] for node in nodes_to_remove)) + def _add_ast_node(self, condition_node: ConditionNode): + """Add new node to the logic-graph""" + formula = Formula(condition_node) + self._candidates[condition_node] = formula + self._unconsidered_nodes.add(condition_node) + self._logic_graph.add_node(formula) + for clause in formula.clauses(): + self._logic_graph.add_edge(formula, clause) + for symbol_name in clause.condition.get_symbols_as_string(): + self._logic_graph.add_edge(clause, symbol := Symbol(symbol_name)) + self._logic_graph.add_edge(formula, symbol, auxiliary=True) + + def _update_ast_node(self, ast_node: ConditionNode): + """ + Update the graph properties for the given condition-node. + + - If it was a single-branch condition node before and is now a condition node with two branches, then we have to update its clauses + in the logic-graph. + """ + assert ast_node in self._candidates, "The condition node must be a candidate." + formula = self._candidates[ast_node] + if not ast_node.is_single_branch and not all(isinstance(c, ClauseFormula) for c in self._formula_graph.successors(formula)): + assert len(clauses := formula.clauses()) == 1, "A non-single condition node should have one formula clause!" + self._logic_graph.remove_nodes_from(list(self._formula_graph.successors(formula))) + self._logic_graph.add_edge(ast_node, clauses[0]) + for symbol in self._auxiliary_graph.successors(formula): + self._logic_graph.add_edge(clauses[0], symbol) + @property def _auxiliary_graph(self) -> DiGraph: """Return a read-only view of the logic-graph containing only the auxiliary-edges, i.e., the edges between formulas and symbols.""" @@ -227,17 +317,20 @@ class ConditionBasedRefinement: Because ¬b1 ∨ ¬b2 is equivalent to ¬(b1∧b2) according to De Morgan's law. """ - def __init__(self, asforest: AbstractSyntaxForest): + def __init__(self, asforest: AbstractSyntaxForest, root: Optional[AbstractSyntaxTreeNode] = None): """Init an instance of the condition-based refinement.""" self.asforest: AbstractSyntaxForest = asforest - self.root: AbstractSyntaxTreeNode = asforest.current_root + self.root: AbstractSyntaxTreeNode = asforest.current_root if root is None else root + self._condition_candidates: Optional[ConditionCandidates] = None @classmethod - def refine(cls, asforest: AbstractSyntaxForest) -> None: - """Apply the condition-based-refinement to the given abstract-syntax-forest.""" - if not isinstance(asforest.current_root, SeqNode): + def refine(cls, asforest: AbstractSyntaxForest, root: Optional[AbstractSyntaxTreeNode] = None) -> None: + """Apply the condition-based-refinement to the given abstract-syntax-forest starting at the given root.""" + if root is None: + root = asforest.current_root + if not isinstance(root, SeqNode): return - if_refinement = cls(asforest) + if_refinement = cls(asforest, root) if_refinement._condition_based_refinement() def _condition_based_refinement(self) -> None: @@ -297,51 +390,88 @@ def _structure_sequence_node(self, sequence_node: SeqNode) -> Set[SeqNode]: """ newly_created_sequence_nodes: Set[SeqNode] = set() sibling_reachability: SiblingReachability = self.asforest.get_sibling_reachability_of_children_of(sequence_node) - condition_candidates = ConditionCandidates([child for child in sequence_node.children if not child.reaching_condition.is_true]) - for child, subexpression in condition_candidates.get_next_subexpression(): - true_cluster, false_cluster = self._cluster_by_condition(subexpression, child, condition_candidates) - all_cluster_nodes = true_cluster + false_cluster - - if len(all_cluster_nodes) < 2: - continue - if self._can_place_condition_node_with_branches(all_cluster_nodes, sibling_reachability): - condition_node = self.asforest.create_condition_node_with(subexpression, true_cluster, false_cluster) - if len(true_cluster) > 1: - newly_created_sequence_nodes.add(condition_node.true_branch_child) - if len(false_cluster) > 1: - newly_created_sequence_nodes.add(condition_node.false_branch_child) - sibling_reachability.merge_siblings_to(condition_node, all_cluster_nodes) - sequence_node._sorted_children = sibling_reachability.sorted_nodes() - condition_candidates.remove_ast_nodes(all_cluster_nodes) + self._condition_candidates = ConditionCandidates(sequence_node, sibling_reachability) + for child, subexpression in self._condition_candidates.get_next_subexpression(): + newly_created_sequence_nodes.update(self._cluster_by_condition(subexpression, child)) return newly_created_sequence_nodes - def _cluster_by_condition( - self, sub_expression: LogicCondition, node_with_subexpression: AbstractSyntaxTreeNode, condition_candidates: ConditionCandidates - ) -> Tuple[List[AbstractSyntaxTreeNode], List[AbstractSyntaxTreeNode]]: + def _cluster_by_condition(self, sub_expression: LogicCondition, current_node: AbstractSyntaxTreeNode) -> List[SeqNode]: """ - Cluster the nodes in sequence_nodes according to the input condition. + Cluster the nodes of the current-sequence node according to the input condition belonging to the given ast-node. - :param sub_expression: The condition for which we check whether it or its negation is a subexpression of the list of input nodes. - :param node_with_subexpression: The node of which the given sub_expression is a sub-expression - :param condition_candidates: class-object handling all condition candidates. - :return: A 2-tuple, where the first list is the set of nodes that have condition as subexpression, the second list is the set of - nodes that have the negated condition as subexpression. + :param sub_expression: The condition for which we check whether it or its negation is a subexpression of a sequence-node child. + :param current_node: The node of which the given sub_expression is a sub-expression. + :return: A list of all new created sequence-nodes that should be considered for a future iteration of the CBR. """ - true_children = [] - false_children = [] symbols_of_condition = set(sub_expression.get_symbols_as_string()) - negated_condition = None - for ast_node in condition_candidates.candidates: - if symbols_of_condition - condition_candidates.get_symbol_names_of(ast_node): + negated_condition: Optional[LogicCondition] = None + for ast_node in [candidate for candidate in self._condition_candidates.candidates if candidate != current_node]: + if symbols_of_condition - self._condition_candidates.get_symbol_names_of( + ast_node + ) or not self._can_place_condition_node_with_branches( + [current_node, ast_node], self._condition_candidates.sibling_reachability + ): continue - if ast_node == node_with_subexpression or self._is_subexpression_of_cnf_formula(sub_expression, ast_node.reaching_condition): - true_children.append(ast_node) + if self._is_possible_branch(ast_node, sub_expression): + current_node = self._add_condition_node_if_needed(current_node, sub_expression) + self._add_node_to_condition(current_node, ast_node, "true") + elif self._is_possible_branch(ast_node, negated_condition := self._get_negated_condition_of(sub_expression, negated_condition)): + current_node = self._add_condition_node_if_needed(current_node, sub_expression) + self._add_node_to_condition(current_node, ast_node, "false") + + if isinstance(current_node, ConditionNode): + return [branch.child for branch in current_node.children if isinstance(branch.child, SeqNode)] + return [] + + def _is_possible_branch(self, ast_node: AbstractSyntaxTreeNode, sub_expression: LogicCondition) -> bool: + """ + Check whether the given ast-node is a possible branch for a condition-node where one branch-condition is the given sub-expression. + """ + condition, is_if_else_node = self._condition_candidates.get_condition(ast_node) + return (not is_if_else_node and self._is_subexpression_of_cnf_formula(sub_expression, condition)) or ( + is_if_else_node and sub_expression.is_equivalent_to(condition) + ) + + def _add_condition_node_if_needed(self, ast_node: AbstractSyntaxTreeNode, sub_expression: LogicCondition) -> ConditionNode: + """ + If the given AST-node is not a condition-node whose condition is equal to the sub-expression, + then we create a condition with the given condition and the ast-node as true-child. + We return a condition node having the given sub-expression as condition. + """ + if not isinstance(ast_node, ConditionNode) or not sub_expression.is_equal_to(ast_node.condition): + tmp = ast_node + ast_node = self.asforest.create_condition_node_with(sub_expression, [ast_node], []) + self._condition_candidates.update_properties_for_integrating_second_node_into_first(ast_node, tmp) + return ast_node + + def _add_node_to_condition(self, condition_node: ConditionNode, ast_node: AbstractSyntaxTreeNode, branch: Literal["true", "false"]): + """ + Add the given ast-node as a branch to the given condition-node. + + - If the branch is "true" the ast-node will be part of the true-branch of the given condition-node. + - If the branch is "false" the ast-node will be part of the false-branch of the given condition-node. + """ + true_cluster, false_cluster = None, None + if isinstance(ast_node, ConditionNode) and ast_node.reaching_condition.is_true: + true_cluster = ast_node.true_branch_child + if ast_node.false_branch: + false_cluster = ast_node.false_branch_child else: - negated_condition = self._get_negated_condition_of(sub_expression, negated_condition) - if self._is_subexpression_of_cnf_formula(negated_condition, ast_node.reaching_condition): - false_children.append(ast_node) - return true_children, false_children + assert true_cluster.reaching_condition.is_true, "single-branch Conditions should not have a RC at this point." + true_cluster.reaching_condition = ast_node.true_branch.branch_condition.copy() + true_expression = condition_node.condition if branch == "true" else ~condition_node.condition + true_cluster.reaching_condition.substitute_by_true(true_expression) + else: + true_cluster = ast_node + true_expression = condition_node.condition if branch == "true" else ~condition_node.condition + true_cluster.reaching_condition.substitute_by_true(true_expression) + if branch == "false": + true_cluster, false_cluster = false_cluster, true_cluster + + self.asforest.add_branches_to_condition_node(condition_node, true_cluster, false_cluster) + self._condition_candidates.update_properties_for_integrating_second_node_into_first(condition_node, ast_node) + self._condition_candidates.sequence_node._sorted_children = self._condition_candidates.sibling_reachability.sorted_nodes() @staticmethod def _get_negated_condition_of(condition: LogicCondition, negated_condition: Optional[LogicCondition]) -> LogicCondition: diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/loop_structuring_rules.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/loop_structuring_rules.py index bd1b31cc5..e1a17ee13 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/loop_structuring_rules.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/loop_structuring_rules.py @@ -119,8 +119,7 @@ def can_be_applied(loop_node: AbstractSyntaxTreeNode): return ( loop_node.is_endless_loop and isinstance(body := loop_node.body, SeqNode) - and isinstance(condition_node := body.children[-1], ConditionNode) - and len(condition_node.children) == 1 + and body.children[-1].is_single_branch and not any(child._has_descendant_code_node_breaking_ancestor_loop() for child in body.children[:-1]) ) diff --git a/decompiler/pipeline/dataflowanalysis/type_propagation.py b/decompiler/pipeline/dataflowanalysis/type_propagation.py index d234cb01c..2d093bf14 100644 --- a/decompiler/pipeline/dataflowanalysis/type_propagation.py +++ b/decompiler/pipeline/dataflowanalysis/type_propagation.py @@ -2,11 +2,11 @@ from __future__ import annotations -from collections import Counter, defaultdict +from collections import Counter from enum import Enum from itertools import chain from logging import info -from typing import DefaultDict, Iterator, List, Set, Tuple +from typing import Iterator, List, Tuple from decompiler.pipeline.stage import PipelineStage from decompiler.structures.graphs.cfg import ControlFlowGraph @@ -29,7 +29,6 @@ class EdgeType(Enum): def __init__(self, **attr): """Generate a new TypeGraph, appending a dict for usage tracking.""" super().__init__(**attr) - self._usages: DefaultDict[Expression, Set] = defaultdict(set) @classmethod def from_cfg(cls, cfg: ControlFlowGraph) -> TypeGraph: @@ -57,7 +56,6 @@ def add_expression(self, expression: Expression, parent: Instruction): while todo: head = todo.pop() self.add_node(self._make_node(head), **{str(id(head)): head}) - self._usages[self._make_node(head)].add(parent) children = list(head) todo.extend(children) for sub_expression in children: diff --git a/decompiler/pipeline/ssa/dependency_graph.py b/decompiler/pipeline/ssa/dependency_graph.py index 928d84fa3..941f67f05 100644 --- a/decompiler/pipeline/ssa/dependency_graph.py +++ b/decompiler/pipeline/ssa/dependency_graph.py @@ -1,74 +1,114 @@ -from typing import Iterable, List, Optional, Set +import itertools +from itertools import combinations +from typing import Iterator +import networkx from decompiler.structures.graphs.cfg import ControlFlowGraph from decompiler.structures.interferencegraph import InterferenceGraph +from decompiler.structures.pseudo import Expression, Operation, OperationType from decompiler.structures.pseudo.expressions import Variable from decompiler.structures.pseudo.instructions import Assignment -from decompiler.structures.pseudo.operations import Call -from networkx import DiGraph, weakly_connected_components +from decompiler.util.decoration import DecoratedGraph +from networkx import MultiDiGraph +# Multiplicative constant applied to dependency scores when encountering operations, to penalize too much nesting. +OPERATION_PENALTY = 0.9 -def _non_call_assignments(cfg: ControlFlowGraph) -> Iterable[Assignment]: + +def decorate_dependency_graph(dependency_graph: MultiDiGraph, interference_graph: InterferenceGraph) -> DecoratedGraph: + """ + Creates a decorated graph from the given dependency and interference graphs. + + This function constructs a new graph where: + - Variables are represented as nodes. + - Dependencies between variables are represented as directed edges. + - Interferences between variables are represented as red, undirected edges. + """ + decorated_graph = MultiDiGraph() + for node in dependency_graph.nodes: + decorated_graph.add_node(hash(node), label="\n".join(map(lambda n: f"{n}: {n.type}, aliased: {n.is_aliased}", node))) + for u, v, data in dependency_graph.edges.data(): + decorated_graph.add_edge(hash(u), hash(v), label=f"{data['score']}") + for nodes in networkx.weakly_connected_components(dependency_graph): + for node_1, node_2 in combinations(nodes, 2): + if any(interference_graph.has_edge(pair[0], pair[1]) for pair in itertools.product(node_1, node_2)): + decorated_graph.add_edge(hash(node_1), hash(node_2), color="red", dir="none") + + return DecoratedGraph(decorated_graph) + + +def dependency_graph_from_cfg(cfg: ControlFlowGraph) -> MultiDiGraph: + """ + Construct the dependency graph of the given CFG, i.e. adds an edge between two variables if they depend on each other. + - Add an edge the definition to at most one requirement for each instruction. + - All variables that where not defined via Phi-functions before have out-degree of at most 1, because they are defined at most once. + - Variables that are defined via Phi-functions can have one successor for each required variable of the Phi-function. + """ + dependency_graph = MultiDiGraph() + + for variable in _collect_variables(cfg): + dependency_graph.add_node((variable,)) + for instruction in _assignments_in_cfg(cfg): + defined_variables = instruction.definitions + for used_variable, score in _expression_dependencies(instruction.value).items(): + if score > 0: + dependency_graph.add_edges_from((((dvar,), (used_variable,)) for dvar in defined_variables), score=score) + + return dependency_graph + + +def _collect_variables(cfg: ControlFlowGraph) -> Iterator[Variable]: + """ + Yields all variables contained in the given control flow graph. + """ + for instruction in cfg.instructions: + for subexpression in instruction.subexpressions(): + if isinstance(subexpression, Variable): + yield subexpression + + +def _assignments_in_cfg(cfg: ControlFlowGraph) -> Iterator[Assignment]: """Yield all interesting assignments for the dependency graph.""" for instr in cfg.instructions: - if isinstance(instr, Assignment) and isinstance(instr.destination, Variable) and not isinstance(instr.value, Call): + if isinstance(instr, Assignment): yield instr -class DependencyGraph(DiGraph): - def __init__(self, interference_graph: Optional[InterferenceGraph] = None): - super().__init__() - self.add_nodes_from(interference_graph.nodes) - self.interference_graph = interference_graph - - @classmethod - def from_cfg(cls, cfg: ControlFlowGraph, interference_graph: InterferenceGraph): - """ - Construct the dependency graph of the given CFG, i.e. adds an edge between two variables if they depend on each other. - - Add an edge the definition to at most one requirement for each instruction. - - All variables that where not defined via Phi-functions before have out-degree at most 1, because they are defined at most once - - Variables that are defined via Phi-functions can have one successor for each required variable of the Phi-function. - """ - dependency_graph = cls(interference_graph) - for instruction in _non_call_assignments(cfg): - defined_variable = instruction.destination - if isinstance(instruction.value, Variable): - if dependency_graph._variables_can_have_same_name(defined_variable, instruction.value): - dependency_graph.add_edge(defined_variable, instruction.requirements[0], strength="high") - elif len(instruction.requirements) == 1: - if dependency_graph._variables_can_have_same_name(defined_variable, instruction.requirements[0]): - dependency_graph.add_edge(defined_variable, instruction.requirements[0], strength="medium") - else: - if non_interfering_variable := dependency_graph._non_interfering_requirements(instruction.requirements, defined_variable): - dependency_graph.add_edge(defined_variable, non_interfering_variable, strength="low") - return dependency_graph - - def _non_interfering_requirements(self, requirements: List[Variable], defined_variable: Variable) -> Optional[Variable]: - """Get the unique non-interfering requirement if it exists, otherwise we return None.""" - non_interfering_requirement = None - for required_variable in requirements: - if self._variables_can_have_same_name(defined_variable, required_variable): - if non_interfering_requirement: - return None - non_interfering_requirement = required_variable - return non_interfering_requirement - - def _variables_can_have_same_name(self, source: Variable, sink: Variable) -> bool: - """ - Two variable can have the same name, if they have the same type, are both aliased or both non-aliased variables, and if they - do not interfere. - - :param source: The potential source vertex. - :param sink: The potential sink vertex - :return: True, if the given variables can have the same name, and false otherwise. - """ - if self.interference_graph.are_interfering(source, sink) or source.type != sink.type or source.is_aliased != sink.is_aliased: - return False - if source.is_aliased and sink.is_aliased and source.name != sink.name: - return False - return True - - def get_components(self) -> Iterable[Set[Variable]]: - """Returns the weakly connected components of the dependency graph.""" - for component in weakly_connected_components(self): - yield set(component) +def _expression_dependencies(expression: Expression) -> dict[Variable, float]: + """ + Calculate the dependencies of an expression in terms of its constituent variables. + + This function analyzes the given `expression` and returns a dictionary mapping each + `Variable` to a float score representing its contribution or dependency weight within + the expression. + The scoring mechanism accounts for different types of operations and + penalizes nested operations to reflect their complexity. + """ + match expression: + case Variable(): + return {expression: 1.0} + case Operation(): + if expression.operation in { + OperationType.call, + OperationType.address, + OperationType.dereference, + OperationType.member_access, + }: + return {} + + operands_dependencies = list(filter(lambda d: d, (_expression_dependencies(operand) for operand in expression.operands))) + dependencies: dict[Variable, float] = {} + for deps in operands_dependencies: + for var in deps: + score = deps[var] + score /= len(operands_dependencies) + score *= OPERATION_PENALTY # penalize operations, so that expressions like (a + (a + (a + (a + a)))) gets a lower score than just (a) + + if var not in dependencies: + dependencies[var] = score + else: + dependencies[var] += score + + return dependencies + case _: + return {} diff --git a/decompiler/pipeline/ssa/outofssatranslation.py b/decompiler/pipeline/ssa/outofssatranslation.py index cd76fe4c5..e6007fd58 100644 --- a/decompiler/pipeline/ssa/outofssatranslation.py +++ b/decompiler/pipeline/ssa/outofssatranslation.py @@ -4,12 +4,12 @@ from collections import defaultdict from configparser import NoOptionError from enum import Enum -from typing import DefaultDict, List +from typing import Callable, DefaultDict, List from decompiler.pipeline.ssa.phi_cleaner import PhiFunctionCleaner from decompiler.pipeline.ssa.phi_dependency_resolver import PhiDependencyResolver from decompiler.pipeline.ssa.phi_lifting import PhiFunctionLifter -from decompiler.pipeline.ssa.variable_renaming import MinimalVariableRenamer, SimpleVariableRenamer +from decompiler.pipeline.ssa.variable_renaming import ConditionalVariableRenamer, MinimalVariableRenamer, SimpleVariableRenamer from decompiler.pipeline.stage import PipelineStage from decompiler.structures.graphs.cfg import BasicBlock from decompiler.structures.interferencegraph import InterferenceGraph @@ -98,12 +98,13 @@ def _out_of_ssa(self) -> None: -> There are different optimization levels """ - try: - self.out_of_ssa_strategy[self._optimization](self) - except KeyError: - error_message = f"The Out of SSA according to the optimization level {self._optimization.value} is not implemented so far." - logging.error(error_message) - raise NotImplementedError(error_message) + strategy = self.out_of_ssa_strategy.get(self._optimization, None) + if strategy is None: + raise NotImplementedError( + f"The Out of SSA according to the optimization level {self._optimization.value} is not implemented so far." + ) + + strategy(self) def _simple_out_of_ssa(self) -> None: """ @@ -158,12 +159,15 @@ def _conditional_out_of_ssa(self) -> None: This is a more advanced algorithm for out of SSA: - We first remove the circular dependency of the Phi-functions - Then, we remove the Phi-functions by lifting them to their predecessor basic blocks. - - Afterwards, we rename the variables, by considering their dependency on each other. + - Afterwards, we rename the variables by considering their dependency on each other. """ - pass + PhiDependencyResolver(self._phi_functions_of).resolve() + self.interference_graph = InterferenceGraph(self.task.graph) + PhiFunctionLifter(self.task.graph, self.interference_graph, self._phi_functions_of).lift() + ConditionalVariableRenamer(self.task, self.interference_graph).rename() # This translator maps the optimization levels to the functions. - out_of_ssa_strategy = { + out_of_ssa_strategy: dict[SSAOptions, Callable[["OutOfSsaTranslation"], None]] = { SSAOptions.simple: _simple_out_of_ssa, SSAOptions.minimization: _minimization_out_of_ssa, SSAOptions.lift_minimal: _lift_minimal_out_of_ssa, diff --git a/decompiler/pipeline/ssa/variable_renaming.py b/decompiler/pipeline/ssa/variable_renaming.py index 910e0d6cf..8280cacb6 100644 --- a/decompiler/pipeline/ssa/variable_renaming.py +++ b/decompiler/pipeline/ssa/variable_renaming.py @@ -7,6 +7,9 @@ from operator import attrgetter from typing import DefaultDict, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union +import networkx +from decompiler.pipeline.ssa.dependency_graph import dependency_graph_from_cfg +from decompiler.structures.graphs.cfg import ControlFlowGraph from decompiler.structures.interferencegraph import InterferenceGraph from decompiler.structures.pseudo.expressions import GlobalVariable, Variable from decompiler.structures.pseudo.instructions import BaseAssignment, Instruction, Relation @@ -14,7 +17,7 @@ from decompiler.task import DecompilerTask from decompiler.util.insertion_ordered_set import InsertionOrderedSet from decompiler.util.lexicographical_bfs import LexicographicalBFS -from networkx import Graph, connected_components +from networkx import Graph, MultiDiGraph, connected_components @dataclass @@ -121,10 +124,11 @@ def rename(self): def _replace_variable_in_instruction(self, variable: Variable, instruction: Instruction) -> None: """Replace the given variable in the given instruction""" - if variable.ssa_label is None: + if variable not in self.renaming_map: return replacement_variable = self.renaming_map[variable].copy() - replacement_variable.ssa_name = variable.copy() + if variable.ssa_label is not None: + replacement_variable.ssa_name = variable.copy() instruction.substitute(variable, replacement_variable) if isinstance(instruction, Relation): instruction.rename(variable, replacement_variable) @@ -334,3 +338,91 @@ def _classes_of(self, neighborhood: Iterable[Variable]) -> Iterable[Variable]: for neighbor in neighborhood: if neighbor in self._variable_classes_handler.color_class_of: yield self._variable_classes_handler.color_class_of[neighbor] + + +class ConditionalVariableRenamer(VariableRenamer): + """ + A renaming strategy that renames the SSA-variables, such that only variables that have a relation with each other can get the same name. + Therefore, we construct a dependency-graph with weights, telling us how likely these two variables are the same variable, i.e., + copy-assignments are more likely to be identically than complicated computations. + """ + + def __init__(self, task, interference_graph: InterferenceGraph): + """ + self._color_classes is a dictionary where the set of keys is the set of colors + and to each color we assign the set of variables of this color. + """ + super().__init__(task, interference_graph.copy()) + self._generate_renaming_map(task.graph) + + def _generate_renaming_map(self, cfg: ControlFlowGraph): + """ + Generate the renaming map for SSA variables. + + This function constructs a dependency graph from the given CFG, merges contracted variables, + creates variable classes, and computes new names for each variable. The process ensures that + only variables with specific relationships can share the same name, as determined by the + dependency graph. + + :param cfg: The control flow graph from which the dependency graph is derived. + """ + dependency_graph = dependency_graph_from_cfg(cfg) + dependency_graph = self.merge_contracted_variables(dependency_graph) + + self.create_variable_classes(dependency_graph) + self.compute_new_name_for_each_variable() + + def merge_contracted_variables(self, dependency_graph: MultiDiGraph): + """Merge nodes which need to be contracted from self._variables_contracted_to""" + mapping: dict[tuple[Variable], tuple[Variable, ...]] = {} + for variable in self.interference_graph.nodes(): + contracted = tuple(self._variables_contracted_to[variable]) + for var in contracted: + mapping[(var,)] = contracted + + return networkx.relabel_nodes(dependency_graph, mapping) + + def create_variable_classes(self, dependency_graph: MultiDiGraph): + """Create the variable classes based on the given dependency graph.""" + while True: + merged_edges: dict[frozenset[tuple[Variable, ...]], float] = defaultdict(lambda: 0) + for u, v, score in dependency_graph.edges(data="score"): + if u != v: + merged_edges[frozenset([u, v])] += score + + for (u, v), _ in sorted(merged_edges.items(), key=lambda edge: edge[1], reverse=True): + if u == v: # self loop + continue + if not self._variables_can_have_same_name(u, v): + continue + + break + else: + # We didn't find any remaining nodes to contract, break outer loop + break + + networkx.relabel_nodes(dependency_graph, {u: (*u, *v), v: (*u, *v)}, copy=False) + + self._variable_classes_handler = VariableClassesHandler(defaultdict(set)) + for i, vars in enumerate(dependency_graph.nodes): + for var in vars: + self._variable_classes_handler.add_variable_to_class(var, i) + + def _variables_can_have_same_name(self, source: tuple[Variable, ...], sink: tuple[Variable, ...]) -> bool: + """ + Two sets of variables can have the same name, if they have the same type, are both aliased or both non-aliased variables, and if they + do not interfere. + + :param source: The potential source vertex. + :param sink: The potential sink vertex + :return: True, if the given sets of variables can have the same name, and false otherwise. + """ + if ( + self.interference_graph.are_interfering(*(source + sink)) + or source[0].type != sink[0].type + or source[0].is_aliased != sink[0].is_aliased + ): + return False + if source[0].is_aliased and sink[0].is_aliased and source[0].name != sink[0].name: + return False + return True diff --git a/decompiler/structures/ast/ast_nodes.py b/decompiler/structures/ast/ast_nodes.py index 77b611fa4..4d0c63888 100644 --- a/decompiler/structures/ast/ast_nodes.py +++ b/decompiler/structures/ast/ast_nodes.py @@ -166,6 +166,11 @@ def is_code_node_ending_with_return(self) -> bool: """Checks whether the node is a CodeNode and ends with a return.""" return isinstance(self, CodeNode) and self.does_end_with_return + @property + def is_single_branch(self) -> bool: + """Check whether the node is a condition node with one branch.""" + return isinstance(self, ConditionNode) and len(self.children) == 1 + def get_end_nodes(self) -> Iterable[Union[CodeNode, SwitchNode, LoopNode, ConditionNode]]: """Yields all nodes where the subtree can terminate.""" for child in self.children: @@ -261,7 +266,7 @@ def is_empty(self) -> bool: def copy(self) -> VirtualRootNode: """Return a copy of the ast node.""" - return VirtualRootNode(self.reaching_condition) + return VirtualRootNode(self.reaching_condition.copy()) def accept(self, visitor: ASTVisitorInterface[T]) -> T: return visitor.visit_root_node(self) @@ -288,7 +293,7 @@ def __repr__(self) -> str: def copy(self) -> SeqNode: """Return a copy of the ast node.""" - return SeqNode(self.reaching_condition) + return SeqNode(self.reaching_condition.copy()) @property def children(self) -> Tuple[AbstractSyntaxTreeNode, ...]: @@ -375,7 +380,7 @@ def __repr__(self) -> str: def copy(self) -> CodeNode: """Return a copy of the ast node.""" - return CodeNode(self.instructions.copy(), self.reaching_condition) + return CodeNode([i.copy() for i in self.instructions], self.reaching_condition.copy()) @property def children(self) -> Tuple[AbstractSyntaxTreeNode, ...]: @@ -508,7 +513,7 @@ def __repr__(self) -> str: def copy(self) -> ConditionNode: """Return a copy of the ast node.""" - return ConditionNode(self.condition, self.reaching_condition) + return ConditionNode(self.condition.copy(), self.reaching_condition.copy()) @property def children(self) -> Tuple[Union[TrueNode, FalseNode], ...]: @@ -655,7 +660,7 @@ def __repr__(self) -> str: def copy(self) -> TrueNode: """Return a copy of the ast node.""" - return TrueNode(self.reaching_condition) + return TrueNode(self.reaching_condition.copy()) @property def branch_condition(self) -> LogicCondition: @@ -680,7 +685,7 @@ def __repr__(self) -> str: def copy(self) -> FalseNode: """Return a copy of the ast node.""" - return FalseNode(self.reaching_condition) + return FalseNode(self.reaching_condition.copy()) @property def branch_condition(self) -> LogicCondition: @@ -810,7 +815,7 @@ class WhileLoopNode(LoopNode): def copy(self) -> WhileLoopNode: """Return a copy of the ast node.""" - return WhileLoopNode(self.condition, self.reaching_condition) + return WhileLoopNode(self.condition.copy(), self.reaching_condition.copy()) @property def loop_type(self) -> LoopType: @@ -837,7 +842,7 @@ class DoWhileLoopNode(LoopNode): def copy(self) -> DoWhileLoopNode: """Return a copy of the ast node.""" - return DoWhileLoopNode(self.condition, self.reaching_condition) + return DoWhileLoopNode(self.condition.copy(), self.reaching_condition.copy()) @property def loop_type(self) -> LoopType: @@ -882,7 +887,7 @@ def __str__(self) -> str: def copy(self) -> ForLoopNode: """Return a copy of the ast node.""" - return ForLoopNode(self.declaration, self.condition, self.modification, self.reaching_condition) + return ForLoopNode(self.declaration.copy(), self.condition.copy(), self.modification.copy(), self.reaching_condition.copy()) @property def loop_type(self) -> LoopType: @@ -950,7 +955,7 @@ def __repr__(self) -> str: def copy(self) -> SwitchNode: """Return a copy of the ast node.""" - return SwitchNode(self.expression, self.reaching_condition) + return SwitchNode(self.expression.copy(), self.reaching_condition.copy()) @property def children(self) -> Tuple[CaseNode]: @@ -1084,7 +1089,7 @@ def __repr__(self) -> str: def copy(self) -> CaseNode: """Return a copy of the ast node.""" - return CaseNode(self.expression, self.constant, self.reaching_condition, self.break_case) + return CaseNode(self.expression.copy(), self.constant.copy(), self.reaching_condition.copy(), self.break_case) @property def does_end_with_break(self) -> bool: diff --git a/decompiler/structures/ast/condition_symbol.py b/decompiler/structures/ast/condition_symbol.py index 8c9c2c374..c6e45b6e4 100644 --- a/decompiler/structures/ast/condition_symbol.py +++ b/decompiler/structures/ast/condition_symbol.py @@ -1,29 +1,180 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Dict, Iterable, Optional, Set +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union from decompiler.structures.logic.logic_condition import LogicCondition, PseudoLogicCondition -from decompiler.structures.pseudo import Condition +from decompiler.structures.logic.z3_implementations import Z3Implementation +from decompiler.structures.pseudo import Condition, Constant, Expression, OperationType, Variable, Z3Converter +from z3 import BoolRef + + +def _is_equivalent(cond1: BoolRef, cond2: BoolRef): + """Check whether the given conditions are equivalent.""" + z3_implementation = Z3Implementation(True) + if z3_implementation.is_equal(cond1, cond2): + return True + return z3_implementation.does_imply(cond1, cond2) and z3_implementation.does_imply(cond2, cond1) + + +def _get_ssa_expression(expression_usage: ExpressionUsages) -> Expression: + """Construct SSA-expression of the given expression.""" + if isinstance(expression_usage.expression, Variable): + return expression_usage.expression.ssa_name if expression_usage.expression.ssa_name else expression_usage.expression + ssa_expression = expression_usage.expression.copy() + for variable in [var for var in ssa_expression.requirements if var.ssa_name]: + ssa_expression.substitute(variable, variable.ssa_name) + return ssa_expression + + +@dataclass(frozen=True) +class ExpressionUsages: + """Dataclass maintaining for a condition the used SSA-variables.""" + + expression: Expression + ssa_usages: Tuple[Optional[Variable], ...] + + @classmethod + def from_expression(cls, expression: Expression) -> ExpressionUsages: + return ExpressionUsages(expression, tuple(var.ssa_name for var in expression.requirements)) @dataclass(frozen=True) +class ZeroCaseCondition: + """Possible switch expression together with its zero-case condition.""" + + expression: Expression + ssa_usages: Set[Optional[Variable]] + z3_condition: BoolRef + + def are_equivalent(self, other: Union[ZeroCaseCondition, PotentialZeroCaseCondition]) -> bool: + return self.ssa_usages == other.ssa_usages and _is_equivalent(self.z3_condition, other.z3_condition) + + +@dataclass(frozen=True) +class PotentialZeroCaseCondition: + """Possible zero-case condition with its z3-condition and ssa-usages.""" + + expression: Condition + ssa_usages: Set[Optional[Variable]] + z3_condition: BoolRef + + def are_equivalent(self, other: Union[ZeroCaseCondition, PotentialZeroCaseCondition]) -> bool: + return self.ssa_usages == other.ssa_usages and _is_equivalent(self.z3_condition, other.z3_condition) + + +@dataclass(frozen=True) +class CaseNodeProperties: + """ + Class for mapping possible expression and constant of a symbol for a switch-case. + + -> symbol: symbol that belongs to the expression and constant + -> constant: the compared constant + -> negation: whether the symbol or its negation belongs to a switch-case + -> The condition that the new case node should get. + """ + + symbol: LogicCondition + expression: ExpressionUsages + constant: Constant + negation: bool + + def __eq__(self, other) -> bool: + """ + We want to be able to compare CaseNodeCandidates with AST-nodes, more precisely, + we want that an CaseNodeCandidate 'case_node' is equal to the AST node 'case_node.node'. + """ + if isinstance(other, CaseNodeProperties): + return self.symbol == other.symbol + return False + + def copy(self) -> CaseNodeProperties: + return CaseNodeProperties(self.symbol, self.expression, self.constant, self.negation) + + +@dataclass class ConditionSymbol: """Dataclass that maintains for each symbol the according condition and its transition in a z3-condition.""" - condition: Condition - symbol: LogicCondition + _condition: Condition + _symbol: LogicCondition z3_condition: PseudoLogicCondition + case_node_property: Optional[CaseNodeProperties] = None + + @property + def condition(self) -> Condition: + return self._condition + + @property + def symbol(self) -> LogicCondition: + return self._symbol + + def __hash__(self) -> int: + return hash((self.condition, self.symbol)) def __eq__(self, other): """Check whether two condition-symbols are equal.""" - return ( - isinstance(other, ConditionSymbol) - and self.condition == other.condition - and self.symbol == other.symbol - and self.z3_condition.is_equivalent_to(other.z3_condition) + return isinstance(other, ConditionSymbol) and self.condition == other.condition and self.symbol == other.symbol + + +@dataclass +class SwitchHandler: + z3_converter: Z3Converter + zero_case_of_switch_expression: Dict[ExpressionUsages, ZeroCaseCondition] + potential_zero_cases: Dict[ConditionSymbol, PotentialZeroCaseCondition] + + @classmethod + def initialize(cls, condition_map: Optional[Dict[LogicCondition, ConditionSymbol]]) -> SwitchHandler: + handler = cls(Z3Converter(), {}, {}) + if condition_map is None: + return handler + for cond_symbol in condition_map.values(): + if cond_symbol.case_node_property is not None: + handler.have_new_zero_case_for(cond_symbol.case_node_property.expression) + elif cond_symbol.condition.operation in {OperationType.equal, OperationType.not_equal} and not any( + isinstance(operand, Constant) for operand in cond_symbol.condition.operands + ): + handler.have_new_potential_zero_case_for(cond_symbol) + return handler + + def have_new_zero_case_for(self, expression_usage: ExpressionUsages) -> bool: + """Returns whether we added a new zero-case condition for the given expression.""" + return expression_usage not in self.zero_case_of_switch_expression and self._successfully_compute_zero_case_condition_for( + expression_usage ) + def have_new_potential_zero_case_for(self, condition_symbol: ConditionSymbol) -> bool: + """Returns whether we added a new zero-case condition for the given expression.""" + return self._successfully_compute_potential_zero_case_condition_for(condition_symbol) + + def _successfully_compute_zero_case_condition_for(self, expression_usage: ExpressionUsages) -> bool: + """Return whether the construction of the zero-case condition was successful and add it to the dictionary.""" + ssa_expression = _get_ssa_expression(expression_usage) + try: + z3_condition = self.z3_converter.convert(Condition(OperationType.equal, [ssa_expression, Constant(0, ssa_expression.type)])) + self.zero_case_of_switch_expression[expression_usage] = ZeroCaseCondition( + expression_usage.expression, set(expression_usage.ssa_usages), z3_condition + ) + return True + except ValueError: + return False + + def _successfully_compute_potential_zero_case_condition_for(self, condition_symbol: ConditionSymbol) -> bool: + """Construct the potential zero-case condition.""" + condition = condition_symbol.condition + expression_usage = ExpressionUsages.from_expression(condition) + ssa_condition = _get_ssa_expression(expression_usage) + assert isinstance(ssa_condition, Condition), f"{ssa_condition} must be of type Condition!" + ssa_condition = ssa_condition.negate() if ssa_condition.operation == OperationType.not_equal else ssa_condition + try: + z3_condition = self.z3_converter.convert(ssa_condition) + self.potential_zero_cases[condition_symbol] = PotentialZeroCaseCondition( + condition, set(expression_usage.ssa_usages), z3_condition + ) + return True + except ValueError: + return False + class ConditionHandler: """Class that handles all the conditions of a transition graph and syntax-forest.""" @@ -33,6 +184,7 @@ def __init__(self, condition_map: Optional[Dict[LogicCondition, ConditionSymbol] self._condition_map: Dict[LogicCondition, ConditionSymbol] = dict() if condition_map is None else condition_map self._symbol_counter = 0 self._logic_context = next(iter(self._condition_map)).context if self._condition_map else LogicCondition.generate_new_context() + self._switch_handler: SwitchHandler = SwitchHandler.initialize(condition_map) def __eq__(self, other) -> bool: """Checks whether two condition handlers are equal.""" @@ -58,7 +210,12 @@ def logic_context(self): def copy(self) -> ConditionHandler: """Return a copy of the condition handler""" condition_map = { - symbol: ConditionSymbol(condition_symbol.condition.copy(), condition_symbol.symbol, condition_symbol.z3_condition) + symbol: ConditionSymbol( + condition_symbol.condition.copy(), + condition_symbol.symbol, + condition_symbol.z3_condition, + condition_symbol.case_node_property.copy(), + ) for symbol, condition_symbol in self._condition_map.items() } return ConditionHandler(condition_map) @@ -71,6 +228,10 @@ def get_z3_condition_of(self, symbol: LogicCondition) -> PseudoLogicCondition: """Return the z3-condition to the given symbol""" return self._condition_map[symbol].z3_condition + def get_case_node_property_of(self, symbol: LogicCondition) -> CaseNodeProperties: + """Return the z3-condition to the given symbol""" + return self._condition_map[symbol].case_node_property + def get_all_symbols(self) -> Set[LogicCondition]: """Return all existing symbols""" return set(self._condition_map.keys()) @@ -87,12 +248,33 @@ def get_reverse_z3_condition_map(self) -> Dict[PseudoLogicCondition, LogicCondit """Return the reverse z3-condition map that maps z3-conditions to symbols.""" return dict((condition_symbol.z3_condition, symbol) for symbol, condition_symbol in self._condition_map.items()) - def update_z3_condition_of(self, symbol: LogicCondition, condition: Condition): - """Change the z3-condition of the given symbol according to the given condition.""" - assert symbol.is_symbol, "Input must be a symbol!" - z3_condition = PseudoLogicCondition.initialize_from_condition(condition, self._logic_context) - pseudo_condition = self.get_condition_of(symbol) - self._condition_map[symbol] = ConditionSymbol(pseudo_condition, symbol, z3_condition) + def get_true_value(self) -> LogicCondition: + """Return a true value.""" + return LogicCondition.initialize_true(self._logic_context) + + def get_false_value(self) -> LogicCondition: + """Return a false value.""" + return LogicCondition.initialize_false(self._logic_context) + + def get_literal_and_constant_of(self, condition: LogicCondition) -> Iterable[LogicCondition, Constant]: + """Get the constant for each literal of the given condition.""" + for literal in condition.get_literals(): + yield literal, self.get_potential_switch_constant_of(literal) + + def get_constants_of(self, condition: LogicCondition) -> Iterable[Constant]: + """Get the constant for each literal of the given condition.""" + for literal in condition.get_literals(): + yield self.get_potential_switch_constant_of(literal) + + def get_potential_switch_constant_of(self, condition: LogicCondition) -> Optional[Constant]: + """Check whether the given condition is a potential switch case, and if return the corresponding constant.""" + if (case_node_property := self._get_case_node_property_of(condition)) is not None: + return case_node_property.constant + + def get_potential_switch_expression_of(self, condition: LogicCondition) -> Optional[ExpressionUsages]: + """Check whether the given condition is a potential switch case, and if return the corresponding expression.""" + if (case_node_property := self._get_case_node_property_of(condition)) is not None: + return case_node_property.expression def add_condition(self, condition: Condition) -> LogicCondition: """Adds a new condition to the condition map and returns the corresponding condition_symbol""" @@ -102,6 +284,7 @@ def add_condition(self, condition: Condition) -> LogicCondition: symbol = self._get_next_symbol() condition_symbol = ConditionSymbol(condition, symbol, z3_condition) + self._set_switch_case_property_for_condition(condition_symbol) self._condition_map[symbol] = condition_symbol return symbol @@ -118,10 +301,87 @@ def _get_next_symbol(self) -> LogicCondition: self._symbol_counter += 1 return LogicCondition.initialize_symbol(f"x{self._symbol_counter}", self._logic_context) - def get_true_value(self) -> LogicCondition: - """Return a true value.""" - return LogicCondition.initialize_true(self._logic_context) + def _set_switch_case_property_for_condition(self, condition_symbol: ConditionSymbol) -> None: + """Compute the switch-case property.""" + condition: Condition = condition_symbol.condition + if condition.operation not in {OperationType.equal, OperationType.not_equal}: + return None + constants: List[Constant] = [operand for operand in condition.operands if isinstance(operand, Constant)] + expressions: List[Expression] = [operand for operand in condition.operands if not isinstance(operand, Constant)] - def get_false_value(self) -> LogicCondition: - """Return a false value.""" - return LogicCondition.initialize_false(self._logic_context) + if len(constants) == 1 and len(expressions) == 1: + expression_usage = ExpressionUsages.from_expression(expressions[0]) + condition_symbol.case_node_property = CaseNodeProperties( + condition_symbol.symbol, expression_usage, constants[0], condition.operation == OperationType.not_equal + ) + self._update_potential_zero_cases_for(expression_usage) + elif len(constants) == 0: + if self._switch_handler.have_new_potential_zero_case_for(condition_symbol): + self._add_zero_case_condition_for(condition_symbol) + + def _update_potential_zero_cases_for(self, expression_usage: ExpressionUsages) -> None: + """ + Update the Zero-cases for the given expression. + + If the switch handler adds a new zero-case condition, we check whether one of the potential zero-cases matches this zero-case. + """ + if self._switch_handler.have_new_zero_case_for(expression_usage): + self._add_missing_zero_cases_for(self._switch_handler.zero_case_of_switch_expression[expression_usage]) + + def _add_missing_zero_cases_for(self, zero_case: ZeroCaseCondition) -> None: + """We check for each potential zero-case whether it matches the given zero-case.""" + found_zero_cases = set() + for condition_symbol, potential_zero_case in self._switch_handler.potential_zero_cases.items(): + if zero_case.are_equivalent(potential_zero_case): + self._update_case_property_for( + condition_symbol, potential_zero_case, ExpressionUsages.from_expression(zero_case.expression) + ) + found_zero_cases.add(condition_symbol) + for zero_case_condition_symbol in found_zero_cases: + del self._switch_handler.potential_zero_cases[zero_case_condition_symbol] + + def _add_zero_case_condition_for(self, potential_zero_case_condition_symbol: ConditionSymbol) -> None: + """ + Check whether the condition belongs to a zero-case of a switch expression. + + If this is the case, we return the switch expression and the zero-constant + """ + potential_zero_case: PotentialZeroCaseCondition = self._switch_handler.potential_zero_cases[potential_zero_case_condition_symbol] + for expression_usage, zero_case in self._switch_handler.zero_case_of_switch_expression.items(): + if potential_zero_case.are_equivalent(zero_case): + self._update_case_property_for(potential_zero_case_condition_symbol, potential_zero_case, expression_usage) + del self._switch_handler.potential_zero_cases[potential_zero_case_condition_symbol] + return None + return None + + def _update_case_property_for( + self, condition_symbol: ConditionSymbol, zero_case: PotentialZeroCaseCondition, expression_usage: ExpressionUsages + ): + """ + Update the case_node_property of the given condition-symbol which belongs to the potential zero-case with the given expression. + """ + condition_symbol.z3_condition = PseudoLogicCondition.initialize_from_condition( + Condition( + zero_case.expression.operation, + [expression_usage.expression, (Constant(0, expression_usage.expression.type))], + ), + self._logic_context, + ) + condition_symbol.case_node_property = CaseNodeProperties( + condition_symbol.symbol, + expression_usage, + Constant(0, expression_usage.expression.type), + zero_case.expression.operation == OperationType.not_equal, + ) + + def _get_case_node_property_of(self, condition: LogicCondition) -> Optional[CaseNodeProperties]: + """Return the case-property of a given literal.""" + negation = False + if condition.is_negation: + condition = condition.operands[0] + negation = True + if condition.is_symbol: + case_node_property = self.get_case_node_property_of(condition) + if case_node_property is not None and case_node_property.negation == negation: + return case_node_property + return None diff --git a/decompiler/structures/ast/switch_node_handler.py b/decompiler/structures/ast/switch_node_handler.py deleted file mode 100644 index 96a2d7ecf..000000000 --- a/decompiler/structures/ast/switch_node_handler.py +++ /dev/null @@ -1,205 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple - -from decompiler.structures.ast.condition_symbol import ConditionHandler -from decompiler.structures.logic.logic_condition import LogicCondition -from decompiler.structures.logic.z3_implementations import Z3Implementation -from decompiler.structures.pseudo import Condition, Constant, Expression, OperationType, Variable, Z3Converter -from z3 import BoolRef - - -@dataclass(frozen=True) -class ExpressionUsages: - """Dataclass maintaining for a condition the used SSA-variables.""" - - expression: Expression - ssa_usages: Tuple[Optional[Variable]] - - -@dataclass -class ZeroCaseCondition: - """Possible switch expression together with its zero-case condition.""" - - expression: Expression - ssa_usages: Set[Optional[Variable]] - z3_condition: BoolRef - - -@dataclass -class CaseNodeProperties: - """ - Class for mapping possible expression and constant of a symbol for a switch-case. - - -> symbol: symbol that belongs to the expression and constant - -> constant: the compared constant - -> The condition that the new case node should get. - """ - - symbol: LogicCondition - expression: ExpressionUsages - constant: Constant - negation: bool - - def __eq__(self, other) -> bool: - """ - We want to be able to compare CaseNodeCandidates with AST-nodes, more precisely, - we want that an CaseNodeCandidate 'case_node' is equal to the AST node 'case_node.node'. - """ - if isinstance(other, CaseNodeProperties): - return self.symbol == other.symbol - return False - - -class SwitchNodeHandler: - """Handler for switch node reconstruction knowing possible constants and expressions for switch-nodes for each symbol.""" - - def __init__(self, condition_handler: ConditionHandler): - """ - Initialize the switch-node constructor. - - self._zero_case_of_switch_expression: maps to each possible switch-expression the possible zero-case condition. - self._case_node_property_of_symbol: maps to each symbol the possible expression and constant for a switch it can belong to. - """ - self._condition_handler: ConditionHandler = condition_handler - self._z3_converter: Z3Converter = Z3Converter() - self._zero_case_of_switch_expression: Dict[ExpressionUsages, ZeroCaseCondition] = dict() - self._get_zero_cases_for_possible_switch_expressions() - self._case_node_properties_of_symbol: Dict[LogicCondition, Optional[CaseNodeProperties]] = dict() - self._initialize_case_node_properties_for_symbols() - - def is_potential_switch_case(self, condition: LogicCondition) -> bool: - """Check whether the given condition is a potential switch case.""" - return self._get_case_node_property_of(condition) is not None - - def get_potential_switch_expression(self, condition: LogicCondition) -> Optional[ExpressionUsages]: - """Check whether the given condition is a potential switch case, and if return the corresponding expression.""" - if (case_node_property := self._get_case_node_property_of(condition)) is not None: - return case_node_property.expression - - def get_potential_switch_constant(self, condition: LogicCondition) -> Optional[Constant]: - """Check whether the given condition is a potential switch case, and if return the corresponding constant.""" - if (case_node_property := self._get_case_node_property_of(condition)) is not None: - return case_node_property.constant - - def get_literal_and_constant_for(self, condition: LogicCondition) -> Iterable[LogicCondition, Constant]: - """Get the constant for each literal of the given condition.""" - for literal in condition.get_literals(): - yield literal, self.get_potential_switch_constant(literal) - - def get_constants_for(self, condition: LogicCondition) -> Iterable[Constant]: - """Get the constant for each literal of the given condition.""" - for literal in condition.get_literals(): - yield self.get_potential_switch_constant(literal) - - def _get_case_node_property_of(self, condition: LogicCondition) -> Optional[CaseNodeProperties]: - """Return the case-property of a given literal.""" - negation = False - if condition.is_negation: - condition = condition.operands[0] - negation = True - if condition.is_symbol: - if condition not in self._case_node_properties_of_symbol: - self._case_node_properties_of_symbol[condition] = self.__get_case_node_property_of_symbol(condition) - if (case_property := self._case_node_properties_of_symbol[condition]) is not None and case_property.negation == negation: - return case_property - return None - - def _get_zero_cases_for_possible_switch_expressions(self) -> None: - """Get all possible switch expressions, i.e., all expression compared with a constant, together with the potential zero case.""" - for symbol in self._condition_handler.get_all_symbols(): - self.__add_switch_expression_and_zero_case_for_symbol(symbol) - - def __add_switch_expression_and_zero_case_for_symbol(self, symbol: LogicCondition) -> None: - """Add possible switch condition for symbol if comparison of expression with constant.""" - assert symbol.is_symbol, f"Each symbol should be a single Literal, but we have {symbol}" - non_constants = [op for op in self._condition_handler.get_condition_of(symbol).operands if not isinstance(op, Constant)] - if len(non_constants) != 1: - return None - expression_usage = ExpressionUsages(non_constants[0], tuple(var.ssa_name for var in non_constants[0].requirements)) - if expression_usage not in self._zero_case_of_switch_expression: - self.__add_switch_expression(expression_usage) - - def __add_switch_expression(self, expression_usage: ExpressionUsages) -> None: - """Construct the zero case condition and add it to the dictionary.""" - ssa_expression = self.__get_ssa_expression(expression_usage) - try: - z3_condition = self._z3_converter.convert(Condition(OperationType.equal, [ssa_expression, Constant(0, ssa_expression.type)])) - except ValueError: - return - self._zero_case_of_switch_expression[expression_usage] = ZeroCaseCondition( - expression_usage.expression, set(expression_usage.ssa_usages), z3_condition - ) - - @staticmethod - def __get_ssa_expression(expression_usage: ExpressionUsages) -> Expression: - """Construct SSA-expression of the given expression.""" - if isinstance(expression_usage.expression, Variable): - return expression_usage.expression.ssa_name if expression_usage.expression.ssa_name else expression_usage.expression - ssa_expression = expression_usage.expression.copy() - for variable in [var for var in ssa_expression.requirements if var.ssa_name]: - ssa_expression.substitute(variable, variable.ssa_name) - return ssa_expression - - def _initialize_case_node_properties_for_symbols(self) -> None: - """Initialize for each symbol the possible switch case properties""" - for symbol in self._condition_handler.get_all_symbols(): - self._case_node_properties_of_symbol[symbol] = self.__get_case_node_property_of_symbol(symbol) - - def __get_case_node_property_of_symbol(self, symbol: LogicCondition) -> Optional[CaseNodeProperties]: - """Return CaseNodeProperty of the given symbol, if it exists.""" - condition = self._condition_handler.get_condition_of(symbol) - if condition.operation not in {OperationType.equal, OperationType.not_equal}: - return None - constants: List[Constant] = [operand for operand in condition.operands if isinstance(operand, Constant)] - expressions: List[Expression] = [operand for operand in condition.operands if not isinstance(operand, Constant)] - - if len(constants) == 1 or len(expressions) == 1: - expression_usage = ExpressionUsages(expressions[0], tuple(var.ssa_name for var in expressions[0].requirements)) - const: Constant = constants[0] - elif len(constants) == 0 and (zero_case_condition := self.__check_for_zero_case_condition(condition)): - expression_usage, const = zero_case_condition - self._condition_handler.update_z3_condition_of(symbol, Condition(condition.operation, [expression_usage.expression, const])) - else: - return None - if expression_usage not in self._zero_case_of_switch_expression: - self.__add_switch_expression(expression_usage) - return CaseNodeProperties(symbol, expression_usage, const, condition.operation == OperationType.not_equal) - - def __check_for_zero_case_condition(self, condition: Condition) -> Optional[Tuple[ExpressionUsages, Constant]]: - """ - Check whether the condition belongs to a zero-case of a switch expression. - - If this is the case, we return the switch expression and the zero-constant - """ - tuple_ssa_usages = tuple(var.ssa_name for var in condition.requirements) - ssa_usages = set(tuple_ssa_usages) - ssa_condition = None - for expression_usage, zero_case_condition in self._zero_case_of_switch_expression.items(): - if zero_case_condition.ssa_usages != ssa_usages: - continue - if ssa_condition is None: - if (ssa_condition := self.__get_z3_condition(ExpressionUsages(condition, tuple_ssa_usages))) is None: - return None - zero_case_z3_condition = zero_case_condition.z3_condition - if self.__is_equivalent(ssa_condition, zero_case_z3_condition): - return expression_usage, Constant(0, expression_usage.expression.type) - - def __get_z3_condition(self, expression_usage: ExpressionUsages) -> Optional[BoolRef]: - """Get z3-condition of the expression usage in SSA-form if there is one""" - ssa_condition = self.__get_ssa_expression(expression_usage) - assert isinstance(ssa_condition, Condition), f"{ssa_condition} must be of type Condition!" - ssa_condition = ssa_condition.negate() if ssa_condition.operation == OperationType.not_equal else ssa_condition - try: - return self._z3_converter.convert(ssa_condition) - except ValueError: - return None - - @staticmethod - def __is_equivalent(cond1: BoolRef, cond2: BoolRef): - """Check whether the given conditions are equivalent.""" - z3_implementation = Z3Implementation(True) - if z3_implementation.is_equal(cond1, cond2): - return True - return z3_implementation.does_imply(cond1, cond2) and z3_implementation.does_imply(cond2, cond1) diff --git a/decompiler/structures/ast/syntaxforest.py b/decompiler/structures/ast/syntaxforest.py index 30662c4bf..4f8fe856f 100644 --- a/decompiler/structures/ast/syntaxforest.py +++ b/decompiler/structures/ast/syntaxforest.py @@ -16,7 +16,6 @@ VirtualRootNode, ) from decompiler.structures.ast.condition_symbol import ConditionHandler -from decompiler.structures.ast.switch_node_handler import SwitchNodeHandler from decompiler.structures.ast.syntaxgraph import AbstractSyntaxInterface from decompiler.structures.graphs.restructuring_graph.transition_cfg import TransitionBlock from decompiler.structures.logic.logic_condition import LogicCondition @@ -37,7 +36,6 @@ def __init__(self, condition_handler: ConditionHandler): self.condition_handler: ConditionHandler = condition_handler self._current_root: VirtualRootNode = self.factory.create_virtual_node() self._add_node(self._current_root) - self.switch_node_handler: SwitchNodeHandler = SwitchNodeHandler(condition_handler) @property def current_root(self) -> Optional[AbstractSyntaxTreeNode]: @@ -225,7 +223,7 @@ def extract_branch_from_condition_node( """ Extract the given Branch from the condition node. - -> Afterwards, the Branch must always be executed after the condition node. + -> Afterward, the Branch must always be executed after the condition node. """ assert isinstance(cond_node, ConditionNode) and branch in cond_node.children, f"{branch} must be a child of {cond_node}." new_seq = self._add_sequence_node_before(cond_node) @@ -241,26 +239,42 @@ def extract_branch_from_condition_node( if new_seq.parent is not None: new_seq.parent.clean() - def extract_switch_from_condition_sequence(self, switch_node: SwitchNode, condition_node: ConditionNode): - """Extract the given switch-node, that is the first or last child of a seq-node Branch from the condition node""" - seq_node_branch = switch_node.parent - seq_node_branch_children = seq_node_branch.children - assert seq_node_branch.parent in condition_node.children, f"{seq_node_branch} must be a branch of {condition_node}" - new_seq_node = self._add_sequence_node_before(condition_node) - self._remove_edge(seq_node_branch, switch_node) - self._add_edge(new_seq_node, switch_node) - if switch_node is seq_node_branch_children[0]: - new_seq_node._sorted_children = (new_seq_node, condition_node) - seq_node_branch._sorted_children = seq_node_branch_children[1:] - elif switch_node is seq_node_branch_children[-1]: - new_seq_node._sorted_children = (condition_node, new_seq_node) - seq_node_branch._sorted_children = seq_node_branch_children[:-1] - - seq_node_branch.clean() - condition_node.clean() + def extract_switch_from_sequence(self, switch_node: SwitchNode): + """ + Extract the given switch-node, that is the first or last child of a seq-node Branch from the condition node + or sequence node with a non-trivial reaching-condition. + """ + switch_parent = switch_node.parent + assert isinstance(switch_parent, SeqNode), f"The parent of the switch-node {switch_node} must be a sequence node!" + if isinstance(switch_parent, SeqNode) and not switch_parent.reaching_condition.is_true: + new_seq_node = self._extract_switch_from_subtree(switch_parent, switch_node) + elif isinstance(condition_node := switch_parent.parent.parent, ConditionNode): + new_seq_node = self._extract_switch_from_subtree(condition_node, switch_node) + condition_node.clean() + else: + raise ValueError( + f"The parent of the switch node {switch_node} must either have a non-trivial reaching-condition or is a branch of a condition-node!" + ) + if new_seq_node.parent is not None: new_seq_node.parent.clean() + def _extract_switch_from_subtree(self, subtree_head: AbstractSyntaxTreeNode, switch_node: SwitchNode): + switch_parent = switch_node.parent + switch_parent_children = switch_node.children + new_seq_node = self._add_sequence_node_before(subtree_head) + self._remove_edge(switch_parent, switch_node) + self._add_edge(new_seq_node, switch_node) + if switch_node is switch_parent_children[0]: + new_seq_node._sorted_children = (new_seq_node, subtree_head) + switch_parent._sorted_children = switch_parent_children[1:] + elif switch_node is switch_parent_children[-1]: + new_seq_node._sorted_children = (subtree_head, new_seq_node) + switch_parent._sorted_children = switch_parent_children[:-1] + + switch_parent.clean() + return new_seq_node + def extract_all_breaks_from_condition_node(self, cond_node: ConditionNode): """Remove all break instructions at the end of the condition node and extracts them, i.e., add a break after the condition.""" for node in cond_node.get_end_nodes(): @@ -323,11 +337,43 @@ def __create_branch_for(self, branch_nodes: List[AbstractSyntaxTreeNode], condit else: branch = self.add_seq_node_with_reaching_condition_before(branch_nodes, self.condition_handler.get_true_value()) for node in branch_nodes: - node.reaching_condition.substitute_by_true(condition) + if node.reaching_condition.is_true and isinstance(node, ConditionNode): + node.condition.substitute_by_true(condition) + else: + node.reaching_condition.substitute_by_true(condition) self._remove_edge(branch.parent, branch) return branch + def add_branches_to_condition_node( + self, + condition_node: ConditionNode, + true_branch: Optional[AbstractSyntaxTreeNode] = None, + false_branch: Optional[AbstractSyntaxTreeNode] = None, + ): + """ + Add the given branches to the given condition-node. + + -> true-branch will be part of the true-branch of the condition node after this transformation + -> false-branch will be part of the false-branch of the condition node after this transformation + - since a clean-condition-node always has a true-branch but may not have a false-branch, we have to add it if it does not exist. + """ + if true_branch: + self._remove_edge(true_branch.parent, true_branch) + new_seq_node = self._add_sequence_node_before(condition_node.true_branch_child) + self._add_edge(new_seq_node, true_branch) + new_seq_node.clean() + if false_branch: + self._remove_edge(false_branch.parent, false_branch) + if condition_node.false_branch is None: + false_node = self.factory.create_false_node() + self._add_node(false_node) + self._add_edges_from(((condition_node, false_node), (false_node, false_branch))) + else: + new_seq_node = self._add_sequence_node_before(condition_node.false_branch_child) + self._add_edge(new_seq_node, false_branch) + new_seq_node.clean() + def create_switch_node_with(self, expression: Expression, cases: List[Tuple[CaseNode, AbstractSyntaxTreeNode]]) -> SwitchNode: """Create a switch node with the given expression and the given list of case nodes.""" assert (parent := self.have_same_parent([case[1] for case in cases])) is not None, "All case nodes must have the same parent." diff --git a/decompiler/structures/ast/syntaxgraph.py b/decompiler/structures/ast/syntaxgraph.py index 6f6a42fa4..765c14e2b 100644 --- a/decompiler/structures/ast/syntaxgraph.py +++ b/decompiler/structures/ast/syntaxgraph.py @@ -330,7 +330,7 @@ def clean_up(self, root: Optional[AbstractSyntaxTreeNode] = None) -> None: node.clean() def replace_condition_node_by_single_branch(self, node: ConditionNode): - """This function replaces the given AST- condition node by its single child in the AST.""" + """This function replaces the given AST-condition node by its single child in the AST.""" assert isinstance(node, ConditionNode), f"This transformation works only for condition nodes!" assert len(node.children) == 1, f"This works only if the Condition node has only one child!" node.clean() diff --git a/decompiler/structures/pseudo/expressions.py b/decompiler/structures/pseudo/expressions.py index 6e3e24aca..2f0eeecc9 100644 --- a/decompiler/structures/pseudo/expressions.py +++ b/decompiler/structures/pseudo/expressions.py @@ -35,7 +35,8 @@ from ...util.insertion_ordered_set import InsertionOrderedSet from .complextypes import Enum, Struct -from .typing import ArrayType, CustomType, Type, UnknownType +from .typing import CustomType, Type, UnknownType + T = TypeVar("T") DecompiledType = TypeVar("DecompiledType", bound=Type) @@ -58,18 +59,6 @@ class DataflowObject(ABC): def __init__(self, tags: Optional[Tuple[Tag, ...]] = None): self.tags = tags - def __eq__(self, other) -> bool: - """Check for equality.""" - return type(other) == type(self) and hash(self) == hash(other) - - def __hash__(self) -> int: - """Return a hash value for the expression.""" - return hash(repr(self)) - - def __repr__(self): - """Return a debug representation.""" - return str(self) - @abstractmethod def __iter__(self) -> Iterator[DataflowObject]: """Iterate all nested DataflowObjects.""" @@ -149,6 +138,12 @@ def __init__(self, msg: str, tags: Optional[Tuple[Tag, ...]] = None): self.msg = msg super().__init__(tags) + def __eq__(self, __value): + return isinstance(__value, UnknownExpression) and self.msg == __value.msg + + def __hash__(self): + return hash(self.msg) + def __str__(self) -> str: """Return the error message as string representation.""" return self.msg @@ -183,6 +178,17 @@ def __init__( self._pointee = pointee super().__init__(tags) + def __eq__(self, __value): + return ( + isinstance(__value, Constant) + and self.value == __value.value + and self._type == __value._type + and self._pointee == __value.pointee + ) + + def __hash__(self): + return hash((tuple(self.value) if isinstance(self.value, list) else self.value, self._type, self._pointee)) + def __repr__(self) -> str: value = str(self) if isinstance(self.value, str) else self.value if self.pointee: @@ -235,6 +241,12 @@ class NotUseableConstant(Constant): def __init__(self, value: str, tags: Optional[Tuple[Tag, ...]] = None): super().__init__(value, CustomType("double", 0), tags=tags) + def __eq__(self, __value): + return isinstance(__value, NotUseableConstant) and self.value == __value.value + + def __hash__(self): + return hash(self.value) + def __str__(self) -> str: """Return a string because NotUseableConstant are string only""" return self.value @@ -255,6 +267,12 @@ def __init__(self, name: str, value: Union[int, float], vartype: Type = UnknownT super().__init__(value, vartype, tags=tags) self._name = name + def __eq__(self, __value): + return isinstance(__value, Symbol) and self._name == __value._name and self.value == __value.value + + def __hash__(self): + return hash((self._name, self.value)) + @property def name(self) -> str: return self._name @@ -278,6 +296,12 @@ def copy(self) -> Symbol: class FunctionSymbol(Symbol): """Represents a function name""" + def __eq__(self, __value): + return isinstance(__value, FunctionSymbol) and super().__eq__(__value) + + def __hash__(self): + return super().__hash__() + def copy(self) -> FunctionSymbol: return FunctionSymbol(self.name, self.value, self._type.copy(), self.tags) @@ -285,6 +309,12 @@ def copy(self) -> FunctionSymbol: class ImportedFunctionSymbol(FunctionSymbol): """Represents an imported function name""" + def __eq__(self, __value): + return isinstance(__value, ImportedFunctionSymbol) and super().__eq__(__value) + + def __hash__(self): + return super().__hash__() + def copy(self) -> ImportedFunctionSymbol: return ImportedFunctionSymbol(self._name, self.value, self._type.copy(), self.tags) @@ -297,6 +327,12 @@ class IntrinsicSymbol(FunctionSymbol): def __init__(self, name: str): super().__init__(name, self.INTRINSIC_ADDRESS) + def __eq__(self, __value): + return isinstance(__value, IntrinsicSymbol) and self.name == __value.name + + def __hash__(self): + return hash(self.name) + def __repr__(self): return f"intrinsic '{self.name}'" @@ -324,6 +360,18 @@ def __init__( self.ssa_name = ssa_name super().__init__(tags) + def __eq__(self, __value): + return ( + isinstance(__value, Variable) + and self._name == __value._name + and self.ssa_label == __value.ssa_label + and self._type == __value._type + and self.is_aliased == __value.is_aliased + ) + + def __hash__(self): + return hash((self._name, self.ssa_label, self._type, self.is_aliased)) + def __repr__(self) -> str: """Return a debug representation of the variable, which includes all the attributes""" return f"{self.name}#{self.ssa_label} (type: {self.type} aliased: {self.is_aliased})" @@ -399,6 +447,12 @@ def __init__( self.initial_value = initial_value self.is_constant = is_constant + def __eq__(self, __value): + return isinstance(__value, GlobalVariable) and super().__eq__(__value) + + def __hash__(self): + return super().__hash__() + def copy( self, name: str = None, @@ -445,6 +499,14 @@ def __init__(self, high: Variable, low: Variable, vartype: Type = UnknownType(), self._low = low self._type = vartype + def __eq__(self, __value): + return ( + isinstance(__value, RegisterPair) and self._high == __value._high and self._low == __value._low and self._type == __value._type + ) + + def __hash__(self): + return hash((self._high, self._low, self._type)) + def __repr__(self) -> str: """Return debug representation of register pair""" return f"{repr(self._high)}:{repr(self._low)} type: {self.type}" @@ -507,6 +569,12 @@ def __init__(self, value: list[Constant], vartype: DecompiledType = UnknownType( tags, ) + def __eq__(self, __value): + return isinstance(__value, ConstantComposition) and super().__eq__(__value) + + def __hash__(self): + return super().__hash__() + def __str__(self) -> str: """Return a string representation of the ConstantComposition""" return "{" + ",".join([str(x) for x in self.value]) + "}" diff --git a/decompiler/structures/pseudo/instructions.py b/decompiler/structures/pseudo/instructions.py index 75625e392..aa1b69dda 100644 --- a/decompiler/structures/pseudo/instructions.py +++ b/decompiler/structures/pseudo/instructions.py @@ -57,6 +57,12 @@ def __init__(self, comment: str, comment_style: str = "C", tags: Optional[Tuple[ self._comment_style = comment_style self._open_comment, self._close_comment = self.STYLES.get(comment_style, self.STYLES[self.DEFAULT_STYLE]) + def __eq__(self, __value): + return isinstance(__value, Comment) and self._comment == __value._comment and self._comment_style == __value._comment_style + + def __hash__(self): + return hash((self._comment, self._comment_style)) + def __repr__(self) -> str: """Return representation of comment.""" return f"{self._open_comment} {self._comment} {self._close_comment}" @@ -161,6 +167,12 @@ def __init__(self, destination: Expression, value: Expression, tags: Optional[Tu """Init a new Assignment.""" super(Assignment, self).__init__(destination, value, tags=tags) + def __eq__(self, __value): + return isinstance(__value, Assignment) and self._destination == __value._destination and self._value == __value._value + + def __hash__(self): + return hash((self._destination, self._value)) + def __str__(self) -> str: """Return a string representation starting with the lhs.""" if isinstance(self._destination, ListOperation) and not self._destination.operands: @@ -211,6 +223,12 @@ def __init__(self, destination: Variable, value: Variable, tags: Optional[Tuple[ """Init a new Relation.""" super(Relation, self).__init__(destination, value, tags=tags) + def __eq__(self, __value): + return isinstance(__value, Relation) and self._destination == __value._destination and self._value == __value._value + + def __hash__(self): + return hash((self._destination, self._value)) + def __str__(self) -> str: """Return a string representation starting with the lhs.""" return f"{self.destination} -> {self.value}" @@ -314,6 +332,12 @@ def __init__(self, condition: Condition, tags: Optional[Tuple[Tag, ...]] = None) """Init a new branch instruction.""" super(Branch, self).__init__(condition, tags=tags) + def __eq__(self, __value): + return isinstance(__value, Branch) and self._condition == __value._condition + + def __hash__(self): + return hash(self._condition) + def __repr__(self) -> str: """Return a debug representation of a branch""" return f"if {repr(self.condition)}" @@ -333,6 +357,12 @@ def __init__(self, condition: Expression, tags: Optional[Tuple[Tag, ...]] = None """Init a new branch instruction.""" super(IndirectBranch, self).__init__(condition, tags=tags) + def __eq__(self, __value): + return isinstance(__value, IndirectBranch) and self._condition + + def __hash__(self): + return hash(self._condition) + def __repr__(self) -> str: """Return a debug representation of a branch""" return f"jmp {repr(self.condition)}" @@ -355,6 +385,12 @@ def __init__(self, values, tags: Optional[Tuple[Tag, ...]] = None): super().__init__(tags) self._values = ListOperation(values) + def __eq__(self, __value): + return isinstance(__value, Return) and self._values == __value._values + + def __hash__(self): + return hash(self._values) + def __repr__(self) -> str: return f"return {repr(self._values)}" @@ -395,6 +431,12 @@ def accept(self, visitor: DataflowObjectVisitorInterface[T]) -> T: class Break(Instruction): + def __eq__(self, __value): + return isinstance(__value, Break) + + def __hash__(self): + return hash(Break) + def __iter__(self) -> Iterator[Expression]: yield from () @@ -417,6 +459,12 @@ def accept(self, visitor: DataflowObjectVisitorInterface[T]) -> T: class Continue(Instruction): + def __eq__(self, __value): + return isinstance(__value, Continue) + + def __hash__(self): + return hash(Continue) + def __iter__(self) -> Iterator[Expression]: yield from () @@ -457,6 +505,12 @@ def __init__( self._origin_block = origin_block if origin_block else {} super().__init__(destination, ListOperation(value), tags=tags) + def __eq__(self, __value): + return isinstance(__value, Phi) and self._destination == __value._destination and self._value == __value._value + + def __hash__(self): + return hash((self._destination, self._value)) + def __repr__(self): return f"{repr(self.destination)} = ϕ({repr(self.value)})" @@ -516,6 +570,12 @@ class MemPhi(Phi): def __init__(self, destination_var: Variable, source_vars: Sequence[Variable], tags: Optional[Tuple[Tag, ...]] = None): super().__init__(destination_var, source_vars, tags=tags) + def __eq__(self, __value): + return isinstance(__value, MemPhi) and super().__eq__(__value) + + def __hash__(self): + return super().__hash__() + def __str__(self) -> str: return f"{self.destination} = ϕ({self.value})" diff --git a/decompiler/structures/pseudo/operations.py b/decompiler/structures/pseudo/operations.py index 3c05d6f95..127263214 100644 --- a/decompiler/structures/pseudo/operations.py +++ b/decompiler/structures/pseudo/operations.py @@ -193,6 +193,17 @@ def __init__( self._type = vartype super().__init__(tags) + def __eq__(self, __value): + return ( + isinstance(__value, Operation) + and self._operation == __value._operation + and self._operands == __value._operands + and self.type == __value.type + ) + + def __hash__(self): + return hash((self._operation, tuple(self._operands), self.type)) + def __repr__(self) -> str: """Return debug representation of an operation. Used in equality checks""" return f"{self.operation.name} [{','.join(map(repr, self._operands))}] {self.type}" @@ -267,6 +278,12 @@ class ListOperation(Operation): def __init__(self, operands: Sequence[Expression], tags: Optional[Tuple[Tag, ...]] = None): super().__init__(OperationType.list_op, operands, tags=tags) + def __eq__(self, __value): + return isinstance(__value, ListOperation) and super().__eq__(__value) + + def __hash__(self): + return super().__hash__() + def __str__(self) -> str: return ",".join(map(str, self.operands)) @@ -283,7 +300,7 @@ def accept(self, visitor: DataflowObjectVisitorInterface[T]) -> T: return visitor.visit_list_operation(self) -@dataclass +@dataclass(unsafe_hash=True) class ArrayInfo: """Class to store array info information for dereference if available base: variable storing start address of an array @@ -331,6 +348,17 @@ def __init__( self.contraction = contraction self.array_info = array_info + def __eq__(self, __value): + return ( + isinstance(__value, UnaryOperation) + and self.contraction == __value.contraction + and self.array_info == __value.array_info + and super().__eq__(__value) + ) + + def __hash__(self): + return hash((self.contraction, self.array_info, super().__hash__())) + def __str__(self): """Return a string representation of the unary operation""" if self.operation == OperationType.cast and self.contraction: @@ -401,6 +429,12 @@ def __init__( self.member_offset = offset self.member_name = member_name + def __eq__(self, __value): + return isinstance(__value, MemberAccess) and super().__eq__(__value) + + def __hash__(self): + return super().__hash__() + def __str__(self): # use -> when accessing member via a pointer to a struct: ptrBook->title # use . when accessing struct member directly: book.title @@ -441,6 +475,12 @@ class BinaryOperation(Operation): __match_args__ = ("operation", "left", "right") + def __eq__(self, __value): + return isinstance(__value, BinaryOperation) and super().__eq__(__value) + + def __hash__(self): + return super().__hash__() + def __str__(self) -> str: """Return a string representation with infix notation.""" str_left = f"({self.left})" if isinstance(self.left, Operation) else f"{self.left}" @@ -484,6 +524,12 @@ def __init__( self._writes_memory = writes_memory self._meta_data = meta_data + def __eq__(self, __value): + return isinstance(__value, Call) and self._function == __value._function and self._operands == __value._operands + + def __hash__(self): + return hash((self._function, tuple(self._operands))) + def __repr__(self): """Return debug representation of a call""" if self._meta_data is not None: @@ -574,6 +620,13 @@ class Condition(BinaryOperation): OperationType.less_us: OperationType.greater_or_equal_us, } + def __eq__(self, __value): + v_ = isinstance(__value, Condition) and super().__eq__(__value) + return v_ + + def __hash__(self): + return super().__hash__() + @property def type(self) -> Type: """Conditions always return a boolean value.""" @@ -608,6 +661,12 @@ def __init__(self, condition: Expression, true: Expression, false: Expression, t """Initialize a new inline-if operation.""" super().__init__(OperationType.ternary, [condition, true, false], true.type, tags=tags) + def __eq__(self, __value): + return isinstance(__value, TernaryExpression) and super().__eq__(__value) + + def __hash__(self): + return super().__hash__() + def __str__(self) -> str: """Returns string representation""" return f"{self.condition} ? {self.true} : {self.false}" diff --git a/decompiler/util/to_dot_converter.py b/decompiler/util/to_dot_converter.py index b1d459d71..8275e70e1 100644 --- a/decompiler/util/to_dot_converter.py +++ b/decompiler/util/to_dot_converter.py @@ -2,14 +2,14 @@ from networkx import DiGraph -HEADER = "strict digraph {" +HEADER = "digraph {" FOOTER = "}" class ToDotConverter: """Class in charge of writing a networkx DiGraph into dot-format""" - ATTRIBUTES = {"color", "fillcolor", "label", "shape", "style"} + ATTRIBUTES = {"color", "fillcolor", "label", "shape", "style", "dir"} def __init__(self, graph: DiGraph): self._graph = graph diff --git a/tests/pipeline/SSA/test_out_of_ssa_renaming.py b/tests/pipeline/SSA/test_out_of_ssa_renaming.py index 4c7dab76d..bd24c15ee 100644 --- a/tests/pipeline/SSA/test_out_of_ssa_renaming.py +++ b/tests/pipeline/SSA/test_out_of_ssa_renaming.py @@ -1,8 +1,16 @@ """Pytest for renaming SSA-variables to non-SSA-variables.""" +import string + from decompiler.pipeline.ssa.phi_lifting import PhiFunctionLifter -from decompiler.pipeline.ssa.variable_renaming import MinimalVariableRenamer, SimpleVariableRenamer, VariableRenamer +from decompiler.pipeline.ssa.variable_renaming import ( + ConditionalVariableRenamer, + MinimalVariableRenamer, + SimpleVariableRenamer, + VariableRenamer, +) from decompiler.structures.interferencegraph import InterferenceGraph +from decompiler.structures.pseudo import Expression, Float, GlobalVariable from tests.pipeline.SSA.utils_out_of_ssa_tests import * @@ -492,6 +500,23 @@ def test_minimal_renaming_basic_relation(graph_with_relations_easy, variable): } +def test_conditional_renaming_basic_relation(graph_with_relations_easy, variable): + """Checks that conditional renaming can handle relations.""" + task, interference_graph = graph_with_relations_easy + minimal_variable_renamer = MinimalVariableRenamer(task, interference_graph) + + var_18 = [Variable("var_18", Integer(32, True), i, True, None) for i in range(4)] + var_10_1 = Variable("var_10", Pointer(Integer(32, True), 32), 1, False, None) + variable[0].is_aliased = True + variable[1]._type = Pointer(Integer(32, True), 32) + + assert minimal_variable_renamer.renaming_map == { + var_10_1: variable[1], + var_18[2]: variable[0], + var_18[3]: variable[0], + } + + @pytest.fixture() def graph_with_relation() -> Tuple[DecompilerTask, InterferenceGraph]: """ @@ -772,3 +797,112 @@ def test_minimal_renaming_relation(graph_with_relation, variable): var_1c[3]: variable[1], var_1c[4]: variable[1], } + + +def test_conditional_renaming_relation(graph_with_relation, variable): + """Test for relations with simple renaming""" + task, interference_graph = graph_with_relation + conditional_variable_renamer = ConditionalVariableRenamer(task, interference_graph) + + var_28 = Variable("var_28", Pointer(Integer(32, True), 32), 1, False, None) + var_1c = [Variable("var_1c", Integer(32, True), i, True, None) for i in range(5)] + edx_3 = Variable("edx_3", Integer(32, True), 4, False, None) + eax_7 = Variable("eax_7", Integer(32, True), 8, False, None) + variable[0].is_aliased = True + variable[1]._type = Pointer(Integer(32, True), 32) + variable[2].is_aliased = True + + assert conditional_variable_renamer.renaming_map == { + var_28: variable[1], + edx_3: variable[3], + eax_7: variable[3], + var_1c[0]: variable[0], + var_1c[2]: variable[0], + var_1c[3]: variable[2], + var_1c[4]: variable[2], + } + + +def test_conditional_renaming(): + """Test that conditional renaming only combines related variables""" + orig_variables = [Variable(letter, Integer.int32_t()) for letter in string.ascii_lowercase] + new_variables = [Variable(f"var_{index}", Integer.int32_t()) for index in range(10)] + + cfg = ControlFlowGraph() + cfg.add_node( + BasicBlock( + 0, + [ + Assignment(orig_variables[0], Constant(0, Integer.int32_t())), + Assignment(ListOperation([]), Call(FunctionSymbol("fun", 0), [orig_variables[0]])), + Assignment(orig_variables[1], Constant(1, Integer.int32_t())), + Assignment(ListOperation([]), Call(FunctionSymbol("fun", 0), [orig_variables[1]])), + Assignment(orig_variables[2], orig_variables[1]), + Assignment(ListOperation([]), Call(FunctionSymbol("fun", 0), [orig_variables[2]])), + Assignment(orig_variables[3], Constant(3, Integer.int32_t())), + Assignment(ListOperation([]), Call(FunctionSymbol("fun", 0), [orig_variables[3]])), + ], + ) + ) + + task = decompiler_task(cfg, SSAOptions.conditional) + interference_graph = InterferenceGraph(cfg) + renamer = ConditionalVariableRenamer(task, interference_graph) + + assert renamer.renaming_map == { + orig_variables[0]: new_variables[0], + orig_variables[1]: new_variables[2], + orig_variables[2]: new_variables[2], + orig_variables[3]: new_variables[1], + } + + +def test_conditional_parallel_edges(): + """ + Test that conditional renaming prioritizes paralles edges of single edges, whose sum of + weights is bigger than the weight of the single edge + """ + + def _v(name: str) -> Variable: + return Variable(name, Float.float()) + + def _c(value: float) -> Constant: + return Constant(value, Float.float()) + + def _op(exp: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.plus_float, [exp, _c(0)]) + + cfg = ControlFlowGraph() + cfg.add_node( + b1 := BasicBlock( + 1, + [ + Assignment(_v("b"), _op(BinaryOperation(OperationType.plus_float, [_v("a0"), GlobalVariable("g0", Float.float(), _c(0))]))), + Assignment(_v("c"), _v("b")), + Assignment(_v("a1"), BinaryOperation(OperationType.plus_float, [_op(_v("b")), _v("c")])), + Assignment(_v("a0"), _v("a1")), # lifted phi function + ], + ) + ) + cfg.add_node( + b0 := BasicBlock( + 0, + [ + # Phi(_v("a0"), [_c(0), _v("a1")], origin_block={b1: _v("a1")}), + Branch(Condition(OperationType.less, [_v("a0"), _c(100)])) + ], + ) + ) + cfg.add_node(b2 := BasicBlock(2, [Return([])])) + + cfg.add_edge(TrueCase(b0, b1)) + cfg.add_edge(FalseCase(b0, b2)) + cfg.add_edge(UnconditionalEdge(b1, b0)) + + task = decompiler_task(cfg, SSAOptions.conditional) + interference_graph = InterferenceGraph(cfg) + renamer = ConditionalVariableRenamer(task, interference_graph) + + assert frozenset(frozenset(c) for c in renamer._variable_classes_handler.variable_class.values()) == frozenset( + {frozenset({GlobalVariable("g0", Float.float(), _c(0))}), frozenset({_v("c")}), frozenset({_v("a0"), _v("a1"), _v("b")})} + ) diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py b/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py index 44ff99ea7..03ac11327 100644 --- a/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py +++ b/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py @@ -93,6 +93,18 @@ def test_constant_fold_invalid_value_type( (OperationType.divide_us, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)), (OperationType.divide_us, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), (OperationType.divide_us, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.modulo, [_c_i32(13), _c_i32(4)], Integer.int32_t(), _c_i32(1), nullcontext()), + (OperationType.modulo, [_c_i32(-2147483647), _c_i32(2)], Integer.int32_t(), _c_i32(-1), nullcontext()), + (OperationType.modulo, [_c_u32(4), _c_i32(3)], Integer.int32_t(), _c_i32(1), nullcontext()), + (OperationType.modulo, [_c_i32(4), _c_i16(3)], Integer.int32_t(), _c_i32(1), nullcontext()), + (OperationType.modulo, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.modulo, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.modulo_us, [_c_i32(13), _c_i32(4)], Integer.int32_t(), _c_i32(1), nullcontext()), + (OperationType.modulo_us, [_c_i32(-2147483647), _c_i32(2)], Integer.int32_t(), _c_i32(1), nullcontext()), + (OperationType.modulo_us, [_c_u32(4), _c_i32(3)], Integer.int32_t(), _c_i32(1), nullcontext()), + (OperationType.modulo_us, [_c_i32(4), _c_i16(3)], Integer.int32_t(), _c_i32(1), nullcontext()), + (OperationType.modulo_us, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.modulo_us, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), (OperationType.negate, [_c_i32(3)], Integer.int32_t(), _c_i32(-3), nullcontext()), (OperationType.negate, [_c_i32(-2147483648)], Integer.int32_t(), _c_i32(-2147483648), nullcontext()), (OperationType.negate, [], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), diff --git a/tests/pipeline/controlflowanalysis/restructuring_commons/test_condition_aware_refinement.py b/tests/pipeline/controlflowanalysis/restructuring_commons/test_condition_aware_refinement.py index 90d99a864..62dc48bf8 100644 --- a/tests/pipeline/controlflowanalysis/restructuring_commons/test_condition_aware_refinement.py +++ b/tests/pipeline/controlflowanalysis/restructuring_commons/test_condition_aware_refinement.py @@ -4156,12 +4156,15 @@ def test_only_one_occurrence_of_each_case(task): assert isinstance(switch := seq_node.children[1], SwitchNode) and len(switch.cases) == 7 assert isinstance(case1 := switch.cases[0], CaseNode) and case1.constant.value == 1 and isinstance(case1_seq := case1.child, SeqNode) assert all(case1.constant != case2.constant for case1, case2 in combinations(switch.cases, 2)) - assert len(case1_seq.children) == 3 - assert isinstance(cn := case1_seq.children[0], ConditionNode) and cn.false_branch is None + assert len(case1_seq.children) == 2 + assert isinstance(cn := case1_seq.children[0], ConditionNode) + assert cn.condition.is_literal + pseudo_condition = task.ast.condition_map[cn.condition] if cn.condition.is_symbol else task.ast.condition_map[~cn.condition].negate() + if pseudo_condition.operation != OperationType.equal: + cn.switch_branches() assert isinstance(tb := cn.true_branch_child, CodeNode) and tb.instructions == vertices[3].instructions - assert isinstance(cn := case1_seq.children[1], ConditionNode) and cn.false_branch is None - assert isinstance(tb := cn.true_branch_child, CodeNode) and tb.instructions == vertices[17].instructions - assert isinstance(cn := case1_seq.children[2], CodeNode) and cn.instructions == vertices[5].instructions + assert isinstance(fb := cn.false_branch_child, CodeNode) and fb.instructions == vertices[17].instructions + assert isinstance(cn := case1_seq.children[1], CodeNode) and cn.instructions == vertices[5].instructions assert isinstance(seq_node.children[2], CodeNode) and seq_node.children[2].instructions == vertices[18].instructions @@ -5716,10 +5719,10 @@ def test_intersecting_cases(task): PatternIndependentRestructuring().run(task) assert len(list(task.ast.get_switch_nodes_post_order())) == 2 - assert isinstance(seq_node := task.ast.root, SeqNode) and len(children := seq_node.children) == 6 - assert isinstance(children[0], CodeNode) and isinstance(children[5], CodeNode) - assert all(isinstance(child, ConditionNode) for child in children[1:4]) - assert isinstance(children[4], SwitchNode) and isinstance(children[3].true_branch_child, SwitchNode) + assert isinstance(seq_node := task.ast.root, SeqNode) and len(children := seq_node.children) == 5 + assert isinstance(children[0], CodeNode) and isinstance(children[4], CodeNode) + assert all(isinstance(child, ConditionNode) for child in children[1:3]) + assert isinstance(children[3], SwitchNode) and isinstance(children[2].true_branch_child, SwitchNode) def test_missing_cases_switch_in_sequence(task): diff --git a/tests/pipeline/controlflowanalysis/test_pattern_independent_restructuring.py b/tests/pipeline/controlflowanalysis/test_pattern_independent_restructuring.py index 265f67b1c..0146743ff 100644 --- a/tests/pipeline/controlflowanalysis/test_pattern_independent_restructuring.py +++ b/tests/pipeline/controlflowanalysis/test_pattern_independent_restructuring.py @@ -4932,107 +4932,188 @@ def test_extract_return(task): assert branch.instructions == vertices[3].instructions -# fix in Issue 28 -# def test_hash_eq_problem(task): -# """ -# Hash and eq are not the same, therefore we have to be careful which one we want: -# -# - eq: Same condition node in sense of same condition -# - hash: same node in the graph -# """ -# arg1 = Variable("arg1", Integer.int32_t(), ssa_name=Variable("arg1", Integer.int32_t(), 0)) -# arg2 = Variable("arg2", Integer.int32_t(), ssa_name=Variable("arg2", Integer.int32_t(), 0)) -# var_2 = Variable("var_2", Integer.int32_t(), None, True, Variable("rax_1", Integer.int32_t(), 1, True, None)) -# var_5 = Variable("var_5", Integer.int32_t(), None, True, Variable("rax_2", Integer.int32_t(), 2, True, None)) -# var_6 = Variable("var_6", Integer.int32_t(), None, True, Variable("rax_5", Integer.int32_t(), 30, True, None)) -# var_7 = Variable("var_7", Integer.int32_t(), None, True, Variable("rax_3", Integer.int32_t(), 3, True, None)) -# task.graph.add_nodes_from( -# vertices := [ -# BasicBlock(0, instructions=[Branch(Condition(OperationType.equal, [arg1, Constant(1, Integer.int32_t())]))]), -# BasicBlock( -# 1, -# instructions=[ -# Assignment(var_2, BinaryOperation(OperationType.plus, [var_2, Constant(1, Integer.int32_t())])), -# Branch(Condition(OperationType.not_equal, [var_2, Constant(0, Integer.int32_t())])), -# ], -# ), -# BasicBlock( -# 2, -# instructions=[ -# Assignment(ListOperation([]), Call(imp_function_symbol("sub_140019288"), [arg2])), -# Branch(Condition(OperationType.equal, [arg1, Constant(0, Integer.int32_t())])), -# ], -# ), -# BasicBlock( -# 3, -# instructions=[ -# Assignment(ListOperation([]), Call(imp_function_symbol("scanf"), [Constant(0x804B01F), var_5])), -# Branch(Condition(OperationType.not_equal, [var_5, Constant(0, Integer.int32_t())])), -# ], -# ), -# BasicBlock( -# 4, instructions=[Assignment(var_5, Constant(0, Integer.int32_t())), Assignment(var_7, Constant(-1, Integer.int32_t()))] -# ), -# BasicBlock( -# 5, -# instructions=[ -# Assignment(var_5, Constant(0, Integer.int32_t())), -# Assignment(var_7, Constant(-1, Integer.int32_t())), -# Assignment(arg1, Constant(0, Integer.int32_t())), -# Assignment(var_2, Constant(0, Integer.int32_t())), -# ], -# ), -# BasicBlock( -# 6, -# instructions=[ -# Assignment(var_5, Constant(0, Integer.int32_t())), -# Assignment(var_7, Constant(-1, Integer.int32_t())), -# Assignment(var_2, Constant(0, Integer.int32_t())), -# ], -# ), -# BasicBlock(7, instructions=[Assignment(ListOperation([]), Call(imp_function_symbol("sub_1400193a8"), []))]), -# BasicBlock( -# 8, -# instructions=[ -# Assignment(ListOperation([]), Call(imp_function_symbol("scanf"), [Constant(0x804B01F), var_6])), -# Branch(Condition(OperationType.greater_us, [var_6, Constant(0, Integer.int32_t())])), -# ], -# ), -# BasicBlock(9, instructions=[Assignment(arg1, Constant(1, Integer.int32_t()))]), -# BasicBlock(10, instructions=[Return([arg1])]), -# ] -# ) -# task.graph.add_edges_from( -# [ -# TrueCase(vertices[0], vertices[1]), -# FalseCase(vertices[0], vertices[2]), -# TrueCase(vertices[1], vertices[3]), -# FalseCase(vertices[1], vertices[4]), -# TrueCase(vertices[2], vertices[5]), -# FalseCase(vertices[2], vertices[6]), -# TrueCase(vertices[3], vertices[7]), -# FalseCase(vertices[3], vertices[8]), -# UnconditionalEdge(vertices[4], vertices[7]), -# UnconditionalEdge(vertices[5], vertices[10]), -# UnconditionalEdge(vertices[6], vertices[9]), -# UnconditionalEdge(vertices[7], vertices[9]), -# TrueCase(vertices[8], vertices[9]), -# FalseCase(vertices[8], vertices[10]), -# UnconditionalEdge(vertices[9], vertices[10]), -# ] -# ) -# PatternIndependentRestructuring().run(task) -# assert any(isinstance(node, SwitchNode) for node in task.syntax_tree) -# var_2_conditions = [] -# for node in task.syntax_tree.get_condition_nodes_post_order(): -# if ( -# not node.condition.is_symbol -# and node.condition.is_literal -# and str(task.syntax_tree.condition_map[~node.condition]) in {"var_2 != 0x0"} -# ): -# node.switch_branches() -# if node.condition.is_symbol and str(task.syntax_tree.condition_map[node.condition]) in {"var_2 != 0x0"}: -# var_2_conditions.append(node) -# assert len(var_2_conditions) == 2 -# assert var_2_conditions[0] == var_2_conditions[1] -# assert hash(var_2_conditions[0]) != hash(var_2_conditions[1]) +def test_hash_eq_problem(task): + """ + Hash and eq are not the same, therefore we have to be careful which one we want: + + - eq: Same condition node in sense of same condition + - hash: same node in the graph + """ + arg1 = Variable("arg1", Integer.int32_t(), ssa_name=Variable("arg1", Integer.int32_t(), 0)) + arg2 = Variable("arg2", Integer.int32_t(), ssa_name=Variable("arg2", Integer.int32_t(), 0)) + var_2 = Variable("var_2", Integer.int32_t(), None, True, Variable("rax_1", Integer.int32_t(), 1, True, None)) + var_5 = Variable("var_5", Integer.int32_t(), None, True, Variable("rax_2", Integer.int32_t(), 2, True, None)) + var_6 = Variable("var_6", Integer.int32_t(), None, True, Variable("rax_5", Integer.int32_t(), 30, True, None)) + var_7 = Variable("var_7", Integer.int32_t(), None, True, Variable("rax_3", Integer.int32_t(), 3, True, None)) + task.graph.add_nodes_from( + vertices := [ + BasicBlock(0, instructions=[Branch(Condition(OperationType.equal, [arg1, Constant(1, Integer.int32_t())]))]), + BasicBlock( + 1, + instructions=[ + Assignment(var_2, BinaryOperation(OperationType.plus, [var_2, Constant(1, Integer.int32_t())])), + Branch(Condition(OperationType.not_equal, [var_2, Constant(0, Integer.int32_t())])), + ], + ), + BasicBlock( + 2, + instructions=[ + Assignment(ListOperation([]), Call(imp_function_symbol("sub_140019288"), [arg2])), + Branch(Condition(OperationType.equal, [arg1, Constant(0, Integer.int32_t())])), + ], + ), + BasicBlock( + 3, + instructions=[ + Assignment(ListOperation([]), Call(imp_function_symbol("scanf"), [Constant(0x804B01F), var_5])), + Branch(Condition(OperationType.not_equal, [var_5, Constant(0, Integer.int32_t())])), + ], + ), + BasicBlock( + 4, instructions=[Assignment(var_5, Constant(0, Integer.int32_t())), Assignment(var_7, Constant(-1, Integer.int32_t()))] + ), + BasicBlock( + 5, + instructions=[ + Assignment(var_5, Constant(0, Integer.int32_t())), + Assignment(var_7, Constant(-1, Integer.int32_t())), + Assignment(arg1, Constant(0, Integer.int32_t())), + Assignment(var_2, Constant(0, Integer.int32_t())), + ], + ), + BasicBlock( + 6, + instructions=[ + Assignment(var_5, Constant(0, Integer.int32_t())), + Assignment(var_7, Constant(-1, Integer.int32_t())), + Assignment(var_2, Constant(0, Integer.int32_t())), + ], + ), + BasicBlock(7, instructions=[Assignment(ListOperation([]), Call(imp_function_symbol("sub_1400193a8"), []))]), + BasicBlock( + 8, + instructions=[ + Assignment(ListOperation([]), Call(imp_function_symbol("scanf"), [Constant(0x804B01F), var_6])), + Branch(Condition(OperationType.greater_us, [var_6, Constant(0, Integer.int32_t())])), + ], + ), + BasicBlock(9, instructions=[Assignment(arg1, Constant(1, Integer.int32_t()))]), + BasicBlock(10, instructions=[Return([arg1])]), + ] + ) + task.graph.add_edges_from( + [ + TrueCase(vertices[0], vertices[1]), + FalseCase(vertices[0], vertices[2]), + TrueCase(vertices[1], vertices[3]), + FalseCase(vertices[1], vertices[4]), + TrueCase(vertices[2], vertices[5]), + FalseCase(vertices[2], vertices[6]), + TrueCase(vertices[3], vertices[7]), + FalseCase(vertices[3], vertices[8]), + UnconditionalEdge(vertices[4], vertices[7]), + UnconditionalEdge(vertices[5], vertices[10]), + UnconditionalEdge(vertices[6], vertices[9]), + UnconditionalEdge(vertices[7], vertices[9]), + TrueCase(vertices[8], vertices[9]), + FalseCase(vertices[8], vertices[10]), + UnconditionalEdge(vertices[9], vertices[10]), + ] + ) + PatternIndependentRestructuring().run(task) + var_2_conditions = [] + for node in task.syntax_tree.get_condition_nodes_post_order(): + if ( + not node.condition.is_symbol + and node.condition.is_literal + and str(task.syntax_tree.condition_map[~node.condition]) in {"arg1 == 0x0"} + ): + node.switch_branches() + if node.condition.is_symbol and str(task.syntax_tree.condition_map[node.condition]) in {"arg1 == 0x0"}: + var_2_conditions.append(node) + assert len(var_2_conditions) == 2 + assert var_2_conditions[0] == var_2_conditions[1] + assert hash(var_2_conditions[0]) != hash(var_2_conditions[1]) + + +def test_condition_based_refined_considers_conditions(task): + """Test condition test 16""" + var_5_0 = Variable("var_5", Integer.int32_t(), None, True, Variable("var_10", Integer.int32_t(), 0, True, None)) + var_5_2 = Variable("var_5", Integer.int32_t(), None, True, Variable("var_10", Integer.int32_t(), 2, True, None)) + var_5_3 = Variable("var_5", Integer.int32_t(), None, True, Variable("var_10", Integer.int32_t(), 3, True, None)) + var_5_4 = Variable("var_5", Integer.int32_t(), None, True, Variable("var_10", Integer.int32_t(), 4, True, None)) + var_5_6 = Variable("var_5", Integer.int32_t(), None, True, Variable("var_10", Integer.int32_t(), 6, True, None)) + var_6_0 = Variable("var_6", Integer.int32_t(), None, True, Variable("var_14", Integer.int32_t(), 0, True, None)) + var_6_2 = Variable("var_6", Integer.int32_t(), None, True, Variable("var_14", Integer.int32_t(), 2, True, None)) + var_6_5 = Variable("var_6", Integer.int32_t(), None, True, Variable("var_14", Integer.int32_t(), 5, True, None)) + var_6_6 = Variable("var_6", Integer.int32_t(), None, True, Variable("var_14", Integer.int32_t(), 6, True, None)) + var_7 = Variable("var_7", Integer.int32_t(), None, True, Variable("c0", Integer.int32_t(), 0, True, None)) + task.graph.add_nodes_from( + vertices := [ + BasicBlock( + 0, + instructions=[ + Assignment( + ListOperation([]), + Call( + imp_function_symbol("sub_140019288"), + [ + UnaryOperation(OperationType.address, [var_5_0], Pointer(Integer.int32_t(), 32)), + UnaryOperation(OperationType.address, [var_6_0], Pointer(Integer.int32_t(), 32)), + ], + ), + ), + Branch(Condition(OperationType.greater, [var_5_2, Constant(4, Integer.int32_t())])), + ], + ), + BasicBlock( + 1, instructions=[Assignment(var_5_3, BinaryOperation(OperationType.minus, [var_5_2, Constant(5, Integer.int32_t())]))] + ), + BasicBlock( + 2, + instructions=[ + Assignment(var_7, BinaryOperation(OperationType.plus, [var_5_2, Constant(5, Integer.int32_t())])), + Assignment(var_5_4, var_7), + Branch(Condition(OperationType.greater, [var_6_0, Constant(4, Integer.int32_t())])), + ], + ), + BasicBlock( + 3, + instructions=[ + Assignment(ListOperation([]), Call(imp_function_symbol("printf"), [Constant(0x804B01F), var_5_6, var_6_6])), + Return([Constant(0, Integer.int32_t())]), + ], + ), + BasicBlock( + 4, + instructions=[ + Assignment(var_6_5, BinaryOperation(OperationType.plus, [var_7, var_6_2])), + ], + ), + ] + ) + task.graph.add_edges_from( + [ + TrueCase(vertices[0], vertices[1]), + FalseCase(vertices[0], vertices[2]), + UnconditionalEdge(vertices[1], vertices[3]), + TrueCase(vertices[2], vertices[3]), + FalseCase(vertices[2], vertices[4]), + UnconditionalEdge(vertices[4], vertices[3]), + ] + ) + PatternIndependentRestructuring().run(task) + assert isinstance(root := task.ast.root, SeqNode) and len(children := root.children) == 3 + assert isinstance(children[0], CodeNode) and children[0].instructions == vertices[0].instructions[:-1] + assert isinstance(cn := children[1], ConditionNode) + assert isinstance(children[2], CodeNode) and children[2].instructions == vertices[3].instructions + + if cn.condition.is_literal and not cn.condition.is_symbol: + cn.switch_branches() + assert task.ast.condition_map[cn.condition] == vertices[0].instructions[-1].condition + assert isinstance(tb := cn.true_branch_child, CodeNode) and tb.instructions == vertices[1].instructions + assert isinstance(fb := cn.false_branch_child, SeqNode) and len(fb_children := fb.children) == 2 + assert isinstance(code_node := fb_children[0], CodeNode) and code_node.instructions == vertices[2].instructions[:-1] + assert isinstance(nested_if := fb_children[1], ConditionNode) and nested_if.condition.is_literal and not nested_if.condition.is_symbol + assert nested_if.false_branch is None and isinstance(nested_code := nested_if.true_branch_child, CodeNode) + assert nested_code.instructions == vertices[4].instructions diff --git a/tests/pipeline/controlflowanalysis/test_pattern_independent_restructuring_blackbox.py b/tests/pipeline/controlflowanalysis/test_pattern_independent_restructuring_blackbox.py index ef0adbd3d..0b61091e5 100644 --- a/tests/pipeline/controlflowanalysis/test_pattern_independent_restructuring_blackbox.py +++ b/tests/pipeline/controlflowanalysis/test_pattern_independent_restructuring_blackbox.py @@ -8,15 +8,18 @@ from decompiler.pipeline.controlflowanalysis.restructuring_commons.condition_aware_refinement_commons.missing_case_finder_intersecting_constants import ( MissingCaseFinderIntersectingConstants, ) +from decompiler.pipeline.controlflowanalysis.restructuring_commons.condition_aware_refinement_commons.switch_extractor import ( + SwitchExtractor, +) from decompiler.pipeline.controlflowanalysis.restructuring_options import LoopBreakOptions, RestructuringOptions -from decompiler.structures.ast.ast_nodes import ConditionNode, SeqNode, SwitchNode +from decompiler.structures.ast.ast_nodes import CodeNode, ConditionNode, SeqNode, SwitchNode from decompiler.structures.ast.condition_symbol import ConditionHandler from decompiler.structures.ast.reachability_graph import SiblingReachabilityGraph from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest from decompiler.structures.graphs.cfg import BasicBlock, ControlFlowGraph, FalseCase, TrueCase, UnconditionalEdge -from decompiler.structures.pseudo.expressions import Constant, Variable +from decompiler.structures.pseudo.expressions import Constant, ImportedFunctionSymbol, Variable from decompiler.structures.pseudo.instructions import Assignment, Branch, Return -from decompiler.structures.pseudo.operations import BinaryOperation, Condition, OperationType +from decompiler.structures.pseudo.operations import BinaryOperation, Call, Condition, ListOperation, OperationType from decompiler.structures.pseudo.typing import CustomType, Integer from decompiler.task import DecompilerTask @@ -183,3 +186,95 @@ def test_insert_intersecting_cases_anywhere(task): assert isinstance(ast.current_root, SeqNode) and len(ast.current_root.children) == 1 assert isinstance(switch := ast.current_root.children[0], SwitchNode) and switch.cases == (case2, case1) + + +def test_switch_extractor_sequence(task): + """Test, switch gets extracted from sequence nodes with Reaching Condition.""" + condition_handler = ConditionHandler() + # cond_1_symbol = condition_handler.add_condition(Condition(OperationType.equal, [var_c, const[1]])) + cond_2_symbol = condition_handler.add_condition(Condition(OperationType.not_equal, [var_c, const[1]])) + + ast = AbstractSyntaxForest(condition_handler=condition_handler) + root = ast.factory.create_seq_node(reaching_condition=cond_2_symbol) + code_node = ast.factory.create_code_node( + [Assignment(ListOperation([]), Call(ImportedFunctionSymbol("scanf", 0x42), [Constant(0x804B01F), var_c]))] + ) + switch = ast.factory.create_switch_node(var_c) + case1 = ast.factory.create_case_node(var_c, const[2], break_case=True) + case2 = ast.factory.create_case_node(var_c, const[3], break_case=True) + case_content = [ + ast.factory.create_code_node([Assignment(var_b, BinaryOperation(OperationType.plus, [var_b, const[i + 1]]))]) for i in range(2) + ] + ast._add_nodes_from(case_content + [root, code_node, switch, case1, case2]) + ast._add_edges_from( + [ + (root, code_node), + (root, switch), + (switch, case1), + (switch, case2), + (case1, case_content[0]), + (case2, case_content[1]), + ] + ) + ast._code_node_reachability_graph.add_reachability_from( + [(code_node, case_content[0]), (code_node, case_content[1]), (case_content[0], case_content[1])] + ) + root.sort_children() + switch.sort_cases() + ast.set_current_root(root) + + SwitchExtractor.extract(ast, RestructuringOptions(True, True, 2, LoopBreakOptions.structural_variable)) + assert isinstance(ast.current_root, SeqNode) and ast.current_root.reaching_condition.is_true and len(ast.current_root.children) == 2 + assert ast.current_root.children[0].reaching_condition == cond_2_symbol + assert isinstance(switch := ast.current_root.children[1], SwitchNode) and switch.cases == (case1, case2) + + +def test_switch_extractor_sequence_no_extraction(task): + """Test, switch gets extracted from sequence nodes with Reaching Condition.""" + condition_handler = ConditionHandler() + # cond_1_symbol = condition_handler.add_condition(Condition(OperationType.equal, [var_c, const[1]])) + cond_1_symbol = condition_handler.add_condition(Condition(OperationType.not_equal, [var_b, const[1]])) + cond_2_symbol = condition_handler.add_condition(Condition(OperationType.not_equal, [var_c, const[1]])) + + ast = AbstractSyntaxForest(condition_handler=condition_handler) + root = ast.factory.create_condition_node(cond_2_symbol) + true_node = ast.factory.create_true_node() + seq_node = ast.factory.create_seq_node(reaching_condition=cond_1_symbol) + code_node = ast.factory.create_code_node( + [Assignment(ListOperation([]), Call(ImportedFunctionSymbol("scanf", 0x42), [Constant(0x804B01F), var_c]))] + ) + switch = ast.factory.create_switch_node(var_c) + case1 = ast.factory.create_case_node(var_c, const[2], break_case=True) + case2 = ast.factory.create_case_node(var_c, const[3], break_case=True) + case_content = [ + ast.factory.create_code_node([Assignment(var_b, BinaryOperation(OperationType.plus, [var_b, const[i + 1]]))]) for i in range(2) + ] + ast._add_nodes_from(case_content + [root, true_node, seq_node, code_node, switch, case1, case2]) + ast._add_edges_from( + [ + (root, true_node), + (true_node, seq_node), + (seq_node, code_node), + (seq_node, switch), + (switch, case1), + (switch, case2), + (case1, case_content[0]), + (case2, case_content[1]), + ] + ) + ast._code_node_reachability_graph.add_reachability_from( + [(code_node, case_content[0]), (code_node, case_content[1]), (case_content[0], case_content[1])] + ) + seq_node.sort_children() + switch.sort_cases() + ast.set_current_root(root) + + SwitchExtractor.extract(ast, RestructuringOptions(True, True, 2, LoopBreakOptions.structural_variable)) + assert isinstance(cond := ast.current_root, ConditionNode) and cond.false_branch is None + assert ( + isinstance(seq_node := cond.true_branch_child, SeqNode) + and seq_node.reaching_condition == cond_1_symbol + and len(seq_node.children) == 2 + ) + assert isinstance(seq_node.children[0], CodeNode) + assert isinstance(switch := seq_node.children[1], SwitchNode) and switch.cases == (case1, case2) diff --git a/tests/pipeline/dataflowanalysis/test_expression_propagation_mem.py b/tests/pipeline/dataflowanalysis/test_expression_propagation_mem.py index 5e6f4d124..a288538b5 100644 --- a/tests/pipeline/dataflowanalysis/test_expression_propagation_mem.py +++ b/tests/pipeline/dataflowanalysis/test_expression_propagation_mem.py @@ -1676,6 +1676,81 @@ def test_correct_propagation_relation(): ] +def test_address_into_dereference(): + """ + Test with cast in destination (x#0 stays the same type) + +---------------------+ + | 0. | + | (long) x#0 = &(x#1) | + | *(x#0) = x#0 | + +---------------------+ + + +---------------------+ + | 0. | + | (long) x#0 = &(x#1) | + | *(x#0) = x#0 | + +---------------------+ + """ + input_cfg, output_cfg = graphs_addr_into_deref() + _run_expression_propagation(input_cfg) + assert _graphs_equal(input_cfg, output_cfg) + + +def test_address_into_dereference_with_multiple_defs(): + """ + Extended test of above where we have two definitions (as a ListOp). + +---------------------+ + | 0. | + | (long) x#1 = &(x#0) | + | *(x#1),y#0 = x#1 | + +---------------------+ + + +---------------------+ + | 0. | + | (long) x#1 = &(x#0) | + | *(x#1),y#0 = x#1 | + +---------------------+ + """ + input_cfg, output_cfg = graphs_addr_into_deref_multiple_defs() + _run_expression_propagation(input_cfg) + assert _graphs_equal(input_cfg, output_cfg) + + +def graphs_addr_into_deref(): + x = vars("x", 2) + in_n0 = BasicBlock( + 0, + [_assign(_cast(int64, x[0]), _addr(x[1])), _assign(_deref(x[0]), x[0])], + ) + in_cfg = ControlFlowGraph() + in_cfg.add_node(in_n0) + out_n0 = BasicBlock( + 0, + [_assign(_cast(int64, x[0]), _addr(x[1])), _assign(_deref(x[0]), x[0])], + ) + out_cfg = ControlFlowGraph() + out_cfg.add_node(out_n0) + return in_cfg, out_cfg + + +def graphs_addr_into_deref_multiple_defs(): + x = vars("x", 2) + y = vars("y", 1) + in_n0 = BasicBlock( + 0, + [_assign(_cast(int64, x[1]), _addr(x[0])), _assign(ListOperation([_deref(x[1]), y[0]]), x[1])], + ) + in_cfg = ControlFlowGraph() + in_cfg.add_node(in_n0) + out_n0 = BasicBlock( + 0, + [_assign(_cast(int64, x[1]), _addr(x[0])), _assign(ListOperation([_deref(x[1]), y[0]]), x[1])], + ) + out_cfg = ControlFlowGraph() + out_cfg.add_node(out_n0) + return in_cfg, out_cfg + + def graphs_with_no_propagation_of_contraction_address_assignment(): x = vars("x", 3) ptr = vars("ptr", 1, type=Pointer(int32)) diff --git a/tests/util/test_decoration.py b/tests/util/test_decoration.py index e055eeeb3..ed2761e4d 100644 --- a/tests/util/test_decoration.py +++ b/tests/util/test_decoration.py @@ -191,7 +191,7 @@ def test_convert_to_dot(self, simple_graph): content = dot_converter._create_dot() assert ( content - == """strict digraph { + == """digraph { 0 [shape="box", color="blue", label="0.\\na#0 = 0x2\\nb#0 = foo(a#0)\\nif(a#0 < b#0)"]; 1 [shape="box", color="blue", label="1.\\nb#2 = a#0 - b#0"]; 2 [shape="box", color="blue", label="2.\\nb#1 = ϕ(b#0,b#2)\\nreturn b#1"]; @@ -207,7 +207,7 @@ def test_convert_to_dot_with_string(self, graph_with_string): content = dot_converter._create_dot() assert ( content - == """strict digraph { + == """digraph { 0 [shape="box", color="blue", label="0.\\na#0 = 0x2\\nb#0 = foo(a#0)\\nif(a#0 < b#0)"]; 1 [shape="box", color="blue", label="1.\\nb#2 = a#0 - b#0"]; 2 [shape="box", color="blue", label="2.\\nb#1 = ϕ(b#0,b#2)\\nprintf(\\"The result is : %i\\", b#1)\\nreturn b#1"]; @@ -470,7 +470,7 @@ def test_dotviz_output(self, ast_condition): [ x in data for x in [ - r"strict digraph {", + r"digraph {", r'[style="filled", fillcolor="#e6f5c9", label="0. SeqNode\n\nSequence"];', r'[style="filled", fillcolor="#e6f5c9", label="1. ConditionNode\n\nif (true)"]', r'[style="filled", fillcolor="#e6f5c9", label="2. SeqNode\n\nSequence"];', @@ -584,7 +584,7 @@ def test_convert_to_dot_if(self, ast_condition): content = dot_converter._create_dot() assert ( content - == """strict digraph { + == """digraph { 0 [style="filled", fillcolor="#e6f5c9", label="0. SeqNode\\n\\nSequence"]; 1 [style="filled", fillcolor="#e6f5c9", label="1. ConditionNode\\n\\nif (true)"]; 2 [style="filled", fillcolor="#e6f5c9", label="2. SeqNode\\n\\nSequence"]; @@ -606,7 +606,7 @@ def test_convert_to_dot_switch(self, ast_switch): content = dot_converter._create_dot() assert ( content - == """strict digraph { + == """digraph { 0 [style="filled", fillcolor="#e6f5c9", label="0. SeqNode\\n\\nSequence"]; 1 [style="filled", fillcolor="#fdcdac", label="1. SwitchNode\\n\\nswitch (0x29)"]; 2 [style="filled", fillcolor="#e6f5c9", label="2. CaseNode\\n\\ncase 0x0:"];