From 6896ea24de8caab99652e9f2048773b1004f697d Mon Sep 17 00:00:00 2001 From: odersky Date: Mon, 18 Nov 2024 17:54:00 +0100 Subject: [PATCH] Attempt to beta reduce only if parameter and argument lists have same shape It's possible to define Functions with wrong apply methods by hand which will give an error but pass on a function that does fails beta reduction. Fixes #21952 --- .../tools/dotc/transform/BetaReduce.scala | 57 ++++++++++++------- .../tools/dotc/transform/InlinePatterns.scala | 8 ++- tests/neg/i21952.scala | 1 + 3 files changed, 43 insertions(+), 23 deletions(-) create mode 100644 tests/neg/i21952.scala diff --git a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala index 60c1bc7c61bb..16219055b8c0 100644 --- a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala +++ b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala @@ -76,10 +76,10 @@ object BetaReduce: val bindingsBuf = new ListBuffer[DefTree] def recur(fn: Tree, argss: List[List[Tree]]): Option[Tree] = fn match case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol => - Some(reduceApplication(ddef, argss, bindingsBuf)) + reduceApplication(ddef, argss, bindingsBuf) case Block((TypeDef(_, template: Template)) :: Nil, Typed(Apply(Select(New(_), _), _), _)) if template.constr.rhs.isEmpty => template.body match - case (ddef: DefDef) :: Nil => Some(reduceApplication(ddef, argss, bindingsBuf)) + case (ddef: DefDef) :: Nil => reduceApplication(ddef, argss, bindingsBuf) case _ => None case Block(stats, expr) if stats.forall(isPureBinding) => recur(expr, argss).map(cpy.Block(fn)(stats, _)) @@ -106,12 +106,22 @@ object BetaReduce: case _ => tree - /** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */ - def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree])(using Context): Tree = + /** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings. + * @return optionally, the expanded call, or none if the actual argument + * lists do not match in shape the formal parameters + */ + def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree]) + (using Context): Option[Tree] = val (targs, args) = argss.flatten.partition(_.isType) val tparams = ddef.leadingTypeParams val vparams = ddef.termParamss.flatten + def shapeMatch(paramss: List[ParamClause], argss: List[List[Tree]]): Boolean = (paramss, argss) match + case (params :: paramss1, args :: argss1) if params.length == args.length => + shapeMatch(paramss1, argss1) + case (Nil, Nil) => true + case _ => false + val targSyms = for (targ, tparam) <- targs.zip(tparams) yield targ.tpe.dealias match @@ -143,19 +153,26 @@ object BetaReduce: bindings += binding.withSpan(arg.span) bindingSymbol - val expansion = TreeTypeMap( - oldOwners = ddef.symbol :: Nil, - newOwners = ctx.owner :: Nil, - substFrom = (tparams ::: vparams).map(_.symbol), - substTo = targSyms ::: argSyms - ).transform(ddef.rhs) - - val expansion1 = new TreeMap { - override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match - case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const) - case tpe: TypeRef if tree.isTerm && tpe.derivesFrom(defn.UnitClass) && isPureExpr(tree) => - cpy.Literal(tree)(Constant(())) - case _ => super.transform(tree) - }.transform(expansion) - - expansion1 + if shapeMatch(ddef.paramss, argss) then + // We can't assume arguments always match. It's possible to construct a + // function with wrong apply method by hand which causes `shapeMatch` to fail. + // See neg/i21952.scala + val expansion = TreeTypeMap( + oldOwners = ddef.symbol :: Nil, + newOwners = ctx.owner :: Nil, + substFrom = (tparams ::: vparams).map(_.symbol), + substTo = targSyms ::: argSyms + ).transform(ddef.rhs) + + val expansion1 = new TreeMap { + override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match + case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const) + case tpe: TypeRef if tree.isTerm && tpe.derivesFrom(defn.UnitClass) && isPureExpr(tree) => + cpy.Literal(tree)(Constant(())) + case _ => super.transform(tree) + }.transform(expansion) + + Some(expansion1) + else None + end reduceApplication +end BetaReduce \ No newline at end of file diff --git a/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala b/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala index 18333ae506fd..d2a72e10fcfc 100644 --- a/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala +++ b/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala @@ -60,9 +60,11 @@ class InlinePatterns extends MiniPhase: template.body match case List(ddef @ DefDef(`name`, _, _, _)) => val bindings = new ListBuffer[DefTree]() - val expansion1 = BetaReduce.reduceApplication(ddef, argss, bindings) - val bindings1 = bindings.result() - seq(bindings1, expansion1) + BetaReduce.reduceApplication(ddef, argss, bindings) match + case Some(expansion1) => + val bindings1 = bindings.result() + seq(bindings1, expansion1) + case None => tree case _ => tree case _ => tree diff --git a/tests/neg/i21952.scala b/tests/neg/i21952.scala new file mode 100644 index 000000000000..0365d82463c0 --- /dev/null +++ b/tests/neg/i21952.scala @@ -0,0 +1 @@ +val _ = (new Function[(Int, Int), Int] {def apply(a: Int, b: Int): Int = a * b})(2, 3) // error