Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of common subexpression elimination and parallel testing #380

Merged
merged 11 commits into from
Jan 31, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from __future__ import annotations

from collections import defaultdict, deque
from collections import Counter, defaultdict, deque
from dataclasses import dataclass
from itertools import chain
from logging import info, warning
from typing import DefaultDict, Deque, Dict, Iterator, List, Optional, Set, Tuple
from typing import DefaultDict, Deque, Dict, Iterable, Iterator, List, Optional, Set, Tuple

from decompiler.pipeline.stage import PipelineStage
from decompiler.structures.graphs.cfg import BasicBlock, ControlFlowGraph
Expand All @@ -18,12 +18,20 @@
from networkx import dfs_postorder_nodes


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class CfgInstruction:
"""
dataclass in charge of tracking the location of Instruction objects in the cfg

-> The considered instruction, where block is the basic block where it is contained and index the position in the basic block.

Note: Two instances with the same data will not be equal (because of eq=False).
This way, eq and hash are way more performant, because at the time of writing this, eq and hash are very
expensive on big instructions.

eq=True would probably be nicer to use, but we don't actually create instances with the same data
multiple times. (Rationale: initially just one instance is created per (block, index) pair.
All further instances with the same (block, index) will have a less complex instruction than before)
"""

instruction: Instruction
Expand Down Expand Up @@ -200,7 +208,7 @@ class DefinitionGenerator:

def __init__(
self,
expression_usages: DefaultDict[Expression, List[CfgInstruction]],
expression_usages: DefaultDict[Expression, Counter[CfgInstruction]],
dominator_tree: NetworkXGraph,
):
"""Generate a new instance based on data parsed from a cfg."""
Expand All @@ -210,16 +218,16 @@ def __init__(
@classmethod
def from_cfg(cls, cfg: ControlFlowGraph) -> DefinitionGenerator:
"""Initialize a DefinitionGenerator utilizing the data of the given cfg."""
usages: DefaultDict[Expression, List[CfgInstruction]] = defaultdict(list)
usages: DefaultDict[Expression, Counter[CfgInstruction]] = defaultdict(Counter)
for basic_block in cfg:
for index, instruction in enumerate(basic_block.instructions):
instruction_with_position = CfgInstruction(instruction, basic_block, index)
for subexpression in _subexpression_dfs(instruction):
usages[subexpression].append(instruction_with_position)
usages[subexpression][instruction_with_position] += 1
return cls(usages, cfg.dominator_tree)

@property
def usages(self) -> DefaultDict[Expression, List[CfgInstruction]]:
def usages(self) -> DefaultDict[Expression, Counter[CfgInstruction]]:
"""Return a mapping from expressions to a set of instructions using them."""
return self._usages

Expand All @@ -241,7 +249,7 @@ def _find_location_for_insertion(self, expression) -> Tuple[BasicBlock, int]:
candidate: BasicBlock = next(iter(usage_blocks))
while not self._is_common_dominator(candidate, usage_blocks) or self._is_invalid_dominator(candidate, expression):
candidate = self._dominator_tree.get_predecessors(candidate)[0]
return candidate, self._find_insertion_index(candidate, set(self._usages[expression]))
return candidate, self._find_insertion_index(candidate, self._usages[expression].keys())

def _is_common_dominator(self, candidate: BasicBlock, basic_blocks: Set[BasicBlock]) -> bool:
"""Check if the given candidate is the common dominator all of given basic blocks."""
Expand All @@ -261,10 +269,10 @@ def _insert_definition(self, definition: CfgInstruction):
"""Insert a new intermediate definition for the given expression at the given location."""
definition.block.instructions.insert(definition.index, definition.instruction)
for subexpression in _subexpression_dfs(definition.instruction):
self._usages[subexpression].append(definition)
self._usages[subexpression][definition] += 1

@staticmethod
def _find_insertion_index(basic_block: BasicBlock, usages: Set[CfgInstruction]) -> int:
def _find_insertion_index(basic_block: BasicBlock, usages: Iterable[CfgInstruction]) -> int:
"""Find the first index in the given basic block where a definition could be inserted."""
usage = min((usage for usage in usages if usage.block == basic_block), default=None, key=lambda x: x.index)
if usage:
Expand Down Expand Up @@ -316,7 +324,7 @@ def eliminate_common_subexpressions(self, definition_generator: DefinitionGenera
except StopIteration:
warning(f"[{self.name}] No dominating basic block could be found for {replacee}")

def _find_elimination_candidates(self, usages: DefaultDict[Expression, List[CfgInstruction]]) -> Iterator[Expression]:
def _find_elimination_candidates(self, usages: DefaultDict[Expression, Counter[CfgInstruction]]) -> Iterator[Expression]:
"""
Iterate all expressions, yielding the expressions which should be eliminated.

Expand All @@ -325,11 +333,12 @@ def _find_elimination_candidates(self, usages: DefaultDict[Expression, List[CfgI
expressions_by_complexity = sorted(usages.keys(), reverse=True, key=lambda x: x.complexity)
for expression in expressions_by_complexity:
if self._is_cse_candidate(expression, usages):
expression_usage = usages[expression]
for subexpression in _subexpression_dfs(expression):
usages[subexpression] = [x for x in usages[subexpression] if x not in usages[expression]]
usages[subexpression].subtract(expression_usage)
yield expression

def _is_cse_candidate(self, expression: Expression, usages: DefaultDict[Expression, List[CfgInstruction]]):
def _is_cse_candidate(self, expression: Expression, usages: DefaultDict[Expression, Counter[CfgInstruction]]):
"""Checks that we can add a common subexpression for the given expression."""
return (
self._is_elimination_candidate(expression, usages[expression])
Expand All @@ -347,15 +356,16 @@ def _is_complex_string(self, expression: Expression) -> bool:
return isinstance(expression.value, str) and len(expression.value) >= self._min_string_length
return False

def _check_inter_instruction(self, expression: Expression, instructions: List[CfgInstruction]) -> bool:
def _check_inter_instruction(self, expression: Expression, instructions: Counter[CfgInstruction]) -> bool:
"""Check if the given expressions should be eliminated based on its global occurrences."""
referencing_instructions_count = len(set(instructions))
referencing_instructions_count = sum(1 for _, count in instructions.items() if count > 0)
return (expression.complexity >= 2 and referencing_instructions_count >= self._threshold) or (
self._is_complex_string(expression) and referencing_instructions_count >= self._string_threshold
)

def _check_intra_instruction(self, expression: Expression, instructions: List[CfgInstruction]) -> bool:
def _check_intra_instruction(self, expression: Expression, instructions: Counter[CfgInstruction]) -> bool:
"""Check if this expression should be eliminated based on the amount of unique instructions utilizing it."""
return (expression.complexity >= 2 and len(instructions) >= self._threshold) or (
self._is_complex_string(expression) and len(instructions) >= self._string_threshold
referencing_count = instructions.total()
return (expression.complexity >= 2 and referencing_count >= self._threshold) or (
self._is_complex_string(expression) and referencing_count >= self._string_threshold
)
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[pytest]
addopts = -n auto
python_files = test-*.py test_*.py
markers = coreutils
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ networkx != 2.8.4
pydot
pygments
pytest !=5.3.4
pytest-xdist
z3-solver == 4.8.10
8 changes: 6 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,19 @@ def pytest_generate_tests(metafunc):
test_cases = _discover_full_tests()
else:
test_cases = _discover_system_tests()

params = list()
ids = list()
for sample_name, functions in test_cases.items():
for f in functions:
params.append((sample_name, f))
metafunc.parametrize("test_cases", params)
ids.append(f"{sample_name}::{f}")

metafunc.parametrize("test_cases", params, ids=ids)

if "coreutils_tests" in metafunc.fixturenames:
coreutils_tests = _discover_coreutils_tests()
metafunc.parametrize("coreutils_tests", coreutils_tests)
metafunc.parametrize("coreutils_tests", coreutils_tests, ids=map(lambda t: f"{t[0]}::{t[1]}", coreutils_tests))


def _discover_full_tests() -> Dict[pathlib.Path, List[str]]:
Expand Down
Loading