diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PassPersistance.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PassPersistance.java index 5e7ed95963a4..e781c58ceb83 100644 --- a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PassPersistance.java +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PassPersistance.java @@ -104,10 +104,8 @@ protected void writeObject(TailCall.TailPosition obj, Output out) throws IOExcep @Override protected TailCall.TailPosition readObject(Input in) throws IOException, ClassNotFoundException { - var b = in.readBoolean(); - return b - ? org.enso.compiler.pass.analyse.TailCall$TailPosition$Tail$.MODULE$ - : org.enso.compiler.pass.analyse.TailCall$TailPosition$NotTail$.MODULE$; + in.readBoolean(); + return org.enso.compiler.pass.analyse.TailCall$TailPosition$Tail$.MODULE$; } } diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/TailCall.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/TailCall.scala index bd62febf20ad..35d9d10f2d05 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/TailCall.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/TailCall.scala @@ -9,7 +9,6 @@ import org.enso.compiler.core.ir.expression.{ Case, Comment, Error, - Foreign, Operator } import org.enso.compiler.core.ir.module.scope.Definition @@ -18,7 +17,6 @@ import org.enso.compiler.core.ir.{ CallArgument, DefinitionArgument, Diagnostic, - Empty, Expression, Function, Literal, @@ -28,7 +26,7 @@ import org.enso.compiler.core.ir.{ Type, Warning } -import org.enso.compiler.core.CompilerError +import org.enso.compiler.core.{CompilerError, IR} import org.enso.compiler.pass.IRPass import org.enso.compiler.pass.desugar._ import org.enso.compiler.pass.resolve.{ExpressionAnnotations, GlobalNames} @@ -36,8 +34,10 @@ import org.enso.compiler.pass.resolve.{ExpressionAnnotations, GlobalNames} /** This pass performs tail call analysis on the Enso IR. * * It is responsible for marking every single expression with whether it is in - * tail position or not. This allows the code generator to correctly create the + * tail position. This allows the code generator to correctly create the * Truffle nodes. + * If the expression is in tail position, [[TailPosition.Tail]] metadata is attached + * to it, otherwise, nothing is attached. * * This pass requires the context to provide: * @@ -61,6 +61,8 @@ case object TailCall extends IRPass { override lazy val invalidatedPasses: Seq[IRPass] = List() + private lazy val TAIL_META = new MetadataPair(this, TailPosition.Tail) + /** Analyses tail call state for expressions in a module. * * @param ir the Enso IR to process @@ -112,14 +114,14 @@ case object TailCall extends IRPass { .copy( body = analyseExpression(method.body, isInTailPosition = true) ) - .updateMetadata(new MetadataPair(this, TailPosition.Tail)) + .updateMetadata(TAIL_META) case method @ definition.Method .Explicit(_, body, _, _, _) => method .copy( body = analyseExpression(body, isInTailPosition = true) ) - .updateMetadata(new MetadataPair(this, TailPosition.Tail)) + .updateMetadata(TAIL_META) case _: definition.Method.Binding => throw new CompilerError( "Sugared method definitions should not occur during tail call " + @@ -127,7 +129,7 @@ case object TailCall extends IRPass { ) case _: Definition.Type => moduleDefinition.updateMetadata( - new MetadataPair(this, TailPosition.Tail) + TAIL_META ) case _: Definition.SugaredType => throw new CompilerError( @@ -153,7 +155,7 @@ case object TailCall extends IRPass { .copy(expression = analyseExpression(ann.expression, isInTailPosition = true) ) - .updateMetadata(new MetadataPair(this, TailPosition.Tail)) + .updateMetadata(TAIL_META) case err: Error => err } } @@ -176,44 +178,49 @@ case object TailCall extends IRPass { ) else expression expressionWithWarning match { - case empty: Empty => - empty.updateMetadata(new MetadataPair(this, TailPosition.NotTail)) case function: Function => analyseFunction(function, isInTailPosition) case caseExpr: Case => analyseCase(caseExpr, isInTailPosition) case typ: Type => analyseType(typ, isInTailPosition) case app: Application => analyseApplication(app, isInTailPosition) case name: Name => analyseName(name, isInTailPosition) - case foreign: Foreign => - foreign.updateMetadata(new MetadataPair(this, TailPosition.NotTail)) case literal: Literal => analyseLiteral(literal, isInTailPosition) case _: Comment => throw new CompilerError( "Comments should not be present during tail call analysis." ) case block @ Expression.Block(expressions, returnValue, _, _, _) => - block - .copy( - expressions = expressions.map( - analyseExpression(_, isInTailPosition = false) - ), - returnValue = analyseExpression(returnValue, isInTailPosition) - ) - .updateMetadata( - new MetadataPair(this, TailPosition.fromBool(isInTailPosition)) - ) + updateMetaIfInTailPosition( + isInTailPosition, + block + .copy( + expressions = expressions.map( + analyseExpression(_, isInTailPosition = false) + ), + returnValue = analyseExpression(returnValue, isInTailPosition) + ) + ) case binding @ Expression.Binding(_, expression, _, _) => - binding - .copy( - expression = analyseExpression(expression, isInTailPosition = false) - ) - .updateMetadata( - new MetadataPair(this, TailPosition.fromBool(isInTailPosition)) - ) - case err: Diagnostic => - err.updateMetadata( - new MetadataPair(this, TailPosition.fromBool(isInTailPosition)) + updateMetaIfInTailPosition( + isInTailPosition, + binding + .copy( + expression = + analyseExpression(expression, isInTailPosition = false) + ) ) + case err: Diagnostic => updateMetaIfInTailPosition(isInTailPosition, err) + } + } + + private def updateMetaIfInTailPosition[T <: IR]( + isInTailPosition: Boolean, + ir: T + ): T = { + if (isInTailPosition) { + ir.updateMetadata(TAIL_META) + } else { + ir } } @@ -224,9 +231,7 @@ case object TailCall extends IRPass { * @return `name`, annotated with tail position metadata */ def analyseName(name: Name, isInTailPosition: Boolean): Name = { - name.updateMetadata( - new MetadataPair(this, TailPosition.fromBool(isInTailPosition)) - ) + updateMetaIfInTailPosition(isInTailPosition, name) } /** Performs tail call analysis on a literal. @@ -240,9 +245,7 @@ case object TailCall extends IRPass { literal: Literal, isInTailPosition: Boolean ): Literal = { - literal.updateMetadata( - new MetadataPair(this, TailPosition.fromBool(isInTailPosition)) - ) + updateMetaIfInTailPosition(isInTailPosition, literal) } /** Performs tail call analysis on an application. @@ -256,43 +259,32 @@ case object TailCall extends IRPass { application: Application, isInTailPosition: Boolean ): Application = { - application match { + val newApp = application match { case app @ Application.Prefix(fn, args, _, _, _) => app .copy( function = analyseExpression(fn, isInTailPosition = false), arguments = args.map(analyseCallArg) ) - .updateMetadata( - new MetadataPair(this, TailPosition.fromBool(isInTailPosition)) - ) case force @ Application.Force(target, _, _) => force .copy( target = analyseExpression(target, isInTailPosition) ) - .updateMetadata( - new MetadataPair(this, TailPosition.fromBool(isInTailPosition)) - ) case vector @ Application.Sequence(items, _, _) => vector .copy(items = items.map(analyseExpression(_, isInTailPosition = false)) ) - .updateMetadata( - new MetadataPair(this, TailPosition.fromBool(isInTailPosition)) - ) case tSet @ Application.Typeset(expr, _, _) => tSet .copy(expression = expr.map(analyseExpression(_, isInTailPosition = false)) ) - .updateMetadata( - new MetadataPair(this, TailPosition.fromBool(isInTailPosition)) - ) case _: Operator => throw new CompilerError("Unexpected binary operator.") } + updateMetaIfInTailPosition(isInTailPosition, newApp) } /** Performs tail call analysis on a call site argument. @@ -308,7 +300,7 @@ case object TailCall extends IRPass { // Note [Call Argument Tail Position] value = analyseExpression(expr, isInTailPosition = true) ) - .updateMetadata(new MetadataPair(this, TailPosition.Tail)) + .updateMetadata(TAIL_META) } } @@ -343,11 +335,11 @@ case object TailCall extends IRPass { * @return `value`, annotated with tail position metadata */ def analyseType(value: Type, isInTailPosition: Boolean): Type = { - value - .mapExpressions(analyseExpression(_, isInTailPosition = false)) - .updateMetadata( - new MetadataPair(this, TailPosition.fromBool(isInTailPosition)) - ) + updateMetaIfInTailPosition( + isInTailPosition, + value + .mapExpressions(analyseExpression(_, isInTailPosition = false)) + ) } /** Performs tail call analysis on a case expression. @@ -358,7 +350,7 @@ case object TailCall extends IRPass { * @return `caseExpr`, annotated with tail position metadata */ def analyseCase(caseExpr: Case, isInTailPosition: Boolean): Case = { - caseExpr match { + val newCaseExpr = caseExpr match { case caseExpr @ Case.Expr(scrutinee, branches, _, _, _) => caseExpr .copy( @@ -366,12 +358,10 @@ case object TailCall extends IRPass { // Note [Analysing Branches in Case Expressions] branches = branches.map(analyseCaseBranch(_, isInTailPosition)) ) - .updateMetadata( - new MetadataPair(this, TailPosition.fromBool(isInTailPosition)) - ) case _: Case.Branch => throw new CompilerError("Unexpected case branch.") } + updateMetaIfInTailPosition(isInTailPosition, newCaseExpr) } /* Note [Analysing Branches in Case Expressions] @@ -396,17 +386,17 @@ case object TailCall extends IRPass { branch: Case.Branch, isInTailPosition: Boolean ): Case.Branch = { - branch - .copy( - pattern = analysePattern(branch.pattern), - expression = analyseExpression( - branch.expression, - isInTailPosition + updateMetaIfInTailPosition( + isInTailPosition, + branch + .copy( + pattern = analysePattern(branch.pattern), + expression = analyseExpression( + branch.expression, + isInTailPosition + ) ) - ) - .updateMetadata( - new MetadataPair(this, TailPosition.fromBool(isInTailPosition)) - ) + ) } /** Performs tail call analysis on a pattern. @@ -423,25 +413,20 @@ case object TailCall extends IRPass { .copy( name = analyseName(name, isInTailPosition = false) ) - .updateMetadata(new MetadataPair(this, TailPosition.NotTail)) case cons @ Pattern.Constructor(constructor, fields, _, _) => cons .copy( constructor = analyseName(constructor, isInTailPosition = false), fields = fields.map(analysePattern) ) - .updateMetadata(new MetadataPair(this, TailPosition.NotTail)) - case literal: Pattern.Literal => - literal - .updateMetadata(new MetadataPair(this, TailPosition.NotTail)) + case literal: Pattern.Literal => literal case tpePattern @ Pattern.Type(name, tpe, _, _) => tpePattern .copy( name = analyseName(name, isInTailPosition = false), tpe = analyseName(tpe, isInTailPosition = false) ) - case err: errors.Pattern => - err.updateMetadata(new MetadataPair(this, TailPosition.NotTail)) + case err: errors.Pattern => err case _: Pattern.Documentation => throw new CompilerError( "Branch documentation should be desugared at an earlier stage." @@ -474,10 +459,7 @@ case object TailCall extends IRPass { "Function sugar should not be present during tail call analysis." ) } - - resultFunction.updateMetadata( - new MetadataPair(this, TailPosition.fromBool(isInTailPosition)) - ) + updateMetaIfInTailPosition(isInTailPosition, resultFunction) } /** Performs tail call analysis on a function definition argument. @@ -492,12 +474,9 @@ case object TailCall extends IRPass { case arg @ DefinitionArgument.Specified(_, _, default, _, _, _) => arg .copy( - defaultValue = default.map(x => - analyseExpression(x, isInTailPosition = false) - .updateMetadata(new MetadataPair(this, TailPosition.NotTail)) - ) + defaultValue = + default.map(x => analyseExpression(x, isInTailPosition = false)) ) - .updateMetadata(new MetadataPair(this, TailPosition.NotTail)) } } @@ -509,7 +488,9 @@ case object TailCall extends IRPass { } object TailPosition { - /** The expression is in a tail position and can be tail call optimised. */ + /** The expression is in a tail position and can be tail call optimised. + * If the expression is not in tail-call position, it has no metadata attached. + */ final case object Tail extends TailPosition { override val metadataName: String = "TailCall.TailPosition.Tail" override def isTail: Boolean = true @@ -525,34 +506,6 @@ case object TailCall extends IRPass { ): Option[Tail.type] = Some(this) } - /** The expression is not in a tail position and cannot be tail call - * optimised. - */ - final case object NotTail extends TailPosition { - override val metadataName: String = "TailCall.TailPosition.NotTail" - override def isTail: Boolean = false - - override def duplicate(): Option[IRPass.IRMetadata] = Some(NotTail) - - /** @inheritdoc */ - override def prepareForSerialization(compiler: Compiler): NotTail.type = - this - - /** @inheritdoc */ - override def restoreFromSerialization( - compiler: Compiler - ): Option[NotTail.type] = Some(this) - } - - /** Implicitly converts a boolean to a [[TailPosition]] value. - * - * @param isTail the boolean - * @return the tail position value corresponding to `bool` - */ - implicit def fromBool(isTail: Boolean): TailPosition = { - if (isTail) TailPosition.Tail else TailPosition.NotTail - } - /** Implicitly converts the tail position data into a boolean. * * @param tailPosition the tail position value diff --git a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/analyse/TailCallTest.scala b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/analyse/TailCallTest.scala index e6828b07946a..3a8d2dc2c18a 100644 --- a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/analyse/TailCallTest.scala +++ b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/analyse/TailCallTest.scala @@ -135,7 +135,7 @@ class TailCallTest extends CompilerTest { implicit val ctx: InlineContext = mkNoTailContext val ir = code.preprocessExpression.get.analyse - ir.getMetadata(TailCall) shouldEqual Some(TailPosition.NotTail) + ir.getMetadata(TailCall) shouldEqual None } "mark the value of a tail assignment as non-tail" in { @@ -146,9 +146,7 @@ class TailCallTest extends CompilerTest { |""".stripMargin.preprocessExpression.get.analyse .asInstanceOf[Expression.Binding] binding.getMetadata(TailCall) shouldEqual Some(TailPosition.Tail) - binding.expression.getMetadata(TailCall) shouldEqual Some( - TailPosition.NotTail - ) + binding.expression.getMetadata(TailCall) shouldEqual None } } @@ -175,9 +173,7 @@ class TailCallTest extends CompilerTest { "mark the other expressions in the function as not tail" in { fnBody.expressions.foreach(expr => - expr.getMetadata(TailCall) shouldEqual Some( - TailPosition.NotTail - ) + expr.getMetadata(TailCall) shouldEqual None ) } @@ -254,16 +250,12 @@ class TailCallTest extends CompilerTest { .returnValue .asInstanceOf[Case.Expr] - caseExpr.getMetadata(TailCall) shouldEqual Some( - TailPosition.NotTail - ) + caseExpr.getMetadata(TailCall) shouldEqual None caseExpr.branches.foreach(branch => { val branchExpression = branch.expression.asInstanceOf[Application.Prefix] - branchExpression.getMetadata(TailCall) shouldEqual Some( - TailPosition.NotTail - ) + branchExpression.getMetadata(TailCall) shouldEqual None }) } @@ -317,16 +309,14 @@ class TailCallTest extends CompilerTest { val pattern = caseBranch.pattern.asInstanceOf[Pattern.Constructor] val patternConstructor = pattern.constructor - pattern.getMetadata(TailCall) shouldEqual Some(TailPosition.NotTail) - patternConstructor.getMetadata(TailCall) shouldEqual Some( - TailPosition.NotTail - ) + pattern.getMetadata(TailCall) shouldEqual None + patternConstructor.getMetadata(TailCall) shouldEqual None pattern.fields.foreach(f => { - f.getMetadata(TailCall) shouldEqual Some(TailPosition.NotTail) + f.getMetadata(TailCall) shouldEqual None f.asInstanceOf[Pattern.Name] .name - .getMetadata(TailCall) shouldEqual Some(TailPosition.NotTail) + .getMetadata(TailCall) shouldEqual None }) } } @@ -389,7 +379,7 @@ class TailCallTest extends CompilerTest { nonTailCallBody.expressions.head .asInstanceOf[Expression.Binding] .expression - .getMetadata(TailCall) shouldEqual Some(TailPosition.NotTail) + .getMetadata(TailCall) shouldEqual None } } @@ -422,9 +412,7 @@ class TailCallTest extends CompilerTest { "mark the block expressions as not tail" in { block.expressions.foreach(expr => - expr.getMetadata(TailCall) shouldEqual Some( - TailPosition.NotTail - ) + expr.getMetadata(TailCall) shouldEqual None ) }