diff --git a/deps/llvm-backend_release b/deps/llvm-backend_release index 5950146bb3b..841597f02d9 100644 --- a/deps/llvm-backend_release +++ b/deps/llvm-backend_release @@ -1 +1 @@ -0.1.103 +0.1.119 diff --git a/flake.lock b/flake.lock index 0680f738f2b..75cee391e60 100644 --- a/flake.lock +++ b/flake.lock @@ -112,16 +112,16 @@ "utils": "utils" }, "locked": { - "lastModified": 1730229432, - "narHash": "sha256-2Y4U7TCmSf9NAZCBmvXiHLOXrHxpiRgIpw5ERYDdNSM=", + "lastModified": 1734123780, + "narHash": "sha256-cR/a2NpIyRL6kb4zA0JCbufLITCRfyHpcdzWcnhDTe8=", "owner": "runtimeverification", "repo": "llvm-backend", - "rev": "d5eab4b0f0e610bc60843ebb482f79c043b92702", + "rev": "1cd3319755df0c53d437313aea0d71b6cfdd9de5", "type": "github" }, "original": { "owner": "runtimeverification", - "ref": "v0.1.103", + "ref": "v0.1.119", "repo": "llvm-backend", "type": "github" } diff --git a/flake.nix b/flake.nix index 77aacd96ac3..df959ff6152 100644 --- a/flake.nix +++ b/flake.nix @@ -1,7 +1,7 @@ { description = "K Framework"; inputs = { - llvm-backend.url = "github:runtimeverification/llvm-backend/v0.1.103"; + llvm-backend.url = "github:runtimeverification/llvm-backend/v0.1.119"; haskell-backend = { url = "github:runtimeverification/haskell-backend/v0.1.105"; inputs.rv-utils.follows = "llvm-backend/rv-utils"; diff --git a/k-distribution/tests/regression-new/proof-instrumentation-debug/input.test.out b/k-distribution/tests/regression-new/proof-instrumentation-debug/input.test.out index 4706f33c253..ef03e1e235d 100644 Binary files a/k-distribution/tests/regression-new/proof-instrumentation-debug/input.test.out and b/k-distribution/tests/regression-new/proof-instrumentation-debug/input.test.out differ diff --git a/k-distribution/tests/regression-new/proof-instrumentation/input.test.out b/k-distribution/tests/regression-new/proof-instrumentation/input.test.out index 46c83b90bc1..24aac4eacce 100644 Binary files a/k-distribution/tests/regression-new/proof-instrumentation/input.test.out and b/k-distribution/tests/regression-new/proof-instrumentation/input.test.out differ diff --git a/llvm-backend/src/main/native/llvm-backend b/llvm-backend/src/main/native/llvm-backend index d5eab4b0f0e..1cd3319755d 160000 --- a/llvm-backend/src/main/native/llvm-backend +++ b/llvm-backend/src/main/native/llvm-backend @@ -1 +1 @@ -Subproject commit d5eab4b0f0e610bc60843ebb482f79c043b92702 +Subproject commit 1cd3319755df0c53d437313aea0d71b6cfdd9de5 diff --git a/pyk/src/pyk/kast/att.py b/pyk/src/pyk/kast/att.py index 5840146ba0a..b428e79a2c5 100644 --- a/pyk/src/pyk/kast/att.py +++ b/pyk/src/pyk/kast/att.py @@ -290,6 +290,7 @@ class Atts: ALIAS_REC: Final = AttKey('alias-rec', type=_NONE) ANYWHERE: Final = AttKey('anywhere', type=_NONE) ASSOC: Final = AttKey('assoc', type=_NONE) + AVOID: Final = AttKey('avoid', type=_NONE) BRACKET: Final = AttKey('bracket', type=_NONE) BRACKET_LABEL: Final = AttKey('bracketLabel', type=_ANY) CIRCULARITY: Final = AttKey('circularity', type=_NONE) @@ -307,6 +308,7 @@ class Atts: DEPENDS: Final = AttKey('depends', type=_ANY) DIGEST: Final = AttKey('digest', type=_ANY) ELEMENT: Final = AttKey('element', type=_ANY) + EXIT: Final = AttKey('exit', type=_ANY) FORMAT: Final = AttKey('format', type=FormatType()) FRESH_GENERATOR: Final = AttKey('freshGenerator', type=_NONE) FUNCTION: Final = AttKey('function', type=_NONE) @@ -325,6 +327,8 @@ class Atts: MACRO: Final = AttKey('macro', type=_NONE) MACRO_REC: Final = AttKey('macro-rec', type=_NONE) MAINCELL: Final = AttKey('maincell', type=_NONE) + MULTIPLICITY: Final = AttKey('multiplicity', type=_ANY) + NO_EVALUATORS: Final = AttKey('no-evaluators', type=_NONE) OVERLOAD: Final = AttKey('overload', type=_STR) OWISE: Final = AttKey('owise', type=_NONE) PREDICATE: Final = AttKey('predicate', type=_ANY) @@ -335,6 +339,7 @@ class Atts: PRODUCTION: Final = AttKey('org.kframework.definition.Production', type=_ANY) PROJECTION: Final = AttKey('projection', type=_NONE) RIGHT: Final = AttKey('right', type=_ANY) # RIGHT and RIGHT_INTERNAL on the Frontend + RETURNS_UNIT: Final = AttKey('returnsUnit', type=_NONE) SIMPLIFICATION: Final = AttKey('simplification', type=_ANY) SEQSTRICT: Final = AttKey('seqstrict', type=_ANY) SORT: Final = AttKey('org.kframework.kore.Sort', type=_ANY) @@ -345,9 +350,11 @@ class Atts: SYNTAX_MODULE: Final = AttKey('syntaxModule', type=_STR) SYMBOLIC: Final = AttKey('symbolic', type=OptionalType(_STR)) TERMINALS: Final = AttKey('terminals', type=_STR) + TERMINATOR_SYMBOL: Final = AttKey('terminator-symbol', type=_ANY) TOKEN: Final = AttKey('token', type=_NONE) TOTAL: Final = AttKey('total', type=_NONE) TRUSTED: Final = AttKey('trusted', type=_NONE) + TYPE: Final = AttKey('type', type=_ANY) UNIT: Final = AttKey('unit', type=_STR) UNIQUE_ID: Final = AttKey('UNIQUE_ID', type=_ANY) UNPARSE_AVOID: Final = AttKey('unparseAvoid', type=_NONE) diff --git a/pyk/src/pyk/kast/inner.py b/pyk/src/pyk/kast/inner.py index e034c048e7a..957981fd67f 100644 --- a/pyk/src/pyk/kast/inner.py +++ b/pyk/src/pyk/kast/inner.py @@ -877,20 +877,18 @@ def _var_occurence(_term: KInner) -> None: return _var_occurrences -# TODO replace by method that does not reconstruct the AST def collect(callback: Callable[[KInner], None], kinner: KInner) -> None: - """Collect information about a given term traversing it bottom-up using a function with side effects. + """Collect information about a given term traversing it top-down using a function with side effects. Args: callback: Function with the side effect of collecting desired information at each AST node. kinner: The term to traverse. """ - - def f(kinner: KInner) -> KInner: - callback(kinner) - return kinner - - bottom_up(f, kinner) + subterms = [kinner] + while subterms: + term = subterms.pop() + subterms.extend(reversed(term.terms)) + callback(term) def build_assoc(unit: KInner, label: str | KLabel, terms: Iterable[KInner]) -> KInner: diff --git a/pyk/src/pyk/kcfg/kcfg.py b/pyk/src/pyk/kcfg/kcfg.py index 825dd7c3289..57c0e95dc97 100644 --- a/pyk/src/pyk/kcfg/kcfg.py +++ b/pyk/src/pyk/kcfg/kcfg.py @@ -559,6 +559,8 @@ def extend( extend_result: KCFGExtendResult, node: KCFG.Node, logs: dict[int, tuple[LogEntry, ...]], + *, + optimize_kcfg: bool, ) -> None: def log(message: str, *, warning: bool = False) -> None: @@ -584,10 +586,25 @@ def log(message: str, *, warning: bool = False) -> None: log(f'abstraction node: {node.id}') case Step(cterm, depth, next_node_logs, rule_labels, _): + node_id = node.id next_node = self.create_node(cterm) + # Optimization for steps consists of on-the-fly merging of consecutive edges and can + # be performed only if the current node has a single predecessor connected by an Edge + if ( + optimize_kcfg + and (len(predecessors := self.predecessors(target_id=node.id)) == 1) + and isinstance(in_edge := predecessors[0], KCFG.Edge) + ): + # The existing edge is removed and the step parameters are updated accordingly + self.remove_edge(in_edge.source.id, node.id) + node_id = in_edge.source.id + depth += in_edge.depth + rule_labels = list(in_edge.rules) + rule_labels + next_node_logs = logs[node.id] + next_node_logs if node.id in logs else next_node_logs + self.remove_node(node.id) + self.create_edge(node_id, next_node.id, depth, rule_labels) logs[next_node.id] = next_node_logs - self.create_edge(node.id, next_node.id, depth, rules=rule_labels) - log(f'basic block at depth {depth}: {node.id} --> {next_node.id}') + log(f'basic block at depth {depth}: {node_id} --> {next_node.id}') case Branch(branches, _): branch_node_ids = self.split_on_constraints(node.id, branches) diff --git a/pyk/src/pyk/kllvm/hints/prooftrace.py b/pyk/src/pyk/kllvm/hints/prooftrace.py index fabe1d2f64d..36028cd99f2 100644 --- a/pyk/src/pyk/kllvm/hints/prooftrace.py +++ b/pyk/src/pyk/kllvm/hints/prooftrace.py @@ -9,6 +9,7 @@ kore_header, llvm_rewrite_event, llvm_function_event, + llvm_function_exit_event, llvm_hook_event, llvm_rewrite_trace, llvm_rule_event, @@ -245,6 +246,43 @@ def args(self) -> list[LLVMArgument]: return [LLVMArgument(arg) for arg in self._function_event.args] +@final +class LLVMFunctionExitEvent(LLVMStepEvent): + """Represent an LLVM function exit event in a proof trace. + + Attributes: + _function_exit_event (llvm_function_exit_event): The underlying LLVM function exit event object. + """ + + _function_exit_event: llvm_function_exit_event + + def __init__(self, function_exit_event: llvm_function_exit_event) -> None: + """Initialize a new instance of the LLVMFunctionExitEvent class. + + Args: + function_exit_event (llvm_function_exit_event): The LLVM function exit event object. + """ + self._function_exit_event = function_exit_event + + def __repr__(self) -> str: + """Return a string representation of the object. + + Returns: + A string representation of the LLVMFunctionExitEvent object using the AST printing method. + """ + return self._function_exit_event.__repr__() + + @property + def rule_ordinal(self) -> int: + """Return the axiom ordinal number associated with the function exit event.""" + return self._function_exit_event.rule_ordinal + + @property + def is_tail(self) -> bool: + """Return True if the function exit event is a tail call.""" + return self._function_exit_event.is_tail + + @final class LLVMHookEvent(LLVMStepEvent): """Represents a hook event in LLVM execution. @@ -330,6 +368,8 @@ def step_event(self) -> LLVMStepEvent: return LLVMSideConditionEventExit(self._argument.step_event) elif isinstance(self._argument.step_event, llvm_function_event): return LLVMFunctionEvent(self._argument.step_event) + elif isinstance(self._argument.step_event, llvm_function_exit_event): + return LLVMFunctionExitEvent(self._argument.step_event) elif isinstance(self._argument.step_event, llvm_hook_event): return LLVMHookEvent(self._argument.step_event) elif isinstance(self._argument.step_event, llvm_pattern_matching_failure_event): diff --git a/pyk/src/pyk/konvert/_module_to_kore.py b/pyk/src/pyk/konvert/_module_to_kore.py index 8ab509a2501..ccd19ff3f25 100644 --- a/pyk/src/pyk/konvert/_module_to_kore.py +++ b/pyk/src/pyk/konvert/_module_to_kore.py @@ -4,14 +4,14 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from functools import reduce -from itertools import repeat +from itertools import chain, repeat from pathlib import Path from typing import ClassVar # noqa: TC003 from typing import TYPE_CHECKING, NamedTuple, final -from ..kast import EMPTY_ATT, Atts, KInner -from ..kast.att import Format -from ..kast.inner import KApply, KRewrite, KSort +from ..kast import AttKey, Atts, KAtt, KInner +from ..kast.att import Format, NoneType +from ..kast.inner import KApply, KLabel, KRewrite, KSort, collect from ..kast.manip import extract_lhs, extract_rhs from ..kast.outer import KDefinition, KNonTerminal, KProduction, KRegexTerminal, KRule, KSyntaxSort, KTerminal from ..kore.prelude import inj @@ -37,7 +37,7 @@ Top, ) from ..prelude.k import K_ITEM, K -from ..utils import FrozenDict, intersperse +from ..utils import FrozenDict, intersperse, not_none from ._kast_to_kore import _kast_to_kore from ._utils import munge @@ -45,8 +45,7 @@ from collections.abc import Callable, Iterable, Mapping from typing import Any, Final - from ..kast import AttEntry, AttKey, KAtt - from ..kast.inner import KLabel + from ..kast import AttEntry from ..kast.outer import KFlatModule, KSentence from ..kore.syntax import Pattern, Sentence, Sort @@ -83,6 +82,9 @@ } +_INTERNAL_CONSTRUCTOR: Final = AttKey('internal-constructor', type=NoneType()) + + def module_to_kore(definition: KDefinition) -> Module: """Convert the main module of a kompiled KAST definition to KORE format.""" module = simplified_module(definition) @@ -98,9 +100,9 @@ def module_to_kore(definition: KDefinition) -> Module: if syntax_sort.sort.name not in [K.name, K_ITEM.name] ] symbol_decls = [ - symbol_prod_to_kore(sentence) - for sentence in module.sentences - if isinstance(sentence, KProduction) and sentence.klabel and sentence.klabel.name not in BUILTIN_LABELS + symbol_prod_to_kore(prod) + for prod in module.productions + if prod.klabel and prod.klabel.name not in BUILTIN_LABELS ] sentences: list[Sentence] = [] @@ -117,6 +119,19 @@ def module_to_kore(definition: KDefinition) -> Module: sentences += _overload_axioms(defn) res = Module(name=name, sentences=sentences, attrs=attrs) + + # Filter the assoc, and _internal_constructor attribute + res = res.let( + sentences=( + ( + sent.let_attrs(attr for attr in sent.attrs if attr.symbol not in ['assoc', 'internal-constructor']) + if isinstance(sent, SymbolDecl) + else sent + ) + for sent in res.sentences + ) + ) + # Filter the overload attribute res = res.let( sentences=(sent.let_attrs(attr for attr in sent.attrs if attr.symbol != 'overload') for sent in res.sentences) @@ -265,11 +280,8 @@ def subsort_axiom(subsort: Sort, supersort: Sort) -> Axiom: ) res: list[Axiom] = [] - for sentence in module.sentences: - if not isinstance(sentence, KProduction): - continue - - subsort_res = sentence.as_subsort + for prod in module.productions: + subsort_res = prod.as_subsort if not subsort_res: continue @@ -327,17 +339,15 @@ def app(left: Pattern, right: Pattern) -> App: module = defn.modules[0] res: list[Axiom] = [] - for sentence in module.sentences: - if not isinstance(sentence, KProduction): + for prod in module.productions: + if not prod.klabel: continue - if not sentence.klabel: + if prod.klabel.name in BUILTIN_LABELS: continue - if sentence.klabel.name in BUILTIN_LABELS: - continue - if not Atts.ASSOC in sentence.att: + if not Atts.ASSOC in prod.att: continue - res.append(assoc_axiom(sentence)) + res.append(assoc_axiom(prod)) return res @@ -373,16 +383,14 @@ def check_is_prod_sort(sort: KSort) -> None: ) res: list[Axiom] = [] - for sentence in module.sentences: - if not isinstance(sentence, KProduction): - continue - if not sentence.klabel: + for prod in module.productions: + if not prod.klabel: continue - if sentence.klabel.name in BUILTIN_LABELS: + if prod.klabel.name in BUILTIN_LABELS: continue - if not Atts.IDEM in sentence.att: + if not Atts.IDEM in prod.att: continue - res.append(idem_axiom(sentence)) + res.append(idem_axiom(prod)) return res @@ -425,18 +433,16 @@ def check_is_prod_sort(sort: KSort) -> None: return left_unit, right_unit res: list[Axiom] = [] - for sentence in module.sentences: - if not isinstance(sentence, KProduction): - continue - if not sentence.klabel: + for prod in module.productions: + if not prod.klabel: continue - if sentence.klabel.name in BUILTIN_LABELS: + if prod.klabel.name in BUILTIN_LABELS: continue - if not Atts.FUNCTION in sentence.att: + if not Atts.FUNCTION in prod.att: continue - if not Atts.UNIT in sentence.att: + if not Atts.UNIT in prod.att: continue - res.extend(unit_axioms(sentence)) + res.extend(unit_axioms(prod)) return res @@ -463,16 +469,14 @@ def functional_axiom(production: KProduction) -> Axiom: ) res: list[Axiom] = [] - for sentence in module.sentences: - if not isinstance(sentence, KProduction): + for prod in module.productions: + if not prod.klabel: continue - if not sentence.klabel: + if prod.klabel.name in BUILTIN_LABELS: continue - if sentence.klabel.name in BUILTIN_LABELS: + if not Atts.FUNCTIONAL in prod.att: continue - if not Atts.FUNCTIONAL in sentence.att: - continue - res.append(functional_axiom(sentence)) + res.append(functional_axiom(prod)) return res @@ -534,21 +538,26 @@ def axiom_for_diff_constr(prod1: KProduction, prod2: KProduction) -> Axiom: ) prods = [ - sent - for sent in module.sentences - if isinstance(sent, KProduction) - and sent.klabel - and sent.klabel.name not in BUILTIN_LABELS - and Atts.CONSTRUCTOR in sent.att + prod + for prod in module.productions + if prod.klabel and prod.klabel.name not in BUILTIN_LABELS and _INTERNAL_CONSTRUCTOR in prod.att ] res: list[Axiom] = [] res += (axiom_for_same_constr(p) for p in prods if p.non_terminals) + + prods_by_sort: dict[KSort, list[KProduction]] = {} + for prod in prods: + prods_by_sort.setdefault(prod.sort, []).append(prod) + + for _, prods in prods_by_sort.items(): + prods.sort(key=lambda p: not_none(p.klabel).name) # type: ignore [attr-defined] + res += ( - axiom_for_diff_constr(p1, p2) - for p1 in prods - for p2 in prods - if p1.sort == p2.sort and p1.klabel and p2.klabel and p1.klabel.name < p2.klabel.name + axiom_for_diff_constr(prods[i], prods[j]) + for prods in prods_by_sort.values() + for i in range(len(prods)) + for j in range(i + 1, len(prods)) ) return res @@ -606,7 +615,10 @@ def key(production: KProduction) -> str: ( prod for prod in productions_for_sort - if prod.klabel and prod.klabel not in BUILTIN_LABELS and Atts.FUNCTION not in prod.att + if prod.klabel + and prod.klabel not in BUILTIN_LABELS + and Atts.FUNCTION not in prod.att + and Atts.MACRO not in prod.att ), key=key, ) @@ -702,6 +714,11 @@ def simplified_module(definition: KDefinition, module_name: str | None = None) - pipeline = ( FlattenDefinition(module_name), # sorts + DiscardSyntaxSortAtts( + [ + Atts.CELL_COLLECTION, + ], + ), AddSyntaxSorts(), AddCollectionAtts(), AddDomainValueAtts(), @@ -709,36 +726,48 @@ def simplified_module(definition: KDefinition, module_name: str | None = None) - PullUpRewrites(), DiscardSymbolAtts( [ - Atts.ASSOC, + Atts.AVOID, + Atts.CELL_COLLECTION, Atts.CELL_FRAGMENT, Atts.CELL_NAME, Atts.CELL_OPT_ABSENT, Atts.COLOR, Atts.COLORS, Atts.COMM, + Atts.EXIT, Atts.FORMAT, Atts.GROUP, - Atts.IMPURE, Atts.INDEX, Atts.INITIALIZER, Atts.LEFT, Atts.MAINCELL, + Atts.MULTIPLICITY, Atts.PREDICATE, Atts.PREFER, Atts.PRIVATE, Atts.PRODUCTION, Atts.PROJECTION, + Atts.RETURNS_UNIT, Atts.RIGHT, Atts.SEQSTRICT, Atts.STRICT, + Atts.TYPE, + Atts.TERMINATOR_SYMBOL, Atts.USER_LIST, + Atts.WRAP_ELEMENT, ], ), DiscardHookAtts(), - AddAnywhereAtts(), + AddImpureAtts(), AddSymbolAtts(Atts.MACRO(None), _is_macro), AddSymbolAtts(Atts.FUNCTIONAL(None), _is_functional), AddSymbolAtts(Atts.INJECTIVE(None), _is_injective), + AddAnywhereAttsFromRules(), + # Mark symbols that require constructor axioms with an internal attribute. + # Has to precede `AddAnywhereAttsFromOverloads`: symbols that would be considewred constructors without + # the extra `anywhere` require a constructor axiom. + AddSymbolAtts(_INTERNAL_CONSTRUCTOR(None), _is_constructor), + AddAnywhereAttsFromOverloads(), AddSymbolAtts(Atts.CONSTRUCTOR(None), _is_constructor), ) definition = reduce(lambda defn, step: step.execute(defn), pipeline, definition) @@ -804,6 +833,22 @@ def _imported_sentences(definition: KDefinition, module_name: str) -> list[KSent return res +@dataclass +class DiscardSyntaxSortAtts(SingleModulePass): + """Remove certain attributes from syntax sorts.""" + + keys: frozenset[AttKey] + + def __init__(self, keys: Iterable[AttKey]): + self.keys = frozenset(keys) + + def _transform_module(self, module: KFlatModule) -> KFlatModule: + return module.map_sentences(self._update, of_type=KSyntaxSort) + + def _update(self, syntax_sort: KSyntaxSort) -> KSyntaxSort: + return syntax_sort.let(att=syntax_sort.att.discard(self.keys)) + + @dataclass class AddSyntaxSorts(SingleModulePass): """Return a definition with explicit syntax declarations: each sort is declared with the union of its attributes.""" @@ -816,31 +861,33 @@ def _transform_module(self, module: KFlatModule) -> KFlatModule: @staticmethod def _syntax_sorts(module: KFlatModule) -> list[KSyntaxSort]: """Return a declaration for each sort in the module.""" - declarations: dict[KSort, KAtt] = {} def is_higher_order(production: KProduction) -> bool: # Example: syntax {Sort} Sort ::= Sort "#as" Sort return production.sort in production.params + def merge_atts(atts: list[KAtt]) -> KAtt: + grouped: dict[AttKey, set[Any]] = {} + for att, value in chain.from_iterable(att.items() for att in atts): + grouped.setdefault(att, set()).add(value) + + entries = [att(next(iter(values))) for att, values in grouped.items() if len(values) == 1] + return KAtt(entries) + + declarations: dict[KSort, list[KAtt]] = {} + # Merge attributes from KSyntaxSort instances for syntax_sort in module.syntax_sorts: - sort = syntax_sort.sort - if sort not in declarations: - declarations[sort] = syntax_sort.att - else: - assert declarations[sort].keys().isdisjoint(syntax_sort.att) - declarations[sort] = declarations[sort].update(syntax_sort.att.entries()) + declarations.setdefault(syntax_sort.sort, []).append(syntax_sort.att) # Also consider production sorts for production in module.productions: if is_higher_order(production): continue - sort = production.sort - if sort not in declarations: - declarations[sort] = EMPTY_ATT + declarations.setdefault(production.sort, []) - return [KSyntaxSort(sort, att=att) for sort, att in declarations.items()] + return [KSyntaxSort(sort, att=merge_atts(atts)) for sort, atts in declarations.items()] @dataclass @@ -857,38 +904,36 @@ class AddCollectionAtts(SingleModulePass): ) def _transform_module(self, module: KFlatModule) -> KFlatModule: - # Example: syntax Map ::= Map Map [..., klabel(_Map_), element(_|->_), unit(.Map), ...] - concat_atts = {prod.sort: prod.att for prod in module.productions if Atts.ELEMENT in prod.att} + # Example: syntax Map ::= Map Map [..., element(_|->_), unit(.Map), ...] + concat_prods = {prod.sort: prod for prod in module.productions if Atts.ELEMENT in prod.att} assert all( - Atts.UNIT in att for _, att in concat_atts.items() + Atts.UNIT in prod.att for _, prod in concat_prods.items() ) # TODO Could be saved with a different attribute structure: concat(Element, Unit) - return module.map_sentences(lambda syntax_sort: self._update(syntax_sort, concat_atts), of_type=KSyntaxSort) + return module.map_sentences(lambda syntax_sort: self._update(syntax_sort, concat_prods), of_type=KSyntaxSort) @staticmethod - def _update(syntax_sort: KSyntaxSort, concat_atts: Mapping[KSort, KAtt]) -> KSyntaxSort: + def _update(syntax_sort: KSyntaxSort, concat_prods: Mapping[KSort, KProduction]) -> KSyntaxSort: if syntax_sort.att.get(Atts.HOOK) not in AddCollectionAtts.COLLECTION_HOOKS: return syntax_sort - assert syntax_sort.sort in concat_atts - concat_att = concat_atts[syntax_sort.sort] + assert syntax_sort.sort in concat_prods + concat_prod = concat_prods[syntax_sort.sort] - # Workaround until zero-argument symbol is removed, rather than - # deprecated. - symbol = concat_att[Atts.SYMBOL] - assert symbol is not None + klabel = concat_prod.klabel + assert klabel is not None return syntax_sort.let( att=syntax_sort.att.update( [ # TODO Here, the attriubte is stored as dict, but ultimately we should parse known attributes in KAtt.from_dict - Atts.CONCAT(KApply(symbol).to_dict()), + Atts.CONCAT(KApply(klabel).to_dict()), # TODO Here, we keep the format from the frontend so that the attributes on SyntaxSort and Production are of the same type. - Atts.ELEMENT(concat_att[Atts.ELEMENT]), - Atts.UNIT(concat_att[Atts.UNIT]), + Atts.ELEMENT(concat_prod.att[Atts.ELEMENT]), + Atts.UNIT(concat_prod.att[Atts.UNIT]), ] - + ([Atts.UPDATE(concat_att[Atts.UPDATE])] if Atts.UPDATE in concat_att else []) + + ([Atts.UPDATE(concat_prod.att[Atts.UPDATE])] if Atts.UPDATE in concat_prod.att else []) ) ) @@ -940,13 +985,80 @@ def _transform_rule(self, rule: KRule) -> KRule: @dataclass -class AddAnywhereAtts(KompilerPass): - """Add the anywhere attribute to all symbol productions that are overloads or have a corresponding anywhere rule.""" +class AddImpureAtts(SingleModulePass): + """Add the `impure` attribute to all function symbol productions whose definition transitively contains `impure`.""" - def execute(self, definition: KDefinition) -> KDefinition: - if len(definition.modules) > 1: - raise ValueError('Expected a single module') - module = definition.modules[0] + def _transform_module(self, module: KFlatModule) -> KFlatModule: + impurities = AddImpureAtts._impurities(module) + + def update(production: KProduction) -> KProduction: + if not production.klabel: + return production + + klabel = production.klabel + + if klabel.name in impurities: + return production.let(att=production.att.update([Atts.IMPURE(None)])) + + return production + + module = module.map_sentences(update, of_type=KProduction) + return module + + @staticmethod + def _impurities(module: KFlatModule) -> set[str]: + callers = AddImpureAtts._callers(module) + + res: set[str] = set() + pending = [ + prod.klabel.name for prod in module.productions if prod.klabel is not None and Atts.IMPURE in prod.att + ] + while pending: + label = pending.pop() + if label in res: + continue + res.add(label) + pending.extend(callers.get(label, [])) + return res + + @staticmethod + def _callers(module: KFlatModule) -> dict[str, set[str]]: + function_labels = {prod.klabel.name for prod in module.productions if prod.klabel and Atts.FUNCTION in prod.att} + + res: dict[str, set[str]] = {} + for rule in module.rules: + assert isinstance(rule.body, KRewrite) + + match rule.body: + case KRewrite(KApply(KLabel(label)), rhs): + if label in function_labels: + rhs_labels = AddImpureAtts._labels(rhs) + for called in rhs_labels: + res.setdefault(called, set()).add(label) + case _: + pass + return res + + @staticmethod + def _labels(inner: KInner) -> set[str]: + res: set[str] = set() + + def add_label(inner: KInner) -> None: + match inner: + case KApply(KLabel(label)): + res.add(label) + case _: + pass + + collect(add_label, inner) + return res + + +@dataclass +class AddAnywhereAttsFromRules(SingleModulePass): + """Add the anywhere attribute to all symbol productions that have a corresponding anywhere rule.""" + + def _transform_module(self, module: KFlatModule) -> KFlatModule: rules = self._rules_by_klabel(module) def update(production: KProduction) -> KProduction: @@ -958,13 +1070,10 @@ def update(production: KProduction) -> KProduction: if any(Atts.ANYWHERE in rule.att for rule in rules.get(klabel, [])): return production.let(att=production.att.update([Atts.ANYWHERE(None)])) - if klabel.name in definition.overloads: - return production.let(att=production.att.update([Atts.ANYWHERE(None)])) - return production module = module.map_sentences(update, of_type=KProduction) - return KDefinition(module.name, (module,)) + return module @staticmethod def _rules_by_klabel(module: KFlatModule) -> dict[KLabel, list[KRule]]: @@ -983,6 +1092,30 @@ def _rules_by_klabel(module: KFlatModule) -> dict[KLabel, list[KRule]]: return res +@dataclass +class AddAnywhereAttsFromOverloads(KompilerPass): + """Add the anywhere attribute to all symbol productions that are overloads.""" + + def execute(self, definition: KDefinition) -> KDefinition: + if len(definition.modules) > 1: + raise ValueError('Expected a single module') + module = definition.modules[0] + + def update(production: KProduction) -> KProduction: + if not production.klabel: + return production + + klabel = production.klabel + + if klabel.name in definition.overloads: + return production.let(att=production.att.update([Atts.ANYWHERE(None)])) + + return production + + module = module.map_sentences(update, of_type=KProduction) + return KDefinition(module.name, (module,)) + + @dataclass class AddSymbolAtts(SingleModulePass): """Add attribute to symbol productions based on a predicate.""" @@ -1063,6 +1196,7 @@ class DiscardHookAtts(SingleModulePass): 'SET', 'STRING', 'SUBSTITUTION', + 'TIMER', 'UNIFICATION', ) diff --git a/pyk/src/pyk/kore/rule.py b/pyk/src/pyk/kore/rule.py index 8bc7fb965ec..9ac88616bfb 100644 --- a/pyk/src/pyk/kore/rule.py +++ b/pyk/src/pyk/kore/rule.py @@ -1,11 +1,12 @@ from __future__ import annotations import logging -from abc import ABC +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, TypeVar, final +from functools import reduce +from typing import TYPE_CHECKING, Generic, TypeVar, cast, final -from .prelude import inj +from .prelude import BOOL, SORT_GENERATED_TOP_CELL, TRUE, inj from .syntax import ( DV, And, @@ -28,7 +29,7 @@ if TYPE_CHECKING: from typing import Final - from .syntax import Definition + from .syntax import Definition, Sort Attrs = dict[str, tuple[Pattern, ...]] @@ -68,8 +69,12 @@ class Rule(ABC): rhs: Pattern req: Pattern | None ens: Pattern | None + sort: Sort priority: int + @abstractmethod + def to_axiom(self) -> Axiom: ... + @staticmethod def from_axiom(axiom: Axiom) -> Rule: if isinstance(axiom.pattern, Rewrites): @@ -89,22 +94,25 @@ def from_axiom(axiom: Axiom) -> Rule: raise ValueError(f'Cannot parse simplification rule: {axiom.text}') @staticmethod - def extract_all(defn: Definition) -> list[Rule]: - def is_rule(axiom: Axiom) -> bool: - if axiom == _INJ_AXIOM: - return False + def is_rule(axiom: Axiom) -> bool: + if axiom == _INJ_AXIOM: + return False - if any(attr in axiom.attrs_by_key for attr in _SKIPPED_ATTRS): - return False + if any(attr in axiom.attrs_by_key for attr in _SKIPPED_ATTRS): + return False - return True + return True - return [Rule.from_axiom(axiom) for axiom in defn.axioms if is_rule(axiom)] + @staticmethod + def extract_all(defn: Definition) -> list[Rule]: + return [Rule.from_axiom(axiom) for axiom in defn.axioms if Rule.is_rule(axiom)] @final @dataclass(frozen=True) class RewriteRule(Rule): + sort = SORT_GENERATED_TOP_CELL + lhs: App rhs: App req: Pattern | None @@ -114,6 +122,19 @@ class RewriteRule(Rule): uid: str label: str | None + def to_axiom(self) -> Axiom: + lhs = self.lhs if self.ctx is None else And(self.sort, (self.lhs, self.ctx)) + req = _to_ml_pred(self.req, self.sort) + ens = _to_ml_pred(self.ens, self.sort) + return Axiom( + (), + Rewrites( + self.sort, + And(self.sort, (lhs, req)), + And(self.sort, (self.rhs, ens)), + ), + ) + @staticmethod def from_axiom(axiom: Axiom) -> RewriteRule: lhs, rhs, req, ens, ctx = RewriteRule._extract(axiom) @@ -166,60 +187,125 @@ class FunctionRule(Rule): rhs: Pattern req: Pattern | None ens: Pattern | None + sort: Sort + arg_sorts: tuple[Sort, ...] + anti_left: Pattern | None priority: int + def to_axiom(self) -> Axiom: + R = SortVar('R') # noqa N806 + + def arg_list(rest: Pattern, arg_pair: tuple[EVar, Pattern]) -> Pattern: + var, arg = arg_pair + return And(R, (In(var.sort, R, var, arg), rest)) + + vars = tuple(EVar(f'X{i}', sort) for i, sort in enumerate(self.arg_sorts)) + + # \and{R}(\in{S1, R}(X1 : S1, Arg1), \and{R}(\in{S2, R}(X2 : S2, Arg2), \top{R}())) etc. + _args = reduce( + arg_list, + reversed(tuple(zip(vars, self.lhs.args, strict=True))), + cast('Pattern', Top(R)), + ) + + _req = _to_ml_pred(self.req, R) + req = And(R, (_req, _args)) + if self.anti_left: + req = And(R, (Not(R, self.anti_left), req)) + + app = self.lhs.let(args=vars) + ens = _to_ml_pred(self.ens, self.sort) + + return Axiom( + (R,), + Implies( + R, + req, + Equals(self.sort, R, app, And(self.sort, (self.rhs, ens))), + ), + ) + @staticmethod def from_axiom(axiom: Axiom) -> FunctionRule: - lhs, rhs, req, ens = FunctionRule._extract(axiom) + anti_left: Pattern | None = None + match axiom.pattern: + case Implies( + left=And(ops=(Not(pattern=anti_left), And(ops=(_req, _args)))), + right=Equals(op_sort=sort, left=App() as app, right=_rhs), + ): + pass + case Implies( + left=And(ops=(_req, _args)), + right=Equals(op_sort=sort, left=App() as app, right=_rhs), + ): + pass + case _: + raise ValueError(f'Cannot extract function rule from axiom: {axiom.text}') + + arg_sorts, args = FunctionRule._extract_args(_args) + lhs = app.let(args=args) + req = _extract_condition(_req) + rhs, ens = _extract_rhs(_rhs) + priority = _extract_priority(axiom) return FunctionRule( lhs=lhs, rhs=rhs, req=req, ens=ens, + sort=sort, + arg_sorts=arg_sorts, + anti_left=anti_left, priority=priority, ) @staticmethod - def _extract(axiom: Axiom) -> tuple[App, Pattern, Pattern | None, Pattern | None]: - match axiom.pattern: - case Implies( - left=And( - ops=(Not(), And(ops=(_req, _args))) | (_req, _args), - ), - right=Equals(left=App() as app, right=_rhs), - ): - args = FunctionRule._extract_args(_args) - lhs = app.let(args=args) - req = _extract_condition(_req) - rhs, ens = _extract_rhs(_rhs) - return lhs, rhs, req, ens - case _: - raise ValueError(f'Cannot extract function rule from axiom: {axiom.text}') - - @staticmethod - def _extract_args(pattern: Pattern) -> tuple[Pattern, ...]: + def _extract_args(pattern: Pattern) -> tuple[tuple[Sort, ...], tuple[Pattern, ...]]: match pattern: case Top(): - return () - case And(ops=(In(left=EVar(), right=arg), rest)): - return (arg,) + FunctionRule._extract_args(rest) + return (), () + case And(ops=(In(left=EVar(sort=sort), right=arg), rest)): + sorts, args = FunctionRule._extract_args(rest) + return (sort,) + sorts, (arg,) + args case _: raise ValueError(f'Cannot extract argument list from pattern: {pattern.text}') class SimpliRule(Rule, Generic[P], ABC): lhs: P + sort: Sort + + def to_axiom(self) -> Axiom: + R = SortVar('R') # noqa N806 + + vars = (R, self.sort) if isinstance(self.sort, SortVar) else (R,) + req = _to_ml_pred(self.req, R) + ens = _to_ml_pred(self.ens, self.sort) + + return Axiom( + vars, + Implies( + R, + req, + Equals(self.sort, R, self.lhs, And(self.sort, (self.rhs, ens))), + ), + attrs=( + App( + 'simplification', + args=() if self.priority == 50 else (String(str(self.priority)),), + ), + ), + ) @staticmethod - def _extract(axiom: Axiom, lhs_type: type[P]) -> tuple[P, Pattern, Pattern | None, Pattern | None]: + def _extract(axiom: Axiom, lhs_type: type[P]) -> tuple[P, Pattern, Pattern | None, Pattern | None, Sort]: match axiom.pattern: - case Implies(left=_req, right=Equals(left=lhs, right=_rhs)): + case Implies(left=_req, right=Equals(op_sort=sort, left=lhs, right=_rhs)): req = _extract_condition(_req) rhs, ens = _extract_rhs(_rhs) if not isinstance(lhs, lhs_type): raise ValueError(f'Invalid LHS type from simplification axiom: {axiom.text}') - return lhs, rhs, req, ens + return lhs, rhs, req, ens, sort case _: raise ValueError(f'Cannot extract simplification rule from axiom: {axiom.text}') @@ -231,63 +317,67 @@ class AppRule(SimpliRule[App]): rhs: Pattern req: Pattern | None ens: Pattern | None + sort: Sort priority: int @staticmethod def from_axiom(axiom: Axiom) -> AppRule: - lhs, rhs, req, ens = SimpliRule._extract(axiom, App) - priority = _extract_priority(axiom) + lhs, rhs, req, ens, sort = SimpliRule._extract(axiom, App) + priority = _extract_simpl_priority(axiom) return AppRule( lhs=lhs, rhs=rhs, req=req, ens=ens, + sort=sort, priority=priority, ) @final @dataclass(frozen=True) -class CeilRule(SimpliRule): +class CeilRule(SimpliRule[Ceil]): lhs: Ceil rhs: Pattern req: Pattern | None ens: Pattern | None + sort: Sort priority: int @staticmethod def from_axiom(axiom: Axiom) -> CeilRule: - lhs, rhs, req, ens = SimpliRule._extract(axiom, Ceil) - priority = _extract_priority(axiom) + lhs, rhs, req, ens, sort = SimpliRule._extract(axiom, Ceil) + priority = _extract_simpl_priority(axiom) return CeilRule( lhs=lhs, rhs=rhs, req=req, ens=ens, + sort=sort, priority=priority, ) @final @dataclass(frozen=True) -class EqualsRule(SimpliRule): +class EqualsRule(SimpliRule[Equals]): lhs: Equals rhs: Pattern req: Pattern | None ens: Pattern | None + sort: Sort priority: int @staticmethod def from_axiom(axiom: Axiom) -> EqualsRule: - lhs, rhs, req, ens = SimpliRule._extract(axiom, Equals) - if not isinstance(lhs, Equals): - raise ValueError(f'Cannot extract LHS as Equals from axiom: {axiom.text}') - priority = _extract_priority(axiom) + lhs, rhs, req, ens, sort = SimpliRule._extract(axiom, Equals) + priority = _extract_simpl_priority(axiom) return EqualsRule( lhs=lhs, rhs=rhs, req=req, ens=ens, + sort=sort, priority=priority, ) @@ -340,3 +430,21 @@ def _extract_priority(axiom: Axiom) -> int: return 200 if 'owise' in attrs else 50 case _: raise ValueError(f'Cannot extract priority from axiom: {axiom.text}') + + +def _extract_simpl_priority(axiom: Axiom) -> int: + attrs = axiom.attrs_by_key + match attrs['simplification']: + case App(args=() | (String(''),)): + return 50 + case App(args=(String(p),)): + return int(p) + case _: + raise ValueError(f'Cannot extract simplification priority from axiom: {axiom.text}') + + +def _to_ml_pred(pattern: Pattern | None, sort: Sort) -> Pattern: + if pattern is None: + return Top(sort) + + return Equals(BOOL, sort, pattern, TRUE) diff --git a/pyk/src/pyk/proof/reachability.py b/pyk/src/pyk/proof/reachability.py index d7e706ca939..1ce8cd56582 100644 --- a/pyk/src/pyk/proof/reachability.py +++ b/pyk/src/pyk/proof/reachability.py @@ -40,6 +40,7 @@ class APRProofResult: node_id: int prior_loops_cache_update: tuple[int, ...] + optimize_kcfg: bool @dataclass @@ -220,6 +221,7 @@ def commit(self, result: APRProofResult) -> None: assert result.cached_node_id in self._next_steps self.kcfg.extend( extend_result=self._next_steps.pop(result.cached_node_id), + optimize_kcfg=result.optimize_kcfg, node=self.kcfg.node(result.node_id), logs=self.logs, ) @@ -230,6 +232,7 @@ def commit(self, result: APRProofResult) -> None: self._next_steps[result.node_id] = result.extension_to_cache self.kcfg.extend( extend_result=result.extension_to_apply, + optimize_kcfg=result.optimize_kcfg, node=self.kcfg.node(result.node_id), logs=self.logs, ) @@ -715,6 +718,7 @@ class APRProver(Prover[APRProof, APRProofStep, APRProofResult]): assume_defined: bool kcfg_explore: KCFGExplore extra_module: KFlatModule | None + optimize_kcfg: bool def __init__( self, @@ -727,6 +731,7 @@ def __init__( direct_subproof_rules: bool = False, assume_defined: bool = False, extra_module: KFlatModule | None = None, + optimize_kcfg: bool = False, ) -> None: self.kcfg_explore = kcfg_explore @@ -739,6 +744,7 @@ def __init__( self.direct_subproof_rules = direct_subproof_rules self.assume_defined = assume_defined self.extra_module = extra_module + self.optimize_kcfg = optimize_kcfg def close(self) -> None: self.kcfg_explore.cterm_symbolic._kore_client.close() @@ -808,14 +814,24 @@ def step_proof(self, step: APRProofStep) -> list[APRProofResult]: _LOGGER.info(f'Prior loop heads for node {step.node.id}: {(step.node.id, prior_loops)}') if len(prior_loops) > step.bmc_depth: _LOGGER.warning(f'Bounded node {step.proof_id}: {step.node.id} at bmc depth {step.bmc_depth}') - return [APRProofBoundedResult(node_id=step.node.id, prior_loops_cache_update=prior_loops)] + return [ + APRProofBoundedResult( + node_id=step.node.id, optimize_kcfg=self.optimize_kcfg, prior_loops_cache_update=prior_loops + ) + ] # Check if the current node and target are terminal is_terminal = self.kcfg_explore.kcfg_semantics.is_terminal(step.node.cterm) target_is_terminal = self.kcfg_explore.kcfg_semantics.is_terminal(step.target.cterm) terminal_result: list[APRProofResult] = ( - [APRProofTerminalResult(node_id=step.node.id, prior_loops_cache_update=prior_loops)] if is_terminal else [] + [ + APRProofTerminalResult( + node_id=step.node.id, optimize_kcfg=self.optimize_kcfg, prior_loops_cache_update=prior_loops + ) + ] + if is_terminal + else [] ) # Subsumption is checked if and only if the target node @@ -826,7 +842,12 @@ def step_proof(self, step: APRProofStep) -> list[APRProofResult]: # Information about the subsumed node being terminal must be returned # so that the set of terminal nodes is correctly updated return terminal_result + [ - APRProofSubsumeResult(csubst=csubst, node_id=step.node.id, prior_loops_cache_update=prior_loops) + APRProofSubsumeResult( + csubst=csubst, + optimize_kcfg=self.optimize_kcfg, + node_id=step.node.id, + prior_loops_cache_update=prior_loops, + ) ] if is_terminal: @@ -849,6 +870,7 @@ def step_proof(self, step: APRProofStep) -> list[APRProofResult]: APRProofUseCacheResult( node_id=step.node.id, cached_node_id=step.use_cache, + optimize_kcfg=self.optimize_kcfg, prior_loops_cache_update=prior_loops, ) ] @@ -876,6 +898,7 @@ def step_proof(self, step: APRProofStep) -> list[APRProofResult]: extension_to_apply=extend_results[0], extension_to_cache=extend_results[1], prior_loops_cache_update=prior_loops, + optimize_kcfg=self.optimize_kcfg, ) ] @@ -885,6 +908,7 @@ def step_proof(self, step: APRProofStep) -> list[APRProofResult]: node_id=step.node.id, extension_to_apply=extend_results[0], prior_loops_cache_update=prior_loops, + optimize_kcfg=self.optimize_kcfg, ) ] diff --git a/pyk/src/tests/integration/kllvm/test_prooftrace.py b/pyk/src/tests/integration/kllvm/test_prooftrace.py index 8ef5cb94956..ee54383fd7a 100644 --- a/pyk/src/tests/integration/kllvm/test_prooftrace.py +++ b/pyk/src/tests/integration/kllvm/test_prooftrace.py @@ -103,7 +103,7 @@ def test_streaming_parser_iter(self, header: prooftrace.KoreHeader, hints_file: list_of_events = list(it) # Test length of the list - assert len(list_of_events) == 13 + assert len(list_of_events) == 17 # Test the type of the events for event in list_of_events: @@ -190,7 +190,7 @@ def test_parse_proof_hint_single_rewrite( assert pt is not None # 10 initialization events - assert len(pt.pre_trace) == 10 + assert len(pt.pre_trace) == 14 # 2 post-initial-configuration events assert len(pt.trace) == 2 @@ -260,10 +260,10 @@ def test_parse_proof_hint_reverse_no_ints( assert pt is not None # 10 initialization events - assert len(pt.pre_trace) == 10 + assert len(pt.pre_trace) == 14 # 9 post-initial-configuration events - assert len(pt.trace) == 9 + assert len(pt.trace) == 12 # Contents of the k cell in the initial configuration kore_pattern = llvm_to_pattern(pt.initial_config.kore_pattern) @@ -321,8 +321,14 @@ def test_parse_proof_hint_reverse_no_ints( assert axiom == axiom_expected assert len(rule_event.substitution) == 1 + # Function exit event (no tail) + function_exit_event = pt.trace[6].step_event + assert isinstance(function_exit_event, prooftrace.LLVMFunctionExitEvent) + assert function_exit_event.rule_ordinal == 158 + assert function_exit_event.is_tail == False + # Function event - rule_event = pt.trace[6].step_event + rule_event = pt.trace[7].step_event assert isinstance(rule_event, prooftrace.LLVMFunctionEvent) assert rule_event.name == "Lblreverse'LParUndsRParUnds'TREE-REVERSE-SYNTAX'Unds'Tree'Unds'Tree{}" assert rule_event.relative_position == '1' @@ -330,16 +336,28 @@ def test_parse_proof_hint_reverse_no_ints( assert len(rule_event.args) == 0 # Simplification rule - rule_event = pt.trace[7].step_event + rule_event = pt.trace[8].step_event assert isinstance(rule_event, prooftrace.LLVMRuleEvent) axiom = repr(definition.get_axiom_by_ordinal(rule_event.rule_ordinal)) axiom_expected = get_pattern_from_ordinal(definition_text, rule_event.rule_ordinal) assert axiom == axiom_expected assert len(rule_event.substitution) == 1 + # Function exit event (no tail) + function_exit_event = pt.trace[9].step_event + assert isinstance(function_exit_event, prooftrace.LLVMFunctionExitEvent) + assert function_exit_event.rule_ordinal == 157 + assert function_exit_event.is_tail == False + + # Function exit event (no tail) + function_exit_event = pt.trace[10].step_event + assert isinstance(function_exit_event, prooftrace.LLVMFunctionExitEvent) + assert function_exit_event.rule_ordinal == 160 + assert function_exit_event.is_tail == False + # Then pattern - assert pt.trace[8].is_kore_pattern() - kore_pattern = llvm_to_pattern(pt.trace[8].kore_pattern) + assert pt.trace[11].is_kore_pattern() + kore_pattern = llvm_to_pattern(pt.trace[11].kore_pattern) k_cell = kore_pattern.patterns[0].dict['args'][0] assert k_cell['name'] == 'kseq' assert ( @@ -385,10 +403,10 @@ def test_parse_proof_hint_non_rec_function( assert pt is not None # 10 initialization events - assert len(pt.pre_trace) == 10 + assert len(pt.pre_trace) == 14 # 6 post-initial-configuration events - assert len(pt.trace) == 6 + assert len(pt.trace) == 8 # Contents of the k cell in the initial configuration kore_pattern = llvm_to_pattern(pt.initial_config.kore_pattern) @@ -415,24 +433,36 @@ def test_parse_proof_hint_non_rec_function( inner_rule_event = pt.trace[2].step_event assert isinstance(inner_rule_event, prooftrace.LLVMRuleEvent) + # Function exit event (no tail) + function_exit_event = pt.trace[3].step_event + assert isinstance(function_exit_event, prooftrace.LLVMFunctionExitEvent) + assert function_exit_event.rule_ordinal == 103 + assert function_exit_event.is_tail == False + # Functional event - fun_event = pt.trace[3].step_event + fun_event = pt.trace[4].step_event assert isinstance(fun_event, prooftrace.LLVMFunctionEvent) assert fun_event.name == "Lblid'LParUndsRParUnds'NON-REC-FUNCTION-SYNTAX'Unds'Foo'Unds'Foo{}" assert fun_event.relative_position == '0:0:0' assert len(fun_event.args) == 0 # Then rule - rule_event = pt.trace[4].step_event + rule_event = pt.trace[5].step_event assert isinstance(rule_event, prooftrace.LLVMRuleEvent) axiom = repr(definition.get_axiom_by_ordinal(rule_event.rule_ordinal)) axiom_expected = get_pattern_from_ordinal(definition_text, rule_event.rule_ordinal) assert axiom == axiom_expected assert len(rule_event.substitution) == 1 + # Function exit event (no tail) + function_exit_event = pt.trace[6].step_event + assert isinstance(function_exit_event, prooftrace.LLVMFunctionExitEvent) + assert function_exit_event.rule_ordinal == 103 + assert function_exit_event.is_tail == False + # Then pattern - assert pt.trace[5].is_kore_pattern() - kore_pattern = llvm_to_pattern(pt.trace[5].kore_pattern) + assert pt.trace[7].is_kore_pattern() + kore_pattern = llvm_to_pattern(pt.trace[7].kore_pattern) k_cell = kore_pattern.patterns[0].dict['args'][0] assert k_cell['name'] == 'kseq' assert ( @@ -468,7 +498,7 @@ def test_parse_proof_hint_dv(self, hints: bytes, header: prooftrace.KoreHeader, assert pt is not None # 10 initialization events - assert len(pt.pre_trace) == 10 + assert len(pt.pre_trace) == 14 # 3 post-initial-configuration events assert len(pt.trace) == 3 @@ -549,7 +579,7 @@ def test_parse_concurrent_counters(self, hints: bytes, header: prooftrace.KoreHe assert pt is not None # 10 initialization events - assert len(pt.pre_trace) == 10 + assert len(pt.pre_trace) == 14 # 37 post-initial-configuration events assert len(pt.trace) == 37 @@ -709,7 +739,7 @@ def test_parse_proof_hint_0_decrement(self, hints: bytes, header: prooftrace.Kor assert pt is not None # 10 initialization events - assert len(pt.pre_trace) == 10 + assert len(pt.pre_trace) == 14 # 1 post-initial-configuration event assert len(pt.trace) == 1 @@ -738,7 +768,7 @@ def test_parse_proof_hint_1_decrement(self, hints: bytes, header: prooftrace.Kor assert pt is not None # 10 initialization events - assert len(pt.pre_trace) == 10 + assert len(pt.pre_trace) == 14 # 2 post-initial-configuration events assert len(pt.trace) == 2 @@ -767,7 +797,7 @@ def test_parse_proof_hint_2_decrement(self, hints: bytes, header: prooftrace.Kor assert pt is not None # 10 initialization events - assert len(pt.pre_trace) == 10 + assert len(pt.pre_trace) == 14 # 3 post-initial-configuration events assert len(pt.trace) == 3 @@ -806,14 +836,14 @@ def test_parse_proof_hint_peano(self, hints: bytes, header: prooftrace.KoreHeade assert pt is not None # 10 initialization events - assert len(pt.pre_trace) == 10 + assert len(pt.pre_trace) == 14 # 776 post-initial-configuration events - assert len(pt.trace) == 776 + assert len(pt.trace) == 916 # Assert that we have a pattern matching failure as the 135th event - assert pt.trace[135].is_step_event() and isinstance( - pt.trace[135].step_event, prooftrace.LLVMPatternMatchingFailureEvent + assert pt.trace[160].is_step_event() and isinstance( + pt.trace[160].step_event, prooftrace.LLVMPatternMatchingFailureEvent ) @@ -915,7 +945,7 @@ def test_parse_proof_hint_imp5(self, hints: bytes, header: prooftrace.KoreHeader assert pt is not None # 14 initialization events - assert len(pt.pre_trace) == 14 + assert len(pt.pre_trace) == 20 # 2 post-initial-configuration events assert len(pt.trace) == 2 @@ -954,7 +984,7 @@ def test_parse_proof_hint_builtin_hook_events( assert pt is not None # 10 initialization events - assert len(pt.pre_trace) == 10 + assert len(pt.pre_trace) == 14 # 4 post-initial-configuration events assert len(pt.trace) == 4 diff --git a/pyk/src/tests/integration/konvert/test_module_to_kore.py b/pyk/src/tests/integration/konvert/test_module_to_kore.py index deb4790aa21..70484c1f86b 100644 --- a/pyk/src/tests/integration/konvert/test_module_to_kore.py +++ b/pyk/src/tests/integration/konvert/test_module_to_kore.py @@ -7,7 +7,8 @@ from pyk.kast.outer import read_kast_definition from pyk.konvert import module_to_kore from pyk.kore.parser import KoreParser -from pyk.kore.syntax import SortDecl, Symbol, SymbolDecl +from pyk.kore.rule import Rule +from pyk.kore.syntax import Axiom, SortDecl, Symbol, SymbolDecl from pyk.ktool.kompile import DefinitionInfo from ..utils import TEST_DATA_DIR @@ -70,8 +71,8 @@ def test_module_to_kore(kast_defn: KDefinition, kore_module: Module) -> None: # Then # Check module name and attributes - assert actual.name == expected.name - assert actual.attrs == expected.attrs + assert expected.name == actual.name + assert expected.attrs == actual.attrs check_generated_sentences(actual, expected) check_missing_sentences(actual, expected) @@ -120,15 +121,15 @@ def find_expected_sentence(sentence: Sentence, expected: Module) -> Sentence | N pytest.fail(f'Invalid sentence: {sent.text}') # Fail with diff - assert sent.text == expected_sent.text + assert expected_sent.text == sent.text def check_missing_sentences(actual: Module, expected: Module) -> None: actual_sentences = set(actual.sentences) for sent in expected.sentences: # TODO remove - # Filter for SortDecl and SymbolDecl for now - if not isinstance(sent, (SortDecl, SymbolDecl)): + # Do not check rule axioms for now + if isinstance(sent, Axiom) and Rule.is_rule(sent): continue if sent not in actual_sentences: pytest.fail(f'Missing sentence: {sent.text}') diff --git a/pyk/src/tests/integration/kore/test_rule.py b/pyk/src/tests/integration/kore/test_rule.py index fadb530548d..27fd56b91e7 100644 --- a/pyk/src/tests/integration/kore/test_rule.py +++ b/pyk/src/tests/integration/kore/test_rule.py @@ -7,11 +7,12 @@ from pyk.kore.parser import KoreParser from pyk.kore.rule import Rule +from pyk.kore.syntax import App, String from ..utils import K_FILES if TYPE_CHECKING: - from pyk.kore.syntax import Definition + from pyk.kore.syntax import Axiom, Definition from pyk.testing import Kompiler @@ -36,3 +37,28 @@ def test_extract_all(definition: Definition) -> None: assert cnt['AppRule'] assert cnt['CeilRule'] assert cnt['EqualsRule'] + + +def test_to_axiom(definition: Definition) -> None: + def adjust_atts(axiom: Axiom) -> Axiom: + match axiom.attrs_by_key.get('simplification'): + case None: + return axiom.let(attrs=()) + case App(args=(String('' | '50'),)): + return axiom.let(attrs=(App('simplification'),)) + case attr: + return axiom.let(attrs=(attr,)) + + for axiom in definition.axioms: + if not Rule.is_rule(axiom): + continue + + # Given + expected = adjust_atts(axiom) + + # When + rule = Rule.from_axiom(axiom) + actual = rule.to_axiom() + + # Then + assert expected == actual diff --git a/pyk/src/tests/integration/proof/test_imp.py b/pyk/src/tests/integration/proof/test_imp.py index 3ad17c599c9..517051b001b 100644 --- a/pyk/src/tests/integration/proof/test_imp.py +++ b/pyk/src/tests/integration/proof/test_imp.py @@ -566,6 +566,35 @@ def same_loop(self, c1: CTerm, c2: CTerm) -> bool: ), ) +APR_PROVE_WITH_KCFG_OPTIMS_TEST_DATA: Iterable[ + tuple[str, Path, str, str, int | None, int | None, Iterable[str], bool, ProofStatus, int] +] = ( + ( + 'imp-simple-sum-100', + K_FILES / 'imp-simple-spec.k', + 'IMP-SIMPLE-SPEC', + 'sum-100', + None, + None, + [], + True, + ProofStatus.PASSED, + 3, + ), + ( + 'imp-simple-long-branches', + K_FILES / 'imp-simple-spec.k', + 'IMP-SIMPLE-SPEC', + 'long-branches', + None, + None, + [], + True, + ProofStatus.PASSED, + 7, + ), +) + PATH_CONSTRAINTS_TEST_DATA: Iterable[ tuple[str, Path, str, str, int | None, int | None, Iterable[str], Iterable[str], str] ] = ( @@ -918,6 +947,55 @@ def test_all_path_reachability_prove( assert proof.status == proof_status assert leaf_number(proof) == expected_leaf_number + @pytest.mark.parametrize( + 'test_id,spec_file,spec_module,claim_id,max_iterations,max_depth,cut_rules,admit_deps,proof_status,expected_nodes', + APR_PROVE_WITH_KCFG_OPTIMS_TEST_DATA, + ids=[test_id for test_id, *_ in APR_PROVE_WITH_KCFG_OPTIMS_TEST_DATA], + ) + def test_all_path_reachability_prove_with_kcfg_optims( + self, + kprove: KProve, + kcfg_explore: KCFGExplore, + test_id: str, + spec_file: str, + spec_module: str, + claim_id: str, + max_iterations: int | None, + max_depth: int | None, + cut_rules: Iterable[str], + admit_deps: bool, + proof_status: ProofStatus, + expected_nodes: int, + tmp_path_factory: TempPathFactory, + ) -> None: + proof_dir = tmp_path_factory.mktemp(f'apr_tmp_proofs-{test_id}') + spec_modules = kprove.parse_modules(Path(spec_file), module_name=spec_module) + spec_label = f'{spec_module}.{claim_id}' + proofs = APRProof.from_spec_modules( + kprove.definition, + spec_modules, + spec_labels=[spec_label], + logs={}, + proof_dir=proof_dir, + ) + proof = single([p for p in proofs if p.id == spec_label]) + if admit_deps: + for subproof in proof.subproofs: + subproof.admit() + subproof.write_proof_data() + + prover = APRProver( + kcfg_explore=kcfg_explore, execute_depth=max_depth, cut_point_rules=cut_rules, optimize_kcfg=True + ) + prover.advance_proof(proof, max_iterations=max_iterations) + + kcfg_show = KCFGShow(kprove, node_printer=APRProofNodePrinter(proof, kprove, full_printer=True)) + cfg_lines = kcfg_show.show(proof.kcfg) + _LOGGER.info('\n'.join(cfg_lines)) + + assert proof.status == proof_status + assert len(proof.kcfg._nodes) == expected_nodes + def test_terminal_node_subsumption( self, kprove: KProve, diff --git a/pyk/src/tests/integration/test-data/module-to-kore/impure.k b/pyk/src/tests/integration/test-data/module-to-kore/impure.k new file mode 100644 index 00000000000..01bf9961146 --- /dev/null +++ b/pyk/src/tests/integration/test-data/module-to-kore/impure.k @@ -0,0 +1,14 @@ +module IMPURE-SYNTAX + syntax Foo ::= "foo" [token] + | bar() [function, impure] + | baz() [function] + | qux() [function] +endmodule + +module IMPURE + imports IMPURE-SYNTAX + + rule bar() => foo + rule baz() => bar() + rule qux() => foo +endmodule diff --git a/pyk/src/tests/integration/test-data/module-to-kore/syntax-sort.k b/pyk/src/tests/integration/test-data/module-to-kore/syntax-sort.k new file mode 100644 index 00000000000..d7837bfbb4a --- /dev/null +++ b/pyk/src/tests/integration/test-data/module-to-kore/syntax-sort.k @@ -0,0 +1,18 @@ +module FOO + syntax Foo ::= "foo" +endmodule + +module BAR + syntax Foo + syntax Foo +endmodule + +module SYNTAX-SORT-SYNTAX + syntax Foo +endmodule + +module SYNTAX-SORT + imports SYNTAX-SORT-SYNTAX + imports FOO + imports BAR +endmodule diff --git a/pyk/src/tests/integration/test_krun_proof_hints.py b/pyk/src/tests/integration/test_krun_proof_hints.py index 008f6fd64f9..fa9ca09b0b7 100644 --- a/pyk/src/tests/integration/test_krun_proof_hints.py +++ b/pyk/src/tests/integration/test_krun_proof_hints.py @@ -48,8 +48,12 @@ class Test0Decrement(KRunTest, ProofTraceTest): function: Lblproject'Coln'KItem{} (0:0) rule: 139 1 VarK = kore[Lbl0'Unds'DECREMENT-SYNTAX'Unds'Nat{}()] +function exit: 139 notail +function exit: 100 notail function: LblinitGeneratedCounterCell{} (1) rule: 98 0 +function exit: 98 notail +function exit: 99 notail config: kore[Lbl'-LT-'generatedTop'-GT-'{}(Lbl'-LT-'k'-GT-'{}(kseq{}(Lbl0'Unds'DECREMENT-SYNTAX'Unds'Nat{}(),dotk{}())),Lbl'-LT-'generatedCounter'-GT-'{}(\\dv{SortInt{}}("0")))] config: kore[Lbl'-LT-'generatedTop'-GT-'{}(Lbl'-LT-'k'-GT-'{}(kseq{}(Lbl0'Unds'DECREMENT-SYNTAX'Unds'Nat{}(),dotk{}())),Lbl'-LT-'generatedCounter'-GT-'{}(\\dv{SortInt{}}("0")))] """ @@ -68,7 +72,7 @@ def test_parse_proof_hint_0_decrement(self, krun: KRun, header: prooftrace.KoreH assert pt is not None # 10 initialization events - assert len(pt.pre_trace) == 10 + assert len(pt.pre_trace) == 14 # 1 post-initial-configuration event assert len(pt.trace) == 1