diff --git a/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py b/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py index 03c2ce2ca..91848bbbb 100644 --- a/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py +++ b/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py @@ -2,11 +2,11 @@ from __future__ import annotations -from collections import defaultdict, deque +from collections import Counter, defaultdict, deque from dataclasses import dataclass from itertools import chain from logging import info, warning -from typing import DefaultDict, Deque, Dict, Iterator, List, Optional, Set, Tuple +from typing import DefaultDict, Deque, Dict, Iterable, Iterator, List, Optional, Set, Tuple from decompiler.pipeline.stage import PipelineStage from decompiler.structures.graphs.cfg import BasicBlock, ControlFlowGraph @@ -18,12 +18,20 @@ from networkx import dfs_postorder_nodes -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class CfgInstruction: """ dataclass in charge of tracking the location of Instruction objects in the cfg -> The considered instruction, where block is the basic block where it is contained and index the position in the basic block. + + Note: Two instances with the same data will not be equal (because of eq=False). + This way, eq and hash are way more performant, because at the time of writing this, eq and hash are very + expensive on big instructions. + + eq=True would probably be nicer to use, but we don't actually create instances with the same data + multiple times. (Rationale: initially just one instance is created per (block, index) pair. + All further instances with the same (block, index) will have a less complex instruction than before) """ instruction: Instruction @@ -200,7 +208,7 @@ class DefinitionGenerator: def __init__( self, - expression_usages: DefaultDict[Expression, List[CfgInstruction]], + expression_usages: DefaultDict[Expression, Counter[CfgInstruction]], dominator_tree: NetworkXGraph, ): """Generate a new instance based on data parsed from a cfg.""" @@ -210,16 +218,16 @@ def __init__( @classmethod def from_cfg(cls, cfg: ControlFlowGraph) -> DefinitionGenerator: """Initialize a DefinitionGenerator utilizing the data of the given cfg.""" - usages: DefaultDict[Expression, List[CfgInstruction]] = defaultdict(list) + usages: DefaultDict[Expression, Counter[CfgInstruction]] = defaultdict(Counter) for basic_block in cfg: for index, instruction in enumerate(basic_block.instructions): instruction_with_position = CfgInstruction(instruction, basic_block, index) for subexpression in _subexpression_dfs(instruction): - usages[subexpression].append(instruction_with_position) + usages[subexpression][instruction_with_position] += 1 return cls(usages, cfg.dominator_tree) @property - def usages(self) -> DefaultDict[Expression, List[CfgInstruction]]: + def usages(self) -> DefaultDict[Expression, Counter[CfgInstruction]]: """Return a mapping from expressions to a set of instructions using them.""" return self._usages @@ -241,7 +249,7 @@ def _find_location_for_insertion(self, expression) -> Tuple[BasicBlock, int]: candidate: BasicBlock = next(iter(usage_blocks)) while not self._is_common_dominator(candidate, usage_blocks) or self._is_invalid_dominator(candidate, expression): candidate = self._dominator_tree.get_predecessors(candidate)[0] - return candidate, self._find_insertion_index(candidate, set(self._usages[expression])) + return candidate, self._find_insertion_index(candidate, self._usages[expression].keys()) def _is_common_dominator(self, candidate: BasicBlock, basic_blocks: Set[BasicBlock]) -> bool: """Check if the given candidate is the common dominator all of given basic blocks.""" @@ -261,10 +269,10 @@ def _insert_definition(self, definition: CfgInstruction): """Insert a new intermediate definition for the given expression at the given location.""" definition.block.instructions.insert(definition.index, definition.instruction) for subexpression in _subexpression_dfs(definition.instruction): - self._usages[subexpression].append(definition) + self._usages[subexpression][definition] += 1 @staticmethod - def _find_insertion_index(basic_block: BasicBlock, usages: Set[CfgInstruction]) -> int: + def _find_insertion_index(basic_block: BasicBlock, usages: Iterable[CfgInstruction]) -> int: """Find the first index in the given basic block where a definition could be inserted.""" usage = min((usage for usage in usages if usage.block == basic_block), default=None, key=lambda x: x.index) if usage: @@ -316,7 +324,7 @@ def eliminate_common_subexpressions(self, definition_generator: DefinitionGenera except StopIteration: warning(f"[{self.name}] No dominating basic block could be found for {replacee}") - def _find_elimination_candidates(self, usages: DefaultDict[Expression, List[CfgInstruction]]) -> Iterator[Expression]: + def _find_elimination_candidates(self, usages: DefaultDict[Expression, Counter[CfgInstruction]]) -> Iterator[Expression]: """ Iterate all expressions, yielding the expressions which should be eliminated. @@ -325,11 +333,12 @@ def _find_elimination_candidates(self, usages: DefaultDict[Expression, List[CfgI expressions_by_complexity = sorted(usages.keys(), reverse=True, key=lambda x: x.complexity) for expression in expressions_by_complexity: if self._is_cse_candidate(expression, usages): + expression_usage = usages[expression] for subexpression in _subexpression_dfs(expression): - usages[subexpression] = [x for x in usages[subexpression] if x not in usages[expression]] + usages[subexpression].subtract(expression_usage) yield expression - def _is_cse_candidate(self, expression: Expression, usages: DefaultDict[Expression, List[CfgInstruction]]): + def _is_cse_candidate(self, expression: Expression, usages: DefaultDict[Expression, Counter[CfgInstruction]]): """Checks that we can add a common subexpression for the given expression.""" return ( self._is_elimination_candidate(expression, usages[expression]) @@ -347,15 +356,16 @@ def _is_complex_string(self, expression: Expression) -> bool: return isinstance(expression.value, str) and len(expression.value) >= self._min_string_length return False - def _check_inter_instruction(self, expression: Expression, instructions: List[CfgInstruction]) -> bool: + def _check_inter_instruction(self, expression: Expression, instructions: Counter[CfgInstruction]) -> bool: """Check if the given expressions should be eliminated based on its global occurrences.""" - referencing_instructions_count = len(set(instructions)) + referencing_instructions_count = sum(1 for _, count in instructions.items() if count > 0) return (expression.complexity >= 2 and referencing_instructions_count >= self._threshold) or ( self._is_complex_string(expression) and referencing_instructions_count >= self._string_threshold ) - def _check_intra_instruction(self, expression: Expression, instructions: List[CfgInstruction]) -> bool: + def _check_intra_instruction(self, expression: Expression, instructions: Counter[CfgInstruction]) -> bool: """Check if this expression should be eliminated based on the amount of unique instructions utilizing it.""" - return (expression.complexity >= 2 and len(instructions) >= self._threshold) or ( - self._is_complex_string(expression) and len(instructions) >= self._string_threshold + referencing_count = instructions.total() + return (expression.complexity >= 2 and referencing_count >= self._threshold) or ( + self._is_complex_string(expression) and referencing_count >= self._string_threshold ) diff --git a/pytest.ini b/pytest.ini index fc9ce7a1f..6c00089ab 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,4 @@ [pytest] +addopts = -n auto python_files = test-*.py test_*.py markers = coreutils diff --git a/requirements.txt b/requirements.txt index 3fca302ff..b4209be4e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ networkx != 2.8.4 pydot pygments pytest !=5.3.4 +pytest-xdist z3-solver == 4.8.10 diff --git a/tests/conftest.py b/tests/conftest.py index 24bcbdd6b..a31d0ccef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,11 @@ import pathlib import re -from typing import Dict, List, Tuple +from itertools import chain +from typing import Iterator import pytest +from _pytest.mark import ParameterSet +from _pytest.python import Metafunc def pytest_addoption(parser): @@ -23,7 +26,7 @@ def pytest_configure(config): setattr(config.option, "markexpr", "not coreutils") -def pytest_generate_tests(metafunc): +def pytest_generate_tests(metafunc: Metafunc): """Generates test_cases based on command line options the resulting fixture test_cases can then be used to parametrize our test_sample function @@ -34,18 +37,28 @@ def pytest_generate_tests(metafunc): test_cases = _discover_full_tests() else: test_cases = _discover_system_tests() - params = list() - for sample_name, functions in test_cases.items(): - for f in functions: - params.append((sample_name, f)) - metafunc.parametrize("test_cases", params) + + metafunc.parametrize("test_cases", _create_params(test_cases)) if "coreutils_tests" in metafunc.fixturenames: coreutils_tests = _discover_coreutils_tests() - metafunc.parametrize("coreutils_tests", coreutils_tests) + metafunc.parametrize("coreutils_tests", _create_params(coreutils_tests)) + + +def _create_params(cases: Iterator[tuple[pathlib.Path, str]]) -> list[ParameterSet]: + """ + Accepts an iterator of sample binaries paired with a function name to test. + Returns a list of ParameterSet objects to be used with metafunc.parametrize. + + Note that we sort all test cases by their id so that we have a deterministic/consistent ordering of tests. + This is needed by pytest-xdist to function properly. + See https://pytest-xdist.readthedocs.io/en/stable/known-limitations.html#order-and-amount-of-test-must-be-consistent + """ + test_cases = map(lambda i: pytest.param((i[0], i[1]), id=f"{i[0]}::{i[1]}"), cases) + return sorted(test_cases, key=lambda p: p.id) -def _discover_full_tests() -> Dict[pathlib.Path, List[str]]: +def _discover_full_tests() -> Iterator[tuple[pathlib.Path, str]]: """Discover test source files and the test functions in these files. All files with a .c extension that contain at least one test function are considered as test files. @@ -53,34 +66,33 @@ def _discover_full_tests() -> Dict[pathlib.Path, List[str]]: makefile = _parse_makefile() test_cases = _discover_tests_in_directory_tree(makefile["system_tests_src_path"], makefile["system_tests_bin_path"]) extended_test_cases = _discover_tests_in_directory_tree(makefile["extended_tests_src_path"], makefile["extended_tests_bin_path"]) - test_cases.update(extended_test_cases) - return test_cases + + for sample_path, functions in chain(test_cases.items(), extended_test_cases.items()): + for function in functions: + yield sample_path, function -def _discover_system_tests() -> Dict[pathlib.Path, List[str]]: +def _discover_system_tests() -> Iterator[tuple[pathlib.Path, str]]: """Returns a mapping of system tests binaries to the lists of function names contained in those binaries""" - test_cases = dict() makefile = _parse_makefile() test_code_files = makefile["system_tests_src_path"].glob("*.c") for test_code_file in test_code_files: - if test_functions := _discover_test_functions_in_sample_code(test_code_file): - test_cases[makefile["system_tests_bin_path"] / "32" / "0" / test_code_file.stem] = test_functions - return test_cases + sample_path = makefile["system_tests_bin_path"] / "32" / "0" / test_code_file.stem + for function_name in _discover_test_functions_in_sample_code(test_code_file): + yield sample_path, function_name -def _discover_coreutils_tests() -> List[Tuple[pathlib.Path, str]]: +def _discover_coreutils_tests() -> Iterator[tuple[pathlib.Path, str]]: """Returns list of (binary, func_name) from a text file for the coreutils binaries.""" with pathlib.Path("tests/coreutils/functions.txt").open("r", encoding="utf-8") as f: funcs_contents = f.readlines() - files = [] + for line in funcs_contents: - f = line.split() - path = pathlib.Path(f"tests/coreutils/binaries/{f[0]}") - files.append(pytest.param((path, f[1]), id=f"{f[0]}:{f[1]}")) - return files + (sample_name, function_name) = line.split() + yield pathlib.Path(f"tests/coreutils/binaries/{sample_name}"), function_name -def _discover_tests_in_directory_tree(src_path, bin_path) -> Dict[pathlib.Path, List[str]]: +def _discover_tests_in_directory_tree(src_path: pathlib.Path, bin_path: pathlib.Path) -> dict[pathlib.Path, list[str]]: """Return a mapping of binaries collected recursively in the bin_path to function names contained in those binaries""" test_cases = dict() test_code_files = src_path.glob("*.c") @@ -94,7 +106,7 @@ def _discover_tests_in_directory_tree(src_path, bin_path) -> Dict[pathlib.Path, return test_cases -def _discover_test_functions_in_sample_code(sample: pathlib.Path) -> List[str]: +def _discover_test_functions_in_sample_code(sample: pathlib.Path) -> list[str]: """Discover test functions in the given source file. Test function to be included have to be named 'testN' where 'N' has to be an integer.""" test_names = list() @@ -105,7 +117,7 @@ def _discover_test_functions_in_sample_code(sample: pathlib.Path) -> List[str]: return test_names -def _parse_makefile() -> Dict[str, pathlib.Path]: +def _parse_makefile() -> dict[str, pathlib.Path]: """Parse from Makefile path to systemtests sources and binaries as well as path to extended tests sources and binaries""" makefile = dict()