Skip to content

Commit

Permalink
Use a different rule for NotNullInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
noti0na1 committed Oct 11, 2024
1 parent 404098f commit 273c867
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 47 deletions.
44 changes: 14 additions & 30 deletions compiler/src/dotty/tools/dotc/typer/Nullables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,48 +53,35 @@ object Nullables:
TypeBoundsTree(lo, hiTree, alias)

/** A set of val or var references that are known to be not null,
* a set of variable references that are not known (anymore) to be not null,
* plus a set of variables that are known to be not null at any point.
* plus a set of variable references that are once assigned to null.
*/
case class NotNullInfo(asserted: Set[TermRef], retracted: Set[TermRef], onceRetracted: Set[TermRef]):
assert((asserted & retracted).isEmpty)
assert(retracted.subsetOf(onceRetracted))

case class NotNullInfo(asserted: Set[TermRef], retracted: Set[TermRef]):
def isEmpty = this eq NotNullInfo.empty

def retractedInfo = NotNullInfo(Set(), retracted, onceRetracted)

def onceRetractedInfo = NotNullInfo(Set(), onceRetracted, onceRetracted)
def retractedInfo = NotNullInfo(Set(), retracted)

/** The sequential combination with another not-null info */
def seq(that: NotNullInfo): NotNullInfo =
if this.isEmpty then that
else if that.isEmpty then this
else NotNullInfo(
this.asserted.union(that.asserted).diff(that.retracted),
this.retracted.union(that.retracted).diff(that.asserted),
this.onceRetracted.union(that.onceRetracted))
this.asserted.diff(that.retracted).union(that.asserted),
this.retracted.union(that.retracted))

/** The alternative path combination with another not-null info. Used to merge
* the nullability info of the two branches of an if.
*/
def alt(that: NotNullInfo): NotNullInfo =
NotNullInfo(
this.asserted.intersect(that.asserted),
this.retracted.union(that.retracted),
this.onceRetracted.union(that.onceRetracted))
NotNullInfo(this.asserted.intersect(that.asserted), this.retracted.union(that.retracted))

def withOnceRetracted(that: NotNullInfo): NotNullInfo =
if that.isEmpty then this
else NotNullInfo(this.asserted, this.retracted, this.onceRetracted.union(that.onceRetracted))
def withRetracted(that: NotNullInfo): NotNullInfo =
NotNullInfo(this.asserted, this.retracted.union(that.retracted))

object NotNullInfo:
val empty = new NotNullInfo(Set(), Set(), Set())
val empty = new NotNullInfo(Set(), Set())
def apply(asserted: Set[TermRef], retracted: Set[TermRef]): NotNullInfo =
apply(asserted, retracted, retracted)
def apply(asserted: Set[TermRef], retracted: Set[TermRef], onceRetracted: Set[TermRef]): NotNullInfo =
if asserted.isEmpty && onceRetracted.isEmpty then empty
else new NotNullInfo(asserted, retracted, onceRetracted)
if asserted.isEmpty && retracted.isEmpty then empty
else new NotNullInfo(asserted, retracted)
end NotNullInfo

/** A pair of not-null sets, depending on whether a condition is `true` or `false` */
Expand Down Expand Up @@ -247,16 +234,13 @@ object Nullables:
* or retractions in `info` supersede infos in existing entries of `infos`.
*/
def extendWith(info: NotNullInfo) =
if info.isEmpty
|| info.asserted.forall(infos.impliesNotNull(_))
&& !info.retracted.exists(infos.impliesNotNull(_))
then infos
if info.isEmpty then infos
else info :: infos

/** Retract all references to mutable variables */
def retractMutables(using Context) =
val mutables = infos.foldLeft(Set[TermRef]())((ms, info) =>
ms.union(info.asserted.filter(_.symbol.is(Mutable))))
val mutables = infos.foldLeft(Set[TermRef]()):
(ms, info) => ms.union(info.asserted.filter(_.symbol.is(Mutable)))
infos.extendWith(NotNullInfo(Set(), mutables))

end extension
Expand Down
34 changes: 18 additions & 16 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1544,9 +1544,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def elsePathInfo = cond1.notNullInfoIf(false).seq(result.elsep.notNullInfo)
result.withNotNullInfo(
if result.thenp.tpe.isRef(defn.NothingClass) then
elsePathInfo.withOnceRetracted(thenPathInfo)
elsePathInfo.withRetracted(thenPathInfo)
else if result.elsep.tpe.isRef(defn.NothingClass) then
thenPathInfo.withOnceRetracted(elsePathInfo)
thenPathInfo.withRetracted(elsePathInfo)
else thenPathInfo.alt(elsePathInfo)
)
end typedIf
Expand Down Expand Up @@ -2143,9 +2143,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedMatchFinish(tree: untpd.Match, sel: Tree, wideSelType: Type, cases: List[untpd.CaseDef], pt: Type)(using Context): Tree = {
val cases1 = harmonic(harmonize, pt)(typedCases(cases, sel, wideSelType, pt.dropIfProto))
.asInstanceOf[List[CaseDef]]
var nni = sel.notNullInfo
if cases1.nonEmpty then nni = nni.seq(cases1.map(_.notNullInfo).reduce(_.alt(_)))
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).withNotNullInfo(nni)
var nnInfo = sel.notNullInfo
if cases1.nonEmpty then nnInfo = nnInfo.seq(cases1.map(_.notNullInfo).reduce(_.alt(_)))
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).withNotNullInfo(nnInfo)
}

def typedCases(cases: List[untpd.CaseDef], sel: Tree, wideSelType0: Type, pt: Type)(using Context): List[CaseDef] =
Expand Down Expand Up @@ -2327,7 +2327,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val capabilityProof = caughtExceptions.reduce(OrType(_, _, true))
untpd.Block(makeCanThrow(capabilityProof), expr)

def typedTry(tree: untpd.Try, pt: Type)(using Context): Try = {
def typedTry(tree: untpd.Try, pt: Type)(using Context): Try =
var nnInfo = NotNullInfo.empty
val expr2 :: cases2x = harmonic(harmonize, pt) {
// We want to type check tree.expr first to comput NotNullInfo, but `addCanThrowCapabilities`
// uses the types of patterns in `tree.cases` to determine the capabilities.
Expand All @@ -2339,25 +2340,26 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val casesEmptyBody1 = tree.cases.mapconserve(cpy.CaseDef(_)(body = EmptyTree))
val casesEmptyBody2 = typedCases(casesEmptyBody1, EmptyTree, defn.ThrowableType, WildcardType)
val expr1 = typed(addCanThrowCapabilities(tree.expr, casesEmptyBody2), pt.dropIfProto)
val casesCtx = ctx.addNotNullInfo(expr1.notNullInfo.retractedInfo)

// Since we don't know at which point the the exception is thrown in the body,
// we have to collect any reference that is once retracted.
nnInfo = expr1.notNullInfo.retractedInfo

val casesCtx = ctx.addNotNullInfo(nnInfo)
val cases1 = typedCases(tree.cases, EmptyTree, defn.ThrowableType, pt.dropIfProto)(using casesCtx)
expr1 :: cases1
}: @unchecked
val cases2 = cases2x.asInstanceOf[List[CaseDef]]

// Since we don't know at which point the the exception is thrown in the body,
// we have to collect any reference that is once retracted.
var nni = expr2.notNullInfo.onceRetractedInfo
// It is possible to have non-exhaustive cases, and some exceptions are thrown and not caught.
// Therefore, the code in the finallizer and after the try block can only rely on the retracted
// info from the cases' body.
if cases2.nonEmpty then nni = nni.seq(cases2.map(_.notNullInfo.retractedInfo).reduce(_.alt(_)))

val finalizer1 = typed(tree.finalizer, defn.UnitType)(using ctx.addNotNullInfo(nni))
nni = nni.seq(finalizer1.notNullInfo)
if cases2.nonEmpty then
nnInfo = nnInfo.seq(cases2.map(_.notNullInfo.retractedInfo).reduce(_.alt(_)))

assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2).withNotNullInfo(nni)
}
val finalizer1 = typed(tree.finalizer, defn.UnitType)(using ctx.addNotNullInfo(nnInfo))
nnInfo = nnInfo.seq(finalizer1.notNullInfo)
assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2).withNotNullInfo(nnInfo)

def typedTry(tree: untpd.ParsedTry, pt: Type)(using Context): Try =
val cases: List[untpd.CaseDef] = tree.handler match
Expand Down
19 changes: 18 additions & 1 deletion tests/explicit-nulls/neg/i21619.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,21 @@ def test4: String =
x = ""
catch
case _ =>
x.replace("", "") // error
x.replace("", "") // error

def test5: Unit =
var x: String | Null = null
var y: String | Null = null
x = ""
y = ""
var i: Int = 1
try
i match
case _ =>
x = null
throw new Exception()
x = ""
catch
case _ =>
val z1: String = x.replace("", "") // error
val z2: String = y.replace("", "")

0 comments on commit 273c867

Please sign in to comment.