diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification.py b/decompiler/pipeline/controlflowanalysis/expression_simplification.py index 9453a7780..c7f132840 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification.py @@ -3,135 +3,142 @@ from decompiler.pipeline.stage import PipelineStage from decompiler.structures.ast.ast_nodes import CodeNode -from decompiler.structures.pseudo.expressions import Constant, DataflowObject, Expression +from decompiler.structures.pseudo.expressions import Constant, Expression from decompiler.structures.pseudo.instructions import Instruction -from decompiler.structures.pseudo.operations import BinaryOperation, OperationType, UnaryOperation +from decompiler.structures.pseudo.operations import BinaryOperation, Operation, OperationType, UnaryOperation from decompiler.structures.pseudo.typing import Integer from decompiler.task import DecompilerTask -def simplify(expression: DataflowObject, parent: Optional[Instruction] = None): - """ - Simplifies the given instruction - a + 0 -> a, a - 0 -> a, 0 - a -> -a, a*0 -> 0, a* 1 -> a, a* -1 -> -a, a / 1 -> a, a / -1 -> -a - """ - parent = expression if parent is None else parent - assert isinstance(parent, Instruction), f"The parent {parent} must be an instruction." - for sub_expr in expression: - simplify(sub_expr, parent) - if isinstance(expression, BinaryOperation) and expression.operation in SIMPLIFICATION_FOR: - SIMPLIFICATION_FOR[expression.operation](expression, parent) - - -def _simplify_addition(binary_operation: BinaryOperation, instruction: Instruction): - """ - Simplifies the given addition in the given instruction. - - -> Simplifies a+0, 0+a, a-0 and -0 + a to a - """ - if any(is_zero_constant(zero := op) for op in binary_operation.operands): - non_zero = get_other_operand(binary_operation, zero) - instruction.substitute(binary_operation, non_zero) - - -def _simplify_multiplication(binary_operation: BinaryOperation, instruction: Instruction): - """ - Simplifies the given multiplication in the given instruction. - - -> Simplifies a*0, 0*a, a*(-0) and (-0) * a to 0 - -> Simplifies a*1, 1*a to a - -> Simplifies a*(-1), (-1)*a to -a - """ - if any(is_zero_constant(zero := op) for op in binary_operation.operands): - instruction.substitute(binary_operation, zero) - elif any(is_one_constant(one := op) for op in binary_operation.operands): - non_one = get_other_operand(binary_operation, one) - instruction.substitute(binary_operation, non_one) - elif any(is_minus_one_constant(minus_one := op) for op in binary_operation.operands): - negated_expression = negate_expression(get_other_operand(binary_operation, minus_one)) - instruction.substitute(binary_operation, negated_expression) - - -def _simplify_subtraction(binary_operation: BinaryOperation, instruction: Instruction): - """ - Simplifies the given subtraction in the given instruction. - - -> Simplifies a-0, a-(-0) to a - -> Simplifies 0-a, -0-a to -a - """ - if is_zero_constant(binary_operation.operands[1]): - instruction.substitute(binary_operation, binary_operation.operands[0]) - elif is_zero_constant(binary_operation.operands[0]): - instruction.substitute(binary_operation, negate_expression(binary_operation.operands[1])) - - -def _simplify_division(binary_operation: BinaryOperation, instruction: Instruction): - """ - Simplifies the given division in the given instruction. - - -> Simplifies a/1 to a and a/(-1) to -a - """ - if is_one_constant(binary_operation.operands[1]): - instruction.substitute(binary_operation, binary_operation.operands[0]) - elif is_minus_one_constant(binary_operation.operands[1]): - instruction.substitute(binary_operation, negate_expression(binary_operation.operands[0])) - - -# This translator maps the operations to their simplification method -SIMPLIFICATION_FOR = { - OperationType.plus: _simplify_addition, - OperationType.multiply: _simplify_multiplication, - OperationType.minus: _simplify_subtraction, - OperationType.divide: _simplify_division, -} - - -def is_zero_constant(expression: Expression) -> bool: - """Checks whether the given expression is 0.""" - return isinstance(expression, Constant) and expression.value == 0 - - -def is_one_constant(expression: Expression) -> bool: - """Checks whether the given expression is 1.""" - return isinstance(expression, Constant) and expression.value == 1 - - -def is_minus_one_constant(expression: Expression) -> bool: - """Checks whether the given expression is -1.""" - return isinstance(expression, Constant) and expression.value == -1 - - -def negate_expression(expression: Expression) -> Expression: - """Negate the given expression and return it.""" - if isinstance(expression, Constant) and expression.value == 0: - return expression - if isinstance(expression, UnaryOperation) and expression.operation == OperationType.negate: - return expression.operand - if isinstance(expression, Constant) and isinstance(expression.type, Integer) and expression.type.is_signed: - return Constant(-expression.value, expression.type) - return UnaryOperation(OperationType.negate, [expression]) - - -def get_other_operand(binary_operation: BinaryOperation, expression: Expression) -> Expression: - """Returns the operand that is not equal to expression.""" - if binary_operation.operands[0] == expression: - return binary_operation.operands[1] - return binary_operation.operands[0] - - class ExpressionSimplification(PipelineStage): """The ExpressionSimplification makes various simplifications to expressions on the AST, like a + 0 = a.""" name = "expression-simplification" + def __init__(self): + self.HANDLERS = { + OperationType.plus: self._simplify_addition, + OperationType.minus: self._simplify_subtraction, + OperationType.multiply: self._simplify_multiplication, + OperationType.divide: self._simplify_division, + OperationType.divide_us: self._simplify_division, + OperationType.divide_float: self._simplify_division, + OperationType.dereference: self._simplify_dereference, + } + def run(self, task: DecompilerTask): """Run the task expression simplification on each instruction of the AST.""" if task.syntax_tree is None: for instruction in task.graph.instructions: - simplify(instruction) + self.simplify(instruction) else: for node in task.syntax_tree.topological_order(): if not isinstance(node, CodeNode): continue for instruction in node.instructions: - simplify(instruction) + self.simplify(instruction) + + def simplify(self, instruction: Instruction): + """Simplify all subexpressions of the given instruction recursively.""" + todo = list(instruction) + while todo and (expression := todo.pop()): + if self.simplify_expression(expression, instruction): + todo = list(instruction) + else: + todo.extend(expression) + + def simplify_expression(self, expression: Expression, parent: Instruction) -> Optional[Expression]: + """Simplify the given instruction utilizing the registered OperationType handlers.""" + if isinstance(expression, Operation) and expression.operation in self.HANDLERS: + if simplified := self.HANDLERS[expression.operation](expression): + parent.substitute(expression, simplified) + return simplified + + def _simplify_addition(self, expression: BinaryOperation) -> Optional[Expression]: + """ + Simplifies the given addition in the given instruction. + + -> Simplifies a+0, 0+a, a-0 and -0 + a to a + """ + if any(self.is_zero_constant(zero := op) for op in expression.operands): + return self.get_other_operand(expression, zero).copy() + + def _simplify_subtraction(self, expression: BinaryOperation) -> Optional[Expression]: + """ + Simplifies the given subtraction in the given instruction. + + -> Simplifies a-0, a-(-0) to a + -> Simplifies 0-a, -0-a to -a + """ + if self.is_zero_constant(expression.operands[1]): + return expression.operands[0].copy() + if self.is_zero_constant(expression.operands[0]): + return self.negate_expression(expression.operands[1]) + + def _simplify_multiplication(self, expression: BinaryOperation) -> Optional[Expression]: + """ + Simplifies the given multiplication in the given instruction. + + -> Simplifies a*0, 0*a, a*(-0) and (-0) * a to 0 + -> Simplifies a*1, 1*a to a + -> Simplifies a*(-1), (-1)*a to -a + """ + if any(self.is_zero_constant(zero := op) for op in expression.operands): + return zero.copy() + if any(self.is_one_constant(one := op) for op in expression.operands): + return self.get_other_operand(expression, one).copy() + if any(self.is_minus_one_constant(minus_one := op) for op in expression.operands): + return self.negate_expression(self.get_other_operand(expression, minus_one)) + + def _simplify_division(self, expression: BinaryOperation) -> Optional[Expression]: + """ + Simplifies the given division in the given instruction. + + -> Simplifies a/1 to a and a/(-1) to -a + """ + if self.is_one_constant(expression.operands[1]): + return expression.operands[0].copy() + if self.is_minus_one_constant(expression.operands[1]): + return self.negate_expression(expression.operands[0]) + + def _simplify_dereference(self, expression: UnaryOperation) -> Optional[Expression]: + """ + Simplifies dereference expression with nested address-of expressions. + + -> Simplifies *(&(x)) to x + """ + if isinstance(expression.operand, UnaryOperation) and expression.operand.operation == OperationType.address: + return expression.operand.operand.copy() + + @staticmethod + def is_zero_constant(expression: Expression) -> bool: + """Checks whether the given expression is 0.""" + return isinstance(expression, Constant) and expression.value == 0 + + @staticmethod + def is_one_constant(expression: Expression) -> bool: + """Checks whether the given expression is 1.""" + return isinstance(expression, Constant) and expression.value == 1 + + @staticmethod + def is_minus_one_constant(expression: Expression) -> bool: + """Checks whether the given expression is -1.""" + return isinstance(expression, Constant) and expression.value == -1 + + @staticmethod + def negate_expression(expression: Expression) -> Expression: + """Negate the given expression and return it.""" + if isinstance(expression, Constant) and expression.value == 0: + return expression + if isinstance(expression, UnaryOperation) and expression.operation == OperationType.negate: + return expression.operand + if isinstance(expression, Constant) and isinstance(expression.type, Integer) and expression.type.is_signed: + return Constant(-expression.value, expression.type) + return UnaryOperation(OperationType.negate, [expression]) + + @staticmethod + def get_other_operand(binary_operation: BinaryOperation, expression: Expression) -> Expression: + """Returns the operand that is not equal to expression.""" + if binary_operation.operands[0] == expression: + return binary_operation.operands[1] + return binary_operation.operands[0] diff --git a/decompiler/pipeline/dataflowanalysis/deadcodeelimination.py b/decompiler/pipeline/dataflowanalysis/deadcodeelimination.py index d54012804..75d820590 100644 --- a/decompiler/pipeline/dataflowanalysis/deadcodeelimination.py +++ b/decompiler/pipeline/dataflowanalysis/deadcodeelimination.py @@ -4,7 +4,7 @@ from decompiler.pipeline.stage import PipelineStage from decompiler.structures.graphs.cfg import ControlFlowGraph -from decompiler.structures.pseudo.expressions import Variable +from decompiler.structures.pseudo.expressions import GlobalVariable, Variable from decompiler.structures.pseudo.instructions import Assignment, BaseAssignment, Instruction, Relation from decompiler.structures.pseudo.operations import BinaryOperation, Call, ListOperation, OperationType, UnaryOperation from decompiler.task import DecompilerTask @@ -78,6 +78,8 @@ def _add_assignment(self, assignment: BaseAssignment, position: CfgPosition): self.add_node(str(defined_variable), instruction=assignment, position=position) for required_variable in assignment.requirements: self.add_edge(str(defined_variable), str(required_variable)) + if isinstance(defined_variable, GlobalVariable): + self.add_edge(self.SINK_LABEL, str(defined_variable)) def find_dead_variables(self) -> Set[str]: """Iterate all dead variables in the graph based on their name to prevent type mismatches.""" diff --git a/decompiler/pipeline/dataflowanalysis/identity_elimination.py b/decompiler/pipeline/dataflowanalysis/identity_elimination.py index 1d4d4549f..687835635 100644 --- a/decompiler/pipeline/dataflowanalysis/identity_elimination.py +++ b/decompiler/pipeline/dataflowanalysis/identity_elimination.py @@ -7,7 +7,7 @@ from decompiler.pipeline.stage import PipelineStage from decompiler.structures.graphs.cfg import BasicBlock -from decompiler.structures.pseudo.expressions import Constant, UnknownExpression, Variable +from decompiler.structures.pseudo.expressions import Constant, GlobalVariable, UnknownExpression, Variable from decompiler.structures.pseudo.instructions import Assignment, Instruction, Phi, Relation from decompiler.task import DecompilerTask from networkx import DiGraph, node_disjoint_paths, weakly_connected_components @@ -68,7 +68,7 @@ def add_assignment(self, assignment: Assignment, basic_block: BasicBlock) -> Non - First check that the assignments defines exactly one variable. - Then compute the set of required variables and add the according edges to the identity graph. """ - if not isinstance(defined_value := assignment.destination, Variable): + if not isinstance(defined_value := assignment.destination, Variable) or isinstance(defined_value, GlobalVariable): return required_values = self._get_variables_utilized_for_direct_assignment(assignment) self.add_node(defined_value, definition=assignment, block=basic_block, is_phi=isinstance(assignment, Phi)) diff --git a/tests/pipeline/controlflowanalysis/test_expression_simplification.py b/tests/pipeline/controlflowanalysis/test_expression_simplification.py index fded455d1..11a9a7042 100644 --- a/tests/pipeline/controlflowanalysis/test_expression_simplification.py +++ b/tests/pipeline/controlflowanalysis/test_expression_simplification.py @@ -6,9 +6,9 @@ from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree from decompiler.structures.graphs.cfg import BasicBlock, ControlFlowGraph from decompiler.structures.logic.logic_condition import LogicCondition -from decompiler.structures.pseudo.expressions import Constant, Variable -from decompiler.structures.pseudo.instructions import Assignment -from decompiler.structures.pseudo.operations import BinaryOperation, OperationType, UnaryOperation +from decompiler.structures.pseudo.expressions import Constant, ImportedFunctionSymbol, Variable +from decompiler.structures.pseudo.instructions import Assignment, Instruction, Return +from decompiler.structures.pseudo.operations import BinaryOperation, Call, ListOperation, OperationType, UnaryOperation from decompiler.structures.pseudo.typing import Integer from decompiler.task import DecompilerTask @@ -179,6 +179,46 @@ def test_simplification_with_one_division(instruction, result): assert task.syntax_tree.root == CodeNode([result], true_value.copy()) +@pytest.mark.parametrize( + "instruction, result", + [ + ( + Assignment(UnaryOperation(OperationType.dereference, [UnaryOperation(OperationType.address, [Variable("x")])]), Constant(0)), + Assignment(Variable("x"), Constant(0)), + ), + ( + Assignment( + ListOperation([]), + Call( + ImportedFunctionSymbol("foo", 0x42), + [ + BinaryOperation( + OperationType.minus, + [ + UnaryOperation(OperationType.dereference, [UnaryOperation(OperationType.address, [Variable("x")])]), + Constant(2), + ], + ) + ], + ), + ), + Assignment( + ListOperation([]), + Call(ImportedFunctionSymbol("foo", 0x42), [BinaryOperation(OperationType.minus, [Variable("x"), Constant(2)])]), + ), + ), + ( + Return([UnaryOperation(OperationType.dereference, [UnaryOperation(OperationType.address, [Variable("y")])])]), + Return([Variable("y")]), + ), + ], +) +def test_simplification_of_dereference_operations(instruction: Instruction, result: Instruction): + """Check if dereference operations with address-of operands are simplified correctly.""" + ExpressionSimplification().simplify(instruction) + assert instruction == result + + @pytest.mark.parametrize( "instruction, result", [ diff --git a/tests/test_sample_binaries.py b/tests/test_sample_binaries.py index eae4010a8..0cb165598 100644 --- a/tests/test_sample_binaries.py +++ b/tests/test_sample_binaries.py @@ -56,7 +56,7 @@ def test_global_strings_and_tables(): # Make sure the global string contains the string hello world. assert output2.count('"Hello World"') == 1 # Ensure that string is referenced correctly - assert output2.count("*&hello_string") == 1 + assert output2.count("hello_string") == 1 @pytest.mark.skip(reason="global lifting not yet implemented in the new lifter")