Skip to content

Commit

Permalink
Implement expression simplification rules
Browse files Browse the repository at this point in the history
  • Loading branch information
rihi committed Aug 24, 2023
1 parent a3e060e commit 58ef1af
Show file tree
Hide file tree
Showing 28 changed files with 1,085 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import Callable, Optional

from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, OperationType


def multiply_int_with_constant(expression: Expression, constant: Constant) -> Expression:
"""
Multiply an integer expression with an integer constant.
:param expression: The integer expression to be multiplied.
:param constant: The constant value to multiply the expression by.
:return: A simplified expression representing the multiplication result.
"""

if not isinstance(expression.type, Integer):
raise ValueError(f"Expression must have integer type, got {expression.type}.")
if not isinstance(constant.type, Integer):
raise ValueError(f"Constant must have integer type, got {constant.type}.")
if expression.type != constant.type:
raise ValueError(f"Expression and constant type must equal. {expression.type} != {constant.type}")

if isinstance(expression, Constant):
return constant_fold(OperationType.multiply, [expression, constant])
else:
return BinaryOperation(OperationType.multiply, [expression, constant])


_FOLD_HANDLER: dict[OperationType, Callable[[list[Constant]], Constant]] = {
OperationType.minus: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x - y),
OperationType.plus: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x + y),
OperationType.multiply: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x * y, True),
OperationType.multiply_us: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x * y, False),
OperationType.divide: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x // y, True),
OperationType.divide_us: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x // y, False),
OperationType.negate: lambda constants: _constant_fold_arithmetic_unary(constants, lambda value: -value),
OperationType.left_shift: lambda constants: _constant_fold_shift(constants, lambda value, shift, size: value << shift),
OperationType.right_shift: lambda constants: _constant_fold_shift(constants, lambda value, shift, size: value >> shift),
OperationType.right_shift_us: lambda constants: _constant_fold_shift(
constants, lambda value, shift, size: normalize_int(value >> shift, size - shift, False)
),
OperationType.bitwise_or: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x | y),
OperationType.bitwise_and: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x & y),
OperationType.bitwise_xor: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x ^ y),
OperationType.bitwise_not: lambda constants: _constant_fold_arithmetic_unary(constants, lambda x: ~x),
}


FOLDABLE_CONSTANTS = _FOLD_HANDLER.keys()


def constant_fold(operation: OperationType, constants: list[Constant]) -> Constant:
"""
Fold operation with constants as operands.
:param operation: The operation.
:param constants: All constant operands of the operation.
:return: A constant representing the result of the operation.
"""

if operation not in _FOLD_HANDLER:
raise ValueError(f"Constant folding not implemented for operation '{operation}'.")

return _FOLD_HANDLER[operation](constants)


def _constant_fold_arithmetic_binary(
constants: list[Constant],
fun: Callable[[int, int], int],
norm_sign: Optional[bool] = None
) -> Constant:
if len(constants) != 2:
raise ValueError(f"Expected exactly 2 constants to fold, got {len(constants)}.")
if not all(constant.type == constants[0].type for constant in constants):
raise ValueError(f"Can not fold constants with different types: {(constant.type for constant in constants)}")
if not all(isinstance(constant.type, Integer) for constant in constants):
raise ValueError(f"All constants must be integers, got {list(constant.type for constant in constants)}.")

left, right = constants

left_value = left.value
right_value = right.value
if norm_sign is not None:
left_value = normalize_int(left_value, left.type.size, norm_sign)
right_value = normalize_int(right_value, right.type.size, norm_sign)

return Constant(
normalize_int(fun(left_value, right_value), left.type.size, left.type.signed),
left.type
)


def _constant_fold_arithmetic_unary(constants: list[Constant], fun: Callable[[int], int]) -> Constant:
if len(constants) != 1:
raise ValueError("Expected exactly 1 constant to fold")
if not isinstance(constants[0].type, Integer):
raise ValueError(f"Constant must be of type integer: {constants[0].type}")

return Constant(normalize_int(fun(constants[0].value), constants[0].type.size, constants[0].type.signed), constants[0].type)


def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int, int], int]) -> Constant:
if len(constants) != 2:
raise ValueError("Expected exactly 2 constants to fold")
if not all(isinstance(constant.type, Integer) for constant in constants):
raise ValueError("All constants must be integers")

left, right = constants

return Constant(normalize_int(fun(left.value, right.value, left.type.size), left.type.size, left.type.signed), left.type)


def normalize_int(v: int, size: int, signed: bool) -> int:
value = v & ((1 << size) - 1)
if signed and value & (1 << (size - 1)):
return value - (1 << size)
else:
return value
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule
from decompiler.structures.pseudo import BinaryOperation, Expression, Operation, OperationType, UnaryOperation


class CollapseAddNeg(SimplificationRule):
"""
Simplifies additions/subtraction with negated expression.
- `e0 + -(e1) -> e0 - e1`
- `e0 - -(e1) -> e0 + e1`
"""

def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
if operation.operation not in [OperationType.plus, OperationType.minus]:
return []
if not isinstance(operation, BinaryOperation):
raise TypeError(f"Expected BinaryOperation, got {type(operation)}")

right = operation.right
if not isinstance(right, UnaryOperation) or right.operation != OperationType.negate:
return []

return [(
operation,
BinaryOperation(
OperationType.minus if operation.operation == OperationType.plus else OperationType.plus,
[operation.left, right.operand],
operation.type
)
)]
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from decompiler.pipeline.controlflowanalysis.expression_simplification.modification import FOLDABLE_CONSTANTS, constant_fold
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule
from decompiler.structures.pseudo import Constant, Expression, Operation


class CollapseConstants(SimplificationRule):
"""
Fold operations with only constants as operands:
"""

def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
if not all(isinstance(o, Constant) for o in operation.operands):
return []
if operation.operation not in FOLDABLE_CONSTANTS:
return []

return [(
operation,
constant_fold(operation.operation, operation.operands)
)]
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule
from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType, UnaryOperation


class CollapseMultNegOne(SimplificationRule):
"""
Simplifies expressions multiplied with -1.
`e0 * -1 -> -(e0)`
"""

def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
if operation.operation != OperationType.multiply:
return []
if not isinstance(operation, BinaryOperation):
raise TypeError(f"Expected BinaryOperation, got {type(operation)}")

right = operation.right
if not isinstance(right, Constant) or right.value != -1:
return []

return [(
operation,
UnaryOperation(
OperationType.negate,
[operation.left],
operation.type
)
)]
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from functools import reduce
from typing import Iterator

from decompiler.pipeline.controlflowanalysis.expression_simplification.modification import constant_fold
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule
from decompiler.structures.pseudo import Constant, Expression, Operation, OperationType, Type
from decompiler.structures.pseudo.operations import COMMUTATIVE_OPERATIONS


class CollectTerms(SimplificationRule):
def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
if operation.operation not in COMMUTATIVE_OPERATIONS:
return []
if not isinstance(operation, Operation):
raise TypeError(f"Expected Operation, got {type(operation)}")

operands = list(_collect_constants(operation))
if len(operands) <= 1:
return []

first, *rest = operands

folded_constant = reduce(
lambda c0, c1: constant_fold(operation.operation, [c0, c1]),
rest,
first
)

identity_constant = _identity_constant(operation.operation, operation.type)
return [
(first, folded_constant),
*((o, identity_constant) for o in rest)
]


def _collect_constants(operation: Operation) -> Iterator[Constant]:
operation_type = operation.operation
operand_type = operation.type

context_stack: list[Operation] = [operation]
while context_stack:
current_operation = context_stack.pop()

for i, operand in enumerate(current_operation.operands):
if operand.type != operand_type:
continue

if isinstance(operand, Operation):
if operand.operation == operation_type:
context_stack.append(operand)
continue
elif isinstance(operand, Constant) and _identity_constant(operation_type, operand_type).value != operand.value:
yield operand


def _identity_constant(operation: OperationType, var_type: Type) -> Constant:
match operation:
case OperationType.plus | OperationType.bitwise_xor | OperationType.bitwise_or:
return Constant(0, var_type)
case OperationType.multiply | OperationType.multiply_us:
return Constant(1, var_type)
case OperationType.bitwise_and:
return constant_fold(OperationType.bitwise_not, [Constant(0, var_type)])
case _:
raise NotImplementedError()
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from decompiler.pipeline.controlflowanalysis.expression_simplification.modification import normalize_int
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule
from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType


class FixAddSubSign(SimplificationRule):
"""
Changes add/sub when variable type is signed.
- `V - a -> E + (-a)` when signed(a) < 0
- `V + a -> E - (-a)` when signed(a) < 0
"""

def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
if operation.operation not in (OperationType.plus, OperationType.minus):
return []
if not isinstance(operation, BinaryOperation):
raise TypeError(f"Expected BinaryOperation, got {type(operation)}")

right = operation.right
if not isinstance(right, Constant):
return []

con_type = right.type
if not isinstance(con_type, Integer):
return []

a = normalize_int(right.value, con_type.size, True)
if a >= 0:
return []

neg_a = Constant(
normalize_int(-a, con_type.size, con_type.signed),
con_type
)
return [(
operation,
BinaryOperation(
OperationType.plus if operation.operation == OperationType.minus else OperationType.minus,
[operation.left, neg_a]
)
)]
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule
from decompiler.structures.pseudo import Expression, Operation, OperationType, UnaryOperation


class SimplifyRedundantReference(SimplificationRule):
"""
Removes redundant nesting of referencing, immediately followed by referencing.
`*(&(e0)) -> e0`
"""

def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
match operation:
case UnaryOperation(
operation=OperationType.dereference,
operand=UnaryOperation(operation=OperationType.address, operand=operand)
):
return [(operation, operand)]
case _:
return []
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule
from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType, UnaryOperation


class SimplifyTrivialArithmetic(SimplificationRule):
"""
Simplifies trivial arithmetic:
- `e + 0 -> e`
- `e - 0 -> e`
- `e * 1 -> e`
- `e u* 1 -> e`
- `e * -1 -> -e`
- `e u* -1 -> -e`
- `e / 1 -> e`
- `e u/ 1 -> e`
- `e / -1 -> -e`
"""

def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
match operation:
case BinaryOperation(operation=OperationType.plus | OperationType.minus, right=Constant(value=0)):
return [(operation, operation.left)]
case BinaryOperation(
operation=OperationType.multiply | OperationType.multiply_us | OperationType.divide | OperationType.divide_us,
right=Constant(value=1),
):
return [(operation, operation.left)]
case BinaryOperation(operation=OperationType.multiply, right=Constant(value=0)):
return [(operation, Constant(0, operation.type))]
case BinaryOperation(
operation=OperationType.multiply | OperationType.multiply_us | OperationType.divide,
right=Constant(value=-1)
):
return [(operation, UnaryOperation(OperationType.negate, [operation.left]))]
case _:
return []
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule
from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType


class SimplifyTrivialBitArithmetic(SimplificationRule):
"""
Simplifies trivial bit arithmetic:
- `e | 0 -> e`
- `e | e -> e`
- `e & 0 -> 0`
- `e & e -> e`
- `e ^ 0 -> e`
- `e ^ e -> 0`
"""

def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
match operation:
case BinaryOperation(operation=OperationType.bitwise_or | OperationType.bitwise_xor, right=Constant(value=0)):
return [(operation, operation.left)]
case BinaryOperation(operation=OperationType.bitwise_and, right=Constant(value=0)):
return [(operation, Constant(0, operation.type))]
case BinaryOperation(operation=OperationType.bitwise_or | OperationType.bitwise_and, left=left, right=right) if left == right:
return [(operation, operation.left)]
case BinaryOperation(operation=OperationType.bitwise_xor, left=left, right=right) if left == right:
return [(operation, Constant(0, operation.type))]
case _:
return []
Loading

0 comments on commit 58ef1af

Please sign in to comment.