Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement sequence syntax #170

Merged
merged 7 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
420 changes: 247 additions & 173 deletions meta/src/main/scala/meta/generator/DPIAPrimitives.scala

Large diffs are not rendered by default.

118 changes: 93 additions & 25 deletions meta/src/main/scala/meta/generator/RisePrimitives.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,34 @@ package meta.generator
import fastparse.{Parsed, parse}
import meta.parser._
import meta.parser.rise.Kind
import meta.parser.rise.Kind.AST

object RisePrimitives {
def main(args: Array[String]): Unit = {
val sourceDir = args.head
val sourceDir = args.head // excepts one argument that is the source directory of the rise repo (i.e. 'rise/src')
val risePath = os.Path(sourceDir) / "rise"
os.walk.stream(risePath).filter(_.ext == "rise").foreach(path => {
os.walk.stream(risePath).filter(_.ext == "rise").foreach(path => { // for each file with the `.rise` extension ...

val definition = os.read(path)
parse(definition, rise.Decl.PrimitiveDeclarations(_)) match {
val definition = os.read(path) // ... read the file content
parse(definition, rise.Decl.PrimitiveDeclarations(_)) match { // ... and parse it
case failure: Parsed.Failure =>
println(s"Failed to parse `${failure.extra.input}'")
println(s" $failure")
case Parsed.Success(seq, _) =>
seq.foreach {
case Parsed.Success(seq, _) => seq.foreach { // ... if successful go over all
case rise.Decl.AST.PrimitiveDeclaration(rise.Decl.AST.Identifier(name), scalaParams, typeSignature)
if rise.isWellKindedType(typeSignature) =>
if rise.isWellKindedType( // ... well kinded declarations
toParamList(definition, scalaParams), typeSignature) =>
val outputPath = (path / os.up) / s"$name.scala"
println(s"Generate $outputPath")

val generatedDef = scalaParams match {
case None =>
generateObject(name, typeSignature)
case Some((start, end)) =>
generateCaseClass(name, definition.substring(start, end), typeSignature)
}
// ... and generate class or object definition
val generatedDef = generate(name, toParamList(definition, scalaParams), typeSignature)

import scala.meta._
val packageName = path.relativeTo(risePath).segments.dropRight(1).foldLeft[Term.Ref](Term.Name("rise")) {
case (t, name) => Term.Select(t, Term.Name(name))
}
// ... combine into the generated source code
val code = s"""// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
|// This file is automatically generated and should not be changed manually //
|// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
Expand All @@ -50,14 +48,34 @@ import arithexpr.arithmetic._
}""".toString()}
|""".stripMargin

os.write.over(outputPath, code)
os.write.over(outputPath, code) // ... and finally write out to disk.
case rise.Decl.AST.PrimitiveDeclaration(name, _, typeSignature) =>
println(s"Could not generate code for `$name' as type signature `$typeSignature' is not well kinded.")
}
}
})
}

// parse scala parameters (i.e. parameters that are not part of the rise language) into a list of parameters
def toParamList(definition: String, scalaParams: Option[(Int, Int)]): Option[List[scala.meta.Term.Param]] = {
import scala.meta._
scalaParams.map { case (start, end) =>
s"def foo(${definition.substring(start, end)})".parse[Stat].get match {
case declDef: Decl.Def => declDef.paramss.head
}
}
}

// generate either an object (if there are no scala parameters) or a case class (if there are some)
def generate(name: String,
params: Option[List[scala.meta.Term.Param]],
typeSignature: rise.Type.AST): scala.meta.Term.Block =
params match {
case None => generateObject(name, typeSignature)
case Some(params) => generateCaseClass(name, params, typeSignature)
}


def generateObject(name: String, typeSignature: rise.Type.AST): scala.meta.Term.Block = {
import scala.meta._
val generated = q"""{
Expand All @@ -83,12 +101,11 @@ import arithexpr.arithmetic._
generated
}

def generateCaseClass(name: String, paramsString: String, typeSignature: rise.Type.AST): scala.meta.Term.Block = {
def generateCaseClass(name: String,
params: List[scala.meta.Term.Param],
typeSignature: rise.Type.AST): scala.meta.Term.Block = {
import scala.meta._

val params = s"def foo($paramsString)".parse[Stat].get match {
case declDef: Decl.Def => declDef.paramss.head
}
val args: List[Term.Name] = params.map(p => Term.Name(p.name.value))
val types: List[Type] = params.map(p => p.decltpe.get)

Expand Down Expand Up @@ -119,16 +136,21 @@ import arithexpr.arithmetic._
}
}

def unapply(arg: rise.core.Expr): Option[..$types] = arg match {
case p: Primitive =>
Some(..${generateMemberAccesses(args)})
def unapply(arg: rise.core.Expr): ${if (types.length > 1) {
t"Option[(..$types)]" // create tuple if there are multiple type parameters
} else {
t"Option[..$types]"}
} = arg match {
case p: Primitive => Some(..${generateMemberAccesses(args)})
case _ => None
}
}
"""
generated
}

// generate a scala term representing the type scheme, i.e. something like this for map:
// impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: n`.`s ->: n`.`t } } }
def generateTypeScheme(typeAST: rise.Type.AST): scala.meta.Term = {
import scala.meta._
typeAST match {
Expand All @@ -138,21 +160,64 @@ import arithexpr.arithmetic._
q"expl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t)})"
case rise.Type.AST.ImplicitDepFunType(id, kind, t) =>
q"impl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t)})"
case rise.Type.AST.VariadicFunType(_, rise.Type.AST.UnrolledIdentifier(inTs), outT) =>
// variadic function type with unrolled identifier, e.g.: n*(*inTs ->) outT
// generates: inTs.foldRight(outT){ case (lhsT, rhsT) => lhs ->: rhs }
// to represent n-many function types: inT0 ->: inT1 ->: ... ->: outT
q"""${Term.Name(inTs)}.foldRight(${generateTypeScheme(outT)}: Type) {
case (lhsT, rhsT) => lhsT ->: rhsT
}"""
case rise.Type.AST.VariadicFunType(rise.Type.AST.Identifier(n), dt, outT) =>
// variadic function type without an unrolled identifier, e.g: n*(dt ->) outT
// generated: Seq.fill(n)(dt).foldRight(outT){ case (lhsT, rhsT) => lhs ->: rhs }
// to represent n-many function types: dt -> dt -> ... -> outT
q"""Seq.fill(${Term.Name(n)})(${generateDataType(dt)}).foldRight(${generateTypeScheme(outT)}: Type) {
case (lhsT, rhsT) => lhsT ->: rhsT
}"""
case rise.Type.AST.VariadicDepFunType(n, ids, kind, t) =>
// variadic dependent function type, e.g.: n*((ids: kind) ->) t
// generates:
// val ids = Seq.fill(n)(DataTypeIdentifier(freshName("dt"), isExplicit = true))
// ids.foldRight(t){ case (id, t) => DepFunType[DataKind](id, t) }
// to represent n-many dependent function types: (id0: kind) -> (id1: kind) -> ... -> t
val (createIds, typeName) = kind match {
case AST.Data =>
(q"""DataTypeIdentifier(freshName("dt"), isExplicit = true)""", Type.Name("DataKind"))
case AST.Address =>
(q"""AddressSpaceIdentifier(freshName("a"), isExplicit = true)""", Type.Name("AddressSpaceKind"))
case AST.Nat2Nat =>
(q"""NatToNatIdentifier(freshName("n2n"), isExplicit = true)""", Type.Name("NatToNatKind"))
case AST.Nat2Data =>
(q"""NatToDataIdentifier(freshName("n2d"), isExplicit = true)""", Type.Name("NatToDataKind"))
case AST.Nat =>
(q"""NatIdentifier(freshName("n"), isExplicit = true)""", Type.Name("NatKind"))
case AST.Fragment => throw new Exception("No support for Fragment Kind yet")
case AST.MatrixLayout => throw new Exception("No support for Matrix Layout Kind yet")
}
q"""{
val ${Pat.Var(Term.Name(ids.name))} = Seq.fill(${Term.Name(n.name)})($createIds)
${Term.Name(ids.name)}.foldRight(${generateTypeScheme(t)}: Type) {
case (id, t) => DepFunType[$typeName, Type](id, t)
}
}"""
case _ => generateDataType(typeAST)
}
}

// generate a scala term representing a rise data type
def generateDataType(typeAST: rise.Type.AST): scala.meta.Term = {
import scala.meta._
typeAST match {
case rise.Type.AST.Identifier(name) =>
Term.Name(name)
case rise.Type.AST.UnrolledIdentifier(name) =>
Term.Name(name)
case rise.Type.AST.ScalarType(t) =>
t.parse[Term].get
case rise.Type.AST.NatType =>
q"NatType"
case rise.Type.AST.OpaqueType(name) =>
q"OpaqueType($name)"
q"OpaqueType(${Lit.String(name)})"
case rise.Type.AST.VectorType(size, elemType) =>
q"VectorType(${generateNat(size)}, ${generateDataType(elemType)})"
case rise.Type.AST.IndexType(size) =>
Expand All @@ -162,7 +227,7 @@ import arithexpr.arithmetic._
case rise.Type.AST.DepPairType(id, kind, t) => kind match {
case Kind.AST.Nat =>
q"Nat `**` ((${Term.Name(id.name)}: Nat) => ${generateDataType(t)})"
case _ => ???
case _ => throw new Exception("DepPair types currently only support Nat Kind")
}
case rise.Type.AST.NatToDataApply(f, n) =>
q"NatToDataApply(${generateDataType(f)}, ${generateNat(n)})"
Expand All @@ -173,11 +238,14 @@ import arithexpr.arithmetic._
case rise.Type.AST.DepArrayType(size, fdt) =>
q"DepArrayType(${generateNat(size)}, ${generateDataType(fdt)})"
case rise.Type.AST.FragmentType(n, m, k, elemType, fKind, mLayout) =>
q"FragmentType(${generateNat(n)}, ${generateNat(m)}, ${generateNat(k)}, ${generateDataType(elemType)}, ${generateFragment(fKind)}, ${generateMatrixLayout(mLayout)})"
q"""FragmentType(${generateNat(n)}, ${generateNat(m)}, ${generateNat(k)}, ${generateDataType(elemType)},
${generateFragment(fKind)}, ${generateMatrixLayout(mLayout)})"""
case rise.Type.AST.ManagedBufferType(dt) =>
q"ManagedBufferType(${generateDataType(dt)})"
case rise.Type.AST.FunType(_, _) | rise.Type.AST.DepFunType(_, _, _) |
rise.Type.AST.ImplicitDepFunType(_, _, _) => ???
rise.Type.AST.ImplicitDepFunType(_, _, _) | rise.Type.AST.VariadicFunType(_, _, _) |
rise.Type.AST.VariadicDepFunType(_, _, _, _) =>
throw new Exception("This should not happen, there are not data types")
}
}

Expand Down
16 changes: 11 additions & 5 deletions meta/src/main/scala/meta/parser/DPIA/Decl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@ object Decl {
sealed trait AST
object AST {
case class Identifier(name: String) extends AST
case class Param(id: Identifier, ty: Either[Kind.AST, Type.AST]) extends AST
case class Param(id: Identifier, ty: KindOrType) extends AST
case class PrimitiveDeclaration(id: Identifier,
scalaParams: Option[(Int, Int)],
params: Seq[Param],
returnType: Type.AST) extends AST
}

sealed trait KindOrType
object KindOrType {
final case class Kind(kind: meta.parser.DPIA.Kind.AST) extends KindOrType
final case class Type(typeAST: meta.parser.DPIA.Type.AST) extends KindOrType
}

def PrimitiveDeclarations[_: P]: P[Seq[AST.PrimitiveDeclaration]] =
P(Start ~ PrimitiveDeclaration.rep(1) ~ End)

// def drop(n: nat, m: nat, t: data, input: exp[n+m.t, read]): exp[m.t, read]
// def mapGlobal[dim: Int](n: nat, s: data, t: data, f: exp[s, read] -> exp[t, read], array: exp[n.s, read]): exp[n.t, read]
def PrimitiveDeclaration[_: P]: P[AST.PrimitiveDeclaration] = {
import scalaparse.Scala.TrailingCommaOps
def ScalaParams: P[(Int, Int)] = {
Expand All @@ -29,8 +33,10 @@ object Decl {
}

def Param: P[AST.Param] = (
(Identifier.map(AST.Identifier) ~ ":" ~ Kind.Kind).map(pair => AST.Param(pair._1, Left(pair._2)))
| (Identifier.map(AST.Identifier) ~ ":" ~ Type.PhraseType).map(pair => AST.Param(pair._1, Right(pair._2)))
(Identifier.map(AST.Identifier) ~ ":" ~ Kind.Kind).map(
pair => AST.Param(pair._1, KindOrType.Kind(pair._2)))
| (Identifier.map(AST.Identifier) ~ ":" ~ Type.PhraseType).map(
pair => AST.Param(pair._1, KindOrType.Type(pair._2)))
)

def Params: P[Seq[AST.Param]] = Param.repTC(0)
Expand Down
12 changes: 8 additions & 4 deletions meta/src/main/scala/meta/parser/DPIA/Kind.scala
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
package meta.parser.DPIA

import fastparse.ScalaWhitespace._
import fastparse._
import meta.parser.rise
import meta.parser.shared.Identifier

object Kind {
sealed trait AST
object AST {
case class RiseKind(riseKind: rise.Kind.AST) extends AST
case object Access extends AST
case class VariadicKind(n: String, kind: AST) extends AST
}

def Kind[_: P]: P[AST] = P(
rise.Kind.Kind.map(AST.RiseKind) |
"access".!.map(_ => AST.Access)
)
def Kind[_: P]: P[AST] = {
def OnlyKind: P[AST] = P(rise.Kind.Kind.map(AST.RiseKind) | "access".!.map(_ => AST.Access))

OnlyKind | (Identifier ~ "*" ~ OnlyKind).map(AST.VariadicKind.tupled)
}
}
8 changes: 7 additions & 1 deletion meta/src/main/scala/meta/parser/DPIA/Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ object Type {
case class FunType(inT: AST, outT: AST) extends AST
case class DepFunType(id: Identifier, kind: Kind.AST, t: AST) extends AST
case class Identifier(name: String) extends AST

case class VariadicType(n: Identifier, ty: AST) extends AST
}

object Access {
Expand Down Expand Up @@ -60,6 +62,10 @@ object Type {

def NonFunPhraseType: P[AST] = P( ExpType | AccType | VarType | CommType | PairType | DepFunType )

P( FunType | NonFunPhraseType )
def OnlyPhraseType: P[AST] = P( FunType | NonFunPhraseType )

def VariadicType: P[AST] = P(Identifier.map(AST.Identifier) ~ "*" ~ OnlyPhraseType).map(AST.VariadicType.tupled)

OnlyPhraseType | VariadicType
}
}
Loading