Skip to content

Commit

Permalink
[Expression simplifcation] Improve error handling (#352)
Browse files Browse the repository at this point in the history
* Fix incorrect error type in term_order.py

* Improve error handling of constant folding

* Catch exception caused by simplification rules if in debug

* Catch empty operands operation in collapse_constants.py

* Update documentation of constant_folding.py

* Create MalformedInput exception

* Handle malformed input in collapse_constants.py

* Handle errors in collapse_nested_constants.py

* Add clarifying comment in collapse_nested_constants.py

* Catch exception in expression simplification earlier

* Fix documentation in constant_folding.py

* Change how exceptions are propagated

* Update outdated comment

* Remove unnecessary check in collapse_constants.py

* fix broken merge

---------

Co-authored-by: Manuel Blatt <[email protected]>
  • Loading branch information
rihi and blattm authored Oct 25, 2023
1 parent 5ed982a commit 1dc15d4
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,91 @@
from functools import partial
from typing import Callable, Optional

from decompiler.structures.pseudo import Constant, Integer, OperationType
from decompiler.structures.pseudo import Constant, Integer, OperationType, Type
from decompiler.util.integer_util import normalize_int

# The first three exception types indicate that an operation is not suitable for constant folding.
# They do NOT indicate that the input was malformed in any way.
# The idea is that the caller of constant_fold does not need to verify that folding is possible.

def constant_fold(operation: OperationType, constants: list[Constant]) -> Constant:

class UnsupportedOperationType(Exception):
"""Indicates that the specified Operation is not supported"""

pass


class UnsupportedValueType(Exception):

"""Indicates that the value type of one constant is not supported."""

pass


class UnsupportedMismatchedSizes(Exception):
"""Indicates that folding of different sized constants is not supported for the specified operation."""

pass


class IncompatibleOperandCount(Exception):
"""Indicates that the specified operation type is not defined for the number of constants specified"""

pass


def constant_fold(operation: OperationType, constants: list[Constant], result_type: Type) -> Constant:
"""
Fold operation with constants as operands.
:param operation: The operation.
:param constants: All constant operands of the operation.
Count of operands must be compatible with the specified operation type.
:param result_type: What type the folded constant should have.
:return: A constant representing the result of the operation.
:raises:
UnsupportedOperationType: Thrown if the specified operation is not supported.
UnsupportedValueType: Thrown if constants contain value of types not supported. Currently only ints are supported.
UnsupportedMismatchedSizes: Thrown if constants types have different sizes and folding of different sized
constants is not supported for the specified operation.
IncompatibleOperandCount: Thrown if the specified operation type is not defined for the number of constants in `constants`.
"""

if operation not in _OPERATION_TO_FOLD_FUNCTION:
raise ValueError(f"Constant folding not implemented for operation '{operation}'.")
raise UnsupportedOperationType(f"Constant folding not implemented for operation '{operation}'.")

if not all(isinstance(v, int) for v in [c.value for c in constants]): # For now we only support integer value folding
raise UnsupportedValueType(f"Constant folding is not implemented for non int constant values: {[c.value for c in constants]}")

return _OPERATION_TO_FOLD_FUNCTION[operation](constants)
return Constant(
normalize_int(
_OPERATION_TO_FOLD_FUNCTION[operation](constants), result_type.size, isinstance(result_type, Integer) and result_type.signed
),
result_type,
)


def _constant_fold_arithmetic_binary(
constants: list[Constant], fun: Callable[[int, int], int], norm_sign: Optional[bool] = None
) -> Constant:
def _constant_fold_arithmetic_binary(constants: list[Constant], fun: Callable[[int, int], int], norm_sign: Optional[bool] = None) -> int:
"""
Fold an arithmetic binary operation with constants as operands.
:param constants: A list of exactly 2 constant operands.
:param constants: A list of exactly 2 constant values.
:param fun: The binary function to perform on the constants.
:param norm_sign: Optional boolean flag to indicate if/how to normalize the input constants to 'fun':
- None (default): no normalization
- True: normalize inputs, interpreted as signed values
- True: normalize inputs, interpreted as signed values
- False: normalize inputs, interpreted as unsigned values
:return: A constant representing the result of the operation.
:return: The result of the operation.
:raises:
UnsupportedMismatchedSizes: Thrown if constants types have different sizes and folding of different sized
constants is not supported for the specified operation.
IncompatibleOperandCount: Thrown if the number of constants is not equal to 2.
"""

if len(constants) != 2:
raise ValueError(f"Expected exactly 2 constants to fold, got {len(constants)}.")
if not all(constant.type == constants[0].type for constant in constants):
raise ValueError(f"Can not fold constants with different types: {(constant.type for constant in constants)}")
if not all(isinstance(constant.type, Integer) for constant in constants):
raise ValueError(f"All constants must be integers, got {list(constant.type for constant in constants)}.")
raise IncompatibleOperandCount(f"Expected exactly 2 constants to fold, got {len(constants)}.")
if not all(constant.type.size == constants[0].type.size for constant in constants):
raise UnsupportedMismatchedSizes(f"Can not fold constants with different sizes: {[constant.type for constant in constants]}")

left, right = constants

Expand All @@ -51,49 +96,48 @@ def _constant_fold_arithmetic_binary(
left_value = normalize_int(left_value, left.type.size, norm_sign)
right_value = normalize_int(right_value, right.type.size, norm_sign)

return Constant(normalize_int(fun(left_value, right_value), left.type.size, left.type.signed), left.type)
return fun(left_value, right_value)


def _constant_fold_arithmetic_unary(constants: list[Constant], fun: Callable[[int], int]) -> Constant:
def _constant_fold_arithmetic_unary(constants: list[Constant], fun: Callable[[int], int]) -> int:
"""
Fold an arithmetic unary operation with a constant operand.
:param constants: A list containing a single constant operand.
:param fun: The unary function to perform on the constant.
:return: A constant representing the result of the operation.
:return: The result of the operation.
:raises:
IncompatibleOperandCount: Thrown if the number of constants is not equal to 1.
"""

if len(constants) != 1:
raise ValueError("Expected exactly 1 constant to fold")
if not isinstance(constants[0].type, Integer):
raise ValueError(f"Constant must be of type integer: {constants[0].type}")
raise IncompatibleOperandCount("Expected exactly 1 constant to fold")

return Constant(normalize_int(fun(constants[0].value), constants[0].type.size, constants[0].type.signed), constants[0].type)
return fun(constants[0].value)


def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], int], signed: bool) -> Constant:
def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], int], signed: bool) -> int:
"""
Fold a shift operation with constants as operands.
:param constants: A list of exactly 2 constant operands.
:param fun: The shift function to perform on the constants.
:param signed: Boolean flag indicating whether the shift is signed.
This is used to normalize the sign of the input constant to simulate unsigned shifts.
:return: A constant representing the result of the operation.
:return: The result of the operation.
:raises:
IncompatibleOperandCount: Thrown if the number of constants is not equal to 2.
"""

if len(constants) != 2:
raise ValueError("Expected exactly 2 constants to fold")
if not all(isinstance(constant.type, Integer) for constant in constants):
raise ValueError("All constants must be integers")
raise IncompatibleOperandCount("Expected exactly 2 constants to fold")

left, right = constants

shifted_value = fun(normalize_int(left.value, left.type.size, left.type.signed and signed), right.value)
return Constant(normalize_int(shifted_value, left.type.size, left.type.signed), left.type)
return fun(normalize_int(left.value, left.type.size, left.type.signed and signed), right.value)


_OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], Constant]] = {
_OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], int]] = {
OperationType.minus: partial(_constant_fold_arithmetic_binary, fun=operator.sub),
OperationType.plus: partial(_constant_fold_arithmetic_binary, fun=operator.add),
OperationType.multiply: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=True),
Expand All @@ -110,5 +154,4 @@ def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], in
OperationType.bitwise_not: partial(_constant_fold_arithmetic_unary, fun=operator.inv),
}


FOLDABLE_OPERATIONS = _OPERATION_TO_FOLD_FUNCTION.keys()
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import FOLDABLE_OPERATIONS, constant_fold
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule
from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import (
IncompatibleOperandCount,
UnsupportedMismatchedSizes,
UnsupportedOperationType,
UnsupportedValueType,
constant_fold,
)
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import MalformedData, SimplificationRule
from decompiler.structures.pseudo import Constant, Expression, Operation


Expand All @@ -11,7 +17,12 @@ class CollapseConstants(SimplificationRule):
def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
if not all(isinstance(o, Constant) for o in operation.operands):
return []
if operation.operation not in FOLDABLE_OPERATIONS:

try:
folded_constant = constant_fold(operation.operation, operation.operands, operation.type)
except (UnsupportedOperationType, UnsupportedValueType, UnsupportedMismatchedSizes):
return []
except IncompatibleOperandCount as e:
raise MalformedData() from e

return [(operation, constant_fold(operation.operation, operation.operands))]
return [(operation, folded_constant)]
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from functools import reduce
from typing import Iterator

from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import FOLDABLE_OPERATIONS, constant_fold
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule
from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import (
FOLDABLE_OPERATIONS,
IncompatibleOperandCount,
UnsupportedValueType,
constant_fold,
)
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import MalformedData, SimplificationRule
from decompiler.structures.pseudo import Constant, Expression, Operation, OperationType, Type
from decompiler.structures.pseudo.operations import COMMUTATIVE_OPERATIONS

Expand All @@ -19,16 +24,21 @@ class CollapseNestedConstants(SimplificationRule):
def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
if operation.operation not in _COLLAPSIBLE_OPERATIONS:
return []
if not isinstance(operation, Operation):
raise TypeError(f"Expected Operation, got {type(operation)}")

constants = list(_collect_constants(operation))
if len(constants) <= 1:
return []

first, *rest = constants

folded_constant = reduce(lambda c0, c1: constant_fold(operation.operation, [c0, c1]), rest, first)
# We don't need to catch UnsupportedOperationType, because check that operation is in _COLLAPSIBLE_OPERATIONS
# We don't need to catch UnsupportedMismatchedSizes, because '_collect_constants' only returns constants of the same type
try:
folded_constant = reduce(lambda c0, c1: constant_fold(operation.operation, [c0, c1], operation.type), rest, first)
except UnsupportedValueType:
return []
except IncompatibleOperandCount as e:
raise MalformedData() from e

identity_constant = _identity_constant(operation.operation, operation.type)
return [(first, folded_constant), *((constant, identity_constant) for constant in rest)]
Expand All @@ -51,7 +61,7 @@ def _collect_constants(operation: Operation) -> Iterator[Constant]:
current_operation = context_stack.pop()

for i, operand in enumerate(current_operation.operands):
if operand.type != operand_type:
if operand.type != operand_type: # This check could potentially be relaxed to only check for equal size
continue

if isinstance(operand, Operation):
Expand All @@ -72,6 +82,11 @@ def _identity_constant(operation: OperationType, var_type: Type) -> Constant:
case OperationType.multiply | OperationType.multiply_us:
return Constant(1, var_type)
case OperationType.bitwise_and:
return constant_fold(OperationType.bitwise_not, [Constant(0, var_type)])
# Should not throw any exception because:
# - OperationType.bitwise_not is foldable (UnsupportedOperationType)
# - constant has integer value, which is supported (UnsupportedValueType)
# - with only 1 constant there cant be mismatched sizes (UnsupportedMismatchedSizes)
# - bitwise_not has exactly one operand (IncompatibleOperandCount)
return constant_fold(OperationType.bitwise_not, [Constant(0, var_type)], var_type)
case _:
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,13 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
:param operation: The operation to which the simplification rule should be applied.
:return: A list of tuples, each containing a pair of expressions representing the original
and simplified versions resulting from applying the simplification rule to the given operation.
:raises:
MalformedData: Thrown inf malformed data, like a dereference operation with two operands, is encountered.
"""
pass


class MalformedData(Exception):
"""Used to indicate that malformed data was encountered"""

pass
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
if operation.operation not in COMMUTATIVE_OPERATIONS:
return []
if not isinstance(operation, BinaryOperation):
raise ValueError(f"Expected BinaryOperation, got {operation}")
raise TypeError(f"Expected BinaryOperation, got {type(operation)}")

if isinstance(operation.left, Constant) and not isinstance(operation.right, Constant):
return [(operation, BinaryOperation(operation.operation, [operation.right, operation.left], operation.type, operation.tags))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,28 @@
class _ExpressionSimplificationBase(PipelineStage, ABC):
def run(self, task: DecompilerTask):
max_iterations = task.options.getint("expression-simplification.max_iterations")
self._simplify_instructions(self._get_instructions(task), max_iterations)
debug = task.options.getboolean("pipeline.debug", fallback=False)

self._simplify_instructions(self._get_instructions(task), max_iterations, debug)

@abstractmethod
def _get_instructions(self, task: DecompilerTask) -> list[Instruction]:
pass

@classmethod
def _simplify_instructions(cls, instructions: list[Instruction], max_iterations: int):
def _simplify_instructions(cls, instructions: list[Instruction], max_iterations: int, debug: bool):
rule_sets = [("pre-rules", _pre_rules), ("rules", _rules), ("post-rules", _post_rules)]
for rule_name, rule_set in rule_sets:
# max_iterations is counted per rule_set
iteration_count = cls._simplify_instructions_with_rule_set(instructions, rule_set, max_iterations)
iteration_count = cls._simplify_instructions_with_rule_set(instructions, rule_set, max_iterations, debug)
if iteration_count <= max_iterations:
logging.info(f"Expression simplification took {iteration_count} iterations for {rule_name}")
else:
logging.warning(f"Exceeded max iteration count for {rule_name}")

@classmethod
def _simplify_instructions_with_rule_set(
cls, instructions: list[Instruction], rule_set: list[SimplificationRule], max_iterations: int
cls, instructions: list[Instruction], rule_set: list[SimplificationRule], max_iterations: int, debug: bool
) -> int:
iteration_count = 0

Expand All @@ -57,7 +59,7 @@ def _simplify_instructions_with_rule_set(

for rule in rule_set:
for instruction in instructions:
additional_iterations = cls._simplify_instruction_with_rule(instruction, rule, max_iterations - iteration_count)
additional_iterations = cls._simplify_instruction_with_rule(instruction, rule, max_iterations - iteration_count, debug)
if additional_iterations > 0:
changes = True

Expand All @@ -68,7 +70,7 @@ def _simplify_instructions_with_rule_set(
return iteration_count

@classmethod
def _simplify_instruction_with_rule(cls, instruction: Instruction, rule: SimplificationRule, max_iterations: int) -> int:
def _simplify_instruction_with_rule(cls, instruction: Instruction, rule: SimplificationRule, max_iterations: int, debug: bool) -> int:
iteration_count = 0
for expression in instruction.subexpressions():
while True:
Expand All @@ -78,9 +80,17 @@ def _simplify_instruction_with_rule(cls, instruction: Instruction, rule: Simplif
if not isinstance(expression, Operation):
break

substitutions = rule.apply(expression)
try:
substitutions = rule.apply(expression)
except Exception as e:
if debug:
raise # re-raise the exception
else:
logging.exception(f"An unexpected error occurred while simplifying: {e}")
break # continue with next subexpression

if not substitutions:
break
break # continue with next subexpression

iteration_count += 1

Expand Down
Loading

0 comments on commit 1dc15d4

Please sign in to comment.