diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py index 70cc75408..89d142c46 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py @@ -1,3 +1,5 @@ +from typing import Optional + from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule from decompiler.structures.pseudo import BinaryOperation, Expression, Operation, OperationType, UnaryOperation @@ -8,25 +10,26 @@ class CollapseAddNeg(SimplificationRule): - `e0 + -(e1) -> e0 - e1` - `e0 - -(e1) -> e0 + e1` + - `-(e0) + e1 -> e1 - e0` + - `-(e0) - e1 -> -(e0 + e1)` """ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: - if operation.operation not in [OperationType.plus, OperationType.minus]: - return [] - if not isinstance(operation, BinaryOperation): - raise TypeError(f"Expected BinaryOperation, got {type(operation)}") + replacement: Optional[Expression] = None + match operation: + case BinaryOperation(operation=OperationType.plus, left=e0, right=UnaryOperation(operation=OperationType.negate, operand=e1)): + replacement = BinaryOperation(OperationType.minus, [e0, e1], operation.type) + + case BinaryOperation(operation=OperationType.minus, left=e0, right=UnaryOperation(operation=OperationType.negate, operand=e1)): + replacement = BinaryOperation(OperationType.plus, [e0, e1], operation.type) + + case BinaryOperation(operation=OperationType.plus, left=UnaryOperation(operation=OperationType.negate, operand=e0), right=e1): + replacement = BinaryOperation(OperationType.minus, [e1, e0], operation.type) + + case BinaryOperation(operation=OperationType.minus, left=UnaryOperation(operation=OperationType.negate, operand=e0), right=e1): + replacement = UnaryOperation(OperationType.negate, [BinaryOperation(OperationType.plus, [e0, e1], operation.type)]) - right = operation.right - if not isinstance(right, UnaryOperation) or right.operation != OperationType.negate: + if replacement is None: return [] - return [ - ( - operation, - BinaryOperation( - OperationType.minus if operation.operation == OperationType.plus else OperationType.plus, - [operation.left, right.operand], - operation.type, - ), - ) - ] + return [(operation, replacement)] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_add_neg.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_add_neg.py index d7300107f..e8dedad6d 100644 --- a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_add_neg.py +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_add_neg.py @@ -17,6 +17,14 @@ BinaryOperation(OperationType.minus, [var_x, UnaryOperation(OperationType.negate, [var_y])]), [BinaryOperation(OperationType.plus, [var_x, var_y])], ), + ( + BinaryOperation(OperationType.plus, [UnaryOperation(OperationType.negate, [var_x]), var_y]), + [BinaryOperation(OperationType.minus, [var_y, var_x])], + ), + ( + BinaryOperation(OperationType.minus, [UnaryOperation(OperationType.negate, [var_x]), var_y]), + [UnaryOperation(OperationType.negate, [BinaryOperation(OperationType.plus, [var_x, var_y])])], + ), ], ) def test_collapse_add_neg(operation: Operation, result: list[Expression]):