Skip to content

Commit

Permalink
Merge pull request #170 from rise-lang/sequence-syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
michel-steuwer authored May 11, 2021
2 parents 91a3d2a + 1395b37 commit 0d358ed
Show file tree
Hide file tree
Showing 250 changed files with 3,262 additions and 713 deletions.
420 changes: 247 additions & 173 deletions meta/src/main/scala/meta/generator/DPIAPrimitives.scala

Large diffs are not rendered by default.

118 changes: 93 additions & 25 deletions meta/src/main/scala/meta/generator/RisePrimitives.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,34 @@ package meta.generator
import fastparse.{Parsed, parse}
import meta.parser._
import meta.parser.rise.Kind
import meta.parser.rise.Kind.AST

object RisePrimitives {
def main(args: Array[String]): Unit = {
val sourceDir = args.head
val sourceDir = args.head // excepts one argument that is the source directory of the rise repo (i.e. 'rise/src')
val risePath = os.Path(sourceDir) / "rise"
os.walk.stream(risePath).filter(_.ext == "rise").foreach(path => {
os.walk.stream(risePath).filter(_.ext == "rise").foreach(path => { // for each file with the `.rise` extension ...

val definition = os.read(path)
parse(definition, rise.Decl.PrimitiveDeclarations(_)) match {
val definition = os.read(path) // ... read the file content
parse(definition, rise.Decl.PrimitiveDeclarations(_)) match { // ... and parse it
case failure: Parsed.Failure =>
println(s"Failed to parse `${failure.extra.input}'")
println(s" $failure")
case Parsed.Success(seq, _) =>
seq.foreach {
case Parsed.Success(seq, _) => seq.foreach { // ... if successful go over all
case rise.Decl.AST.PrimitiveDeclaration(rise.Decl.AST.Identifier(name), scalaParams, typeSignature)
if rise.isWellKindedType(typeSignature) =>
if rise.isWellKindedType( // ... well kinded declarations
toParamList(definition, scalaParams), typeSignature) =>
val outputPath = (path / os.up) / s"$name.scala"
println(s"Generate $outputPath")

val generatedDef = scalaParams match {
case None =>
generateObject(name, typeSignature)
case Some((start, end)) =>
generateCaseClass(name, definition.substring(start, end), typeSignature)
}
// ... and generate class or object definition
val generatedDef = generate(name, toParamList(definition, scalaParams), typeSignature)

import scala.meta._
val packageName = path.relativeTo(risePath).segments.dropRight(1).foldLeft[Term.Ref](Term.Name("rise")) {
case (t, name) => Term.Select(t, Term.Name(name))
}
// ... combine into the generated source code
val code = s"""// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
|// This file is automatically generated and should not be changed manually //
|// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
Expand All @@ -50,14 +48,34 @@ import arithexpr.arithmetic._
}""".toString()}
|""".stripMargin

os.write.over(outputPath, code)
os.write.over(outputPath, code) // ... and finally write out to disk.
case rise.Decl.AST.PrimitiveDeclaration(name, _, typeSignature) =>
println(s"Could not generate code for `$name' as type signature `$typeSignature' is not well kinded.")
}
}
})
}

// parse scala parameters (i.e. parameters that are not part of the rise language) into a list of parameters
def toParamList(definition: String, scalaParams: Option[(Int, Int)]): Option[List[scala.meta.Term.Param]] = {
import scala.meta._
scalaParams.map { case (start, end) =>
s"def foo(${definition.substring(start, end)})".parse[Stat].get match {
case declDef: Decl.Def => declDef.paramss.head
}
}
}

// generate either an object (if there are no scala parameters) or a case class (if there are some)
def generate(name: String,
params: Option[List[scala.meta.Term.Param]],
typeSignature: rise.Type.AST): scala.meta.Term.Block =
params match {
case None => generateObject(name, typeSignature)
case Some(params) => generateCaseClass(name, params, typeSignature)
}


def generateObject(name: String, typeSignature: rise.Type.AST): scala.meta.Term.Block = {
import scala.meta._
val generated = q"""{
Expand All @@ -83,12 +101,11 @@ import arithexpr.arithmetic._
generated
}

def generateCaseClass(name: String, paramsString: String, typeSignature: rise.Type.AST): scala.meta.Term.Block = {
def generateCaseClass(name: String,
params: List[scala.meta.Term.Param],
typeSignature: rise.Type.AST): scala.meta.Term.Block = {
import scala.meta._

val params = s"def foo($paramsString)".parse[Stat].get match {
case declDef: Decl.Def => declDef.paramss.head
}
val args: List[Term.Name] = params.map(p => Term.Name(p.name.value))
val types: List[Type] = params.map(p => p.decltpe.get)

Expand Down Expand Up @@ -119,16 +136,21 @@ import arithexpr.arithmetic._
}
}

def unapply(arg: rise.core.Expr): Option[..$types] = arg match {
case p: Primitive =>
Some(..${generateMemberAccesses(args)})
def unapply(arg: rise.core.Expr): ${if (types.length > 1) {
t"Option[(..$types)]" // create tuple if there are multiple type parameters
} else {
t"Option[..$types]"}
} = arg match {
case p: Primitive => Some(..${generateMemberAccesses(args)})
case _ => None
}
}
"""
generated
}

// generate a scala term representing the type scheme, i.e. something like this for map:
// impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: n`.`s ->: n`.`t } } }
def generateTypeScheme(typeAST: rise.Type.AST): scala.meta.Term = {
import scala.meta._
typeAST match {
Expand All @@ -138,21 +160,64 @@ import arithexpr.arithmetic._
q"expl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t)})"
case rise.Type.AST.ImplicitDepFunType(id, kind, t) =>
q"impl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t)})"
case rise.Type.AST.VariadicFunType(_, rise.Type.AST.UnrolledIdentifier(inTs), outT) =>
// variadic function type with unrolled identifier, e.g.: n*(*inTs ->) outT
// generates: inTs.foldRight(outT){ case (lhsT, rhsT) => lhs ->: rhs }
// to represent n-many function types: inT0 ->: inT1 ->: ... ->: outT
q"""${Term.Name(inTs)}.foldRight(${generateTypeScheme(outT)}: Type) {
case (lhsT, rhsT) => lhsT ->: rhsT
}"""
case rise.Type.AST.VariadicFunType(rise.Type.AST.Identifier(n), dt, outT) =>
// variadic function type without an unrolled identifier, e.g: n*(dt ->) outT
// generated: Seq.fill(n)(dt).foldRight(outT){ case (lhsT, rhsT) => lhs ->: rhs }
// to represent n-many function types: dt -> dt -> ... -> outT
q"""Seq.fill(${Term.Name(n)})(${generateDataType(dt)}).foldRight(${generateTypeScheme(outT)}: Type) {
case (lhsT, rhsT) => lhsT ->: rhsT
}"""
case rise.Type.AST.VariadicDepFunType(n, ids, kind, t) =>
// variadic dependent function type, e.g.: n*((ids: kind) ->) t
// generates:
// val ids = Seq.fill(n)(DataTypeIdentifier(freshName("dt"), isExplicit = true))
// ids.foldRight(t){ case (id, t) => DepFunType[DataKind](id, t) }
// to represent n-many dependent function types: (id0: kind) -> (id1: kind) -> ... -> t
val (createIds, typeName) = kind match {
case AST.Data =>
(q"""DataTypeIdentifier(freshName("dt"), isExplicit = true)""", Type.Name("DataKind"))
case AST.Address =>
(q"""AddressSpaceIdentifier(freshName("a"), isExplicit = true)""", Type.Name("AddressSpaceKind"))
case AST.Nat2Nat =>
(q"""NatToNatIdentifier(freshName("n2n"), isExplicit = true)""", Type.Name("NatToNatKind"))
case AST.Nat2Data =>
(q"""NatToDataIdentifier(freshName("n2d"), isExplicit = true)""", Type.Name("NatToDataKind"))
case AST.Nat =>
(q"""NatIdentifier(freshName("n"), isExplicit = true)""", Type.Name("NatKind"))
case AST.Fragment => throw new Exception("No support for Fragment Kind yet")
case AST.MatrixLayout => throw new Exception("No support for Matrix Layout Kind yet")
}
q"""{
val ${Pat.Var(Term.Name(ids.name))} = Seq.fill(${Term.Name(n.name)})($createIds)
${Term.Name(ids.name)}.foldRight(${generateTypeScheme(t)}: Type) {
case (id, t) => DepFunType[$typeName, Type](id, t)
}
}"""
case _ => generateDataType(typeAST)
}
}

// generate a scala term representing a rise data type
def generateDataType(typeAST: rise.Type.AST): scala.meta.Term = {
import scala.meta._
typeAST match {
case rise.Type.AST.Identifier(name) =>
Term.Name(name)
case rise.Type.AST.UnrolledIdentifier(name) =>
Term.Name(name)
case rise.Type.AST.ScalarType(t) =>
t.parse[Term].get
case rise.Type.AST.NatType =>
q"NatType"
case rise.Type.AST.OpaqueType(name) =>
q"OpaqueType($name)"
q"OpaqueType(${Lit.String(name)})"
case rise.Type.AST.VectorType(size, elemType) =>
q"VectorType(${generateNat(size)}, ${generateDataType(elemType)})"
case rise.Type.AST.IndexType(size) =>
Expand All @@ -162,7 +227,7 @@ import arithexpr.arithmetic._
case rise.Type.AST.DepPairType(id, kind, t) => kind match {
case Kind.AST.Nat =>
q"Nat `**` ((${Term.Name(id.name)}: Nat) => ${generateDataType(t)})"
case _ => ???
case _ => throw new Exception("DepPair types currently only support Nat Kind")
}
case rise.Type.AST.NatToDataApply(f, n) =>
q"NatToDataApply(${generateDataType(f)}, ${generateNat(n)})"
Expand All @@ -173,11 +238,14 @@ import arithexpr.arithmetic._
case rise.Type.AST.DepArrayType(size, fdt) =>
q"DepArrayType(${generateNat(size)}, ${generateDataType(fdt)})"
case rise.Type.AST.FragmentType(n, m, k, elemType, fKind, mLayout) =>
q"FragmentType(${generateNat(n)}, ${generateNat(m)}, ${generateNat(k)}, ${generateDataType(elemType)}, ${generateFragment(fKind)}, ${generateMatrixLayout(mLayout)})"
q"""FragmentType(${generateNat(n)}, ${generateNat(m)}, ${generateNat(k)}, ${generateDataType(elemType)},
${generateFragment(fKind)}, ${generateMatrixLayout(mLayout)})"""
case rise.Type.AST.ManagedBufferType(dt) =>
q"ManagedBufferType(${generateDataType(dt)})"
case rise.Type.AST.FunType(_, _) | rise.Type.AST.DepFunType(_, _, _) |
rise.Type.AST.ImplicitDepFunType(_, _, _) => ???
rise.Type.AST.ImplicitDepFunType(_, _, _) | rise.Type.AST.VariadicFunType(_, _, _) |
rise.Type.AST.VariadicDepFunType(_, _, _, _) =>
throw new Exception("This should not happen, there are not data types")
}
}

Expand Down
16 changes: 11 additions & 5 deletions meta/src/main/scala/meta/parser/DPIA/Decl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@ object Decl {
sealed trait AST
object AST {
case class Identifier(name: String) extends AST
case class Param(id: Identifier, ty: Either[Kind.AST, Type.AST]) extends AST
case class Param(id: Identifier, ty: KindOrType) extends AST
case class PrimitiveDeclaration(id: Identifier,
scalaParams: Option[(Int, Int)],
params: Seq[Param],
returnType: Type.AST) extends AST
}

sealed trait KindOrType
object KindOrType {
final case class Kind(kind: meta.parser.DPIA.Kind.AST) extends KindOrType
final case class Type(typeAST: meta.parser.DPIA.Type.AST) extends KindOrType
}

def PrimitiveDeclarations[_: P]: P[Seq[AST.PrimitiveDeclaration]] =
P(Start ~ PrimitiveDeclaration.rep(1) ~ End)

// def drop(n: nat, m: nat, t: data, input: exp[n+m.t, read]): exp[m.t, read]
// def mapGlobal[dim: Int](n: nat, s: data, t: data, f: exp[s, read] -> exp[t, read], array: exp[n.s, read]): exp[n.t, read]
def PrimitiveDeclaration[_: P]: P[AST.PrimitiveDeclaration] = {
import scalaparse.Scala.TrailingCommaOps
def ScalaParams: P[(Int, Int)] = {
Expand All @@ -29,8 +33,10 @@ object Decl {
}

def Param: P[AST.Param] = (
(Identifier.map(AST.Identifier) ~ ":" ~ Kind.Kind).map(pair => AST.Param(pair._1, Left(pair._2)))
| (Identifier.map(AST.Identifier) ~ ":" ~ Type.PhraseType).map(pair => AST.Param(pair._1, Right(pair._2)))
(Identifier.map(AST.Identifier) ~ ":" ~ Kind.Kind).map(
pair => AST.Param(pair._1, KindOrType.Kind(pair._2)))
| (Identifier.map(AST.Identifier) ~ ":" ~ Type.PhraseType).map(
pair => AST.Param(pair._1, KindOrType.Type(pair._2)))
)

def Params: P[Seq[AST.Param]] = Param.repTC(0)
Expand Down
12 changes: 8 additions & 4 deletions meta/src/main/scala/meta/parser/DPIA/Kind.scala
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
package meta.parser.DPIA

import fastparse.ScalaWhitespace._
import fastparse._
import meta.parser.rise
import meta.parser.shared.Identifier

object Kind {
sealed trait AST
object AST {
case class RiseKind(riseKind: rise.Kind.AST) extends AST
case object Access extends AST
case class VariadicKind(n: String, kind: AST) extends AST
}

def Kind[_: P]: P[AST] = P(
rise.Kind.Kind.map(AST.RiseKind) |
"access".!.map(_ => AST.Access)
)
def Kind[_: P]: P[AST] = {
def OnlyKind: P[AST] = P(rise.Kind.Kind.map(AST.RiseKind) | "access".!.map(_ => AST.Access))

OnlyKind | (Identifier ~ "*" ~ OnlyKind).map(AST.VariadicKind.tupled)
}
}
8 changes: 7 additions & 1 deletion meta/src/main/scala/meta/parser/DPIA/Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ object Type {
case class FunType(inT: AST, outT: AST) extends AST
case class DepFunType(id: Identifier, kind: Kind.AST, t: AST) extends AST
case class Identifier(name: String) extends AST

case class VariadicType(n: Identifier, ty: AST) extends AST
}

object Access {
Expand Down Expand Up @@ -60,6 +62,10 @@ object Type {

def NonFunPhraseType: P[AST] = P( ExpType | AccType | VarType | CommType | PairType | DepFunType )

P( FunType | NonFunPhraseType )
def OnlyPhraseType: P[AST] = P( FunType | NonFunPhraseType )

def VariadicType: P[AST] = P(Identifier.map(AST.Identifier) ~ "*" ~ OnlyPhraseType).map(AST.VariadicType.tupled)

OnlyPhraseType | VariadicType
}
}
Loading

0 comments on commit 0d358ed

Please sign in to comment.