Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dpia traversals #161

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 116 additions & 12 deletions macros/src/main/scala/shine/macros/Primitive.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,62 @@ object Primitive {
class Impl(val c: blackbox.Context) {
import c.universe._

def expPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees)
def accPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees)
def comPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees)
def expPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(getClassInfo(c)))(annottees)
def accPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(getClassInfo(c)))(annottees)
def comPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(getClassInfo(c)))(annottees)

def primitive(transform : ClassDef => ClassDef)(annottees: Seq[c.Expr[Any]]): c.Expr[Any] = {
annottees.map(_.tree) match {
case (cdef: ClassDef) :: Nil =>
c.Expr(transform(cdef))
case (cdef: ClassDef) :: (md: ModuleDef) :: Nil =>
c.Expr(q"{${transform(cdef)}; $md}")
case (cdef: ClassDef) :: Nil => c.Expr(transform(cdef))
case (cdef: ClassDef) :: (md: ModuleDef) :: Nil => c.Expr(q"{${transform(cdef)}; $md}")
case _ => c.abort(c.enclosingPosition, "expected a class definition")
}
}

def makeLowerCaseName(s: String): String =
s"${Character.toLowerCase(s.charAt(0))}${s.substring(1)}"

def makeTraverseCall(v : Tree, name : TermName) : Tree => Option[Tree] = {
case Ident(TypeName("DataType")) | Ident(TypeName("ScalarType")) |
Ident(TypeName("BasicType")) => Some(fq"${name} <- $v.datatype($name)")
case Ident(TypeName("Data")) => Some(fq"${name} <- $v.data($name)")
case Ident(TypeName("Nat")) => Some(fq"${name} <- $v.natDispatch($name)")
case Ident(TypeName("NatIdentifier")) => Some(fq"${name} <- $v.typeIdentifierDispatch(shine.DPIA.Phrases.traverse.Reference)($name)")
case Ident(TypeName("NatToNat")) => Some(fq"${name} <- $v.natToNat($name)")
case Ident(TypeName("NatToData")) => Some(fq"${name} <- $v.natToData($name)")
case Ident(TypeName("AccessType")) => Some(fq"${name} <- $v.accessType($name)")
case Ident(TypeName("AddressSpace")) => Some(fq"${name} <- $v.addressSpace($name)")
// Phrase[ExpType]
case AppliedTypeTree((Ident(TypeName("Phrase")), _)) => Some(fq"${name} <- $v.phrase($name)")
// Vector[Phrase[ExpType]]
case AppliedTypeTree((Ident(TypeName("Vector")),
List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) => Some(fq"${name} <- monad.traverseV($name.map($v.phrase(_)))")
case AppliedTypeTree((Ident(TypeName("Seq")),
List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) => Some(fq"${name} <- monad.traverse($name.map($v.phrase(_)))")
case _ => None
}

def makeTraverse(name: TypeName, additionalParams: List[ValDef], params: List[ValDef], parent : Tree): Tree = {
val v = q"v"
val paramNames = params.map { case ValDef(_, name, _, _) => q"$name" }
val additionalParamNames = additionalParams.map { case ValDef(_, name, _, _) => q"$name" }
val forLoopBindings : List[Tree] = params.flatMap {
case ValDef(_, name, tpt, _) => makeTraverseCall(v, name)(tpt)
}
val construct = if (additionalParamNames.isEmpty) q"new $name(..$paramNames)"
else q"new $name(..$additionalParamNames)(..$paramNames)"
val forloop = if (forLoopBindings.isEmpty) q"monad.return_($construct)"
else q"for (..${forLoopBindings}) yield ${construct}"

q"""
override def traverse[M[+_]]($v: shine.DPIA.Phrases.traverse.Traversal[M]): M[$name] = {
import util.monads._
implicit val monad: Monad[M] = implicitly($v.monad)
$forloop
}
"""
}

def makeVisitAndRebuild(name: TypeName,
additionalParams: List[ValDef],
params: List[ValDef]): Tree = {
Expand Down Expand Up @@ -81,13 +120,77 @@ object Primitive {
"""
}

def makeXMLPrinter(name: TypeName,
additionalParams: List[ValDef],
params: List[ValDef]): Tree = {
def makeAttributes(params: List[ValDef]): (List[ValDef], Tree) = {
if (params.isEmpty) return (params, q"scala.xml.Null")
params.head match {
case ValDef(_, name, tpt, _) => tpt match {
case Ident(TypeName("DataType")) | Ident(TypeName("ScalarType")) |
Ident(TypeName("BasicType")) | Ident(TypeName("Nat")) |
Ident(TypeName("NatToNat")) | Ident(TypeName("NatToData")) |
Ident(TypeName("AccessType")) | Ident(TypeName("AddressSpace"))
=>
val (list, next) = makeAttributes(params.tail)
(list, q"""
scala.xml.Attribute(${name.toString},
scala.xml.Text(
shine.DPIA.Phrases.ToString($name)),
$next)
""")
case _ => (params, q"scala.xml.Null")
}
}
}

def makeBody(params: List[ValDef]): List[Tree] = {
params.map {
case ValDef(_, name, tpt, _) =>

val body = tpt match {
// Phrase[ExpType]
case AppliedTypeTree((Ident(TypeName("Phrase")), _)) =>
q"shine.DPIA.Phrases.xmlPrinter($name)"
// Vector[Phrase[ExpType]]
case AppliedTypeTree((Ident(TypeName("Vector")),
List(AppliedTypeTree((Ident(TypeName("Phrase")), _)))))
| AppliedTypeTree((Ident(TypeName("Seq")),
List(AppliedTypeTree((Ident(TypeName("Phrase")), _)))))
=>
q"$name.flatMap(shine.DPIA.Phrases.xmlPrinter(_)):_*"
case _ =>
q"scala.xml.Text(shine.DPIA.Phrases.ToString($name))"
}
q"""
scala.xml.Elem(null, ${name.toString},
scala.xml.Null, scala.xml.TopScope,
minimizeEmpty = false, $body)
"""
}
}

val lowerCaseName = makeLowerCaseName(name.toString)
val (rest, attributes) = makeAttributes(params)
val body = makeBody(rest)

q"""
override def xmlPrinter: scala.xml.Elem = {
val attributes_ = $attributes
val body_ = $body
scala.xml.Elem(null, $lowerCaseName, attributes_, scala.xml.TopScope,
minimizeEmpty = false, (body_):_*)
}
"""
}

case class ClassInfo(name: TypeName,
additionalParams: List[ValDef],
params: List[ValDef],
body: List[Tree],
parents: List[Tree])

def primitivesFromClassDef: ClassDef => ClassInfo = {
def getClassInfo: ClassDef => ClassInfo = {
case q"case class $name(..$params) extends { ..$_ } with ..$parents {..$body} " =>
ClassInfo(
name.asInstanceOf[c.TypeName],
Expand Down Expand Up @@ -123,14 +226,15 @@ object Primitive {
}

def makePrimitiveClass : ClassInfo => ClassDef = { case ClassInfo(name, additionalParams, params, body, parents) =>
val traverseMissing =
body.collectFirst({ case DefDef(_, TermName("traverse"), _, _, _, _) => ()}).isEmpty
val visitAndRebuildMissing =
body.collectFirst({ case DefDef(_, TermName("visitAndRebuild"), _, _, _, _) => ()}).isEmpty

val generated = q"""
${if (visitAndRebuildMissing)
makeVisitAndRebuild(name, additionalParams, params)
else q""}
"""
${if (traverseMissing) makeTraverse(name, additionalParams, params, parents(0)) else q""}
${if (visitAndRebuildMissing) makeVisitAndRebuild(name, additionalParams, params) else q""}
"""

val expClass = (additionalParams match {
case List() =>
Expand Down
10 changes: 5 additions & 5 deletions src/main/scala/rise/core/traverse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ object traverse {
case object Binding extends VarType
case object Reference extends VarType

trait Traversal[M[_]] {
trait Traversal[M[+_]] {
protected[this] implicit def monad : Monad[M]
def return_[T] : T => M[T] = monad.return_
def bind[T,S] : M[T] => (T => M[S]) => M[S] = monad.bind
Expand Down Expand Up @@ -187,13 +187,13 @@ object traverse {
}
}

trait ExprTraversal[M[_]] extends Traversal[M] {
trait ExprTraversal[M[+_]] extends Traversal[M] {
override def `type`[T <: Type] : T => M[T] = return_
}

trait PureTraversal extends Traversal[Pure] {override def monad : PureMonad.type = PureMonad }
trait PureExprTraversal extends PureTraversal with ExprTraversal[Pure]
trait AccumulatorTraversal[F,M[_]] extends Traversal[InMonad[M]#SetFst[F]#Type] {
trait AccumulatorTraversal[F,M[+_]] extends Traversal[InMonad[M]#SetFst[F]#Type] {
type Pair[T] = InMonad[M]#SetFst[F]#Type[T]
implicit val accumulator : Monoid[F]
implicit val wrapperMonad : Monad[M]
Expand All @@ -211,6 +211,6 @@ object traverse {
def traverse[T <: Type](t: T, f: PureTraversal): T = f.`type`(t).unwrap
def traverse[F](e: Expr, f: PureAccumulatorTraversal[F]): (F, Expr) = f.expr(e).unwrap
def traverse[F,T <: Type](t: T, f: PureAccumulatorTraversal[F]): (F, T) = f.`type`(t).unwrap
def traverse[M[_]](e: Expr, f: Traversal[M]): M[Expr] = f.expr(e)
def traverse[T <: Type, M[_]](e: T, f: Traversal[M]): M[T] = f.`type`(e)
def traverse[M[+_]](e: Expr, f: Traversal[M]): M[Expr] = f.expr(e)
def traverse[T <: Type, M[+_]](e: T, f: Traversal[M]): M[T] = f.`type`(e)
}
30 changes: 17 additions & 13 deletions src/main/scala/shine/DPIA/Phrases/Phrase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package shine.DPIA.Phrases

import arithexpr.arithmetic.{NamedVar, RangeAdd}
import shine.DPIA.Lifting.{liftDependentFunction, liftFunction, liftPair}
import shine.DPIA.Phrases.traverse._
import shine.DPIA.Types._
import shine.DPIA.Types.TypeCheck._
import shine.DPIA._
import shine.DPIA.primitives.functional.NatAsIndex
import util.monads.Pure

sealed trait Phrase[T <: PhraseType] {
val t: T
Expand Down Expand Up @@ -43,6 +45,7 @@ final case class DepLambda[K <: Kind, T <: PhraseType](x: K#I, body: Phrase[T])
extends Phrase[K `()->:` T] {
override val t: DepFunType[K, T] = DepFunType[K, T](x, body.t)
override def toString: String = s"Λ(${x.name} : ${kn.get}). $body"
val kindName : KindName[K] = implicitly(kn)
}

object DepLambda {
Expand Down Expand Up @@ -134,7 +137,7 @@ object Phrase {
`for`: Phrase[T1],
in: Phrase[T2]): Phrase[T2] = {
var substCounter = 0
object Visitor extends VisitAndRebuild.Visitor {
object Visitor extends PureTraversal {
def renaming[X <: PhraseType](p: Phrase[X]): Phrase[X] = {
case class Renaming(idMap: Map[String, String]) extends VisitAndRebuild.Visitor {
override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = p match {
Expand All @@ -151,8 +154,7 @@ object Phrase {
}

override def nat[N <: Nat](n: N): N = n.visitAndRebuild({
case i: NatIdentifier =>
NatIdentifier(idMap.getOrElse(i.name, i.name))
case i: NatIdentifier => NatIdentifier(idMap.getOrElse(i.name, i.name))
case ae => ae
}).asInstanceOf[N]

Expand All @@ -164,33 +166,32 @@ object Phrase {
}
VisitAndRebuild(p, Renaming(Map()))
}
override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = {
p match {

override def phrase[T <: PhraseType]: Phrase[T] => Pure[Phrase[T]] = {
case `for` =>
val newPh = if (substCounter == 0) ph else renaming(ph)
substCounter += 1
Stop(newPh.asInstanceOf[Phrase[T]])
case Natural(n) =>
return_(newPh.asInstanceOf[Phrase[T]])
case p@Natural(n) =>
val v = NatIdentifier(`for` match {
case Identifier(name, _) => name
case _ => throw new Exception("This should never happen")
})

ph.t match {
case ExpType(NatType, _) =>
Stop(Natural(Nat.substitute(
return_(Natural(Nat.substitute(
Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]])
case ExpType(IndexType(_), _) =>
Stop(Natural(Nat.substitute(
return_(Natural(Nat.substitute(
Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]])
case _ => Continue(p, this)
case _ => super.phrase(p)
}
case _ => Continue(p, this)
}
case p => super.phrase(p)
}
}

VisitAndRebuild(in, Visitor)
Visitor.phrase(in).unwrap
}

def substitute[T2 <: PhraseType](substitutionMap: Map[Phrase[_], Phrase[_]],
Expand Down Expand Up @@ -361,6 +362,9 @@ object Phrase {
sealed trait Primitive[T <: PhraseType] extends Phrase[T] {
def prettyPrint: String = this.toString

def traverse[M[+_]](f: Traversal[M]): M[Phrase[T]] =
throw new Exception("traverse should be implemented by a macro")

def visitAndRebuild(f: VisitAndRebuild.Visitor): Phrase[T] =
throw new Exception("visitAndRebuild should be implemented by a macro")
}
Expand Down
Loading