diff --git a/decompiler/pipeline/dataflowanalysis/expressionpropagationfunctioncall.py b/decompiler/pipeline/dataflowanalysis/expressionpropagationfunctioncall.py index 4efd7cb3c..3572d67a5 100644 --- a/decompiler/pipeline/dataflowanalysis/expressionpropagationfunctioncall.py +++ b/decompiler/pipeline/dataflowanalysis/expressionpropagationfunctioncall.py @@ -60,9 +60,25 @@ def _is_call_value_used_exactly_once(self, definition: Assignment) -> bool: True on exactly one use. False otherwise, or Call has more than one return value. """ - if len(return_values := definition.destination.requirements) == 1: - return len(self._use_map.get(return_values[0])) == 1 - return False + if len(return_values := definition.destination.requirements) != 1: + return False + + [required_variable] = return_values + requiring_instructions = self._use_map.get(required_variable) + + if len(requiring_instructions) != 1: + return False + + [requiring_instruction] = requiring_instructions + + usages = 0 + for variable in requiring_instruction.requirements_iter: + if variable == required_variable: + usages += 1 + if usages > 1: + return False + + return usages == 1 def _definition_can_be_propagated_into_target(self, definition: Assignment, target: Instruction): """Tests if propagation is allowed based on set of rules, namely diff --git a/decompiler/structures/pseudo/expressions.py b/decompiler/structures/pseudo/expressions.py index 4f9774ac6..582af3002 100644 --- a/decompiler/structures/pseudo/expressions.py +++ b/decompiler/structures/pseudo/expressions.py @@ -30,8 +30,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, Iterator, List, Optional, Tuple, TypeVar, Union +from typing import TYPE_CHECKING, Generic, Iterator, List, Optional, Tuple, TypeVar, Union, final +from ...util.insertion_ordered_set import InsertionOrderedSet from .complextypes import Enum from .typing import CustomType, Type, UnknownType @@ -85,10 +86,16 @@ def complexity(self) -> int: pass @property - @abstractmethod + def requirements_iter(self) -> Iterator[Variable]: + """Return an iterator of required variables.""" + return + yield + + @property + @final def requirements(self) -> List[Variable]: - """Return a list of required variables.""" - pass + """Return a list of unique required variables.""" + return list(InsertionOrderedSet(self.requirements_iter)) def copy(self): """Generate a copy of the object.""" @@ -122,11 +129,6 @@ def complexity(self) -> int: """Simple expressions like constants and variables have complexity 1""" return 1 - @property - def requirements(self) -> List[Variable]: - """Requirements are empty list by default if not redefined by concrete Expression implementation""" - return [] - def substitute(self, replacee: Expression, replacement: Expression) -> None: """Do nothing: default behavior for simple expressions, like Variables and Constants""" pass @@ -382,9 +384,9 @@ def name(self) -> str: return self._name @property - def requirements(self) -> List["Variable"]: + def requirements_iter(self) -> Iterator["Variable"]: """A variable depends on itself""" - return [self] + yield self @property def type(self) -> DecompiledType: @@ -510,12 +512,14 @@ def low(self) -> Variable: return self._low @property - def requirements(self) -> List[Variable]: + def requirements_iter(self) -> Iterator[Variable]: """Pairs depend on their components and itself in case when being used as a single variable e.g. 0: (eax:edx) = 0x666667 * ebx 1: edx = (eax:edx) - 2 """ - return [self, self._high, self._low] + yield self + yield self._high + yield self._low @property def type(self) -> Type: diff --git a/decompiler/structures/pseudo/instructions.py b/decompiler/structures/pseudo/instructions.py index d7dfdc8fe..1d89e76e0 100644 --- a/decompiler/structures/pseudo/instructions.py +++ b/decompiler/structures/pseudo/instructions.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, Generic, Iterator, List, Optional, Sequence, Set, Tuple, TypeVar, Union, final -from .expressions import Constant, DataflowObject, Expression, GlobalVariable, Tag, Variable +from .expressions import Constant, DataflowObject, Expression, Tag, Variable from .operations import BinaryOperation, Call, Condition, ListOperation, OperationType, UnaryOperation E = TypeVar("E", bound=Expression) @@ -74,11 +74,6 @@ def complexity(self) -> int: """Return 0, since comment should not add complexity.""" return 0 - @property - def requirements(self) -> List["Variable"]: - """Return [] since comment has no requirements.""" - return [] - def copy(self) -> Comment: """Return a Comment with same str parameters.""" return Comment(self._comment, self._comment_style, self.tags) @@ -131,15 +126,16 @@ def value(self) -> F: return self._value @property - def requirements(self) -> List[Variable]: + def requirements_iter(self) -> Iterator[Variable]: """Return the values necessary for evaluation.""" if ( - isinstance(self._destination, Variable) - or isinstance(self._destination, ListOperation) - or self._is_contraction(self._destination) + not isinstance(self.destination, Variable) + and not isinstance(self.destination, ListOperation) + and not self._is_contraction(self.destination) ): - return self._value.requirements - return self._destination.requirements + self._value.requirements + yield from self._destination.requirements_iter + + yield from self._value.requirements_iter @property def writes_memory(self) -> Optional[int]: @@ -260,9 +256,9 @@ def complexity(self) -> int: return self.condition.complexity @property - def requirements(self) -> List[Variable]: + def requirements_iter(self) -> Iterator[Variable]: """Return the conditions dependencies.""" - return self.condition.requirements + return self.condition.requirements_iter def substitute(self, replacee: Expression, replacement: Expression) -> None: """Substitutes condition directly (in case of condition is a variable) @@ -372,9 +368,9 @@ def complexity(self) -> int: return self._values.complexity @property - def requirements(self) -> List[Variable]: + def requirements_iter(self) -> Iterator[Variable]: """All returned values are required by the return statement.""" - return self._values.requirements + return self._values.requirements_iter @property def values(self) -> ListOperation: @@ -405,10 +401,6 @@ def __str__(self) -> str: def complexity(self) -> int: return 0 - @property - def requirements(self) -> List[Variable]: - return [] - @final def copy(self) -> Break: return Break() @@ -431,10 +423,6 @@ def __str__(self) -> str: def complexity(self) -> int: return 0 - @property - def requirements(self) -> List[Variable]: - return [] - @final def copy(self) -> Continue: return Continue() diff --git a/decompiler/structures/pseudo/operations.py b/decompiler/structures/pseudo/operations.py index c21002068..b7c5b169e 100644 --- a/decompiler/structures/pseudo/operations.py +++ b/decompiler/structures/pseudo/operations.py @@ -231,9 +231,10 @@ def operation(self) -> OperationType: return self._operation @property - def requirements(self) -> List[Variable]: + def requirements_iter(self) -> Iterator[Variable]: """Operation requires a list of all unique variables required by each of its operands""" - return self._collect_required_variables(self.operands) + for operand in self._operands: + yield from operand.requirements_iter def substitute(self, replacee: Expression, replacement: Expression) -> None: """Substitutes operand directly if possible, then recursively substitutes replacee in operands""" @@ -355,9 +356,9 @@ def complexity(self) -> int: return self.operand.complexity @property - def requirements(self): + def requirements_iter(self) -> Iterator[Variable]: """Return the requirements of the single operand.""" - return self.operand.requirements + return self.operand.requirements_iter @property def writes_memory(self) -> Optional[int]: @@ -516,10 +517,9 @@ def meta_data(self) -> Dict[str, List[str]]: return self._meta_data @property - def requirements(self) -> List[Variable]: - if isinstance(self._function, Variable): - return self._collect_required_variables(self.operands + [self.function]) - return super().requirements + def requirements_iter(self) -> Iterator[Variable]: + yield from self._function.requirements_iter + yield from super().requirements_iter @property def function(self) -> Union[FunctionSymbol, ImportedFunctionSymbol, IntrinsicSymbol, Variable]: diff --git a/tests/pipeline/dataflowanalysis/test_expression_propagation_function_call.py b/tests/pipeline/dataflowanalysis/test_expression_propagation_function_call.py index f4b90754b..50f24d0f8 100644 --- a/tests/pipeline/dataflowanalysis/test_expression_propagation_function_call.py +++ b/tests/pipeline/dataflowanalysis/test_expression_propagation_function_call.py @@ -1,3 +1,4 @@ +import copy from typing import List from decompiler.pipeline.dataflowanalysis import ExpressionPropagationFunctionCall @@ -175,6 +176,38 @@ def test_multiple_propagations(): assert node.instructions == [_assign(return_x, Constant(0x0)), _assign(return_y, Constant(0x0)), Return([_func("g", [_func("f", [])])])] +def test_single_instruction_multiple_propagations(): + """ + Test that functions are not propagated if they occur multiple times in a single instruction. + + +-------------+ + | 0. | + | x = f() | + | y = g(x, x) | + | return y | + +-------------+ + + + +-------------+ + | 0. | + | x = f() | + | y = g(x, x) | + | return y | + +-------------+ + """ + + instructions = [ + _assign(x, _func("f", [])), + Return([_func("g", [x, x])]) + ] + cfg = ControlFlowGraph() + cfg.add_node(BasicBlock(0, copy.deepcopy(instructions))) + + _run_expression_propagation(cfg) + + assert cfg.nodes[0].instructions == instructions + + def _func(name: str, parameters: List): return Call(FunctionSymbol(name, 0), parameters, writes_memory=1) diff --git a/tests/structures/pseudo/test_requirements.py b/tests/structures/pseudo/test_requirements.py new file mode 100644 index 000000000..dd0a181c3 --- /dev/null +++ b/tests/structures/pseudo/test_requirements.py @@ -0,0 +1,39 @@ +import pytest +from decompiler.structures.pseudo import ( + Assignment, + BinaryOperation, + Call, + DataflowObject, + IndirectBranch, + Integer, + ListOperation, + OperationType, + RegisterPair, + Return, + UnaryOperation, + Variable, +) + +_a = Variable("a", Integer.int32_t(), 0) +_b = Variable("b", Integer.int32_t(), 1) + + +@pytest.mark.parametrize( + ["obj", "expected_requirements"], + [ + (_a, [_a]), + (_r := RegisterPair(_a, _b), [_r, _a, _b]), + (Assignment(_a, _b), [_b]), + (Assignment(ListOperation([_a]), _b), [_b]), + (Assignment(UnaryOperation(OperationType.cast, [_a], contraction=True), _b), [_b]), + (Assignment(UnaryOperation(OperationType.dereference, [_a]), _b), [_a, _b]), + (IndirectBranch(_a), [_a]), + (Return([_a, _b]), [_a, _b]), + (ListOperation([_a, _b]), [_a, _b]), + (BinaryOperation(OperationType.plus, [_a, _b]), [_a, _b]), + (Call(_a, [_b]), [_a, _b]), + (BinaryOperation(OperationType.plus, [_a, _a]), [_a, _a]) + ] +) +def test_requirements(obj: DataflowObject, expected_requirements: list[Variable]): + assert list(obj.requirements_iter) == expected_requirements