Skip to content

Commit

Permalink
Context Bounds for Polymorphic Functions (#21643)
Browse files Browse the repository at this point in the history
Implement the `#6` point form SIP-64 i.e.
---

### 6. Context Bounds for Polymorphic Functions

Currently, context bounds can be used in methods, but not in function
types or function literals. It would be nice propose to drop this
irregularity and allow context bounds also in these places. Example:

```scala
type Comparer = [X: Ord] => (x: X, y: X) => Boolean
val less: Comparer = [X: Ord as ord] => (x: X, y: X) =>
  ord.compare(x, y) < 0
```

The expansion of such context bounds is analogous to the expansion in
method types, except that instead of adding a using clause in a method,
we insert a context function type.

For instance, the `type` and `val` definitions above would expand to
```scala
type Comparer = [X] => (x: X, y: X) => Ord[X] ?=> Boolean
val less: Comparer = [X] => (x: X, y: X) => (ord: Ord[X]) ?=>
  ord.compare(x, y) < 0
```

The expansion of using clauses does look inside alias types. For
instance,
here is a variation of the previous example that uses a parameterized
type alias:
```scala
type Cmp[X] = (x: X, y: X) => Boolean
type Comparer2 = [X: Ord] => Cmp[X]
```
The expansion of the right hand side of `Comparer2` expands the `Cmp[X]`
alias
and then inserts the context function at the same place as what's done
for `Comparer`.
  • Loading branch information
WojciechMazur authored Nov 19, 2024
2 parents cc4a324 + 952eff7 commit 5611522
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 68 deletions.
202 changes: 139 additions & 63 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ object desugar {
*/
val ContextBoundParam: Property.Key[Unit] = Property.StickyKey()

/** Marks a poly fcuntion apply method, so that we can handle adding evidence parameters to them in a special way
*/
val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey()

/** What static check should be applied to a Match? */
enum MatchCheck {
case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom
Expand Down Expand Up @@ -242,7 +246,7 @@ object desugar {
* def f$default$2[T](x: Int) = x + "m"
*/
private def defDef(meth: DefDef, isPrimaryConstructor: Boolean = false)(using Context): Tree =
addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor))
addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor).asInstanceOf[DefDef])

/** Drop context bounds in given TypeDef, replacing them with evidence ValDefs that
* get added to a buffer.
Expand Down Expand Up @@ -304,10 +308,8 @@ object desugar {
tdef1
end desugarContextBounds

private def elimContextBounds(meth: DefDef, isPrimaryConstructor: Boolean)(using Context): DefDef =
val DefDef(_, paramss, tpt, rhs) = meth
def elimContextBounds(meth: Tree, isPrimaryConstructor: Boolean = false)(using Context): Tree =
val evidenceParamBuf = mutable.ListBuffer[ValDef]()

var seenContextBounds: Int = 0
def freshName(unused: Tree) =
seenContextBounds += 1 // Start at 1 like FreshNameCreator.
Expand All @@ -317,7 +319,7 @@ object desugar {
// parameters of the method since shadowing does not affect
// implicit resolution in Scala 3.

val paramssNoContextBounds =
def paramssNoContextBounds(paramss: List[ParamClause]): List[ParamClause] =
val iflag = paramss.lastOption.flatMap(_.headOption) match
case Some(param) if param.mods.isOneOf(GivenOrImplicit) =>
param.mods.flags & GivenOrImplicit
Expand All @@ -329,15 +331,32 @@ object desugar {
tparam => desugarContextBounds(tparam, evidenceParamBuf, flags, freshName, paramss)
}(identity)

rhs match
case MacroTree(call) =>
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
case _ =>
addEvidenceParams(
cpy.DefDef(meth)(
name = normalizeName(meth, tpt).asTermName,
paramss = paramssNoContextBounds),
evidenceParamBuf.toList)
meth match
case meth @ DefDef(_, paramss, tpt, rhs) =>
val newParamss = paramssNoContextBounds(paramss)
rhs match
case MacroTree(call) =>
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
case _ =>
addEvidenceParams(
cpy.DefDef(meth)(
name = normalizeName(meth, tpt).asTermName,
paramss = newParamss
),
evidenceParamBuf.toList
)
case meth @ PolyFunction(tparams, fun) =>
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = meth: @unchecked
val Function(vparams: List[untpd.ValDef] @unchecked, rhs) = fun: @unchecked
val newParamss = paramssNoContextBounds(tparams :: vparams :: Nil)
val params = evidenceParamBuf.toList
if params.isEmpty then
meth
else
val boundNames = getBoundNames(params, newParamss)
val recur = fitEvidenceParams(params, nme.apply, boundNames)
val (paramsFst, paramsSnd) = recur(newParamss)
functionsOf((paramsFst ++ paramsSnd).filter(_.nonEmpty), rhs)
end elimContextBounds

def addDefaultGetters(meth: DefDef)(using Context): Tree =
Expand Down Expand Up @@ -465,6 +484,74 @@ object desugar {
case _ =>
(Nil, tree)

private def referencesName(vdef: ValDef, names: Set[TermName])(using Context): Boolean =
vdef.tpt.existsSubTree:
case Ident(name: TermName) => names.contains(name)
case _ => false

/** Fit evidence `params` into the `mparamss` parameter lists, making sure
* that all parameters referencing `params` are after them.
* - for methods the final parameter lists are := result._1 ++ result._2
* - for poly functions, each element of the pair contains at most one term
* parameter list
*
* @param params the evidence parameters list that should fit into `mparamss`
* @param methName the name of the method that `mparamss` belongs to
* @param boundNames the names of the evidence parameters
* @param mparamss the original parameter lists of the method
* @return a pair of parameter lists containing all parameter lists in a
* reference-correct order; make sure that `params` is always at the
* intersection of the pair elements; this is relevant, for poly functions
* where `mparamss` is guaranteed to have exectly one term parameter list,
* then each pair element will have at most one term parameter list
*/
private def fitEvidenceParams(
params: List[ValDef],
methName: Name,
boundNames: Set[TermName]
)(mparamss: List[ParamClause])(using Context): (List[ParamClause], List[ParamClause]) = mparamss match
case ValDefs(mparams) :: _ if mparams.exists(referencesName(_, boundNames)) =>
(params :: Nil) -> mparamss
case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) =>
val normParams =
if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then
params.map: param =>
val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit))
param.withMods(param.mods.withFlags(normFlags))
.showing(i"adapted param $result ${result.mods.flags} for ${methName}", Printers.desugar)
else params
((normParams ++ mparams) :: Nil) -> Nil
case mparams :: mparamss1 =>
val (fst, snd) = fitEvidenceParams(params, methName, boundNames)(mparamss1)
(mparams :: fst) -> snd
case Nil =>
Nil -> (params :: Nil)

/** Create a chain of possibly contextual functions from the parameter lists */
private def functionsOf(paramss: List[ParamClause], rhs: Tree)(using Context): Tree = paramss match
case Nil => rhs
case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) =>
val paramTpts = head.map(_.tpt)
val paramNames = head.map(_.name)
val paramsErased = head.map(_.mods.flags.is(Erased))
makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span)
case ValDefs(head) :: rest =>
Function(head, functionsOf(rest, rhs))
case TypeDefs(head) :: rest =>
PolyFunction(head, functionsOf(rest, rhs))
case _ =>
assert(false, i"unexpected paramss $paramss")
EmptyTree

private def getBoundNames(params: List[ValDef], paramss: List[ParamClause])(using Context): Set[TermName] =
var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
for mparams <- paramss; mparam <- mparams do
mparam match
case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) =>
boundNames += tparam.name.toTermName
case _ =>
boundNames

/** Add all evidence parameters in `params` as implicit parameters to `meth`.
* The position of the added parameters is determined as follows:
*
Expand All @@ -479,36 +566,23 @@ object desugar {
private def addEvidenceParams(meth: DefDef, params: List[ValDef])(using Context): DefDef =
if params.isEmpty then return meth

var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
for mparams <- meth.paramss; mparam <- mparams do
mparam match
case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) =>
boundNames += tparam.name.toTermName
case _ =>
val boundNames = getBoundNames(params, meth.paramss)

def referencesBoundName(vdef: ValDef): Boolean =
vdef.tpt.existsSubTree:
case Ident(name: TermName) => boundNames.contains(name)
case _ => false
val fitParams = fitEvidenceParams(params, meth.name, boundNames)

def recur(mparamss: List[ParamClause]): List[ParamClause] = mparamss match
case ValDefs(mparams) :: _ if mparams.exists(referencesBoundName) =>
params :: mparamss
case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) =>
val normParams =
if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then
params.map: param =>
val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit))
param.withMods(param.mods.withFlags(normFlags))
.showing(i"adapted param $result ${result.mods.flags} for ${meth.name}", Printers.desugar)
else params
(normParams ++ mparams) :: Nil
case mparams :: mparamss1 =>
mparams :: recur(mparamss1)
case Nil =>
params :: Nil

cpy.DefDef(meth)(paramss = recur(meth.paramss))
if meth.removeAttachment(PolyFunctionApply).isDefined then
// for PolyFunctions we are limited to a single term param list, so we
// reuse the fitEvidenceParams logic to compute the new parameter lists
// and then we add the other parameter lists as function types to the
// return type
val (paramsFst, paramsSnd) = fitParams(meth.paramss)
if ctx.mode.is(Mode.Type) then
cpy.DefDef(meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt))
else
cpy.DefDef(meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs))
else
val (paramsFst, paramsSnd) = fitParams(meth.paramss)
cpy.DefDef(meth)(paramss = paramsFst ++ paramsSnd)
end addEvidenceParams

/** The parameters generated from the contextual bounds of `meth`, as generated by `desugar.defDef` */
Expand Down Expand Up @@ -1224,27 +1298,29 @@ object desugar {
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
*/
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
val paramFlags = fun match
case fun: FunctionWithMods =>
// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)

// Function flags to be propagated to each parameter in the desugared method type.
val givenFlag = fun.mods.flags.toTermFlags & Given
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
case _ =>
vparamTypes.map(_ => EmptyFlags)

val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
}.toList

RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
)).withSpan(tree.span)
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = (tree: @unchecked) match
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) =>
val paramFlags = fun match
case fun: FunctionWithMods =>
// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)

// Function flags to be propagated to each parameter in the desugared method type.
val givenFlag = fun.mods.flags.toTermFlags & Given
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
case _ =>
vparamTypes.map(_ => EmptyFlags)

val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
}.toList

RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
.withFlags(Synthetic)
.withAttachment(PolyFunctionApply, ())
)).withSpan(tree.span)
end makePolyFunctionType

/** Invent a name for an anonympus given of type or template `impl`. */
Expand Down
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3459,7 +3459,7 @@ object Parsers {
*
* TypTypeParamClause::= ‘[’ TypTypeParam {‘,’ TypTypeParam} ‘]’
* TypTypeParam ::= {Annotation}
* (id | ‘_’) [HkTypeParamClause] TypeBounds
* (id | ‘_’) [HkTypeParamClause] TypeAndCtxBounds
*
* HkTypeParamClause ::= ‘[’ HkTypeParam {‘,’ HkTypeParam} ‘]’
* HkTypeParam ::= {Annotation} [‘+’ | ‘-’]
Expand Down Expand Up @@ -3490,7 +3490,9 @@ object Parsers {
else ident().toTypeName
val hkparams = typeParamClauseOpt(ParamOwner.Hk)
val bounds =
if paramOwner.acceptsCtxBounds then typeAndCtxBounds(name) else typeBounds()
if paramOwner.acceptsCtxBounds then typeAndCtxBounds(name)
else if in.featureEnabled(Feature.modularity) && paramOwner == ParamOwner.Type then typeAndCtxBounds(name)
else typeBounds()
TypeDef(name, lambdaAbstract(hkparams, bounds)).withMods(mods)
}
}
Expand Down
5 changes: 2 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1920,7 +1920,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
val tree1 = desugar.normalizePolyFunction(tree)
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
else typedPolyFunctionValue(tree1, pt)
else typedPolyFunctionValue(desugar.elimContextBounds(tree1).asInstanceOf[untpd.PolyFunction], pt)

def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
Expand Down Expand Up @@ -2474,7 +2474,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val TypeDef(_, impl: Template) = typed(refineClsDef): @unchecked
val refinements1 = impl.body
val seen = mutable.Set[Symbol]()
for (refinement <- refinements1) { // TODO: get clarity whether we want to enforce these conditions
for refinement <- refinements1 do // TODO: get clarity whether we want to enforce these conditions
typr.println(s"adding refinement $refinement")
checkRefinementNonCyclic(refinement, refineCls, seen)
val rsym = refinement.symbol
Expand All @@ -2488,7 +2488,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val member = refineCls.info.member(rsym.name)
if (member.isOverloaded)
report.error(OverloadInRefinement(rsym), refinement.srcPos)
}
assignType(cpy.RefinedTypeTree(tree)(tpt1, refinements1), tpt1, refinements1, refineCls)
}

Expand Down
Loading

0 comments on commit 5611522

Please sign in to comment.