Skip to content

Commit

Permalink
Simplify dereference operations containing address-of operations (#17)
Browse files Browse the repository at this point in the history
* refactored expression_simplification.py and added dereference operation simplification
  • Loading branch information
0x6e62 authored Mar 4, 2022
1 parent 1a02645 commit a76fc52
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 120 deletions.
233 changes: 120 additions & 113 deletions decompiler/pipeline/controlflowanalysis/expression_simplification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,135 +3,142 @@

from decompiler.pipeline.stage import PipelineStage
from decompiler.structures.ast.ast_nodes import CodeNode
from decompiler.structures.pseudo.expressions import Constant, DataflowObject, Expression
from decompiler.structures.pseudo.expressions import Constant, Expression
from decompiler.structures.pseudo.instructions import Instruction
from decompiler.structures.pseudo.operations import BinaryOperation, OperationType, UnaryOperation
from decompiler.structures.pseudo.operations import BinaryOperation, Operation, OperationType, UnaryOperation
from decompiler.structures.pseudo.typing import Integer
from decompiler.task import DecompilerTask


def simplify(expression: DataflowObject, parent: Optional[Instruction] = None):
"""
Simplifies the given instruction
a + 0 -> a, a - 0 -> a, 0 - a -> -a, a*0 -> 0, a* 1 -> a, a* -1 -> -a, a / 1 -> a, a / -1 -> -a
"""
parent = expression if parent is None else parent
assert isinstance(parent, Instruction), f"The parent {parent} must be an instruction."
for sub_expr in expression:
simplify(sub_expr, parent)
if isinstance(expression, BinaryOperation) and expression.operation in SIMPLIFICATION_FOR:
SIMPLIFICATION_FOR[expression.operation](expression, parent)


def _simplify_addition(binary_operation: BinaryOperation, instruction: Instruction):
"""
Simplifies the given addition in the given instruction.
-> Simplifies a+0, 0+a, a-0 and -0 + a to a
"""
if any(is_zero_constant(zero := op) for op in binary_operation.operands):
non_zero = get_other_operand(binary_operation, zero)
instruction.substitute(binary_operation, non_zero)


def _simplify_multiplication(binary_operation: BinaryOperation, instruction: Instruction):
"""
Simplifies the given multiplication in the given instruction.
-> Simplifies a*0, 0*a, a*(-0) and (-0) * a to 0
-> Simplifies a*1, 1*a to a
-> Simplifies a*(-1), (-1)*a to -a
"""
if any(is_zero_constant(zero := op) for op in binary_operation.operands):
instruction.substitute(binary_operation, zero)
elif any(is_one_constant(one := op) for op in binary_operation.operands):
non_one = get_other_operand(binary_operation, one)
instruction.substitute(binary_operation, non_one)
elif any(is_minus_one_constant(minus_one := op) for op in binary_operation.operands):
negated_expression = negate_expression(get_other_operand(binary_operation, minus_one))
instruction.substitute(binary_operation, negated_expression)


def _simplify_subtraction(binary_operation: BinaryOperation, instruction: Instruction):
"""
Simplifies the given subtraction in the given instruction.
-> Simplifies a-0, a-(-0) to a
-> Simplifies 0-a, -0-a to -a
"""
if is_zero_constant(binary_operation.operands[1]):
instruction.substitute(binary_operation, binary_operation.operands[0])
elif is_zero_constant(binary_operation.operands[0]):
instruction.substitute(binary_operation, negate_expression(binary_operation.operands[1]))


def _simplify_division(binary_operation: BinaryOperation, instruction: Instruction):
"""
Simplifies the given division in the given instruction.
-> Simplifies a/1 to a and a/(-1) to -a
"""
if is_one_constant(binary_operation.operands[1]):
instruction.substitute(binary_operation, binary_operation.operands[0])
elif is_minus_one_constant(binary_operation.operands[1]):
instruction.substitute(binary_operation, negate_expression(binary_operation.operands[0]))


# This translator maps the operations to their simplification method
SIMPLIFICATION_FOR = {
OperationType.plus: _simplify_addition,
OperationType.multiply: _simplify_multiplication,
OperationType.minus: _simplify_subtraction,
OperationType.divide: _simplify_division,
}


def is_zero_constant(expression: Expression) -> bool:
"""Checks whether the given expression is 0."""
return isinstance(expression, Constant) and expression.value == 0


def is_one_constant(expression: Expression) -> bool:
"""Checks whether the given expression is 1."""
return isinstance(expression, Constant) and expression.value == 1


def is_minus_one_constant(expression: Expression) -> bool:
"""Checks whether the given expression is -1."""
return isinstance(expression, Constant) and expression.value == -1


def negate_expression(expression: Expression) -> Expression:
"""Negate the given expression and return it."""
if isinstance(expression, Constant) and expression.value == 0:
return expression
if isinstance(expression, UnaryOperation) and expression.operation == OperationType.negate:
return expression.operand
if isinstance(expression, Constant) and isinstance(expression.type, Integer) and expression.type.is_signed:
return Constant(-expression.value, expression.type)
return UnaryOperation(OperationType.negate, [expression])


def get_other_operand(binary_operation: BinaryOperation, expression: Expression) -> Expression:
"""Returns the operand that is not equal to expression."""
if binary_operation.operands[0] == expression:
return binary_operation.operands[1]
return binary_operation.operands[0]


class ExpressionSimplification(PipelineStage):
"""The ExpressionSimplification makes various simplifications to expressions on the AST, like a + 0 = a."""

name = "expression-simplification"

def __init__(self):
self.HANDLERS = {
OperationType.plus: self._simplify_addition,
OperationType.minus: self._simplify_subtraction,
OperationType.multiply: self._simplify_multiplication,
OperationType.divide: self._simplify_division,
OperationType.divide_us: self._simplify_division,
OperationType.divide_float: self._simplify_division,
OperationType.dereference: self._simplify_dereference,
}

def run(self, task: DecompilerTask):
"""Run the task expression simplification on each instruction of the AST."""
if task.syntax_tree is None:
for instruction in task.graph.instructions:
simplify(instruction)
self.simplify(instruction)
else:
for node in task.syntax_tree.topological_order():
if not isinstance(node, CodeNode):
continue
for instruction in node.instructions:
simplify(instruction)
self.simplify(instruction)

def simplify(self, instruction: Instruction):
"""Simplify all subexpressions of the given instruction recursively."""
todo = list(instruction)
while todo and (expression := todo.pop()):
if self.simplify_expression(expression, instruction):
todo = list(instruction)
else:
todo.extend(expression)

def simplify_expression(self, expression: Expression, parent: Instruction) -> Optional[Expression]:
"""Simplify the given instruction utilizing the registered OperationType handlers."""
if isinstance(expression, Operation) and expression.operation in self.HANDLERS:
if simplified := self.HANDLERS[expression.operation](expression):
parent.substitute(expression, simplified)
return simplified

def _simplify_addition(self, expression: BinaryOperation) -> Optional[Expression]:
"""
Simplifies the given addition in the given instruction.
-> Simplifies a+0, 0+a, a-0 and -0 + a to a
"""
if any(self.is_zero_constant(zero := op) for op in expression.operands):
return self.get_other_operand(expression, zero).copy()

def _simplify_subtraction(self, expression: BinaryOperation) -> Optional[Expression]:
"""
Simplifies the given subtraction in the given instruction.
-> Simplifies a-0, a-(-0) to a
-> Simplifies 0-a, -0-a to -a
"""
if self.is_zero_constant(expression.operands[1]):
return expression.operands[0].copy()
if self.is_zero_constant(expression.operands[0]):
return self.negate_expression(expression.operands[1])

def _simplify_multiplication(self, expression: BinaryOperation) -> Optional[Expression]:
"""
Simplifies the given multiplication in the given instruction.
-> Simplifies a*0, 0*a, a*(-0) and (-0) * a to 0
-> Simplifies a*1, 1*a to a
-> Simplifies a*(-1), (-1)*a to -a
"""
if any(self.is_zero_constant(zero := op) for op in expression.operands):
return zero.copy()
if any(self.is_one_constant(one := op) for op in expression.operands):
return self.get_other_operand(expression, one).copy()
if any(self.is_minus_one_constant(minus_one := op) for op in expression.operands):
return self.negate_expression(self.get_other_operand(expression, minus_one))

def _simplify_division(self, expression: BinaryOperation) -> Optional[Expression]:
"""
Simplifies the given division in the given instruction.
-> Simplifies a/1 to a and a/(-1) to -a
"""
if self.is_one_constant(expression.operands[1]):
return expression.operands[0].copy()
if self.is_minus_one_constant(expression.operands[1]):
return self.negate_expression(expression.operands[0])

def _simplify_dereference(self, expression: UnaryOperation) -> Optional[Expression]:
"""
Simplifies dereference expression with nested address-of expressions.
-> Simplifies *(&(x)) to x
"""
if isinstance(expression.operand, UnaryOperation) and expression.operand.operation == OperationType.address:
return expression.operand.operand.copy()

@staticmethod
def is_zero_constant(expression: Expression) -> bool:
"""Checks whether the given expression is 0."""
return isinstance(expression, Constant) and expression.value == 0

@staticmethod
def is_one_constant(expression: Expression) -> bool:
"""Checks whether the given expression is 1."""
return isinstance(expression, Constant) and expression.value == 1

@staticmethod
def is_minus_one_constant(expression: Expression) -> bool:
"""Checks whether the given expression is -1."""
return isinstance(expression, Constant) and expression.value == -1

@staticmethod
def negate_expression(expression: Expression) -> Expression:
"""Negate the given expression and return it."""
if isinstance(expression, Constant) and expression.value == 0:
return expression
if isinstance(expression, UnaryOperation) and expression.operation == OperationType.negate:
return expression.operand
if isinstance(expression, Constant) and isinstance(expression.type, Integer) and expression.type.is_signed:
return Constant(-expression.value, expression.type)
return UnaryOperation(OperationType.negate, [expression])

@staticmethod
def get_other_operand(binary_operation: BinaryOperation, expression: Expression) -> Expression:
"""Returns the operand that is not equal to expression."""
if binary_operation.operands[0] == expression:
return binary_operation.operands[1]
return binary_operation.operands[0]
4 changes: 3 additions & 1 deletion decompiler/pipeline/dataflowanalysis/deadcodeelimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from decompiler.pipeline.stage import PipelineStage
from decompiler.structures.graphs.cfg import ControlFlowGraph
from decompiler.structures.pseudo.expressions import Variable
from decompiler.structures.pseudo.expressions import GlobalVariable, Variable
from decompiler.structures.pseudo.instructions import Assignment, BaseAssignment, Instruction, Relation
from decompiler.structures.pseudo.operations import BinaryOperation, Call, ListOperation, OperationType, UnaryOperation
from decompiler.task import DecompilerTask
Expand Down Expand Up @@ -78,6 +78,8 @@ def _add_assignment(self, assignment: BaseAssignment, position: CfgPosition):
self.add_node(str(defined_variable), instruction=assignment, position=position)
for required_variable in assignment.requirements:
self.add_edge(str(defined_variable), str(required_variable))
if isinstance(defined_variable, GlobalVariable):
self.add_edge(self.SINK_LABEL, str(defined_variable))

def find_dead_variables(self) -> Set[str]:
"""Iterate all dead variables in the graph based on their name to prevent type mismatches."""
Expand Down
4 changes: 2 additions & 2 deletions decompiler/pipeline/dataflowanalysis/identity_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from decompiler.pipeline.stage import PipelineStage
from decompiler.structures.graphs.cfg import BasicBlock
from decompiler.structures.pseudo.expressions import Constant, UnknownExpression, Variable
from decompiler.structures.pseudo.expressions import Constant, GlobalVariable, UnknownExpression, Variable
from decompiler.structures.pseudo.instructions import Assignment, Instruction, Phi, Relation
from decompiler.task import DecompilerTask
from networkx import DiGraph, node_disjoint_paths, weakly_connected_components
Expand Down Expand Up @@ -68,7 +68,7 @@ def add_assignment(self, assignment: Assignment, basic_block: BasicBlock) -> Non
- First check that the assignments defines exactly one variable.
- Then compute the set of required variables and add the according edges to the identity graph.
"""
if not isinstance(defined_value := assignment.destination, Variable):
if not isinstance(defined_value := assignment.destination, Variable) or isinstance(defined_value, GlobalVariable):
return
required_values = self._get_variables_utilized_for_direct_assignment(assignment)
self.add_node(defined_value, definition=assignment, block=basic_block, is_phi=isinstance(assignment, Phi))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree
from decompiler.structures.graphs.cfg import BasicBlock, ControlFlowGraph
from decompiler.structures.logic.logic_condition import LogicCondition
from decompiler.structures.pseudo.expressions import Constant, Variable
from decompiler.structures.pseudo.instructions import Assignment
from decompiler.structures.pseudo.operations import BinaryOperation, OperationType, UnaryOperation
from decompiler.structures.pseudo.expressions import Constant, ImportedFunctionSymbol, Variable
from decompiler.structures.pseudo.instructions import Assignment, Instruction, Return
from decompiler.structures.pseudo.operations import BinaryOperation, Call, ListOperation, OperationType, UnaryOperation
from decompiler.structures.pseudo.typing import Integer
from decompiler.task import DecompilerTask

Expand Down Expand Up @@ -179,6 +179,46 @@ def test_simplification_with_one_division(instruction, result):
assert task.syntax_tree.root == CodeNode([result], true_value.copy())


@pytest.mark.parametrize(
"instruction, result",
[
(
Assignment(UnaryOperation(OperationType.dereference, [UnaryOperation(OperationType.address, [Variable("x")])]), Constant(0)),
Assignment(Variable("x"), Constant(0)),
),
(
Assignment(
ListOperation([]),
Call(
ImportedFunctionSymbol("foo", 0x42),
[
BinaryOperation(
OperationType.minus,
[
UnaryOperation(OperationType.dereference, [UnaryOperation(OperationType.address, [Variable("x")])]),
Constant(2),
],
)
],
),
),
Assignment(
ListOperation([]),
Call(ImportedFunctionSymbol("foo", 0x42), [BinaryOperation(OperationType.minus, [Variable("x"), Constant(2)])]),
),
),
(
Return([UnaryOperation(OperationType.dereference, [UnaryOperation(OperationType.address, [Variable("y")])])]),
Return([Variable("y")]),
),
],
)
def test_simplification_of_dereference_operations(instruction: Instruction, result: Instruction):
"""Check if dereference operations with address-of operands are simplified correctly."""
ExpressionSimplification().simplify(instruction)
assert instruction == result


@pytest.mark.parametrize(
"instruction, result",
[
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sample_binaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_global_strings_and_tables():
# Make sure the global string contains the string hello world.
assert output2.count('"Hello World"') == 1
# Ensure that string is referenced correctly
assert output2.count("*&hello_string") == 1
assert output2.count("hello_string") == 1


@pytest.mark.skip(reason="global lifting not yet implemented in the new lifter")
Expand Down

0 comments on commit a76fc52

Please sign in to comment.