Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement sequence syntax #170

Merged
merged 7 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 121 additions & 35 deletions meta/src/main/scala/meta/generator/DPIAPrimitives.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ object DPIAPrimitives {
case Parsed.Success(seq, _) =>
seq.foreach {
case PrimitiveDeclaration(Identifier(originalName), scalaParams, params, returnType)
if DPIA.isWellKindedDefinition(params, returnType) =>
if DPIA.isWellKindedDefinition(toParamList(definition, scalaParams), params, returnType) =>
val name = originalName.capitalize

val outputPath = (path / os.up) / s"$name.scala"
Expand All @@ -29,10 +29,7 @@ object DPIAPrimitives {
val packageName = path.relativeTo(shinePath).segments.dropRight(1).foldLeft[Term.Ref](Term.Name("shine")) {
case (t, name) => Term.Select(t, Term.Name(name))
}
val scalaParamsString = scalaParams match {
case Some((start, end)) => definition.substring(start, end)
case None => ""
}

val code = s"""// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
|// This file is automatically generated and should not be changed manually //
|// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
Expand All @@ -45,7 +42,7 @@ import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._

${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)}
${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), params, returnType)}

}""".toString()}
|""".stripMargin
Expand All @@ -59,7 +56,7 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)}
}

def generateCaseClass(name: scala.meta.Type.Name,
scalaParamsString: String,
scalaParams: Option[List[scala.meta.Term.Param]],
params: Seq[DPIA.Decl.AST.Param],
returnType: DPIA.Type.AST): scala.meta.Defn.Class = {
import scala.meta._
Expand All @@ -70,7 +67,7 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)}
case AST.CommType => (t"CommType", init"CommandPrimitive")
case _ => throw new Exception(s"Expected `exp', `acc' or `comm' as return type for ${name.value}")
}
val generatedParams = generateParams(scalaParamsString, params)
val generatedParams = generateParams(scalaParams, params)
q"""
final case class $name(...$generatedParams) extends $superClass {
{
Expand All @@ -83,26 +80,32 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)}

${generateVisitAndRebuild(name, generatedParams)}

..${if (scalaParamsString.nonEmpty && generatedParams.last.size > 1) {
..${if (scalaParams.nonEmpty && generatedParams.last.size > 1) {
List(generateUnwrap(generatedParams.last))
} else List() }
}
"""
}

def generateParams(scalaParamsString: String,
params: Seq[DPIA.Decl.AST.Param]): List[List[scala.meta.Term.Param]] = {
def toParamList(definition: String, scalaParams: Option[(Int, Int)]): Option[List[scala.meta.Term.Param]] = {
import scala.meta._

val scalaParams = if (scalaParamsString.nonEmpty) {
s"def foo($scalaParamsString)".parse[Stat].get match {
case declDef: Decl.Def => declDef.paramss
scalaParams.map { case (start, end) =>
s"def foo(${definition.substring(start, end)})".parse[Stat].get match {
case declDef: Decl.Def => declDef.paramss.head
}
} else {
List()
}
}

def generateParams(scalaParams: Option[List[scala.meta.Term.Param]],
params: Seq[DPIA.Decl.AST.Param]): List[List[scala.meta.Term.Param]] = {
import scala.meta._

scalaParams ++ List(params.map(generateParam).toList)
scalaParams match {
case Some(scalaParamList) =>
List(scalaParamList) ++ List(params.map(generateParam).toList)
case None =>
List(params.map(generateParam).toList)
}
}

def generateParam(param: DPIA.Decl.AST.Param): scala.meta.Term.Param = {
Expand All @@ -111,6 +114,8 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)}
param"val ${Term.Name(param.id.name)}: ${
param.ty match {
case Left(kindAST) => generateType(kindAST)
case Right(meta.parser.DPIA.Type.AST.VariadicType(_, typeAST)) =>
t"Seq[Phrase[${generatePhraseType(typeAST)}]]"
case Right(typeAST) => t"Phrase[${generatePhraseType(typeAST)}]"
}
}"
Expand All @@ -122,25 +127,102 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)}
..${params.
filter(param => param.ty.isRight). // only check types for parameters with phrase types
map(param =>
q"${Term.Name(param.id.name)} :: ${
param.ty match {
case Right([email protected](id, kind, t)) =>
q"""{
${Defn.Val(
mods = Nil,
pats = List(Pat.Var(name = Term.Name(id.name))),
decltpe = None,
rhs = q"${Term.Name(param.id.name)}.t.x"
)}
${generateTerm(typeAST)}
}"""
case Right(typeAST) => generateTerm(typeAST)
case Left(kindAST) => throw new Exception("This should not happen")
}}"
param.ty match {
case Right([email protected](id, _, _)) =>
q"""${Term.Name(param.id.name)} :: {
${Defn.Val(
mods = Nil,
pats = List(Pat.Var(name = Term.Name(id.name))),
decltpe = None,
rhs = q"${Term.Name(param.id.name)}.t.x"
)}
${generateTerm(typeAST)}
}"""
case Right(DPIA.Type.AST.VariadicType(n, typeAST)) =>
getUnrolledIds(typeAST) match {
case Some(unrolledIds) =>
val zips = unrolledIds.foldLeft[Term](Term.Name(param.id.name)) {
case (term, id) => q"$term.zip(${Term.Name(id.name)})"
}
val p = Term.Name(param.id.name)
val pattern = unrolledIds.foldRight[Pat](Pat.Var(p)) {
case (id, pattern) => p"($pattern, ${Pat.Var(Term.Name(id.name))})"
}
q"""$zips.foreach {
case ($pattern) => $p :: ${generateTerm(typeAST)}
}"""
case None =>
q"${Term.Name(param.id.name)}.foreach(_ :: ${generateTerm(typeAST)})"
}
case Right(typeAST) =>
q"${Term.Name(param.id.name)} :: ${generateTerm(typeAST)}"
case Left(_) => throw new Exception("Generation of type checks not necessary for kinds")
}
).toList}
}"""
}

def getUnrolledIds(typeAST: DPIA.Type.AST): Option[Seq[rise.Type.AST.UnrolledIdentifier]] = {
import DPIA.Type.AST
typeAST match {
case AST.ExpType(dataType, _) =>
getUnrolledIds(dataType)
case AST.AccType(dataType) =>
getUnrolledIds(dataType)
case AST.CommType => None
case AST.PairType(lhs, rhs) =>
for {
ids1 <- getUnrolledIds(lhs)
ids2 <- getUnrolledIds(rhs)
} yield ids1 concat ids2
case AST.FunType(inT, outT) =>
for {
ids1 <- getUnrolledIds(inT)
ids2 <- getUnrolledIds(outT)
} yield ids1 concat ids2
case AST.DepFunType(_, _, t) =>
getUnrolledIds(t)
case AST.Identifier(_) => None
case AST.VariadicType(n, ty) => throw new Exception("This function should not be called on a variadic type")
}
}

def getUnrolledIds(typeAST: rise.Type.AST): Option[Seq[rise.Type.AST.UnrolledIdentifier]] = {
michel-steuwer marked this conversation as resolved.
Show resolved Hide resolved
import rise.Type.AST
typeAST match {
case AST.Identifier(_) => None
case [email protected](_) => Some(Seq(id))
case AST.FunType(inT, outT) =>
for {
ids1 <- getUnrolledIds(inT)
ids2 <- getUnrolledIds(outT)
} yield ids1 concat ids2
case AST.DepFunType(_, _, t) =>
getUnrolledIds(t)
case AST.ImplicitDepFunType(_, _, t) =>
getUnrolledIds(t)
case AST.VariadicFunType(_, _, _) => throw new Exception("This function should not be called on a variadic type")
case AST.VariadicDepFunType(_, _, _, _) => throw new Exception("This function should not be called on a variadic type")
case AST.ScalarType(_) => None
case AST.NatType => None
case AST.OpaqueType(_) => None
case AST.VectorType(_, elemType) => getUnrolledIds(elemType)
case AST.IndexType(_) => None
case AST.PairType(lhs, rhs) =>
for {
ids1 <- getUnrolledIds(lhs)
ids2 <- getUnrolledIds(rhs)
} yield ids1 concat ids2
case AST.DepPairType(_, _, t) => getUnrolledIds(t)
case AST.NatToDataApply(_, _) => None
case AST.NatToDataLambda(_, t) => getUnrolledIds(t)
case AST.ArrayType(_, elemType) => getUnrolledIds(elemType)
case AST.DepArrayType(_, _) => None
case AST.FragmentType(_, _, _, dt, _, _) => getUnrolledIds(dt)
case AST.ManagedBufferType(t) => getUnrolledIds(t)
}
}

def generatePhraseType(typeAST: DPIA.Type.AST): scala.meta.Type = {
import scala.meta._
import meta.parser.DPIA.Type.AST
Expand All @@ -152,6 +234,7 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)}
case AST.FunType(inT, outT) => t"FunType[${generatePhraseType(inT)}, ${generatePhraseType(outT)}]"
case AST.DepFunType(id, kind, t) => t"DepFunType[${generateKindType(kind)}, ${generatePhraseType(t)}]"
case AST.Identifier(name) => Type.Name(name)
case AST.VariadicType(_, _) => throw new Exception("Can not generate Phrase Type for Variadic Type")
}
}

Expand All @@ -171,6 +254,7 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)}
case AST.MatrixLayout => Type.Name("MatrixLayout")
}
case AST.Access => Type.Name("AccessType")
case AST.VariadicKind(_, kind) => t"Seq[${generateType(kind)}]"
}
}

Expand All @@ -186,10 +270,11 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)}
case AST.Nat2Nat => Type.Name("NatToNatKind")
case AST.Nat2Data => Type.Name("NatToDataKind")
case AST.Nat => Type.Name("NatKind")
case AST.Fragment => ???
case AST.MatrixLayout => ???
case AST.Fragment => throw new Exception("Can not generate Kind for Fragment")
case AST.MatrixLayout => throw new Exception("Can not generate Kind for Matrix Layout")
}
case AST.Access => Type.Name("AccessKind")
case AST.VariadicKind(_, _) => throw new Exception("Can not generate Kind for Variadic Kind")
}
}

Expand All @@ -210,6 +295,7 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)}
case AST.DepFunType(id, kind, t) =>
q"DepFunType[${generateKindType(kind)}, PhraseType](${Term.Name(id.name)}, ${generateTerm(t)})"
case AST.Identifier(name) => Term.Name(name)
case AST.VariadicType(_, _) => throw new Exception("Can not generate Term for Variadic Type")
}
}

Expand Down
68 changes: 55 additions & 13 deletions meta/src/main/scala/meta/generator/RisePrimitives.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package meta.generator
import fastparse.{Parsed, parse}
import meta.parser._
import meta.parser.rise.Kind
import meta.parser.rise.Kind.AST

import scala.meta.Term

object RisePrimitives {
def main(args: Array[String]): Unit = {
Expand All @@ -18,16 +21,11 @@ object RisePrimitives {
case Parsed.Success(seq, _) =>
seq.foreach {
case rise.Decl.AST.PrimitiveDeclaration(rise.Decl.AST.Identifier(name), scalaParams, typeSignature)
if rise.isWellKindedType(typeSignature) =>
if rise.isWellKindedType(toParamList(definition, scalaParams), typeSignature) =>
val outputPath = (path / os.up) / s"$name.scala"
println(s"Generate $outputPath")

val generatedDef = scalaParams match {
case None =>
generateObject(name, typeSignature)
case Some((start, end)) =>
generateCaseClass(name, definition.substring(start, end), typeSignature)
}
val generatedDef = generate(name, toParamList(definition, scalaParams), typeSignature)

import scala.meta._
val packageName = path.relativeTo(risePath).segments.dropRight(1).foldLeft[Term.Ref](Term.Name("rise")) {
Expand Down Expand Up @@ -58,6 +56,22 @@ import arithexpr.arithmetic._
})
}

def toParamList(definition: String, scalaParams: Option[(Int, Int)]): Option[List[Term.Param]] = {
import scala.meta._
scalaParams.map { case (start, end) =>
s"def foo(${definition.substring(start, end)})".parse[Stat].get match {
case declDef: Decl.Def => declDef.paramss.head
}
}
}

def generate(name: String, params: Option[List[Term.Param]], typeSignature: rise.Type.AST): scala.meta.Term.Block =
params match {
case None => generateObject(name, typeSignature)
case Some(params) => generateCaseClass(name, params, typeSignature)
}


def generateObject(name: String, typeSignature: rise.Type.AST): scala.meta.Term.Block = {
import scala.meta._
val generated = q"""{
Expand All @@ -83,12 +97,9 @@ import arithexpr.arithmetic._
generated
}

def generateCaseClass(name: String, paramsString: String, typeSignature: rise.Type.AST): scala.meta.Term.Block = {
def generateCaseClass(name: String, params: List[Term.Param], typeSignature: rise.Type.AST): scala.meta.Term.Block = {
import scala.meta._

val params = s"def foo($paramsString)".parse[Stat].get match {
case declDef: Decl.Def => declDef.paramss.head
}
val args: List[Term.Name] = params.map(p => Term.Name(p.name.value))
val types: List[Type] = params.map(p => p.decltpe.get)

Expand Down Expand Up @@ -119,7 +130,11 @@ import arithexpr.arithmetic._
}
}

def unapply(arg: rise.core.Expr): Option[..$types] = arg match {
def unapply(arg: rise.core.Expr): ${if (types.length > 1) {
t"Option[(..$types)]"
} else {
t"Option[..$types]"}
} = arg match {
case p: Primitive =>
Some(..${generateMemberAccesses(args)})
case _ => None
Expand All @@ -138,6 +153,30 @@ import arithexpr.arithmetic._
q"expl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t)})"
case rise.Type.AST.ImplicitDepFunType(id, kind, t) =>
q"impl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t)})"
case rise.Type.AST.VariadicFunType(_, rise.Type.AST.UnrolledIdentifier(inTs), outT) =>
q"""${Term.Name(inTs)}.foldRight(${generateTypeScheme(outT)}: Type) {
case (lhsT, rhsT) => lhsT ->: rhsT
}"""
case rise.Type.AST.VariadicFunType(rise.Type.AST.Identifier(n), inTs, outT) =>
q"""Seq.fill(${Term.Name(n)})(${generateDataType(inTs)}).foldRight(${generateTypeScheme(outT)}: Type) {
case (lhsT, rhsT) => lhsT ->: rhsT
}"""
case rise.Type.AST.VariadicDepFunType(n, id, kind, t) =>
val (createIds, typeName) = kind match {
case AST.Data => (q"""DataTypeIdentifier(freshName("dt"), isExplicit = true)""", Type.Name("DataKind"))
case AST.Address => ???
case AST.Nat2Nat => ???
case AST.Nat2Data => ???
case AST.Nat => ???
case AST.Fragment => ???
case AST.MatrixLayout => ???
}
q"""{
val ${Pat.Var(Term.Name(id.name))} = Seq.fill(${Term.Name(n.name)})($createIds)
${Term.Name(id.name)}.foldRight(${generateTypeScheme(t)}: Type) {
case (id, t) => DepFunType[$typeName, Type](id, t)
}
}"""
case _ => generateDataType(typeAST)
}
}
Expand All @@ -147,6 +186,8 @@ import arithexpr.arithmetic._
typeAST match {
case rise.Type.AST.Identifier(name) =>
Term.Name(name)
case rise.Type.AST.UnrolledIdentifier(name) =>
Term.Name(name)
case rise.Type.AST.ScalarType(t) =>
t.parse[Term].get
case rise.Type.AST.NatType =>
Expand Down Expand Up @@ -177,7 +218,8 @@ import arithexpr.arithmetic._
case rise.Type.AST.ManagedBufferType(dt) =>
q"ManagedBufferType(${generateDataType(dt)})"
case rise.Type.AST.FunType(_, _) | rise.Type.AST.DepFunType(_, _, _) |
rise.Type.AST.ImplicitDepFunType(_, _, _) => ???
rise.Type.AST.ImplicitDepFunType(_, _, _) | rise.Type.AST.VariadicFunType(_, _, _) |
rise.Type.AST.VariadicDepFunType(_, _, _, _) => ???
}
}

Expand Down
Loading