Skip to content

Commit

Permalink
Add lowerBounds and upperBounds methods to POSet (#3733)
Browse files Browse the repository at this point in the history
This PR does the following:
- Cleans up all the warnings in `POSet.scala`
- Adds methods to `POSet` to compute all lower bounds or upper bounds of
a given set of elements.
- Modifies the least upper bound computation in `AddSortInjections.java`
to use these new methods

This is just to remove some duplication when computing joins and meets
in the new type inference engine.

---------

Co-authored-by: rv-jenkins <[email protected]>
  • Loading branch information
Scott-Guest and rv-jenkins authored Oct 20, 2023
1 parent bdde319 commit d18ea8f
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 75 deletions.
47 changes: 15 additions & 32 deletions kernel/src/main/java/org/kframework/compile/AddSortInjections.java
Original file line number Diff line number Diff line change
Expand Up @@ -453,48 +453,31 @@ private static Sort lub(Collection<Sort> entries, Sort expectedSort, HasLocation
if (filteredEntries.isEmpty()) { // if all sorts are parameters, take the first
return entries.iterator().next();
}
Set<Sort> bounds = upperBounds(filteredEntries, mod);

Set<Sort> nonParametric =
filteredEntries.stream().filter(s -> s.params().isEmpty()).collect(Collectors.toSet());
Set<Sort> bounds = mutable(mod.subsorts().upperBounds(immutable(nonParametric)));
// Anything less than KBott or greater than K is a syntactic sort from kast.md which should not be considered
bounds.removeIf(s -> mod.subsorts().lessThanEq(s, Sorts.KBott()) || mod.subsorts().greaterThan(s, Sorts.K()));
if (expectedSort != null && !expectedSort.name().equals(SORTPARAM_NAME)) {
bounds.removeIf(s -> !mod.subsorts().lessThanEq(s, expectedSort));
}

// For parametric sorts, each bound must bound at least one instantiation
Set<Sort> parametric =
filteredEntries.stream().filter(s -> ! s.params().isEmpty()).collect(Collectors.toSet());
bounds.removeIf(bound ->
parametric.stream().anyMatch(param ->
stream(mod.definedInstantiations().apply(param.head()))
.noneMatch(inst -> mod.subsorts().lessThanEq(inst, bound))));

Set<Sort> lub = mod.subsorts().minimal(bounds);
if (lub.size() != 1) {
throw KEMException.internalError("Could not compute least upper bound for rewrite sort. Possible candidates: " + lub, loc);
}
return lub.iterator().next();
}

private static Set<Sort> upperBounds(Collection<Sort> bounds, Module mod) {
Set<Sort> maxs = new HashSet<>();
nextsort:
for (Sort sort : iterable(mod.allSorts())) { // for every declared sort
// Sorts at or below KBott, or above K, are assumed to be
// sorts from kast.k representing meta-syntax that is not a real sort.
// This is done to prevent variables from being inferred as KBott or
// as KList.
if (mod.subsorts().lessThanEq(sort, Sorts.KBott()))
continue;
if (mod.subsorts().greaterThan(sort, Sorts.K()))
continue;
for (Sort bound : bounds)
if (bound.params().isEmpty()) {
if (!mod.subsorts().lessThanEq(bound, sort))
continue nextsort;
} else {
boolean any = false;
for (Sort instantiation : iterable(mod.definedInstantiations().apply(bound.head()))) {
if (mod.subsorts().lessThanEq(instantiation, sort)) {
any = true;
}
}
if (!any)
continue nextsort;
}
maxs.add(sort);
}
return maxs;
}

private Sort freshSortParam() {
return Sort(SORTPARAM_NAME, Sort("Q" + freshSortParamCounter++));
}
Expand Down
97 changes: 54 additions & 43 deletions kore/src/main/scala/org/kframework/POSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@ package org.kframework
import org.kframework.utils.errorsystem.KEMException

import java.util
import java.util.Optional
import collection._
import scala.annotation.tailrec

/**
* A partially ordered set based on an initial set of direct relations.
*/
class POSet[T](val directRelations: Set[(T, T)]) extends Serializable {

// convert the input set of relations to Map form for performance
private val directRelationsMap: Map[T, Set[T]] = directRelations groupBy { _._1 } mapValues { _ map { _._2 } toSet } map identity

Expand All @@ -26,6 +25,7 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable {
* The implementation is simple. It links each element to the successors of its successors.
* TODO: there may be a more efficient algorithm (low priority)
*/
@tailrec
private def transitiveClosure(relations: Map[T, Set[T]]): Map[T, Set[T]] = {
val newRelations = relations map {
case (start, succ) =>
Expand All @@ -44,23 +44,31 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable {
* @param current element
* @param path so far
*/
private def constructAndThrowCycleException(start: T, current: T, path: Seq[T]) {
private def constructAndThrowCycleException(start: T, current: T, path: Seq[T]): Unit = {
val currentPath = path :+ current
val succs = directRelationsMap.getOrElse(current, Set())
if (succs.contains(start)) {
throw KEMException.compilerError("Illegal circular relation: " + (currentPath :+ start).mkString(" < "));
throw KEMException.compilerError("Illegal circular relation: " + (currentPath :+ start).mkString(" < "))
}
succs foreach { constructAndThrowCycleException(start, _, currentPath) }
}

/**
* All the relations of the POSet, including the transitive ones.
*
* Concretely, a map from each element of the poset to the set of elements greater than it.
*/
val relations: Map[T, Set[T]] = transitiveClosure(directRelationsMap)

/**
* A map from each element of the poset to the set of elements less than it.
*/
val relations = transitiveClosure(directRelationsMap)
lazy val relationsOp: Map[T, Set[T]] =
relations.toSet[(T, Set[T])].flatMap { case (x, ys) => ys.map(_ -> x) }.groupBy(_._1).mapValues(_.map(_._2))

def <(x: T, y: T): Boolean = relations.get(x).exists(_.contains(y))
def >(x: T, y: T): Boolean = relations.get(y).exists(_.contains(x))
def ~(x: T, y: T) = <(x, y) || <(y, x)
def ~(x: T, y: T): Boolean = <(x, y) || <(y, x)

/**
* Returns true if x < y
Expand All @@ -77,34 +85,29 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable {
/**
* Returns true if y < x or y < x
*/
def inSomeRelation(x: T, y: T) = this.~(x, y)
def inSomeRelationEq(x: T, y: T) = x == y || this.~(x, y)
def inSomeRelation(x: T, y: T): Boolean = this.~(x, y)
def inSomeRelationEq(x: T, y: T): Boolean = x == y || this.~(x, y)

/**
* Returns an Optional of the least upper bound if it exists, or an empty Optional otherwise.
* Return the set of all upper bounds of the input.
*/
lazy val leastUpperBound: Optional[T] = lub match {
case Some(x) => Optional.of(x)
case None => Optional.empty()
}
def upperBounds(sorts: Iterable[T]): Set[T] =
if (sorts.isEmpty) elements else POSet.upperBounds(sorts, relations)

/**
* Return the set of all lower bounds of the input.
*/
def lowerBounds(sorts: Iterable[T]): Set[T] =
if (sorts.isEmpty) elements else POSet.upperBounds(sorts, relationsOp)

lazy val lub: Option[T] = {
val candidates = relations.values reduce { (a, b) => a & b }

if (candidates.size == 0)
None
else if (candidates.size == 1)
Some(candidates.head)
else {
val allPairs = for (a <- candidates; b <- candidates) yield { (a, b) }
if (allPairs exists { case (a, b) => ! ~(a, b) })
None
else
Some(
candidates.min(new Ordering[T]() {
def compare(x: T, y: T) = if (x < y) -1 else if (x > y) 1 else 0
}))
}
val mins = minimal(upperBounds(elements))
if (mins.size == 1) Some(mins.head) else None
}

lazy val glb: Option[T] = {
val maxs = maximal(lowerBounds(elements))
if (maxs.size == 1) Some(maxs.head) else None
}

lazy val asOrdering: Ordering[T] = (x: T, y: T) => if (lessThanEq(x, y)) -1 else if (lessThanEq(y, x)) 1 else 0
Expand All @@ -113,33 +116,33 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable {
* Return the subset of items from the argument which are not
* less than any other item.
*/
def maximal(sorts : Iterable[T]) : Set[T] =
def maximal(sorts: Iterable[T]): Set[T] =
sorts.filter(s1 => !sorts.exists(s2 => lessThan(s1,s2))).toSet

def maximal(sorts : util.Collection[T]) : util.Set[T] = {
import scala.collection.JavaConversions._
maximal(sorts : Iterable[T])
def maximal(sorts: util.Collection[T]): util.Set[T] = {
import scala.collection.JavaConverters._
maximal(sorts.asScala).asJava
}

/**
* Return the subset of items from the argument which are not
* greater than any other item.
*/
def minimal(sorts : Iterable[T]) : Set[T] =
def minimal(sorts: Iterable[T]): Set[T] =
sorts.filter(s1 => !sorts.exists(s2 => >(s1,s2))).toSet

def minimal(sorts : util.Collection[T]) : util.Set[T] = {
import scala.collection.JavaConversions._
minimal(sorts : Iterable[T])
def minimal(sorts: util.Collection[T]): util.Set[T] = {
import scala.collection.JavaConverters._
minimal(sorts.asScala).asJava
}

override def toString() = {
"POSet(" + (relations flatMap { case (from, tos) => tos map { case to => from + "<" + to } }).mkString(",") + ")"
override def toString: String = {
"POSet(" + (relations flatMap { case (from, tos) => tos map { to => from + "<" + to } }).mkString(",") + ")"
}

override def hashCode = relations.hashCode()
override def hashCode: Int = relations.hashCode()

override def equals(that: Any) = that match {
override def equals(that: Any): Boolean = that match {
case that: POSet[_] => relations == that.relations
case _ => false
}
Expand All @@ -153,7 +156,15 @@ object POSet {
* Import this for Scala syntactic sugar.
*/
implicit class PO[T](x: T)(implicit val po: POSet[T]) {
def <(y: T) = po.<(x, y)
def >(y: T) = po.>(x, y)
def <(y: T): Boolean = po.<(x, y)
def >(y: T): Boolean = po.>(x, y)
}

/**
* Return the set of all elements which are greater than or equal to each element of the input,
* using the provided relations map. Input must be non-empty.
*/
private def upperBounds[T](sorts: Iterable[T], relations: Map[T, Set[T]]): Set[T] =
(((sorts filterNot relations.keys.toSet[T]) map {Set.empty + _}) ++
((relations filterKeys sorts.toSet) map { case (k, v) => v + k })) reduce { (a, b) => a & b }
}
8 changes: 8 additions & 0 deletions kore/src/test/scala/org/kframework/POSetTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,12 @@ class POSetTest {
assertEquals(None, POSet(b1 -> b2, b2 -> b3, b4 -> b5).lub)
assertEquals(None, POSet(b1 -> b2, b2 -> b3, b2 -> b4).lub)
}

@Test def glb() {
assertEquals(Some(b2), POSet(b2 -> b1).glb)
assertEquals(Some(b3), POSet(b3 -> b1, b3 -> b2).glb)
assertEquals(Some(b4), POSet(b3 -> b1, b3 -> b2, b4 -> b3).glb)
assertEquals(None, POSet(b2 -> b1, b3 -> b2, b5 -> b4).glb)
assertEquals(None, POSet(b2 -> b1, b3 -> b2, b4 -> b2).glb)
}
}

0 comments on commit d18ea8f

Please sign in to comment.