Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Condition Based Refinement] Runtime Improvements #394

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

22 changes: 18 additions & 4 deletions decompiler/structures/graphs/nxgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Dict, Iterator, Optional, Tuple, TypeVar
from typing import Dict, Iterator, Optional, Tuple, TypeVar, Union

from networkx import bfs_edges # type: ignore
from networkx import (
Expand All @@ -18,7 +18,7 @@
topological_sort,
)

from .interface import EDGE, NODE, GraphInterface
from .interface import EDGE, NODE, GraphInterface, GraphNodeInterface

T = TypeVar("T", bound=GraphInterface)

Expand Down Expand Up @@ -49,13 +49,21 @@ def remove_edge(self, edge: EDGE):
self._graph.remove_edge(edge.source, edge.sink)

def get_roots(self) -> Tuple[NODE, ...]:
"""Return all nodes with in degree 0."""
"""Return all nodes with in-degree 0."""
return tuple(node for node, d in self._graph.in_degree() if not d)

def get_leaves(self) -> Tuple[NODE, ...]:
"""Return all nodes with out degree 0."""
"""Return all nodes with out-degree 0."""
return tuple(node for node, d in self._graph.out_degree() if not d)

def get_out_degree(self, node: NODE) -> int:
"""Return the out-degree of the given node."""
return self._graph.out_degree(node)

def get_ancestors(self, node: NODE) -> Iterator[NODE]:
"""Iterate all ancestors of the given node."""
yield from (child for _, child in bfs_edges(self._graph, node, reverse=True))

def __len__(self) -> int:
"""Return the amount of nodes in the graph."""
return len(self._graph.nodes)
Expand All @@ -68,6 +76,12 @@ def __iter__(self) -> Iterator[NODE]:
"""Iterate all nodes in the graph."""
yield from self._graph.nodes

def __contains__(self, obj: Union[NODE, EDGE]):
"""Check if a node or edge is contained in the graph."""
if isinstance(obj, GraphNodeInterface):
return obj in self._graph
return (obj.source, obj.sink, {"data": obj}) in self._graph.edges(data=True)

def iter_depth_first(self, source: NODE) -> Iterator[NODE]:
"""Iterate all nodes in dfs fashion."""
edges = dfs_edges(self._graph, source=source)
Expand Down
8 changes: 5 additions & 3 deletions decompiler/structures/logic/z3_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,11 @@ def simplify_z3_condition(self, z3_condition: BoolRef, resolve_negations: bool =
"""
if self._resolve_negations and resolve_negations:
z3_condition = self._resolve_negation(z3_condition)
if self._too_large_to_fully_simplify(z3_condition):
return simplify(Repeat(Tactic("ctx-simplify", ctx=z3_condition.ctx))(z3_condition).as_expr())
return Repeat(Tactic("ctx-solver-simplify", ctx=z3_condition.ctx))(z3_condition).as_expr()
z3_condition = simplify(z3_condition)
z3_condition = simplify(Repeat(Tactic("ctx-simplify", ctx=z3_condition.ctx))(z3_condition).as_expr())
if not self._too_large_to_fully_simplify(z3_condition):
z3_condition = simplify(Repeat(Tactic("ctx-solver-simplify", ctx=z3_condition.ctx))(z3_condition).as_expr())
return z3_condition

@staticmethod
def get_symbols(condition: BoolRef) -> Iterator[BoolRef]:
Expand Down
7 changes: 7 additions & 0 deletions decompiler/structures/logic/z3_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def is_complementary_to(self, other: LOGICCLASS) -> bool:
"""Check whether the condition is complementary to the given condition, i.e. self == Not(other)."""
if self.is_true or self.is_false or other.is_true or other.is_false:
return False
condition_symbols = set(self.get_symbols_as_string())
other_symbols = set(other.get_symbols_as_string())
if len(condition_symbols) != len(other_symbols) or any(symbol not in condition_symbols for symbol in other_symbols):
return False
return self.z3.does_imply(self._condition, Not(other._condition)) and self.z3.does_imply(Not(other._condition), self._condition)

def to_cnf(self) -> LOGICCLASS:
Expand Down Expand Up @@ -191,6 +195,9 @@ def substitute_by_true(self, condition: LOGICCLASS, condition_handler: Optional[
Example: substituting in the expression (a∨b)∧c the condition (a∨b) by true results in the condition c,
and substituting the condition c by true in the condition (a∨b)
"""
if self.is_equal_to(condition):
self._condition = BoolVal(True, ctx=self.context)
return self
self._condition = self.z3.simplify_z3_condition(And(self._condition, condition._condition))
if condition_handler:
self.remove_redundancy(condition_handler)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4932,106 +4932,107 @@ def test_extract_return(task):
assert branch.instructions == vertices[3].instructions


def test_hash_eq_problem(task):
"""
Hash and eq are not the same, therefore we have to be careful which one we want:

- eq: Same condition node in sense of same condition
- hash: same node in the graph
"""
arg1 = Variable("arg1", Integer.int32_t(), ssa_name=Variable("arg1", Integer.int32_t(), 0))
arg2 = Variable("arg2", Integer.int32_t(), ssa_name=Variable("arg2", Integer.int32_t(), 0))
var_2 = Variable("var_2", Integer.int32_t(), None, True, Variable("rax_1", Integer.int32_t(), 1, True, None))
var_5 = Variable("var_5", Integer.int32_t(), None, True, Variable("rax_2", Integer.int32_t(), 2, True, None))
var_6 = Variable("var_6", Integer.int32_t(), None, True, Variable("rax_5", Integer.int32_t(), 30, True, None))
var_7 = Variable("var_7", Integer.int32_t(), None, True, Variable("rax_3", Integer.int32_t(), 3, True, None))
task.graph.add_nodes_from(
vertices := [
BasicBlock(0, instructions=[Branch(Condition(OperationType.equal, [arg1, Constant(1, Integer.int32_t())]))]),
BasicBlock(
1,
instructions=[
Assignment(var_2, BinaryOperation(OperationType.plus, [var_2, Constant(1, Integer.int32_t())])),
Branch(Condition(OperationType.not_equal, [var_2, Constant(0, Integer.int32_t())])),
],
),
BasicBlock(
2,
instructions=[
Assignment(ListOperation([]), Call(imp_function_symbol("sub_140019288"), [arg2])),
Branch(Condition(OperationType.equal, [arg1, Constant(0, Integer.int32_t())])),
],
),
BasicBlock(
3,
instructions=[
Assignment(ListOperation([]), Call(imp_function_symbol("scanf"), [Constant(0x804B01F), var_5])),
Branch(Condition(OperationType.not_equal, [var_5, Constant(0, Integer.int32_t())])),
],
),
BasicBlock(
4, instructions=[Assignment(var_5, Constant(0, Integer.int32_t())), Assignment(var_7, Constant(-1, Integer.int32_t()))]
),
BasicBlock(
5,
instructions=[
Assignment(var_5, Constant(0, Integer.int32_t())),
Assignment(var_7, Constant(-1, Integer.int32_t())),
Assignment(arg1, Constant(0, Integer.int32_t())),
Assignment(var_2, Constant(0, Integer.int32_t())),
],
),
BasicBlock(
6,
instructions=[
Assignment(var_5, Constant(0, Integer.int32_t())),
Assignment(var_7, Constant(-1, Integer.int32_t())),
Assignment(var_2, Constant(0, Integer.int32_t())),
],
),
BasicBlock(7, instructions=[Assignment(ListOperation([]), Call(imp_function_symbol("sub_1400193a8"), []))]),
BasicBlock(
8,
instructions=[
Assignment(ListOperation([]), Call(imp_function_symbol("scanf"), [Constant(0x804B01F), var_6])),
Branch(Condition(OperationType.greater_us, [var_6, Constant(0, Integer.int32_t())])),
],
),
BasicBlock(9, instructions=[Assignment(arg1, Constant(1, Integer.int32_t()))]),
BasicBlock(10, instructions=[Return([arg1])]),
]
)
task.graph.add_edges_from(
[
TrueCase(vertices[0], vertices[1]),
FalseCase(vertices[0], vertices[2]),
TrueCase(vertices[1], vertices[3]),
FalseCase(vertices[1], vertices[4]),
TrueCase(vertices[2], vertices[5]),
FalseCase(vertices[2], vertices[6]),
TrueCase(vertices[3], vertices[7]),
FalseCase(vertices[3], vertices[8]),
UnconditionalEdge(vertices[4], vertices[7]),
UnconditionalEdge(vertices[5], vertices[10]),
UnconditionalEdge(vertices[6], vertices[9]),
UnconditionalEdge(vertices[7], vertices[9]),
TrueCase(vertices[8], vertices[9]),
FalseCase(vertices[8], vertices[10]),
UnconditionalEdge(vertices[9], vertices[10]),
]
)
PatternIndependentRestructuring().run(task)
assert any(isinstance(node, SwitchNode) for node in task.syntax_tree)
var_2_conditions = []
for node in task.syntax_tree.get_condition_nodes_post_order():
if (
not node.condition.is_symbol
and node.condition.is_literal
and str(task.syntax_tree.condition_map[~node.condition]) in {"var_2 != 0x0"}
):
node.switch_branches()
if node.condition.is_symbol and str(task.syntax_tree.condition_map[node.condition]) in {"var_2 != 0x0"}:
var_2_conditions.append(node)
assert len(var_2_conditions) == 2
assert var_2_conditions[0] == var_2_conditions[1]
assert hash(var_2_conditions[0]) != hash(var_2_conditions[1])
# fix in Issue 28
# def test_hash_eq_problem(task):
# """
# Hash and eq are not the same, therefore we have to be careful which one we want:
#
# - eq: Same condition node in sense of same condition
# - hash: same node in the graph
# """
# arg1 = Variable("arg1", Integer.int32_t(), ssa_name=Variable("arg1", Integer.int32_t(), 0))
# arg2 = Variable("arg2", Integer.int32_t(), ssa_name=Variable("arg2", Integer.int32_t(), 0))
# var_2 = Variable("var_2", Integer.int32_t(), None, True, Variable("rax_1", Integer.int32_t(), 1, True, None))
# var_5 = Variable("var_5", Integer.int32_t(), None, True, Variable("rax_2", Integer.int32_t(), 2, True, None))
# var_6 = Variable("var_6", Integer.int32_t(), None, True, Variable("rax_5", Integer.int32_t(), 30, True, None))
# var_7 = Variable("var_7", Integer.int32_t(), None, True, Variable("rax_3", Integer.int32_t(), 3, True, None))
# task.graph.add_nodes_from(
# vertices := [
# BasicBlock(0, instructions=[Branch(Condition(OperationType.equal, [arg1, Constant(1, Integer.int32_t())]))]),
# BasicBlock(
# 1,
# instructions=[
# Assignment(var_2, BinaryOperation(OperationType.plus, [var_2, Constant(1, Integer.int32_t())])),
# Branch(Condition(OperationType.not_equal, [var_2, Constant(0, Integer.int32_t())])),
# ],
# ),
# BasicBlock(
# 2,
# instructions=[
# Assignment(ListOperation([]), Call(imp_function_symbol("sub_140019288"), [arg2])),
# Branch(Condition(OperationType.equal, [arg1, Constant(0, Integer.int32_t())])),
# ],
# ),
# BasicBlock(
# 3,
# instructions=[
# Assignment(ListOperation([]), Call(imp_function_symbol("scanf"), [Constant(0x804B01F), var_5])),
# Branch(Condition(OperationType.not_equal, [var_5, Constant(0, Integer.int32_t())])),
# ],
# ),
# BasicBlock(
# 4, instructions=[Assignment(var_5, Constant(0, Integer.int32_t())), Assignment(var_7, Constant(-1, Integer.int32_t()))]
# ),
# BasicBlock(
# 5,
# instructions=[
# Assignment(var_5, Constant(0, Integer.int32_t())),
# Assignment(var_7, Constant(-1, Integer.int32_t())),
# Assignment(arg1, Constant(0, Integer.int32_t())),
# Assignment(var_2, Constant(0, Integer.int32_t())),
# ],
# ),
# BasicBlock(
# 6,
# instructions=[
# Assignment(var_5, Constant(0, Integer.int32_t())),
# Assignment(var_7, Constant(-1, Integer.int32_t())),
# Assignment(var_2, Constant(0, Integer.int32_t())),
# ],
# ),
# BasicBlock(7, instructions=[Assignment(ListOperation([]), Call(imp_function_symbol("sub_1400193a8"), []))]),
# BasicBlock(
# 8,
# instructions=[
# Assignment(ListOperation([]), Call(imp_function_symbol("scanf"), [Constant(0x804B01F), var_6])),
# Branch(Condition(OperationType.greater_us, [var_6, Constant(0, Integer.int32_t())])),
# ],
# ),
# BasicBlock(9, instructions=[Assignment(arg1, Constant(1, Integer.int32_t()))]),
# BasicBlock(10, instructions=[Return([arg1])]),
# ]
# )
# task.graph.add_edges_from(
# [
# TrueCase(vertices[0], vertices[1]),
# FalseCase(vertices[0], vertices[2]),
# TrueCase(vertices[1], vertices[3]),
# FalseCase(vertices[1], vertices[4]),
# TrueCase(vertices[2], vertices[5]),
# FalseCase(vertices[2], vertices[6]),
# TrueCase(vertices[3], vertices[7]),
# FalseCase(vertices[3], vertices[8]),
# UnconditionalEdge(vertices[4], vertices[7]),
# UnconditionalEdge(vertices[5], vertices[10]),
# UnconditionalEdge(vertices[6], vertices[9]),
# UnconditionalEdge(vertices[7], vertices[9]),
# TrueCase(vertices[8], vertices[9]),
# FalseCase(vertices[8], vertices[10]),
# UnconditionalEdge(vertices[9], vertices[10]),
# ]
# )
# PatternIndependentRestructuring().run(task)
# assert any(isinstance(node, SwitchNode) for node in task.syntax_tree)
# var_2_conditions = []
# for node in task.syntax_tree.get_condition_nodes_post_order():
# if (
# not node.condition.is_symbol
# and node.condition.is_literal
# and str(task.syntax_tree.condition_map[~node.condition]) in {"var_2 != 0x0"}
# ):
# node.switch_branches()
# if node.condition.is_symbol and str(task.syntax_tree.condition_map[node.condition]) in {"var_2 != 0x0"}:
# var_2_conditions.append(node)
# assert len(var_2_conditions) == 2
# assert var_2_conditions[0] == var_2_conditions[1]
# assert hash(var_2_conditions[0]) != hash(var_2_conditions[1])
8 changes: 8 additions & 0 deletions tests/structures/graphs/test_graph_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,11 @@ def test_get_shortest_path(self, nodes):
assert graph.get_shortest_path(nodes[0], nodes[6]) == (nodes[0], nodes[1], nodes[3], nodes[6])
graph.add_edge(BasicEdge(nodes[0], nodes[6]))
assert graph.get_shortest_path(nodes[0], nodes[6]) == (nodes[0], nodes[6])

def test_contains(self):
"""Test the contains method."""
graph, nodes, edges = self.get_easy_graph()
assert nodes[0] in graph
assert not BasicNode(6) in graph
assert edges[2] in graph
assert not BasicEdge(nodes[2], nodes[0]) in graph
Loading