Skip to content

Commit

Permalink
Merge branch 'main' into conditional-out-of-ssa
Browse files Browse the repository at this point in the history
  • Loading branch information
ebehner authored Jun 20, 2024
2 parents 658ad35 + 1e9c13c commit 67e1cb4
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 21 deletions.
6 changes: 3 additions & 3 deletions decompiler/frontend/binaryninja/handlers/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,13 @@ def _get_unknown_value(self, variable: DataVariable):
type = PseudoArrayType(self._lifter.lift(data[1]), len(data[0]))
data = ConstantComposition([Constant(x, type.type) for x in data[0]], type)
else:
data, type = get_raw_bytes(variable.address, self._view), Pointer(CustomType.void())
data, type = get_raw_bytes(variable.address, self._view), Pointer(CustomType.void(), self._view.address_size * BYTE_SIZE)
return data, type

def _get_unknown_pointer_value(self, variable: DataVariable, callers: list[int] = None):
"""Return symbol, datavariable, address, string or raw bytes for a value of a datavariable(!) (dv should be a pointer)."""
if not addr_in_section(self._view, variable.value):
type = Pointer(CustomType.void())
type = Pointer(CustomType.void(), self._view.address_size * BYTE_SIZE)
return Constant(variable.value, type), type

if datavariable := self._view.get_data_var_at(variable.value):
Expand Down Expand Up @@ -288,7 +288,7 @@ def _get_unknown_pointer_value(self, variable: DataVariable, callers: list[int]
type,
)
else:
data, type = get_raw_bytes(variable.value, self._view), Pointer(CustomType.void())
data, type = get_raw_bytes(variable.value, self._view), Pointer(CustomType.void(), self._view.address_size * BYTE_SIZE)
return data, type


Expand Down
6 changes: 5 additions & 1 deletion decompiler/frontend/binaryninja/handlers/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ def __init__(self, lifter: ObserverLifter):
SymbolType.ImportedDataSymbol: Symbol,
SymbolType.ExternalSymbol: ImportedFunctionSymbol,
SymbolType.LibraryFunctionSymbol: Symbol,
SymbolType.SymbolicFunctionSymbol: FunctionSymbol,
}
# SymbolicFunctionSymbol is not available for Binary Ninja < 4
try:
self.SYMBOL_MAP[SymbolType.SymbolicFunctionSymbol] = FunctionSymbol
except AttributeError:
pass

def register(self):
"""Register the handler at the parent lifter."""
Expand Down
22 changes: 18 additions & 4 deletions decompiler/pipeline/commons/expressionpropagationcommons.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Operation,
OperationType,
Phi,
Relation,
Return,
UnaryOperation,
UnknownExpression,
Expand Down Expand Up @@ -265,15 +266,17 @@ def _definition_value_could_be_modified_via_memory_access_between_definition_and
) -> bool:
"""
Tests for definition containing aliased if a modification of the aliased value is possible, i.e.
via its pointer (ptr = &aliased) or via use of its reference (aka address) in function calls.
via its pointer (ptr = &aliased) or via use of its reference (aka address) in function calls
or if a relation is in between.
:return: true if a modification of the aliased value is possible (hence, the propagation should be avoided) false otherwise
"""
for aliased_variable in set(self._iter_aliased_variables(definition)):
dangerous_address_uses = self._get_dangerous_uses_of_variable_address(aliased_variable)
dangerous_pointer_uses = self._get_dangerous_uses_of_pointer_to_variable(aliased_variable)
if dangerous_address_uses or dangerous_pointer_uses:
dangerous_uses = dangerous_pointer_uses.union(dangerous_address_uses)
dangerous_alias_uses = self._get_dangerous_relations_between_definition_and_target(aliased_variable)
dangerous_uses = dangerous_pointer_uses | dangerous_address_uses | dangerous_alias_uses
if dangerous_uses:
if self._has_any_of_dangerous_uses_between_definition_and_target(definition, target, dangerous_uses):
return True
return False
Expand Down Expand Up @@ -353,6 +356,17 @@ def _get_dangerous_uses_of_pointer_to_variable(self, var: Variable) -> Set[Instr
dangerous_uses.update(self._get_dangerous_uses_of_pointer(pointer))
return dangerous_uses

def _get_dangerous_relations_between_definition_and_target(self, alias_variable: Variable) -> Set[Relation]:
"""Return all relations of the alias variable."""
relations = set()
# Collect all relations for alias_variable ignoring SSA
for basic_block in self._cfg:
for instruction in basic_block:
if isinstance(instruction, Relation) and instruction.destination.name == alias_variable.name:
relations |= {instruction}

return relations

def _get_dangerous_uses_of_pointer(self, pointer: Variable) -> Set[Instruction]:
"""
:param pointer to a variable
Expand Down Expand Up @@ -438,7 +452,7 @@ def _is_aliased_variable(expression: Expression) -> bool:
def _contains_writeable_global_variable(expression: Assignment) -> bool:
"""
:param expression: Assignment expression to be tested
:return: true if any requirement of expression is a GlobalVariable
:return: true if any requirement of expression is a writeable GlobalVariable
"""
for expr in expression.destination.requirements:
if isinstance(expr, GlobalVariable) and not expr.is_constant:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@ def constant_fold(operation: OperationType, constants: list[Constant], result_ty
)


def _constant_fold_arithmetic_binary(constants: list[Constant], fun: Callable[[int, int], int], norm_sign: Optional[bool] = None) -> int:
def _constant_fold_arithmetic_binary(
constants: list[Constant],
fun: Callable[[int, int], int],
norm_sign: Optional[bool] = None,
allow_mismatched_sizes: bool = False,
) -> int:
"""
Fold an arithmetic binary operation with constants as operands.
Expand All @@ -84,7 +89,7 @@ def _constant_fold_arithmetic_binary(constants: list[Constant], fun: Callable[[i

if len(constants) != 2:
raise IncompatibleOperandCount(f"Expected exactly 2 constants to fold, got {len(constants)}.")
if not all(constant.type.size == constants[0].type.size for constant in constants):
if not allow_mismatched_sizes and not all(constant.type.size == constants[0].type.size for constant in constants):
raise UnsupportedMismatchedSizes(f"Can not fold constants with different sizes: {[constant.type for constant in constants]}")

left, right = constants
Expand Down Expand Up @@ -137,13 +142,19 @@ def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], in
return fun(normalize_int(left.value, left.type.size, norm_signed), right.value)


def remainder(n, d):
return (-1 if n < 0 else 1) * (n % d)


_OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], int]] = {
OperationType.minus: partial(_constant_fold_arithmetic_binary, fun=operator.sub),
OperationType.plus: partial(_constant_fold_arithmetic_binary, fun=operator.add),
OperationType.multiply: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=True),
OperationType.multiply_us: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=False),
OperationType.divide: partial(_constant_fold_arithmetic_binary, fun=operator.floordiv, norm_sign=True),
OperationType.divide_us: partial(_constant_fold_arithmetic_binary, fun=operator.floordiv, norm_sign=False),
OperationType.modulo: partial(_constant_fold_arithmetic_binary, fun=remainder, norm_sign=True, allow_mismatched_sizes=True),
OperationType.modulo_us: partial(_constant_fold_arithmetic_binary, fun=operator.mod, norm_sign=False, allow_mismatched_sizes=True),
OperationType.negate: partial(_constant_fold_arithmetic_unary, fun=operator.neg),
OperationType.left_shift: partial(_constant_fold_shift, fun=operator.lshift, signed=True),
OperationType.right_shift: partial(_constant_fold_shift, fun=operator.rshift, signed=True),
Expand Down
22 changes: 11 additions & 11 deletions decompiler/structures/ast/ast_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def is_empty(self) -> bool:

def copy(self) -> VirtualRootNode:
"""Return a copy of the ast node."""
return VirtualRootNode(self.reaching_condition)
return VirtualRootNode(self.reaching_condition.copy())

def accept(self, visitor: ASTVisitorInterface[T]) -> T:
return visitor.visit_root_node(self)
Expand All @@ -288,7 +288,7 @@ def __repr__(self) -> str:

def copy(self) -> SeqNode:
"""Return a copy of the ast node."""
return SeqNode(self.reaching_condition)
return SeqNode(self.reaching_condition.copy())

@property
def children(self) -> Tuple[AbstractSyntaxTreeNode, ...]:
Expand Down Expand Up @@ -375,7 +375,7 @@ def __repr__(self) -> str:

def copy(self) -> CodeNode:
"""Return a copy of the ast node."""
return CodeNode(self.instructions.copy(), self.reaching_condition)
return CodeNode([i.copy() for i in self.instructions], self.reaching_condition.copy())

@property
def children(self) -> Tuple[AbstractSyntaxTreeNode, ...]:
Expand Down Expand Up @@ -508,7 +508,7 @@ def __repr__(self) -> str:

def copy(self) -> ConditionNode:
"""Return a copy of the ast node."""
return ConditionNode(self.condition, self.reaching_condition)
return ConditionNode(self.condition.copy(), self.reaching_condition.copy())

@property
def children(self) -> Tuple[Union[TrueNode, FalseNode], ...]:
Expand Down Expand Up @@ -655,7 +655,7 @@ def __repr__(self) -> str:

def copy(self) -> TrueNode:
"""Return a copy of the ast node."""
return TrueNode(self.reaching_condition)
return TrueNode(self.reaching_condition.copy())

@property
def branch_condition(self) -> LogicCondition:
Expand All @@ -680,7 +680,7 @@ def __repr__(self) -> str:

def copy(self) -> FalseNode:
"""Return a copy of the ast node."""
return FalseNode(self.reaching_condition)
return FalseNode(self.reaching_condition.copy())

@property
def branch_condition(self) -> LogicCondition:
Expand Down Expand Up @@ -810,7 +810,7 @@ class WhileLoopNode(LoopNode):

def copy(self) -> WhileLoopNode:
"""Return a copy of the ast node."""
return WhileLoopNode(self.condition, self.reaching_condition)
return WhileLoopNode(self.condition.copy(), self.reaching_condition.copy())

@property
def loop_type(self) -> LoopType:
Expand All @@ -837,7 +837,7 @@ class DoWhileLoopNode(LoopNode):

def copy(self) -> DoWhileLoopNode:
"""Return a copy of the ast node."""
return DoWhileLoopNode(self.condition, self.reaching_condition)
return DoWhileLoopNode(self.condition.copy(), self.reaching_condition.copy())

@property
def loop_type(self) -> LoopType:
Expand Down Expand Up @@ -882,7 +882,7 @@ def __str__(self) -> str:

def copy(self) -> ForLoopNode:
"""Return a copy of the ast node."""
return ForLoopNode(self.declaration, self.condition, self.modification, self.reaching_condition)
return ForLoopNode(self.declaration.copy(), self.condition.copy(), self.modification.copy(), self.reaching_condition.copy())

@property
def loop_type(self) -> LoopType:
Expand Down Expand Up @@ -950,7 +950,7 @@ def __repr__(self) -> str:

def copy(self) -> SwitchNode:
"""Return a copy of the ast node."""
return SwitchNode(self.expression, self.reaching_condition)
return SwitchNode(self.expression.copy(), self.reaching_condition.copy())

@property
def children(self) -> Tuple[CaseNode]:
Expand Down Expand Up @@ -1084,7 +1084,7 @@ def __repr__(self) -> str:

def copy(self) -> CaseNode:
"""Return a copy of the ast node."""
return CaseNode(self.expression, self.constant, self.reaching_condition, self.break_case)
return CaseNode(self.expression.copy(), self.constant.copy(), self.reaching_condition.copy(), self.break_case)

@property
def does_end_with_break(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ def test_constant_fold_invalid_value_type(
(OperationType.divide_us, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)),
(OperationType.divide_us, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.divide_us, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.modulo, [_c_i32(13), _c_i32(4)], Integer.int32_t(), _c_i32(1), nullcontext()),
(OperationType.modulo, [_c_i32(-2147483647), _c_i32(2)], Integer.int32_t(), _c_i32(-1), nullcontext()),
(OperationType.modulo, [_c_u32(4), _c_i32(3)], Integer.int32_t(), _c_i32(1), nullcontext()),
(OperationType.modulo, [_c_i32(4), _c_i16(3)], Integer.int32_t(), _c_i32(1), nullcontext()),
(OperationType.modulo, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.modulo, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.modulo_us, [_c_i32(13), _c_i32(4)], Integer.int32_t(), _c_i32(1), nullcontext()),
(OperationType.modulo_us, [_c_i32(-2147483647), _c_i32(2)], Integer.int32_t(), _c_i32(1), nullcontext()),
(OperationType.modulo_us, [_c_u32(4), _c_i32(3)], Integer.int32_t(), _c_i32(1), nullcontext()),
(OperationType.modulo_us, [_c_i32(4), _c_i16(3)], Integer.int32_t(), _c_i32(1), nullcontext()),
(OperationType.modulo_us, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.modulo_us, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.negate, [_c_i32(3)], Integer.int32_t(), _c_i32(-3), nullcontext()),
(OperationType.negate, [_c_i32(-2147483648)], Integer.int32_t(), _c_i32(-2147483648), nullcontext()),
(OperationType.negate, [], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
Expand Down
58 changes: 58 additions & 0 deletions tests/pipeline/dataflowanalysis/test_expression_propagation_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,64 @@ def test_dangerous_reference_use_in_single_block_graph():
assert _graphs_equal(in_cfg, out_cfg)


def test_dangerous_relation_in_between():
"""
Don't propagate y#0 into rand(x#0) because of possible change in between (relation)
+-----------------+
| 0. |
| x#0 = y#0 |
| memset(y#0) |
| y#1 -> y#0 |
| z#0 = rand(x#0) |
| return z#0 |
+-----------------+
+-----------------+
| 0. |
| x#0 = y#0 |
| memset(y#0) |
| y#1 -> y#0 |
| z#0 = rand(x#0) |
| return z#0 |
+-----------------+
"""
in_cfg, out_cfg = _graph_with_dangerous_relation_between()
_run_expression_propagation(in_cfg)
assert _graphs_equal(in_cfg, out_cfg)


def _graph_with_dangerous_relation_between():
in_cfg = ControlFlowGraph()
x = vars("x", 2, aliased=False)
y = vars("y", 2, aliased=True)
z = vars("z", 1, aliased=False)
c = const(11)
in_node = BasicBlock(
0,
[
_assign(x[0], y[0]),
_call("memset", [], [y[0]]),
Relation(y[1], y[0]),
_call("rand", [z[0]], [x[0]]),
_ret(z[0]),
],
)
in_cfg.add_node(in_node)
out_cfg = ControlFlowGraph()
out_node = BasicBlock(
0,
[
_assign(x[0], y[0]),
_call("memset", [], [y[0]]),
Relation(y[1], y[0]),
_call("rand", [z[0]], [x[0]]),
_ret(z[0]),
],
)
out_cfg.add_node(out_node)
return in_cfg, out_cfg


def _graphs_with_dangerous_reference_use() -> Tuple[ControlFlowGraph, ControlFlowGraph]:
in_cfg = ControlFlowGraph()
x = vars("x", 2, aliased=False)
Expand Down

0 comments on commit 67e1cb4

Please sign in to comment.