Skip to content

Commit

Permalink
Improve performance of common subexpression elimination and parallel …
Browse files Browse the repository at this point in the history
…testing (#380)

* Optimize common_subexpression_elimination

* Parallelize pytest

* Add comment to explain the motivation behind equality

* Improve pytest logging

* Remove -v again

* fix typing

* Add further clarification on comment

* add rational why code is still correct

* Add deterministic ordering to conftests

* Fix types

---------

Co-authored-by: Manuel Blatt <[email protected]>
Co-authored-by: Manuel Blatt <[email protected]>
  • Loading branch information
3 people authored Jan 31, 2024
1 parent d394012 commit e80b8ca
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from __future__ import annotations

from collections import defaultdict, deque
from collections import Counter, defaultdict, deque
from dataclasses import dataclass
from itertools import chain
from logging import info, warning
from typing import DefaultDict, Deque, Dict, Iterator, List, Optional, Set, Tuple
from typing import DefaultDict, Deque, Dict, Iterable, Iterator, List, Optional, Set, Tuple

from decompiler.pipeline.stage import PipelineStage
from decompiler.structures.graphs.cfg import BasicBlock, ControlFlowGraph
Expand All @@ -18,12 +18,20 @@
from networkx import dfs_postorder_nodes


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class CfgInstruction:
"""
dataclass in charge of tracking the location of Instruction objects in the cfg
-> The considered instruction, where block is the basic block where it is contained and index the position in the basic block.
Note: Two instances with the same data will not be equal (because of eq=False).
This way, eq and hash are way more performant, because at the time of writing this, eq and hash are very
expensive on big instructions.
eq=True would probably be nicer to use, but we don't actually create instances with the same data
multiple times. (Rationale: initially just one instance is created per (block, index) pair.
All further instances with the same (block, index) will have a less complex instruction than before)
"""

instruction: Instruction
Expand Down Expand Up @@ -200,7 +208,7 @@ class DefinitionGenerator:

def __init__(
self,
expression_usages: DefaultDict[Expression, List[CfgInstruction]],
expression_usages: DefaultDict[Expression, Counter[CfgInstruction]],
dominator_tree: NetworkXGraph,
):
"""Generate a new instance based on data parsed from a cfg."""
Expand All @@ -210,16 +218,16 @@ def __init__(
@classmethod
def from_cfg(cls, cfg: ControlFlowGraph) -> DefinitionGenerator:
"""Initialize a DefinitionGenerator utilizing the data of the given cfg."""
usages: DefaultDict[Expression, List[CfgInstruction]] = defaultdict(list)
usages: DefaultDict[Expression, Counter[CfgInstruction]] = defaultdict(Counter)
for basic_block in cfg:
for index, instruction in enumerate(basic_block.instructions):
instruction_with_position = CfgInstruction(instruction, basic_block, index)
for subexpression in _subexpression_dfs(instruction):
usages[subexpression].append(instruction_with_position)
usages[subexpression][instruction_with_position] += 1
return cls(usages, cfg.dominator_tree)

@property
def usages(self) -> DefaultDict[Expression, List[CfgInstruction]]:
def usages(self) -> DefaultDict[Expression, Counter[CfgInstruction]]:
"""Return a mapping from expressions to a set of instructions using them."""
return self._usages

Expand All @@ -241,7 +249,7 @@ def _find_location_for_insertion(self, expression) -> Tuple[BasicBlock, int]:
candidate: BasicBlock = next(iter(usage_blocks))
while not self._is_common_dominator(candidate, usage_blocks) or self._is_invalid_dominator(candidate, expression):
candidate = self._dominator_tree.get_predecessors(candidate)[0]
return candidate, self._find_insertion_index(candidate, set(self._usages[expression]))
return candidate, self._find_insertion_index(candidate, self._usages[expression].keys())

def _is_common_dominator(self, candidate: BasicBlock, basic_blocks: Set[BasicBlock]) -> bool:
"""Check if the given candidate is the common dominator all of given basic blocks."""
Expand All @@ -261,10 +269,10 @@ def _insert_definition(self, definition: CfgInstruction):
"""Insert a new intermediate definition for the given expression at the given location."""
definition.block.instructions.insert(definition.index, definition.instruction)
for subexpression in _subexpression_dfs(definition.instruction):
self._usages[subexpression].append(definition)
self._usages[subexpression][definition] += 1

@staticmethod
def _find_insertion_index(basic_block: BasicBlock, usages: Set[CfgInstruction]) -> int:
def _find_insertion_index(basic_block: BasicBlock, usages: Iterable[CfgInstruction]) -> int:
"""Find the first index in the given basic block where a definition could be inserted."""
usage = min((usage for usage in usages if usage.block == basic_block), default=None, key=lambda x: x.index)
if usage:
Expand Down Expand Up @@ -316,7 +324,7 @@ def eliminate_common_subexpressions(self, definition_generator: DefinitionGenera
except StopIteration:
warning(f"[{self.name}] No dominating basic block could be found for {replacee}")

def _find_elimination_candidates(self, usages: DefaultDict[Expression, List[CfgInstruction]]) -> Iterator[Expression]:
def _find_elimination_candidates(self, usages: DefaultDict[Expression, Counter[CfgInstruction]]) -> Iterator[Expression]:
"""
Iterate all expressions, yielding the expressions which should be eliminated.
Expand All @@ -325,11 +333,12 @@ def _find_elimination_candidates(self, usages: DefaultDict[Expression, List[CfgI
expressions_by_complexity = sorted(usages.keys(), reverse=True, key=lambda x: x.complexity)
for expression in expressions_by_complexity:
if self._is_cse_candidate(expression, usages):
expression_usage = usages[expression]
for subexpression in _subexpression_dfs(expression):
usages[subexpression] = [x for x in usages[subexpression] if x not in usages[expression]]
usages[subexpression].subtract(expression_usage)
yield expression

def _is_cse_candidate(self, expression: Expression, usages: DefaultDict[Expression, List[CfgInstruction]]):
def _is_cse_candidate(self, expression: Expression, usages: DefaultDict[Expression, Counter[CfgInstruction]]):
"""Checks that we can add a common subexpression for the given expression."""
return (
self._is_elimination_candidate(expression, usages[expression])
Expand All @@ -347,15 +356,16 @@ def _is_complex_string(self, expression: Expression) -> bool:
return isinstance(expression.value, str) and len(expression.value) >= self._min_string_length
return False

def _check_inter_instruction(self, expression: Expression, instructions: List[CfgInstruction]) -> bool:
def _check_inter_instruction(self, expression: Expression, instructions: Counter[CfgInstruction]) -> bool:
"""Check if the given expressions should be eliminated based on its global occurrences."""
referencing_instructions_count = len(set(instructions))
referencing_instructions_count = sum(1 for _, count in instructions.items() if count > 0)
return (expression.complexity >= 2 and referencing_instructions_count >= self._threshold) or (
self._is_complex_string(expression) and referencing_instructions_count >= self._string_threshold
)

def _check_intra_instruction(self, expression: Expression, instructions: List[CfgInstruction]) -> bool:
def _check_intra_instruction(self, expression: Expression, instructions: Counter[CfgInstruction]) -> bool:
"""Check if this expression should be eliminated based on the amount of unique instructions utilizing it."""
return (expression.complexity >= 2 and len(instructions) >= self._threshold) or (
self._is_complex_string(expression) and len(instructions) >= self._string_threshold
referencing_count = instructions.total()
return (expression.complexity >= 2 and referencing_count >= self._threshold) or (
self._is_complex_string(expression) and referencing_count >= self._string_threshold
)
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[pytest]
addopts = -n auto
python_files = test-*.py test_*.py
markers = coreutils
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ networkx != 2.8.4
pydot
pygments
pytest !=5.3.4
pytest-xdist
z3-solver == 4.8.10
62 changes: 37 additions & 25 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pathlib
import re
from typing import Dict, List, Tuple
from itertools import chain
from typing import Iterator

import pytest
from _pytest.mark import ParameterSet
from _pytest.python import Metafunc


def pytest_addoption(parser):
Expand All @@ -23,7 +26,7 @@ def pytest_configure(config):
setattr(config.option, "markexpr", "not coreutils")


def pytest_generate_tests(metafunc):
def pytest_generate_tests(metafunc: Metafunc):
"""Generates test_cases based on command line options
the resulting fixture test_cases can then be used to parametrize our test_sample function
Expand All @@ -34,53 +37,62 @@ def pytest_generate_tests(metafunc):
test_cases = _discover_full_tests()
else:
test_cases = _discover_system_tests()
params = list()
for sample_name, functions in test_cases.items():
for f in functions:
params.append((sample_name, f))
metafunc.parametrize("test_cases", params)

metafunc.parametrize("test_cases", _create_params(test_cases))

if "coreutils_tests" in metafunc.fixturenames:
coreutils_tests = _discover_coreutils_tests()
metafunc.parametrize("coreutils_tests", coreutils_tests)
metafunc.parametrize("coreutils_tests", _create_params(coreutils_tests))


def _create_params(cases: Iterator[tuple[pathlib.Path, str]]) -> list[ParameterSet]:
"""
Accepts an iterator of sample binaries paired with a function name to test.
Returns a list of ParameterSet objects to be used with metafunc.parametrize.
Note that we sort all test cases by their id so that we have a deterministic/consistent ordering of tests.
This is needed by pytest-xdist to function properly.
See https://pytest-xdist.readthedocs.io/en/stable/known-limitations.html#order-and-amount-of-test-must-be-consistent
"""
test_cases = map(lambda i: pytest.param((i[0], i[1]), id=f"{i[0]}::{i[1]}"), cases)
return sorted(test_cases, key=lambda p: p.id)


def _discover_full_tests() -> Dict[pathlib.Path, List[str]]:
def _discover_full_tests() -> Iterator[tuple[pathlib.Path, str]]:
"""Discover test source files and the test functions in these files.
All files with a .c extension that contain at least one test function are considered as test files.
"""
makefile = _parse_makefile()
test_cases = _discover_tests_in_directory_tree(makefile["system_tests_src_path"], makefile["system_tests_bin_path"])
extended_test_cases = _discover_tests_in_directory_tree(makefile["extended_tests_src_path"], makefile["extended_tests_bin_path"])
test_cases.update(extended_test_cases)
return test_cases

for sample_path, functions in chain(test_cases.items(), extended_test_cases.items()):
for function in functions:
yield sample_path, function


def _discover_system_tests() -> Dict[pathlib.Path, List[str]]:
def _discover_system_tests() -> Iterator[tuple[pathlib.Path, str]]:
"""Returns a mapping of system tests binaries to the lists of function names contained in those binaries"""
test_cases = dict()
makefile = _parse_makefile()
test_code_files = makefile["system_tests_src_path"].glob("*.c")
for test_code_file in test_code_files:
if test_functions := _discover_test_functions_in_sample_code(test_code_file):
test_cases[makefile["system_tests_bin_path"] / "32" / "0" / test_code_file.stem] = test_functions
return test_cases
sample_path = makefile["system_tests_bin_path"] / "32" / "0" / test_code_file.stem
for function_name in _discover_test_functions_in_sample_code(test_code_file):
yield sample_path, function_name


def _discover_coreutils_tests() -> List[Tuple[pathlib.Path, str]]:
def _discover_coreutils_tests() -> Iterator[tuple[pathlib.Path, str]]:
"""Returns list of (binary, func_name) from a text file for the coreutils binaries."""
with pathlib.Path("tests/coreutils/functions.txt").open("r", encoding="utf-8") as f:
funcs_contents = f.readlines()
files = []

for line in funcs_contents:
f = line.split()
path = pathlib.Path(f"tests/coreutils/binaries/{f[0]}")
files.append(pytest.param((path, f[1]), id=f"{f[0]}:{f[1]}"))
return files
(sample_name, function_name) = line.split()
yield pathlib.Path(f"tests/coreutils/binaries/{sample_name}"), function_name


def _discover_tests_in_directory_tree(src_path, bin_path) -> Dict[pathlib.Path, List[str]]:
def _discover_tests_in_directory_tree(src_path: pathlib.Path, bin_path: pathlib.Path) -> dict[pathlib.Path, list[str]]:
"""Return a mapping of binaries collected recursively in the bin_path to function names contained in those binaries"""
test_cases = dict()
test_code_files = src_path.glob("*.c")
Expand All @@ -94,7 +106,7 @@ def _discover_tests_in_directory_tree(src_path, bin_path) -> Dict[pathlib.Path,
return test_cases


def _discover_test_functions_in_sample_code(sample: pathlib.Path) -> List[str]:
def _discover_test_functions_in_sample_code(sample: pathlib.Path) -> list[str]:
"""Discover test functions in the given source file.
Test function to be included have to be named 'testN' where 'N' has to be an integer."""
test_names = list()
Expand All @@ -105,7 +117,7 @@ def _discover_test_functions_in_sample_code(sample: pathlib.Path) -> List[str]:
return test_names


def _parse_makefile() -> Dict[str, pathlib.Path]:
def _parse_makefile() -> dict[str, pathlib.Path]:
"""Parse from Makefile path to systemtests sources and binaries as well as
path to extended tests sources and binaries"""
makefile = dict()
Expand Down

0 comments on commit e80b8ca

Please sign in to comment.