Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Expression simplifcation] Improve error handling #352

Merged
merged 17 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,80 @@
from functools import partial
from typing import Callable, Optional

from decompiler.structures.pseudo import Constant, Integer, OperationType
from decompiler.structures.pseudo import Constant, Integer, OperationType, Type
from decompiler.util.integer_util import normalize_int


def constant_fold(operation: OperationType, constants: list[Constant]) -> Constant:
class UnsupportedOperationType(Exception):
rihi marked this conversation as resolved.
Show resolved Hide resolved
"""Indicates that the specified Operation is not supported"""
pass


class UnsupportedValueType(Exception):
"""Indicates that the value type of one constant is not supported."""
pass


class UnsupportedMismatchedSizes(Exception):
"""Indicates that folding of different sized constants is not supported for the specified operation."""
pass


def constant_fold(operation: OperationType, constants: list[Constant], result_type: Type) -> Constant:
"""
Fold operation with constants as operands.

:param operation: The operation.
:param constants: All constant operands of the operation.
:param result_type: What type the folded constant should have.
:return: A constant representing the result of the operation.
:raises:
UnsupportedOperationType: Thrown if the specified operation is not supported.
UnsupportedValueType: Thrown if constants contain value of types not supported. Currently only ints are supported.
UnsupportedMismatchedValueSizes: Thrown if constants types have different sizes and folding of different sized
constants is not supported for the specified operation.
rihi marked this conversation as resolved.
Show resolved Hide resolved
"""

if not constants:
raise ValueError(f"Constants list may not be empty")

if operation not in _OPERATION_TO_FOLD_FUNCTION:
raise ValueError(f"Constant folding not implemented for operation '{operation}'.")
raise UnsupportedOperationType(f"Constant folding not implemented for operation '{operation}'.")

if not all(isinstance(v, int) for v in [c.value for c in constants]): # For now we only support integer value folding
raise UnsupportedValueType(f"Constant folding is not implemented for non int constant values: {[c.value for c in constants]}")

return _OPERATION_TO_FOLD_FUNCTION[operation](constants)
return Constant(
normalize_int(
_OPERATION_TO_FOLD_FUNCTION[operation](constants),
result_type.size,
isinstance(result_type, Integer) and result_type.signed
),
result_type
)


def _constant_fold_arithmetic_binary(
constants: list[Constant],
fun: Callable[[int, int], int],
norm_sign: Optional[bool] = None
) -> Constant:
) -> int:
"""
Fold an arithmetic binary operation with constants as operands.

:param constants: A list of exactly 2 constant operands.
:param constants: A list of exactly 2 constant values.
:param fun: The binary function to perform on the constants.
:param norm_sign: Optional boolean flag to indicate if/how to normalize the input constants to 'fun':
- None (default): no normalization
- True: normalize inputs, interpreted as signed values
- True: normalize inputs, interpreted as signed values
- False: normalize inputs, interpreted as unsigned values
:return: A constant representing the result of the operation.
:return: The result of the operation.
rihi marked this conversation as resolved.
Show resolved Hide resolved
"""

if len(constants) != 2:
raise ValueError(f"Expected exactly 2 constants to fold, got {len(constants)}.")
if not all(constant.type == constants[0].type for constant in constants):
raise ValueError(f"Can not fold constants with different types: {(constant.type for constant in constants)}")
if not all(isinstance(constant.type, Integer) for constant in constants):
raise ValueError(f"All constants must be integers, got {list(constant.type for constant in constants)}.")
if not all(constant.type.size == constants[0].type.size for constant in constants):
raise UnsupportedMismatchedSizes(f"Can not fold constants with different sizes: {[constant.type for constant in constants]}")

left, right = constants

Expand All @@ -53,58 +85,47 @@ def _constant_fold_arithmetic_binary(
left_value = normalize_int(left_value, left.type.size, norm_sign)
right_value = normalize_int(right_value, right.type.size, norm_sign)

return Constant(
normalize_int(fun(left_value, right_value), left.type.size, left.type.signed),
left.type
)
return fun(left_value, right_value)


def _constant_fold_arithmetic_unary(constants: list[Constant], fun: Callable[[int], int]) -> Constant:
def _constant_fold_arithmetic_unary(constants: list[Constant], fun: Callable[[int], int]) -> int:
"""
Fold an arithmetic unary operation with a constant operand.

:param constants: A list containing a single constant operand.
:param fun: The unary function to perform on the constant.
:return: A constant representing the result of the operation.
:return: The result of the operation.
rihi marked this conversation as resolved.
Show resolved Hide resolved
"""

if len(constants) != 1:
raise ValueError("Expected exactly 1 constant to fold")
if not isinstance(constants[0].type, Integer):
raise ValueError(f"Constant must be of type integer: {constants[0].type}")

return Constant(normalize_int(fun(constants[0].value), constants[0].type.size, constants[0].type.signed), constants[0].type)
return fun(constants[0].value)


def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], int], signed: bool) -> Constant:
def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], int], signed: bool) -> int:
rihi marked this conversation as resolved.
Show resolved Hide resolved
"""
Fold a shift operation with constants as operands.

:param constants: A list of exactly 2 constant operands.
:param fun: The shift function to perform on the constants.
:param signed: Boolean flag indicating whether the shift is signed.
This is used to normalize the sign of the input constant to simulate unsigned shifts.
:return: A constant representing the result of the operation.
:return: The result of the operation.
"""

if len(constants) != 2:
raise ValueError("Expected exactly 2 constants to fold")
if not all(isinstance(constant.type, Integer) for constant in constants):
raise ValueError("All constants must be integers")

left, right = constants

shifted_value = fun(
return fun(
normalize_int(left.value, left.type.size, left.type.signed and signed),
right.value
)
return Constant(
normalize_int(shifted_value, left.type.size, left.type.signed),
left.type
)


_OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], Constant]] = {
_OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], int]] = {
OperationType.minus: partial(_constant_fold_arithmetic_binary, fun=operator.sub),
OperationType.plus: partial(_constant_fold_arithmetic_binary, fun=operator.add),
OperationType.multiply: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=True),
Expand All @@ -121,5 +142,4 @@ def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], in
OperationType.bitwise_not: partial(_constant_fold_arithmetic_unary, fun=operator.inv),
}


FOLDABLE_OPERATIONS = _OPERATION_TO_FOLD_FUNCTION.keys()
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import FOLDABLE_OPERATIONS, constant_fold
from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import (
UnsupportedMismatchedSizes,
UnsupportedOperationType,
UnsupportedValueType,
constant_fold,
)
from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule
from decompiler.structures.pseudo import Constant, Expression, Operation

Expand All @@ -9,12 +14,14 @@ class CollapseConstants(SimplificationRule):
"""

def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
if not operation.operands:
rihi marked this conversation as resolved.
Show resolved Hide resolved
return [] # Is this even allowed?
rihi marked this conversation as resolved.
Show resolved Hide resolved
if not all(isinstance(o, Constant) for o in operation.operands):
return []
if operation.operation not in FOLDABLE_OPERATIONS:

try:
folded_constant = constant_fold(operation.operation, operation.operands, operation.type)
except (UnsupportedOperationType, UnsupportedValueType, UnsupportedMismatchedSizes):
return []

return [(
operation,
constant_fold(operation.operation, operation.operands)
)]
return [(operation, folded_constant)]
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

_COLLAPSIBLE_OPERATIONS = COMMUTATIVE_OPERATIONS & FOLDABLE_OPERATIONS


class CollapseNestedConstants(SimplificationRule):
"""
This rule walks the dafaflow tree and collects and folds constants in commutative operations.
Expand All @@ -17,17 +18,17 @@ class CollapseNestedConstants(SimplificationRule):
def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
rihi marked this conversation as resolved.
Show resolved Hide resolved
if operation.operation not in _COLLAPSIBLE_OPERATIONS:
return []
if not isinstance(operation, Operation):
raise TypeError(f"Expected Operation, got {type(operation)}")

constants = list(_collect_constants(operation))
if len(constants) <= 1:
return []

first, *rest = constants

# We don't need to catch any errors of 'constant_fold', because '_collect_constants' only returns constants
# which have the same type as 'operation.type' and we check that operation.operation is foldable.
rihi marked this conversation as resolved.
Show resolved Hide resolved
folded_constant = reduce(
lambda c0, c1: constant_fold(operation.operation, [c0, c1]),
lambda c0, c1: constant_fold(operation.operation, [c0, c1], operation.type),
rest,
first
)
Expand Down Expand Up @@ -56,7 +57,7 @@ def _collect_constants(operation: Operation) -> Iterator[Constant]:
current_operation = context_stack.pop()

for i, operand in enumerate(current_operation.operands):
if operand.type != operand_type:
if operand.type != operand_type: # This check could potentially be relaxed to only check for equal size
continue

if isinstance(operand, Operation):
Expand All @@ -77,6 +78,6 @@ def _identity_constant(operation: OperationType, var_type: Type) -> Constant:
case OperationType.multiply | OperationType.multiply_us:
return Constant(1, var_type)
case OperationType.bitwise_and:
return constant_fold(OperationType.bitwise_not, [Constant(0, var_type)])
return constant_fold(OperationType.bitwise_not, [Constant(0, var_type)], var_type)
rihi marked this conversation as resolved.
Show resolved Hide resolved
case _:
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
if operation.operation not in COMMUTATIVE_OPERATIONS:
return []
if not isinstance(operation, BinaryOperation):
raise ValueError(f"Expected BinaryOperation, got {operation}")
raise TypeError(f"Expected BinaryOperation, got {type(operation)}")

if isinstance(operation.left, Constant) and not isinstance(operation.right, Constant):
return [(operation, BinaryOperation(operation.operation, [operation.right, operation.left], operation.type, operation.tags))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,42 @@
from decompiler.task import DecompilerTask


class _SimplificationException(Exception):
pass


class _ExpressionSimplificationBase(PipelineStage, ABC):

def run(self, task: DecompilerTask):
max_iterations = task.options.getint("expression-simplification.max_iterations")
self._simplify_instructions(self._get_instructions(task), max_iterations)
debug = task.options.getboolean("pipeline.debug", fallback=False)

self._simplify_instructions(self._get_instructions(task), max_iterations, debug)

@abstractmethod
def _get_instructions(self, task: DecompilerTask) -> list[Instruction]:
pass

@classmethod
def _simplify_instructions(cls, instructions: list[Instruction], max_iterations: int):
def _simplify_instructions(cls, instructions: list[Instruction], max_iterations: int, debug: bool):
rule_sets = [
("pre-rules", _pre_rules),
("rules", _rules),
("post-rules", _post_rules)
]
for rule_name, rule_set in rule_sets:
# max_iterations is counted per rule_set
iteration_count = cls._simplify_instructions_with_rule_set(instructions, rule_set, max_iterations)
if iteration_count <= max_iterations:
logging.info(f"Expression simplification took {iteration_count} iterations for {rule_name}")
try:
rihi marked this conversation as resolved.
Show resolved Hide resolved
for rule_name, rule_set in rule_sets:
# max_iterations is counted per rule_set
iteration_count = cls._simplify_instructions_with_rule_set(instructions, rule_set, max_iterations)
if iteration_count <= max_iterations:
logging.info(f"Expression simplification took {iteration_count} iterations for {rule_name}")
else:
logging.warning(f"Exceeded max iteration count for {rule_name}")
except _SimplificationException as e:
if debug:
raise # re-raises the exception
else:
logging.warning(f"Exceeded max iteration count for {rule_name}")
logging.exception(f"An unexpected error occurred while simplifying: {e}")

@classmethod
def _simplify_instructions_with_rule_set(
Expand Down Expand Up @@ -91,7 +103,11 @@ def _simplify_instruction_with_rule(
if not isinstance(expression, Operation):
break

substitutions = rule.apply(expression)
try:
substitutions = rule.apply(expression)
except Exception as e:
raise _SimplificationException(e)
rihi marked this conversation as resolved.
Show resolved Hide resolved

if not substitutions:
break

Expand Down
Loading
Loading