Skip to content

Commit

Permalink
added rewriting of curreid bridges, adopted to 3.5.x line
Browse files Browse the repository at this point in the history
  • Loading branch information
rssh committed Apr 22, 2024
1 parent c465876 commit 36f2e3b
Show file tree
Hide file tree
Showing 25 changed files with 297 additions and 101 deletions.
5 changes: 1 addition & 4 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
//val dottyVersion = "3.4.0-RC1-bin-SNAPSHOT"
//val dottyVersion = "3.3.2-RC1-bin-SNAPSHOT"
//val dottyVersion = "3.3.1-RC4"
val dottyVersion = "3.3.3"
//val dottyVersion = "3.4.2-RC1-bin-SNAPSHOT"
//val dottyVersion = "3.5.0-RC1-bin-SNAPSHOT"


ThisBuild/version := "0.9.22-SNAPSHOT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ trait CpsChangeSymbols {
try
val ntp = CpsTransformHelper.cpsTransformedErasedType(sym.info, monadType, sym.symbol.srcPos)
selectRecord.changedType = ntp
sym.copySymDenotation(info = ntp)
val retval = sym.copySymDenotation(info = ntp)
retval
catch
case ex:CpsTransformException =>
ex.printStackTrace()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import ast.tpd.*
import core.*
import core.Symbols.*
import core.Types.*
import core.Decorators.toTermName
import plugins.*
import cps.plugin.DefDefSelectKind.USING_CONTEXT_PARAM
import dotty.tools.dotc.core.Annotations.ConcreteAnnotation
import dotty.tools.dotc.core.DenotTransformers.{DenotTransformer, IdentityDenotTransformer, InfoTransformer}
import dotty.tools.dotc.transform.{Pickler, SetRootTree}

Expand All @@ -25,9 +27,10 @@ class PhaseSelectAndGenerateShiftedMethods(selectedNodes: SelectedNodes) extends
override def changesParents: Boolean = true





override def transformDefDef(tree: tpd.DefDef)(using Context): tpd.Tree = {
def transformDefDefDisabled(tree: tpd.DefDef)(using Context): tpd.Tree = {

lazy val cpsTransformedAnnot = Symbols.requiredClass("cps.plugin.annotation.CpsTransformed")

Expand Down Expand Up @@ -98,8 +101,15 @@ class PhaseSelectAndGenerateShiftedMethods(selectedNodes: SelectedNodes) extends
super.transformAssign(tree)
}

// generate shifted version for hight-order functions annotated by makeCPS
override def transformTemplate(tree: tpd.Template)(using Context): tpd.Tree = {

override def transformTemplate(tree: tpd.Template)(using Context): tpd.Tree = {

// annotated selected methds with CpsTransform
for(m <- tree.body) {
val changed = annotateTopMethodWithSelectKind(m)
}

// add shifted methods for @makeCPS annotated high-order members
val makeCpsAnnot = Symbols.requiredClass("cps.plugin.annotation.makeCPS")
val shiftedMethods = tree.body.filter(_.symbol.annotations.exists(_.symbol == makeCpsAnnot))
.flatMap { m =>
Expand All @@ -121,8 +131,25 @@ class PhaseSelectAndGenerateShiftedMethods(selectedNodes: SelectedNodes) extends
}
cpy.Template(tree)(body = tree.body ++ shiftedMethods)
}



def annotateTopMethodWithSelectKind(tree: tpd.Tree)(using Context): Boolean = {
lazy val cpsTransformedAnnot = Symbols.requiredClass("cps.plugin.annotation.CpsTransformed")
tree match
case dd: DefDef =>
val optKind = SelectedNodes.detectDefDefSelectKind(dd)
optKind match
case Some(kind) =>
val monadType = CpsTransformHelper.extractMonadType(kind.getCpsDirectContext.tpe, CpsTransformHelper.cpsDirectAliasSymbol, dd.srcPos)
val annotExpr = New(cpsTransformedAnnot.typeRef.appliedTo(monadType), Nil)
val initAnnotExpr = Apply(TypeApply(Select(annotExpr, "<init>".toTermName),List(TypeTree(monadType))),Nil)
dd.symbol.addAnnotation(ConcreteAnnotation(initAnnotExpr))
selectedNodes.addDefDef(dd.symbol,kind)
true
case None => false
case _ =>
false
}

}

object PhaseSelectAndGenerateShiftedMethods {
Expand Down
187 changes: 140 additions & 47 deletions compiler-plugin/src/main/scala/cps/plugin/RemoveScaffolding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,72 +32,165 @@ trait RemoveScaffolding {
report.error(s"plain tree: ${tree}", tree.srcPos)
}

selectedNodes.getDefDefRecord(tree.symbol) match
case Some(selectRecord) =>
//if (true || selectRecord.debugLevel > 10) {
// log(s"changeSymbol for ${tree.symbol} ")
// log(s"rhs ${tree.rhs.show} ")
// log(s"plain rhs ${tree.rhs} ")
//}
// here we see our defdefs with next changes:
// - erased type
// - type params are removed
// - all argument lists are merged into one
// - box/unbox for primitive types are inserted
tree.rhs match
case Scaffolding.Uncpsed(nRhs) =>
val changedDdefType = if (selectRecord.changedType != Types.NoType) {
selectRecord.changedType
} else {
CpsTransformHelper.cpsTransformedErasedType(tree.symbol.info, selectRecord.monadType, tree.srcPos)
}
val nTpt = retrieveReturnType(changedDdefType)
val typedNRhs = if (nRhs.tpe.widen <:< nTpt) {
nRhs
} else {
// we know that monads are not primitive types.
// (potentially, we can have monadic value classes in future)
TypeApply(Select(nRhs, "asInstanceOf".toTermName), List(TypeTree(nTpt)))
}
// TODO: insert asInstanceOf ?
cpy.DefDef(tree)(rhs = typedNRhs, tpt = TypeTree(nTpt))
case EmptyTree =>
tree
case _ =>
reportErrorWithTree(s"not found uncpsed scaffolding: for ${tree.symbol} (${tree.symbol.id})", tree.rhs)
tree
case None =>
tree
if (tree.symbol.is(Flags.Bridge)) then
// search for bridge which is generated for CpsDirect methods
// (it created in erasure after uncarrying (preserve carrying form of method))
tree.rhs match
case treeBlock@Block(List(ddef:DefDef), Closure(env,meth,tpe)) if meth.symbol == ddef.symbol =>
println(s"closure found, ddef.rhs=${ddef.rhs}")
ddef.rhs match
case Apply(fn, args) =>
fn.symbol.getAnnotation(Symbols.requiredClass("cps.plugin.annotation.CpsTransformed")) match
case Some(transformedAnnotation) =>
println(s"fn in closure has CpsTransformed annotation, ddef.rhs.tpe.widen=${ddef.rhs.tpe.widen.show}")
println(s"ddef.tpe.widen=${ddef.tpe.widen.show}, tree=${ddef.tpe.widen}")
ddef.tpe.widen match
case mt: MethodOrPoly =>
val nType = mt.derivedLambdaType(resType = ddef.rhs.tpe.widen)
//val nDdef = ddef.withType( mt.derivedLambdaType(resType = ddef.rhs.tpe.widen ))
//val nDdef = cpy.DefDef(ddef)(tpt = TypeTree(ddef.rhs.tpe.widen))
val newDdefSymbol = Symbols.newSymbol(ddef.symbol.owner, ddef.name, ddef.symbol.flags, nType)
val nDdef = DefDef(newDdefSymbol, paramss => {
val paramsMap = (ddef.paramss zip paramss).foldLeft(Map.empty[Symbol,Tree]){ case (s,(psOld,psNew)) =>
(psOld zip psNew).foldLeft(s){ case (s,(pOld,pNew)) =>
s.updated(pOld.symbol, ref(pNew.symbol).withSpan(pOld.span))
}
}
TransformUtil.substParamsMap(ddef.rhs, paramsMap)
})
println(s"ddef.symbol.hashCode=${ddef.symbol.hashCode()} nDdef.symbol.hashCode=${nDdef.symbol.hashCode()}")
println(s"fn=${fn}, args=${args}")
cpy.DefDef(tree)(rhs = Block(List(nDdef), Closure(env,ref(newDdefSymbol),tpe)).withSpan(treeBlock.span))
case _ =>
throw CpsTransformException("Assumed that ddef.tpe.widen is MethodOrPoly", ddef.srcPos)
case None =>
println(s"fn in closure has no annotation")
tree
case _ =>
tree
else
selectedNodes.getDefDefRecord(tree.symbol) match
case Some(selectRecord) =>
//if (true || selectRecord.debugLevel > 10) {
// log(s"changeSymbol for ${tree.symbol} ")
// log(s"rhs ${tree.rhs.show} ")
// log(s"plain rhs ${tree.rhs} ")
//}
// here we see our defdefs with next changes:
// - erased type
// - type params are removed
// - all argument lists are merged into one
// - box/unbox for primitive types are inserted
tree.rhs match
case Scaffolding.Uncpsed(nRhs) =>
val changedDdefType = if (selectRecord.changedType != Types.NoType) {
selectRecord.changedType
} else {
CpsTransformHelper.cpsTransformedErasedType(tree.symbol.info, selectRecord.monadType, tree.srcPos)
}
val nTpt = retrieveReturnType(changedDdefType)
val typedNRhs = if (nRhs.tpe.widen <:< nTpt) {
nRhs
} else {
// we know that monads are not primitive types.
// (potentially, we can have monadic value classes in future)
TypeApply(Select(nRhs, "asInstanceOf".toTermName), List(TypeTree(nTpt)))
}
// TODO: insert asInstanceOf ?
cpy.DefDef(tree)(rhs = typedNRhs, tpt = TypeTree(nTpt))
case EmptyTree =>
tree
case _ =>
reportErrorWithTree(s"not found uncpsed scaffolding: for ${tree.symbol} (${tree.symbol.id})", tree.rhs)
tree
case None =>
tree
}




override def transformApply(tree: Apply)(using ctx: Context): Tree = {

def retypeFn(fn: Tree):Tree = {
val runRetype = false

def retypeFn(fn: Tree) :Tree = {
fn match
case id: Ident =>
if (id.symbol.hasAnnotation(Symbols.requiredClass("cps.plugin.annotation.CpsTransformed"))) then
println(s"fn has annotation: ${id.symbol.showFullName}")
else
println(s"fn has no annotation")
selectedNodes.getDefDefRecord(id.symbol) match
case Some(selectRecord) =>
println("fn in selectRecord")
case None =>
println(s"fn not in selectRecord, id=${id.show}, id.symbol=${id.symbol.showFullName}")
val retval = ref(id.symbol).withSpan(id.span) // here this will be symbol after phase CpsChangeSymbols
retval
case sel: Select =>
val retval = Select(sel.qualifier,sel.name).withSpan(sel.span)
retval
//case sel: Select =>
// val retval = Select(sel.qualifier,sel.name).withSpan(sel.span)
// retval
case _ =>
fn
}


tree match
case Scaffolding.Cpsed(cpsedCall) =>
// here we need to retype arg because we change the type of symbols.
val cpsedCallRetyped = cpsedCall match
case Apply(fn, args) =>
Apply(retypeFn(fn), args).withSpan(cpsedCall.span)
case _ => cpsedCall
cpsedCallRetyped
if (runRetype) then
val cpsedCallRetyped = cpsedCall match
case Apply(fn, args) =>
val retval =
try
val fnRetyped = retypeFn(fn)
Apply(fnRetyped, args).withSpan(cpsedCall.span)
catch
case ex: Throwable =>
println(s"RemoveScaffolding error: fn=${fn.show}, args=${args.map(_.show).mkString(",")}")
throw ex
retval
case _ =>
cpsedCall
cpsedCallRetyped
else
cpsedCall
case Apply(fn, args) =>
if (fn.symbol.hasAnnotation(Symbols.requiredClass("cps.plugin.annotation.CpsTransformed"))) then
println(s"RemoveScaffolding::Apply, ${tree.show} fn has CpsTransformed annotation: ${fn.symbol.showFullName}")
println(s"fn.tpe.widen=${fn.tpe.widen.show}")
println(s"tree.tpe.widen=${tree.tpe.widen.show}")
tree
else
tree
// selectedNodes.getDefDefRecord(fn.symbol) match
// case Some(selectRecord) =>
// println(s"RemoveScaffolding: foudn apply with selectRecord, tree.tpe=${tree.tpe.show}, ")
// ???
// case None =>
// tree
case _ =>
tree
}

override def transformIdent(tree: Ident)(using Context): Tree = {
if (tree.symbol.hasAnnotation(Symbols.requiredClass("cps.plugin.annotation.CpsTransformed"))) then
println(s"RemoveScaffolding::Ident, ${tree.show} has CpsTransformed annotation: ${tree.symbol.showFullName}")
println(s"tree.tpe.widen=${tree.tpe.widen.show}, tree.symbol.info.widen=${tree.symbol.info.widen}")
ref(tree.symbol).withSpan(tree.span)
else
tree
}

override def transformSelect(tree: Select)(using Context): Tree = {
if (tree.symbol.hasAnnotation(Symbols.requiredClass("cps.plugin.annotation.CpsTransformed"))) then
println(s"RemoveScaffolding::Select, ${tree.show} has CpsTransformed annotation: ${tree.symbol.showFullName}")
println(s"sel.tpe.widen=${tree.tpe.widen.show}, sel.symbol.info.widen=${tree.symbol.info.widen}")
println(s"sel.qualifier.tpe.widen=${tree.qualifier.tpe.widen.show}, ${tree.tpe.show}")
println(s"sel.qualifier.symbol.infos=${tree.qualifier.symbol.info.show}")
tree
else
tree
}


def retrieveReturnType(ddefType: Type)(using Context): Type = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ enum DefDefSelectKind {
case USING_CONTEXT_PARAM(cmc) => cmc
case RETURN_CONTEXT_FUN(internal) => internal.getCpsDirectContext


}

/**
Expand Down Expand Up @@ -98,7 +99,7 @@ object SelectedNodes {
/**
*
* @param tree: tree to process
* @param f: (defDef(selected function), Tree: (CpsMonadContext parameter)) => Option[A] is called
* @param f: (defDef(selected function), Tree: (CpsMonadContext parameter), index:Int) => Option[A] is called
* when we find parameter with type CpsMonadContext.
* @param acc
* @param Context
Expand Down
14 changes: 9 additions & 5 deletions compiler-plugin/src/test/scala/cc/DotcInvocations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ object DotcInvocations {

def compileFilesInDir(dir: String, invocationArgs: DotcInvocationArgs = DotcInvocationArgs()): Reporter = {
val dotcInvocations = new DotcInvocations(invocationArgs.silent)
dotcInvocations.compileFilesInDir(dir, invocationArgs.outDir.getOrElse(dir),
dotcInvocations.compileFilesInDir(dir, invocationArgs.outDir.getOrElse(s"${dir}-classes"),
invocationArgs.extraDotcArgs, invocationArgs.checkAll, invocationArgs.usePlugin)
dotcInvocations.reporter
}
Expand All @@ -187,7 +187,7 @@ object DotcInvocations {
): Unit = {
val dotcInvocations = new DotcInvocations(invocationArgs.silent)

val (code, output) = dotcInvocations.compileAndRunFilesInDirJVM(dir,invocationArgs.outDir.getOrElse(dir),
val (code, output) = dotcInvocations.compileAndRunFilesInDirJVM(dir,invocationArgs.outDir.getOrElse(s"${dir}-classes"),
mainClass,invocationArgs.extraDotcArgs,invocationArgs.checkAll,invocationArgs.usePlugin)

val reporter = dotcInvocations.reporter
Expand Down Expand Up @@ -220,7 +220,7 @@ object DotcInvocations {
sourceDir: String,
compiledFlag: IsAlreadyCompiledFlag
) {
def outDir = sourceDir
def outDir = s"${sourceDir}-classes"
}


Expand All @@ -236,9 +236,13 @@ object DotcInvocations {
}
val baseClassPath = if (invocationArgs.useScalaJsLib) currentJsClasspath else System.getProperty("java.class.path")
val classpath1 = s"${dependency.outDir}:${baseClassPath}"
val secondInvokationArgs = invocationArgs.copy(extraDotcArgs = List("-classpath", classpath1) ++ invocationArgs.extraDotcArgs)
val secondOutDir = s"${dirname}-classes"
val secondInvokationArgs = invocationArgs.copy(
extraDotcArgs = List("-classpath", classpath1) ++ invocationArgs.extraDotcArgs,
outDir = Some(secondOutDir)
)
DotcInvocations.succesfullyCompileFilesInDir(dirname, secondInvokationArgs)
val classpath2 = s"${dirname}:${classpath1}"
val classpath2 = s"${secondOutDir}:${classpath1}"
val mainClass = "testUtil.JunitMain"
val cmd = s"java -cp $classpath2 $mainClass $testClassName"
println(s"Running $cmd")
Expand Down
4 changes: 2 additions & 2 deletions compiler-plugin/src/test/scala/cc/Test12.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Test12 {
def testCompileAndRunM1(): Unit = {
val dotcInvocations = new DotcInvocations()
val (codeOrErrors, output) =
dotcInvocations.compileAndRunFilesInDirJVM("testdata/set12/m1", "testdata/set12/m1", "cpstest.Test12m1")
dotcInvocations.compileAndRunFilesInDirJVM("testdata/set12/m1", "testdata/set12/m1-classes", "cpstest.Test12m1")
val reporter = dotcInvocations.reporter
if (reporter.errorCount == 0) then
println(s"output=${output}")
Expand Down Expand Up @@ -50,7 +50,7 @@ class Test12 {
def testCompileAndRunM3(): Unit =
val dotcInvocations = new DotcInvocations()
val (code, output) =
dotcInvocations.compileAndRunFilesInDirJVM("testdata/set12/m3", "testdata/set12/m3", "cpstest.Test12m3")
dotcInvocations.compileAndRunFilesInDirJVM("testdata/set12/m3", "testdata/set12/m3-classes", "cpstest.Test12m3")
val reporter = dotcInvocations.reporter
//println("summary: " + reporter.summary)
//println(s"output=${output}")
Expand Down
2 changes: 1 addition & 1 deletion compiler-plugin/src/test/scala/cc/Test13TestCases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Test13TestCases {

def compileAfterCommon(dirname: String): Unit = {
compileCommon()
val classpath = s"testdata/set13TestCases/common:${System.getProperty("java.class.path")}"
val classpath = s"testdata/set13TestCases/common-classes:${System.getProperty("java.class.path")}"
val secondInvokationArgs = DotcInvocationArgs(extraDotcArgs = List("-classpath", classpath))
DotcInvocations.succesfullyCompileFilesInDir(dirname, secondInvokationArgs)
}
Expand Down
Loading

0 comments on commit 36f2e3b

Please sign in to comment.