Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Expression simplifcation] Improve error handling #352

Merged
merged 17 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,91 @@
from functools import partial
from typing import Callable, Optional

from decompiler.structures.pseudo import Constant, Integer, OperationType
from decompiler.structures.pseudo import Constant, Integer, OperationType, Type
from decompiler.util.integer_util import normalize_int

# The first three exception types indicate that an operation is not suitable for constant folding.
# They do NOT indicate that the input was malformed in any way.
# The idea is that the caller of constant_fold does not need to verify that folding is possible.

def constant_fold(operation: OperationType, constants: list[Constant]) -> Constant:

class UnsupportedOperationType(Exception):
rihi marked this conversation as resolved.
Show resolved Hide resolved
"""Indicates that the specified Operation is not supported"""

pass


class UnsupportedValueType(Exception):

"""Indicates that the value type of one constant is not supported."""

pass


class UnsupportedMismatchedSizes(Exception):
"""Indicates that folding of different sized constants is not supported for the specified operation."""

pass


class IncompatibleOperandCount(Exception):
"""Indicates that the specified operation type is not defined for the number of constants specified"""

pass


def constant_fold(operation: OperationType, constants: list[Constant], result_type: Type) -> Constant:
"""
Fold operation with constants as operands.

:param operation: The operation.
:param constants: All constant operands of the operation.
Count of operands must be compatible with the specified operation type.
:param result_type: What type the folded constant should have.
:return: A constant representing the result of the operation.
:raises:
UnsupportedOperationType: Thrown if the specified operation is not supported.
UnsupportedValueType: Thrown if constants contain value of types not supported. Currently only ints are supported.
UnsupportedMismatchedSizes: Thrown if constants types have different sizes and folding of different sized
constants is not supported for the specified operation.
rihi marked this conversation as resolved.
Show resolved Hide resolved
IncompatibleOperandCount: Thrown if the specified operation type is not defined for the number of constants in `constants`.
"""

if operation not in _OPERATION_TO_FOLD_FUNCTION:
raise ValueError(f"Constant folding not implemented for operation '{operation}'.")
raise UnsupportedOperationType(f"Constant folding not implemented for operation '{operation}'.")

if not all(isinstance(v, int) for v in [c.value for c in constants]): # For now we only support integer value folding
raise UnsupportedValueType(f"Constant folding is not implemented for non int constant values: {[c.value for c in constants]}")

return _OPERATION_TO_FOLD_FUNCTION[operation](constants)
return Constant(
normalize_int(
_OPERATION_TO_FOLD_FUNCTION[operation](constants), result_type.size, isinstance(result_type, Integer) and result_type.signed
),
result_type,
)


def _constant_fold_arithmetic_binary(
constants: list[Constant], fun: Callable[[int, int], int], norm_sign: Optional[bool] = None
) -> Constant:
def _constant_fold_arithmetic_binary(constants: list[Constant], fun: Callable[[int, int], int], norm_sign: Optional[bool] = None) -> int:
"""
Fold an arithmetic binary operation with constants as operands.

:param constants: A list of exactly 2 constant operands.
:param constants: A list of exactly 2 constant values.
:param fun: The binary function to perform on the constants.
:param norm_sign: Optional boolean flag to indicate if/how to normalize the input constants to 'fun':
- None (default): no normalization
- True: normalize inputs, interpreted as signed values
- True: normalize inputs, interpreted as signed values
- False: normalize inputs, interpreted as unsigned values
:return: A constant representing the result of the operation.
:return: The result of the operation.
rihi marked this conversation as resolved.
Show resolved Hide resolved
:raises:
UnsupportedMismatchedSizes: Thrown if constants types have different sizes and folding of different sized
constants is not supported for the specified operation.
IncompatibleOperandCount: Thrown if the number of constants is not equal to 2.
"""

if len(constants) != 2:
raise ValueError(f"Expected exactly 2 constants to fold, got {len(constants)}.")
if not all(constant.type == constants[0].type for constant in constants):
raise ValueError(f"Can not fold constants with different types: {(constant.type for constant in constants)}")
if not all(isinstance(constant.type, Integer) for constant in constants):
raise ValueError(f"All constants must be integers, got {list(constant.type for constant in constants)}.")
raise IncompatibleOperandCount(f"Expected exactly 2 constants to fold, got {len(constants)}.")
if not all(constant.type.size == constants[0].type.size for constant in constants):
raise UnsupportedMismatchedSizes(f"Can not fold constants with different sizes: {[constant.type for constant in constants]}")

left, right = constants

Expand All @@ -51,49 +96,48 @@ def _constant_fold_arithmetic_binary(
left_value = normalize_int(left_value, left.type.size, norm_sign)
right_value = normalize_int(right_value, right.type.size, norm_sign)

return Constant(normalize_int(fun(left_value, right_value), left.type.size, left.type.signed), left.type)
return fun(left_value, right_value)


def _constant_fold_arithmetic_unary(constants: list[Constant], fun: Callable[[int], int]) -> Constant:
def _constant_fold_arithmetic_unary(constants: list[Constant], fun: Callable[[int], int]) -> int:
"""
Fold an arithmetic unary operation with a constant operand.

:param constants: A list containing a single constant operand.
:param fun: The unary function to perform on the constant.
:return: A constant representing the result of the operation.
:return: The result of the operation.
rihi marked this conversation as resolved.
Show resolved Hide resolved
:raises:
IncompatibleOperandCount: Thrown if the number of constants is not equal to 1.
"""

if len(constants) != 1:
raise ValueError("Expected exactly 1 constant to fold")
if not isinstance(constants[0].type, Integer):
raise ValueError(f"Constant must be of type integer: {constants[0].type}")
raise IncompatibleOperandCount("Expected exactly 1 constant to fold")

return Constant(normalize_int(fun(constants[0].value), constants[0].type.size, constants[0].type.signed), constants[0].type)
return fun(constants[0].value)


def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], int], signed: bool) -> Constant:
def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], int], signed: bool) -> int:
rihi marked this conversation as resolved.
Show resolved Hide resolved
"""
Fold a shift operation with constants as operands.

:param constants: A list of exactly 2 constant operands.
:param fun: The shift function to perform on the constants.
:param signed: Boolean flag indicating whether the shift is signed.
This is used to normalize the sign of the input constant to simulate unsigned shifts.
:return: A constant representing the result of the operation.
:return: The result of the operation.
:raises:
IncompatibleOperandCount: Thrown if the number of constants is not equal to 2.
"""

if len(constants) != 2:
raise ValueError("Expected exactly 2 constants to fold")
if not all(isinstance(constant.type, Integer) for constant in constants):
raise ValueError("All constants must be integers")
raise IncompatibleOperandCount("Expected exactly 2 constants to fold")

left, right = constants

shifted_value = fun(normalize_int(left.value, left.type.size, left.type.signed and signed), right.value)
return Constant(normalize_int(shifted_value, left.type.size, left.type.signed), left.type)
return fun(normalize_int(left.value, left.type.size, left.type.signed and signed), right.value)


_OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], Constant]] = {
_OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], int]] = {
OperationType.minus: partial(_constant_fold_arithmetic_binary, fun=operator.sub),
OperationType.plus: partial(_constant_fold_arithmetic_binary, fun=operator.add),
OperationType.multiply: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=True),
Expand All @@ -110,5 +154,4 @@ def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], in
OperationType.bitwise_not: partial(_constant_fold_arithmetic_unary, fun=operator.inv),
}


FOLDABLE_OPERATIONS = _OPERATION_TO_FOLD_FUNCTION.keys()
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import FOLDABLE_OPERATIONS, constant_fold
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule
from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import (
IncompatibleOperandCount,
UnsupportedMismatchedSizes,
UnsupportedOperationType,
UnsupportedValueType,
constant_fold,
)
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import MalformedData, SimplificationRule
from decompiler.structures.pseudo import Constant, Expression, Operation


Expand All @@ -11,7 +17,12 @@ class CollapseConstants(SimplificationRule):
def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
if not all(isinstance(o, Constant) for o in operation.operands):
rihi marked this conversation as resolved.
Show resolved Hide resolved
return []
if operation.operation not in FOLDABLE_OPERATIONS:

try:
folded_constant = constant_fold(operation.operation, operation.operands, operation.type)
except (UnsupportedOperationType, UnsupportedValueType, UnsupportedMismatchedSizes):
return []
except IncompatibleOperandCount as e:
raise MalformedData() from e

return [(operation, constant_fold(operation.operation, operation.operands))]
return [(operation, folded_constant)]
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from functools import reduce
from typing import Iterator

from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import FOLDABLE_OPERATIONS, constant_fold
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule
from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import (
FOLDABLE_OPERATIONS,
IncompatibleOperandCount,
UnsupportedValueType,
constant_fold,
)
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import MalformedData, SimplificationRule
from decompiler.structures.pseudo import Constant, Expression, Operation, OperationType, Type
from decompiler.structures.pseudo.operations import COMMUTATIVE_OPERATIONS

Expand All @@ -19,16 +24,21 @@ class CollapseNestedConstants(SimplificationRule):
def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
rihi marked this conversation as resolved.
Show resolved Hide resolved
if operation.operation not in _COLLAPSIBLE_OPERATIONS:
return []
if not isinstance(operation, Operation):
raise TypeError(f"Expected Operation, got {type(operation)}")

constants = list(_collect_constants(operation))
if len(constants) <= 1:
return []

first, *rest = constants

folded_constant = reduce(lambda c0, c1: constant_fold(operation.operation, [c0, c1]), rest, first)
# We don't need to catch UnsupportedOperationType, because check that operation is in _COLLAPSIBLE_OPERATIONS
# We don't need to catch UnsupportedMismatchedSizes, because '_collect_constants' only returns constants of the same type
try:
folded_constant = reduce(lambda c0, c1: constant_fold(operation.operation, [c0, c1], operation.type), rest, first)
except UnsupportedValueType:
return []
except IncompatibleOperandCount as e:
raise MalformedData() from e

identity_constant = _identity_constant(operation.operation, operation.type)
return [(first, folded_constant), *((constant, identity_constant) for constant in rest)]
Expand All @@ -51,7 +61,7 @@ def _collect_constants(operation: Operation) -> Iterator[Constant]:
current_operation = context_stack.pop()

for i, operand in enumerate(current_operation.operands):
if operand.type != operand_type:
if operand.type != operand_type: # This check could potentially be relaxed to only check for equal size
continue

if isinstance(operand, Operation):
Expand All @@ -72,6 +82,11 @@ def _identity_constant(operation: OperationType, var_type: Type) -> Constant:
case OperationType.multiply | OperationType.multiply_us:
return Constant(1, var_type)
case OperationType.bitwise_and:
return constant_fold(OperationType.bitwise_not, [Constant(0, var_type)])
# Should not throw any exception because:
# - OperationType.bitwise_not is foldable (UnsupportedOperationType)
# - constant has integer value, which is supported (UnsupportedValueType)
# - with only 1 constant there cant be mismatched sizes (UnsupportedMismatchedSizes)
# - bitwise_not has exactly one operand (IncompatibleOperandCount)
return constant_fold(OperationType.bitwise_not, [Constant(0, var_type)], var_type)
rihi marked this conversation as resolved.
Show resolved Hide resolved
case _:
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,13 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
:param operation: The operation to which the simplification rule should be applied.
:return: A list of tuples, each containing a pair of expressions representing the original
and simplified versions resulting from applying the simplification rule to the given operation.
:raises:
MalformedData: Thrown inf malformed data, like a dereference operation with two operands, is encountered.
"""
pass


class MalformedData(Exception):
"""Used to indicate that malformed data was encountered"""

pass
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
if operation.operation not in COMMUTATIVE_OPERATIONS:
return []
if not isinstance(operation, BinaryOperation):
raise ValueError(f"Expected BinaryOperation, got {operation}")
raise TypeError(f"Expected BinaryOperation, got {type(operation)}")

if isinstance(operation.left, Constant) and not isinstance(operation.right, Constant):
return [(operation, BinaryOperation(operation.operation, [operation.right, operation.left], operation.type, operation.tags))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,28 @@
class _ExpressionSimplificationBase(PipelineStage, ABC):
def run(self, task: DecompilerTask):
max_iterations = task.options.getint("expression-simplification.max_iterations")
self._simplify_instructions(self._get_instructions(task), max_iterations)
debug = task.options.getboolean("pipeline.debug", fallback=False)

self._simplify_instructions(self._get_instructions(task), max_iterations, debug)

@abstractmethod
def _get_instructions(self, task: DecompilerTask) -> list[Instruction]:
pass

@classmethod
def _simplify_instructions(cls, instructions: list[Instruction], max_iterations: int):
def _simplify_instructions(cls, instructions: list[Instruction], max_iterations: int, debug: bool):
rule_sets = [("pre-rules", _pre_rules), ("rules", _rules), ("post-rules", _post_rules)]
for rule_name, rule_set in rule_sets:
# max_iterations is counted per rule_set
iteration_count = cls._simplify_instructions_with_rule_set(instructions, rule_set, max_iterations)
iteration_count = cls._simplify_instructions_with_rule_set(instructions, rule_set, max_iterations, debug)
if iteration_count <= max_iterations:
logging.info(f"Expression simplification took {iteration_count} iterations for {rule_name}")
else:
logging.warning(f"Exceeded max iteration count for {rule_name}")

@classmethod
def _simplify_instructions_with_rule_set(
cls, instructions: list[Instruction], rule_set: list[SimplificationRule], max_iterations: int
cls, instructions: list[Instruction], rule_set: list[SimplificationRule], max_iterations: int, debug: bool
) -> int:
iteration_count = 0

Expand All @@ -57,7 +59,7 @@ def _simplify_instructions_with_rule_set(

for rule in rule_set:
for instruction in instructions:
additional_iterations = cls._simplify_instruction_with_rule(instruction, rule, max_iterations - iteration_count)
additional_iterations = cls._simplify_instruction_with_rule(instruction, rule, max_iterations - iteration_count, debug)
if additional_iterations > 0:
changes = True

Expand All @@ -68,7 +70,7 @@ def _simplify_instructions_with_rule_set(
return iteration_count

@classmethod
def _simplify_instruction_with_rule(cls, instruction: Instruction, rule: SimplificationRule, max_iterations: int) -> int:
def _simplify_instruction_with_rule(cls, instruction: Instruction, rule: SimplificationRule, max_iterations: int, debug: bool) -> int:
iteration_count = 0
for expression in instruction.subexpressions():
while True:
Expand All @@ -78,9 +80,17 @@ def _simplify_instruction_with_rule(cls, instruction: Instruction, rule: Simplif
if not isinstance(expression, Operation):
break

substitutions = rule.apply(expression)
try:
substitutions = rule.apply(expression)
except Exception as e:
if debug:
raise # re-raise the exception
else:
logging.exception(f"An unexpected error occurred while simplifying: {e}")
break # continue with next subexpression

if not substitutions:
break
break # continue with next subexpression

iteration_count += 1

Expand Down
Loading
Loading