Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
blattm committed Oct 25, 2023
1 parent e753a10 commit f5cf44e
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,26 @@

class UnsupportedOperationType(Exception):
"""Indicates that the specified Operation is not supported"""

pass


class UnsupportedValueType(Exception):

"""Indicates that the value type of one constant is not supported."""

pass


class UnsupportedMismatchedSizes(Exception):
"""Indicates that folding of different sized constants is not supported for the specified operation."""

pass


class IncompatibleOperandCount(Exception):
"""Indicates that the specified operation type is not defined for the number of constants specified"""

pass


Expand Down Expand Up @@ -55,19 +60,13 @@ def constant_fold(operation: OperationType, constants: list[Constant], result_ty

return Constant(
normalize_int(
_OPERATION_TO_FOLD_FUNCTION[operation](constants),
result_type.size,
isinstance(result_type, Integer) and result_type.signed
_OPERATION_TO_FOLD_FUNCTION[operation](constants), result_type.size, isinstance(result_type, Integer) and result_type.signed
),
result_type
result_type,
)


def _constant_fold_arithmetic_binary(
constants: list[Constant],
fun: Callable[[int, int], int],
norm_sign: Optional[bool] = None
) -> int:
def _constant_fold_arithmetic_binary(constants: list[Constant], fun: Callable[[int, int], int], norm_sign: Optional[bool] = None) -> int:
"""
Fold an arithmetic binary operation with constants as operands.
Expand Down Expand Up @@ -135,10 +134,7 @@ def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], in

left, right = constants

return fun(
normalize_int(left.value, left.type.size, left.type.signed and signed),
right.value
)
return fun(normalize_int(left.value, left.type.size, left.type.signed and signed), right.value)


_OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], int]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:
# We don't need to catch UnsupportedOperationType, because check that operation is in _COLLAPSIBLE_OPERATIONS
# We don't need to catch UnsupportedMismatchedSizes, because '_collect_constants' only returns constants of the same type
try:
folded_constant = reduce(
lambda c0, c1: constant_fold(operation.operation, [c0, c1], operation.type),
rest,
first
)
folded_constant = reduce(lambda c0, c1: constant_fold(operation.operation, [c0, c1], operation.type), rest, first)
except UnsupportedValueType:
return []
except IncompatibleOperandCount as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]:

class MalformedData(Exception):
"""Used to indicate that malformed data was encountered"""

pass
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@ def _get_instructions(self, task: DecompilerTask) -> list[Instruction]:

@classmethod
def _simplify_instructions(cls, instructions: list[Instruction], max_iterations: int, debug: bool):
rule_sets = [
("pre-rules", _pre_rules),
("rules", _rules),
("post-rules", _post_rules)
]
rule_sets = [("pre-rules", _pre_rules), ("rules", _rules), ("post-rules", _post_rules)]
for rule_name, rule_set in rule_sets:
# max_iterations is counted per rule_set
iteration_count = cls._simplify_instructions_with_rule_set(instructions, rule_set, max_iterations, debug)
Expand All @@ -53,11 +49,7 @@ def _simplify_instructions(cls, instructions: list[Instruction], max_iterations:

@classmethod
def _simplify_instructions_with_rule_set(
cls,
instructions: list[Instruction],
rule_set: list[SimplificationRule],
max_iterations: int,
debug: bool
cls, instructions: list[Instruction], rule_set: list[SimplificationRule], max_iterations: int, debug: bool
) -> int:
iteration_count = 0

Expand All @@ -78,13 +70,7 @@ def _simplify_instructions_with_rule_set(
return iteration_count

@classmethod
def _simplify_instruction_with_rule(
cls,
instruction: Instruction,
rule: SimplificationRule,
max_iterations: int,
debug: bool
) -> int:
def _simplify_instruction_with_rule(cls, instruction: Instruction, rule: SimplificationRule, max_iterations: int, debug: bool) -> int:
iteration_count = 0
for expression in instruction.subexpressions():
while True:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ def _c_float(value: float) -> Constant:
return Constant(value, Float.float())


@pytest.mark.parametrize(
["operation"],
[(operation,) for operation in OperationType if operation not in FOLDABLE_OPERATIONS]
)
@pytest.mark.parametrize(["operation"], [(operation,) for operation in OperationType if operation not in FOLDABLE_OPERATIONS])
def test_constant_fold_invalid_operations(operation: OperationType):
with pytest.raises(UnsupportedOperationType):
constant_fold(operation, [], Integer.int32_t())
Expand All @@ -44,14 +41,10 @@ def test_constant_fold_invalid_operations(operation: OperationType):
(OperationType.plus, [_c_i32(0), _c_i32(0)], Integer.int32_t(), _c_i32(0), nullcontext()),
(OperationType.plus, [_c_float(0.0), _c_float(0.0)], Float.float(), _c_float(0.0), pytest.raises(UnsupportedValueType)),
(OperationType.plus, [_c_i32(0), _c_float(0.0)], Integer.int32_t(), _c_i32(0), pytest.raises(UnsupportedValueType)),
]
],
)
def test_constant_fold_invalid_value_type(
operation: OperationType,
constants: list[Constant],
result_type: Type,
expected_result: Optional[Constant],
context
operation: OperationType, constants: list[Constant], result_type: Type, expected_result: Optional[Constant], context
):
with context:
assert constant_fold(operation, constants, result_type) == expected_result
Expand All @@ -67,92 +60,79 @@ def test_constant_fold_invalid_value_type(
(OperationType.plus, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)),
(OperationType.plus, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.plus, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.minus, [_c_i32(3), _c_i32(4)], Integer.int32_t(), _c_i32(-1), nullcontext()),
(OperationType.minus, [_c_i32(-2147483648), _c_i32(1)], Integer.int32_t(), _c_i32(2147483647), nullcontext()),
(OperationType.minus, [_c_u32(3), _c_u32(4)], Integer.uint32_t(), _c_u32(4294967295), nullcontext()),
(OperationType.minus, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(-1), nullcontext()),
(OperationType.minus, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)),
(OperationType.minus, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.minus, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.multiply, [_c_i32(3), _c_i32(4)], Integer.int32_t(), _c_i32(12), nullcontext()),
(OperationType.multiply, [_c_i32(-1073741824), _c_i32(2)], Integer.int32_t(), _c_i32(-2147483648), nullcontext()),
(OperationType.multiply, [_c_u32(3221225472), _c_u32(2)], Integer.uint32_t(), _c_u32(2147483648), nullcontext()),
(OperationType.multiply, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(12), nullcontext()),
(OperationType.multiply, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)),
(OperationType.multiply, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.multiply, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.multiply_us, [_c_i32(3), _c_i32(4)], Integer.int32_t(), _c_i32(12), nullcontext()),
(OperationType.multiply_us, [_c_i32(-1073741824), _c_i32(2)], Integer.int32_t(), _c_i32(-2147483648), nullcontext()),
(OperationType.multiply_us, [_c_u32(3221225472), _c_u32(2)], Integer.uint32_t(), _c_u32(2147483648), nullcontext()),
(OperationType.multiply_us, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(12), nullcontext()),
(OperationType.multiply_us, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)),
(OperationType.multiply_us, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.multiply_us, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.divide, [_c_i32(12), _c_i32(4)], Integer.int32_t(), _c_i32(3), nullcontext()),
(OperationType.divide, [_c_i32(-2147483648), _c_i32(2)], Integer.int32_t(), _c_i32(-1073741824), nullcontext()),
(OperationType.divide, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(0), nullcontext()),
(OperationType.divide, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)),
(OperationType.divide, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.divide, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.divide_us, [_c_i32(12), _c_i32(4)], Integer.int32_t(), _c_i32(3), nullcontext()),
(OperationType.divide_us, [_c_i32(-2147483648), _c_i32(2)], Integer.int32_t(), _c_i32(1073741824), nullcontext()),
(OperationType.divide_us, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(0), nullcontext()),
(OperationType.divide_us, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)),
(OperationType.divide_us, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.divide_us, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.negate, [_c_i32(3)], Integer.int32_t(), _c_i32(-3), nullcontext()),
(OperationType.negate, [_c_i32(-2147483648)], Integer.int32_t(), _c_i32(-2147483648), nullcontext()),
(OperationType.negate, [], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.negate, [_c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.left_shift, [_c_i32(3), _c_i32(4)], Integer.int32_t(), _c_i32(48), nullcontext()),
(OperationType.left_shift, [_c_i32(1073741824), _c_i32(1)], Integer.int32_t(), _c_i32(-2147483648), nullcontext()),
(OperationType.left_shift, [_c_u32(1073741824), _c_u32(1)], Integer.uint32_t(), _c_u32(2147483648), nullcontext()),
(OperationType.left_shift, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.left_shift, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.right_shift, [_c_i32(32), _c_i32(4)], Integer.int32_t(), _c_i32(2), nullcontext()),
(OperationType.right_shift, [_c_i32(-2147483648), _c_i32(1)], Integer.int32_t(), _c_i32(-1073741824), nullcontext()),
(OperationType.right_shift, [_c_u32(2147483648), _c_u32(1)], Integer.uint32_t(), _c_u32(1073741824), nullcontext()),
(OperationType.right_shift, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.right_shift, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.right_shift_us, [_c_i32(32), _c_i32(4)], Integer.int32_t(), _c_i32(2), nullcontext()),
(OperationType.right_shift_us, [_c_i32(-2147483648), _c_i32(1)], Integer.int32_t(), _c_i32(1073741824), nullcontext()),
(OperationType.right_shift_us, [_c_u32(2147483648), _c_u32(1)], Integer.uint32_t(), _c_u32(1073741824), nullcontext()),
(OperationType.right_shift_us, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.right_shift_us, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.bitwise_or, [_c_i32(85), _c_i32(34)], Integer.int32_t(), _c_i32(119), nullcontext()),
(OperationType.bitwise_or, [_c_i32(-2147483648), _c_i32(1)], Integer.int32_t(), _c_i32(-2147483647), nullcontext()),
(OperationType.bitwise_or, [_c_u32(2147483648), _c_u32(1)], Integer.uint32_t(), _c_u32(2147483649), nullcontext()),
(OperationType.bitwise_or, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(7), nullcontext()),
(OperationType.bitwise_or, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)),
(OperationType.bitwise_or, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.bitwise_or, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.bitwise_and, [_c_i32(85), _c_i32(51)], Integer.int32_t(), _c_i32(17), nullcontext()),
(OperationType.bitwise_and, [_c_i32(-2147483647), _c_i32(3)], Integer.int32_t(), _c_i32(1), nullcontext()),
(OperationType.bitwise_and, [_c_u32(2147483649), _c_u32(3)], Integer.uint32_t(), _c_u32(1), nullcontext()),
(OperationType.bitwise_and, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(0), nullcontext()),
(OperationType.bitwise_and, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)),
(OperationType.bitwise_and, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.bitwise_and, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.bitwise_xor, [_c_i32(85), _c_i32(51)], Integer.int32_t(), _c_i32(102), nullcontext()),
(OperationType.bitwise_xor, [_c_i32(-2147483647), _c_i32(-2147483646)], Integer.int32_t(), _c_i32(3), nullcontext()),
(OperationType.bitwise_xor, [_c_u32(2147483649), _c_u32(2147483650)], Integer.uint32_t(), _c_u32(3), nullcontext()),
(OperationType.bitwise_xor, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(7), nullcontext()),
(OperationType.bitwise_xor, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)),
(OperationType.bitwise_xor, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.bitwise_xor, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.bitwise_not, [_c_i32(6)], Integer.int32_t(), _c_i32(-7), nullcontext()),
(OperationType.bitwise_not, [_c_i32(-2147483648)], Integer.int32_t(), _c_i32(2147483647), nullcontext()),
(OperationType.bitwise_not, [_c_u32(2147483648)], Integer.uint32_t(), _c_u32(2147483647), nullcontext()),
Expand All @@ -161,11 +141,7 @@ def test_constant_fold_invalid_value_type(
],
)
def test_constant_fold(
operation: OperationType,
constants: list[Constant],
result_type: Type,
expected_result: Optional[Constant],
context
operation: OperationType, constants: list[Constant], result_type: Type, expected_result: Optional[Constant], context
):
with context:
assert constant_fold(operation, constants, result_type) == expected_result
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,8 @@ def _v_i32(name: str) -> Variable:
),
],
)
def test_simplify_instructions_with_rule_set(
rule_set: list[SimplificationRule],
instruction: Instruction,
expected_result: Instruction
):
_ExpressionSimplificationBase._simplify_instructions_with_rule_set(
[instruction],
rule_set,
100,
True
)
def test_simplify_instructions_with_rule_set(rule_set: list[SimplificationRule], instruction: Instruction, expected_result: Instruction):
_ExpressionSimplificationBase._simplify_instructions_with_rule_set([instruction], rule_set, 100, True)
assert instruction == expected_result


Expand All @@ -81,10 +72,5 @@ def test_simplify_instructions_with_rule_set(
def test_simplify_instructions_with_rule_set_max_iterations(
rule_set: list[SimplificationRule], instruction: Instruction, max_iterations: int, expect_exceed_max_iterations: bool
):
iterations = _ExpressionSimplificationBase._simplify_instructions_with_rule_set(
[instruction],
rule_set,
max_iterations,
True
)
iterations = _ExpressionSimplificationBase._simplify_instructions_with_rule_set([instruction], rule_set, max_iterations, True)
assert (iterations > max_iterations) == expect_exceed_max_iterations

0 comments on commit f5cf44e

Please sign in to comment.