Skip to content

Commit

Permalink
[Code Generator] Simplify Expressions (#318)
Browse files Browse the repository at this point in the history
* Create draft PR for #25

* Implement substitute_visitor

* Replace expression_simplification with expression_simplification_rules

* Implement expression simplification rules

* Add documentation to visit_mem_phi in substitute_visitor.py

Co-authored-by: Manuel Blatt <[email protected]>

* Add docstring to modification.py/normalize_int(int,int,bool)

* Rename modification.py/FOLDABLE_CONSTANTS to FOLDABLE_OPERATIONS

* Rename modification.py/_FOLD_HANDLER to _OPERATION_TO_FOLD_FUNCTION

* Move _simplify_instructions and _simplify_instructions_with_rule_set into _ExpressionSimplificationRulesBase

* Extract code from _simplify_instructions_with_rule_set to _simplify_instruction_with_rule

* Improve documentation in default.json for expression simplification max_iterations

* Add documentation to substitute_visitor.py

* Add additional test cases to test_substitute_visitor.py

- Assignment lhs
- UnaryOperation
- UnaryOperation ArrayInfo
- Phi

* Improve readability of _visit_operation(self,Operation) in substitute_visitor.py

* Improve readability of visit_unary_operation(self,UnaryOperation) in substitute_visitor.py

* Use visit instead of mapper function in visit_phi(self,Phi)

* Lift constraints of SubstituteVisitor that tried to uphold validity of the dataflow tree

Lift some constraints of what can be substituted. These constraints where in place in an attempt to prevent erroneous modifications to the dataflow tree, resulting in invalid states.
This commit shifts the responsibility of keeping the dataflow graph in a valid state to the user allowing for a more flexible implementation.

* Slight syntax changes to test_substitute_visitor.py

* Improve syntax in _OPERATION_TO_FOLD_FUNCTION

* Assert that instructions are not substituted in expression_simplification_rules.py

* Add clarifying comment to _simplify_instruction_with_rule

* Rename ExpressionSimplificationRules to ExpressionSimplification

* Create copies while substituting with SubstituteVisitor.equality

* Remove unused function multiply_int_with_constant

* Add documentation to modification.py

* Rename modification.py to constant_folding.py

* Extract method from visit_unary_operation

* Add documentation to collect_terms.py

* Add tests for ExpressionSimplification stage

* Update documentation for '_constant_fold_arithmetic_binary'

Co-authored-by: Manuel Blatt <[email protected]>

* Update documentation for 'normalize_int'

Co-authored-by: Manuel Blatt <[email protected]>

* Add comment explaining 'max_iterations'

Co-authored-by: Manuel Blatt <[email protected]>

* Improve documentation of CollectTerms

Co-authored-by: Manuel Blatt <[email protected]>

* Rename variable in collect_terms.py for readability

Co-authored-by: Manuel Blatt <[email protected]>

* Rename variable in simplify_redundant_reference.py for readability

Co-authored-by: Manuel Blatt <[email protected]>

* Improve documentation of TermOrder

Co-authored-by: Manuel Blatt <[email protected]>

* Remove unused collapse_mult_neg_one.py

* Fix test in test_substitute_visitor.py

* Rename simplification rule FixAddSubSign to PositiveConstants

* Rename variables in positive_constants.py

* Rename operands to constant in collect_terms.py

* Rename CollectTerms to CollapseNestedConstants

* Add missing unsigned multiply case to simplify_trivial_arithmetic.py

* Use 'neg' operation in SubToAdd

* Delete test_collapse_mult_neg_one.py

* Fix test_stage.py

* Slight change to test_stage.py

* Ignore MemPhi in SubstituteVisitor

---------

Co-authored-by: rihi <[email protected]>
Co-authored-by: rihi <[email protected]>
Co-authored-by: Manuel Blatt <[email protected]>
Co-authored-by: Manuel Blatt <[email protected]>
  • Loading branch information
5 people authored Sep 21, 2023
1 parent e1ff6e4 commit a7d49d2
Show file tree
Hide file tree
Showing 35 changed files with 1,774 additions and 392 deletions.
2 changes: 1 addition & 1 deletion decompiler/pipeline/controlflowanalysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .expression_simplification import ExpressionSimplification
from .expression_simplification.stages import ExpressionSimplificationAst, ExpressionSimplificationCfg
from .instruction_length_handler import InstructionLengthHandler
from .readability_based_refinement import ReadabilityBasedRefinement
from .variable_name_generation import VariableNameGeneration
142 changes: 0 additions & 142 deletions decompiler/pipeline/controlflowanalysis/expression_simplification.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import operator
from functools import partial
from typing import Callable, Optional

from decompiler.structures.pseudo import Constant, Integer, OperationType


def constant_fold(operation: OperationType, constants: list[Constant]) -> Constant:
"""
Fold operation with constants as operands.
:param operation: The operation.
:param constants: All constant operands of the operation.
:return: A constant representing the result of the operation.
"""

if operation not in _OPERATION_TO_FOLD_FUNCTION:
raise ValueError(f"Constant folding not implemented for operation '{operation}'.")

return _OPERATION_TO_FOLD_FUNCTION[operation](constants)


def _constant_fold_arithmetic_binary(
constants: list[Constant],
fun: Callable[[int, int], int],
norm_sign: Optional[bool] = None
) -> Constant:
"""
Fold an arithmetic binary operation with constants as operands.
:param constants: A list of exactly 2 constant operands.
:param fun: The binary function to perform on the constants.
:param norm_sign: Optional boolean flag to indicate if/how to normalize the input constants to 'fun':
- None (default): no normalization
- True: normalize inputs, interpreted as signed values
- False: normalize inputs, interpreted as unsigned values
:return: A constant representing the result of the operation.
"""

if len(constants) != 2:
raise ValueError(f"Expected exactly 2 constants to fold, got {len(constants)}.")
if not all(constant.type == constants[0].type for constant in constants):
raise ValueError(f"Can not fold constants with different types: {(constant.type for constant in constants)}")
if not all(isinstance(constant.type, Integer) for constant in constants):
raise ValueError(f"All constants must be integers, got {list(constant.type for constant in constants)}.")

left, right = constants

left_value = left.value
right_value = right.value
if norm_sign is not None:
left_value = normalize_int(left_value, left.type.size, norm_sign)
right_value = normalize_int(right_value, right.type.size, norm_sign)

return Constant(
normalize_int(fun(left_value, right_value), left.type.size, left.type.signed),
left.type
)


def _constant_fold_arithmetic_unary(constants: list[Constant], fun: Callable[[int], int]) -> Constant:
"""
Fold an arithmetic unary operation with a constant operand.
:param constants: A list containing a single constant operand.
:param fun: The unary function to perform on the constant.
:return: A constant representing the result of the operation.
"""

if len(constants) != 1:
raise ValueError("Expected exactly 1 constant to fold")
if not isinstance(constants[0].type, Integer):
raise ValueError(f"Constant must be of type integer: {constants[0].type}")

return Constant(normalize_int(fun(constants[0].value), constants[0].type.size, constants[0].type.signed), constants[0].type)


def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], int], signed: bool) -> Constant:
"""
Fold a shift operation with constants as operands.
:param constants: A list of exactly 2 constant operands.
:param fun: The shift function to perform on the constants.
:param signed: Boolean flag indicating whether the shift is signed.
This is used to normalize the sign of the input constant to simulate unsigned shifts.
:return: A constant representing the result of the operation.
"""

if len(constants) != 2:
raise ValueError("Expected exactly 2 constants to fold")
if not all(isinstance(constant.type, Integer) for constant in constants):
raise ValueError("All constants must be integers")

left, right = constants

shifted_value = fun(
normalize_int(left.value, left.type.size, left.type.signed and signed),
right.value
)
return Constant(
normalize_int(shifted_value, left.type.size, left.type.signed),
left.type
)


def normalize_int(v: int, size: int, signed: bool) -> int:
"""
Normalizes an integer value to a specific size and signedness.
This function takes an integer value 'v' and normalizes it to fit within
the specified 'size' in bits by discarding overflowing bits. If 'signed' is
true, the value is treated as a signed integer, i.e. interpreted as a two's complement.
Therefore the return value will be negative iff 'signed' is true and the most-significant bit is set.
:param v: The value to be normalized.
:param size: The desired bit size for the normalized integer.
:param signed: True if the integer should be treated as signed.
:return: The normalized integer value.
"""
value = v & ((1 << size) - 1)
if signed and value & (1 << (size - 1)):
return value - (1 << size)
else:
return value


_OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], Constant]] = {
OperationType.minus: partial(_constant_fold_arithmetic_binary, fun=operator.sub),
OperationType.plus: partial(_constant_fold_arithmetic_binary, fun=operator.add),
OperationType.multiply: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=True),
OperationType.multiply_us: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=False),
OperationType.divide: partial(_constant_fold_arithmetic_binary, fun=operator.floordiv, norm_sign=True),
OperationType.divide_us: partial(_constant_fold_arithmetic_binary, fun=operator.floordiv, norm_sign=False),
OperationType.negate: partial(_constant_fold_arithmetic_unary, fun=operator.neg),
OperationType.left_shift: partial(_constant_fold_shift, fun=operator.lshift, signed=True),
OperationType.right_shift: partial(_constant_fold_shift, fun=operator.rshift, signed=True),
OperationType.right_shift_us: partial(_constant_fold_shift, fun=operator.rshift, signed=False),
OperationType.bitwise_or: partial(_constant_fold_arithmetic_binary, fun=operator.or_),
OperationType.bitwise_and: partial(_constant_fold_arithmetic_binary, fun=operator.and_),
OperationType.bitwise_xor: partial(_constant_fold_arithmetic_binary, fun=operator.xor),
OperationType.bitwise_not: partial(_constant_fold_arithmetic_unary, fun=operator.inv),
}


FOLDABLE_OPERATIONS = _OPERATION_TO_FOLD_FUNCTION.keys()
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule
from decompiler.structures.pseudo import BinaryOperation, Expression, Operation, OperationType, UnaryOperation


class CollapseAddNeg(SimplificationRule):
"""
Simplifies additions/subtraction with negated expression.
- `e0 + -(e1) -> e0 - e1`
- `e0 - -(e1) -> e0 + e1`
"""

def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
if operation.operation not in [OperationType.plus, OperationType.minus]:
return []
if not isinstance(operation, BinaryOperation):
raise TypeError(f"Expected BinaryOperation, got {type(operation)}")

right = operation.right
if not isinstance(right, UnaryOperation) or right.operation != OperationType.negate:
return []

return [(
operation,
BinaryOperation(
OperationType.minus if operation.operation == OperationType.plus else OperationType.plus,
[operation.left, right.operand],
operation.type
)
)]
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import FOLDABLE_OPERATIONS, constant_fold
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule
from decompiler.structures.pseudo import Constant, Expression, Operation


class CollapseConstants(SimplificationRule):
"""
Fold operations with only constants as operands:
"""

def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
if not all(isinstance(o, Constant) for o in operation.operands):
return []
if operation.operation not in FOLDABLE_OPERATIONS:
return []

return [(
operation,
constant_fold(operation.operation, operation.operands)
)]
Loading

0 comments on commit a7d49d2

Please sign in to comment.