-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement expression simplification rules
- Loading branch information
Showing
28 changed files
with
1,085 additions
and
2 deletions.
There are no files selected for viewing
117 changes: 117 additions & 0 deletions
117
decompiler/pipeline/controlflowanalysis/expression_simplification/modification.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from typing import Callable, Optional | ||
|
||
from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, OperationType | ||
|
||
|
||
def multiply_int_with_constant(expression: Expression, constant: Constant) -> Expression: | ||
""" | ||
Multiply an integer expression with an integer constant. | ||
:param expression: The integer expression to be multiplied. | ||
:param constant: The constant value to multiply the expression by. | ||
:return: A simplified expression representing the multiplication result. | ||
""" | ||
|
||
if not isinstance(expression.type, Integer): | ||
raise ValueError(f"Expression must have integer type, got {expression.type}.") | ||
if not isinstance(constant.type, Integer): | ||
raise ValueError(f"Constant must have integer type, got {constant.type}.") | ||
if expression.type != constant.type: | ||
raise ValueError(f"Expression and constant type must equal. {expression.type} != {constant.type}") | ||
|
||
if isinstance(expression, Constant): | ||
return constant_fold(OperationType.multiply, [expression, constant]) | ||
else: | ||
return BinaryOperation(OperationType.multiply, [expression, constant]) | ||
|
||
|
||
_FOLD_HANDLER: dict[OperationType, Callable[[list[Constant]], Constant]] = { | ||
OperationType.minus: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x - y), | ||
OperationType.plus: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x + y), | ||
OperationType.multiply: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x * y, True), | ||
OperationType.multiply_us: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x * y, False), | ||
OperationType.divide: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x // y, True), | ||
OperationType.divide_us: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x // y, False), | ||
OperationType.negate: lambda constants: _constant_fold_arithmetic_unary(constants, lambda value: -value), | ||
OperationType.left_shift: lambda constants: _constant_fold_shift(constants, lambda value, shift, size: value << shift), | ||
OperationType.right_shift: lambda constants: _constant_fold_shift(constants, lambda value, shift, size: value >> shift), | ||
OperationType.right_shift_us: lambda constants: _constant_fold_shift( | ||
constants, lambda value, shift, size: normalize_int(value >> shift, size - shift, False) | ||
), | ||
OperationType.bitwise_or: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x | y), | ||
OperationType.bitwise_and: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x & y), | ||
OperationType.bitwise_xor: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x ^ y), | ||
OperationType.bitwise_not: lambda constants: _constant_fold_arithmetic_unary(constants, lambda x: ~x), | ||
} | ||
|
||
|
||
FOLDABLE_CONSTANTS = _FOLD_HANDLER.keys() | ||
|
||
|
||
def constant_fold(operation: OperationType, constants: list[Constant]) -> Constant: | ||
""" | ||
Fold operation with constants as operands. | ||
:param operation: The operation. | ||
:param constants: All constant operands of the operation. | ||
:return: A constant representing the result of the operation. | ||
""" | ||
|
||
if operation not in _FOLD_HANDLER: | ||
raise ValueError(f"Constant folding not implemented for operation '{operation}'.") | ||
|
||
return _FOLD_HANDLER[operation](constants) | ||
|
||
|
||
def _constant_fold_arithmetic_binary( | ||
constants: list[Constant], | ||
fun: Callable[[int, int], int], | ||
norm_sign: Optional[bool] = None | ||
) -> Constant: | ||
if len(constants) != 2: | ||
raise ValueError(f"Expected exactly 2 constants to fold, got {len(constants)}.") | ||
if not all(constant.type == constants[0].type for constant in constants): | ||
raise ValueError(f"Can not fold constants with different types: {(constant.type for constant in constants)}") | ||
if not all(isinstance(constant.type, Integer) for constant in constants): | ||
raise ValueError(f"All constants must be integers, got {list(constant.type for constant in constants)}.") | ||
|
||
left, right = constants | ||
|
||
left_value = left.value | ||
right_value = right.value | ||
if norm_sign is not None: | ||
left_value = normalize_int(left_value, left.type.size, norm_sign) | ||
right_value = normalize_int(right_value, right.type.size, norm_sign) | ||
|
||
return Constant( | ||
normalize_int(fun(left_value, right_value), left.type.size, left.type.signed), | ||
left.type | ||
) | ||
|
||
|
||
def _constant_fold_arithmetic_unary(constants: list[Constant], fun: Callable[[int], int]) -> Constant: | ||
if len(constants) != 1: | ||
raise ValueError("Expected exactly 1 constant to fold") | ||
if not isinstance(constants[0].type, Integer): | ||
raise ValueError(f"Constant must be of type integer: {constants[0].type}") | ||
|
||
return Constant(normalize_int(fun(constants[0].value), constants[0].type.size, constants[0].type.signed), constants[0].type) | ||
|
||
|
||
def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int, int], int]) -> Constant: | ||
if len(constants) != 2: | ||
raise ValueError("Expected exactly 2 constants to fold") | ||
if not all(isinstance(constant.type, Integer) for constant in constants): | ||
raise ValueError("All constants must be integers") | ||
|
||
left, right = constants | ||
|
||
return Constant(normalize_int(fun(left.value, right.value, left.type.size), left.type.size, left.type.signed), left.type) | ||
|
||
|
||
def normalize_int(v: int, size: int, signed: bool) -> int: | ||
value = v & ((1 << size) - 1) | ||
if signed and value & (1 << (size - 1)): | ||
return value - (1 << size) | ||
else: | ||
return value |
30 changes: 30 additions & 0 deletions
30
decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule | ||
from decompiler.structures.pseudo import BinaryOperation, Expression, Operation, OperationType, UnaryOperation | ||
|
||
|
||
class CollapseAddNeg(SimplificationRule): | ||
""" | ||
Simplifies additions/subtraction with negated expression. | ||
- `e0 + -(e1) -> e0 - e1` | ||
- `e0 - -(e1) -> e0 + e1` | ||
""" | ||
|
||
def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: | ||
if operation.operation not in [OperationType.plus, OperationType.minus]: | ||
return [] | ||
if not isinstance(operation, BinaryOperation): | ||
raise TypeError(f"Expected BinaryOperation, got {type(operation)}") | ||
|
||
right = operation.right | ||
if not isinstance(right, UnaryOperation) or right.operation != OperationType.negate: | ||
return [] | ||
|
||
return [( | ||
operation, | ||
BinaryOperation( | ||
OperationType.minus if operation.operation == OperationType.plus else OperationType.plus, | ||
[operation.left, right.operand], | ||
operation.type | ||
) | ||
)] |
20 changes: 20 additions & 0 deletions
20
...mpiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from decompiler.pipeline.controlflowanalysis.expression_simplification.modification import FOLDABLE_CONSTANTS, constant_fold | ||
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule | ||
from decompiler.structures.pseudo import Constant, Expression, Operation | ||
|
||
|
||
class CollapseConstants(SimplificationRule): | ||
""" | ||
Fold operations with only constants as operands: | ||
""" | ||
|
||
def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: | ||
if not all(isinstance(o, Constant) for o in operation.operands): | ||
return [] | ||
if operation.operation not in FOLDABLE_CONSTANTS: | ||
return [] | ||
|
||
return [( | ||
operation, | ||
constant_fold(operation.operation, operation.operands) | ||
)] |
29 changes: 29 additions & 0 deletions
29
...ler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_mult_neg_one.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule | ||
from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType, UnaryOperation | ||
|
||
|
||
class CollapseMultNegOne(SimplificationRule): | ||
""" | ||
Simplifies expressions multiplied with -1. | ||
`e0 * -1 -> -(e0)` | ||
""" | ||
|
||
def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: | ||
if operation.operation != OperationType.multiply: | ||
return [] | ||
if not isinstance(operation, BinaryOperation): | ||
raise TypeError(f"Expected BinaryOperation, got {type(operation)}") | ||
|
||
right = operation.right | ||
if not isinstance(right, Constant) or right.value != -1: | ||
return [] | ||
|
||
return [( | ||
operation, | ||
UnaryOperation( | ||
OperationType.negate, | ||
[operation.left], | ||
operation.type | ||
) | ||
)] |
65 changes: 65 additions & 0 deletions
65
decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collect_terms.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from functools import reduce | ||
from typing import Iterator | ||
|
||
from decompiler.pipeline.controlflowanalysis.expression_simplification.modification import constant_fold | ||
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule | ||
from decompiler.structures.pseudo import Constant, Expression, Operation, OperationType, Type | ||
from decompiler.structures.pseudo.operations import COMMUTATIVE_OPERATIONS | ||
|
||
|
||
class CollectTerms(SimplificationRule): | ||
def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: | ||
if operation.operation not in COMMUTATIVE_OPERATIONS: | ||
return [] | ||
if not isinstance(operation, Operation): | ||
raise TypeError(f"Expected Operation, got {type(operation)}") | ||
|
||
operands = list(_collect_constants(operation)) | ||
if len(operands) <= 1: | ||
return [] | ||
|
||
first, *rest = operands | ||
|
||
folded_constant = reduce( | ||
lambda c0, c1: constant_fold(operation.operation, [c0, c1]), | ||
rest, | ||
first | ||
) | ||
|
||
identity_constant = _identity_constant(operation.operation, operation.type) | ||
return [ | ||
(first, folded_constant), | ||
*((o, identity_constant) for o in rest) | ||
] | ||
|
||
|
||
def _collect_constants(operation: Operation) -> Iterator[Constant]: | ||
operation_type = operation.operation | ||
operand_type = operation.type | ||
|
||
context_stack: list[Operation] = [operation] | ||
while context_stack: | ||
current_operation = context_stack.pop() | ||
|
||
for i, operand in enumerate(current_operation.operands): | ||
if operand.type != operand_type: | ||
continue | ||
|
||
if isinstance(operand, Operation): | ||
if operand.operation == operation_type: | ||
context_stack.append(operand) | ||
continue | ||
elif isinstance(operand, Constant) and _identity_constant(operation_type, operand_type).value != operand.value: | ||
yield operand | ||
|
||
|
||
def _identity_constant(operation: OperationType, var_type: Type) -> Constant: | ||
match operation: | ||
case OperationType.plus | OperationType.bitwise_xor | OperationType.bitwise_or: | ||
return Constant(0, var_type) | ||
case OperationType.multiply | OperationType.multiply_us: | ||
return Constant(1, var_type) | ||
case OperationType.bitwise_and: | ||
return constant_fold(OperationType.bitwise_not, [Constant(0, var_type)]) | ||
case _: | ||
raise NotImplementedError() |
42 changes: 42 additions & 0 deletions
42
decompiler/pipeline/controlflowanalysis/expression_simplification/rules/fix_add_sub_sign.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from decompiler.pipeline.controlflowanalysis.expression_simplification.modification import normalize_int | ||
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule | ||
from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType | ||
|
||
|
||
class FixAddSubSign(SimplificationRule): | ||
""" | ||
Changes add/sub when variable type is signed. | ||
- `V - a -> E + (-a)` when signed(a) < 0 | ||
- `V + a -> E - (-a)` when signed(a) < 0 | ||
""" | ||
|
||
def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: | ||
if operation.operation not in (OperationType.plus, OperationType.minus): | ||
return [] | ||
if not isinstance(operation, BinaryOperation): | ||
raise TypeError(f"Expected BinaryOperation, got {type(operation)}") | ||
|
||
right = operation.right | ||
if not isinstance(right, Constant): | ||
return [] | ||
|
||
con_type = right.type | ||
if not isinstance(con_type, Integer): | ||
return [] | ||
|
||
a = normalize_int(right.value, con_type.size, True) | ||
if a >= 0: | ||
return [] | ||
|
||
neg_a = Constant( | ||
normalize_int(-a, con_type.size, con_type.signed), | ||
con_type | ||
) | ||
return [( | ||
operation, | ||
BinaryOperation( | ||
OperationType.plus if operation.operation == OperationType.minus else OperationType.minus, | ||
[operation.left, neg_a] | ||
) | ||
)] |
20 changes: 20 additions & 0 deletions
20
...eline/controlflowanalysis/expression_simplification/rules/simplify_redundant_reference.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule | ||
from decompiler.structures.pseudo import Expression, Operation, OperationType, UnaryOperation | ||
|
||
|
||
class SimplifyRedundantReference(SimplificationRule): | ||
""" | ||
Removes redundant nesting of referencing, immediately followed by referencing. | ||
`*(&(e0)) -> e0` | ||
""" | ||
|
||
def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: | ||
match operation: | ||
case UnaryOperation( | ||
operation=OperationType.dereference, | ||
operand=UnaryOperation(operation=OperationType.address, operand=operand) | ||
): | ||
return [(operation, operand)] | ||
case _: | ||
return [] |
37 changes: 37 additions & 0 deletions
37
...peline/controlflowanalysis/expression_simplification/rules/simplify_trivial_arithmetic.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule | ||
from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType, UnaryOperation | ||
|
||
|
||
class SimplifyTrivialArithmetic(SimplificationRule): | ||
""" | ||
Simplifies trivial arithmetic: | ||
- `e + 0 -> e` | ||
- `e - 0 -> e` | ||
- `e * 1 -> e` | ||
- `e u* 1 -> e` | ||
- `e * -1 -> -e` | ||
- `e u* -1 -> -e` | ||
- `e / 1 -> e` | ||
- `e u/ 1 -> e` | ||
- `e / -1 -> -e` | ||
""" | ||
|
||
def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: | ||
match operation: | ||
case BinaryOperation(operation=OperationType.plus | OperationType.minus, right=Constant(value=0)): | ||
return [(operation, operation.left)] | ||
case BinaryOperation( | ||
operation=OperationType.multiply | OperationType.multiply_us | OperationType.divide | OperationType.divide_us, | ||
right=Constant(value=1), | ||
): | ||
return [(operation, operation.left)] | ||
case BinaryOperation(operation=OperationType.multiply, right=Constant(value=0)): | ||
return [(operation, Constant(0, operation.type))] | ||
case BinaryOperation( | ||
operation=OperationType.multiply | OperationType.multiply_us | OperationType.divide, | ||
right=Constant(value=-1) | ||
): | ||
return [(operation, UnaryOperation(OperationType.negate, [operation.left]))] | ||
case _: | ||
return [] |
28 changes: 28 additions & 0 deletions
28
...ne/controlflowanalysis/expression_simplification/rules/simplify_trivial_bit_arithmetic.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule | ||
from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType | ||
|
||
|
||
class SimplifyTrivialBitArithmetic(SimplificationRule): | ||
""" | ||
Simplifies trivial bit arithmetic: | ||
- `e | 0 -> e` | ||
- `e | e -> e` | ||
- `e & 0 -> 0` | ||
- `e & e -> e` | ||
- `e ^ 0 -> e` | ||
- `e ^ e -> 0` | ||
""" | ||
|
||
def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: | ||
match operation: | ||
case BinaryOperation(operation=OperationType.bitwise_or | OperationType.bitwise_xor, right=Constant(value=0)): | ||
return [(operation, operation.left)] | ||
case BinaryOperation(operation=OperationType.bitwise_and, right=Constant(value=0)): | ||
return [(operation, Constant(0, operation.type))] | ||
case BinaryOperation(operation=OperationType.bitwise_or | OperationType.bitwise_and, left=left, right=right) if left == right: | ||
return [(operation, operation.left)] | ||
case BinaryOperation(operation=OperationType.bitwise_xor, left=left, right=right) if left == right: | ||
return [(operation, Constant(0, operation.type))] | ||
case _: | ||
return [] |
Oops, something went wrong.