Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt to beta reduce only if parameters and arguments have same shape #21970

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions compiler/src/dotty/tools/dotc/transform/BetaReduce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ object BetaReduce:
val bindingsBuf = new ListBuffer[DefTree]
def recur(fn: Tree, argss: List[List[Tree]]): Option[Tree] = fn match
case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol =>
Some(reduceApplication(ddef, argss, bindingsBuf))
reduceApplication(ddef, argss, bindingsBuf)
case Block((TypeDef(_, template: Template)) :: Nil, Typed(Apply(Select(New(_), _), _), _)) if template.constr.rhs.isEmpty =>
template.body match
case (ddef: DefDef) :: Nil => Some(reduceApplication(ddef, argss, bindingsBuf))
case (ddef: DefDef) :: Nil => reduceApplication(ddef, argss, bindingsBuf)
case _ => None
case Block(stats, expr) if stats.forall(isPureBinding) =>
recur(expr, argss).map(cpy.Block(fn)(stats, _))
Expand All @@ -106,12 +106,22 @@ object BetaReduce:
case _ =>
tree

/** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */
def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree])(using Context): Tree =
/** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings.
* @return optionally, the expanded call, or none if the actual argument
* lists do not match in shape the formal parameters
*/
def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree])
(using Context): Option[Tree] =
val (targs, args) = argss.flatten.partition(_.isType)
val tparams = ddef.leadingTypeParams
val vparams = ddef.termParamss.flatten

def shapeMatch(paramss: List[ParamClause], argss: List[List[Tree]]): Boolean = (paramss, argss) match
case (params :: paramss1, args :: argss1) if params.length == args.length =>
shapeMatch(paramss1, argss1)
case (Nil, Nil) => true
case _ => false

val targSyms =
for (targ, tparam) <- targs.zip(tparams) yield
targ.tpe.dealias match
Expand Down Expand Up @@ -143,19 +153,26 @@ object BetaReduce:
bindings += binding.withSpan(arg.span)
bindingSymbol

val expansion = TreeTypeMap(
oldOwners = ddef.symbol :: Nil,
newOwners = ctx.owner :: Nil,
substFrom = (tparams ::: vparams).map(_.symbol),
substTo = targSyms ::: argSyms
).transform(ddef.rhs)

val expansion1 = new TreeMap {
override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match
case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const)
case tpe: TypeRef if tree.isTerm && tpe.derivesFrom(defn.UnitClass) && isPureExpr(tree) =>
cpy.Literal(tree)(Constant(()))
case _ => super.transform(tree)
}.transform(expansion)

expansion1
if shapeMatch(ddef.paramss, argss) then
// We can't assume arguments always match. It's possible to construct a
// function with wrong apply method by hand which causes `shapeMatch` to fail.
// See neg/i21952.scala
val expansion = TreeTypeMap(
oldOwners = ddef.symbol :: Nil,
newOwners = ctx.owner :: Nil,
substFrom = (tparams ::: vparams).map(_.symbol),
substTo = targSyms ::: argSyms
).transform(ddef.rhs)

val expansion1 = new TreeMap {
override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match
case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const)
case tpe: TypeRef if tree.isTerm && tpe.derivesFrom(defn.UnitClass) && isPureExpr(tree) =>
cpy.Literal(tree)(Constant(()))
case _ => super.transform(tree)
}.transform(expansion)

Some(expansion1)
else None
end reduceApplication
end BetaReduce
8 changes: 5 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ class InlinePatterns extends MiniPhase:
template.body match
case List(ddef @ DefDef(`name`, _, _, _)) =>
val bindings = new ListBuffer[DefTree]()
val expansion1 = BetaReduce.reduceApplication(ddef, argss, bindings)
val bindings1 = bindings.result()
seq(bindings1, expansion1)
BetaReduce.reduceApplication(ddef, argss, bindings) match
case Some(expansion1) =>
val bindings1 = bindings.result()
seq(bindings1, expansion1)
case None => tree
case _ => tree
case _ => tree

Expand Down
1 change: 1 addition & 0 deletions tests/neg/i21952.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
val _ = (new Function[(Int, Int), Int] {def apply(a: Int, b: Int): Int = a * b})(2, 3) // error
Loading