diff --git a/build.sbt b/build.sbt index 31e89f3fe..82f178a21 100644 --- a/build.sbt +++ b/build.sbt @@ -1,5 +1,5 @@ val dottyVersion = "3.3.3" -//val dottyVersion = "3.5.0-RC1-bin-SNAPSHOT" +//val dottyVersion = "3.5.1-RC1-bin-SNAPSHOT" ThisBuild/version := "0.9.22-SNAPSHOT" @@ -126,8 +126,8 @@ lazy val compilerPlugin = project.in(file("compiler-plugin")) "org.scala-lang" %% "scala3-compiler" % scalaVersion.value % "provided", "com.github.sbt" % "junit-interface" % "0.13.3" % "test", ("org.scala-js" %% "scalajs-linker" % "1.16.0").cross(CrossVersion.for3Use2_13) % "test", - ("org.scala-js" %% "scalajs-ir" % "1.16.0").cross(CrossVersion.for3Use2_13) % "test", - ("org.scala-js" %% "scalajs-library" % "1.16.0").cross(CrossVersion.for3Use2_13) % "test", + //("org.scala-js" %% "scalajs-ir" % "1.16.0").cross(CrossVersion.for3Use2_13) % "test", + //("org.scala-js" %% "scalajs-library" % "1.16.0").cross(CrossVersion.for3Use2_13) % "test", ("org.scala-js" %% "scalajs-env-nodejs" % "1.4.0").cross(CrossVersion.for3Use2_13) % "test", ), // TODO: split test into subdirectories. diff --git a/compiler-plugin/src/main/scala/cps/plugin/CpsPlugin.scala b/compiler-plugin/src/main/scala/cps/plugin/CpsPlugin.scala index fdc32fe66..514cb5bc9 100644 --- a/compiler-plugin/src/main/scala/cps/plugin/CpsPlugin.scala +++ b/compiler-plugin/src/main/scala/cps/plugin/CpsPlugin.scala @@ -19,7 +19,7 @@ class CpsPlugin extends StandardPlugin { // - def init(options: List[String]): List[PluginPhase] = { + override def init(options: List[String]): List[PluginPhase] = { val settings = parseOptions(options) val selectedNodes = new SelectedNodes() List( @@ -35,6 +35,10 @@ class CpsPlugin extends StandardPlugin { if (option.startsWith("debugLevel=")) { val level = option.substring("debugLevel=".length).toInt settings.debugLevel = level + } else if (option == "printTree") { + settings.printTree = true + } else if (option == "printCode") { + settings.printCode = true } else if (option == "useLoom") { settings.useLoom = true } else { diff --git a/compiler-plugin/src/main/scala/cps/plugin/CpsPluginSettings.scala b/compiler-plugin/src/main/scala/cps/plugin/CpsPluginSettings.scala index 59e8e2302..0ebcb5782 100644 --- a/compiler-plugin/src/main/scala/cps/plugin/CpsPluginSettings.scala +++ b/compiler-plugin/src/main/scala/cps/plugin/CpsPluginSettings.scala @@ -2,6 +2,8 @@ package cps.plugin class CpsPluginSettings(var useLoom: Boolean = true, var debugLevel: Int = 0, + var printTree: Boolean = false, + var printCode: Boolean = false, var withShiftReplaceStages: Boolean = false, var transformDirectContextLambda: Boolean = false, diff --git a/compiler-plugin/src/main/scala/cps/plugin/DebugSettings.scala b/compiler-plugin/src/main/scala/cps/plugin/DebugSettings.scala index a36cf8457..a963660ce 100644 --- a/compiler-plugin/src/main/scala/cps/plugin/DebugSettings.scala +++ b/compiler-plugin/src/main/scala/cps/plugin/DebugSettings.scala @@ -24,21 +24,25 @@ case class DebugSettings( object DebugSettings { - def make(from:Tree)(using Context):DebugSettings = { + def make(from:Tree, pluginSettings: CpsPluginSettings)(using Context):DebugSettings = { val debugLevelAnSym = Symbols.requiredClass("cps.plugin.annotation.CpsDebugLevel") val scopeOwner = if (from.symbol != NoSymbol) then from.symbol else summon[Context].owner - val debugLevel: Int = findEnclosingAnnotation(scopeOwner, debugLevelAnSym) match - case Some(an) => - an.argument(0) match - case Some(Literal(Constant(v:Int))) => v - case other => - throw CpsTransformException(s"CpsDebugLevelAnnotation should have literal constant as argument, we have $other",an.tree.srcPos) - case None => - 0 + val debugLevel: Int = + if (pluginSettings.debugLevel != 0) + pluginSettings.debugLevel + else + findEnclosingAnnotation(scopeOwner, debugLevelAnSym) match + case Some(an) => + an.argument(0) match + case Some(Literal(Constant(v:Int))) => v + case other => + throw CpsTransformException(s"CpsDebugLevelAnnotation should have literal constant as argument, we have $other",an.tree.srcPos) + case None => + 0 if (false) { //don't work after 'cc' stage (always show name) //TODO: look after inlinging or ask user to enable retain-tree @@ -56,10 +60,24 @@ object DebugSettings { case None => 0 } - val printCodeTpe = Symbols.requiredClass("cps.plugin.settings.PrintCode").typeRef - val printCode = CpsTransformHelper.findImplicitInstance(printCodeTpe, summon[Context].tree.span).isDefined - val printTreeTpe = Symbols.requiredClass("cps.plugin.settings.PrintTree").typeRef - val printTree = CpsTransformHelper.findImplicitInstance(printTreeTpe, summon[Context].tree.span).isDefined + val printCode = { + pluginSettings.printCode || { + val oldPrintCodeTpe = Symbols.requiredClass("cps.macros.flags.PrintCode").typeRef + val printCodeTpe = Symbols.requiredClass("cps.plugin.settings.PrintCode").typeRef + println(s"context = ${summon[Context].tree.show}, phase = ${summon[Context].phase}") + CpsTransformHelper.findImplicitInstance(printCodeTpe, summon[Context].tree.span).isDefined || + CpsTransformHelper.findImplicitInstance(oldPrintCodeTpe, summon[Context].tree.span).isDefined + } + } + val printTree = { + pluginSettings.printTree || { + val oldPrintTreeTpe = Symbols.requiredClass("cps.macros.flags.PrintTree").typeRef + val printTreeTpe = Symbols.requiredClass("cps.plugin.settings.PrintTree").typeRef + CpsTransformHelper.findImplicitInstance(printTreeTpe, summon[Context].tree.span).isDefined || + CpsTransformHelper.findImplicitInstance(oldPrintTreeTpe, summon[Context].tree.span).isDefined + + } + } DebugSettings(debugLevel = debugLevel, printTree = printTree, printCode = printCode) } diff --git a/compiler-plugin/src/main/scala/cps/plugin/PhaseCps.scala b/compiler-plugin/src/main/scala/cps/plugin/PhaseCps.scala index 54128ff34..c4d04180c 100644 --- a/compiler-plugin/src/main/scala/cps/plugin/PhaseCps.scala +++ b/compiler-plugin/src/main/scala/cps/plugin/PhaseCps.scala @@ -65,10 +65,8 @@ class PhaseCps(settings: CpsPluginSettings, def transformDefDefInternal(tree: DefDef, selectRecord: DefDefSelectRecord, optTopLevelContext:Option[CpsTopLevelContext]=None)(using Context): DefDef = { - val debugSettings = optTopLevelContext.map(_.debugSettings).getOrElse(DebugSettings.make(tree)) + val debugSettings = optTopLevelContext.map(_.debugSettings).getOrElse(DebugSettings.make(tree, settings)) selectRecord.debugLevel = debugSettings.debugLevel - if (debugSettings.printCode) then - report.log("transforming tree:", tree.srcPos) val retval = selectRecord.kind match case USING_CONTEXT_PARAM(cpsMonadContextArg) => val cpsMonadContext = ref(cpsMonadContextArg.symbol) @@ -82,10 +80,14 @@ class PhaseCps(settings: CpsPluginSettings, //selectRecord.changedReturnType = nTpt given CpsTopLevelContext = tc val ctx1: Context = summon[Context].withOwner(tree.symbol) + if (debugSettings.printCode) then + Log.info(s"transformDefDefIntenal: ${tree.show}",0, tree.srcPos) val transformedRhs = RootTransform(tree.rhs,tree.symbol, 0)(using ctx1, tc).transformed val nRhs = Block(monadValDef::Nil,transformedRhs)(using ctx1) val adoptedRhs = Scaffolding.adoptUncpsedRhs(nRhs, tree.tpt.tpe, tc.monadType) val retval = cpy.DefDef(tree)(tree.name, tree.paramss, tree.tpt, adoptedRhs) + if (debugSettings.printCode) then + Log.info(s"transformDefDefInternal: transformed: ${retval.show}",0, tree.srcPos) retval case RETURN_CONTEXT_FUN(internalKind) => val cpsDirectContext = ref(selectRecord.kind.getCpsDirectContext.symbol) @@ -104,12 +106,17 @@ class PhaseCps(settings: CpsPluginSettings, tree case _ => throw CpsTransformException("Lambda function was expected, we have $tree",tree.srcPos) - if (debugSettings.printCode) then - report.log(s"transforned: ${retval.show}",tree.srcPos) - report.log(s"transforned: ${retval}",tree.srcPos) retval } + override def prepareForApply(tree: tpd.Apply)(using Context): Context = { + if (summon[Context].phase != this) { + println(s"PhaseCps::prepareForApply, invalid phase = ${summon[Context].phase}") + summon[Context].withPhase(this) + } else { + summon[Context] + } + } override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree = { try @@ -187,11 +194,11 @@ class PhaseCps(settings: CpsPluginSettings, case Some(tree) => val (tc, monadValDef) = tree match case CpsDirectHelper.ByInclusionCall(tf, tg, fctx, fginc) => - makeCpsTopLevelContext(fctx, summon[Context].owner, a.srcPos, DebugSettings.make(a), CpsTransformHelper.cpsMonadContextClassSymbol) + makeCpsTopLevelContext(fctx, summon[Context].owner, a.srcPos, DebugSettings.make(a, settings), CpsTransformHelper.cpsMonadContextClassSymbol) case CpsDirectHelper.NewCall(fctx) => - makeCpsTopLevelContext(fctx, summon[Context].owner, a.srcPos, DebugSettings.make(a), CpsTransformHelper.cpsMonadContextClassSymbol) + makeCpsTopLevelContext(fctx, summon[Context].owner, a.srcPos, DebugSettings.make(a, settings), CpsTransformHelper.cpsMonadContextClassSymbol) case other => - makeCpsTopLevelContext(other, summon[Context].owner, a.srcPos, DebugSettings.make(a), CpsTransformHelper.cpsDirectAliasSymbol) + makeCpsTopLevelContext(other, summon[Context].owner, a.srcPos, DebugSettings.make(a, settings), CpsTransformHelper.cpsDirectAliasSymbol) val nTree = { given CpsTopLevelContext = tc RootTransform(a, summon[Context].owner, 0).transformed @@ -206,7 +213,8 @@ class PhaseCps(settings: CpsPluginSettings, super.transformApply(tree) case Apply(Apply(TypeApply(deferredAsyncCn, List(tp,mtp,mctp)), List(applyTerm)), List(ctx)) if (deferredAsyncCn.symbol == Symbols.requiredMethod("cps.plugin.scaffolding.deferredAsync")) => - val (tc, monadValDef) = makeCpsTopLevelContext(ctx,summon[Context].owner, tree.srcPos, DebugSettings.make(tree), CpsTransformHelper.cpsMonadContextClassSymbol) + val (tc, monadValDef) = makeCpsTopLevelContext(ctx,summon[Context].owner, tree.srcPos, + DebugSettings.make(tree, settings), CpsTransformHelper.cpsMonadContextClassSymbol) val nApplyTerm = { given CpsTopLevelContext = tc RootTransform(applyTerm, summon[Context].owner, 0).transformed @@ -234,7 +242,12 @@ class PhaseCps(settings: CpsPluginSettings, val contextParam = cpsMonadContext match case vd: ValDef => ref(vd.symbol) case _ => throw CpsTransformException(s"excepted that cpsMonadContext is ValDef, but we have ${cpsMonadContext.show}", asyncCallTree.srcPos) - val (tctx, monadValDef) = makeCpsTopLevelContext(contextParam, ddef.symbol, asyncCallTree.srcPos, DebugSettings.make(asyncCallTree), CpsTransformHelper.cpsMonadContextClassSymbol) + val (tctx, monadValDef) = makeCpsTopLevelContext(contextParam, ddef.symbol, asyncCallTree.srcPos, + DebugSettings.make(ddef, settings), CpsTransformHelper.cpsMonadContextClassSymbol) + if (tctx.debugSettings.printCode) { + Log.info(s"transformDefDefInsideAsync: ${ddef.show}", 0, ddef.srcPos)(using ctx, tctx) + Log.info(s"transformDefDefInsideAsync: body: ${ddef.rhs.show}", 0, ddef.srcPos)(using ctx, tctx) + } val ddefCtx = ctx.withOwner(ddef.symbol) val nRhsCps = RootTransform(ddef.rhs, ddef.symbol, 0)(using ddefCtx, tctx) val nRhsTerm = wrapTopLevelCpsTree(nRhsCps)(using ddefCtx, tctx) @@ -259,7 +272,10 @@ class PhaseCps(settings: CpsPluginSettings, // case Some(tree) => println(s"err::symbol still have old owner: ${tree.show}") // case None => //} - + if (tctx.debugSettings.printCode) { + Log.info(s"transformDefDefInsideAsync: transformed: ${nDefDef.show}", 0, ddef.srcPos)(using ctx, tctx) + Log.info(s"transformDefDefInsideAsync: transformed body: ${nDefDef.rhs.show}", 0, ddef.srcPos)(using ctx, tctx) + } nDefDef } @@ -313,7 +329,8 @@ class PhaseCps(settings: CpsPluginSettings, Inlined(call, env, transformDefDef2InsideCpsAsyncStreamApply(body, ctxRef)) case Block((ddef: DefDef)::Nil, closure: Closure) if (ddef.symbol == closure.meth.symbol) => //val monadType = CpsTransformHelper.extractMonadType(cpsMonadContext.tpe.widen, CpsTransformHelper.cpsMonadContextClassSymbol, asyncCallTree.srcPos) - val (tctx, monadValDef) = makeCpsTopLevelContext(ctxRef, ddef.symbol, ddef.rhs.srcPos, DebugSettings.make(ddef), CpsTransformHelper.cpsMonadContextClassSymbol) + val (tctx, monadValDef) = makeCpsTopLevelContext(ctxRef, ddef.symbol, ddef.rhs.srcPos, + DebugSettings.make(ddef, settings), CpsTransformHelper.cpsMonadContextClassSymbol) val ddefContext = ctx.withOwner(ddef.symbol) val nRhsCps = RootTransform(ddef.rhs, ddef.symbol, 0)(using ddefContext, tctx) val nRhs = Block(monadValDef.changeOwner(monadValDef.symbol.owner,ddef.symbol)::Nil, nRhsCps.transformed(using ddefContext, tctx)) diff --git a/compiler-plugin/src/main/scala/cps/plugin/forest/ApplyTransform.scala b/compiler-plugin/src/main/scala/cps/plugin/forest/ApplyTransform.scala index 9f7e6003c..f628926dc 100644 --- a/compiler-plugin/src/main/scala/cps/plugin/forest/ApplyTransform.scala +++ b/compiler-plugin/src/main/scala/cps/plugin/forest/ApplyTransform.scala @@ -373,13 +373,25 @@ object ApplyTransform { } def adoptCallMode(origin: Tree, plainTree: Tree, funSym: Symbol, owner: Symbol, argss: List[ApplyArgList], callMode: FunCallMode, nesting: Int)(using Context, CpsTopLevelContext): CpsTree = { - Log.trace(s"adoptCallMode: plainTree=${plainTree.show}, callMode=${callMode} funSym=${funSym}", nesting) - if (argss.exists(_.containsDirectContext) && !funSym.hasAnnotation(Symbols.requiredClass("cps.plugin.annotation.CpsNotChange"))) { + Log.trace(s"adoptCallMode: plainTree=${plainTree.show}, callMode=${callMode} funSym=${funSym}, haveDirectContext=${argss.exists(_.containsDirectContext)}", nesting) + if (argss.exists(_.containsDirectContext)) + //if (funSym.hasAnnotation(Symbols.requiredClass("cps.plugin.annotation.CpsNotChange"))) { + // throw CpsTransformException(s"DirectContext function can't be annotated with @CpsNotChange", origin.srcPos) + //} + val noChangeAnnotation = funSym.hasAnnotation(Symbols.requiredClass("cps.plugin.annotation.CpsNotChange")) val directContextArg = argss.find(_.containsDirectContext).flatMap(_.findDirectContext).get + Log.trace(s"adoptCallMode: directContextArg=${directContextArg.show}", nesting) val adoptedTree = directContextArg match case dc@CpsDirectHelper.ByInclusionCall(tf,tg,fctx,fgincl) => + Log.trace("!!!adoptCallMode: ByInclusionCall", nesting) + if (funSym.hasAnnotation(Symbols.requiredClass("cps.plugin.annotation.CpsNotChange"))) { + println("byInclusdionCall with CpsNotChange") + println(s" tf.tpe=${tf.tpe.show}") + throw CpsTransformException(s"DirectContext function can't be annotated with @CpsNotChange", origin.srcPos) + } val callArgs = CpsDirectHelper.ByInclusionCallArgs(tf,tg,fctx,fgincl) if (tf.tpe =:= tg.tpe) then + Log.trace(s"!!!adoptCallMode: ByInclusionCall, tf.tpe =:= tg.tpe = ${tf.tpe.show}", nesting) val nCpsDirectArg = CpsDirectHelper.genCpsDirectDefaultConstructor(TypeTree(tf.tpe),fctx,dc.span) val tree = CpsDirectHelper.substituteCpsDirectArgInCall(plainTree, callArgs, nCpsDirectArg).getOrElse( throw CpsTransformException("Internal error: can't find direct context argument in call", origin.srcPos) @@ -411,7 +423,14 @@ object ApplyTransform { }) CpsDirectHelper.genConventionCall(fctx,fgincl,origin.tpe.widen,lambda,origin.span) case other => - Scaffolding.adoptCpsedCall(plainTree, plainTree.tpe.widen, summon[CpsTopLevelContext].monadType) + if (funSym.hasAnnotation(Symbols.requiredClass("cps.plugin.annotation.CpsNotChange"))) { + // no adoption + plainTree + //println(s"directContextArg=${directContextArg.show}") + //throw CpsTransformException(s"DirectContext function can't be annotated with @CpsNotChange, fun=${funSym.fullName}", origin.srcPos) + } else { + Scaffolding.adoptCpsedCall(plainTree, plainTree.tpe.widen, summon[CpsTopLevelContext].monadType) + } //if (isImpure) { // TODO: such situationis possible when we pass lamba with context parameters (can be inline) // TODO: separate this case. @@ -425,15 +444,20 @@ object ApplyTransform { CpsTree.impure(origin, owner, adoptedTree, internalKind) */ //CpsTree.impure(origin, owner, adoptedTree, AsyncKind.Sync) - adoptResultKind(origin, adoptedTree, owner, callMode, true) - } else { - adoptResultKind(origin, plainTree, owner, callMode, false) - } + val useDirectContext = ! noChangeAnnotation + adoptResultKind(origin, plainTree, adoptedTree, owner, callMode, useDirectContext) + else + adoptResultKind(origin, plainTree, plainTree, owner, callMode, false) } - def adoptResultKind(origin:Tree, newApply: Tree, owner: Symbol, callMode: FunCallMode, usingDirectContext: Boolean)(using Context, CpsTopLevelContext): CpsTree = { + def adoptResultKind(origin:Tree, + newApply: Tree, // changed apply, but not wrapped in adoptCpsedCall + wrappedNewApply: Tree, // if usignDirectContext then adoptCpsedCall(newApply) otherwis - same as newApply + owner: Symbol, + callMode: FunCallMode, + usingDirectContext: Boolean)(using Context, CpsTopLevelContext): CpsTree = { if (callMode.argCallMode == ApplyArgCallMode.ASYNC_SHIFT || callMode.fromCallChain) { if (newApply.tpe.baseType(Symbols.requiredClass("cps.runtime.CallChainAsyncShiftSubst"))!=NoType) { @@ -447,22 +471,25 @@ object ApplyTransform { val newType = newApply.tpe.widen if (originType =:= newType) { if (usingDirectContext) then - CpsTree.impure(origin,owner,newApply,AsyncKind.Sync) + CpsTree.impure(origin,owner,wrappedNewApply,AsyncKind.Sync) else CpsTree.pure(origin,owner,newApply) } else if (newType <:< summon[CpsTopLevelContext].monadType.appliedTo(WildcardType)) { val adoptedApply = if (usingDirectContext) { - report.warning("async-shifted function with direct context return wrapped type", origin.srcPos) + report.warning(s"async-shifted function with direct context return wrapped type", origin.srcPos) + report.warning(s"origin tree: ${origin.show}", origin.srcPos) + report.warning(s"transformed tree: ${newApply.show}", origin.srcPos) + report.warning(s"debugLevel: ${summon[CpsTopLevelContext].pluginSettings.debugLevel}", origin.srcPos) Apply( TypeApply( Select(summon[CpsTopLevelContext].cpsMonadRef, "flatten".toTermName), List(TypeTree(originType.widen)) ), - List(newApply) + List(wrappedNewApply) ) } else { - newApply + wrappedNewApply } CpsTree.impure(origin, owner, adoptedApply, AsyncKind.Sync) } else if (callMode.asyncLambdaApplication.isDefined) { @@ -473,9 +500,9 @@ object ApplyTransform { } else { // TODO: warn about possible unsafe result type if (usingDirectContext) { - CpsTree.impure(origin,owner,newApply,AsyncKind.Sync) + CpsTree.impure(origin,owner,wrappedNewApply,AsyncKind.Sync) } else { - CpsTree.pure(origin, owner, newApply) + CpsTree.pure(origin, owner, wrappedNewApply) } } } @@ -487,9 +514,9 @@ object ApplyTransform { callMode.asyncLambdaApplication.get match case AsyncKind.Sync => if (usingDirectContext && transformDirectContextLambdaCall) { - CpsTree.impure(origin,owner,newApply,AsyncKind.Sync) + CpsTree.impure(origin,owner,wrappedNewApply,AsyncKind.Sync) } else { - CpsTree.pure(origin,owner,newApply) + CpsTree.pure(origin,owner,wrappedNewApply) } case AsyncKind.Async(internalKind) => if (usingDirectContext && transformDirectContextLambdaCall) { @@ -498,22 +525,22 @@ object ApplyTransform { Select(summon[CpsTopLevelContext].cpsMonadRef, "flatten".toTermName), List(TypeTree(newApply.tpe.widen)) ), - List(newApply) + List(wrappedNewApply) ) CpsTree.impure(origin,owner,flattenedNewApply,internalKind) } else { - CpsTree.impure(origin, owner, newApply, internalKind) + CpsTree.impure(origin, owner, wrappedNewApply, internalKind) } case AsyncKind.AsyncLambda(bodyKind) => if (usingDirectContext && transformDirectContextLambdaCall) { // it's why better to keep transformDirectContextLambdaCall = false throw CpsTransformException("Unsuppored use of lamba application as output of direct context lambda", origin.srcPos) } - CpsTree.opaqueAsyncLambda(origin,owner,newApply,bodyKind) + CpsTree.opaqueAsyncLambda(origin,owner,wrappedNewApply,bodyKind) } else if (usingDirectContext) { - CpsTree.impure(origin,owner,newApply, AsyncKind.Sync) + CpsTree.impure(origin,owner,wrappedNewApply, AsyncKind.Sync) } else { - CpsTree.pure(origin,owner,newApply) + CpsTree.pure(origin,owner,wrappedNewApply) } @@ -732,6 +759,7 @@ object ApplyTransform { case ShiftedFun(origin, obj, method, targs, additionalArgs, canBeOverloaded, callShouldBeInlined, shape) => obj.tpe.member(method).symbol val lastCpsTree = adoptCallMode(fullOrigin, pureReply, funSymbol, owner, argss, callMode, nesting) + // to preserve order of evaluation, evaluate all arguments before call. val nApplyCpsTree = genPrefixes(argss, lastCpsTree) val retval = nApplyCpsTree Log.trace(s"genApplication result: ${retval.show}", nesting) diff --git a/compiler-plugin/src/main/scala/cps/plugin/forest/application/ApplyArgList.scala b/compiler-plugin/src/main/scala/cps/plugin/forest/application/ApplyArgList.scala index 5a31b8712..57bc26d56 100644 --- a/compiler-plugin/src/main/scala/cps/plugin/forest/application/ApplyArgList.scala +++ b/compiler-plugin/src/main/scala/cps/plugin/forest/application/ApplyArgList.scala @@ -76,12 +76,20 @@ object ApplyTermArgList { val (aExpr, named) = a match case NamedArg(name, expr) => (expr, Some(name.toTermName)) case expr => (expr, None) - val depResult = DependencyCheck.run(aExpr,s.symbols) + val isDirectContext = mt.isDirectContext(s.index, a.srcPos) + val depResult = if (isDirectContext) { + // direct context is always syntetic and substituted by real context in the code + // (if we will generate temporary val for it, we broke this substitution) + // (mb add contextual argument for byIdentityCall to set). + DependencyCheck.Result(false, Set.empty) + } else { + DependencyCheck.run(aExpr,s.symbols) + } val nApplyArg = ApplyArg( aExpr, mt.paramName(s.index, a.srcPos).toTermName, mt.paramType(s.index, a.srcPos), mt.isByName(s.index, a.srcPos), - mt.isDirectContext(s.index, a.srcPos), + isDirectContext, owner, depResult.canBeDependent, named, diff --git a/compiler-plugin/src/test/scala/cc/DotcInvocations.scala b/compiler-plugin/src/test/scala/cc/DotcInvocations.scala index 8f869609c..1f7b51f02 100644 --- a/compiler-plugin/src/test/scala/cc/DotcInvocations.scala +++ b/compiler-plugin/src/test/scala/cc/DotcInvocations.scala @@ -19,6 +19,7 @@ class DotcInvocations(silent: Boolean = true, scalaJs: Boolean = false) { compilerClasspathOption ++ extraArgs ++ DotcInvocations.defaultCompileOpts ++ + (if (scalaJs) List("-scalajs") else List.empty) ++ (if (checkAll) List("-Ycheck:all") else List.empty) println(s"compile args: ${args}, usePlugin=${usePlugin}") compileFilesWithFullArgs(files, outDir, args) diff --git a/compiler-plugin/src/test/scala/cc/Test22cc.scala b/compiler-plugin/src/test/scala/cc/Test22cc.scala index eec3118a1..95c939586 100644 --- a/compiler-plugin/src/test/scala/cc/Test22cc.scala +++ b/compiler-plugin/src/test/scala/cc/Test22cc.scala @@ -6,7 +6,7 @@ import org.junit.{Ignore, Test} class Test22cc { @Test - //@Ignore // yet not working + @Ignore def testCompileContextExtractor() = { val inDir = "testdata/set22cc/m1" @@ -27,4 +27,53 @@ class Test22cc { } + @Test + @Ignore // In process now + def testCompileJSAsyncWithInternalCpsAsync() = { + + val inDir = "testdata/set22cc/m2" + val outDir = "testdata/set22cc/m2-out" + val jsLinkOut = "testdata/set22cc/m2-linkout" + + val dotcInvocations = new DotcInvocations(silent = false, scalaJs = true) + + val reporter = dotcInvocations.compileFilesInDirs(List(inDir), outDir, checkAll = true, + extraArgs = List("-Vprint:erasure,rssh.cps", "-experimental", "-P:rssh.cps:printCode"), + usePlugin = true + ) + + if (reporter.hasErrors) { + println("errors:") + for(error <- reporter.allErrors) + val pos = error.pos + println(s"${pos.source}:${pos.startLine}:${pos.startColumn} ${error.msg}") + println(reporter.summary) + } + + assert(!reporter.hasErrors, "compilation failed") + + } + + @Test + @Ignore // solved, not in focus. + def testCompileJSResolveReject() = { + + val inDir = "testdata/set22cc/m3" + val outDir = "testdata/set22cc/m3-out" + val jsLinkOut = "testdata/set22cc/m3-linkout" + + val dotcInvocations = new DotcInvocations(silent = false, scalaJs = true) + + val reporter = dotcInvocations.compileFilesInDirs(List(inDir), outDir, checkAll = true, + extraArgs = List("-Vprint:erasure,rssh.cps", "-Yprint-syms", "-experimental" ), + usePlugin = false + ) + + if (reporter.hasErrors) { + println(reporter.summary) + } + + assert(!reporter.hasErrors, "compilation failed") + + } } diff --git a/compiler-plugin/testdata/set22cc/m2/Async.scala b/compiler-plugin/testdata/set22cc/m2/Async.scala new file mode 100644 index 000000000..a2db46d8d --- /dev/null +++ b/compiler-plugin/testdata/set22cc/m2/Async.scala @@ -0,0 +1,118 @@ +package gears.async + +import cps.* +import cps.plugin.* +import gears.async.JSAsync.JSAsyncCpsMonad + + +trait Async extends CpsTryMonadContext[JSAsync]: + + val support: AsyncSupport + + val scheduler: support.Scheduler + + def await[T](src: Async.Source[T])(using AsyncContext): T + + def group: CompletionGroup + + def withGroup(group: CompletionGroup): Async + + override val monad: CpsTryMonad[JSAsync] = new JSAsync.JSAsyncCpsMonad(this) + +end Async + + +object Async: + + // + inline def apply[T](inline body: Async ?=> T ): JSAsync[T] = + cpsAsyncApply[JSAsync,T, Async](JSAsyncCpsMonad(Async.empty), body(using summon[Async]) ) + + + inline def current(using async: Async): Async = async + + opaque type Spawn <: Async = Async + + def group[T](body: Async.Spawn ?=> T)(using AsyncContext): T = + withNewCompletionGroup(CompletionGroup().link())(body) + + def group_async[T](body: Async.Spawn => JSAsync[T])(using AsyncContext): T = + withNewCompletionGroupAsync(CompletionGroup().link())(body) + + case class DefaultAsync(group: CompletionGroup = CompletionGroup.Unlinked) extends Async: + override val support: JSAsyncSupport.type = JSAsyncSupport + override val scheduler = support.scheduler + override def await[T](src: Async.Source[T])(using AsyncContext): T = ??? + override def withGroup(group: CompletionGroup): Async = copy(group = group) + + private[async] def empty: Async = DefaultAsync() + + /** Runs a body within another completion group. When the body returns, the group is cancelled and its completion + * awaited with the `Unlinked` group. + */ + private[async] def withNewCompletionGroup[T](group: CompletionGroup)(body: Async.Spawn ?=> T)(using ac: AsyncContext): T = { + withNewCompletionGroupAsync(group)( spawn => JSAsync.Pure(body(using spawn)) ) + } + + + private[async] def withNewCompletionGroupAsync[T](group: CompletionGroup)(body: Async.Spawn => JSAsync[T])(using ac: AsyncContext): T = { + val completionAsync = + if CompletionGroup.Unlinked == ac.async.group + then ac.async + else ac.async.withGroup(CompletionGroup.Unlinked) + + try + // compiler crash here + await(body(summon[Async].withGroup(group))) + finally + group.cancel() + given Async = completionAsync + group.waitCompletion() + } + + + /** An asynchronous data source. Sources can be persistent or ephemeral. A persistent source will always pass same + * data to calls of [[Source!.poll]] and [[Source!.onComplete]]. An ephemeral source can pass new data in every call. + * + * @see + * An example of a persistent source is [[gears.async.Future]]. + * @see + * An example of an ephemeral source is [[gears.async.Channel]]. + */ + trait Source[+T]: + + /**Checks whether data is available at present and pass it to `k` if so. Calls to `poll` are always synchronous and + * non-blocking. + */ + def poll(k: Listener[T]): Boolean + + /** + * Once data is available, pass it to the listener `k`. `onComplete` is always non-blocking. + */ + def onComplete(k: Listener[T]): Unit + + /** Signal that listener `k` is dead (i.e. will always fail to acquire locks from now on), and should be removed + * from `onComplete` queues. + */ + def dropListener(k: Listener[T]): Unit + + def poll(): Option[T] = ??? + + end Source + + abstract class OriginalSource[+T] extends Source[T]: + + /** Add `k` to the listener set of this source. */ + protected def addListener(k: Listener[T]): Unit + + def onComplete(k: Listener[T]): Unit = synchronized: + if !poll(k) then addListener(k) + + end OriginalSource + + +end Async + + + + diff --git a/compiler-plugin/testdata/set22cc/m2/AsyncSupport.scala b/compiler-plugin/testdata/set22cc/m2/AsyncSupport.scala new file mode 100644 index 000000000..e4e512dc1 --- /dev/null +++ b/compiler-plugin/testdata/set22cc/m2/AsyncSupport.scala @@ -0,0 +1,55 @@ +package gears.async + +import scala.concurrent.* +import scala.concurrent.duration.* +import cps.* + +type AsyncContext = CpsDirect[JSAsync] + +extension (actx: AsyncContext) + + def group: CompletionGroup = + async.group + + inline def async: Async = + actx.context.asInstanceOf[Async] + +@cps.plugin.annotation.CpsNotChange() +given asyncFromAsyncContext(using actx: AsyncContext): Async = actx.context.asInstanceOf[Async] + +given cps.macros.flags.UseCompilerPlugin.type = cps.macros.flags.UseCompilerPlugin + + +/** + * Analog of suspension in monadic is arrow. + **/ +trait Suspension[-T, +R]: + + def resume(arg:T)(using CpsDirect[JSAsync]): R + + +/** + * A scheduler implementation, with the ability to execute a computation immediately or after a delay. + **/ +trait Scheduler: + def execute(body: Runnable)(using CpsDirect[JSAsync]): Unit + def schedule(delay: FiniteDuration, body: Runnable): Cancellable + + +trait SuspendSupport: + + type Label[R] + + type Suspension[-T, +R] <: gears.async.Suspension[T,R] + + def boundary[R](body: Label[R] ?=> R)(using CpsDirect[JSAsync]): R + + + +trait AsyncSupport extends SuspendSupport: + + type Scheduler <: gears.async.Scheduler + + + + diff --git a/compiler-plugin/testdata/set22cc/m2/Cancellable.scala b/compiler-plugin/testdata/set22cc/m2/Cancellable.scala new file mode 100644 index 000000000..48c710190 --- /dev/null +++ b/compiler-plugin/testdata/set22cc/m2/Cancellable.scala @@ -0,0 +1,51 @@ +package gears.async + +trait Cancellable: + + private[gears] var optGroup: Option[CompletionGroup] = None + + // bug in scalajs: stack overflow when use unlinked. + private def group: CompletionGroup = optGroup.getOrElse(CompletionGroup.Unlinked) + private def group_=(group: CompletionGroup): Unit = optGroup = Some(group) + + def cancel(): Unit + + def link(group: CompletionGroup): this.type = synchronized: + this.group.drop(this) + this.group = group + this.group.add(this) + this + + /** Link this cancellable to the cancellable group of the current async context. + */ + def link()(using async: Async): this.type = + link(async.group) + + /** Unlink this cancellable from its group. */ + def unlink(): this.type = + link(CompletionGroup.Unlinked) + + +end Cancellable + + +object Cancellable: + + trait Tracking extends Cancellable: + + def isCancelled: Boolean + + + object Tracking: + + def apply() = new Tracking: + private var cancelled: Boolean = false + + def cancel(): Unit = + cancelled = true + + def isCancelled = cancelled + + end Tracking + +end Cancellable diff --git a/compiler-plugin/testdata/set22cc/m2/CompletionGroup.scala b/compiler-plugin/testdata/set22cc/m2/CompletionGroup.scala new file mode 100644 index 000000000..f430acb38 --- /dev/null +++ b/compiler-plugin/testdata/set22cc/m2/CompletionGroup.scala @@ -0,0 +1,37 @@ +package gears.async + +import cps.* + +class CompletionGroup extends Cancellable.Tracking: + + def isCancelled: Boolean = ??? + + def cancel(): Unit = ??? + + /** Wait for all members of the group to complete and unlink themselves. */ + private[async] def waitCompletion()(using AsyncContext): Unit = { + ??? + } + + /** Add given member to the members set. If the group has already been cancelled, cancels that member immediately. */ + def add(member: Cancellable): Unit = + ??? + + /** Remove given member from the members set if it is an element */ + def drop(member: Cancellable): Unit = + ??? + + +object CompletionGroup: + + /** A sentinel group of cancellables that are in fact not linked to any real group. `cancel`, `add`, and `drop` do + * nothing when called on this group. + */ + object Unlinked extends CompletionGroup: + override def cancel(): Unit = () + override def waitCompletion()(using CpsDirect[JSAsync]): Unit = () + override def add(member: Cancellable): Unit = () + override def drop(member: Cancellable): Unit = () + end Unlinked + +end CompletionGroup diff --git a/compiler-plugin/testdata/set22cc/m2/JSAsyncCpsMonad.scala b/compiler-plugin/testdata/set22cc/m2/JSAsyncCpsMonad.scala new file mode 100644 index 000000000..230c95a91 --- /dev/null +++ b/compiler-plugin/testdata/set22cc/m2/JSAsyncCpsMonad.scala @@ -0,0 +1,356 @@ +package gears.async + +import cps.* + +import scala.concurrent.* +import scala.scalajs.js.{JavaScriptException, Thenable} +import scala.util.* +import scala.util.control.NonFatal +import scalajs.* + + +/** + * AsyncMonad + **/ +sealed trait JSAsync[+A] { + + def map[B](f: A => B): JSAsync[B] + + def flatMap[B](f: A => JSAsync[B]): JSAsync[B] + + def mapTry[B](f: Try[A] => B): JSAsync[B] = + flatMapTry((ta: Try[A]) => JSAsync.Pure(f(ta))) + + def flatMapTry[B](f: Try[A] => JSAsync[B]): JSAsync[B] + + def toPromise: js.Promise[A] + + def toFuture: Future[A] + + def onComplete(f: Try[A] => Unit): Unit + +} + + + + +object JSAsync { + + + + case class JSPromiseWrapper[A](val jsPromise: js.Promise[Either[A,JSAsync[A]]]) extends JSAsync[A] { + + override def map[B](f: A=>B): JSAsync[B] = { + JSPromiseWrapper( + jsPromise.`then`[Either[B,JSAsync[B]]]( la => + la match + case Left(a) => + try + Left(f(a)) + catch + case NonFatal(ex) => + Right(Error(ex)) + case Right(next) => + Right(next.map(f)) + ) + ) + } + + override def flatMap[B](f: A => JSAsync[B]): JSAsync[B] = { + import scala.concurrent.ExecutionContext.Implicits.global + JSPromiseWrapper( + jsPromise.`then`[Either[B,JSAsync[B]]] { + case Left(a) => + evalOp(f(a)) + case Right(lPrev) => + // TODO: mb deque of task to prevent quadratic complexity of flatMap chains. + Right(lPrev.flatMap(f)) + } + ) + } + + override def flatMapTry[B](f: Try[A] => JSAsync[B]): JSAsync[B] = { + import scala.concurrent.ExecutionContext.Implicits.global + + JSPromiseWrapper( + jsPromise.`then`[Either[B,JSAsync[B]]]( + onFulfilled = { + case Left(a) => + evalOp(f(Success(a))) + case Right(wa) => + wa match + case Error(ex) => + evalOp(f(Failure(ex))) + case other => + Right(other.flatMapTry(f)) + }, + onRejected = { (e: Any) => + e match + case ex: Throwable => + evalOp(f(Failure(ex))) + case other => + evalOp(f(Failure(JavaScriptException(other)))) + } + ) + ) + } + + override def toPromise: js.Promise[A] = { + jsPromise.`then`[A]({ + case Left(a) => a + case Right(next) => + next.toPromise + }) + } + + override def toFuture: Future[A] = { + import scala.concurrent.ExecutionContext.Implicits.global + jsPromise.toFuture.flatMap({ + case Left(a) => Future successful a + case Right(next) => next.toFuture + }) + } + + override def onComplete(f: Try[A] => Unit): Unit = { + jsPromise.`then`[Unit]({ + case Left(a) => f(Success(a)) + case Right(next) => next.onComplete(f) + }) + } + + } + + case class ScalaFutureWrapper[A](future: Future[Either[A,JSAsync[A]]]) extends JSAsync[A] { + import scala.concurrent.ExecutionContext.Implicits.global + + override def map[B](f: A=>B): JSAsync[B] = { + ScalaFutureWrapper(future.map({ + case Left(a) => evalOp(Pure(f(a))) + case Right(next) => Right(next.map(f)) + })) + } + + override def flatMap[B](f: A=>JSAsync[B]) = { + val next:Future[Either[B,JSAsync[B]]] = future.flatMap{ + case Left(a) => + Future successful evalOp[B](f(a)) + case Right(wa) => + wa match + case Pure(a) => + Future successful evalOp(f(a)) + case Error(ex) => + Future failed ex + case ScalaFutureWrapper(fw) => + fw.flatMap{ + case Left(a) => Future successful evalOp[B](f(a)) + case Right(next) => Future successful Right(next.flatMap(f)) + } + case other => + Future successful( Right(other.flatMap(f)) ) + } + ScalaFutureWrapper(next) + } + + + override def flatMapTry[B](f: Try[A] => JSAsync[B]): JSAsync[B] = { + val next: Future[Either[B,JSAsync[B]]] = future.flatMap{ + case Left(a) => + Future successful evalOp(f(Success(a))) + case Right(wa) => + wa match + case Pure(a) => + Future successful evalOp(f(Success(a))) + case Error(ex) => + Future successful evalOp(f(Failure(ex))) + case other => + Future successful Right(other.flatMapTry(f)) + } + ScalaFutureWrapper(next) + } + + override def toPromise: js.Promise[A] = { + val p = js.Promise[A]((resolve, reject) => { + future.onComplete({ + case Success(Left(a)) => + resolve(a) + case Success(Right(next)) => + resolve(next.toPromise) + case Failure(ex) => + reject(ex) + }) + }) + p + } + + override def toFuture: Future[A] = { + future.flatMap({ + case Left(a) => Future successful a + case Right(next) => next.toFuture + }) + } + + override def onComplete(f: Try[A] => Unit): Unit = { + future.onComplete({ + case Success(Left(a)) => f(Success(a)) + case Success(Right(next)) => next.onComplete(f) + case Failure(ex) => f(Failure(ex)) + }) + } + + } + + case class Pure[A](val value: A) extends JSAsync[A] { + + override def map[B](f: A=>B): JSAsync[B] = { + try + Pure(f(value)) + catch + case NonFatal(ex) => + Error(ex) + } + + override def flatMap[B](f: A=>JSAsync[B]): JSAsync[B] = { + try + f(value) + catch + case NonFatal(ex) => Error(ex) + ScalaFutureWrapper(Future.failed(ex)) + } + + override def flatMapTry[B](f: Try[A] => JSAsync[B]): JSAsync[B] = { + try + f(Success(value)) + catch + case NonFatal(ex) => + Error(ex) + } + + override def toPromise: js.Promise[A] = { + js.Promise.resolve(value) + } + + override def toFuture: Future[A] = { + Future successful value + } + + override def onComplete(f: Try[A] => Unit): Unit = { + f(Success(value)) + } + + + } + + object Pure { + + val unit: Pure[Unit] = new Pure(()) + + } + + case class Error(ex: Throwable) extends JSAsync[Nothing] { + + override def map[B](f: Nothing=>B): JSAsync[B] = { + this + } + + override def flatMap[B](f: Nothing=>JSAsync[B]): JSAsync[B] = { + this + } + + override def flatMapTry[B](f: Try[Nothing] => JSAsync[B]): JSAsync[B] = { + try + f(Failure(ex)) + catch + case NonFatal(ex1) => + Error(ex1) + } + + override def toPromise: js.Promise[Nothing] = { + js.Promise.reject(ex) + } + + override def toFuture: Future[Nothing] = { + Future failed ex + } + + override def onComplete(f: Try[Nothing] => Unit): Unit = { + f(Failure(ex)) + } + + } + + + + class JSAsyncCpsMonad(jctx: Async) extends CpsTryContextMonad[JSAsync, Async] { + + override def pure[A](a:A): JSAsync[A] = JSAsync.Pure(a) + + override def map[A,B](fa:JSAsync[A])(f: A => B): JSAsync[B] = + fa.map(f) + + override def flatMap[A,B](fa:JSAsync[A])(f: A => JSAsync[B]): JSAsync[B] = + fa.flatMap(f) + + override def error[A](e: Throwable): JSAsync[A] = + Error(e) + + override def flatMapTry[A,B](fa:JSAsync[A])(f: Try[A] => JSAsync[B]): JSAsync[B] = + fa.flatMapTry(f) + + override def applyContext[A](op: Async => JSAsync[A]): JSAsync[A] = { + try + op(jctx) + catch + case NonFatal(ex) => + ScalaFutureWrapper(Future.failed(ex)) + } + + } + + given emptyJSAsyncMonad: JSAsyncCpsMonad = new JSAsyncCpsMonad(Async.empty) + + given CpsMonadConversion[js.Promise, JSAsync] with { + + def apply[T](p: js.Promise[T]): JSAsync[T] = + JSPromiseWrapper(p.`then`(v => Left(v))) + + } + + given CpsMonadConversion[Future, JSAsync] with { + + def apply[T](f: Future[T]): JSAsync[T] = + import scala.concurrent.ExecutionContext.Implicits.global + ScalaFutureWrapper(f.map(Left(_))) + + } + + def eval[A](f: =>A): JSAsync[A] = { + try + Pure(f) + catch + case NonFatal(ex) => + Error(ex) + } + + + def evalOp[T](op: => JSAsync[T]): Either[T, JSAsync[T]] = { + try + val r = op + r match + case Pure(v) => + Left(v) + case other => + Right(other) + catch + case NonFatal(ex) => + Right(Error(ex)) + } + + def failure(ex: Throwable): JSAsync[Nothing] = + Error(ex) + + def success[A](a: A): JSAsync[A] = Pure(a) + + +} + + diff --git a/compiler-plugin/testdata/set22cc/m2/JSAsyncSupport.scala b/compiler-plugin/testdata/set22cc/m2/JSAsyncSupport.scala new file mode 100644 index 000000000..09ba37636 --- /dev/null +++ b/compiler-plugin/testdata/set22cc/m2/JSAsyncSupport.scala @@ -0,0 +1,129 @@ +package gears.async + +import scala.concurrent.duration.FiniteDuration +import scala.util.* +import scala.util.control.NonFatal +import scala.scalajs.* +import scala.scalajs.js.timers.* +import cps.* + +import scala.scalajs.js.{JavaScriptException, Promise, Thenable} +import scala.util.boundary.Break + + +object JSScheduler extends Scheduler: + + val resolvedUnitPromise = js.Promise.resolve(()) + + def execute(body: Runnable)(using AsyncContext): Unit = { + ??? + /* + val p = js.Promise[Unit]((resolve, reject) => { + resolvedUnitPromise.`then`( (_: Unit) => { + try + body.run() + catch + case NonFatal(e) => + e match + case e: JavaScriptException => + reject(e.exception) + case other => + reject(e) + }) + }) + await(p) + + */ + } + + def schedule(delay: FiniteDuration, body: Runnable): Cancellable = { + //val handle = setTimeout(delay.toMillis.toDouble)(body.run()) + //() => clearTimeout(handle) + ??? + } + + +end JSScheduler + + +object JSAsyncSupport extends AsyncSupport: + + override type Scheduler = JSScheduler.type + + val scheduler: Scheduler = JSScheduler + + case class JSLabel[R]( + var wasSuspended: Boolean, + bodyRrp: JSResolveRejectWithPromise[R], + suspendRrp: JSResolveRejectWithPromise[R]) + + + + type Label[R] = JSLabel[R] + + case class JSSuspension[-T, +R](trr: JSResolveReject[T], bp: js.Promise[R]) extends gears.async.Suspension[T, R] { + + def resume(t: T)(using AsyncContext): R = { + trr.resolve(t) + await(bp) + } + + } + + type Suspension[-T,+R] = JSSuspension[T,R] + + def boundary[R](body: Label[R] ?=> R)(using AsyncContext): R = { + boundary_async[R](label => JSAsync.eval(body(using label))) + } + + def boundary_async[R](body: Label[R] => JSAsync[R])(using AsyncContext): R = { + val bodyRrp = await(JSResolveReject.withPromise[R]()) + val suspendRrp = await(JSResolveReject.withPromise[R]()) + val label: Label[R] = JSLabel(false, bodyRrp, suspendRrp) + try + body(label).onComplete { + case Success(x) => + label.bodyRrp.resolveReject.resolve(x) + if (!label.wasSuspended) then + label.suspendRrp.resolveReject.resolve(x) + case Failure(ex) => + label.bodyRrp.resolveReject.reject(ex) + if (!label.wasSuspended) then + label.suspendRrp.resolveReject.reject(ex) + } + catch + case NonFatal(e) => + label.bodyRrp.resolveReject.reject(e) + if (!label.wasSuspended) { + label.suspendRrp.resolveReject.reject(e) + } + + await(label.suspendRrp.promise) + } + + def suspend[T, R](body: Suspension[T,R] => R)(using Label[R], CpsDirect[JSAsync]): T = { + suspend_async[T,R](s => JSAsync.eval(body(s))) + } + + def suspend_async[T, R](body: Suspension[T,R] => JSAsync[R])(using Label[R], CpsDirect[JSAsync]): T = { + val tRrp = await(JSResolveReject.withPromise[T]()) + summon[Label[R]].wasSuspended = true + + val suspension = JSSuspension(tRrp.resolveReject,summon[Label[R]].bodyRrp.promise) + try + body(suspension).onComplete{ + case Success(r) => + summon[Label[R]].suspendRrp.resolveReject.resolve(r) + case Failure(ex) => + summon[Label[R]].suspendRrp.resolveReject.reject(ex) + } + catch + case NonFatal(e) => + summon[Label[R]].suspendRrp.resolveReject.reject(e) + + await(tRrp.promise) + } + + +end JSAsyncSupport + diff --git a/compiler-plugin/testdata/set22cc/m2/JSBoundaryTest.scala b/compiler-plugin/testdata/set22cc/m2/JSBoundaryTest.scala new file mode 100644 index 000000000..3b3960f6d --- /dev/null +++ b/compiler-plugin/testdata/set22cc/m2/JSBoundaryTest.scala @@ -0,0 +1,67 @@ +package jsasync + +import cps.* +import gears.async.{JSAsync, JSAsyncSupport} +import scala.concurrent.* +import scala.concurrent.duration.* +import scala.concurrent.ExecutionContext.Implicits.global + +import gears.async.given +import cps.plugin.annotation.CpsDebugLevel + + +@CpsDebugLevel(20) +class JSBoundaryTest { + + + @CpsDebugLevel(20) + def simpleFlowWithSuspension = { + println("simple flow with suspension:start") + val r = async[JSAsync] { + println("in async") + var q = 1 + println("before boundary") + JSAsyncSupport.boundary{ + println("in boundary") + q = 2 + println("before suspend") + val k = JSAsyncSupport.suspend[Int,Int]{ s => + println("before resume") + val r = s.resume(3)+1 + println("after resume") + r + } + println("after suspend") + if (k == 3) then + q = 10 + println("Ok") + q+k + } + } + r.toFuture.map{ x => + println(s"result = $x") + assert(x == 14) + } + } + + + /* + def simpleFloWithBoundaryButWithoutSuspension = { + var r = async[JSAsync] { + var q = 1 + JSAsyncSupport.boundary{ + q = 2 + val k = 3 + if (k == 3) then + q = 10 + q+k + } + } + r.toFuture.map{ x => + assert(x == 13) + } + } + + */ + +} diff --git a/compiler-plugin/testdata/set22cc/m2/JSResolveReject.scala b/compiler-plugin/testdata/set22cc/m2/JSResolveReject.scala new file mode 100644 index 000000000..2cee30a5b --- /dev/null +++ b/compiler-plugin/testdata/set22cc/m2/JSResolveReject.scala @@ -0,0 +1,37 @@ +package gears.async + +import scala.scalajs.js +import scala.scalajs.js.Promise + +case class JSResolveReject[-R](resolve: (R) => Unit, reject: Throwable => Unit) { + + + +} + +case class JSResolveRejectWithPromise[R](resolveReject: JSResolveReject[R], promise: js.Promise[R]) + + +case object JSResolveReject { + + + + def withPromise[R](): js.Promise[JSResolveRejectWithPromise[R]] = { + var outResolve: (R) => Unit = null + var outReject: (Throwable) => Unit = null + val p = new js.Promise[R]((resolve, reject) => { + outResolve = (x) => resolve(x) + outReject = (ex) => reject(ex) + }) + // now we want be sure, that assigments was done + // now we hope that second promise with evaluated after first, but it is not guaranteed + // TODO: in JS2024 exosts Promise.withResolvers which can be used here, but it is not yet available in scala.js. + // TODO: research and maybe wrap in interval + val wrappedPromise = new Promise[JSResolveRejectWithPromise[R]]((resultResolve, resultReject) => { + resultResolve(JSResolveRejectWithPromise(JSResolveReject(outResolve, outReject), p)) + }) + wrappedPromise + } + + +} diff --git a/compiler-plugin/testdata/set22cc/m2/Listener.scala b/compiler-plugin/testdata/set22cc/m2/Listener.scala new file mode 100644 index 000000000..3d05b1d5a --- /dev/null +++ b/compiler-plugin/testdata/set22cc/m2/Listener.scala @@ -0,0 +1,101 @@ +package gears.async + +import java.util.concurrent.locks.ReentrantLock + +trait Listener[-T]: + + /** Complete the listener with the given item, from the given source. **If the listener exposes a + * [[Listener.ListenerLock]]**, it is required to acquire this lock before calling [[complete]]. This can also be + * done conveniently with [[completeNow]]. For performance reasons, this condition is usually not checked and will + * end up causing unexpected behavior if not satisfied. + * + * The listener must automatically release its own lock upon completion. + */ + def complete(data: T, source: Async.Source[T]): Unit + + /** Represents the exposed API for synchronization on listeners at receiving time. If the listener does not have any + * form of synchronization, [[lock]] should be `null`. + */ + val lock: Listener.ListenerLock | Null + + /** Attempts to acquire locks and then calling [[complete]] with the given item and source. If locking fails, + * [[releaseLock]] is automatically called. + */ + def completeNow(data: T, source: Async.Source[T]): Boolean = + if acquireLock() then + this.complete(data, source) + true + else false + + /** Release the listener's lock if it exists. */ + inline final def releaseLock(): Unit = if lock != null then lock.release() + + /** Attempts to lock the listener, if such a lock exists. Succeeds with `true` immediately if [[lock]] is `null`. + */ + inline final def acquireLock(): Boolean = + if lock != null then lock.acquire() else true + + +object Listener: + + /** A simple [[Listener]] that always accepts the item and sends it to the consumer. */ + inline def acceptingListener[T](inline consumer: (T, Async.Source[T]) => Unit) = + new Listener[T]: + val lock = null + def complete(data: T, source: Async.Source[T]) = consumer(data, source) + + /** Returns a simple [[Listener]] that always accepts the item and sends it to the consumer. */ + inline def apply[T](consumer: (T, Async.Source[T]) => Unit): Listener[T] = acceptingListener(consumer) + + /** A special class of listener that forwards the inner listener through the given source. For purposes of + * [[Async.Source.dropListener]] these listeners are compared for equality by the hash of the source and the inner + * listener. + */ + abstract case class ForwardingListener[T](src: Async.Source[?], inner: Listener[?]) extends Listener[T] + + object ForwardingListener: + /** Create an empty [[ForwardingListener]] for equality comparison. */ + def empty[T](src: Async.Source[?], inner: Listener[?]) = new ForwardingListener[T](src, inner): + val lock = null + override def complete(data: T, source: Async.Source[T]) = ??? + + /** A lock required by a listener to be acquired before accepting values. Should there be multiple listeners that + * needs to be locked at the same time, they should be locked by larger-number-first. + * + * Some implementations are provided for ease of implementations: + * - For custom listener implementations involving locks: [[NumberedLock]] provides uniquely numbered locks. + * - For source transformation implementations: [[withLock]] is a convenient `.map` for `[[ListenerLock]] | Null`. + */ + trait ListenerLock: + /** The assigned number of the lock. It is required that listeners that can be locked together to have different + * [[selfNumber numbers]]. This requirement can be simply done by using a lock created using [[NumberedLock]]. + */ + val selfNumber: Long + + /** Attempt to lock the current [[ListenerLock]]. Locks are guaranteed to be held as short as possible. + */ + def acquire(): Boolean + + /** Release the current lock. */ + def release(): Unit + end ListenerLock + + /** Maps the lock of a listener, if it exists. */ + inline def withLock[T](listener: Listener[?])(inline body: ListenerLock => T): T | Null = + listener.lock match + case null => null + case l: ListenerLock => body(l) + + /** A helper instance that provides an uniquely numbered mutex. */ + trait NumberedLock: + import NumberedLock._ + + val number = listenerNumber.getAndIncrement() + private val lock0 = ReentrantLock() + + protected def acquireLock() = lock0.lock() + protected def releaseLock() = lock0.unlock() + + object NumberedLock: + private val listenerNumber = java.util.concurrent.atomic.AtomicLong() + diff --git a/compiler-plugin/testdata/set22cc/m2/M.scala b/compiler-plugin/testdata/set22cc/m2/M.scala deleted file mode 100644 index 8c7a7367c..000000000 --- a/compiler-plugin/testdata/set22cc/m2/M.scala +++ /dev/null @@ -1,14 +0,0 @@ - - -trait Monad[F[_]]: - def flatMap[T,S](ft:F[T])(f: T=>F[S]): F[S] - def pure[T](t:T): F[T] - - def map(T,S)(ft:F[T])(f: T=>S): F[S] = flatMap(ft)(t => pure(f(t))) - def flatten(T)(ft:F[F[T]]): F[T] = flatMap(ft)(identity) - - def flatMap1(T,S)(ft:F[T])(f: T=>F[S]): F[S] = { - // define via map and flatten - map(f) - flatten(map(ft)(f)) - } \ No newline at end of file diff --git a/compiler-plugin/testdata/set22cc/m2/futures.scala b/compiler-plugin/testdata/set22cc/m2/futures.scala new file mode 100644 index 000000000..2ec9182fe --- /dev/null +++ b/compiler-plugin/testdata/set22cc/m2/futures.scala @@ -0,0 +1,78 @@ +package gears.async + +import java.util.concurrent.atomic.AtomicBoolean +import scala.annotation.unchecked.uncheckedVariance +import scala.collection.mutable +import scala.compiletime.uninitialized +import scala.util.Try + +trait Future[+T] extends Async.OriginalSource[Try[T]], Cancellable + +object Future: + + private class CoreFuture[+T] extends Future[T]: + + @volatile protected var hasCompleted: Boolean = false + protected var cancelRequest = AtomicBoolean(false) + private var result: Try[T] = uninitialized // guaranteed to be set if hasCompleted = true + private val waiting: mutable.Set[Listener[Try[T]]] = mutable.Set() + + // Async.Source method implementations + + def poll(k: Listener[Try[T]]): Boolean = + if hasCompleted then + k.completeNow(result, this) + true + else false + + def addListener(k: Listener[Try[T]]): Unit = synchronized: + waiting += k + + def dropListener(k: Listener[Try[T]]): Unit = synchronized: + waiting -= k + + // Cancellable method implementations + + def cancel(): Unit = + setCancelled() + + override def link(group: CompletionGroup): this.type = + // though hasCompleted is accessible without "synchronized", + // we want it not to be run while the future was trying to complete. + synchronized: + if !hasCompleted || group == CompletionGroup.Unlinked then super.link(group) + else this + + /** Sets the cancellation state and returns `true` if the future has not been completed and cancelled before. */ + protected final def setCancelled(): Boolean = + !hasCompleted && cancelRequest.compareAndSet(false, true) + + /** Complete future with result. If future was cancelled in the meantime, return a CancellationException failure + * instead. Note: @uncheckedVariance is safe here since `complete` is called from only two places: + * - from the initializer of RunnableFuture, where we are sure that `T` is exactly the type with which the future + * was created, and + * - from Promise.complete, where we are sure the type `T` is exactly the type with which the future was created + * since `Promise` is invariant. + */ + private[Future] def complete(result: Try[T] @uncheckedVariance): Unit = + val toNotify = synchronized: + if hasCompleted then Nil + else + this.result = result + hasCompleted = true + val ws = waiting.toList + waiting.clear() + unlink() + ws + for listener <- toNotify do listener.completeNow(result, this) + + end CoreFuture + + private class RunnableFuture[+T](body: Async.Spawn ?=> T)(using ac: Async) extends CoreFuture[T]: + + + + end RunnableFuture + + +end Future \ No newline at end of file diff --git a/compiler-plugin/testdata/set22cc/m2/package.scala b/compiler-plugin/testdata/set22cc/m2/package.scala new file mode 100644 index 000000000..01ed1a58a --- /dev/null +++ b/compiler-plugin/testdata/set22cc/m2/package.scala @@ -0,0 +1,8 @@ +package gears + +package object async { + + type CancellationException = java.util.concurrent.CancellationException + + +} diff --git a/compiler-plugin/testdata/set22cc/m3/JSResolveReject.scala b/compiler-plugin/testdata/set22cc/m3/JSResolveReject.scala new file mode 100644 index 000000000..2cee30a5b --- /dev/null +++ b/compiler-plugin/testdata/set22cc/m3/JSResolveReject.scala @@ -0,0 +1,37 @@ +package gears.async + +import scala.scalajs.js +import scala.scalajs.js.Promise + +case class JSResolveReject[-R](resolve: (R) => Unit, reject: Throwable => Unit) { + + + +} + +case class JSResolveRejectWithPromise[R](resolveReject: JSResolveReject[R], promise: js.Promise[R]) + + +case object JSResolveReject { + + + + def withPromise[R](): js.Promise[JSResolveRejectWithPromise[R]] = { + var outResolve: (R) => Unit = null + var outReject: (Throwable) => Unit = null + val p = new js.Promise[R]((resolve, reject) => { + outResolve = (x) => resolve(x) + outReject = (ex) => reject(ex) + }) + // now we want be sure, that assigments was done + // now we hope that second promise with evaluated after first, but it is not guaranteed + // TODO: in JS2024 exosts Promise.withResolvers which can be used here, but it is not yet available in scala.js. + // TODO: research and maybe wrap in interval + val wrappedPromise = new Promise[JSResolveRejectWithPromise[R]]((resultResolve, resultReject) => { + resultResolve(JSResolveRejectWithPromise(JSResolveReject(outResolve, outReject), p)) + }) + wrappedPromise + } + + +} diff --git a/js/src/main/scala/cps/monads/jsfuture/JSFuture.scala b/js/src/main/scala/cps/monads/jsfuture/JSFuture.scala index 9f044cc7d..66c681d1f 100644 --- a/js/src/main/scala/cps/monads/jsfuture/JSFuture.scala +++ b/js/src/main/scala/cps/monads/jsfuture/JSFuture.scala @@ -1,6 +1,6 @@ package cps.monads.jsfuture -import cps._ +import cps.* import scala.language.implicitConversions @@ -51,7 +51,7 @@ object JSFutureExecutor: * and inside async use usual async/await monadic API. **/ class JSFuture[T]( - val executorOrUndef: js.UndefOr[js.Function2[js.Function1[T | js.Thenable[T], _], js.Function1[scala.Any, _], _]], + val executorOrUndef: js.UndefOr[js.Function2[js.Function1[T | js.Thenable[T], ?], js.Function1[scala.Any, ?], ?]], val futureOrUndef: js.UndefOr[Future[T]]) extends js.Promise[T]( if (js.isUndefined(executorOrUndef)) then if (js.isUndefined(futureOrUndef)) then diff --git a/shared/src/main/scala/cps/macros/forest/application/ApplyArgRecordScope.scala b/shared/src/main/scala/cps/macros/forest/application/ApplyArgRecordScope.scala index 8f6d19c0e..dab15e71d 100644 --- a/shared/src/main/scala/cps/macros/forest/application/ApplyArgRecordScope.scala +++ b/shared/src/main/scala/cps/macros/forest/application/ApplyArgRecordScope.scala @@ -284,6 +284,7 @@ trait ApplyArgRecordScope[F[_], CT, CC<:CpsMonadContext[F]]: case Some(syncBody) => if (cpsBody.isChanged) then if (term.tpe.isContextFunctionType && !allowUncontext) then + println(s"context cpsBody = ${cpsBody}, ") throw MacroError("Can't transform context function: TastyAPI don;t support this yet",posExpr(term)) val mt = MethodType(paramNames)(_ => paramTypes, _ => syncBody.tpe.widen) Lambda(owner, mt, diff --git a/shared/src/test/scala/snippets/gsoc2023/Snippets.scala b/shared/src/test/scala/snippets/gsoc2023/Snippets.scala index 7fb7f9680..66247e50f 100644 --- a/shared/src/test/scala/snippets/gsoc2023/Snippets.scala +++ b/shared/src/test/scala/snippets/gsoc2023/Snippets.scala @@ -1,5 +1,7 @@ package snippets.gsoc2024 + + /* import cps.*