From fe2916e665beb6afc8cd5fb1c287bb06cc5ecc5d Mon Sep 17 00:00:00 2001 From: Scott Guest Date: Tue, 5 Dec 2023 20:31:50 -0500 Subject: [PATCH] Implement new sort inference algorithm (#3673) Closes #3601. This PR introduces a new sort inference algorithm, aiming to eventually replace the Z3-based approach and pave the way for parametric rules and sorts. The design is heavily inspired by [The Simple Essence of Algebraic Subtyping: Principal Type Inference with Subtyping Made Easy](https://infoscience.epfl.ch/record/278576) by Lionel Parreaux. A high-level description explaining the relevant background can be found in `docs/developers/sort_inference.md`. The current PR implements a limited form of the proposed algorithm, still falling back to the Z3-based algorithm for any terms containing: - ambiguities - strict casts - parametric sorts For reviewers, I would begin by reading - `docs/developers/sort_inference.md` for a high-level conceptual understanding - the comment at the top of `SortInferencer.java` briefly explaining how the high-level design is actually implemented with our data structures - the paper if anything is unclear --- docs/developers/sort_inference.md | 216 +++++++ .../include/kframework/ktest-fail.mak | 2 +- k-distribution/include/kframework/ktest.mak | 2 +- .../regression-new/checks/wideningMacro.k.out | 2 +- .../kframework/kompile/DefinitionParsing.java | 26 +- .../kframework/kompile/KompileOptions.java | 20 +- .../parser/inner/ParseInModule.java | 57 +- .../parser/inner/RuleGrammarGenerator.java | 70 ++- .../disambiguation/inference/BoundedSort.java | 28 + .../disambiguation/inference/CompactSort.java | 125 ++++ .../inference/InferenceDriver.java | 136 ++++ .../disambiguation/inference/ParamId.java | 33 + .../inference/SortInferenceError.java | 84 +++ .../inference/SortInferencer.java | 594 ++++++++++++++++++ .../inference/SortVariable.java | 14 + .../disambiguation/inference/TermSort.java | 38 ++ .../disambiguation/inference/VariableId.java | 42 ++ .../utils/errorsystem/KEMException.java | 7 +- .../scala/org/kframework/builtin/Sorts.scala | 1 + 19 files changed, 1465 insertions(+), 32 deletions(-) create mode 100644 docs/developers/sort_inference.md create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/BoundedSort.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/CompactSort.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/InferenceDriver.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/ParamId.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferenceError.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferencer.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortVariable.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/TermSort.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/VariableId.java diff --git a/docs/developers/sort_inference.md b/docs/developers/sort_inference.md new file mode 100644 index 00000000000..b8539a39fa4 --- /dev/null +++ b/docs/developers/sort_inference.md @@ -0,0 +1,216 @@ +--- +copyright: Copyright (c) Runtime Verification, Inc. All Rights Reserved. +--- + +# Design +The overall design of the new sort inference algorithm (`SortInferencer.java`) is based on the paper [The Simple Essence of Algebraic Subtyping: Principal Type Inference with Subtyping Made Easy](https://infoscience.epfl.ch/record/278576) by Lionel Parreaux. We summarize the relevant parts below, but it's a short and enlightening paper that's worth reading if you want a deeper understanding. + +## SimpleSub Summary +The SimpleSub paper describes a type inference algorithm for a lambda calculus with sub-typing. The type system is akin in expressiveness to something like Java-generics, i.e., permitting type variables `𝛼` with bounds `L <: 𝛼 <: U` (`super` and `extends` in Java). + +Notably though, it captures this expressiveness with a very simple type semantics, enabling inferred subtyping constraints to be efficiently reduced to constraints on type variables (e.g. by mutating a type variable's sub-/super-type bounds throughout the inference process). As well, the results are expressed using "set-theoretic" types, which allow the type constraints to be recorded in an indirect, compact form while also making certain type-simplifications more obvious. + +The inferred types have the following syntax +``` +𝜏 ::= primitive // built-ins + | 𝜏 β†’ 𝜏 // functions + | { 𝑙0 : 𝜏 ; ...; 𝑙𝑛 : 𝜏 } // structurally-typed records (won't be relevant for us) + | 𝛼 // type variables + | ⊀ // top type which is a supertype of all others + | βŠ₯ // bottom type which is a subtype of all others + | 𝜏 βŠ” 𝜏 // type unions / joins + | 𝜏 βŠ“ 𝜏 // type intersections / meets + | πœ‡π›Ό.𝜏 // recursive types (won't be relevant for us) +``` +which is additionally subject to a *polarity* restriction. Informally, for a type `𝜏` which is a syntactic subcomponent of some other type `T`, the polarity of `𝜏` is +- *negative*, if `𝜏` describes a value given as an input to a term with type `T` +- *positive*, if `𝜏` describes a value produced by a term with type `T` + +As a concrete example, in a type like `(𝜏0 β†’ 𝜏1) β†’ 𝜏2`, +- `𝜏2` is positive, as the term produces the value of type `𝜏2` as an output +- `𝜏0 β†’ 𝜏1` is negative, as the value of this type is given as an input to the term +- `𝜏1` is negative, as the value of this type is also given as an input to the term (indirectly via the input function `𝜏0 β†’ 𝜏1`) +- `𝜏0` is positive, as the term itself must produce this value in order to call the input function `𝜏0 β†’ 𝜏1` + +More formally, we define the type as a whole to be positive, and say +- `𝜏` is negative if either + - it occurs as the left part of an arrow `𝜏 β†’ 𝜏'` where `𝜏 β†’ 𝜏'` is itself positive, or + - it occurs as the right part of an arrow `𝜏' β†’ 𝜏` where `𝜏' β†’ 𝜏` is itself negative +- `𝜏` is positive otherwise + +The polarity restriction on our type syntax then requires that +- type intersections `𝜏 βŠ“ 𝜏` may only occur in negative position +- type unions `𝜏 βŠ” 𝜏` may only occur in positive position + +To understand the motivation for this restriction, consider the subtyping constraints in a program and observe that +- if a type `𝜏` is negative, then it corresponds to an upper bound - the term requires a `𝜏` as input, and therefore can accept any sub-type of `𝜏` +- if a type `𝜏` is positive, then it corresponds to a lower bound - the term produces a `𝜏` as output, which can then be used at any place where some supertype of `𝜏` is expected + +Informally then, the polarity restriction enforces that type intersections can only be used for upper bounds and type unions can only be used for lower bounds. In fact, there is an exact correspondence, as conversely any upper/lower bounds can always be encoded by a type intersection/union: +- `𝜏 <: 𝜏1 and 𝜏 <: 𝜏2 iff 𝜏 <: 𝜏1 βŠ“ 𝜏2` +- `𝜏1 <: 𝜏 and 𝜏2 <: 𝜏 iff 𝜏1 βŠ” 𝜏2 <: 𝜏` + +In total then, any type variable with bounds `L <: 𝛼 <: U` can be encoded as a set-theoretic type by +- replacing every negative instance of `𝛼` with `𝛼 βŠ“ U` +- replacing every positive instance of `𝛼` with `𝛼 βŠ” L` + +Conversely, any set-theoretic type subject to the polarity restriction can be converted back to type variables with bounds by iteratively applying this process in reverse, i.e., +- replacing every intersection involving a type variable `𝛼 βŠ“ U` with `𝛼` and recording the bound `𝛼 <: U` (introducing a fresh type variable for intersections involving only concrete types) +- replacing every union involving a type variable `𝛼 βŠ” L` with `𝛼` and recording the bound `L <: 𝛼` (introducing a fresh type variable for unions involving only concrete types) + +For example, consider a term like +``` +πœ†π‘₯ . { L = π‘₯ βˆ’ 1 ; R = if π‘₯ < 0 then 0 else π‘₯ } + ``` +where `nat <: int`, `- : int β†’ int β†’ int`, and `0 : nat`. + +Prior to some simplification passes, SimpleSub will infer the type +``` +𝛼 βŠ“ int β†’ { L : int ; R : 𝛽 βŠ” nat βŠ” 𝛼 } +``` +which corresponds to the Java-esque type +``` +βŸ¨π›Ό extends int, 𝛽 super nat | π›ΌβŸ©(𝛼) β†’ { L : int ; R : 𝛽 } +``` +After simplification, SimpleSub will produce the final type +``` +𝛼 βŠ“ int β†’ { L : int ; R : nat βŠ” 𝛼 } +``` +which corresponds to the Java-esque type +``` +βŸ¨π›Ό super nat extends int⟩(𝛼) β†’ { L : int ; R : 𝛼 } +``` + +### Inference Algorithm +With this background understood, the actual algorithm is quite simple. Below, I provide the algorithms from the paper in Scala with all the parts that are irrelevant to us removed (namely records, let-bindings, and recursive types). + +Take our AST as follows: +``` +enum Term { +case Lit (value: Int) +case Var (name: String) +case Lam (name: String, rhs: Term) +case App (lhs: Term, rhs: Term) +} +``` +The first step is to produce a `SimpleType` which directly records bounds for each variable: +``` +enum SimpleType { +case Variable (st: VariableState) +case Primitive (name: String) +case Function (lhs: SimpleType, rhs: SimpleType) +} + +class VariableState(var lowerBounds: List[SimpleType], + var upperBounds: List[SimpleType]) +``` +The algorithm proceeds straightforwardly, noting type constraints at each function application: +``` +def typeTerm(term: Term)(implicit ctx: Map[String, SimpleType]): SimpleType = term match { +case Lit(n) => Primitive("int") +case Var(name) => ctx.getOrElse(name, err("not found: " + name)) +case Lam(name, body) => + val param = freshVar + Function(param, typeTerm(body)(ctx + (name -> param))) +case App(f, a) => + val res = freshVar + constrain(typeTerm(f), Function(typeTerm(a), res)) + res +} +``` +The constrain function propagates the newly found sub-typing constraint in a manner that ensures coherence with all previously recorded constraints: +``` +def constrain(lhs: SimpleType, rhs: SimpleType) + (implicit cache: MutSet[(SimpleType, SimpleType)]): Unit = { +if (cache.contains(lhs -> rhs)) return () else cache += lhs -> rhs +(lhs, rhs) match { + case (Primitive(n0), Primitive(n1)) if subtype(n0, n1) => () // nothing to do + case (Function(l0, r0), Function(l1, r1)) => + constrain(l1, l0); constrain(r0, r1) + case (Variable(lhs), rhs) => + lhs.upperBounds = rhs :: lhs.upperBounds + lhs.lowerBounds.foreach(constrain(_, rhs)) + case (lhs, Variable(rhs)) => + rhs.lowerBounds = lhs :: rhs.lowerBounds + rhs.upperBounds.foreach(constrain(lhs, _)) + case _ => err("cannot constrain " + lhs + " <: " + rhs) +}} +``` +Once the `SimpleType` is inferred, we then go through the process described before of encoding the typing constraints as unions or intersections of the bounds. The end result is a `CompactType`, which denotes either a union or intersection of its contained types depending on the polarity where it occurs: + +``` +case class CompactType(vars: Set[TypeVariable], + prims: Set[PrimType], + fun: Option[(CompactType, CompactType)]) +``` +The implementation of the compact function is straightforward and omitted here. + +We then perform two simplifications on these `CompactTypes` to remove unnecessary parametricity. We say that two type variables *co-occur* if they appear in the same union or intersection. The two simplifications are then as follows: + +- If a type variable `𝛼` always co-occurs positively with some other type variable `Ξ²` and vice-versa, then `𝛼` and `Ξ²` can be unified. The same applies for negative occurrences. For example, the obvious type for an `if-then-else` function `bool β†’ 𝛼 β†’ Ξ² β†’ 𝛼 βŠ” Ξ²` is in fact equivalent to `bool β†’ 𝛼 β†’ 𝛼 β†’ 𝛼`. + +- If a type variable `𝛼` always co-occurs both positively and negatively with some other type `T`, then `𝛼` can be removed. For example, `𝛼 βŠ“ Int β†’ 𝛼 βŠ” Int` is the same as `Int β†’ Int` + +See the paper for an explanation of why these simplifications are justified. + +The final step of the algorithm is to coalesce each `CompactType` into the final set-theoretic type syntax described initially. Ignoring recursive types, this is just an easy syntactic transformation, so we omit the implementation here (it's also unneeded for our use cases). + +## From SimpleSub To Sort Inference + +To begin, consider the simplified setting of terms without ambiguities or parametric sorts. + +The conceptual translation from K `Term`s to the SimpleSub language is straightforward. + +First, given `pr` a `ProductionReference` whose id is `prId`, sort is `S`, non-terminal sorts are `S1, ..., SK`, and items are `t1, ..., tK`, we note that the typing constraints it induces are the same as any function symbol, so we let +``` +translateBody(pr) = prId translateBody(t1) ... translateBody(tK) +``` +where +``` +prId : S1 β†’ ... β†’ SK β†’ S +``` +is taken as a built-in in the SimpleSub language. + +Then, given `t` any `Term` containing variables `x1, ..., xN`, define the translation +``` +translate(t) = πœ†x1. ... πœ†xN. translateBody(t) +``` + +When we perform type inference on `translate(t)`, we will obtain a function type of the form `T1 β†’ ... β†’ TN β†’ T`. This tells us that `x1 : T1`, ..., `xN : TN`, and, `term : T`, exactly as needed for sort inference! + +There are a few final caveats, which can be easily addressed: +- In K, we do not actually have a complete type lattice, and the intersections and unions of sorts may not actually exist. If we produce a type with such non-existent sorts, it is simply a type error. +- We do not yet support parametric rules throughout K, and may not ever support bounds on parametric sort variables (although this is something to consider!). We can address this by just monomorphizing with every sort in the poset that satisfies the appropriate bounds for each variable. These different monomorphizations should correspond one-to-one with the different maximal models found in the Z3-based sort inference engine. +- Function, macro, and anywhere rules must be special-case. Conceptually, we can just think of these as a `ProductionReference` to some specialized `KRewrite` whose return sort is exactly the declared sort of the LHS. +- Strict casts require us to unify rather than introduce subtyping constraints. + +### Ambiguities +To extend the idea above to handle ambiguities, we note that any parse tree can be divided into unambiguous sub-trees (which we call *slices* in the implementation). The high-level idea is to infer a type for each slice parametric over its contained ambiguities. + +Explicitly, given `t` a `Term` thought of as a tree/AST, cut `t` at each ambiguity node. This will produce a number of sub-trees trees, each of whose root node is either the root of `t` or the direct child of some `Ambiguity` occurring in `t`, and each of whose leaves are either a `Constant` or `Ambiguity`. + +Given such a slice `t` containing `Ambiguity` nodes with which we associate identifiers `amb1, ..., ambK`, and presuming our overall term contains the variables `x1, ..., xN`, let +``` +translateSlice(t) = πœ†amb1. ... πœ†ambK. πœ†x1. ... πœ†xN. translateBody(t) +``` +and extend `translateBody` as above with the rule +``` +translateBody(ambi) = ambi x1 ... xN +``` +The intuition here is that our translation shows that any `Term` corresponds to a SimpleSub function abstracted over all the contrained variables, so an ambiguity can just be represented as some function variable (whose concrete types are not yet known) which is applied to all these variables. + +When we perform inference on `translateSlice(t)`, the resulting type will have the form +``` +(A11 β†’ ... β†’ A1N β†’ A1) β†’ ... β†’ (AK1 β†’ ... β†’ AKN β†’ AK) β†’ T1 β†’ ... β†’ TN β†’ T`. +``` +Here, each `Ai` indicates the expected type of the `Ambiguity` associated to the identifier `ambi`. + +If an ambiguity's child does not itself contain ambiguities (as will always be true far enough down in the parse tree), we can "fold" that particular child into the type of the parent slice by simple function application, corresponding to choosing that child of the ambiguity for our final parse. + +Specifically, let `t` be the term as above and `c` its child slice. We can infer the type of the function application `translateSlice(t) translateSlice(c)` and will be left with either a type error or a type of the form +``` +(A21 β†’ ... β†’ A2N β†’ A1) β†’ ... β†’ (AK1 β†’ ... β†’ AKN β†’ AK) β†’ T1' β†’ ... β†’ TN' β†’ T'`. +``` +which is the specialization of `translateSlice(t)`'s type when picking `c` as the value of the ambiguity. + +We can use this process to then iteratively collapse the tree of slices along all possible paths, starting from the leaves upward. Those paths that encounter type errors during substitution are pruned, while the others are collected into a set of all possible parses along with their inferred types. Unfortunately, this is `O(2^N)` in the worst case, but the hope is that the actual factors involved will be small. diff --git a/k-distribution/include/kframework/ktest-fail.mak b/k-distribution/include/kframework/ktest-fail.mak index 0d06850fab9..151558f0f7c 100644 --- a/k-distribution/include/kframework/ktest-fail.mak +++ b/k-distribution/include/kframework/ktest-fail.mak @@ -14,7 +14,7 @@ TESTS?=$(wildcard $(DEFDIR)/*.md) $(wildcard $(DEFDIR)/*.k) KOMPILE_BACKEND?=llvm KAST_TESTS?=$(wildcard ./*.kast) -KOMPILE_FLAGS+=--no-exc-wrap +KOMPILE_FLAGS+=--no-exc-wrap --type-inference-mode checked KPROVE_FLAGS+=--no-exc-wrap KRUN_FLAGS+=--no-exc-wrap diff --git a/k-distribution/include/kframework/ktest.mak b/k-distribution/include/kframework/ktest.mak index 5d5a69b3860..d76f3f03a9f 100644 --- a/k-distribution/include/kframework/ktest.mak +++ b/k-distribution/include/kframework/ktest.mak @@ -52,7 +52,7 @@ ifeq ($(UNAME), Darwin) KOMPILE_FLAGS+=--no-haskell-binary endif -KOMPILE_FLAGS+=--no-exc-wrap +KOMPILE_FLAGS+=--no-exc-wrap --type-inference-mode checked KPROVE_FLAGS+=--no-exc-wrap KRUN_FLAGS+=--no-exc-wrap diff --git a/k-distribution/tests/regression-new/checks/wideningMacro.k.out b/k-distribution/tests/regression-new/checks/wideningMacro.k.out index 15be25e7c42..aa35d9d3655 100644 --- a/k-distribution/tests/regression-new/checks/wideningMacro.k.out +++ b/k-distribution/tests/regression-new/checks/wideningMacro.k.out @@ -1,4 +1,4 @@ -[Error] Inner Parser: Unexpected sort Int for term parsed as production syntax {Sort} Sort ::= Sort "=>" Sort [klabel(#KRewrite), symbol]. Expected: Bool +[Error] Inner Parser: Sort of variable X inferred as greatest lower bound of [Bool, Int], but no such bound exists. Source(wideningMacro.k) Location(8,10,8,25) 8 | rule #bar ( X ) => X diff --git a/kernel/src/main/java/org/kframework/kompile/DefinitionParsing.java b/kernel/src/main/java/org/kframework/kompile/DefinitionParsing.java index 8a9e2c29e28..ba8136a20e6 100644 --- a/kernel/src/main/java/org/kframework/kompile/DefinitionParsing.java +++ b/kernel/src/main/java/org/kframework/kompile/DefinitionParsing.java @@ -441,7 +441,11 @@ private Definition resolveConfigBubbles(Definition def) { ParseCache cache = loadCache(configParserModule); try (ParseInModule parser = RuleGrammarGenerator.getCombinedGrammar( - cache.module(), profileRules, files, options.debugTypeInference)) { + cache.module(), + profileRules, + files, + options.debugTypeInference, + options.typeInferenceMode)) { // each parser gets its own scanner because config labels can conflict with user // tokens parser.getScanner(globalOptions); @@ -509,7 +513,8 @@ private Definition resolveConfigBubbles(Definition def) { gen.getConfigGrammar(module), profileRules, files, - options.debugTypeInference) + options.debugTypeInference, + options.typeInferenceMode) .getExtensionModule(); Set configDeclProductions = stream(module.localSentences()) @@ -550,7 +555,14 @@ private Definition resolveNonConfigBubbles( ParseCache cache = loadCache(ruleParserModule); try (ParseInModule parser = RuleGrammarGenerator.getCombinedGrammar( - cache.module(), profileRules, false, true, files, options.debugTypeInference, false)) { + cache.module(), + profileRules, + false, + true, + files, + options.debugTypeInference, + options.typeInferenceMode, + false)) { Scanner scanner; if (deserializeScanner) { scanner = new Scanner(parser, globalOptions, files.resolveKompiled("scanner")); @@ -582,7 +594,11 @@ private Module resolveNonConfigBubbles(Module module, Scanner scanner, RuleGramm try (ParseInModule parser = needNewScanner ? RuleGrammarGenerator.getCombinedGrammar( - cache.module(), profileRules, files, options.debugTypeInference) + cache.module(), + profileRules, + files, + options.debugTypeInference, + options.typeInferenceMode) : RuleGrammarGenerator.getCombinedGrammar( cache.module(), scanner, @@ -590,6 +606,7 @@ private Module resolveNonConfigBubbles(Module module, Scanner scanner, RuleGramm false, files, options.debugTypeInference, + options.typeInferenceMode, false)) { if (needNewScanner) parser.getScanner(globalOptions); parser.initialize(); @@ -789,6 +806,7 @@ public Rule parseRule(CompiledDefinition compiledDef, String contents, Source so true, files, options.debugTypeInference, + options.typeInferenceMode, false)) { parser.setScanner(new Scanner(parser, globalOptions, files.resolveKompiled("scanner"))); java.util.Set res = diff --git a/kernel/src/main/java/org/kframework/kompile/KompileOptions.java b/kernel/src/main/java/org/kframework/kompile/KompileOptions.java index fb5cbd39b8e..23aaf4be5e3 100644 --- a/kernel/src/main/java/org/kframework/kompile/KompileOptions.java +++ b/kernel/src/main/java/org/kframework/kompile/KompileOptions.java @@ -241,7 +241,25 @@ public String syntaxModule(FileUtil files) { @Parameter( names = "--enable-kore-antileft", - description = "Enable generation of legacy antileft priority predicates ", + description = "Enable generation of legacy antileft priority predicates.", hidden = true) public boolean enableKoreAntileft; + + public enum TypeInferenceMode { + Z3, + SIMPLESUB, + CHECKED, + // We use an explicit DEFAULT option here so that ParseInModule can set a default which + // applies even for those codepaths that don't rely on KompileOptions + DEFAULT, + } + + @Parameter( + names = "--type-inference-mode", + description = + "Choose between the Z3-based and SimpleSub-based type inference algorithms, or run both" + + " and check that their results are equal. Must be one of " + + "[z3|simplesub|checked|default].", + hidden = true) + public TypeInferenceMode typeInferenceMode = TypeInferenceMode.DEFAULT; } diff --git a/kernel/src/main/java/org/kframework/parser/inner/ParseInModule.java b/kernel/src/main/java/org/kframework/parser/inner/ParseInModule.java index 38283092a65..cab42970e4b 100644 --- a/kernel/src/main/java/org/kframework/parser/inner/ParseInModule.java +++ b/kernel/src/main/java/org/kframework/parser/inner/ParseInModule.java @@ -15,12 +15,14 @@ import org.kframework.definition.Module; import org.kframework.definition.Terminal; import org.kframework.definition.TerminalLike; +import org.kframework.kompile.KompileOptions; import org.kframework.kore.K; import org.kframework.kore.Sort; import org.kframework.main.GlobalOptions; import org.kframework.parser.Term; import org.kframework.parser.TreeNodesToKORE; import org.kframework.parser.inner.disambiguation.*; +import org.kframework.parser.inner.disambiguation.inference.SortInferencer; import org.kframework.parser.inner.kernel.EarleyParser; import org.kframework.parser.inner.kernel.Scanner; import org.kframework.parser.outer.Outer; @@ -39,6 +41,7 @@ */ public class ParseInModule implements Serializable, AutoCloseable { private final Module seedModule; + private Module extensionModule; /** @@ -62,6 +65,7 @@ public class ParseInModule implements Serializable, AutoCloseable { private final boolean forGlobalScanner; private final FileUtil files; private final String typeInferenceDebug; + private final KompileOptions.TypeInferenceMode typeInferenceMode; private final boolean partialParseDebug; ParseInModule( @@ -71,6 +75,7 @@ public class ParseInModule implements Serializable, AutoCloseable { boolean forGlobalScanner, FileUtil files, String typeInferenceDebug, + KompileOptions.TypeInferenceMode typeInferenceMode, boolean partialParseDebug) { this( seedModule, @@ -83,6 +88,7 @@ public class ParseInModule implements Serializable, AutoCloseable { forGlobalScanner, files, typeInferenceDebug, + typeInferenceMode, partialParseDebug); } @@ -94,6 +100,7 @@ public class ParseInModule implements Serializable, AutoCloseable { boolean forGlobalScanner, FileUtil files, String typeInferenceDebug, + KompileOptions.TypeInferenceMode typeInferenceMode, boolean partialParseDebug) { this( seedModule, @@ -106,6 +113,7 @@ public class ParseInModule implements Serializable, AutoCloseable { forGlobalScanner, files, typeInferenceDebug, + typeInferenceMode, partialParseDebug); } @@ -120,6 +128,7 @@ private ParseInModule( boolean forGlobalScanner, FileUtil files, String typeInferenceDebug, + KompileOptions.TypeInferenceMode typeInferenceMode, boolean partialParseDebug) { this.seedModule = seedModule; this.extensionModule = extensionModule; @@ -131,6 +140,10 @@ private ParseInModule( this.forGlobalScanner = forGlobalScanner; this.files = files; this.typeInferenceDebug = typeInferenceDebug; + this.typeInferenceMode = + typeInferenceMode == KompileOptions.TypeInferenceMode.DEFAULT + ? KompileOptions.TypeInferenceMode.Z3 + : typeInferenceMode; this.partialParseDebug = partialParseDebug; } @@ -401,20 +414,44 @@ private Tuple2, Term>, Set> parseStringTe rez3 = new PushTopAmbiguityUp().apply(rez3); startTypeInf = profileRules ? System.currentTimeMillis() : 0; - TypeInferencer currentInferencer; - if (isDebug(source, startLine)) { - currentInferencer = new TypeInferencer(disambModule, true); - inferencers.add(currentInferencer); - } else { - currentInferencer = inferencer.get(); - if (currentInferencer == null) { - currentInferencer = new TypeInferencer(disambModule, isDebug(source, startLine)); - inferencer.set(currentInferencer); + KompileOptions.TypeInferenceMode infModeForTerm = + SortInferencer.isSupported(rez3) + ? typeInferenceMode + : KompileOptions.TypeInferenceMode.Z3; + + if (infModeForTerm == KompileOptions.TypeInferenceMode.SIMPLESUB + || infModeForTerm == KompileOptions.TypeInferenceMode.CHECKED) { + rez = new SortInferencer(disambModule).apply(rez3, startSymbol, isAnywhere); + } + if (infModeForTerm == KompileOptions.TypeInferenceMode.Z3 + || infModeForTerm == KompileOptions.TypeInferenceMode.CHECKED) { + + TypeInferencer currentInferencer; + if (isDebug(source, startLine)) { + currentInferencer = new TypeInferencer(disambModule, true); inferencers.add(currentInferencer); + } else { + currentInferencer = inferencer.get(); + if (currentInferencer == null) { + currentInferencer = new TypeInferencer(disambModule, isDebug(source, startLine)); + inferencer.set(currentInferencer); + inferencers.add(currentInferencer); + } + } + Either, Term> z3Rez = + new TypeInferenceVisitor(currentInferencer, startSymbol, isAnywhere).apply(rez3); + if (infModeForTerm == KompileOptions.TypeInferenceMode.CHECKED) { + boolean bothLeft = rez.isLeft() && z3Rez.isLeft(); + boolean equalRight = + rez.isRight() && z3Rez.isRight() && rez.right().get().equals(z3Rez.right().get()); + if (!(bothLeft || equalRight)) { + throw KEMException.criticalError("Z3 and SimpleSub algorithms differ!"); + } + } else { + rez = z3Rez; } } - rez = new TypeInferenceVisitor(currentInferencer, startSymbol, isAnywhere).apply(rez3); if (rez.isLeft()) return new Tuple2<>(rez, warn); endTypeInf = profileRules ? System.currentTimeMillis() : 0; diff --git a/kernel/src/main/java/org/kframework/parser/inner/RuleGrammarGenerator.java b/kernel/src/main/java/org/kframework/parser/inner/RuleGrammarGenerator.java index bf15b74d1a1..130ed27f176 100644 --- a/kernel/src/main/java/org/kframework/parser/inner/RuleGrammarGenerator.java +++ b/kernel/src/main/java/org/kframework/parser/inner/RuleGrammarGenerator.java @@ -33,6 +33,7 @@ import org.kframework.definition.Terminal; import org.kframework.definition.UidProvider; import org.kframework.definition.UserList; +import org.kframework.kompile.KompileOptions; import org.kframework.kore.Sort; import org.kframework.kore.SortHead; import org.kframework.parser.inner.kernel.Scanner; @@ -201,36 +202,68 @@ public static boolean isParserSort(Sort s) { /* use this overload if you don't need to profile rule parse times. */ public static ParseInModule getCombinedGrammar(Module mod, FileUtil files) { - return getCombinedGrammar(mod, false, false, false, files, null, false); + return getCombinedGrammar( + mod, false, false, false, files, null, KompileOptions.TypeInferenceMode.DEFAULT, false); } public static ParseInModule getCombinedGrammar( Module mod, FileUtil files, boolean partialParseDebug) { - return getCombinedGrammar(mod, false, false, false, files, null, partialParseDebug); + return getCombinedGrammar( + mod, + false, + false, + false, + files, + null, + KompileOptions.TypeInferenceMode.DEFAULT, + partialParseDebug); } public static ParseInModule getCombinedGrammar(Module mod, boolean timing, FileUtil files) { - return getCombinedGrammar(mod, timing, false, false, files, null, false); + return getCombinedGrammar( + mod, timing, false, false, files, null, KompileOptions.TypeInferenceMode.DEFAULT, false); } public static ParseInModule getCombinedGrammar( - Module mod, boolean timing, FileUtil files, String debugTypeInference) { - return getCombinedGrammar(mod, timing, false, false, files, debugTypeInference, false); + Module mod, + boolean timing, + FileUtil files, + String debugTypeInference, + KompileOptions.TypeInferenceMode typeInferenceMode) { + return getCombinedGrammar( + mod, timing, false, false, files, debugTypeInference, typeInferenceMode, false); } public static ParseInModule getCombinedGrammar( Module mod, boolean timing, boolean isBison, FileUtil files) { - return getCombinedGrammar(mod, timing, isBison, false, files, null, false); + return getCombinedGrammar( + mod, timing, isBison, false, files, null, KompileOptions.TypeInferenceMode.DEFAULT, false); } public static ParseInModule getCombinedGrammar( Module mod, boolean timing, boolean isBison, boolean forGlobalScanner, FileUtil files) { - return getCombinedGrammar(mod, timing, isBison, forGlobalScanner, files, null, false); + return getCombinedGrammar( + mod, + timing, + isBison, + forGlobalScanner, + files, + null, + KompileOptions.TypeInferenceMode.DEFAULT, + false); } public static ParseInModule getCombinedGrammar( Module mod, Scanner scanner, boolean timing, boolean isBison, FileUtil files) { - return getCombinedGrammar(mod, scanner, timing, isBison, files, null, false); + return getCombinedGrammar( + mod, + scanner, + timing, + isBison, + files, + null, + KompileOptions.TypeInferenceMode.DEFAULT, + false); } // the forGlobalScanner flag tells the ParseInModule class not to exclude @@ -258,9 +291,17 @@ public static ParseInModule getCombinedGrammar( boolean forGlobalScanner, FileUtil files, String debugTypeInference, + KompileOptions.TypeInferenceMode typeInferenceMode, boolean partialParseDebug) { return new ParseInModule( - mod, timing, isBison, forGlobalScanner, files, debugTypeInference, partialParseDebug); + mod, + timing, + isBison, + forGlobalScanner, + files, + debugTypeInference, + typeInferenceMode, + partialParseDebug); } public static ParseInModule getCombinedGrammar( @@ -270,9 +311,18 @@ public static ParseInModule getCombinedGrammar( boolean isBison, FileUtil files, String debugTypeInference, + KompileOptions.TypeInferenceMode typeInferenceMode, boolean partialParseDebug) { return new ParseInModule( - mod, scanner, timing, isBison, false, files, debugTypeInference, partialParseDebug); + mod, + scanner, + timing, + isBison, + false, + files, + debugTypeInference, + typeInferenceMode, + partialParseDebug); } public static Tuple3 getCombinedGrammarImpl( diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/BoundedSort.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/BoundedSort.java new file mode 100644 index 00000000000..96ed9ed45e3 --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/BoundedSort.java @@ -0,0 +1,28 @@ +// Copyright (c) K Team. All Rights Reserved. +package org.kframework.parser.inner.disambiguation.inference; + +import java.util.HashSet; +import java.util.Set; +import org.kframework.kore.SortHead; + +/** An unsimplified sort analogous to SimpleSub's SimpleType. */ +public sealed interface BoundedSort { + /** A primitive sort */ + record Constructor(SortHead head) implements BoundedSort {} + + /** + * A sort variable with sub- and super-type constraints. + * + * @param sortVar - The underlying SortVariable. This holds no real information, but is needed to + * prevent distinct Variables with the same bounds from comparing equal. + * @param lowerBounds - All those sorts which must be a sub-type of this variable + * @param upperBounds - All those sorts which must be a super-type of this variable + */ + record Variable(SortVariable sortVar, Set lowerBounds, Set upperBounds) + implements BoundedSort { + + Variable() { + this(new SortVariable(), new HashSet<>(), new HashSet<>()); + } + } +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/CompactSort.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/CompactSort.java new file mode 100644 index 00000000000..8bd2177e62c --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/CompactSort.java @@ -0,0 +1,125 @@ +// Copyright (c) K Team. All Rights Reserved. +package org.kframework.parser.inner.disambiguation.inference; + +import static org.kframework.Collections.*; + +import java.util.HashSet; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.kframework.POSet; +import org.kframework.builtin.Sorts; +import org.kframework.kore.Sort; +import org.kframework.kore.SortHead; +import scala.util.Either; +import scala.util.Left; +import scala.util.Right; + +/** + * A sort which, depending on the polarity where it occurs, represents either an intersection or a + * union of component sorts (variables and constructors). Analogous to CompactType from SimpleSub. + */ +public record CompactSort(Set vars, Set ctors) { + + public CompactSort(SortVariable var) { + this(Stream.of(var).collect(Collectors.toCollection(HashSet::new)), new HashSet<>()); + } + + /** + * Compact a BoundedSort, converting all transitive bounds into a single intersection / union. + * + * @param sort - The BoundedSort to make compact + * @param polarity - The polarity where sort occurs. True for positive, false for negative. + * @return A CompactSort containing all bounds represented by sort + */ + public static CompactSort makeCompact(BoundedSort sort, boolean polarity) { + if (sort instanceof BoundedSort.Constructor ctor) { + if (ctor.head().params() == 0) { + Set ctors = new HashSet<>(); + ctors.add(ctor.head()); + return new CompactSort(new HashSet<>(), ctors); + } + throw new AssertionError("Parametric sorts are not yet supported!"); + } + + BoundedSort.Variable var = (BoundedSort.Variable) sort; + Set bounds = polarity ? var.lowerBounds() : var.upperBounds(); + + Set vars = new HashSet<>(); + Set ctors = new HashSet<>(); + vars.add(var.sortVar()); + for (BoundedSort bound : bounds) { + CompactSort compactBound = makeCompact(bound, polarity); + vars.addAll(compactBound.vars()); + ctors.addAll(compactBound.ctors()); + } + return new CompactSort(vars, ctors); + } + + /** + * Substitute variables for CompactSorts. + * + * @param subst - A map where an entry v |-> Optional.of(t) indicates that the variable v should + * be replaced by t, and an entry v |-> Optional.empty() indicates that v should be removed + * entirely (effectively, replacing it with top or bottom depending on polarity). + * @return A new CompactSort with the substitution applied + */ + public CompactSort substitute(Map> subst) { + Set newVars = new HashSet<>(); + Set newCtors = new HashSet<>(ctors); + for (SortVariable var : vars) { + if (!subst.containsKey(var)) { + newVars.add(var); + continue; + } + if (subst.get(var).isPresent()) { + CompactSort newSort = subst.get(var).get(); + newVars.addAll(newSort.vars()); + newCtors.addAll(newSort.ctors()); + } + } + return new CompactSort(newVars, newCtors); + } + + /** + * An error indicating that we could not compute a type meet or join. + * + * @param sorts - The set of sorts we are trying to meet/join. + * @param candidates - The set of minimal upper bounds / maximal lower bounds of sorts. + * @param polarity - True for positive, false for negative + */ + public record LatticeOpError(Set sorts, Set candidates, boolean polarity) {} + + /** + * Convert to an equivalent Sort, instantiating variables and actually computing the type + * join/meet on the provided subsort. + * + * @param polarity - The polarity where this CompactSort occurs. + * @param instantiation - A map indicating how variables should be instantiated + * @param subsorts - The Sort poset + * @return An equivalent Sort + */ + public Either asSort( + boolean polarity, Map instantiation, POSet subsorts) { + Set sorts = vars.stream().map(instantiation::get).collect(Collectors.toSet()); + sorts.addAll( + ctors.stream() + .map(h -> new org.kframework.kore.ADT.Sort(h.name(), Seq())) + .collect(Collectors.toSet())); + // WLOG upper/lower bound of a set is the upper/lower bound of its maximal/minimal elements. + // This is useful to prune the search space as the POSet operations can be quite expensive + sorts = polarity ? subsorts.maximal(sorts) : subsorts.minimal(sorts); + if (sorts.size() == 1) { + return Right.apply(sorts.iterator().next()); + } + Set bounds = polarity ? subsorts.upperBounds(sorts) : subsorts.lowerBounds(sorts); + bounds.removeIf(s -> subsorts.lessThanEq(s, Sorts.KBott())); + Set candidates = polarity ? subsorts.minimal(bounds) : subsorts.maximal(bounds); + if (candidates.size() != 1) { + return Left.apply(new LatticeOpError(sorts, candidates, polarity)); + } + return Right.apply(candidates.iterator().next()); + } +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/InferenceDriver.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/InferenceDriver.java new file mode 100644 index 00000000000..c038c4620ef --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/InferenceDriver.java @@ -0,0 +1,136 @@ +// Copyright (c) K Team. All Rights Reserved. +package org.kframework.parser.inner.disambiguation.inference; + +import static org.kframework.Collections.*; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import org.kframework.POSet; +import org.kframework.builtin.Sorts; +import org.kframework.kore.Sort; +import org.kframework.parser.Constant; +import org.kframework.parser.ProductionReference; +import org.kframework.parser.Term; +import scala.Tuple2; + +/** + * A driver which manages all state during the initial phase of SortInferencer where we infer + * unsimplified BoundedSorts. + */ +public final class InferenceDriver { + /** The POSet of sorts ordered by the sub-sort relation. */ + private final POSet subsorts; + + /** A unique sort variable for every variable that inference encounters. */ + private final Map varSorts = new HashMap<>(); + + /** A unique sort variable for every parameter that inference encounters. */ + private final Map paramSorts = new HashMap<>(); + + /** + * A cache of all subtyping constraints that have already been processed. Pairs (t1, t2) represent + * a constraint t1 <: t2. + */ + private final Set> constraintCache = new HashSet<>(); + + public InferenceDriver(POSet subsorts) { + this.subsorts = subsorts; + } + + /** + * Get the BoundedSort.Variable instance for a variable. Unlike SimpleSub, we don't have explicit + * binders at the top of the Term introducing every variable, so these BoundedSort.Variable + * instances must be created on the fly when each variable is first encountered during inference. + */ + public BoundedSort varSort(Constant var) { + VariableId varId = VariableId.apply(var); + if (!varSorts.containsKey(varId)) { + // Every variable must be assigned a user sort <=K + BoundedSort.Variable sort = new BoundedSort.Variable(); + sort.upperBounds().add(sortToBoundedSort(Sorts.K(), null)); + varSorts.put(varId, sort); + } + return varSorts.get(varId); + } + + /** + * Convert a Sort to a BoundedSort, which may be a sort parameter from the provided + * ProductionReference. + * + * @param sort - The Sort to convert + * @param prOrNull - The ProductionReference where this Sort occurs, or null if sort is guaranteed + * not to be a sort parameter. + * @return A BoundedSort representing sort + */ + public BoundedSort sortToBoundedSort(Sort sort, ProductionReference prOrNull) { + if (prOrNull != null && prOrNull.production().params().contains(sort)) { + ParamId paramId = new ParamId(prOrNull, sort); + if (!paramSorts.containsKey(paramId)) { + paramSorts.put(paramId, new BoundedSort.Variable()); + } + return paramSorts.get(paramId); + } + return new BoundedSort.Constructor(sort.head()); + } + + /** + * Update sub-/super-type constraints to record the fact that lhs <: rhs. + * + * @param pr - The ProductionReference where this constraint originated. This is only necessary + * for error reporting. + * @throws ConstraintError - An error if lhs <: rhs induces some subsort relation which is invalid + * based on the subsort poset. + */ + public void constrain(BoundedSort lhs, BoundedSort rhs, ProductionReference pr) + throws ConstraintError { + // This cache is necessary to prevent exponential blow-up and avoid loops like + // loops like a <: b <: a <: b ... + if (lhs.equals(rhs) || constraintCache.contains(Tuple2.apply(lhs, rhs))) { + return; + } + + if (lhs instanceof BoundedSort.Variable lhsVar) { + constraintCache.add(Tuple2.apply(lhs, rhs)); + lhsVar.upperBounds().add(rhs); + for (BoundedSort lhsLower : lhsVar.lowerBounds()) { + constrain(lhsLower, rhs, pr); + } + return; + } + + if (rhs instanceof BoundedSort.Variable rhsVar) { + constraintCache.add(Tuple2.apply(lhs, rhs)); + rhsVar.lowerBounds().add(lhs); + for (BoundedSort rhsUpper : rhsVar.upperBounds()) { + constrain(lhs, rhsUpper, pr); + } + return; + } + + // If they are primitive sorts, we can check the sort poset directly + BoundedSort.Constructor lhsCtor = (BoundedSort.Constructor) lhs; + BoundedSort.Constructor rhsCtor = (BoundedSort.Constructor) rhs; + if (lhsCtor.head().params() == 0 && rhsCtor.head().params() == 0) { + Sort lhsSort = new org.kframework.kore.ADT.Sort(lhsCtor.head().name(), Seq()); + Sort rhsSort = new org.kframework.kore.ADT.Sort(rhsCtor.head().name(), Seq()); + if (subsorts.lessThanEq(lhsSort, rhsSort)) { + return; + } + throw new ConstraintError(lhsSort, rhsSort, pr); + } + + throw new AssertionError("Parametric sorts are not yet supported!"); + } + + /** + * After inference is complete, get the final TermSort result recording the sorts of variables. + * + * @param term - The term that we ran the driver on. + * @param sort - The inferred sort of the overall term. + */ + public TermSort getResult(Term term, BoundedSort sort) { + return new TermSort<>(term, sort, varSorts); + } +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/ParamId.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/ParamId.java new file mode 100644 index 00000000000..1b51fdbe3d8 --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/ParamId.java @@ -0,0 +1,33 @@ +// Copyright (c) K Team. All Rights Reserved. +package org.kframework.parser.inner.disambiguation.inference; + +import java.util.Objects; +import org.kframework.kore.Sort; +import org.kframework.parser.ProductionReference; + +/** + * A class representing a particular usage of a production's parameter. Effectively, a pair + * (ProductionReference, Sort) with reference semantics for the ProductionReference. + */ +public final class ParamId { + private final ProductionReference pr; + private final Sort param; + + public ParamId(ProductionReference pr, Sort param) { + this.pr = pr; + this.param = param; + } + + @Override + public boolean equals(Object o) { + if (o instanceof ParamId p) { + return this.pr == p.pr && this.param.equals(p.param); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(System.identityHashCode(pr), param); + } +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferenceError.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferenceError.java new file mode 100644 index 00000000000..1d2ab4d07c0 --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferenceError.java @@ -0,0 +1,84 @@ +// Copyright (c) K Team. All Rights Reserved. +package org.kframework.parser.inner.disambiguation.inference; + +import java.util.Optional; +import org.kframework.attributes.HasLocation; +import org.kframework.kore.Sort; +import org.kframework.parser.ProductionReference; +import org.kframework.utils.errorsystem.KEMException; + +/** + * The parent class of all errors thrown by SortInferencer. We use our own exceptions here rather + * than KEMException because a SortInferenceError may not indicate an actual error by the user, + * e.g., it may be thrown for a type error in one branch of an Ambiguity to indicate that it should + * be pruned. + */ +abstract sealed class SortInferenceError extends Exception { + private final Optional loc; + + public SortInferenceError(String message, HasLocation loc) { + super(message); + this.loc = Optional.of(loc); + } + + public KEMException asInnerParseError(HasLocation defaultLoc) { + return KEMException.innerParserError(getMessage(), loc.orElse(defaultLoc)); + } +} + +/** An error indicating that we could not compute some type join / meet. */ +final class LatticeOpError extends SortInferenceError { + public LatticeOpError(CompactSort.LatticeOpError err, HasLocation loc, Optional name) { + super( + "Sort" + + name.map(n -> " of " + n + " ").orElse(" ") + + "inferred as " + + (err.polarity() ? "least upper bound" : "greatest lower bound") + + " of " + + err.sorts() + + ", but " + + (err.candidates().isEmpty() + ? "no such bound exists." + : ("candidate bounds are " + "incomparable: " + err.candidates() + ".")), + loc); + } +} + +/** An error indicating that a sub-typing constraint is invalid. */ +final class ConstraintError extends SortInferenceError { + public ConstraintError(Sort lhs, Sort rhs, ProductionReference pr) { + super( + "Unexpected sort " + + lhs + + " for term parsed as production " + + pr.production() + + ". Expected: " + + rhs, + pr); + } +} + +/** An error indicating that some type variable cannot be monomorphized as an actual K sort. */ +final class MonomorphizationError extends SortInferenceError { + // TODO: Produce better error messages! + // + // Type variables can originate from three places: + // - variables + // - sort parameters + // - as a generalization of the sort of a production + // + // For the first two cases, we could provide nicer error messages by pointing to the + // location where the type variable originated, reporting all bounds on the + // variable, and stating that they cannot be satisfied. + // + // However, it's unclear how to easily explain errors in the third case. + // + // Additionally, there are cases where two type variables may be individually but + // not mutually monomorphized, so its not clear how to report the root cause. + public MonomorphizationError(HasLocation loc) { + super( + "Term is not well-sorted due to monomorphization failure. Add sort annotations to " + + "produce a better error message.", + loc); + } +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferencer.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferencer.java new file mode 100644 index 00000000000..2de3668a55a --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferencer.java @@ -0,0 +1,594 @@ +// Copyright (c) K Team. All Rights Reserved. +package org.kframework.parser.inner.disambiguation.inference; + +import static org.kframework.Collections.*; +import static org.kframework.kore.KORE.*; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import org.kframework.attributes.Att; +import org.kframework.builtin.KLabels; +import org.kframework.builtin.Sorts; +import org.kframework.definition.Module; +import org.kframework.definition.NonTerminal; +import org.kframework.definition.Production; +import org.kframework.kore.KLabel; +import org.kframework.kore.Sort; +import org.kframework.kore.SortHead; +import org.kframework.parser.Ambiguity; +import org.kframework.parser.Constant; +import org.kframework.parser.ProductionReference; +import org.kframework.parser.Term; +import org.kframework.parser.TermCons; +import org.kframework.utils.errorsystem.KEMException; +import org.pcollections.ConsPStack; +import scala.Tuple2; +import scala.util.Either; +import scala.util.Left; +import scala.util.Right; + +/** + * Disambiguation transformer which performs type checking and infers the sorts of variables. + * + *

The overall design is heavily inspired by the algorithm described in "The Simple Essence of + * Algebraic Subtyping: Principal Type Inference with Subtyping Made Easy" by Lionel Parreaux. + * + *

Each Term can be viewed as a SimpleSub-esque term with equivalent subtyping constraints: + * + *

    + *
  • Each Production is a function symbol whose input types are the sorts of its non-terminals + * and whose output type is the sort of the overall Production. + *
  • Each ProductionReference is an application of the corresponding Production function symbol + * to its arguments. + *
  • The Term as a whole is a function, treating each variable in the Term as a variable bound + * at the top level of the Term. + *
+ * + * Inferring the SimpleSub-esque type is then equivalent to performing SortInference. That is, we + * infer a type a1 -> ... aN -> b telling us that the variables x1, ..., xN have sorts a1, ..., aN + * and the overall Term has sort b. + * + *

Explicitly, the algorithm proceeds as follows + * + *

    + *
  1. Infer a BoundedSort for the input Term as well as all of its variables, recording each + * subtype constraint as lower and upper bounds on sort variables. + *
      + *
    • BoundedSort is directly analogous to SimpleType from SimpleSub, except that we only + * have primitive sorts (BoundedSort.Constructor) and variables (BoundedSort.Variable). + *
    • TermSort represents the "function type" of the overall Term + *
    + *
  2. Constrain the inferred BoundedSort of the overall Term as a subsort of the expected + * topSort. + *
  3. Compactify then simplify the TermSort to produce a CompactSort (analogous to producing the + * CompactType in SimpleSub). + *
  4. Convert the inferred CompactSort into a normal K Sort + *
      + *
    • Monomorphize each sort variable, allowing it to take any value between its recorded + * bounds and possibly producing multiple valid monomorphizations. + *
    • For each type intersection/union, actually compute the corresponding meet/join on the + * subsort poset, erroring if no such meet/join exists. + *
    + *
  5. Insert a SemanticCast around every variable in the Term to record the results. + *
+ */ +public class SortInferencer { + private final Module mod; + + public SortInferencer(Module mod) { + this.mod = mod; + } + + /** + * Determine whether a Term is supported by the current SortInferencer algorithm. Supported terms + * can contain neither ambiguities, strict casts, nor parametric sorts. + */ + public static boolean isSupported(Term t) { + return !hasAmbiguity(t) && !hasStrictCast(t) && !hasParametricSorts(t); + } + + private static boolean hasAmbiguity(Term t) { + if (t instanceof Ambiguity) { + return true; + } + if (t instanceof Constant) { + return false; + } + return ((TermCons) t).items().stream().anyMatch(SortInferencer::hasAmbiguity); + } + + private static boolean hasStrictCast(Term t) { + if (t instanceof Ambiguity amb) { + return amb.items().stream().anyMatch(SortInferencer::hasStrictCast); + } + ProductionReference pr = (ProductionReference) t; + if (pr.production().klabel().isDefined()) { + KLabel klabel = pr.production().klabel().get(); + String label = klabel.name(); + if (label.equals("#SyntacticCast") || label.equals("#InnerCast")) { + return true; + } + } + if (t instanceof Constant) { + return false; + } + return ((TermCons) t).items().stream().anyMatch(SortInferencer::hasStrictCast); + } + + private static boolean hasParametricSorts(Term t) { + if (t instanceof Ambiguity amb) { + return amb.items().stream().anyMatch(SortInferencer::hasParametricSorts); + } + ProductionReference pr = (ProductionReference) t; + if (stream(pr.production().items()) + .filter(pi -> pi instanceof NonTerminal) + .map(pi -> ((NonTerminal) pi).sort()) + .anyMatch(s -> !s.params().isEmpty())) { + return true; + } + if (!pr.production().sort().params().isEmpty()) { + return true; + } + if (pr instanceof Constant) { + return false; + } + return ((TermCons) t).items().stream().anyMatch(SortInferencer::hasParametricSorts); + } + + /** + * Determine if a term is a rule which can be applied anywhere in a configuration, and thus does + * not permit the RHS sort to be wider than the LHS. Such a rule is either explicitly marked with + * an attribute, or is a function or macro rule. + * + * @param isAnywhere - Whether the Term was explicitly marked with an attribute such as anywhere, + * simplification, macro, etc. indicating that it is a rule which applies anywhere + */ + private static boolean isAnywhereRule(Term t, boolean isAnywhere) { + if (t instanceof Ambiguity) { + throw new AssertionError("Ambiguities are not yet supported!"); + } + t = stripBrackets(t); + if (t instanceof Constant) { + return false; + } + TermCons tc = (TermCons) t; + // For every #RuleContent production, the first non-terminal holds a #RuleBody + if (tc.production().sort().equals(Sorts.RuleContent())) { + assert tc.production().nonterminals().size() >= 1 + && tc.production().nonterminal(0).sort().equals(Sorts.RuleBody()); + return isAnywhereRule(tc.get(0), isAnywhere); + } + // For every #RuleBody production, the first non-terminal holds the actual K term + if (tc.production().sort().equals(Sorts.RuleBody())) { + assert tc.production().nonterminals().size() >= 1 + && tc.production().nonterminal(0).sort().equals(Sorts.K()); + return isAnywhereRule(tc.get(0), isAnywhere); + } + // This is the first actual K term we encounter after stripping away rule syntax, + // and should be a rewrite if this is anywhere rule. + if (tc.production().klabel().filter(k -> k.head().equals(KLabels.KREWRITE)).isDefined()) { + Term lhs = stripBrackets(tc.get(0)); + if (lhs instanceof Ambiguity) { + throw new AssertionError("Ambiguities are not yet supported!"); + } + ProductionReference lhsPr = (ProductionReference) lhs; + return isAnywhere + || lhsPr.production().att().contains(Att.FUNCTION()) + || lhsPr.production().att().getMacro().isDefined(); + } + return false; + } + + /** + * The main entry point of SortInferencer, inferring the sort of the input's variables and + * recording the results by inserting casts. + * + * @param t - The Term to infer the sort of + * @param topSort - The expected sort of t + * @param isAnywhere - Whether t is a rule with an attribute indicating that the rule applies + * anywhere in a configuration (e.g. macro, simplification, anywhere, ...). + * @return If t is not well-sorted, then a set of errors. If t is well-sorted, then a new Term + * which is the same as t, but with each variable wrapped in a SemanticCast to its inferred + * type (returning an Ambiguity of all solutions when there are multiple possible sorts). + */ + public Either, Term> apply(Term t, Sort topSort, boolean isAnywhere) { + Set> monoRes; + try { + InferenceDriver driver = new InferenceDriver(mod.subsorts()); + BoundedSort itemSort = infer(t, isAnywhereRule(t, isAnywhere), driver); + BoundedSort topBoundedSort = driver.sortToBoundedSort(topSort, null); + driver.constrain(itemSort, topBoundedSort, (ProductionReference) t); + TermSort unsimplifiedRes = driver.getResult(t, topBoundedSort); + TermSort res = simplify(unsimplifiedRes.mapSorts(CompactSort::makeCompact)); + monoRes = monomorphize(res); + } catch (SortInferenceError e) { + Set errs = new HashSet<>(); + errs.add(e.asInnerParseError(t)); + return Left.apply(errs); + } + + Set items = new HashSet<>(); + for (TermSort mono : monoRes) { + items.add(insertCasts(t, mono, false)); + } + if (items.size() == 1) { + return Right.apply(items.iterator().next()); + } else { + return Right.apply(Ambiguity.apply(items)); + } + } + + /** + * Infer an unsimplified BoundedSort for a term. + * + * @param t - The term we want to infer the type of + * @param isAnywhereRule - Whether t is a rule which can be applied anywhere in a configuration + * @param driver - A driver maintaining all state during inference, including the sort of all + * variables as they are encountered. + * @return The unsimplified sort of the input term + * @throws SortInferenceError - an exception indicating that the term is not well-typed + */ + private BoundedSort infer(Term t, boolean isAnywhereRule, InferenceDriver driver) + throws SortInferenceError { + if (t instanceof Ambiguity) { + throw new AssertionError("Ambiguities are not yet supported!"); + } + + ProductionReference pr = (ProductionReference) t; + if (pr instanceof Constant c) { + if (c.production().sort().equals(Sorts.KVariable()) + || c.production().sort().equals(Sorts.KConfigVar())) { + return driver.varSort(c); + } + return driver.sortToBoundedSort(c.production().sort(), pr); + } + + TermCons tc = (TermCons) pr; + if (isAnywhereRule + && tc.production().klabel().filter(k -> k.head().equals(KLabels.KREWRITE)).isDefined()) { + BoundedSort lhsSort = infer(tc.get(0), false, driver); + // To prevent widening, we constrain RHS's inferred sort <: LHS's declared sort. + // + // Note that we do actually need the LHS's declared sort. The LHS's inferred sort + // is a variable X with a bound L <: X, and constraining against X would just add a + // new lower bound aka permit widening. + ProductionReference lhsDeclaredPr = (ProductionReference) stripBrackets(tc.get(0)); + BoundedSort lhsDeclaredSort = + driver.sortToBoundedSort(lhsDeclaredPr.production().sort(), lhsDeclaredPr); + BoundedSort rhsSort = infer(tc.get(1), false, driver); + driver.constrain(rhsSort, lhsDeclaredSort, (ProductionReference) tc.get(1)); + return lhsSort; + } + + for (int prodI = 0, tcI = 0; prodI < tc.production().items().size(); prodI++) { + if (!(tc.production().items().apply(prodI) instanceof NonTerminal nt)) { + continue; + } + BoundedSort expectedSort = driver.sortToBoundedSort(nt.sort(), pr); + BoundedSort childSort = infer(tc.get(tcI), isAnywhereRule, driver); + driver.constrain(childSort, expectedSort, pr); + tcI++; + } + BoundedSort resSort = new BoundedSort.Variable(); + driver.constrain(driver.sortToBoundedSort(tc.production().sort(), pr), resSort, pr); + return resSort; + } + + /** + * Perform co-occurrence analysis to remove redundant type variables and unify those that are + * indistinguishable. + */ + private TermSort simplify(TermSort res) { + Map, CoOccurrences> coOccurrences = + analyzeCoOccurrences(res, CoOccurMode.ALWAYS); + Map> varSubst = new HashMap<>(); + // Simplify away all those variables that only occur in negative (resp. positive) position. + Set allVars = + coOccurrences.keySet().stream().map(Tuple2::_1).collect(Collectors.toSet()); + allVars.forEach( + (v) -> { + boolean negative = coOccurrences.containsKey(Tuple2.apply(v, false)); + boolean positive = coOccurrences.containsKey(Tuple2.apply(v, true)); + if ((negative && !positive) || (!negative && positive)) { + varSubst.put(v, Optional.empty()); + } + }); + + List pols = new ArrayList<>(); + pols.add(false); + pols.add(true); + + for (SortVariable v : allVars) { + if (varSubst.containsKey(v)) { + continue; + } + for (Boolean pol : pols) { + if (!coOccurrences.containsKey(Tuple2.apply(v, pol))) { + continue; + } + CoOccurrences vCoOccurs = coOccurrences.get(Tuple2.apply(v, pol)); + // v is not in varSubst already, so it must occur in both polarities + // thus this access is valid + CoOccurrences vOpCoOccurs = coOccurrences.get(Tuple2.apply(v, !pol)); + for (SortVariable w : vCoOccurs.vars()) { + if (v.equals(w) || varSubst.containsKey(w)) { + continue; + } + if (coOccurrences.containsKey(Tuple2.apply(w, pol)) + && coOccurrences.get(Tuple2.apply(w, pol)).vars().contains(v)) { + // v and w always co-occur in the given polarity, so we unify w into v + varSubst.put(w, Optional.of(new CompactSort(v))); + // we also need to update v's co-occurrences correspondingly + // (intersecting with w's) + CoOccurrences wOpCoOccurs = coOccurrences.get(Tuple2.apply(w, !pol)); + vOpCoOccurs.vars().retainAll(wOpCoOccurs.vars()); + vOpCoOccurs.ctors().retainAll(wOpCoOccurs.ctors()); + vOpCoOccurs.vars().add(v); + } + } + for (SortHead ctor : vCoOccurs.ctors()) { + // This is not a variable, so check if we have a sandwich ctor <: v <: ctor + // and can thus simplify away v + if (vOpCoOccurs.ctors().contains(ctor)) { + varSubst.put(v, Optional.empty()); + } + } + } + } + + return res.mapSorts((c, p) -> c.substitute(varSubst)); + } + + /** + * Modes for the co-occurrence analysis. A variable is said to co-occur positively (resp. + * negatively) with another type if they occur in the same type union (resp. intersection). + */ + private enum CoOccurMode { + /** + * For each variable and polarity, record only those sorts which always co-occur with the in + * every single position. This is the co-occurrence analysis described in SimpleSub. + */ + ALWAYS, + /** + * For each variable and polarity, record any sort that ever co-occurs with the variable in at + * least one position. In effect, this records all the bounds on the given variable. + */ + EVER + } + + private record CoOccurrences(Set vars, Set ctors) {} + + /** + * Compute the co-occurrences within a TermSort based on the given mode. + * + * @param res - The TermSort to analyze + * @param mode - Mode indicating what type of analysis to perform. See documentation for the + * CoOccurMode enum above. + * @return The result of the co-occurrence analysis. + */ + private Map, CoOccurrences> analyzeCoOccurrences( + TermSort res, CoOccurMode mode) { + Map, CoOccurrences> coOccurrences = new HashMap<>(); + res.forEachSort((s, pol) -> updateCoOccurrences(s, pol, mode, coOccurrences)); + return coOccurrences; + } + + /** + * Update the co-occurrence analysis results so-far to account for the occurrences within sort + * + * @param sort - The sort which we are processing + * @param polarity - The polarity of the provided sort + * @param coOccurrences - mutated to record all co-occurrences for each variable occurring in sort + */ + private void updateCoOccurrences( + CompactSort sort, + boolean polarity, + CoOccurMode mode, + Map, CoOccurrences> coOccurrences) { + for (SortVariable var : sort.vars()) { + Tuple2 polVar = Tuple2.apply(var, polarity); + if (coOccurrences.containsKey(polVar)) { + CoOccurrences coOccurs = coOccurrences.get(polVar); + switch (mode) { + case ALWAYS -> { + coOccurs.vars().retainAll(sort.vars()); + coOccurs.ctors().retainAll(sort.ctors()); + } + case EVER -> { + coOccurs.vars().addAll(sort.vars()); + coOccurs.ctors().addAll(sort.ctors()); + } + } + } else { + coOccurrences.put( + polVar, new CoOccurrences(new HashSet<>(sort.vars()), new HashSet<>(sort.ctors()))); + } + } + } + + /** + * Monomorphize a TermSort. + * + * @param res - The result to monomorphize + * @return A set of all possible monomorphizations of the input result + * @throws SortInferenceError - An error if there are no monomorphizations which can actually be + * produced from the subsort lattice. + */ + private Set> monomorphize(TermSort res) throws SortInferenceError { + Map, CoOccurrences> bounds = + analyzeCoOccurrences(res, CoOccurMode.EVER); + + // Produce all valid instantiations by monomorphizing one variable at a time + Set allVars = + bounds.keySet().stream().map(Tuple2::_1).collect(Collectors.toSet()); + Set> instantiations = new HashSet<>(); + instantiations.add(new HashMap<>()); + for (SortVariable var : allVars) { + Set> newInstantiations = new HashSet<>(); + for (Map instant : instantiations) { + newInstantiations.addAll(monomorphizeInVar(instant, var, bounds)); + } + if (newInstantiations.isEmpty()) { + throw new MonomorphizationError(res.term()); + } + instantiations = newInstantiations; + } + + Set> monos = new HashSet<>(); + SortInferenceError lastError = null; + for (Map inst : instantiations) { + Either> monoRes = realizeTermSort(res, inst); + if (monoRes.isLeft()) { + lastError = monoRes.left().get(); + } else { + monos.add(monoRes.right().get()); + } + } + if (monos.isEmpty()) { + assert lastError != null; + throw lastError; + } + return monos; + } + + /** + * Update the instantiation of variables so far to also include all possible instantiations of the + * provided sort variable. + * + * @param instantiation - A particular instantiation for some subset of the SortVariables + * occurring in our term. + * @param var - A SortVariable that we wish to add to our instantiation + * @param bounds - A map of entries (v, p) |-> C where C records every sort that ever co-occurs + * with the Variable v in the polarity p in our term. + * @return A set of all possible new instantiations accounting for var + */ + private Set> monomorphizeInVar( + Map instantiation, + SortVariable var, + Map, CoOccurrences> bounds) { + + // Record the bounds in each polarity, then search the subsort poset for all solutions + // that satisfy these bounds + Map> polBounds = new HashMap<>(); + polBounds.put(true, new HashSet<>()); + polBounds.put(false, new HashSet<>()); + + for (Entry> polBound : polBounds.entrySet()) { + Tuple2 polVar = Tuple2.apply(var, polBound.getKey()); + if (!bounds.containsKey(polVar)) { + continue; + } + CoOccurrences bound = bounds.get(polVar); + for (SortVariable bVar : bound.vars()) { + // If bVar hasn't been instantiated yet, we can simply ignore it. If any instantiation we + // produce here is invalid when considering bVar, we will prune it later when bVar + // is actually instantiated. + if (instantiation.containsKey(bVar)) { + polBound.getValue().add(instantiation.get(bVar)); + } + } + for (SortHead bCtor : bound.ctors()) { + polBound.getValue().add(new org.kframework.kore.ADT.Sort(bCtor.name(), Seq())); + } + } + + Set range = mod.subsorts().upperBounds(polBounds.get(true)); + range.retainAll(mod.subsorts().lowerBounds(polBounds.get(false))); + + Set> insts = new HashSet<>(); + for (Sort sort : range) { + Map inst = new HashMap<>(instantiation); + inst.put(var, sort); + insts.add(inst); + } + return insts; + } + + /** + * Apply an instantiation to a {@code TermSort}, then compute type joins/meets to + * collapse it to a {@code TermSort}. + * + * @param res - The provided TermSort to realize + * @param instantiation - A concrete value for each variable occurring in res + * @return An equivalent {@code TermSort} + */ + private Either> realizeTermSort( + TermSort res, Map instantiation) { + Either sortRes = + res.sort().asSort(true, instantiation, mod.subsorts()); + if (sortRes.isLeft()) { + return Left.apply(new LatticeOpError(sortRes.left().get(), res.term(), Optional.empty())); + } + Sort sort = sortRes.right().get(); + Map varSorts = new HashMap<>(); + for (Entry entry : res.varSorts().entrySet()) { + Either varRes = + entry.getValue().asSort(false, instantiation, mod.subsorts()); + if (varRes.isLeft()) { + CompactSort.LatticeOpError latticeErr = varRes.left().get(); + if (entry.getKey() instanceof VariableId.Anon anon) { + return Left.apply( + new LatticeOpError(latticeErr, anon.constant(), Optional.of("variable"))); + } + if (entry.getKey() instanceof VariableId.Named named) { + return Left.apply( + new LatticeOpError(latticeErr, res.term(), Optional.of("variable " + named.name()))); + } + throw new AssertionError("VariableId should be either Anon or Named"); + } + varSorts.put(entry.getKey(), varRes.right().get()); + } + return Right.apply(new TermSort<>(res.term(), sort, varSorts)); + } + + /** + * Insert SemanticCasts around each variable casting it to the appropriate Sort. + * + * @param t - The term to insert casts on + * @param sorts - The inferred sorts of t + * @param existingCast - Whether t is already wrapped in an existing cast + */ + private Term insertCasts(Term t, TermSort sorts, boolean existingCast) { + if (t instanceof Ambiguity) { + throw new AssertionError("Ambiguities are not yet supported!"); + } + + ProductionReference pr = (ProductionReference) t; + if (pr instanceof Constant c) { + if (c.production().sort().equals(Sorts.KVariable()) + || c.production().sort().equals(Sorts.KConfigVar())) { + Sort inferred = sorts.varSorts().get(VariableId.apply(c)); + if (!existingCast) { + Production cast = + mod.productionsFor().apply(KLabel("#SemanticCastTo" + inferred.toString())).head(); + return TermCons.apply(ConsPStack.singleton(t), cast, t.location(), t.source()); + } + } + return c; + } + + TermCons tc = (TermCons) pr; + boolean isCast = + tc.production().klabel().filter(k -> k.name().startsWith("#SemanticCastTo")).isDefined(); + for (int i = 0; i < tc.items().size(); i++) { + tc = tc.with(i, insertCasts(tc.get(i), sorts, isCast)); + } + return tc; + } + + private static Term stripBrackets(Term tc) { + Term child = tc; + while (child instanceof TermCons + && ((TermCons) child).production().att().contains(Att.BRACKET())) { + child = ((TermCons) child).get(0); + } + return child; + } +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortVariable.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortVariable.java new file mode 100644 index 00000000000..a1c5dddbd8e --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortVariable.java @@ -0,0 +1,14 @@ +// Copyright (c) K Team. All Rights Reserved. +package org.kframework.parser.inner.disambiguation.inference; + +/** + * A variable which could be instantiated with any Sort. + * + *

This class is necessary to ensure that all distinct variable instances compare unequal. In + * particular, we can't just use {@code Sort("X")} to represent a sort variable lest {@code new + * Sort("X").equals(new Sort("X"))} even when the {@code X} arises from two different + * ProductionReferences. + */ +public class SortVariable { + public SortVariable() {} +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/TermSort.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/TermSort.java new file mode 100644 index 00000000000..0da17e28c4e --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/TermSort.java @@ -0,0 +1,38 @@ +// Copyright (c) K Team. All Rights Reserved. +package org.kframework.parser.inner.disambiguation.inference; + +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.stream.Collectors; +import org.kframework.parser.Term; + +/** + * The sort of a Term and all its variables. + * + * @param term - The term which has this particular sort + * @param sort - The top level sort of the overall term + * @param varSorts - The sort of each variable occurring in the term + * @param The particular type of sort we are considering (will be BoundedSort, CompactSort, and + * Sort depending on the stage of inference). + */ +public record TermSort(Term term, S sort, Map varSorts) { + /** + * Map over all contained sorts in their polarity. + * + * @param func - A map taking both a sort and a polarity then producing a new sort. + */ + public TermSort mapSorts(BiFunction func) { + T newSort = func.apply(sort, true); + Map newVarSorts = + varSorts().entrySet().stream() + .collect(Collectors.toMap((Map.Entry::getKey), (e) -> func.apply(e.getValue(), false))); + return new TermSort<>(term, newSort, newVarSorts); + } + + /** Apply an action to all contained sorts in their polarity. */ + public void forEachSort(BiConsumer action) { + action.accept(sort, true); + varSorts().values().forEach((v) -> action.accept(v, false)); + } +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/VariableId.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/VariableId.java new file mode 100644 index 00000000000..6d2db775797 --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/VariableId.java @@ -0,0 +1,42 @@ +// Copyright (c) K Team. All Rights Reserved. +package org.kframework.parser.inner.disambiguation.inference; + +import org.kframework.attributes.Att; +import org.kframework.compile.ResolveAnonVar; +import org.kframework.kore.ADT.KVariable; +import org.kframework.parser.Constant; + +/** + * A type representing a particular variable. Specifically, either + * + *

    + *
  • Named, which just records the variable's name + *
  • Anon, which wraps the particular Constant of an anonymous variable in order to provide it + * reference semantics (lest all vars named "_" compare equal). + *
+ */ +public sealed interface VariableId { + static VariableId apply(Constant var) { + if (ResolveAnonVar.isAnonVarOrNamedAnonVar(new KVariable(var.value(), Att.empty()))) { + return new Anon(var); + } + return new Named(var.value()); + } + + record Named(String name) implements VariableId {} + + record Anon(Constant constant) implements VariableId { + @Override + public boolean equals(Object o) { + if (o instanceof Anon a) { + return this.constant == a.constant; + } + return false; + } + + @Override + public int hashCode() { + return System.identityHashCode(constant); + } + } +} diff --git a/kore/src/main/java/org/kframework/utils/errorsystem/KEMException.java b/kore/src/main/java/org/kframework/utils/errorsystem/KEMException.java index fcb85e7cd25..745aed6fb64 100644 --- a/kore/src/main/java/org/kframework/utils/errorsystem/KEMException.java +++ b/kore/src/main/java/org/kframework/utils/errorsystem/KEMException.java @@ -5,7 +5,6 @@ import org.kframework.attributes.HasLocation; import org.kframework.attributes.Location; import org.kframework.attributes.Source; -import org.kframework.parser.Term; import org.kframework.utils.errorsystem.KException.ExceptionType; import org.kframework.utils.errorsystem.KException.KExceptionGroup; @@ -138,14 +137,14 @@ public static KEMException innerParserError(String message, Source source, Locat ExceptionType.ERROR, KExceptionGroup.INNER_PARSER, message, null, location, source); } - public static KEMException innerParserError(String message, Term t) { + public static KEMException innerParserError(String message, HasLocation node) { return create( ExceptionType.ERROR, KExceptionGroup.INNER_PARSER, message, null, - t.location().orElse(null), - t.source().orElse(null)); + node.location().orElse(null), + node.source().orElse(null)); } public static KEMException innerParserError( diff --git a/kore/src/main/scala/org/kframework/builtin/Sorts.scala b/kore/src/main/scala/org/kframework/builtin/Sorts.scala index 68de9ee337f..61d1843e877 100644 --- a/kore/src/main/scala/org/kframework/builtin/Sorts.scala +++ b/kore/src/main/scala/org/kframework/builtin/Sorts.scala @@ -14,6 +14,7 @@ object Sorts { val KString = Sort("KString") val KBool = Sort("KBool") val RuleContent = Sort("#RuleContent") + val RuleBody = Sort("#RuleBody") val Bool = Sort("Bool") val Int = Sort("Int")