From d38f3ecda8b9fda8a651fdffedbc9b6690c068de Mon Sep 17 00:00:00 2001 From: Michel Steuwer Date: Wed, 5 May 2021 11:28:17 +0100 Subject: [PATCH 1/5] Implement sequence syntax as discussed in https://github.com/rise-lang/shine/issues/168 --- .../scala/meta/generator/DPIAPrimitives.scala | 156 +++++++++++--- .../scala/meta/generator/RisePrimitives.scala | 68 ++++-- .../main/scala/meta/parser/DPIA/Kind.scala | 12 +- .../main/scala/meta/parser/DPIA/Type.scala | 8 +- .../parser/DPIA/isWellKindedDefinition.scala | 40 ++-- .../main/scala/meta/parser/rise/Type.scala | 19 +- .../meta/parser/rise/isWellKindedType.scala | 200 +++++++++++------- .../scala/rise/Cuda/primitives/.gitignore | 1 - .../rise/Cuda/primitives/asFragment.scala | 24 +++ .../scala/rise/Cuda/primitives/asMatrix.scala | 24 +++ .../Cuda/primitives/generateFragment.scala | 24 +++ .../rise/Cuda/primitives/globalToShared.scala | 24 +++ .../scala/rise/Cuda/primitives/mapBlock.scala | 37 ++++ .../rise/Cuda/primitives/mapFragment.scala | 24 +++ .../rise/Cuda/primitives/mapGlobal.scala | 37 ++++ .../scala/rise/Cuda/primitives/mapLane.scala | 37 ++++ .../rise/Cuda/primitives/mapThreads.scala | 37 ++++ .../scala/rise/Cuda/primitives/mapWarp.scala | 37 ++++ .../rise/Cuda/primitives/tensorMMA.scala | 24 +++ src/main/scala/rise/core/DSL/package.scala | 38 ++-- .../scala/rise/core/ForeignFunction.scala | 13 -- src/main/scala/rise/core/package.scala | 1 - .../scala/rise/core/primitives/.gitignore | 1 - src/main/scala/rise/core/primitives/add.scala | 24 +++ .../scala/rise/core/primitives/asScalar.scala | 24 +++ .../scala/rise/core/primitives/asVector.scala | 24 +++ .../core/primitives/asVectorAligned.scala | 24 +++ .../scala/rise/core/primitives/cast.scala | 24 +++ .../rise/core/primitives/circularBuffer.scala | 24 +++ .../scala/rise/core/primitives/concat.scala | 24 +++ .../scala/rise/core/primitives/depJoin.scala | 24 +++ .../rise/core/primitives/depMapSeq.scala | 24 +++ .../scala/rise/core/primitives/depZip.scala | 24 +++ src/main/scala/rise/core/primitives/div.scala | 24 +++ .../scala/rise/core/primitives/dmatch.scala | 24 +++ .../scala/rise/core/primitives/drop.scala | 24 +++ .../scala/rise/core/primitives/equal.scala | 24 +++ .../core/primitives/foreignFunction.scala | 46 ++++ src/main/scala/rise/core/primitives/fst.scala | 24 +++ .../scala/rise/core/primitives/gather.scala | 24 +++ .../scala/rise/core/primitives/generate.scala | 24 +++ src/main/scala/rise/core/primitives/gt.scala | 24 +++ src/main/scala/rise/core/primitives/id.scala | 24 +++ src/main/scala/rise/core/primitives/idx.scala | 24 +++ .../rise/core/primitives/indexAsNat.scala | 24 +++ .../scala/rise/core/primitives/iterate.scala | 24 +++ .../rise/core/primitives/iterateStream.scala | 24 +++ .../scala/rise/core/primitives/join.scala | 24 +++ src/main/scala/rise/core/primitives/let.scala | 24 +++ src/main/scala/rise/core/primitives/lt.scala | 24 +++ .../rise/core/primitives/makeArray.scala | 56 +++-- .../rise/core/primitives/makeDepPair.scala | 24 +++ .../scala/rise/core/primitives/makePair.scala | 24 +++ src/main/scala/rise/core/primitives/map.scala | 24 +++ .../scala/rise/core/primitives/mapFst.scala | 24 +++ .../scala/rise/core/primitives/mapSeq.scala | 24 +++ .../rise/core/primitives/mapSeqUnroll.scala | 24 +++ .../scala/rise/core/primitives/mapSnd.scala | 24 +++ .../rise/core/primitives/mapStream.scala | 24 +++ src/main/scala/rise/core/primitives/mod.scala | 24 +++ src/main/scala/rise/core/primitives/mul.scala | 24 +++ .../rise/core/primitives/natAsIndex.scala | 24 +++ src/main/scala/rise/core/primitives/neg.scala | 24 +++ src/main/scala/rise/core/primitives/not.scala | 24 +++ .../scala/rise/core/primitives/padClamp.scala | 24 +++ .../scala/rise/core/primitives/padCst.scala | 24 +++ .../scala/rise/core/primitives/padEmpty.scala | 24 +++ .../rise/core/primitives/partition.scala | 24 +++ .../rise/core/primitives/primitives.rise | 11 +- .../scala/rise/core/primitives/reduce.scala | 24 +++ .../rise/core/primitives/reduceSeq.scala | 24 +++ .../core/primitives/reduceSeqUnroll.scala | 24 +++ .../scala/rise/core/primitives/reorder.scala | 24 +++ .../rise/core/primitives/rotateValues.scala | 24 +++ .../scala/rise/core/primitives/scanSeq.scala | 24 +++ .../scala/rise/core/primitives/scatter.scala | 24 +++ .../scala/rise/core/primitives/select.scala | 24 +++ .../scala/rise/core/primitives/slide.scala | 24 +++ src/main/scala/rise/core/primitives/snd.scala | 24 +++ .../scala/rise/core/primitives/split.scala | 24 +++ src/main/scala/rise/core/primitives/sub.scala | 24 +++ .../scala/rise/core/primitives/take.scala | 24 +++ .../scala/rise/core/primitives/toMem.scala | 24 +++ .../rise/core/primitives/transpose.scala | 24 +++ .../scala/rise/core/primitives/unzip.scala | 24 +++ .../core/primitives/vectorFromScalar.scala | 24 +++ src/main/scala/rise/core/primitives/zip.scala | 24 +++ .../scala/rise/elevate/rules/lowering.scala | 2 +- .../scala/rise/elevate/rules/traversal.scala | 1 - .../rise/elevate/strategies/lowering.scala | 1 - .../scala/rise/openCL/primitives/.gitignore | 1 - .../rise/openCL/primitives/mapGlobal.scala | 37 ++++ .../rise/openCL/primitives/mapLocal.scala | 37 ++++ .../rise/openCL/primitives/mapWorkGroup.scala | 37 ++++ .../openCL/primitives/oclCircularBuffer.scala | 24 +++ .../rise/openCL/primitives/oclIterate.scala | 24 +++ .../rise/openCL/primitives/oclReduceSeq.scala | 24 +++ .../primitives/oclReduceSeqUnroll.scala | 24 +++ .../openCL/primitives/oclRotateValues.scala | 24 +++ .../openCL/primitives/oclRunPrimitive.scala | 24 +++ .../rise/openCL/primitives/oclToMem.scala | 24 +++ .../scala/rise/openMP/primitives/.gitignore | 1 - .../scala/rise/openMP/primitives/mapPar.scala | 24 +++ .../rise/openMP/primitives/reducePar.scala | 24 +++ .../shine/C/Compilation/CodeGenerator.scala | 10 +- .../Compilation/AcceptorTranslation.scala | 12 +- .../Compilation/ContinuationTranslation.scala | 22 +- src/main/scala/shine/DPIA/Data.scala | 2 + .../shine/DPIA/InferAccessAnnotation.scala | 7 +- src/main/scala/shine/DPIA/fromRise.scala | 39 ++-- src/main/scala/shine/DPIA/package.scala | 6 + .../functional/ForeignFunctionCall.scala | 12 +- .../primitives/functional/MakeArray.scala | 10 +- .../primitives/functional/primitives.dpia | 9 +- .../Compilation/HostCodeGenerator.scala | 4 +- .../Compilation/HostManagedBuffers.scala | 24 ++- .../Compilation/KernelCodeGenerator.scala | 4 +- .../Passes/InsertMemoryBarriers.scala | 8 +- .../SeparateHostAndKernelCode.scala | 7 +- .../primitives/functional/KernelCall.scala | 12 +- .../functional/OpenCLFunctionCall.scala | 12 +- .../primitives/functional/primitives.dpia | 18 +- .../primitives/imperative/KernelCallCmd.scala | 10 +- .../primitives/imperative/primitives.dpia | 7 +- .../scala/shine/OpenMP/CodeGenerator.scala | 4 +- .../separableConvolution2DNaiveEqsat.scala | 1 - .../shine/DPIA/Primitives/Generate.scala | 4 +- 127 files changed, 2720 insertions(+), 320 deletions(-) delete mode 100644 src/main/scala/rise/Cuda/primitives/.gitignore create mode 100644 src/main/scala/rise/Cuda/primitives/asFragment.scala create mode 100644 src/main/scala/rise/Cuda/primitives/asMatrix.scala create mode 100644 src/main/scala/rise/Cuda/primitives/generateFragment.scala create mode 100644 src/main/scala/rise/Cuda/primitives/globalToShared.scala create mode 100644 src/main/scala/rise/Cuda/primitives/mapBlock.scala create mode 100644 src/main/scala/rise/Cuda/primitives/mapFragment.scala create mode 100644 src/main/scala/rise/Cuda/primitives/mapGlobal.scala create mode 100644 src/main/scala/rise/Cuda/primitives/mapLane.scala create mode 100644 src/main/scala/rise/Cuda/primitives/mapThreads.scala create mode 100644 src/main/scala/rise/Cuda/primitives/mapWarp.scala create mode 100644 src/main/scala/rise/Cuda/primitives/tensorMMA.scala delete mode 100644 src/main/scala/rise/core/primitives/.gitignore create mode 100644 src/main/scala/rise/core/primitives/add.scala create mode 100644 src/main/scala/rise/core/primitives/asScalar.scala create mode 100644 src/main/scala/rise/core/primitives/asVector.scala create mode 100644 src/main/scala/rise/core/primitives/asVectorAligned.scala create mode 100644 src/main/scala/rise/core/primitives/cast.scala create mode 100644 src/main/scala/rise/core/primitives/circularBuffer.scala create mode 100644 src/main/scala/rise/core/primitives/concat.scala create mode 100644 src/main/scala/rise/core/primitives/depJoin.scala create mode 100644 src/main/scala/rise/core/primitives/depMapSeq.scala create mode 100644 src/main/scala/rise/core/primitives/depZip.scala create mode 100644 src/main/scala/rise/core/primitives/div.scala create mode 100644 src/main/scala/rise/core/primitives/dmatch.scala create mode 100644 src/main/scala/rise/core/primitives/drop.scala create mode 100644 src/main/scala/rise/core/primitives/equal.scala create mode 100644 src/main/scala/rise/core/primitives/foreignFunction.scala create mode 100644 src/main/scala/rise/core/primitives/fst.scala create mode 100644 src/main/scala/rise/core/primitives/gather.scala create mode 100644 src/main/scala/rise/core/primitives/generate.scala create mode 100644 src/main/scala/rise/core/primitives/gt.scala create mode 100644 src/main/scala/rise/core/primitives/id.scala create mode 100644 src/main/scala/rise/core/primitives/idx.scala create mode 100644 src/main/scala/rise/core/primitives/indexAsNat.scala create mode 100644 src/main/scala/rise/core/primitives/iterate.scala create mode 100644 src/main/scala/rise/core/primitives/iterateStream.scala create mode 100644 src/main/scala/rise/core/primitives/join.scala create mode 100644 src/main/scala/rise/core/primitives/let.scala create mode 100644 src/main/scala/rise/core/primitives/lt.scala create mode 100644 src/main/scala/rise/core/primitives/makeDepPair.scala create mode 100644 src/main/scala/rise/core/primitives/makePair.scala create mode 100644 src/main/scala/rise/core/primitives/map.scala create mode 100644 src/main/scala/rise/core/primitives/mapFst.scala create mode 100644 src/main/scala/rise/core/primitives/mapSeq.scala create mode 100644 src/main/scala/rise/core/primitives/mapSeqUnroll.scala create mode 100644 src/main/scala/rise/core/primitives/mapSnd.scala create mode 100644 src/main/scala/rise/core/primitives/mapStream.scala create mode 100644 src/main/scala/rise/core/primitives/mod.scala create mode 100644 src/main/scala/rise/core/primitives/mul.scala create mode 100644 src/main/scala/rise/core/primitives/natAsIndex.scala create mode 100644 src/main/scala/rise/core/primitives/neg.scala create mode 100644 src/main/scala/rise/core/primitives/not.scala create mode 100644 src/main/scala/rise/core/primitives/padClamp.scala create mode 100644 src/main/scala/rise/core/primitives/padCst.scala create mode 100644 src/main/scala/rise/core/primitives/padEmpty.scala create mode 100644 src/main/scala/rise/core/primitives/partition.scala create mode 100644 src/main/scala/rise/core/primitives/reduce.scala create mode 100644 src/main/scala/rise/core/primitives/reduceSeq.scala create mode 100644 src/main/scala/rise/core/primitives/reduceSeqUnroll.scala create mode 100644 src/main/scala/rise/core/primitives/reorder.scala create mode 100644 src/main/scala/rise/core/primitives/rotateValues.scala create mode 100644 src/main/scala/rise/core/primitives/scanSeq.scala create mode 100644 src/main/scala/rise/core/primitives/scatter.scala create mode 100644 src/main/scala/rise/core/primitives/select.scala create mode 100644 src/main/scala/rise/core/primitives/slide.scala create mode 100644 src/main/scala/rise/core/primitives/snd.scala create mode 100644 src/main/scala/rise/core/primitives/split.scala create mode 100644 src/main/scala/rise/core/primitives/sub.scala create mode 100644 src/main/scala/rise/core/primitives/take.scala create mode 100644 src/main/scala/rise/core/primitives/toMem.scala create mode 100644 src/main/scala/rise/core/primitives/transpose.scala create mode 100644 src/main/scala/rise/core/primitives/unzip.scala create mode 100644 src/main/scala/rise/core/primitives/vectorFromScalar.scala create mode 100644 src/main/scala/rise/core/primitives/zip.scala delete mode 100644 src/main/scala/rise/openCL/primitives/.gitignore create mode 100644 src/main/scala/rise/openCL/primitives/mapGlobal.scala create mode 100644 src/main/scala/rise/openCL/primitives/mapLocal.scala create mode 100644 src/main/scala/rise/openCL/primitives/mapWorkGroup.scala create mode 100644 src/main/scala/rise/openCL/primitives/oclCircularBuffer.scala create mode 100644 src/main/scala/rise/openCL/primitives/oclIterate.scala create mode 100644 src/main/scala/rise/openCL/primitives/oclReduceSeq.scala create mode 100644 src/main/scala/rise/openCL/primitives/oclReduceSeqUnroll.scala create mode 100644 src/main/scala/rise/openCL/primitives/oclRotateValues.scala create mode 100644 src/main/scala/rise/openCL/primitives/oclRunPrimitive.scala create mode 100644 src/main/scala/rise/openCL/primitives/oclToMem.scala delete mode 100644 src/main/scala/rise/openMP/primitives/.gitignore create mode 100644 src/main/scala/rise/openMP/primitives/mapPar.scala create mode 100644 src/main/scala/rise/openMP/primitives/reducePar.scala diff --git a/meta/src/main/scala/meta/generator/DPIAPrimitives.scala b/meta/src/main/scala/meta/generator/DPIAPrimitives.scala index 7e6f14729..387e19494 100644 --- a/meta/src/main/scala/meta/generator/DPIAPrimitives.scala +++ b/meta/src/main/scala/meta/generator/DPIAPrimitives.scala @@ -19,7 +19,7 @@ object DPIAPrimitives { case Parsed.Success(seq, _) => seq.foreach { case PrimitiveDeclaration(Identifier(originalName), scalaParams, params, returnType) - if DPIA.isWellKindedDefinition(params, returnType) => + if DPIA.isWellKindedDefinition(toParamList(definition, scalaParams), params, returnType) => val name = originalName.capitalize val outputPath = (path / os.up) / s"$name.scala" @@ -29,10 +29,7 @@ object DPIAPrimitives { val packageName = path.relativeTo(shinePath).segments.dropRight(1).foldLeft[Term.Ref](Term.Name("shine")) { case (t, name) => Term.Select(t, Term.Name(name)) } - val scalaParamsString = scalaParams match { - case Some((start, end)) => definition.substring(start, end) - case None => "" - } + val code = s"""// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // |// This file is automatically generated and should not be changed manually // |// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // @@ -45,7 +42,7 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)} +${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), params, returnType)} }""".toString()} |""".stripMargin @@ -59,7 +56,7 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)} } def generateCaseClass(name: scala.meta.Type.Name, - scalaParamsString: String, + scalaParams: Option[List[scala.meta.Term.Param]], params: Seq[DPIA.Decl.AST.Param], returnType: DPIA.Type.AST): scala.meta.Defn.Class = { import scala.meta._ @@ -70,7 +67,7 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)} case AST.CommType => (t"CommType", init"CommandPrimitive") case _ => throw new Exception(s"Expected `exp', `acc' or `comm' as return type for ${name.value}") } - val generatedParams = generateParams(scalaParamsString, params) + val generatedParams = generateParams(scalaParams, params) q""" final case class $name(...$generatedParams) extends $superClass { { @@ -83,26 +80,32 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)} ${generateVisitAndRebuild(name, generatedParams)} - ..${if (scalaParamsString.nonEmpty && generatedParams.last.size > 1) { + ..${if (scalaParams.nonEmpty && generatedParams.last.size > 1) { List(generateUnwrap(generatedParams.last)) } else List() } } """ } - def generateParams(scalaParamsString: String, - params: Seq[DPIA.Decl.AST.Param]): List[List[scala.meta.Term.Param]] = { + def toParamList(definition: String, scalaParams: Option[(Int, Int)]): Option[List[scala.meta.Term.Param]] = { import scala.meta._ - - val scalaParams = if (scalaParamsString.nonEmpty) { - s"def foo($scalaParamsString)".parse[Stat].get match { - case declDef: Decl.Def => declDef.paramss + scalaParams.map { case (start, end) => + s"def foo(${definition.substring(start, end)})".parse[Stat].get match { + case declDef: Decl.Def => declDef.paramss.head } - } else { - List() } + } + + def generateParams(scalaParams: Option[List[scala.meta.Term.Param]], + params: Seq[DPIA.Decl.AST.Param]): List[List[scala.meta.Term.Param]] = { + import scala.meta._ - scalaParams ++ List(params.map(generateParam).toList) + scalaParams match { + case Some(scalaParamList) => + List(scalaParamList) ++ List(params.map(generateParam).toList) + case None => + List(params.map(generateParam).toList) + } } def generateParam(param: DPIA.Decl.AST.Param): scala.meta.Term.Param = { @@ -111,6 +114,8 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)} param"val ${Term.Name(param.id.name)}: ${ param.ty match { case Left(kindAST) => generateType(kindAST) + case Right(meta.parser.DPIA.Type.AST.VariadicType(_, typeAST)) => + t"Seq[Phrase[${generatePhraseType(typeAST)}]]" case Right(typeAST) => t"Phrase[${generatePhraseType(typeAST)}]" } }" @@ -122,25 +127,102 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)} ..${params. filter(param => param.ty.isRight). // only check types for parameters with phrase types map(param => - q"${Term.Name(param.id.name)} :: ${ - param.ty match { - case Right(typeAST@DPIA.Type.AST.DepFunType(id, kind, t)) => - q"""{ - ${Defn.Val( - mods = Nil, - pats = List(Pat.Var(name = Term.Name(id.name))), - decltpe = None, - rhs = q"${Term.Name(param.id.name)}.t.x" - )} - ${generateTerm(typeAST)} - }""" - case Right(typeAST) => generateTerm(typeAST) - case Left(kindAST) => throw new Exception("This should not happen") - }}" + param.ty match { + case Right(typeAST@DPIA.Type.AST.DepFunType(id, _, _)) => + q"""${Term.Name(param.id.name)} :: { + ${Defn.Val( + mods = Nil, + pats = List(Pat.Var(name = Term.Name(id.name))), + decltpe = None, + rhs = q"${Term.Name(param.id.name)}.t.x" + )} + ${generateTerm(typeAST)} + }""" + case Right(DPIA.Type.AST.VariadicType(n, typeAST)) => + getUnrolledIds(typeAST) match { + case Some(unrolledIds) => + val zips = unrolledIds.foldLeft[Term](Term.Name(param.id.name)) { + case (term, id) => q"$term.zip(${Term.Name(id.name)})" + } + val p = Term.Name(param.id.name) + val pattern = unrolledIds.foldRight[Pat](Pat.Var(p)) { + case (id, pattern) => p"($pattern, ${Pat.Var(Term.Name(id.name))})" + } + q"""$zips.foreach { + case ($pattern) => $p :: ${generateTerm(typeAST)} + }""" + case None => + q"${Term.Name(param.id.name)}.foreach(_ :: ${generateTerm(typeAST)})" + } + case Right(typeAST) => + q"${Term.Name(param.id.name)} :: ${generateTerm(typeAST)}" + case Left(_) => throw new Exception("Generation of type checks not necessary for kinds") + } ).toList} }""" } + def getUnrolledIds(typeAST: DPIA.Type.AST): Option[Seq[rise.Type.AST.UnrolledIdentifier]] = { + import DPIA.Type.AST + typeAST match { + case AST.ExpType(dataType, _) => + getUnrolledIds(dataType) + case AST.AccType(dataType) => + getUnrolledIds(dataType) + case AST.CommType => None + case AST.PairType(lhs, rhs) => + for { + ids1 <- getUnrolledIds(lhs) + ids2 <- getUnrolledIds(rhs) + } yield ids1 concat ids2 + case AST.FunType(inT, outT) => + for { + ids1 <- getUnrolledIds(inT) + ids2 <- getUnrolledIds(outT) + } yield ids1 concat ids2 + case AST.DepFunType(_, _, t) => + getUnrolledIds(t) + case AST.Identifier(_) => None + case AST.VariadicType(n, ty) => throw new Exception("This function should not be called on a variadic type") + } + } + + def getUnrolledIds(typeAST: rise.Type.AST): Option[Seq[rise.Type.AST.UnrolledIdentifier]] = { + import rise.Type.AST + typeAST match { + case AST.Identifier(_) => None + case id@AST.UnrolledIdentifier(_) => Some(Seq(id)) + case AST.FunType(inT, outT) => + for { + ids1 <- getUnrolledIds(inT) + ids2 <- getUnrolledIds(outT) + } yield ids1 concat ids2 + case AST.DepFunType(_, _, t) => + getUnrolledIds(t) + case AST.ImplicitDepFunType(_, _, t) => + getUnrolledIds(t) + case AST.VariadicFunType(_, _, _) => throw new Exception("This function should not be called on a variadic type") + case AST.VariadicDepFunType(_, _, _, _) => throw new Exception("This function should not be called on a variadic type") + case AST.ScalarType(_) => None + case AST.NatType => None + case AST.OpaqueType(_) => None + case AST.VectorType(_, elemType) => getUnrolledIds(elemType) + case AST.IndexType(_) => None + case AST.PairType(lhs, rhs) => + for { + ids1 <- getUnrolledIds(lhs) + ids2 <- getUnrolledIds(rhs) + } yield ids1 concat ids2 + case AST.DepPairType(_, _, t) => getUnrolledIds(t) + case AST.NatToDataApply(_, _) => None + case AST.NatToDataLambda(_, t) => getUnrolledIds(t) + case AST.ArrayType(_, elemType) => getUnrolledIds(elemType) + case AST.DepArrayType(_, _) => None + case AST.FragmentType(_, _, _, dt, _, _) => getUnrolledIds(dt) + case AST.ManagedBufferType(t) => getUnrolledIds(t) + } + } + def generatePhraseType(typeAST: DPIA.Type.AST): scala.meta.Type = { import scala.meta._ import meta.parser.DPIA.Type.AST @@ -152,6 +234,7 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)} case AST.FunType(inT, outT) => t"FunType[${generatePhraseType(inT)}, ${generatePhraseType(outT)}]" case AST.DepFunType(id, kind, t) => t"DepFunType[${generateKindType(kind)}, ${generatePhraseType(t)}]" case AST.Identifier(name) => Type.Name(name) + case AST.VariadicType(_, _) => throw new Exception("Can not generate Phrase Type for Variadic Type") } } @@ -171,6 +254,7 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)} case AST.MatrixLayout => Type.Name("MatrixLayout") } case AST.Access => Type.Name("AccessType") + case AST.VariadicKind(_, kind) => t"Seq[${generateType(kind)}]" } } @@ -186,10 +270,11 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)} case AST.Nat2Nat => Type.Name("NatToNatKind") case AST.Nat2Data => Type.Name("NatToDataKind") case AST.Nat => Type.Name("NatKind") - case AST.Fragment => ??? - case AST.MatrixLayout => ??? + case AST.Fragment => throw new Exception("Can not generate Kind for Fragment") + case AST.MatrixLayout => throw new Exception("Can not generate Kind for Matrix Layout") } case AST.Access => Type.Name("AccessKind") + case AST.VariadicKind(_, _) => throw new Exception("Can not generate Kind for Variadic Kind") } } @@ -210,6 +295,7 @@ ${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)} case AST.DepFunType(id, kind, t) => q"DepFunType[${generateKindType(kind)}, PhraseType](${Term.Name(id.name)}, ${generateTerm(t)})" case AST.Identifier(name) => Term.Name(name) + case AST.VariadicType(_, _) => throw new Exception("Can not generate Term for Variadic Type") } } diff --git a/meta/src/main/scala/meta/generator/RisePrimitives.scala b/meta/src/main/scala/meta/generator/RisePrimitives.scala index f48f7e0a0..d85e2d7f3 100644 --- a/meta/src/main/scala/meta/generator/RisePrimitives.scala +++ b/meta/src/main/scala/meta/generator/RisePrimitives.scala @@ -3,6 +3,9 @@ package meta.generator import fastparse.{Parsed, parse} import meta.parser._ import meta.parser.rise.Kind +import meta.parser.rise.Kind.AST + +import scala.meta.Term object RisePrimitives { def main(args: Array[String]): Unit = { @@ -18,16 +21,11 @@ object RisePrimitives { case Parsed.Success(seq, _) => seq.foreach { case rise.Decl.AST.PrimitiveDeclaration(rise.Decl.AST.Identifier(name), scalaParams, typeSignature) - if rise.isWellKindedType(typeSignature) => + if rise.isWellKindedType(toParamList(definition, scalaParams), typeSignature) => val outputPath = (path / os.up) / s"$name.scala" println(s"Generate $outputPath") - val generatedDef = scalaParams match { - case None => - generateObject(name, typeSignature) - case Some((start, end)) => - generateCaseClass(name, definition.substring(start, end), typeSignature) - } + val generatedDef = generate(name, toParamList(definition, scalaParams), typeSignature) import scala.meta._ val packageName = path.relativeTo(risePath).segments.dropRight(1).foldLeft[Term.Ref](Term.Name("rise")) { @@ -58,6 +56,22 @@ import arithexpr.arithmetic._ }) } + def toParamList(definition: String, scalaParams: Option[(Int, Int)]): Option[List[Term.Param]] = { + import scala.meta._ + scalaParams.map { case (start, end) => + s"def foo(${definition.substring(start, end)})".parse[Stat].get match { + case declDef: Decl.Def => declDef.paramss.head + } + } + } + + def generate(name: String, params: Option[List[Term.Param]], typeSignature: rise.Type.AST): scala.meta.Term.Block = + params match { + case None => generateObject(name, typeSignature) + case Some(params) => generateCaseClass(name, params, typeSignature) + } + + def generateObject(name: String, typeSignature: rise.Type.AST): scala.meta.Term.Block = { import scala.meta._ val generated = q"""{ @@ -83,12 +97,9 @@ import arithexpr.arithmetic._ generated } - def generateCaseClass(name: String, paramsString: String, typeSignature: rise.Type.AST): scala.meta.Term.Block = { + def generateCaseClass(name: String, params: List[Term.Param], typeSignature: rise.Type.AST): scala.meta.Term.Block = { import scala.meta._ - val params = s"def foo($paramsString)".parse[Stat].get match { - case declDef: Decl.Def => declDef.paramss.head - } val args: List[Term.Name] = params.map(p => Term.Name(p.name.value)) val types: List[Type] = params.map(p => p.decltpe.get) @@ -119,7 +130,11 @@ import arithexpr.arithmetic._ } } - def unapply(arg: rise.core.Expr): Option[..$types] = arg match { + def unapply(arg: rise.core.Expr): ${if (types.length > 1) { + t"Option[(..$types)]" + } else { + t"Option[..$types]"} + } = arg match { case p: Primitive => Some(..${generateMemberAccesses(args)}) case _ => None @@ -138,6 +153,30 @@ import arithexpr.arithmetic._ q"expl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t)})" case rise.Type.AST.ImplicitDepFunType(id, kind, t) => q"impl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t)})" + case rise.Type.AST.VariadicFunType(_, rise.Type.AST.UnrolledIdentifier(inTs), outT) => + q"""${Term.Name(inTs)}.foldRight(${generateTypeScheme(outT)}: Type) { + case (lhsT, rhsT) => lhsT ->: rhsT + }""" + case rise.Type.AST.VariadicFunType(rise.Type.AST.Identifier(n), inTs, outT) => + q"""Seq.fill(${Term.Name(n)})(${generateDataType(inTs)}).foldRight(${generateTypeScheme(outT)}: Type) { + case (lhsT, rhsT) => lhsT ->: rhsT + }""" + case rise.Type.AST.VariadicDepFunType(n, id, kind, t) => + val (createIds, typeName) = kind match { + case AST.Data => (q"""DataTypeIdentifier(freshName("dt"), isExplicit = true)""", Type.Name("DataKind")) + case AST.Address => ??? + case AST.Nat2Nat => ??? + case AST.Nat2Data => ??? + case AST.Nat => ??? + case AST.Fragment => ??? + case AST.MatrixLayout => ??? + } + q"""{ + val ${Pat.Var(Term.Name(id.name))} = Seq.fill(${Term.Name(n.name)})($createIds) + ${Term.Name(id.name)}.foldRight(${generateTypeScheme(t)}: Type) { + case (id, t) => DepFunType[$typeName, Type](id, t) + } + }""" case _ => generateDataType(typeAST) } } @@ -147,6 +186,8 @@ import arithexpr.arithmetic._ typeAST match { case rise.Type.AST.Identifier(name) => Term.Name(name) + case rise.Type.AST.UnrolledIdentifier(name) => + Term.Name(name) case rise.Type.AST.ScalarType(t) => t.parse[Term].get case rise.Type.AST.NatType => @@ -177,7 +218,8 @@ import arithexpr.arithmetic._ case rise.Type.AST.ManagedBufferType(dt) => q"ManagedBufferType(${generateDataType(dt)})" case rise.Type.AST.FunType(_, _) | rise.Type.AST.DepFunType(_, _, _) | - rise.Type.AST.ImplicitDepFunType(_, _, _) => ??? + rise.Type.AST.ImplicitDepFunType(_, _, _) | rise.Type.AST.VariadicFunType(_, _, _) | + rise.Type.AST.VariadicDepFunType(_, _, _, _) => ??? } } diff --git a/meta/src/main/scala/meta/parser/DPIA/Kind.scala b/meta/src/main/scala/meta/parser/DPIA/Kind.scala index 6e3c0544e..2b4b8dc9e 100644 --- a/meta/src/main/scala/meta/parser/DPIA/Kind.scala +++ b/meta/src/main/scala/meta/parser/DPIA/Kind.scala @@ -1,17 +1,21 @@ package meta.parser.DPIA +import fastparse.ScalaWhitespace._ import fastparse._ import meta.parser.rise +import meta.parser.shared.Identifier object Kind { sealed trait AST object AST { case class RiseKind(riseKind: rise.Kind.AST) extends AST case object Access extends AST + case class VariadicKind(n: String, kind: AST) extends AST } - def Kind[_: P]: P[AST] = P( - rise.Kind.Kind.map(AST.RiseKind) | - "access".!.map(_ => AST.Access) - ) + def Kind[_: P]: P[AST] = { + def OnlyKind: P[AST] = P(rise.Kind.Kind.map(AST.RiseKind) | "access".!.map(_ => AST.Access)) + + OnlyKind | (Identifier ~ "*" ~ OnlyKind).map(AST.VariadicKind.tupled) + } } diff --git a/meta/src/main/scala/meta/parser/DPIA/Type.scala b/meta/src/main/scala/meta/parser/DPIA/Type.scala index c998ed5d4..a72e9448c 100644 --- a/meta/src/main/scala/meta/parser/DPIA/Type.scala +++ b/meta/src/main/scala/meta/parser/DPIA/Type.scala @@ -16,6 +16,8 @@ object Type { case class FunType(inT: AST, outT: AST) extends AST case class DepFunType(id: Identifier, kind: Kind.AST, t: AST) extends AST case class Identifier(name: String) extends AST + + case class VariadicType(n: Identifier, ty: AST) extends AST } object Access { @@ -60,6 +62,10 @@ object Type { def NonFunPhraseType: P[AST] = P( ExpType | AccType | VarType | CommType | PairType | DepFunType ) - P( FunType | NonFunPhraseType ) + def OnlyPhraseType: P[AST] = P( FunType | NonFunPhraseType ) + + def VariadicType: P[AST] = P(Identifier.map(AST.Identifier) ~ "*" ~ OnlyPhraseType).map(AST.VariadicType.tupled) + + OnlyPhraseType | VariadicType } } diff --git a/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala b/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala index 297978ea2..ca38195bb 100644 --- a/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala +++ b/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala @@ -4,51 +4,67 @@ import meta.parser._ object isWellKindedDefinition { - def apply(params: Seq[Decl.AST.Param], returnType: Type.AST): Boolean = { + def apply(scalaParams: Option[List[scala.meta.Term.Param]], + params: Seq[Decl.AST.Param], + returnType: Type.AST): Boolean = { import Decl.AST._ var isWellKindedFlag = true - val env = params.foldLeft(Map.empty[String, Kind.AST]) { + val incompleteEnv = scalaParams.map(_.foldLeft(Map.empty[String, Kind.AST]) { + case (env, param) => param.decltpe match { + case Some(scala.meta.Type.Name("Int")) => + env.updated(param.name.value,Kind.AST.RiseKind(rise.Kind.AST.Nat)) + case _ => env + } + }).getOrElse(Map.empty[String, Kind.AST]) + val env = params.foldLeft(incompleteEnv) { case (env, Param(Identifier(name), Left(kind))) => env.updated(name, kind) case (env, Param(Identifier(_), Right(typeAST))) => - if (!isWellKinded(typeAST, env)) isWellKindedFlag = false + if (!isWellKinded(typeAST, env, None)) + isWellKindedFlag = false env } - isWellKindedFlag && isWellKinded(returnType, env) + isWellKindedFlag && isWellKinded(returnType, env, None) } - def isWellKinded(typeAST: Type.AST, env: Map[String, Kind.AST]): Boolean = { + def isWellKinded(typeAST: Type.AST, + env: Map[String, Kind.AST], + variadicN: Option[String]): Boolean = { import Type._ import rise.isWellKindedType._ typeAST match { case AST.ExpType(dataType, access) => val nenv = env.flatMap { case (string, DPIA.Kind.AST.RiseKind(riseKind)) => - Some((string, riseKind)) + Some((string, rise.isWellKindedType.DataTypeKind(riseKind))) + case (string, DPIA.Kind.AST.VariadicKind(n, DPIA.Kind.AST.RiseKind(riseKind))) => + Some((string, rise.isWellKindedType.VariadicKind(n, riseKind))) case _ => None } - kindOf(dataType, nenv).isDefined && isWellKinded(access, env) + kindOf(dataType, nenv, variadicN).isDefined && isWellKinded(access, env) case AST.AccType(dataType) => val nenv = env.flatMap { case (string, DPIA.Kind.AST.RiseKind(riseKind)) => - Some((string, riseKind)) + Some((string, rise.isWellKindedType.DataTypeKind(riseKind))) case _ => None } - kindOf(dataType, nenv).isDefined + kindOf(dataType, nenv, variadicN).isDefined case AST.CommType => true case AST.PairType(lhs, rhs) => - isWellKinded(lhs, env) && isWellKinded(rhs, env) + isWellKinded(lhs, env, variadicN) && isWellKinded(rhs, env, variadicN) case AST.FunType(inT, outT) => - isWellKinded(inT, env) && isWellKinded(outT, env) + isWellKinded(inT, env, variadicN) && isWellKinded(outT, env, variadicN) case AST.DepFunType(id, kind, t) => if (env.isDefinedAt(id.name)) { // we forbid shadowing false } else { - isWellKinded(t, env.updated(id.name, kind)) + isWellKinded(t, env.updated(id.name, kind), variadicN) } case AST.Identifier(name) => env.contains(name) + case AST.VariadicType(n, ty) => + isWellKinded(ty, env, Some(n.name)) } } diff --git a/meta/src/main/scala/meta/parser/rise/Type.scala b/meta/src/main/scala/meta/parser/rise/Type.scala index 6cd066db8..8172fba63 100644 --- a/meta/src/main/scala/meta/parser/rise/Type.scala +++ b/meta/src/main/scala/meta/parser/rise/Type.scala @@ -9,9 +9,13 @@ object Type { sealed trait AST object AST { case class Identifier(name: String) extends AST + case class UnrolledIdentifier(name: String) extends AST + case class FunType(inT: AST, outT: AST) extends AST case class DepFunType(id: Identifier, kind: Kind.AST, t: AST) extends AST case class ImplicitDepFunType(id: Identifier, kind: Kind.AST, t: AST) extends AST + case class VariadicFunType(n: Identifier, inTs: AST, outT: AST) extends AST + case class VariadicDepFunType(n: Identifier, id: Identifier, kind: Kind.AST, t: AST) extends AST case class ScalarType(t: String) extends AST case object NatType extends AST @@ -53,6 +57,11 @@ object Type { def DepFunType: P[AST.DepFunType] = P("(" ~ IdentifierKindPair ~ ")" ~ "->" ~/ TypeSignature).map(AST.DepFunType.tupled) + def VariadicDepFunType: P[AST.VariadicDepFunType] = + P(Identifier.map(AST.Identifier) ~ "*" ~ "(" ~ "(" ~ IdentifierKindPair ~ ")" ~ "->" ~ ")" ~/ TypeSignature).map { + case (n, (id, kind), t) => AST.VariadicDepFunType(n, id, kind, t) + } + def ImplicitDepFunType: P[AST.ImplicitDepFunType] = P("{" ~ IdentifierKindPair ~ "}" ~ "->" ~/ TypeSignature). map(AST.ImplicitDepFunType.tupled) @@ -60,10 +69,14 @@ object Type { def FunType: P[AST.FunType] = P(NoCut(LeftTypeSignature) ~ "->" ~/ TypeSignature).map(AST.FunType.tupled) + def VariadicFunType: P[AST.VariadicFunType] = + P(Identifier.map(AST.Identifier) ~ "*" ~ "(" ~ NoCut(LeftTypeSignature) ~ "->" ~ ")" ~/ TypeSignature).map( + AST.VariadicFunType.tupled) + // Types that can appear at the left of an function arrow def LeftTypeSignature: P[AST] = P(DataType.DataType | ("(" ~ TypeSignature ~ ")")) - P(DepFunType | ImplicitDepFunType | FunType | LeftTypeSignature) + P(DepFunType | VariadicDepFunType | ImplicitDepFunType | FunType | VariadicFunType | LeftTypeSignature) } def TypeIdentifier[_: P]: P[AST.Identifier] = P(Identifier).map(AST.Identifier) @@ -71,6 +84,8 @@ object Type { def IdentifierKindPair[_: P]: P[(AST.Identifier, Kind.AST)] = P(Identifier.map(AST.Identifier) ~ ":" ~ Kind.Kind) + def UnrolledTypeIdentifier[_: P]: P[AST.UnrolledIdentifier] = P("*" ~ Identifier).map(AST.UnrolledIdentifier) + object DataType { def ScalarType[_: P]: P[AST.ScalarType] = P("bool".! | "int".! | @@ -127,7 +142,7 @@ object Type { def DataType[_: P]: P[AST] = P(ScalarType | NatType | OpaqueType | IndexType | VectorType | FragmentType | ManagedBufferType | DepArrayType | ArrayType | DepPairType | NatToDataApply | - PairType | TypeIdentifier | ("(" ~ DataType ~ ")")) + PairType | UnrolledTypeIdentifier | TypeIdentifier | ("(" ~ DataType ~ ")")) def TypeName[_: P]: P[Unit] = P(ScalarType | NatType | "idx" | "vec" | "fragment" | "matrixLayout") diff --git a/meta/src/main/scala/meta/parser/rise/isWellKindedType.scala b/meta/src/main/scala/meta/parser/rise/isWellKindedType.scala index afb868305..351a2865b 100644 --- a/meta/src/main/scala/meta/parser/rise/isWellKindedType.scala +++ b/meta/src/main/scala/meta/parser/rise/isWellKindedType.scala @@ -1,34 +1,57 @@ package meta.parser.rise import meta.parser._ +import meta.parser.rise.Kind.{AST => KindAST} object isWellKindedType { - def apply(typeAST: Type.AST): Boolean = { - kindOf(typeAST, Map.empty).isDefined + def apply(scalaParams: Option[List[scala.meta.Term.Param]], + typeAST: Type.AST): Boolean = { + import scala.meta._ + val env = scalaParams.getOrElse(List.empty).foldLeft[Map[String, Kind]](Map.empty){ + case (env, param) => param.decltpe match { + case Some(t"Int") => env.updated(param.name.value, DataTypeKind(KindAST.Nat)) + case _ => env + } + } + kindOf(typeAST, env, None).isDefined } - sealed trait DataTypeOrFunctionKind - case class DataTypeKind(kind: Kind.AST) extends DataTypeOrFunctionKind - case object FunctionKind extends DataTypeOrFunctionKind + sealed trait Kind + case class DataTypeKind(kind: KindAST) extends Kind + case object FunctionKind extends Kind + case class VariadicKind(n: String, kind: KindAST) extends Kind def kindOf(typeAST: Type.AST, - env: Map[String, Kind.AST]): Option[DataTypeOrFunctionKind] = { + env: Map[String, Kind], + variadicN: Option[String]): Option[Kind] = { import Type._ typeAST match { case AST.Identifier(name) => - env.get(name).map(DataTypeKind) + env.get(name) + case AST.UnrolledIdentifier(name) => + env.get(name).flatMap { + case VariadicKind(n, kind) => + variadicN.flatMap(m => + if (n == m) { + Some(VariadicKind(n, kind)) + } else { + None + }) + case DataTypeKind(_) => None + case FunctionKind => None + } case AST.FunType(inT, outT) => for { - _ <- kindOf(inT, env) - _ <- kindOf(outT, env) + _ <- kindOf(inT, env, variadicN) + _ <- kindOf(outT, env, variadicN) } yield FunctionKind case AST.DepFunType(id, kind, t) => if (env.isDefinedAt(id.name)) { None // we forbid shadowing } else { for { - _ <- kindOf(t, env.updated(id.name, kind)) + _ <- kindOf(t, env.updated(id.name, DataTypeKind(kind)), variadicN) } yield FunctionKind } case AST.ImplicitDepFunType(id, kind, t) => @@ -36,141 +59,162 @@ object isWellKindedType { None // we forbid shadowing } else { for { - _ <- kindOf(t, env.updated(id.name, kind)) + _ <- kindOf(t, env.updated(id.name, DataTypeKind(kind)), variadicN) + } yield FunctionKind + } + case AST.VariadicFunType(n, inTs, outT) => + for { + k1 <- kindOf(n, env, None) + _ <- kindOf(inTs, env, Some(n.name)) + _ <- kindOf(outT, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat) + } yield FunctionKind + case AST.VariadicDepFunType(n, id, kind, t) => + if (env.isDefinedAt(id.name)) { + None // we forbid shadowing + } else { + for { + k1 <- kindOf(n, env, None) + _ <- kindOf(t, env.updated(id.name, VariadicKind(n.name, kind)), variadicN) + if k1 == DataTypeKind(KindAST.Nat) } yield FunctionKind } case AST.VectorType(size, elemType) => for { - k1 <- kindOf(size, env) - k2 <- kindOf(elemType, env) - if k1 == Kind.AST.Nat && k2 == DataTypeKind(Kind.AST.Data) - } yield DataTypeKind(Kind.AST.Data) + k1 <- kindOf(size, env, variadicN) + k2 <- kindOf(elemType, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Data) + } yield DataTypeKind(KindAST.Data) case AST.IndexType(size) => for { - k <- kindOf(size, env) - if k == Kind.AST.Nat - } yield DataTypeKind(Kind.AST.Data) + k <- kindOf(size, env, variadicN) + if k == DataTypeKind(KindAST.Nat) + } yield DataTypeKind(KindAST.Data) case AST.PairType(lhs, rhs) => for { - k1 <- kindOf(lhs, env) - k2 <- kindOf(rhs, env) - if k1 == DataTypeKind(Kind.AST.Data) && k2 == DataTypeKind(Kind.AST.Data) - } yield DataTypeKind(Kind.AST.Data) + k1 <- kindOf(lhs, env, variadicN) + k2 <- kindOf(rhs, env, variadicN) + if k1 == DataTypeKind(KindAST.Data) && k2 == DataTypeKind(KindAST.Data) + } yield DataTypeKind(KindAST.Data) case AST.DepPairType(id, kind, t) => if (env.isDefinedAt(id.name)) { None // we forbid shadowing } else { - kindOf(t, env.updated(id.name, kind)) + kindOf(t, env.updated(id.name, DataTypeKind(kind)), variadicN) } case AST.NatToDataApply(f, n) => for { - k1 <- kindOf(f, env) - k2 <- kindOf(n, env) - if k1 == DataTypeKind(Kind.AST.Nat2Data) && k2 == Kind.AST.Nat - } yield DataTypeKind(Kind.AST.Data) + k1 <- kindOf(f, env, variadicN) + k2 <- kindOf(n, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat2Data) && k2 == DataTypeKind(KindAST.Nat) + } yield DataTypeKind(KindAST.Data) case AST.NatToDataLambda(id, t) => if (env.isDefinedAt(id.name)) { None // we forbid shadowing } else { for { - k <- kindOf(t, env.updated(id.name, Kind.AST.Nat)) - if k == DataTypeKind(Kind.AST.Data) - } yield DataTypeKind(Kind.AST.Nat2Data) + k <- kindOf(t, env.updated(id.name, DataTypeKind(KindAST.Nat)), variadicN) + if k == DataTypeKind(KindAST.Data) + } yield DataTypeKind(KindAST.Nat2Data) } case AST.ArrayType(size, elemType) => for { - k1 <- kindOf(size, env) - k2 <- kindOf(elemType, env) - if k1 == Kind.AST.Nat && k2 == DataTypeKind(Kind.AST.Data) - } yield DataTypeKind(Kind.AST.Data) + k1 <- kindOf(size, env, variadicN) + k2 <- kindOf(elemType, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Data) + } yield DataTypeKind(KindAST.Data) case AST.DepArrayType(size, fdt) => for { - k1 <- kindOf(size, env) - k2 <- kindOf(fdt, env) - if k1 == Kind.AST.Nat && k2 == DataTypeKind(Kind.AST.Nat2Data) - } yield DataTypeKind(Kind.AST.Data) + k1 <- kindOf(size, env, variadicN) + k2 <- kindOf(fdt, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Nat2Data) + } yield DataTypeKind(KindAST.Data) case AST.FragmentType(n, m, k, elemType, fKind, mLayout) => for { - k1 <- kindOf(n, env) - k2 <- kindOf(m, env) - k3 <- kindOf(k, env) - k4 <- kindOf(elemType, env) - k5 <- kindOf(fKind, env) - k6 <- kindOf(mLayout, env) - if k1 == Kind.AST.Nat && k2 == Kind.AST.Nat && k3 == Kind.AST.Nat && k4 == DataTypeKind(Kind.AST.Data) && - k5 == Kind.AST.Fragment && k6 == Kind.AST.MatrixLayout - } yield DataTypeKind(Kind.AST.Data) + k1 <- kindOf(n, env, variadicN) + k2 <- kindOf(m, env, variadicN) + k3 <- kindOf(k, env, variadicN) + k4 <- kindOf(elemType, env, variadicN) + k5 <- kindOf(fKind, env, variadicN) + k6 <- kindOf(mLayout, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Nat) && + k3 == DataTypeKind(KindAST.Nat) && k4 == DataTypeKind(KindAST.Data) && + k5 == DataTypeKind(KindAST.Fragment) && k6 == DataTypeKind(KindAST.MatrixLayout) + } yield DataTypeKind(KindAST.Data) case AST.ManagedBufferType(dt) => for { - k1 <- kindOf(dt, env) - if k1 == DataTypeKind(Kind.AST.Data) - } yield DataTypeKind(Kind.AST.Data) + k1 <- kindOf(dt, env, variadicN) + if k1 == DataTypeKind(KindAST.Data) + } yield DataTypeKind(KindAST.Data) case _: AST.ScalarType | AST.NatType | _: AST.OpaqueType => - Some(DataTypeKind(Kind.AST.Data)) + Some(DataTypeKind(KindAST.Data)) } } def kindOf(natAST: Nat.AST, - env: Map[String, Kind.AST] - ): Option[Kind.AST] = { + env: Map[String, Kind], + variadicN: Option[String] + ): Option[Kind] = { natAST match { case Nat.AST.Identifier(id) => env.get(id) case Nat.AST.Number(_) => - Some(Kind.AST.Nat) + Some(DataTypeKind(KindAST.Nat)) case Nat.AST.BinaryOp(lhs, _, rhs) => for { - k1 <- kindOf(lhs, env) - k2 <- kindOf(rhs, env) - if k1 == Kind.AST.Nat && k2 == Kind.AST.Nat - } yield Kind.AST.Nat + k1 <- kindOf(lhs, env, variadicN) + k2 <- kindOf(rhs, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Nat) + } yield DataTypeKind(KindAST.Nat) case Nat.AST.TernaryOp(_, thenN, elseN) => for { - k1 <- kindOf(thenN, env) - k2 <- kindOf(elseN, env) - if k1 == Kind.AST.Nat && k2 == Kind.AST.Nat - } yield Kind.AST.Nat + k1 <- kindOf(thenN, env, variadicN) + k2 <- kindOf(elseN, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Nat) + } yield DataTypeKind(KindAST.Nat) case Nat.AST.Nat2NatApply(f, n) => for { - k1 <- kindOf(f, env) - k2 <- kindOf(n, env) - if k1 == DataTypeKind(Kind.AST.Nat2Nat) && k2 == Kind.AST.Nat - } yield Kind.AST.Nat + k1 <- kindOf(f, env, variadicN) + k2 <- kindOf(n, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat2Nat) && k2 == DataTypeKind(KindAST.Nat) + } yield DataTypeKind(KindAST.Nat) case Nat.AST.Sum(id, from, upTo, body) => - val nEnv = env.updated(id.name, Kind.AST.Nat) + val nEnv = env.updated(id.name, DataTypeKind(KindAST.Nat)) for { - k1 <- kindOf(from, nEnv) - k2 <- kindOf(upTo, nEnv) - k3 <- kindOf(body, nEnv) - if k1 == Kind.AST.Nat && k2 == Kind.AST.Nat && k3 == Kind.AST.Nat - } yield Kind.AST.Nat + k1 <- kindOf(from, nEnv, variadicN) + k2 <- kindOf(upTo, nEnv, variadicN) + k3 <- kindOf(body, nEnv, variadicN) + if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Nat) && k3 == DataTypeKind(KindAST.Nat) + } yield DataTypeKind(KindAST.Nat) } } def kindOf(fragmentAST: Type.Fragment.AST, - env: Map[String, Kind.AST] - ): Option[Kind.AST] = { + env: Map[String, Kind], + variadicN: Option[String] + ): Option[Kind] = { import Type._ fragmentAST match { case Fragment.AST.Identifier(id) => env.get(id) - case Fragment.AST.ACC | Fragment.AST.A | Fragment.AST.B => Some(Kind.AST.Fragment) + case Fragment.AST.ACC | Fragment.AST.A | Fragment.AST.B => Some(DataTypeKind(KindAST.Fragment)) } } def kindOf(matrixLayout: Type.MatrixLayout.AST, - env: Map[String, Kind.AST] - ): Option[Kind.AST] = { + env: Map[String, Kind], + variadicN: Option[String] + ): Option[Kind] = { import Type._ matrixLayout match { case MatrixLayout.AST.Identifier(id) => env.get(id) case MatrixLayout.AST.ROW_MAJOR | MatrixLayout.AST.COL_MAJOR | - MatrixLayout.AST.NONE => Some(Kind.AST.MatrixLayout) + MatrixLayout.AST.NONE => Some(DataTypeKind(KindAST.MatrixLayout)) } } } diff --git a/src/main/scala/rise/Cuda/primitives/.gitignore b/src/main/scala/rise/Cuda/primitives/.gitignore deleted file mode 100644 index 2d916b240..000000000 --- a/src/main/scala/rise/Cuda/primitives/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.scala \ No newline at end of file diff --git a/src/main/scala/rise/Cuda/primitives/asFragment.scala b/src/main/scala/rise/Cuda/primitives/asFragment.scala new file mode 100644 index 000000000..9805b1d8a --- /dev/null +++ b/src/main/scala/rise/Cuda/primitives/asFragment.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.Cuda.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object asFragment extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "asFragment" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (m: Nat) => impl { (k: Nat) => impl { (s: DataType) => impl { (f: FragmentKind) => impl { (l: MatrixLayout) => ArrayType(n, ArrayType(m, s)) ->: FragmentType(n, m, k, s, f, l) } } } } } } + } + override def toString: String = "asFragment" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/Cuda/primitives/asMatrix.scala b/src/main/scala/rise/Cuda/primitives/asMatrix.scala new file mode 100644 index 000000000..f54d15b2d --- /dev/null +++ b/src/main/scala/rise/Cuda/primitives/asMatrix.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.Cuda.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object asMatrix extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "asMatrix" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (m: Nat) => impl { (k: Nat) => impl { (s: DataType) => FragmentType(n, m, k, s, FragmentKind.Accumulator, MatrixLayout.Row_Major) ->: ArrayType(n, ArrayType(m, s)) } } } } + } + override def toString: String = "asMatrix" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/Cuda/primitives/generateFragment.scala b/src/main/scala/rise/Cuda/primitives/generateFragment.scala new file mode 100644 index 000000000..9aae2c6cd --- /dev/null +++ b/src/main/scala/rise/Cuda/primitives/generateFragment.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.Cuda.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object generateFragment extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "generateFragment" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (m: Nat) => impl { (k: Nat) => impl { (s: DataType) => impl { (f: FragmentKind) => s ->: FragmentType(n, m, k, s, f, MatrixLayout.Row_Major) } } } } } + } + override def toString: String = "generateFragment" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/Cuda/primitives/globalToShared.scala b/src/main/scala/rise/Cuda/primitives/globalToShared.scala new file mode 100644 index 000000000..8fd0bfe26 --- /dev/null +++ b/src/main/scala/rise/Cuda/primitives/globalToShared.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.Cuda.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object globalToShared extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "globalToShared" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (t: DataType) => t ->: t } + } + override def toString: String = "globalToShared" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/Cuda/primitives/mapBlock.scala b/src/main/scala/rise/Cuda/primitives/mapBlock.scala new file mode 100644 index 000000000..9769c8ff8 --- /dev/null +++ b/src/main/scala/rise/Cuda/primitives/mapBlock.scala @@ -0,0 +1,37 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.Cuda.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +final case class mapBlock(dim: Int) extends Builder { + override def toString: String = "mapBlock" + override def primitive: rise.core.Primitive = mapBlock.Primitive(dim)() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(mapBlock.Primitive(dim)()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} +object mapBlock { + private final case class Primitive(dim: Int)(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mapBlock" + override def setType(ty: Type): Primitive = Primitive(dim)(ty) + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n, s) ->: ArrayType(n, t) } } } + override def primEq(obj: rise.core.Primitive): Boolean = obj match { + case p: Primitive => + p.dim == dim && true + case _ => + false + } + } + def unapply(arg: rise.core.Expr): Option[Int] = arg match { + case p: Primitive => + Some(p.dim) + case _ => + None + } +} diff --git a/src/main/scala/rise/Cuda/primitives/mapFragment.scala b/src/main/scala/rise/Cuda/primitives/mapFragment.scala new file mode 100644 index 000000000..2b12f1a7a --- /dev/null +++ b/src/main/scala/rise/Cuda/primitives/mapFragment.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.Cuda.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object mapFragment extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mapFragment" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (m: Nat) => impl { (n: Nat) => impl { (k: Nat) => impl { (s: DataType) => impl { (f: FragmentKind) => impl { (l: MatrixLayout) => (s ->: s) ->: FragmentType(n, m, k, s, f, l) ->: FragmentType(n, m, k, s, f, l) } } } } } } + } + override def toString: String = "mapFragment" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/Cuda/primitives/mapGlobal.scala b/src/main/scala/rise/Cuda/primitives/mapGlobal.scala new file mode 100644 index 000000000..cef517f5c --- /dev/null +++ b/src/main/scala/rise/Cuda/primitives/mapGlobal.scala @@ -0,0 +1,37 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.Cuda.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +final case class mapGlobal(dim: Int) extends Builder { + override def toString: String = "mapGlobal" + override def primitive: rise.core.Primitive = mapGlobal.Primitive(dim)() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(mapGlobal.Primitive(dim)()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} +object mapGlobal { + private final case class Primitive(dim: Int)(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mapGlobal" + override def setType(ty: Type): Primitive = Primitive(dim)(ty) + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n, s) ->: ArrayType(n, t) } } } + override def primEq(obj: rise.core.Primitive): Boolean = obj match { + case p: Primitive => + p.dim == dim && true + case _ => + false + } + } + def unapply(arg: rise.core.Expr): Option[Int] = arg match { + case p: Primitive => + Some(p.dim) + case _ => + None + } +} diff --git a/src/main/scala/rise/Cuda/primitives/mapLane.scala b/src/main/scala/rise/Cuda/primitives/mapLane.scala new file mode 100644 index 000000000..da44ffad8 --- /dev/null +++ b/src/main/scala/rise/Cuda/primitives/mapLane.scala @@ -0,0 +1,37 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.Cuda.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +final case class mapLane(dim: Int) extends Builder { + override def toString: String = "mapLane" + override def primitive: rise.core.Primitive = mapLane.Primitive(dim)() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(mapLane.Primitive(dim)()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} +object mapLane { + private final case class Primitive(dim: Int)(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mapLane" + override def setType(ty: Type): Primitive = Primitive(dim)(ty) + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n, s) ->: ArrayType(n, t) } } } + override def primEq(obj: rise.core.Primitive): Boolean = obj match { + case p: Primitive => + p.dim == dim && true + case _ => + false + } + } + def unapply(arg: rise.core.Expr): Option[Int] = arg match { + case p: Primitive => + Some(p.dim) + case _ => + None + } +} diff --git a/src/main/scala/rise/Cuda/primitives/mapThreads.scala b/src/main/scala/rise/Cuda/primitives/mapThreads.scala new file mode 100644 index 000000000..0bd566d7c --- /dev/null +++ b/src/main/scala/rise/Cuda/primitives/mapThreads.scala @@ -0,0 +1,37 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.Cuda.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +final case class mapThreads(dim: Int) extends Builder { + override def toString: String = "mapThreads" + override def primitive: rise.core.Primitive = mapThreads.Primitive(dim)() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(mapThreads.Primitive(dim)()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} +object mapThreads { + private final case class Primitive(dim: Int)(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mapThreads" + override def setType(ty: Type): Primitive = Primitive(dim)(ty) + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n, s) ->: ArrayType(n, t) } } } + override def primEq(obj: rise.core.Primitive): Boolean = obj match { + case p: Primitive => + p.dim == dim && true + case _ => + false + } + } + def unapply(arg: rise.core.Expr): Option[Int] = arg match { + case p: Primitive => + Some(p.dim) + case _ => + None + } +} diff --git a/src/main/scala/rise/Cuda/primitives/mapWarp.scala b/src/main/scala/rise/Cuda/primitives/mapWarp.scala new file mode 100644 index 000000000..b6cecc046 --- /dev/null +++ b/src/main/scala/rise/Cuda/primitives/mapWarp.scala @@ -0,0 +1,37 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.Cuda.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +final case class mapWarp(dim: Int) extends Builder { + override def toString: String = "mapWarp" + override def primitive: rise.core.Primitive = mapWarp.Primitive(dim)() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(mapWarp.Primitive(dim)()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} +object mapWarp { + private final case class Primitive(dim: Int)(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mapWarp" + override def setType(ty: Type): Primitive = Primitive(dim)(ty) + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n, s) ->: ArrayType(n, t) } } } + override def primEq(obj: rise.core.Primitive): Boolean = obj match { + case p: Primitive => + p.dim == dim && true + case _ => + false + } + } + def unapply(arg: rise.core.Expr): Option[Int] = arg match { + case p: Primitive => + Some(p.dim) + case _ => + None + } +} diff --git a/src/main/scala/rise/Cuda/primitives/tensorMMA.scala b/src/main/scala/rise/Cuda/primitives/tensorMMA.scala new file mode 100644 index 000000000..1c9ea6e62 --- /dev/null +++ b/src/main/scala/rise/Cuda/primitives/tensorMMA.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.Cuda.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object tensorMMA extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "tensorMMA" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (la: MatrixLayout) => impl { (lb: MatrixLayout) => impl { (m: Nat) => impl { (n: Nat) => impl { (k: Nat) => impl { (s: DataType) => impl { (t: DataType) => FragmentType(m, k, n, s, FragmentKind.AMatrix, la) ->: FragmentType(k, n, m, s, FragmentKind.BMatrix, lb) ->: FragmentType(m, n, k, t, FragmentKind.Accumulator, MatrixLayout.Row_Major) ->: FragmentType(m, n, k, t, FragmentKind.Accumulator, MatrixLayout.Row_Major) } } } } } } } + } + override def toString: String = "tensorMMA" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/DSL/package.scala b/src/main/scala/rise/core/DSL/package.scala index 370be51f0..b629b4c84 100644 --- a/src/main/scala/rise/core/DSL/package.scala +++ b/src/main/scala/rise/core/DSL/package.scala @@ -564,21 +564,33 @@ package object DSL { def lu8(v: Int): ToBeTyped[Expr] = cast(l(v)) :: u8 object foreignFun { - def apply(name: String, t: Type): ToBeTyped[ForeignFunction] = { - toBeTyped(ForeignFunction(ForeignFunction.Decl(name, None))(t)) + def apply(name: String, t: Type): ToBeTyped[Expr] = { + apply(ForeignFunction.Decl(name, None), t) } - def apply( - name: String, - params: Seq[String], - body: String, - t: Type - ): ToBeTyped[ForeignFunction] = { - toBeTyped( - ForeignFunction( - ForeignFunction.Decl(name, Some(ForeignFunction.Def(params, body))) - )(t) - ) + def apply(name: String, + params: Seq[String], + body: String, + t: Type + ): ToBeTyped[Expr] = { + apply(ForeignFunction.Decl(name, Some(ForeignFunction.Def(params, body))), t) + } + + def apply(decl: ForeignFunction.Decl, t: Type): ToBeTyped[Expr] = { + def collectTypes(t: Type): (Seq[DataType], DataType) = { + t match { + case dt: DataType => (Vector(), dt) + case FunType(dt: DataType, out) => + val (i, o) = collectTypes(out) + (dt +: i, o) + case _ => throw new Exception("This should not be possible") + } + } + val (inTs, outT) = collectTypes(t) + val ff: ToBeTyped[Expr] = foreignFunction(decl, inTs.size).apply + inTs.foldLeft(ff) { + case (f, t) => f.apply(t) + }.apply(outT) } } diff --git a/src/main/scala/rise/core/ForeignFunction.scala b/src/main/scala/rise/core/ForeignFunction.scala index 8ceb6d9e2..525a5f1fd 100644 --- a/src/main/scala/rise/core/ForeignFunction.scala +++ b/src/main/scala/rise/core/ForeignFunction.scala @@ -1,18 +1,5 @@ package rise.core -import rise.core.types.Type - -case class ForeignFunction(decl: ForeignFunction.Decl)(override val t: Type) - extends Primitive { - override def primEq(obj: Primitive): Boolean = obj match { - case ForeignFunction(decl1) => decl == decl1 - case _ => false - } - override def typeScheme: Type = t - override def setType(t: Type): ForeignFunction = ForeignFunction(decl)(t) - override val name: String = decl.name -} - object ForeignFunction { case class Decl(name: String, definition: Option[Def]) case class Def(params: Seq[String], body: String) diff --git a/src/main/scala/rise/core/package.scala b/src/main/scala/rise/core/package.scala index ba0bce6ba..43cc436dc 100644 --- a/src/main/scala/rise/core/package.scala +++ b/src/main/scala/rise/core/package.scala @@ -38,7 +38,6 @@ package object core { s"DepApply[DataKind](${toEvaluableString(f)}, $dt)" } case Literal(d) => s"Literal($d)" - case ff: ForeignFunction => ff.toString case TypeAnnotation(e, t) => s"TypeAnnotation(${toEvaluableString(e)}, $t)" case TypeAssertion(e, t) => s"TypeAssertion(${toEvaluableString(e)}, $t)" case Opaque(e, t) => s"Opaque(${toEvaluableString(e)}, $t)" diff --git a/src/main/scala/rise/core/primitives/.gitignore b/src/main/scala/rise/core/primitives/.gitignore deleted file mode 100644 index afcdcd193..000000000 --- a/src/main/scala/rise/core/primitives/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.scala diff --git a/src/main/scala/rise/core/primitives/add.scala b/src/main/scala/rise/core/primitives/add.scala new file mode 100644 index 000000000..67f6acd40 --- /dev/null +++ b/src/main/scala/rise/core/primitives/add.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object add extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "add" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (t: DataType) => t ->: t ->: t } + } + override def toString: String = "add" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/asScalar.scala b/src/main/scala/rise/core/primitives/asScalar.scala new file mode 100644 index 000000000..2fb9f4399 --- /dev/null +++ b/src/main/scala/rise/core/primitives/asScalar.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object asScalar extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "asScalar" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (m: Nat) => impl { (t: DataType) => ArrayType(m, VectorType(n, t)) ->: ArrayType(m * n, t) } } } + } + override def toString: String = "asScalar" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/asVector.scala b/src/main/scala/rise/core/primitives/asVector.scala new file mode 100644 index 000000000..263ee8012 --- /dev/null +++ b/src/main/scala/rise/core/primitives/asVector.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object asVector extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "asVector" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = expl { (n: Nat) => impl { (m: Nat) => impl { (t: DataType) => ArrayType(m * n, t) ->: ArrayType(m, VectorType(n, t)) } } } + } + override def toString: String = "asVector" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/asVectorAligned.scala b/src/main/scala/rise/core/primitives/asVectorAligned.scala new file mode 100644 index 000000000..e77ee5246 --- /dev/null +++ b/src/main/scala/rise/core/primitives/asVectorAligned.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object asVectorAligned extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "asVectorAligned" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = expl { (n: Nat) => impl { (m: Nat) => impl { (t: DataType) => ArrayType(m * n, t) ->: ArrayType(m, VectorType(n, t)) } } } + } + override def toString: String = "asVectorAligned" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/cast.scala b/src/main/scala/rise/core/primitives/cast.scala new file mode 100644 index 000000000..563faf90b --- /dev/null +++ b/src/main/scala/rise/core/primitives/cast.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object cast extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "cast" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (s: DataType) => impl { (t: DataType) => s ->: t } } + } + override def toString: String = "cast" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/circularBuffer.scala b/src/main/scala/rise/core/primitives/circularBuffer.scala new file mode 100644 index 000000000..995a55f90 --- /dev/null +++ b/src/main/scala/rise/core/primitives/circularBuffer.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object circularBuffer extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "circularBuffer" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => expl { (alloc: Nat) => expl { (sz: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n - 1 + sz, s) ->: ArrayType(n, ArrayType(sz, t)) } } } } } + } + override def toString: String = "circularBuffer" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/concat.scala b/src/main/scala/rise/core/primitives/concat.scala new file mode 100644 index 000000000..461bda813 --- /dev/null +++ b/src/main/scala/rise/core/primitives/concat.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object concat extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "concat" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (m: Nat) => impl { (t: DataType) => ArrayType(n, t) ->: ArrayType(m, t) ->: ArrayType(n + m, t) } } } + } + override def toString: String = "concat" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/depJoin.scala b/src/main/scala/rise/core/primitives/depJoin.scala new file mode 100644 index 000000000..5ef34e954 --- /dev/null +++ b/src/main/scala/rise/core/primitives/depJoin.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object depJoin extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "depJoin" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (lenF: NatToNat) => impl { (t: DataType) => DepArrayType(n, n2dtFun { (i: NatIdentifier) => ArrayType(lenF(i), t) }) ->: ArrayType(BigSum(from = 0, upTo = n - 1, (i: Nat) => lenF(i)), t) } } } + } + override def toString: String = "depJoin" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/depMapSeq.scala b/src/main/scala/rise/core/primitives/depMapSeq.scala new file mode 100644 index 000000000..5ef16a835 --- /dev/null +++ b/src/main/scala/rise/core/primitives/depMapSeq.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object depMapSeq extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "depMapSeq" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (ft1: NatToData) => impl { (ft2: NatToData) => expl { (k: Nat) => NatToDataApply(ft1, k) ->: NatToDataApply(ft2, k) } ->: DepArrayType(n, ft1) ->: DepArrayType(n, ft2) } } } + } + override def toString: String = "depMapSeq" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/depZip.scala b/src/main/scala/rise/core/primitives/depZip.scala new file mode 100644 index 000000000..317b91c57 --- /dev/null +++ b/src/main/scala/rise/core/primitives/depZip.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object depZip extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "depZip" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (ft1: NatToData) => impl { (ft2: NatToData) => DepArrayType(n, ft1) ->: DepArrayType(n, ft2) ->: DepArrayType(n, n2dtFun { (i: NatIdentifier) => PairType(NatToDataApply(ft1, i), NatToDataApply(ft2, i)) }) } } } + } + override def toString: String = "depZip" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/div.scala b/src/main/scala/rise/core/primitives/div.scala new file mode 100644 index 000000000..e1c2387a2 --- /dev/null +++ b/src/main/scala/rise/core/primitives/div.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object div extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "div" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (t: DataType) => t ->: t ->: t } + } + override def toString: String = "div" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/dmatch.scala b/src/main/scala/rise/core/primitives/dmatch.scala new file mode 100644 index 000000000..0fb9ad5ff --- /dev/null +++ b/src/main/scala/rise/core/primitives/dmatch.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object dmatch extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "dmatch" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (fdt: NatToData) => impl { (t: DataType) => (Nat ** ((n: Nat) => NatToDataApply(fdt, n))) ->: expl { (m: Nat) => NatToDataApply(fdt, m) ->: t } ->: t } } + } + override def toString: String = "dmatch" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/drop.scala b/src/main/scala/rise/core/primitives/drop.scala new file mode 100644 index 000000000..cdf2151b8 --- /dev/null +++ b/src/main/scala/rise/core/primitives/drop.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object drop extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "drop" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = expl { (n: Nat) => impl { (m: Nat) => impl { (t: DataType) => ArrayType(n + m, t) ->: ArrayType(m, t) } } } + } + override def toString: String = "drop" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/equal.scala b/src/main/scala/rise/core/primitives/equal.scala new file mode 100644 index 000000000..7aba6a815 --- /dev/null +++ b/src/main/scala/rise/core/primitives/equal.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object equal extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "equal" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (t: DataType) => t ->: t ->: bool } + } + override def toString: String = "equal" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/foreignFunction.scala b/src/main/scala/rise/core/primitives/foreignFunction.scala new file mode 100644 index 000000000..f6fe871b3 --- /dev/null +++ b/src/main/scala/rise/core/primitives/foreignFunction.scala @@ -0,0 +1,46 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +final case class foreignFunction(funDecl: rise.core.ForeignFunction.Decl, n: Int) extends Builder { + override def toString: String = "foreignFunction" + override def primitive: rise.core.Primitive = foreignFunction.Primitive(funDecl, n)() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(foreignFunction.Primitive(funDecl, n)()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} +object foreignFunction { + private final case class Primitive(funDecl: rise.core.ForeignFunction.Decl, n: Int)(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "foreignFunction" + override def setType(ty: Type): Primitive = Primitive(funDecl, n)(ty) + override def typeScheme: Type = { + val inTs = Seq.fill(n)(DataTypeIdentifier(freshName("dt"), isExplicit = true)) + inTs.foldRight(expl { (outT: DataType) => inTs.foldRight(outT: Type)({ + case (lhsT, rhsT) => + lhsT ->: rhsT + }) }: Type)({ + case (id, t) => + DepFunType[DataKind, Type](id, t) + }) + } + override def primEq(obj: rise.core.Primitive): Boolean = obj match { + case p: Primitive => + p.funDecl == funDecl && (p.n == n && true) + case _ => + false + } + } + def unapply(arg: rise.core.Expr): Option[(rise.core.ForeignFunction.Decl, Int)] = arg match { + case p: Primitive => + Some(p.funDecl, p.n) + case _ => + None + } +} diff --git a/src/main/scala/rise/core/primitives/fst.scala b/src/main/scala/rise/core/primitives/fst.scala new file mode 100644 index 000000000..4fce44116 --- /dev/null +++ b/src/main/scala/rise/core/primitives/fst.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object fst extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "fst" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (s: DataType) => impl { (t: DataType) => PairType(s, t) ->: s } } + } + override def toString: String = "fst" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/gather.scala b/src/main/scala/rise/core/primitives/gather.scala new file mode 100644 index 000000000..20f7e8f17 --- /dev/null +++ b/src/main/scala/rise/core/primitives/gather.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object gather extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "gather" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (m: Nat) => impl { (t: DataType) => ArrayType(m, IndexType(n)) ->: ArrayType(n, t) ->: ArrayType(m, t) } } } + } + override def toString: String = "gather" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/generate.scala b/src/main/scala/rise/core/primitives/generate.scala new file mode 100644 index 000000000..5b8da8f50 --- /dev/null +++ b/src/main/scala/rise/core/primitives/generate.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object generate extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "generate" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (t: DataType) => (IndexType(n) ->: t) ->: ArrayType(n, t) } } + } + override def toString: String = "generate" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/gt.scala b/src/main/scala/rise/core/primitives/gt.scala new file mode 100644 index 000000000..3fbb51ec9 --- /dev/null +++ b/src/main/scala/rise/core/primitives/gt.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object gt extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "gt" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (t: DataType) => t ->: t ->: bool } + } + override def toString: String = "gt" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/id.scala b/src/main/scala/rise/core/primitives/id.scala new file mode 100644 index 000000000..a2a02dc74 --- /dev/null +++ b/src/main/scala/rise/core/primitives/id.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object id extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "id" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (t: DataType) => t ->: t } + } + override def toString: String = "id" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/idx.scala b/src/main/scala/rise/core/primitives/idx.scala new file mode 100644 index 000000000..52eae70c1 --- /dev/null +++ b/src/main/scala/rise/core/primitives/idx.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object idx extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "idx" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (t: DataType) => IndexType(n) ->: ArrayType(n, t) ->: t } } + } + override def toString: String = "idx" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/indexAsNat.scala b/src/main/scala/rise/core/primitives/indexAsNat.scala new file mode 100644 index 000000000..3a0614a0d --- /dev/null +++ b/src/main/scala/rise/core/primitives/indexAsNat.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object indexAsNat extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "indexAsNat" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => IndexType(n) ->: NatType } + } + override def toString: String = "indexAsNat" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/iterate.scala b/src/main/scala/rise/core/primitives/iterate.scala new file mode 100644 index 000000000..2c1ef1d28 --- /dev/null +++ b/src/main/scala/rise/core/primitives/iterate.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object iterate extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "iterate" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (m: Nat) => expl { (k: Nat) => impl { (t: DataType) => expl { (l: Nat) => ArrayType(l * n, t) ->: ArrayType(l, t) } ->: ArrayType(m * n.pow(k), t) ->: ArrayType(m, t) } } } } + } + override def toString: String = "iterate" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/iterateStream.scala b/src/main/scala/rise/core/primitives/iterateStream.scala new file mode 100644 index 000000000..e83dbdad0 --- /dev/null +++ b/src/main/scala/rise/core/primitives/iterateStream.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object iterateStream extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "iterateStream" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n, s) ->: ArrayType(n, t) } } } + } + override def toString: String = "iterateStream" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/join.scala b/src/main/scala/rise/core/primitives/join.scala new file mode 100644 index 000000000..4f8be31b0 --- /dev/null +++ b/src/main/scala/rise/core/primitives/join.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object join extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "join" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (m: Nat) => impl { (t: DataType) => ArrayType(n, ArrayType(m, t)) ->: ArrayType(n * m, t) } } } + } + override def toString: String = "join" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/let.scala b/src/main/scala/rise/core/primitives/let.scala new file mode 100644 index 000000000..ddad64004 --- /dev/null +++ b/src/main/scala/rise/core/primitives/let.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object let extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "let" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (s: DataType) => impl { (t: DataType) => s ->: (s ->: t) ->: t } } + } + override def toString: String = "let" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/lt.scala b/src/main/scala/rise/core/primitives/lt.scala new file mode 100644 index 000000000..fd1c8e59f --- /dev/null +++ b/src/main/scala/rise/core/primitives/lt.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object lt extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "lt" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (t: DataType) => t ->: t ->: bool } + } + override def toString: String = "lt" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/makeArray.scala b/src/main/scala/rise/core/primitives/makeArray.scala index a479308bb..df31ff402 100644 --- a/src/main/scala/rise/core/primitives/makeArray.scala +++ b/src/main/scala/rise/core/primitives/makeArray.scala @@ -1,46 +1,40 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package rise.core.primitives - -import rise.core.DSL.ToBeTyped +import rise.core.DSL._ import rise.core.DSL.Type._ import rise.core._ import rise.core.types._ - +import arithexpr.arithmetic._ final case class makeArray(n: Int) extends Builder { - override def apply: ToBeTyped[Primitive] = ToBeTyped(makeArray.Primitive(n)()) - - override def primitive: Primitive = makeArray.Primitive(n)() - + override def toString: String = "makeArray" + override def primitive: rise.core.Primitive = makeArray.Primitive(n)() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(makeArray.Primitive(n)()) override def unapply(arg: Expr): Boolean = arg match { - case _: makeArray.Primitive => true + case _: Primitive => true case _ => false } } - object makeArray { - private final case class Primitive(n: Int) - (override val t: Type = TypePlaceholder) - extends rise.core.Primitive - { - override def name: String = "makeArray" + private final case class Primitive(n: Int)(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "makeArray" + override def setType(ty: Type): Primitive = Primitive(n)(ty) + override def typeScheme: Type = impl { (dt: DataType) => Seq.fill(n)(dt).foldRight(ArrayType(n, dt): Type)({ + case (lhsT, rhsT) => + lhsT ->: rhsT + }) } override def primEq(obj: rise.core.Primitive): Boolean = obj match { - case Primitive(n1) => n == n1 - case _ => false + case p: Primitive => + p.n == n && true + case _ => + false } - override def setType(t: Type): Primitive = Primitive(n)(t) - override def typeScheme: Type = - impl { t: DataType => { - def tRec(m: Int, dt: DataType): Type = - if (m <= 0) { - ArrayType(n, dt) - } else { - dt ->: tRec(m - 1, dt) - } - tRec(n, t) - }} } - - def unapply(arg: Expr): Option[Int] = arg match { - case p: makeArray.Primitive => Some(p.n) - case _ => None + def unapply(arg: rise.core.Expr): Option[Int] = arg match { + case p: Primitive => + Some(p.n) + case _ => + None } } diff --git a/src/main/scala/rise/core/primitives/makeDepPair.scala b/src/main/scala/rise/core/primitives/makeDepPair.scala new file mode 100644 index 000000000..40959b65e --- /dev/null +++ b/src/main/scala/rise/core/primitives/makeDepPair.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object makeDepPair extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "makeDepPair" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (fdt: NatToData) => expl { (n: Nat) => NatToDataApply(fdt, n) ->: (Nat ** ((m: Nat) => NatToDataApply(fdt, m))) } } + } + override def toString: String = "makeDepPair" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/makePair.scala b/src/main/scala/rise/core/primitives/makePair.scala new file mode 100644 index 000000000..84cf68008 --- /dev/null +++ b/src/main/scala/rise/core/primitives/makePair.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object makePair extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "makePair" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (s: DataType) => impl { (t: DataType) => s ->: t ->: PairType(s, t) } } + } + override def toString: String = "makePair" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/map.scala b/src/main/scala/rise/core/primitives/map.scala new file mode 100644 index 000000000..b3d5a60e2 --- /dev/null +++ b/src/main/scala/rise/core/primitives/map.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object map extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "map" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n, s) ->: ArrayType(n, t) } } } + } + override def toString: String = "map" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/mapFst.scala b/src/main/scala/rise/core/primitives/mapFst.scala new file mode 100644 index 000000000..9913b1899 --- /dev/null +++ b/src/main/scala/rise/core/primitives/mapFst.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object mapFst extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mapFst" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (s1: DataType) => impl { (t: DataType) => impl { (s2: DataType) => (s1 ->: s2) ->: PairType(s1, t) ->: PairType(s2, t) } } } + } + override def toString: String = "mapFst" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/mapSeq.scala b/src/main/scala/rise/core/primitives/mapSeq.scala new file mode 100644 index 000000000..6e804b80e --- /dev/null +++ b/src/main/scala/rise/core/primitives/mapSeq.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object mapSeq extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mapSeq" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n, s) ->: ArrayType(n, t) } } } + } + override def toString: String = "mapSeq" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/mapSeqUnroll.scala b/src/main/scala/rise/core/primitives/mapSeqUnroll.scala new file mode 100644 index 000000000..036972bfd --- /dev/null +++ b/src/main/scala/rise/core/primitives/mapSeqUnroll.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object mapSeqUnroll extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mapSeqUnroll" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n, s) ->: ArrayType(n, t) } } } + } + override def toString: String = "mapSeqUnroll" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/mapSnd.scala b/src/main/scala/rise/core/primitives/mapSnd.scala new file mode 100644 index 000000000..c0cf65b0a --- /dev/null +++ b/src/main/scala/rise/core/primitives/mapSnd.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object mapSnd extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mapSnd" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (s: DataType) => impl { (t1: DataType) => impl { (t2: DataType) => (t1 ->: t2) ->: PairType(s, t1) ->: PairType(s, t2) } } } + } + override def toString: String = "mapSnd" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/mapStream.scala b/src/main/scala/rise/core/primitives/mapStream.scala new file mode 100644 index 000000000..837f0e4f1 --- /dev/null +++ b/src/main/scala/rise/core/primitives/mapStream.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object mapStream extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mapStream" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n, s) ->: ArrayType(n, t) } } } + } + override def toString: String = "mapStream" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/mod.scala b/src/main/scala/rise/core/primitives/mod.scala new file mode 100644 index 000000000..743522619 --- /dev/null +++ b/src/main/scala/rise/core/primitives/mod.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object mod extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mod" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (t: DataType) => t ->: t ->: t } + } + override def toString: String = "mod" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/mul.scala b/src/main/scala/rise/core/primitives/mul.scala new file mode 100644 index 000000000..f8bebc05d --- /dev/null +++ b/src/main/scala/rise/core/primitives/mul.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object mul extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mul" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (t: DataType) => t ->: t ->: t } + } + override def toString: String = "mul" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/natAsIndex.scala b/src/main/scala/rise/core/primitives/natAsIndex.scala new file mode 100644 index 000000000..59f62d82c --- /dev/null +++ b/src/main/scala/rise/core/primitives/natAsIndex.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object natAsIndex extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "natAsIndex" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = expl { (n: Nat) => NatType ->: IndexType(n) } + } + override def toString: String = "natAsIndex" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/neg.scala b/src/main/scala/rise/core/primitives/neg.scala new file mode 100644 index 000000000..16fd8ad7a --- /dev/null +++ b/src/main/scala/rise/core/primitives/neg.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object neg extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "neg" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (t: DataType) => t ->: t } + } + override def toString: String = "neg" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/not.scala b/src/main/scala/rise/core/primitives/not.scala new file mode 100644 index 000000000..56d6b35b9 --- /dev/null +++ b/src/main/scala/rise/core/primitives/not.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object not extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "not" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = bool ->: bool + } + override def toString: String = "not" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/padClamp.scala b/src/main/scala/rise/core/primitives/padClamp.scala new file mode 100644 index 000000000..0c5bd621f --- /dev/null +++ b/src/main/scala/rise/core/primitives/padClamp.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object padClamp extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "padClamp" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => expl { (l: Nat) => expl { (r: Nat) => impl { (t: DataType) => ArrayType(n, t) ->: ArrayType(l + n + r, t) } } } } + } + override def toString: String = "padClamp" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/padCst.scala b/src/main/scala/rise/core/primitives/padCst.scala new file mode 100644 index 000000000..f973c39e6 --- /dev/null +++ b/src/main/scala/rise/core/primitives/padCst.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object padCst extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "padCst" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => expl { (l: Nat) => expl { (r: Nat) => impl { (t: DataType) => t ->: ArrayType(n, t) ->: ArrayType(l + n + r, t) } } } } + } + override def toString: String = "padCst" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/padEmpty.scala b/src/main/scala/rise/core/primitives/padEmpty.scala new file mode 100644 index 000000000..00b3fe793 --- /dev/null +++ b/src/main/scala/rise/core/primitives/padEmpty.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object padEmpty extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "padEmpty" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => expl { (r: Nat) => impl { (t: DataType) => ArrayType(n, t) ->: ArrayType(n + r, t) } } } + } + override def toString: String = "padEmpty" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/partition.scala b/src/main/scala/rise/core/primitives/partition.scala new file mode 100644 index 000000000..b4bafb869 --- /dev/null +++ b/src/main/scala/rise/core/primitives/partition.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object partition extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "partition" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (t: DataType) => expl { (m: Nat) => expl { (lenF: NatToNat) => ArrayType(n, t) ->: DepArrayType(m, n2dtFun { (i: NatIdentifier) => ArrayType(lenF(i), t) }) } } } } + } + override def toString: String = "partition" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/primitives.rise b/src/main/scala/rise/core/primitives/primitives.rise index 5e7895adc..a53996885 100644 --- a/src/main/scala/rise/core/primitives/primitives.rise +++ b/src/main/scala/rise/core/primitives/primitives.rise @@ -27,12 +27,13 @@ def natAsIndex: (n: nat) -> natType -> idx[n] def let: {s: data} -> {t: data} -> s -> (s -> t) -> t def toMem: {t: data} -> t -> t +// foreign functions +def foreignFunction(funDecl: rise.core.ForeignFunction.Decl, n: Int): + n*((inTs: data) ->) (outT: data) -> n*(*inTs ->) outT + // array ops -// def makeArray(1): {t: data} -> t -> 1.t -// def makeArray(2): {t: data} -> t -> t -> 2.t -// def makeArray(3): {t: data} -> t -> t -> t -> 3.t -// ... -// def makeArray(n): {t: data} -> t -> t -> t -> ... -> t -> n.t +def makeArray(n: Int): {dt: data} -> n*(dt ->) n.dt + def generate: {n: nat} -> {t: data} -> (idx[n] -> t) -> n.t def idx: {n: nat} -> {t: data} -> idx[n] -> n.t -> t diff --git a/src/main/scala/rise/core/primitives/reduce.scala b/src/main/scala/rise/core/primitives/reduce.scala new file mode 100644 index 000000000..e00ddf2ee --- /dev/null +++ b/src/main/scala/rise/core/primitives/reduce.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object reduce extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "reduce" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (t: DataType) => (t ->: t ->: t) ->: t ->: ArrayType(n, t) ->: t } } + } + override def toString: String = "reduce" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/reduceSeq.scala b/src/main/scala/rise/core/primitives/reduceSeq.scala new file mode 100644 index 000000000..6f33868ec --- /dev/null +++ b/src/main/scala/rise/core/primitives/reduceSeq.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object reduceSeq extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "reduceSeq" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (t ->: s ->: t) ->: t ->: ArrayType(n, s) ->: t } } } + } + override def toString: String = "reduceSeq" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/reduceSeqUnroll.scala b/src/main/scala/rise/core/primitives/reduceSeqUnroll.scala new file mode 100644 index 000000000..7d8340820 --- /dev/null +++ b/src/main/scala/rise/core/primitives/reduceSeqUnroll.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object reduceSeqUnroll extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "reduceSeqUnroll" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (t ->: s ->: t) ->: t ->: ArrayType(n, s) ->: t } } } + } + override def toString: String = "reduceSeqUnroll" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/reorder.scala b/src/main/scala/rise/core/primitives/reorder.scala new file mode 100644 index 000000000..adf54ea18 --- /dev/null +++ b/src/main/scala/rise/core/primitives/reorder.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object reorder extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "reorder" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (t: DataType) => expl { (n: Nat) => expl { (idxF: NatToNat) => expl { (idxFinv: NatToNat) => ArrayType(n, t) ->: ArrayType(n, t) } } } } + } + override def toString: String = "reorder" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/rotateValues.scala b/src/main/scala/rise/core/primitives/rotateValues.scala new file mode 100644 index 000000000..7d22ce213 --- /dev/null +++ b/src/main/scala/rise/core/primitives/rotateValues.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object rotateValues extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "rotateValues" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => expl { (sz: Nat) => impl { (t: DataType) => (t ->: t) ->: ArrayType(n - 1 + sz, t) ->: ArrayType(n, ArrayType(sz, t)) } } } + } + override def toString: String = "rotateValues" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/scanSeq.scala b/src/main/scala/rise/core/primitives/scanSeq.scala new file mode 100644 index 000000000..444b542df --- /dev/null +++ b/src/main/scala/rise/core/primitives/scanSeq.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object scanSeq extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "scanSeq" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t ->: t) ->: t ->: ArrayType(n, s) ->: ArrayType(n, t) } } } + } + override def toString: String = "scanSeq" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/scatter.scala b/src/main/scala/rise/core/primitives/scatter.scala new file mode 100644 index 000000000..89388453d --- /dev/null +++ b/src/main/scala/rise/core/primitives/scatter.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object scatter extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "scatter" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (m: Nat) => impl { (t: DataType) => ArrayType(n, IndexType(m)) ->: ArrayType(n, t) ->: ArrayType(m, t) } } } + } + override def toString: String = "scatter" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/select.scala b/src/main/scala/rise/core/primitives/select.scala new file mode 100644 index 000000000..688daee52 --- /dev/null +++ b/src/main/scala/rise/core/primitives/select.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object select extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "select" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (t: DataType) => bool ->: t ->: t ->: t } + } + override def toString: String = "select" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/slide.scala b/src/main/scala/rise/core/primitives/slide.scala new file mode 100644 index 000000000..44cc0d621 --- /dev/null +++ b/src/main/scala/rise/core/primitives/slide.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object slide extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "slide" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => expl { (sz: Nat) => expl { (sp: Nat) => impl { (t: DataType) => ArrayType(sp * n + sz, t) ->: ArrayType(1 + n, ArrayType(sz, t)) } } } } + } + override def toString: String = "slide" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/snd.scala b/src/main/scala/rise/core/primitives/snd.scala new file mode 100644 index 000000000..4a7140e85 --- /dev/null +++ b/src/main/scala/rise/core/primitives/snd.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object snd extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "snd" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (s: DataType) => impl { (t: DataType) => PairType(s, t) ->: t } } + } + override def toString: String = "snd" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/split.scala b/src/main/scala/rise/core/primitives/split.scala new file mode 100644 index 000000000..661832469 --- /dev/null +++ b/src/main/scala/rise/core/primitives/split.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object split extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "split" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = expl { (n: Nat) => impl { (m: Nat) => impl { (t: DataType) => ArrayType(m * n, t) ->: ArrayType(m, ArrayType(n, t)) } } } + } + override def toString: String = "split" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/sub.scala b/src/main/scala/rise/core/primitives/sub.scala new file mode 100644 index 000000000..2abb751a2 --- /dev/null +++ b/src/main/scala/rise/core/primitives/sub.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object sub extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "sub" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (t: DataType) => t ->: t ->: t } + } + override def toString: String = "sub" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/take.scala b/src/main/scala/rise/core/primitives/take.scala new file mode 100644 index 000000000..f8f5d5b37 --- /dev/null +++ b/src/main/scala/rise/core/primitives/take.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object take extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "take" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = expl { (n: Nat) => impl { (m: Nat) => impl { (t: DataType) => ArrayType(n + m, t) ->: ArrayType(n, t) } } } + } + override def toString: String = "take" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/toMem.scala b/src/main/scala/rise/core/primitives/toMem.scala new file mode 100644 index 000000000..98b12bc48 --- /dev/null +++ b/src/main/scala/rise/core/primitives/toMem.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object toMem extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "toMem" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (t: DataType) => t ->: t } + } + override def toString: String = "toMem" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/transpose.scala b/src/main/scala/rise/core/primitives/transpose.scala new file mode 100644 index 000000000..0ffd8cc03 --- /dev/null +++ b/src/main/scala/rise/core/primitives/transpose.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object transpose extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "transpose" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (m: Nat) => impl { (t: DataType) => ArrayType(n, ArrayType(m, t)) ->: ArrayType(m, ArrayType(n, t)) } } } + } + override def toString: String = "transpose" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/unzip.scala b/src/main/scala/rise/core/primitives/unzip.scala new file mode 100644 index 000000000..15043eda6 --- /dev/null +++ b/src/main/scala/rise/core/primitives/unzip.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object unzip extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "unzip" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => ArrayType(n, PairType(s, t)) ->: PairType(ArrayType(n, s), ArrayType(n, t)) } } } + } + override def toString: String = "unzip" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/vectorFromScalar.scala b/src/main/scala/rise/core/primitives/vectorFromScalar.scala new file mode 100644 index 000000000..cc2d1d910 --- /dev/null +++ b/src/main/scala/rise/core/primitives/vectorFromScalar.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object vectorFromScalar extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "vectorFromScalar" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (t: DataType) => t ->: VectorType(n, t) } } + } + override def toString: String = "vectorFromScalar" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/core/primitives/zip.scala b/src/main/scala/rise/core/primitives/zip.scala new file mode 100644 index 000000000..52d06a550 --- /dev/null +++ b/src/main/scala/rise/core/primitives/zip.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.core.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object zip extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "zip" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => ArrayType(n, s) ->: ArrayType(n, t) ->: ArrayType(n, PairType(s, t)) } } } + } + override def toString: String = "zip" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/elevate/rules/lowering.scala b/src/main/scala/rise/elevate/rules/lowering.scala index 3e363807e..395a05851 100644 --- a/src/main/scala/rise/elevate/rules/lowering.scala +++ b/src/main/scala/rise/elevate/rules/lowering.scala @@ -119,7 +119,7 @@ object lowering { isPairOrBasicType(out) => Success(l) case _ => Failure(containsComputation()) } - case f@ForeignFunction(_) => Success(f) + case f@foreignFunction(_, _) => Success(f) case _ => Failure(containsComputation()) } } diff --git a/src/main/scala/rise/elevate/rules/traversal.scala b/src/main/scala/rise/elevate/rules/traversal.scala index fa2ef94b8..f4da6675d 100644 --- a/src/main/scala/rise/elevate/rules/traversal.scala +++ b/src/main/scala/rise/elevate/rules/traversal.scala @@ -192,7 +192,6 @@ object traversal { case _: TypeAnnotation => throw new Exception("Type annotations should be gone.") case _: TypeAssertion => throw new Exception("Type assertions should be gone.") case _: Opaque => throw new Exception("Opaque expressions should be gone.") - case _: ForeignFunction => None case _: Primitive => None } } diff --git a/src/main/scala/rise/elevate/strategies/lowering.scala b/src/main/scala/rise/elevate/strategies/lowering.scala index c1e1cd1b6..c83280cce 100644 --- a/src/main/scala/rise/elevate/strategies/lowering.scala +++ b/src/main/scala/rise/elevate/strategies/lowering.scala @@ -39,7 +39,6 @@ object lowering { case DepApp(_, _) => Failure(extract(what)) case _: Identifier => Failure(extract(what)) case _: Literal => Failure(extract(what)) - case _: ForeignFunction => Failure(extract(what)) case _: TypeAnnotation => throw new Exception("Type annotations should be gone.") case _: TypeAssertion => throw new Exception("Type assertions should be gone.") case _: Opaque => throw new Exception("Opaque expressions should be gone.") diff --git a/src/main/scala/rise/openCL/primitives/.gitignore b/src/main/scala/rise/openCL/primitives/.gitignore deleted file mode 100644 index 2d916b240..000000000 --- a/src/main/scala/rise/openCL/primitives/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.scala \ No newline at end of file diff --git a/src/main/scala/rise/openCL/primitives/mapGlobal.scala b/src/main/scala/rise/openCL/primitives/mapGlobal.scala new file mode 100644 index 000000000..d3ca39bc7 --- /dev/null +++ b/src/main/scala/rise/openCL/primitives/mapGlobal.scala @@ -0,0 +1,37 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.openCL.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +final case class mapGlobal(dim: Int) extends Builder { + override def toString: String = "mapGlobal" + override def primitive: rise.core.Primitive = mapGlobal.Primitive(dim)() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(mapGlobal.Primitive(dim)()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} +object mapGlobal { + private final case class Primitive(dim: Int)(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mapGlobal" + override def setType(ty: Type): Primitive = Primitive(dim)(ty) + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n, s) ->: ArrayType(n, t) } } } + override def primEq(obj: rise.core.Primitive): Boolean = obj match { + case p: Primitive => + p.dim == dim && true + case _ => + false + } + } + def unapply(arg: rise.core.Expr): Option[Int] = arg match { + case p: Primitive => + Some(p.dim) + case _ => + None + } +} diff --git a/src/main/scala/rise/openCL/primitives/mapLocal.scala b/src/main/scala/rise/openCL/primitives/mapLocal.scala new file mode 100644 index 000000000..93955713c --- /dev/null +++ b/src/main/scala/rise/openCL/primitives/mapLocal.scala @@ -0,0 +1,37 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.openCL.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +final case class mapLocal(dim: Int) extends Builder { + override def toString: String = "mapLocal" + override def primitive: rise.core.Primitive = mapLocal.Primitive(dim)() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(mapLocal.Primitive(dim)()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} +object mapLocal { + private final case class Primitive(dim: Int)(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mapLocal" + override def setType(ty: Type): Primitive = Primitive(dim)(ty) + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n, s) ->: ArrayType(n, t) } } } + override def primEq(obj: rise.core.Primitive): Boolean = obj match { + case p: Primitive => + p.dim == dim && true + case _ => + false + } + } + def unapply(arg: rise.core.Expr): Option[Int] = arg match { + case p: Primitive => + Some(p.dim) + case _ => + None + } +} diff --git a/src/main/scala/rise/openCL/primitives/mapWorkGroup.scala b/src/main/scala/rise/openCL/primitives/mapWorkGroup.scala new file mode 100644 index 000000000..28ba25ff2 --- /dev/null +++ b/src/main/scala/rise/openCL/primitives/mapWorkGroup.scala @@ -0,0 +1,37 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.openCL.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +final case class mapWorkGroup(dim: Int) extends Builder { + override def toString: String = "mapWorkGroup" + override def primitive: rise.core.Primitive = mapWorkGroup.Primitive(dim)() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(mapWorkGroup.Primitive(dim)()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} +object mapWorkGroup { + private final case class Primitive(dim: Int)(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mapWorkGroup" + override def setType(ty: Type): Primitive = Primitive(dim)(ty) + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n, s) ->: ArrayType(n, t) } } } + override def primEq(obj: rise.core.Primitive): Boolean = obj match { + case p: Primitive => + p.dim == dim && true + case _ => + false + } + } + def unapply(arg: rise.core.Expr): Option[Int] = arg match { + case p: Primitive => + Some(p.dim) + case _ => + None + } +} diff --git a/src/main/scala/rise/openCL/primitives/oclCircularBuffer.scala b/src/main/scala/rise/openCL/primitives/oclCircularBuffer.scala new file mode 100644 index 000000000..2cb0238c1 --- /dev/null +++ b/src/main/scala/rise/openCL/primitives/oclCircularBuffer.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.openCL.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object oclCircularBuffer extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "oclCircularBuffer" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = expl { (a: AddressSpace) => impl { (n: Nat) => expl { (alloc: Nat) => expl { (sz: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n + sz, s) ->: ArrayType(1 + n, ArrayType(sz, t)) } } } } } } + } + override def toString: String = "oclCircularBuffer" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/openCL/primitives/oclIterate.scala b/src/main/scala/rise/openCL/primitives/oclIterate.scala new file mode 100644 index 000000000..928124b78 --- /dev/null +++ b/src/main/scala/rise/openCL/primitives/oclIterate.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.openCL.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object oclIterate extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "oclIterate" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = expl { (a: AddressSpace) => impl { (n: Nat) => impl { (m: Nat) => expl { (k: Nat) => impl { (t: DataType) => expl { (l: Nat) => ArrayType(l * n, t) ->: ArrayType(l, t) } ->: ArrayType(m * n.pow(k), t) ->: ArrayType(m, t) } } } } } + } + override def toString: String = "oclIterate" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/openCL/primitives/oclReduceSeq.scala b/src/main/scala/rise/openCL/primitives/oclReduceSeq.scala new file mode 100644 index 000000000..6cb8b28cd --- /dev/null +++ b/src/main/scala/rise/openCL/primitives/oclReduceSeq.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.openCL.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object oclReduceSeq extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "oclReduceSeq" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = expl { (a: AddressSpace) => impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (t ->: s ->: t) ->: t ->: ArrayType(n, s) ->: t } } } } + } + override def toString: String = "oclReduceSeq" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/openCL/primitives/oclReduceSeqUnroll.scala b/src/main/scala/rise/openCL/primitives/oclReduceSeqUnroll.scala new file mode 100644 index 000000000..391210119 --- /dev/null +++ b/src/main/scala/rise/openCL/primitives/oclReduceSeqUnroll.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.openCL.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object oclReduceSeqUnroll extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "oclReduceSeqUnroll" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = expl { (a: AddressSpace) => impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (t ->: s ->: t) ->: t ->: ArrayType(n, s) ->: t } } } } + } + override def toString: String = "oclReduceSeqUnroll" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/openCL/primitives/oclRotateValues.scala b/src/main/scala/rise/openCL/primitives/oclRotateValues.scala new file mode 100644 index 000000000..72ebcb14c --- /dev/null +++ b/src/main/scala/rise/openCL/primitives/oclRotateValues.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.openCL.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object oclRotateValues extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "oclRotateValues" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = expl { (a: AddressSpace) => impl { (n: Nat) => expl { (sz: Nat) => impl { (s: DataType) => (s ->: s) ->: ArrayType(n + sz, s) ->: ArrayType(1 + n, ArrayType(sz, s)) } } } } + } + override def toString: String = "oclRotateValues" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/openCL/primitives/oclRunPrimitive.scala b/src/main/scala/rise/openCL/primitives/oclRunPrimitive.scala new file mode 100644 index 000000000..f71c19340 --- /dev/null +++ b/src/main/scala/rise/openCL/primitives/oclRunPrimitive.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.openCL.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object oclRunPrimitive extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "oclRunPrimitive" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = expl { (ls1: Nat) => expl { (ls2: Nat) => expl { (ls3: Nat) => expl { (gs1: Nat) => expl { (gs2: Nat) => expl { (gs3: Nat) => impl { (t: DataType) => t ->: t } } } } } } } + } + override def toString: String = "oclRunPrimitive" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/openCL/primitives/oclToMem.scala b/src/main/scala/rise/openCL/primitives/oclToMem.scala new file mode 100644 index 000000000..747482e55 --- /dev/null +++ b/src/main/scala/rise/openCL/primitives/oclToMem.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.openCL.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object oclToMem extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "oclToMem" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (t: DataType) => expl { (a: AddressSpace) => t ->: t } } + } + override def toString: String = "oclToMem" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/openMP/primitives/.gitignore b/src/main/scala/rise/openMP/primitives/.gitignore deleted file mode 100644 index 2d916b240..000000000 --- a/src/main/scala/rise/openMP/primitives/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.scala \ No newline at end of file diff --git a/src/main/scala/rise/openMP/primitives/mapPar.scala b/src/main/scala/rise/openMP/primitives/mapPar.scala new file mode 100644 index 000000000..17bd42b75 --- /dev/null +++ b/src/main/scala/rise/openMP/primitives/mapPar.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.openMP.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object mapPar extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "mapPar" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: ArrayType(n, s) ->: ArrayType(n, t) } } } + } + override def toString: String = "mapPar" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/rise/openMP/primitives/reducePar.scala b/src/main/scala/rise/openMP/primitives/reducePar.scala new file mode 100644 index 000000000..fb8d123b9 --- /dev/null +++ b/src/main/scala/rise/openMP/primitives/reducePar.scala @@ -0,0 +1,24 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package rise.openMP.primitives +import rise.core.DSL._ +import rise.core.DSL.Type._ +import rise.core._ +import rise.core.types._ +import arithexpr.arithmetic._ +object reducePar extends Builder { + private final case class Primitive()(override val t: Type = TypePlaceholder) extends rise.core.Primitive { + override val name: String = "reducePar" + override def setType(ty: Type): Primitive = Primitive()(ty) + override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass + override def typeScheme: Type = impl { (n: Nat) => impl { (t: DataType) => (t ->: t ->: t) ->: t ->: ArrayType(n, t) ->: t } } + } + override def toString: String = "reducePar" + override def primitive: rise.core.Primitive = Primitive()() + override def apply: ToBeTyped[rise.core.Primitive] = toBeTyped(Primitive()()) + override def unapply(arg: Expr): Boolean = arg match { + case _: Primitive => true + case _ => false + } +} diff --git a/src/main/scala/shine/C/Compilation/CodeGenerator.scala b/src/main/scala/shine/C/Compilation/CodeGenerator.scala index e5577d20b..05eb775d0 100644 --- a/src/main/scala/shine/C/Compilation/CodeGenerator.scala +++ b/src/main/scala/shine/C/Compilation/CodeGenerator.scala @@ -496,9 +496,9 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, case _ => error(s"Expected path to be not empty") } - case MakeArray(elems) => path match { + case m@MakeArray(_) => path match { case (i: CIntExpr) :: ps => try { - elems(i.eval) |> exp(env, ps, cont) + m.elements(i.eval) |> exp(env, ps, cont) } catch { case NotEvaluableException() => error(s"could not evaluate $i") } @@ -509,8 +509,8 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, case DepIdx(_, _, i, e) => e |> exp(env, CIntExpr(i) :: path, cont) - case ffc@ForeignFunctionCall(f, inTs, args) => - CCodeGen.codeGenForeignFunctionCall(f, inTs, ffc.outT, args, env, fe => + case ffc@ForeignFunctionCall(f, _) => + CCodeGen.codeGenForeignFunctionCall(f, ffc.inTs, ffc.outT, ffc.args, env, fe => generateAccess(ffc.outT, fe, path, env, cont) ) @@ -830,6 +830,8 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, case IndexData(i, _) => C.AST.ArithmeticExpr(i) case _: IntData | _: FloatData | _: DoubleData | _: BoolData => C.AST.Literal(d.toString) + case NatAsIntData(n) => + C.AST.Literal(n.toString) case ArrayData(a) => d.dataType match { case ArrayType(_, ArrayType(_, _)) => codeGenLiteral(ArrayData(a.flatten(d => d.asInstanceOf[ArrayData].a))) diff --git a/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala b/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala index 32994bf80..ed9f10d40 100644 --- a/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala +++ b/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala @@ -283,18 +283,18 @@ object AcceptorTranslation { λ(expT({l * n}`.`dt, read))(x => acc(f(l)(x))(o)))), x))) - case ocl.KernelCall(name, localSize, globalSize, _, args) => + case kc@ocl.KernelCall(name, localSize, globalSize, n) => def rec(ts: Seq[Phrase[ExpType]], es: Seq[Phrase[ExpType]]): Phrase[CommType] = { ts match { case Nil => - oclImp.KernelCallCmd(name, localSize, globalSize, es)(A.t.dataType, A) + oclImp.KernelCallCmd(name, localSize, globalSize, n)(kc.inTs, kc.outT, kc.args, A) case Seq(arg, tail@_*) => con(arg)(λ(expT(arg.t.dataType, read))(e => rec(tail, es :+ e))) } } - rec(args, Seq()) + rec(kc.args, Seq()) case map@ocl.Map(level, dim) => val (n, dt1, dt2, f, array) = map.unwrap @@ -303,7 +303,7 @@ object AcceptorTranslation { λ(expT(dt1, read))(x => λ(accT(dt2))(o => acc(f(x))(o))), x, A))) - case fc@ocl.OpenCLFunctionCall(name, inTs, args) => + case fc@ocl.OpenCLFunctionCall(name, n) => def rec(ts: Seq[(Phrase[ExpType], DataType)], exps: Seq[Phrase[ExpType]], inTs: Seq[DataType]): Phrase[CommType] = { @@ -311,14 +311,14 @@ object AcceptorTranslation { // with only one argument left to process return the assignment of the OpenCLFunction call case Seq( (arg, inT) ) => con(arg)(λ(expT(inT, read))(e => - A :=|fc.outT| ocl.OpenCLFunctionCall(name, inTs :+ inT, exps :+ e)(fc.outT) )) + A :=|fc.outT| ocl.OpenCLFunctionCall(name, n)(inTs :+ inT, fc.outT, exps :+ e) )) // with a `tail` of arguments left, recurse case Seq( (arg, inT), tail@_* ) => con(arg)(λ(expT(inT, read))(e => rec(tail, exps :+ e, inTs :+ inT) )) } } - rec(args zip inTs, Seq(), Seq()) + rec(fc.args zip fc.inTs, Seq(), Seq()) // CUDA case cuda.AsFragment(rows, columns, layers, dataType, fragmentKind, layout, matrix) => diff --git a/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala b/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala index 2617c22a7..1153d4997 100644 --- a/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala +++ b/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala @@ -119,7 +119,7 @@ object ContinuationTranslation { con(array)(λ(expT((n + m)`.` dt, read))(x => C(Drop(n, m, dt, x)))) - case ffc@ForeignFunctionCall(funDecl, inTs, args) => + case ffc@ForeignFunctionCall(funDecl, n) => def rec(ts: Seq[(Phrase[ExpType], DataType)], exps: Seq[Phrase[ExpType]], inTs: Seq[DataType]): Phrase[CommType] = { @@ -141,9 +141,9 @@ object ContinuationTranslation { `new`.apply } backendNew(ffc.outT, tmp => - Assign(ffc.outT, tmp.wr, ForeignFunctionCall(funDecl, inTs :+ inT, exps :+ e)(ffc.outT)) `;` + Assign(ffc.outT, tmp.wr, ForeignFunctionCall(funDecl, n)( inTs :+ inT, ffc.outT, exps :+ e)) `;` C(tmp.rd)) - case _ => C( ForeignFunctionCall(funDecl, inTs :+ inT, exps :+ e)(ffc.outT) ) + case _ => C( ForeignFunctionCall(funDecl, n)(inTs :+ inT, ffc.outT, exps :+ e) ) })) // with a `tail` of arguments left, rec case Seq( (arg, inT), tail@_* ) => @@ -152,7 +152,7 @@ object ContinuationTranslation { } } - rec(args zip inTs, Seq(), Seq()) + rec(ffc.args zip ffc.inTs, Seq(), Seq()) case Fst(dt1, dt2, pair) => con(pair)(λ(expT(dt1 x dt2, read))(x => @@ -191,17 +191,17 @@ object ContinuationTranslation { con(value)(fun(value.t)(x => con(f(x))(C))) - case ma@MakeArray(elements) => - def rec(func: Vector[Phrase[ExpType]], imp: Vector[Phrase[ExpType]]): Phrase[CommType] = { + case ma@MakeArray(_) => + def rec(func: Seq[Phrase[ExpType]], imp: Seq[Phrase[ExpType]]): Phrase[CommType] = { func match { case xf +: func => con(xf)(fun(expT(ma.dt, read))(xi => rec(func, imp :+ xi) )) - case _ => C(MakeArray(imp)(ma.n, ma.dt)) + case _ => C(MakeArray(ma.n)(ma.dt, imp)) } } - rec(elements, Vector()) + rec(ma.elements, Seq()) case makeDepPair@MakeDepPair(a, fst, sndT, snd) => // Allocate for the resulting dependent pair, @@ -362,7 +362,7 @@ object ContinuationTranslation { `new`(map.n `.` map.dt2, λ(varT(map.n `.` map.dt2))(tmp => acc(map)(tmp.wr) `;` C(tmp.rd))) - case fc@ocl.OpenCLFunctionCall(name, inTs, args) => + case fc@ocl.OpenCLFunctionCall(name, n) => def rec(ts: Seq[(Phrase[ExpType], DataType)], es: Seq[Phrase[ExpType]], inTs: Seq[DataType]): Phrase[CommType] = { @@ -370,14 +370,14 @@ object ContinuationTranslation { // with only one argument left to process continue with the OpenCLFunction call case Seq( (arg, inT) ) => con(arg)(λ(expT(inT, read))(e => - C(ocl.OpenCLFunctionCall(name, inTs :+ inT, es :+ e)(fc.outT)) )) + C(ocl.OpenCLFunctionCall(name, n)(inTs :+ inT, fc.outT, es :+ e)) )) // with a `tail` of arguments left, rec case Seq( (arg, inT), tail@_* ) => con(arg)(λ(expT(inT, read))(e => rec(tail, es :+ e, inTs :+ inT) )) } } - rec(args zip inTs, Seq(), Seq()) + rec(fc.args zip fc.inTs, Seq(), Seq()) case reduceSeq@ocl.ReduceSeq(unroll) => val (n, a, dt1, dt2, f, init, array) = reduceSeq.unwrap diff --git a/src/main/scala/shine/DPIA/Data.scala b/src/main/scala/shine/DPIA/Data.scala index 6ba75f437..ec245940d 100644 --- a/src/main/scala/shine/DPIA/Data.scala +++ b/src/main/scala/shine/DPIA/Data.scala @@ -14,6 +14,8 @@ final case class IntData(i: Int) extends Data(int) { override def toString: String = i.toString } +final case class NatAsIntData(n: Nat) extends Data(int) + final case class FloatData(f: Float) extends Data(f32) { override def toString: String = f.toString + "f" } diff --git a/src/main/scala/shine/DPIA/InferAccessAnnotation.scala b/src/main/scala/shine/DPIA/InferAccessAnnotation.scala index 3d2831d1a..20a797770 100644 --- a/src/main/scala/shine/DPIA/InferAccessAnnotation.scala +++ b/src/main/scala/shine/DPIA/InferAccessAnnotation.scala @@ -532,14 +532,15 @@ private class InferAccessAnnotation { case _ => error() } - case r.ForeignFunction(_) => + case rp.foreignFunction(_, _) => def buildType(t: rt.Type): PhraseType = t match { case dt: rt.DataType => expT(dataType(dt), read) case rt.FunType(in: rt.DataType, out) => expT(in, read) ->: buildType(out) - case _ => - throw Exception("This should not happen") + case rt.DepFunType(d: rt.DataTypeIdentifier, t) => + dFunT(d, buildType(t)) + case _ => throw Exception("This should not happen") } buildType(p.t) diff --git a/src/main/scala/shine/DPIA/fromRise.scala b/src/main/scala/shine/DPIA/fromRise.scala index a9c487b83..32be7922a 100644 --- a/src/main/scala/shine/DPIA/fromRise.scala +++ b/src/main/scala/shine/DPIA/fromRise.scala @@ -673,31 +673,32 @@ object fromRise { Let(s, t, a, x, f))) } - case r.ForeignFunction(decl) => - def collectTypes(t: PhraseType): (Vector[DataType], DataType) = { + case core.foreignFunction(decl, n) => + def collectTypes(t: PhraseType): (Seq[DataTypeIdentifier], DataTypeIdentifier) = { t match { - case ExpType(dt: DataType, `read`) => - (Vector(), dt) - case FunType(ExpType(dt: DataType, `read`), out) => + case ExpType(dt: DataTypeIdentifier, `read`) => (Seq(), dt) + case FunType(ExpType(dt: DataTypeIdentifier, `read`), out) => val (i, o) = collectTypes(out) (dt +: i, o) - case _ => - throw new Exception("This should not be possible") + case DepFunType(_, t) => collectTypes(t) + case _ => throw new Exception("This should not be possible") } } val (inTs, outT) = collectTypes(t) - - def buildFFCall(args: Vector[Phrase[ExpType]] - ): Phrase[_ <: PhraseType] = { - val i = args.length - if (i < inTs.length) { - fun[ExpType](ExpType(inTs(i), read), a => - buildFFCall(args :+ a)) - } else { - ForeignFunctionCall(decl, inTs, args)(outT) - } + assert(inTs.length == n) + + inTs.foldRight[Phrase[_ <: PhraseType]]( + depFun[DataKind](outT)({ + val args = Seq.tabulate(n)(i => Identifier(freshName("x"), ExpType(inTs(i), read))) + args.foldRight[Phrase[_ <: PhraseType]]( + ForeignFunctionCall(decl, n)(inTs, outT, args) + ) { + case (arg, f) => Lambda(arg, f) + } + }) + ) { + case (t, f) => depFun[DataKind](t)(f) } - buildFFCall(Vector()) case core.generate() => fromType { case (expT(IndexType(n), `read`) ->: expT(t, `read`)) ->: @@ -713,7 +714,7 @@ object fromRise { ): Phrase[_ <: PhraseType] = t match { case FunType(in: ExpType, out) => fun[ExpType](in, e => buildArrayPrimitive(out, elements :+ e)) - case ExpType(ArrayType(_, et), _) => MakeArray(elements)(elements.size, et) + case ExpType(ArrayType(_, et), _) => MakeArray(elements.size)(et, elements) case _ => error(s"did not expect $t") } buildArrayPrimitive(t, Vector()) diff --git a/src/main/scala/shine/DPIA/package.scala b/src/main/scala/shine/DPIA/package.scala index 55a6515a1..b9a1b3d44 100644 --- a/src/main/scala/shine/DPIA/package.scala +++ b/src/main/scala/shine/DPIA/package.scala @@ -142,6 +142,12 @@ package object DPIA { } } + object dFunT { + def apply(d: rt.DataTypeIdentifier, t: PhraseType): PhraseType = { + DepFunType[DataKind, PhraseType](fromRise.dataTypeIdentifier(d), t) + } + } + object aFunT { def apply(a: rt.AddressSpaceIdentifier, t: PhraseType): PhraseType = { DepFunType[AddressSpaceKind, PhraseType]( diff --git a/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala b/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala index 714cd3b23..2b1e560cf 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala @@ -7,8 +7,14 @@ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -final case class ForeignFunctionCall(funDecl: rise.core.ForeignFunction.Decl, inTs: Seq[DataType], args: Seq[Phrase[ExpType]])(val outT: DataType) extends ExpPrimitive { - {} +final case class ForeignFunctionCall(funDecl: rise.core.ForeignFunction.Decl, n: Int)(val inTs: Seq[DataType], val outT: DataType, val args: Seq[Phrase[ExpType]]) extends ExpPrimitive { + { + args.zip(inTs).foreach({ + case (args, inTs) => + args :: expT(inTs, read) + }) + } override val t: ExpType = expT(outT, read) - override def visitAndRebuild(v: VisitAndRebuild.Visitor): ForeignFunctionCall = new ForeignFunctionCall(funDecl, inTs.map(v.data), args.map(VisitAndRebuild(_, v)))(v.data(outT)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ForeignFunctionCall = new ForeignFunctionCall(funDecl, n)(inTs.map(v.data), v.data(outT), args.map(VisitAndRebuild(_, v))) + def unwrap: (Seq[DataType], DataType, Seq[Phrase[ExpType]]) = (inTs, outT, args) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala b/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala index 528f9a149..e887fd51a 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala @@ -7,9 +7,11 @@ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -final case class MakeArray(elements: Vector[Phrase[ExpType]])(val n: Nat, val dt: DataType) extends ExpPrimitive { - {} +final case class MakeArray(n: Int)(val dt: DataType, val elements: Seq[Phrase[ExpType]]) extends ExpPrimitive { + { + elements.foreach(_ :: expT(dt, read)) + } override val t: ExpType = expT(ArrayType(n, dt), read) - override def visitAndRebuild(v: VisitAndRebuild.Visitor): MakeArray = new MakeArray(elements.map(VisitAndRebuild(_, v)))(v.nat(n), v.data(dt)) - def unwrap: (Nat, DataType) = (n, dt) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): MakeArray = new MakeArray(n)(v.data(dt), elements.map(VisitAndRebuild(_, v))) + def unwrap: (DataType, Seq[Phrase[ExpType]]) = (dt, elements) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/primitives.dpia b/src/main/scala/shine/DPIA/primitives/functional/primitives.dpia index 37e8ede1c..078144ea0 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/primitives.dpia +++ b/src/main/scala/shine/DPIA/primitives/functional/primitives.dpia @@ -29,9 +29,10 @@ def depZip(n: nat, ft1: nat2data, ft2: nat2data, def drop(n: nat, m: nat, dt: data, array: exp[n+m.dt, read]): exp[m.dt, read] -def foreignFunctionCall{funDecl: rise.core.ForeignFunction.Decl, - inTs: Seq[DataType], - args: Seq[Phrase[ExpType]]}(outT: data): exp[outT, read] +def foreignFunctionCall{funDecl: rise.core.ForeignFunction.Decl, n: Int} + (inTs: n*data, + outT: data, + args: n*exp[*inTs, read]): exp[outT, read] def fst(dt1: data, dt2: data, pair: exp[(dt1, dt2), read]): exp[dt1, read] @@ -58,7 +59,7 @@ def join(n: nat, m: nat, a: access, dt: data, array: exp[n.m.dt, a]): exp[(n*m). def let(dt1: data, dt2: data, a: access, value: exp[dt1, read], f: exp[dt1, read] -> exp[dt2, a]): exp[dt2, a] -def makeArray{elements: Vector[Phrase[ExpType]]}(n: nat, dt: data): exp[n.dt, read] +def makeArray{n: Int}(dt: data, elements: n*exp[dt, read]): exp[n.dt, read] // def makeDepPair(...) diff --git a/src/main/scala/shine/OpenCL/Compilation/HostCodeGenerator.scala b/src/main/scala/shine/OpenCL/Compilation/HostCodeGenerator.scala index a2bb26544..f93a1e3bd 100644 --- a/src/main/scala/shine/OpenCL/Compilation/HostCodeGenerator.scala +++ b/src/main/scala/shine/OpenCL/Compilation/HostCodeGenerator.scala @@ -25,8 +25,8 @@ case class HostCodeGenerator(override val decls: C.Compilation.CodeGenerator.Dec override def name: String = "OpenCL Host" override def cmd(env: Environment): Phrase[CommType] => Stmt = { - case k@KernelCallCmd(name, LocalSize(ls), GlobalSize(gs), args) => - kernelCallCmd(name, ls, gs, k.output, args, env) + case k@KernelCallCmd(name, LocalSize(ls), GlobalSize(gs), n) => + kernelCallCmd(name, ls, gs, k.output, k.args, env) case n@NewManagedBuffer(access) => val (dt, Lambda(v, p)) = n.unwrap newManagedBuffer(dt, access, v, p, env) diff --git a/src/main/scala/shine/OpenCL/Compilation/HostManagedBuffers.scala b/src/main/scala/shine/OpenCL/Compilation/HostManagedBuffers.scala index 51a2f8e8f..6c3e81f11 100644 --- a/src/main/scala/shine/OpenCL/Compilation/HostManagedBuffers.scala +++ b/src/main/scala/shine/OpenCL/Compilation/HostManagedBuffers.scala @@ -97,14 +97,15 @@ object HostManagedBuffers { collectWrites(lhs, metadata.host_writes) collectReads(rhs, allocs, metadata.host_reads) Stop(p) - case k@ocl.KernelCallCmd(_, _, _, in) => - in.foreach(collectReads(_, allocs, metadata.device_reads)) + case k@ocl.KernelCallCmd(_, _, _, _) => + k.args.foreach(collectReads(_, allocs, metadata.device_reads)) collectWrites(k.output, metadata.device_writes) - ((k.output, DEVICE_WRITE) +: in.map(_ -> DEVICE_READ)).foreach { + ((k.output, DEVICE_WRITE) +: k.args.map(_ -> DEVICE_READ)).foreach { case (i: Identifier[_], a) => recordManagedAccess(managed, i, a) case (Proj1(i: Identifier[_]), a) => recordManagedAccess(managed, i, a) case (Proj2(i: Identifier[_]), a) => recordManagedAccess(managed, i, a) case (Natural(_), _) => + case (Literal(NatAsIntData(_)), _) => case (unexpected, _) => throw new Exception(s"did not expect $unexpected") } Stop(p) @@ -154,12 +155,13 @@ object HostManagedBuffers { val x2 = managed(x)._2.asInstanceOf[Identifier[VarType]] Continue(ocl.NewManagedBuffer(access)(dt, Lambda(x2, body)), this) case _: dpia.New | _: Lambda[_, _] | _: dpia.Seq | - _: Proj2[_, _] | _: Proj1[_, _] | Natural(_) => + _: Proj2[_, _] | _: Proj1[_, _] | Natural(_) | Literal(NatAsIntData(_)) => Continue(p, this) - case k@ocl.KernelCallCmd(name, ls, gs, args) => + case k@ocl.KernelCallCmd(name, ls, gs, n) => val newOutput = VisitAndRebuild(k.output, this) - Stop(ocl.KernelCallCmd(name, ls, gs, args.map(VisitAndRebuild(_, this)))( - newOutput.t.dataType, newOutput)) + val newArgs = k.args.map(VisitAndRebuild(_, this)) + Stop(ocl.KernelCallCmd(name, ls, gs, n)( + newArgs.map(_.t.dataType), newOutput.t.dataType, k.args.map(VisitAndRebuild(_, this)), newOutput)) case _: HostExecution => Stop(p) case unexpected => throw new Exception(s"did not expect $unexpected") } @@ -234,8 +236,8 @@ object HostManagedBuffers { case PadClamp(_, _, _, _, e) => collectReads(e, allocs, reads) case Cast(_, _, e) => collectReads(e, allocs, reads) - case ForeignFunctionCall(_, _, es) => - es.foreach { + case ffc@ForeignFunctionCall(_, _) => + ffc.args.foreach { collectReads(_, allocs, reads) } case NatAsIndex(_, e) => collectReads(e, allocs, reads) @@ -245,8 +247,8 @@ object HostManagedBuffers { case MakePair(_, _, _, e1, e2) => collectReads(e1, allocs, reads); collectReads(e2, allocs, reads) case Reorder(_, _, _, _, _, e) => collectReads(e, allocs, reads) - case MakeArray(es) => - es.foreach { + case m@MakeArray(_) => + m.elements.foreach { collectReads(_, allocs, reads) } case Gather(_, _, _, e1, e2) => diff --git a/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala b/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala index af5ede7dd..0445506dc 100644 --- a/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala +++ b/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala @@ -191,8 +191,8 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, case IdxVec(_, _, i, e) => CCodeGen.codeGenIdx(i, e, env, path, cont) - case OpenCLFunctionCall(name, _, args) => - CCodeGen.codeGenForeignCall(name, args, env, Nil, cont) + case fc@OpenCLFunctionCall(name, _) => + CCodeGen.codeGenForeignCall(name, fc.args, env, Nil, cont) case idx: ocl.IdxDistribute => path match { // TODO: ensure that i % stride == init ? diff --git a/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala b/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala index e772d1d50..fbee659d0 100644 --- a/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala +++ b/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala @@ -212,8 +212,8 @@ object InsertMemoryBarriers { case PadClamp(_, _, _, _, e) => collectReads(e, allocs, reads) case Cast(_, _, e) => collectReads(e, allocs, reads) - case ForeignFunctionCall(_, _, es) => - es.foreach { + case ffc@ForeignFunctionCall(_, _) => + ffc.args.foreach { collectReads(_, allocs, reads) } case NatAsIndex(_, e) => collectReads(e, allocs, reads) @@ -223,8 +223,8 @@ object InsertMemoryBarriers { case MakePair(_, _, _, e1, e2) => collectReads(e1, allocs, reads); collectReads(e2, allocs, reads) case Reorder(_, _, _, _, _, e) => collectReads(e, allocs, reads) - case MakeArray(es) => - es.foreach { + case m@MakeArray(_) => + m.elements.foreach { collectReads(_, allocs, reads) } case Gather(_, _, _, e1, e2) => diff --git a/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala b/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala index d6622ff82..9d2a6dcb9 100644 --- a/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala +++ b/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala @@ -25,9 +25,10 @@ object SeparateHostAndKernelCode { val (closedDef, args) = closeDefinition(r.input) val kernelDef = KernelDef(name, closedDef, localSize, globalSize) kernelDefinitions += kernelDef - Stop(KernelCall(name, localSize, globalSize, + Stop(KernelCall(name, localSize, globalSize, args.length)( kernelDef.paramTypes.map(_.dataType), - args)(kernelDef.returnType.dataType).asInstanceOf[Phrase[T]]) + kernelDef.returnType.dataType, + args).asInstanceOf[Phrase[T]]) // on the fly beta-reduction case Apply(fun, arg) => @@ -59,7 +60,7 @@ object SeparateHostAndKernelCode { freeNats match { case v +: rest => iterNats( DepLambda[NatKind](NatIdentifier(v.name, v.range))(definition), - Natural(v) +: args, rest) + Literal(NatAsIntData(v)) +: args, rest) case Nil => (definition, args) } } diff --git a/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala b/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala index d2f150c9a..0f66f31cc 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala @@ -7,8 +7,14 @@ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -final case class KernelCall(name: String, localSize: shine.OpenCL.LocalSize, globalSize: shine.OpenCL.GlobalSize, inTs: Seq[DataType], args: Seq[Phrase[ExpType]])(val outT: DataType) extends ExpPrimitive { - {} +final case class KernelCall(name: String, localSize: shine.OpenCL.LocalSize, globalSize: shine.OpenCL.GlobalSize, n: Int)(val inTs: Seq[DataType], val outT: DataType, val args: Seq[Phrase[ExpType]]) extends ExpPrimitive { + { + args.zip(inTs).foreach({ + case (args, inTs) => + args :: expT(inTs, read) + }) + } override val t: ExpType = expT(outT, write) - override def visitAndRebuild(v: VisitAndRebuild.Visitor): KernelCall = new KernelCall(name, localSize.visitAndRebuild(v), globalSize.visitAndRebuild(v), inTs.map(v.data), args.map(VisitAndRebuild(_, v)))(v.data(outT)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): KernelCall = new KernelCall(name, localSize.visitAndRebuild(v), globalSize.visitAndRebuild(v), n)(inTs.map(v.data), v.data(outT), args.map(VisitAndRebuild(_, v))) + def unwrap: (Seq[DataType], DataType, Seq[Phrase[ExpType]]) = (inTs, outT, args) } diff --git a/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala b/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala index a3f2ec8a0..363c6d792 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala @@ -7,8 +7,14 @@ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -final case class OpenCLFunctionCall(name: String, inTs: Seq[DataType], args: Seq[Phrase[ExpType]])(val outT: DataType) extends ExpPrimitive { - {} +final case class OpenCLFunctionCall(name: String, n: Int)(val inTs: Seq[DataType], val outT: DataType, val args: Seq[Phrase[ExpType]]) extends ExpPrimitive { + { + args.zip(inTs).foreach({ + case (args, inTs) => + args :: expT(inTs, read) + }) + } override val t: ExpType = expT(outT, write) - override def visitAndRebuild(v: VisitAndRebuild.Visitor): OpenCLFunctionCall = new OpenCLFunctionCall(name, inTs.map(v.data), args.map(VisitAndRebuild(_, v)))(v.data(outT)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): OpenCLFunctionCall = new OpenCLFunctionCall(name, n)(inTs.map(v.data), v.data(outT), args.map(VisitAndRebuild(_, v))) + def unwrap: (Seq[DataType], DataType, Seq[Phrase[ExpType]]) = (inTs, outT, args) } diff --git a/src/main/scala/shine/OpenCL/primitives/functional/primitives.dpia b/src/main/scala/shine/OpenCL/primitives/functional/primitives.dpia index 7f24dbfd8..94f67520a 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/primitives.dpia +++ b/src/main/scala/shine/OpenCL/primitives/functional/primitives.dpia @@ -11,22 +11,20 @@ def iterate(a: address, n: nat, m: nat, k: nat, dt: data, f: (l: nat) -> exp[(l*n).dt, read] -> exp[l.dt, write], array: exp[(m*(n^k)).dt, read]): exp[m.dt, write] -def kernelCall{name: String, - localSize: shine.OpenCL.LocalSize, - globalSize: shine.OpenCL.GlobalSize, - inTs: Seq[DataType], - args: Seq[Phrase[ExpType]]} - (outT: data): exp[outT, write] +def kernelCall{name: String, localSize: shine.OpenCL.LocalSize, globalSize: shine.OpenCL.GlobalSize, n: Int} + (inTs: n*data, + outT: data, + args: n*exp[*inTs, read]): exp[outT, write] def map{level: shine.OpenCL.ParallelismLevel, dim: Int} (n: nat, dt1: data, dt2: data, f: exp[dt1, read] -> exp[dt2, write], array: exp[n.dt1, read]): exp[n.dt2, write] -def openCLFunctionCall{name: String, - inTs: Seq[DataType], - args: Seq[Phrase[ExpType]]} - (outT: data): exp[outT, write] +def openCLFunctionCall{name: String, n: Int} + (inTs: n*data, + outT: data, + args: n*exp[*inTs, read]): exp[outT, write] def reduceSeq{unroll: Boolean} (n: nat, a: address, dt1: data, dt2: data, diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala b/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala index 2e420b356..24e26716d 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala @@ -7,11 +7,15 @@ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -final case class KernelCallCmd(name: String, localSize: shine.OpenCL.LocalSize, globalSize: shine.OpenCL.GlobalSize, args: Seq[Phrase[ExpType]])(val dt: DataType, val output: Phrase[AccType]) extends CommandPrimitive { +final case class KernelCallCmd(name: String, localSize: shine.OpenCL.LocalSize, globalSize: shine.OpenCL.GlobalSize, n: Int)(val inTs: Seq[DataType], val dt: DataType, val args: Seq[Phrase[ExpType]], val output: Phrase[AccType]) extends CommandPrimitive { { + args.zip(inTs).foreach({ + case (args, inTs) => + args :: expT(inTs, read) + }) output :: accT(dt) } override val t: CommType = comm - override def visitAndRebuild(v: VisitAndRebuild.Visitor): KernelCallCmd = new KernelCallCmd(name, localSize.visitAndRebuild(v), globalSize.visitAndRebuild(v), args.map(VisitAndRebuild(_, v)))(v.data(dt), VisitAndRebuild(output, v)) - def unwrap: (DataType, Phrase[AccType]) = (dt, output) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): KernelCallCmd = new KernelCallCmd(name, localSize.visitAndRebuild(v), globalSize.visitAndRebuild(v), n)(inTs.map(v.data), v.data(dt), args.map(VisitAndRebuild(_, v)), VisitAndRebuild(output, v)) + def unwrap: (Seq[DataType], DataType, Seq[Phrase[ExpType]], Phrase[AccType]) = (inTs, dt, args, output) } diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/primitives.dpia b/src/main/scala/shine/OpenCL/primitives/imperative/primitives.dpia index 5b78dbaa0..de85a4e0c 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/primitives.dpia +++ b/src/main/scala/shine/OpenCL/primitives/imperative/primitives.dpia @@ -11,11 +11,8 @@ def idxDistribute{parallelismLevel: shine.OpenCL.ParallelismLevel} def idxDistributeAcc{parallelismLevel: shine.OpenCL.ParallelismLevel} (m: nat, n: nat, stride: nat, dt: data, array: acc[m.dt]): acc[n.dt] -def kernelCallCmd{name: String, - localSize: shine.OpenCL.LocalSize, - globalSize: shine.OpenCL.GlobalSize, - args: Seq[Phrase[ExpType]]} - (dt: data, output: acc[dt]): comm +def kernelCallCmd{name: String, localSize: shine.OpenCL.LocalSize, globalSize: shine.OpenCL.GlobalSize, n: Int} + (inTs: n*data, dt: data, args: n*exp[*inTs, read], output: acc[dt]): comm def new(a: address, dt: data, f: var[dt] -> comm): comm diff --git a/src/main/scala/shine/OpenMP/CodeGenerator.scala b/src/main/scala/shine/OpenMP/CodeGenerator.scala index e567531c9..1ccb6af92 100644 --- a/src/main/scala/shine/OpenMP/CodeGenerator.scala +++ b/src/main/scala/shine/OpenMP/CodeGenerator.scala @@ -93,8 +93,8 @@ class CodeGenerator(override val decls: CCodeGenerator.Declarations, } case _ => phrase |> super.exp(env, path, cont) } - case ffc@ForeignFunctionCall(f, inTs, args) => - OpenMPCodeGen.codeGenForeignFunctionCall(f, inTs, ffc.outT, args, env, path, cont) + case ffc@ForeignFunctionCall(f, _) => + OpenMPCodeGen.codeGenForeignFunctionCall(f, ffc.inTs, ffc.outT, ffc.args, env, path, cont) case AsVectorAligned(n, _, dt, _, e) => path match { case (i: CIntExpr) :: (j: CIntExpr) :: ps => e |> exp(env, CIntExpr((i * n) + j) :: ps, cont) diff --git a/src/test/scala/apps/separableConvolution2DNaiveEqsat.scala b/src/test/scala/apps/separableConvolution2DNaiveEqsat.scala index 9d00f52df..0cc00d395 100644 --- a/src/test/scala/apps/separableConvolution2DNaiveEqsat.scala +++ b/src/test/scala/apps/separableConvolution2DNaiveEqsat.scala @@ -62,7 +62,6 @@ class separableConvolution2DNaiveEqsat extends test_util.Tests { } case DepApp(f, x) => everywhere(s)(f).map(DepApp(_, x)(p.t)) case Literal(_) => Nil - case _: ForeignFunction => Nil case _: TypeAnnotation => throw new Exception("Type annotations should be gone.") case _: TypeAssertion => throw new Exception("Type assertions should be gone.") case _: Opaque => throw new Exception("Opaque expressions should be gone.") diff --git a/src/test/scala/shine/DPIA/Primitives/Generate.scala b/src/test/scala/shine/DPIA/Primitives/Generate.scala index 26808d595..a65b95293 100644 --- a/src/test/scala/shine/DPIA/Primitives/Generate.scala +++ b/src/test/scala/shine/DPIA/Primitives/Generate.scala @@ -1,6 +1,6 @@ package shine.DPIA.Primitives -import rise.core.{ForeignFunction, Lambda} +import rise.core.{Expr, Lambda} import rise.core.DSL._ import rise.core.primitives._ import Type._ @@ -11,7 +11,7 @@ import util.gen.c.function class Generate extends test_util.Tests { val id: ToBeTyped[Lambda] = fun(x => x) val addT: ToBeTyped[Lambda] = fun(x => fst(x) + snd(x)) - val cos: ToBeTyped[ForeignFunction] = foreignFun("callCos", Seq("x"), "{ return cos(x); }", f64 ->: f64) + val cos: ToBeTyped[Expr] = foreignFun("callCos", Seq("x"), "{ return cos(x); }", f64 ->: f64) test("Very simple one-dimensional generate generates syntactically correct code in C.") { val e = depFun((n: Nat) => generate(fun(IndexType(n))(i => cast(i) + lf64(1.0))) |> mapSeq(id)) From e4a63a005a8c0639e457a3c41df225fa1faee097 Mon Sep 17 00:00:00 2001 From: Michel Steuwer Date: Fri, 7 May 2021 10:55:10 +0100 Subject: [PATCH 2/5] Cleanup and added comments --- .../scala/meta/generator/DPIAPrimitives.scala | 50 +++++--------- .../scala/meta/generator/RisePrimitives.scala | 69 ++++++++++++------- .../main/scala/meta/parser/DPIA/Decl.scala | 2 - .../main/scala/meta/parser/rise/Type.scala | 3 +- 4 files changed, 65 insertions(+), 59 deletions(-) diff --git a/meta/src/main/scala/meta/generator/DPIAPrimitives.scala b/meta/src/main/scala/meta/generator/DPIAPrimitives.scala index 387e19494..ca1269664 100644 --- a/meta/src/main/scala/meta/generator/DPIAPrimitives.scala +++ b/meta/src/main/scala/meta/generator/DPIAPrimitives.scala @@ -140,7 +140,9 @@ ${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), param }""" case Right(DPIA.Type.AST.VariadicType(n, typeAST)) => getUnrolledIds(typeAST) match { - case Some(unrolledIds) => + case Seq() => + q"${Term.Name(param.id.name)}.foreach(_ :: ${generateTerm(typeAST)})" + case unrolledIds => val zips = unrolledIds.foldLeft[Term](Term.Name(param.id.name)) { case (term, id) => q"$term.zip(${Term.Name(id.name)})" } @@ -151,8 +153,6 @@ ${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), param q"""$zips.foreach { case ($pattern) => $p :: ${generateTerm(typeAST)} }""" - case None => - q"${Term.Name(param.id.name)}.foreach(_ :: ${generateTerm(typeAST)})" } case Right(typeAST) => q"${Term.Name(param.id.name)} :: ${generateTerm(typeAST)}" @@ -162,62 +162,50 @@ ${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), param }""" } - def getUnrolledIds(typeAST: DPIA.Type.AST): Option[Seq[rise.Type.AST.UnrolledIdentifier]] = { + def getUnrolledIds(typeAST: DPIA.Type.AST): Seq[rise.Type.AST.UnrolledIdentifier] = { import DPIA.Type.AST typeAST match { case AST.ExpType(dataType, _) => getUnrolledIds(dataType) case AST.AccType(dataType) => getUnrolledIds(dataType) - case AST.CommType => None + case AST.CommType => Seq() case AST.PairType(lhs, rhs) => - for { - ids1 <- getUnrolledIds(lhs) - ids2 <- getUnrolledIds(rhs) - } yield ids1 concat ids2 + getUnrolledIds(lhs) concat getUnrolledIds(rhs) case AST.FunType(inT, outT) => - for { - ids1 <- getUnrolledIds(inT) - ids2 <- getUnrolledIds(outT) - } yield ids1 concat ids2 + getUnrolledIds(inT) concat getUnrolledIds(outT) case AST.DepFunType(_, _, t) => getUnrolledIds(t) - case AST.Identifier(_) => None + case AST.Identifier(_) => Seq() case AST.VariadicType(n, ty) => throw new Exception("This function should not be called on a variadic type") } } - def getUnrolledIds(typeAST: rise.Type.AST): Option[Seq[rise.Type.AST.UnrolledIdentifier]] = { + def getUnrolledIds(typeAST: rise.Type.AST): Seq[rise.Type.AST.UnrolledIdentifier] = { import rise.Type.AST typeAST match { - case AST.Identifier(_) => None - case id@AST.UnrolledIdentifier(_) => Some(Seq(id)) + case AST.Identifier(_) => Seq() + case id@AST.UnrolledIdentifier(_) => Seq(id) case AST.FunType(inT, outT) => - for { - ids1 <- getUnrolledIds(inT) - ids2 <- getUnrolledIds(outT) - } yield ids1 concat ids2 + getUnrolledIds(inT) concat getUnrolledIds(outT) case AST.DepFunType(_, _, t) => getUnrolledIds(t) case AST.ImplicitDepFunType(_, _, t) => getUnrolledIds(t) case AST.VariadicFunType(_, _, _) => throw new Exception("This function should not be called on a variadic type") case AST.VariadicDepFunType(_, _, _, _) => throw new Exception("This function should not be called on a variadic type") - case AST.ScalarType(_) => None - case AST.NatType => None - case AST.OpaqueType(_) => None + case AST.ScalarType(_) => Seq() + case AST.NatType => Seq() + case AST.OpaqueType(_) => Seq() case AST.VectorType(_, elemType) => getUnrolledIds(elemType) - case AST.IndexType(_) => None + case AST.IndexType(_) => Seq() case AST.PairType(lhs, rhs) => - for { - ids1 <- getUnrolledIds(lhs) - ids2 <- getUnrolledIds(rhs) - } yield ids1 concat ids2 + getUnrolledIds(lhs) concat getUnrolledIds(rhs) case AST.DepPairType(_, _, t) => getUnrolledIds(t) - case AST.NatToDataApply(_, _) => None + case AST.NatToDataApply(_, _) => Seq() case AST.NatToDataLambda(_, t) => getUnrolledIds(t) case AST.ArrayType(_, elemType) => getUnrolledIds(elemType) - case AST.DepArrayType(_, _) => None + case AST.DepArrayType(_, _) => Seq() case AST.FragmentType(_, _, _, dt, _, _) => getUnrolledIds(dt) case AST.ManagedBufferType(t) => getUnrolledIds(t) } diff --git a/meta/src/main/scala/meta/generator/RisePrimitives.scala b/meta/src/main/scala/meta/generator/RisePrimitives.scala index d85e2d7f3..7523d0046 100644 --- a/meta/src/main/scala/meta/generator/RisePrimitives.scala +++ b/meta/src/main/scala/meta/generator/RisePrimitives.scala @@ -9,28 +9,30 @@ import scala.meta.Term object RisePrimitives { def main(args: Array[String]): Unit = { - val sourceDir = args.head + val sourceDir = args.head // excepts one argument that is the source directory of the rise repo (i.e. 'rise/src') val risePath = os.Path(sourceDir) / "rise" - os.walk.stream(risePath).filter(_.ext == "rise").foreach(path => { + os.walk.stream(risePath).filter(_.ext == "rise").foreach(path => { // for each file with the `.rise` extension ... - val definition = os.read(path) - parse(definition, rise.Decl.PrimitiveDeclarations(_)) match { + val definition = os.read(path) // ... read the file content + parse(definition, rise.Decl.PrimitiveDeclarations(_)) match { // ... and parse it case failure: Parsed.Failure => println(s"Failed to parse `${failure.extra.input}'") println(s" $failure") - case Parsed.Success(seq, _) => - seq.foreach { + case Parsed.Success(seq, _) => seq.foreach { // ... if successful go over all case rise.Decl.AST.PrimitiveDeclaration(rise.Decl.AST.Identifier(name), scalaParams, typeSignature) - if rise.isWellKindedType(toParamList(definition, scalaParams), typeSignature) => + if rise.isWellKindedType( // ... well kinded declarations + toParamList(definition, scalaParams), typeSignature) => val outputPath = (path / os.up) / s"$name.scala" println(s"Generate $outputPath") + // ... and generate class or object definition val generatedDef = generate(name, toParamList(definition, scalaParams), typeSignature) import scala.meta._ val packageName = path.relativeTo(risePath).segments.dropRight(1).foldLeft[Term.Ref](Term.Name("rise")) { case (t, name) => Term.Select(t, Term.Name(name)) } + // ... combine into the generated source code val code = s"""// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // |// This file is automatically generated and should not be changed manually // |// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // @@ -48,7 +50,7 @@ import arithexpr.arithmetic._ }""".toString()} |""".stripMargin - os.write.over(outputPath, code) + os.write.over(outputPath, code) // ... and finally write out to disk. case rise.Decl.AST.PrimitiveDeclaration(name, _, typeSignature) => println(s"Could not generate code for `$name' as type signature `$typeSignature' is not well kinded.") } @@ -56,6 +58,7 @@ import arithexpr.arithmetic._ }) } + // parse scala parameters (i.e. parameters that are not part of the rise language) into a list of parameters def toParamList(definition: String, scalaParams: Option[(Int, Int)]): Option[List[Term.Param]] = { import scala.meta._ scalaParams.map { case (start, end) => @@ -65,6 +68,7 @@ import arithexpr.arithmetic._ } } + // generate either an object (if there are no scala parameters) or a case class (if there are some) def generate(name: String, params: Option[List[Term.Param]], typeSignature: rise.Type.AST): scala.meta.Term.Block = params match { case None => generateObject(name, typeSignature) @@ -131,12 +135,11 @@ import arithexpr.arithmetic._ } def unapply(arg: rise.core.Expr): ${if (types.length > 1) { - t"Option[(..$types)]" + t"Option[(..$types)]" // create tuple if there are multiple type parameters } else { t"Option[..$types]"} } = arg match { - case p: Primitive => - Some(..${generateMemberAccesses(args)}) + case p: Primitive => Some(..${generateMemberAccesses(args)}) case _ => None } } @@ -144,6 +147,8 @@ import arithexpr.arithmetic._ generated } + // generate a scala term representing the type scheme, i.e. something like this for map: + // impl { (n: Nat) => impl { (s: DataType) => impl { (t: DataType) => (s ->: t) ->: n`.`s ->: n`.`t } } } def generateTypeScheme(typeAST: rise.Type.AST): scala.meta.Term = { import scala.meta._ typeAST match { @@ -154,26 +159,37 @@ import arithexpr.arithmetic._ case rise.Type.AST.ImplicitDepFunType(id, kind, t) => q"impl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t)})" case rise.Type.AST.VariadicFunType(_, rise.Type.AST.UnrolledIdentifier(inTs), outT) => + // variadic function type with unrolled identifier, e.g.: n*(*inTs ->) outT + // generates: inTs.foldRight(outT){ case (lhsT, rhsT) => lhs ->: rhs } + // to represent n-many function types: inT0 ->: inT1 ->: ... ->: outT q"""${Term.Name(inTs)}.foldRight(${generateTypeScheme(outT)}: Type) { case (lhsT, rhsT) => lhsT ->: rhsT }""" - case rise.Type.AST.VariadicFunType(rise.Type.AST.Identifier(n), inTs, outT) => - q"""Seq.fill(${Term.Name(n)})(${generateDataType(inTs)}).foldRight(${generateTypeScheme(outT)}: Type) { + case rise.Type.AST.VariadicFunType(rise.Type.AST.Identifier(n), dt, outT) => + // variadic function type without an unrolled identifier, e.g: n*(dt ->) outT + // generated: Seq.fill(n)(dt).foldRight(outT){ case (lhsT, rhsT) => lhs ->: rhs } + // to represent n-many function types: dt -> dt -> ... -> outT + q"""Seq.fill(${Term.Name(n)})(${generateDataType(dt)}).foldRight(${generateTypeScheme(outT)}: Type) { case (lhsT, rhsT) => lhsT ->: rhsT }""" - case rise.Type.AST.VariadicDepFunType(n, id, kind, t) => + case rise.Type.AST.VariadicDepFunType(n, ids, kind, t) => + // variadic dependent function type, e.g.: n*((ids: kind) ->) t + // generates: + // val ids = Seq.fill(n)(DataTypeIdentifier(freshName("dt"), isExplicit = true)) + // ids.foldRight(t){ case (id, t) => DepFunType[DataKind](id, t) } + // to represent n-many dependent function types: (id0: kind) -> (id1: kind) -> ... -> t val (createIds, typeName) = kind match { case AST.Data => (q"""DataTypeIdentifier(freshName("dt"), isExplicit = true)""", Type.Name("DataKind")) - case AST.Address => ??? - case AST.Nat2Nat => ??? - case AST.Nat2Data => ??? - case AST.Nat => ??? - case AST.Fragment => ??? - case AST.MatrixLayout => ??? + case AST.Address => (q"""AddressSpaceIdentifier(freshName("a"), isExplicit = true)""", Type.Name("AddressSpaceKind")) + case AST.Nat2Nat => (q"""NatToNatIdentifier(freshName("n2n"), isExplicit = true)""", Type.Name("NatToNatKind")) + case AST.Nat2Data => (q"""NatToDataIdentifier(freshName("n2d"), isExplicit = true)""", Type.Name("NatToDataKind")) + case AST.Nat => (q"""NatIdentifier(freshName("n"), isExplicit = true)""", Type.Name("NatKind")) + case AST.Fragment => throw new Exception("No support for Fragment Kind yet") + case AST.MatrixLayout => throw new Exception("No support for Matrix Layout Kind yet") } q"""{ - val ${Pat.Var(Term.Name(id.name))} = Seq.fill(${Term.Name(n.name)})($createIds) - ${Term.Name(id.name)}.foldRight(${generateTypeScheme(t)}: Type) { + val ${Pat.Var(Term.Name(ids.name))} = Seq.fill(${Term.Name(n.name)})($createIds) + ${Term.Name(ids.name)}.foldRight(${generateTypeScheme(t)}: Type) { case (id, t) => DepFunType[$typeName, Type](id, t) } }""" @@ -181,6 +197,7 @@ import arithexpr.arithmetic._ } } + // generate a scala term representing a rise data type def generateDataType(typeAST: rise.Type.AST): scala.meta.Term = { import scala.meta._ typeAST match { @@ -203,7 +220,7 @@ import arithexpr.arithmetic._ case rise.Type.AST.DepPairType(id, kind, t) => kind match { case Kind.AST.Nat => q"Nat `**` ((${Term.Name(id.name)}: Nat) => ${generateDataType(t)})" - case _ => ??? + case _ => throw new Exception("DepPair types currently only support Nat Kind") } case rise.Type.AST.NatToDataApply(f, n) => q"NatToDataApply(${generateDataType(f)}, ${generateNat(n)})" @@ -214,12 +231,14 @@ import arithexpr.arithmetic._ case rise.Type.AST.DepArrayType(size, fdt) => q"DepArrayType(${generateNat(size)}, ${generateDataType(fdt)})" case rise.Type.AST.FragmentType(n, m, k, elemType, fKind, mLayout) => - q"FragmentType(${generateNat(n)}, ${generateNat(m)}, ${generateNat(k)}, ${generateDataType(elemType)}, ${generateFragment(fKind)}, ${generateMatrixLayout(mLayout)})" + q"""FragmentType(${generateNat(n)}, ${generateNat(m)}, ${generateNat(k)}, ${generateDataType(elemType)}, + ${generateFragment(fKind)}, ${generateMatrixLayout(mLayout)})""" case rise.Type.AST.ManagedBufferType(dt) => q"ManagedBufferType(${generateDataType(dt)})" case rise.Type.AST.FunType(_, _) | rise.Type.AST.DepFunType(_, _, _) | rise.Type.AST.ImplicitDepFunType(_, _, _) | rise.Type.AST.VariadicFunType(_, _, _) | - rise.Type.AST.VariadicDepFunType(_, _, _, _) => ??? + rise.Type.AST.VariadicDepFunType(_, _, _, _) => + throw new Exception("This should not happen, there are not data types") } } diff --git a/meta/src/main/scala/meta/parser/DPIA/Decl.scala b/meta/src/main/scala/meta/parser/DPIA/Decl.scala index 505a508c9..f5a105e39 100644 --- a/meta/src/main/scala/meta/parser/DPIA/Decl.scala +++ b/meta/src/main/scala/meta/parser/DPIA/Decl.scala @@ -18,8 +18,6 @@ object Decl { def PrimitiveDeclarations[_: P]: P[Seq[AST.PrimitiveDeclaration]] = P(Start ~ PrimitiveDeclaration.rep(1) ~ End) - // def drop(n: nat, m: nat, t: data, input: exp[n+m.t, read]): exp[m.t, read] - // def mapGlobal[dim: Int](n: nat, s: data, t: data, f: exp[s, read] -> exp[t, read], array: exp[n.s, read]): exp[n.t, read] def PrimitiveDeclaration[_: P]: P[AST.PrimitiveDeclaration] = { import scalaparse.Scala.TrailingCommaOps def ScalaParams: P[(Int, Int)] = { diff --git a/meta/src/main/scala/meta/parser/rise/Type.scala b/meta/src/main/scala/meta/parser/rise/Type.scala index 8172fba63..fa3cfa80e 100644 --- a/meta/src/main/scala/meta/parser/rise/Type.scala +++ b/meta/src/main/scala/meta/parser/rise/Type.scala @@ -84,7 +84,8 @@ object Type { def IdentifierKindPair[_: P]: P[(AST.Identifier, Kind.AST)] = P(Identifier.map(AST.Identifier) ~ ":" ~ Kind.Kind) - def UnrolledTypeIdentifier[_: P]: P[AST.UnrolledIdentifier] = P("*" ~ Identifier).map(AST.UnrolledIdentifier) + def UnrolledTypeIdentifier[_: P]: P[AST.UnrolledIdentifier] = + P("*" ~ Identifier).map(AST.UnrolledIdentifier) object DataType { def ScalarType[_: P]: P[AST.ScalarType] = From 0d9bfa02174df5c59169b84cd2ee426ecf00d8a9 Mon Sep 17 00:00:00 2001 From: Michel Steuwer Date: Mon, 10 May 2021 09:49:06 +0100 Subject: [PATCH 3/5] Cleanup and added more comments --- .../scala/meta/generator/DPIAPrimitives.scala | 455 +++++++++--------- .../scala/meta/generator/RisePrimitives.scala | 29 +- .../main/scala/meta/parser/DPIA/Decl.scala | 14 +- .../parser/DPIA/isWellKindedDefinition.scala | 5 +- 4 files changed, 249 insertions(+), 254 deletions(-) diff --git a/meta/src/main/scala/meta/generator/DPIAPrimitives.scala b/meta/src/main/scala/meta/generator/DPIAPrimitives.scala index ca1269664..b18280f40 100644 --- a/meta/src/main/scala/meta/generator/DPIAPrimitives.scala +++ b/meta/src/main/scala/meta/generator/DPIAPrimitives.scala @@ -1,35 +1,37 @@ package meta.generator import fastparse.{Parsed, parse} +import meta.parser.DPIA.Decl.KindOrType import meta.parser._ +import scala.meta._ object DPIAPrimitives { def main(args: Array[String]): Unit = { - val sourceDir = args.head + val sourceDir = args.head // excepts one argument that is the source directory of the rise repo (i.e. 'rise/src') val shinePath = os.Path(sourceDir) / "shine" - os.walk.stream(shinePath).filter(_.ext == "dpia").foreach(path => { + os.walk.stream(shinePath).filter(_.ext == "dpia").foreach(path => { // for each file with the `.dpia` extension ... import DPIA.Decl.AST._ - val definition = os.read(path) - parse(definition, DPIA.Decl.PrimitiveDeclarations(_)) match { + val definition = os.read(path) // ... read the file content + parse(definition, DPIA.Decl.PrimitiveDeclarations(_)) match { // ... and parse it case failure: Parsed.Failure => println(s"Failed to parse `${failure.extra.input}'") println(s" $failure") - case Parsed.Success(seq, _) => - seq.foreach { + case Parsed.Success(seq, _) => seq.foreach { // ... if successful go over all case PrimitiveDeclaration(Identifier(originalName), scalaParams, params, returnType) - if DPIA.isWellKindedDefinition(toParamList(definition, scalaParams), params, returnType) => + if DPIA.isWellKindedDefinition( // ... well kinded declarations + toParamList(definition, scalaParams), params, returnType) => val name = originalName.capitalize val outputPath = (path / os.up) / s"$name.scala" println(s"Generate $outputPath") - import scala.meta._ val packageName = path.relativeTo(shinePath).segments.dropRight(1).foldLeft[Term.Ref](Term.Name("shine")) { case (t, name) => Term.Select(t, Term.Name(name)) } + // ... and generate a case class definition with some imports val code = s"""// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // |// This file is automatically generated and should not be changed manually // |// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // @@ -47,24 +49,33 @@ ${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), param }""".toString()} |""".stripMargin - os.write.over(outputPath, code) + os.write.over(outputPath, code) // ... and finally write out to disk. case PrimitiveDeclaration(Identifier(name), _, params, returnType) => - println(s"Could not generate code for `$name' as parameters `$params' and/or `$returnType' are not well kinded.") + println(s"Could not generate code for `$name' as parameters " + + s"`$params' and/or `$returnType' are not well kinded.") } } }) } + // parse scala parameters (i.e. parameters that are not part of the DPIA language) into a list of parameters + def toParamList(definition: String, scalaParams: Option[(Int, Int)]): Option[List[scala.meta.Term.Param]] = { + scalaParams.map { case (start, end) => + s"def foo(${definition.substring(start, end)})".parse[Stat].get match { + case declDef: Decl.Def => declDef.paramss.head + } + } + } + def generateCaseClass(name: scala.meta.Type.Name, scalaParams: Option[List[scala.meta.Term.Param]], params: Seq[DPIA.Decl.AST.Param], returnType: DPIA.Type.AST): scala.meta.Defn.Class = { - import scala.meta._ - import meta.parser.DPIA.Type.AST + // determine the super class and the scale type of the `t` member val (scalaReturnType, superClass) = returnType match { - case AST.ExpType(_, _) => (t"ExpType", init"ExpPrimitive") - case AST.AccType(_) => (t"AccType", init"AccPrimitive") - case AST.CommType => (t"CommType", init"CommandPrimitive") + case DPIA.Type.AST.ExpType(_, _) => (t"ExpType", init"ExpPrimitive") + case DPIA.Type.AST.AccType(_) => (t"AccType", init"AccPrimitive") + case DPIA.Type.AST.CommType => (t"CommType", init"CommandPrimitive") case _ => throw new Exception(s"Expected `exp', `acc' or `comm' as return type for ${name.value}") } val generatedParams = generateParams(scalaParams, params) @@ -81,226 +92,193 @@ ${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), param ${generateVisitAndRebuild(name, generatedParams)} ..${if (scalaParams.nonEmpty && generatedParams.last.size > 1) { + // generate a unwrap function, that returns the second parameter list as a tuple List(generateUnwrap(generatedParams.last)) } else List() } } """ } - def toParamList(definition: String, scalaParams: Option[(Int, Int)]): Option[List[scala.meta.Term.Param]] = { - import scala.meta._ - scalaParams.map { case (start, end) => - s"def foo(${definition.substring(start, end)})".parse[Stat].get match { - case declDef: Decl.Def => declDef.paramss.head - } - } - } - + // generates lists of scala parameters to the generated case class + // if there are scala params there will be two parameter list, otherwise there will be one def generateParams(scalaParams: Option[List[scala.meta.Term.Param]], - params: Seq[DPIA.Decl.AST.Param]): List[List[scala.meta.Term.Param]] = { - import scala.meta._ - - scalaParams match { - case Some(scalaParamList) => - List(scalaParamList) ++ List(params.map(generateParam).toList) - case None => - List(params.map(generateParam).toList) - } + params: Seq[DPIA.Decl.AST.Param]): List[List[scala.meta.Term.Param]] = scalaParams match { + case Some(scalaParamList) => + List(scalaParamList) ++ List(params.map(generateParam).toList) + case None => + List(params.map(generateParam).toList) } - def generateParam(param: DPIA.Decl.AST.Param): scala.meta.Term.Param = { - import scala.meta._ - import _root_.meta.parser.DPIA.Kind - param"val ${Term.Name(param.id.name)}: ${ + def generateParam(param: DPIA.Decl.AST.Param): scala.meta.Term.Param = + param"""val ${Term.Name(param.id.name)}: ${ // generates something like param.ty match { - case Left(kindAST) => generateType(kindAST) - case Right(meta.parser.DPIA.Type.AST.VariadicType(_, typeAST)) => - t"Seq[Phrase[${generatePhraseType(typeAST)}]]" - case Right(typeAST) => t"Phrase[${generatePhraseType(typeAST)}]" + case KindOrType.Kind(kindAST) => generateType(kindAST) // e.g. val dt: DataType + case KindOrType.Type(DPIA.Type.AST.VariadicType(_, typeAST)) => + t"Seq[Phrase[${generatePhraseType(typeAST)}]]" // e.g. val args: Seq[Phrase[ExpType]] + case KindOrType.Type(typeAST) => + t"Phrase[${generatePhraseType(typeAST)}]" // e.g. val input: Phrase[ExpType] } - }" - } - - def generateTypeChecks(params: Seq[DPIA.Decl.AST.Param]): scala.meta.Term.Block = { - import scala.meta._ - q"""{ - ..${params. - filter(param => param.ty.isRight). // only check types for parameters with phrase types - map(param => - param.ty match { - case Right(typeAST@DPIA.Type.AST.DepFunType(id, _, _)) => - q"""${Term.Name(param.id.name)} :: { - ${Defn.Val( - mods = Nil, - pats = List(Pat.Var(name = Term.Name(id.name))), - decltpe = None, - rhs = q"${Term.Name(param.id.name)}.t.x" - )} - ${generateTerm(typeAST)} - }""" - case Right(DPIA.Type.AST.VariadicType(n, typeAST)) => - getUnrolledIds(typeAST) match { - case Seq() => - q"${Term.Name(param.id.name)}.foreach(_ :: ${generateTerm(typeAST)})" - case unrolledIds => - val zips = unrolledIds.foldLeft[Term](Term.Name(param.id.name)) { - case (term, id) => q"$term.zip(${Term.Name(id.name)})" - } - val p = Term.Name(param.id.name) - val pattern = unrolledIds.foldRight[Pat](Pat.Var(p)) { - case (id, pattern) => p"($pattern, ${Pat.Var(Term.Name(id.name))})" - } - q"""$zips.foreach { - case ($pattern) => $p :: ${generateTerm(typeAST)} - }""" - } - case Right(typeAST) => - q"${Term.Name(param.id.name)} :: ${generateTerm(typeAST)}" - case Left(_) => throw new Exception("Generation of type checks not necessary for kinds") - } - ).toList} }""" - } - def getUnrolledIds(typeAST: DPIA.Type.AST): Seq[rise.Type.AST.UnrolledIdentifier] = { - import DPIA.Type.AST - typeAST match { - case AST.ExpType(dataType, _) => - getUnrolledIds(dataType) - case AST.AccType(dataType) => - getUnrolledIds(dataType) - case AST.CommType => Seq() - case AST.PairType(lhs, rhs) => - getUnrolledIds(lhs) concat getUnrolledIds(rhs) - case AST.FunType(inT, outT) => - getUnrolledIds(inT) concat getUnrolledIds(outT) - case AST.DepFunType(_, _, t) => - getUnrolledIds(t) - case AST.Identifier(_) => Seq() - case AST.VariadicType(n, ty) => throw new Exception("This function should not be called on a variadic type") - } + // generate the Scala type for representing DPIA/Rise types of different kinds + def generateType(kindAST: DPIA.Kind.AST): scala.meta.Type = kindAST match { + case DPIA.Kind.AST.RiseKind(riseKind) => riseKind match { + case rise.Kind.AST.Data => Type.Name("DataType") + case rise.Kind.AST.Address => Type.Name("AddressSpace") + case rise.Kind.AST.Nat2Nat => Type.Name("NatToNat") + case rise.Kind.AST.Nat2Data => Type.Name("NatToData") + case rise.Kind.AST.Nat => Type.Name("Nat") + case rise.Kind.AST.Fragment => Type.Name("FragmentKind") + case rise.Kind.AST.MatrixLayout => Type.Name("MatrixLayout") + } + case DPIA.Kind.AST.Access => Type.Name("AccessType") + case DPIA.Kind.AST.VariadicKind(_, kind) => t"Seq[${generateType(kind)}]" } - def getUnrolledIds(typeAST: rise.Type.AST): Seq[rise.Type.AST.UnrolledIdentifier] = { - import rise.Type.AST - typeAST match { - case AST.Identifier(_) => Seq() - case id@AST.UnrolledIdentifier(_) => Seq(id) - case AST.FunType(inT, outT) => - getUnrolledIds(inT) concat getUnrolledIds(outT) - case AST.DepFunType(_, _, t) => - getUnrolledIds(t) - case AST.ImplicitDepFunType(_, _, t) => - getUnrolledIds(t) - case AST.VariadicFunType(_, _, _) => throw new Exception("This function should not be called on a variadic type") - case AST.VariadicDepFunType(_, _, _, _) => throw new Exception("This function should not be called on a variadic type") - case AST.ScalarType(_) => Seq() - case AST.NatType => Seq() - case AST.OpaqueType(_) => Seq() - case AST.VectorType(_, elemType) => getUnrolledIds(elemType) - case AST.IndexType(_) => Seq() - case AST.PairType(lhs, rhs) => - getUnrolledIds(lhs) concat getUnrolledIds(rhs) - case AST.DepPairType(_, _, t) => getUnrolledIds(t) - case AST.NatToDataApply(_, _) => Seq() - case AST.NatToDataLambda(_, t) => getUnrolledIds(t) - case AST.ArrayType(_, elemType) => getUnrolledIds(elemType) - case AST.DepArrayType(_, _) => Seq() - case AST.FragmentType(_, _, _, dt, _, _) => getUnrolledIds(dt) - case AST.ManagedBufferType(t) => getUnrolledIds(t) - } + // generate Scala type for different phrase types + def generatePhraseType(typeAST: DPIA.Type.AST): scala.meta.Type = typeAST match { + case DPIA.Type.AST.ExpType(_, _) => t"ExpType" + case DPIA.Type.AST.AccType(_) => t"AccType" + case DPIA.Type.AST.CommType => t"CommType" + case DPIA.Type.AST.PairType(lhs, rhs) => t"PhrasePairType[${generatePhraseType(lhs)}, ${generatePhraseType(rhs)}]" + case DPIA.Type.AST.FunType(inT, outT) => t"FunType[${generatePhraseType(inT)}, ${generatePhraseType(outT)}]" + case DPIA.Type.AST.DepFunType(id, kind, t) => t"DepFunType[${generateKindType(kind)}, ${generatePhraseType(t)}]" + case DPIA.Type.AST.Identifier(name) => Type.Name(name) + case DPIA.Type.AST.VariadicType(_, _) => throw new Exception("Can not generate Phrase Type for Variadic Type") } - def generatePhraseType(typeAST: DPIA.Type.AST): scala.meta.Type = { - import scala.meta._ - import meta.parser.DPIA.Type.AST - typeAST match { - case AST.ExpType(_, _) => t"ExpType" - case AST.AccType(_) => t"AccType" - case AST.CommType => t"CommType" - case AST.PairType(lhs, rhs) => t"PhrasePairType[${generatePhraseType(lhs)}, ${generatePhraseType(rhs)}]" - case AST.FunType(inT, outT) => t"FunType[${generatePhraseType(inT)}, ${generatePhraseType(outT)}]" - case AST.DepFunType(id, kind, t) => t"DepFunType[${generateKindType(kind)}, ${generatePhraseType(t)}]" - case AST.Identifier(name) => Type.Name(name) - case AST.VariadicType(_, _) => throw new Exception("Can not generate Phrase Type for Variadic Type") - } + // generate Scala type for representing the DPIA/rise kinds themselves + def generateKindType(kindAST: DPIA.Kind.AST): scala.meta.Type = kindAST match { + case DPIA.Kind.AST.RiseKind(riseKind) => riseKind match { + case rise.Kind.AST.Data => Type.Name("DataKind") + case rise.Kind.AST.Address => Type.Name("AddressSpaceKind") + case rise.Kind.AST.Nat2Nat => Type.Name("NatToNatKind") + case rise.Kind.AST.Nat2Data => Type.Name("NatToDataKind") + case rise.Kind.AST.Nat => Type.Name("NatKind") + case rise.Kind.AST.Fragment => throw new Exception("Can not generate Kind for Fragment") + case rise.Kind.AST.MatrixLayout => throw new Exception("Can not generate Kind for Matrix Layout") + } + case DPIA.Kind.AST.Access => Type.Name("AccessKind") + case DPIA.Kind.AST.VariadicKind(_, _) => throw new Exception("Can not generate Kind for Variadic Kind") } - def generateType(kindAST: DPIA.Kind.AST): scala.meta.Type = { - import scala.meta._ - import meta.parser.DPIA.Kind.AST - kindAST match { - case AST.RiseKind(riseKind) => - import meta.parser.rise.Kind.AST - riseKind match { - case AST.Data => Type.Name("DataType") - case AST.Address => Type.Name("AddressSpace") - case AST.Nat2Nat => Type.Name("NatToNat") - case AST.Nat2Data => Type.Name("NatToData") - case AST.Nat => Type.Name("Nat") - case AST.Fragment => Type.Name("FragmentKind") - case AST.MatrixLayout => Type.Name("MatrixLayout") + // generate type checks in the body of the generated case classes, e.g. for map: + // f :: FunType(expT(dt1, a), expT(dt2, a)) + // array :: expT(ArrayType(n, dt1), a) + def generateTypeChecks(params: Seq[DPIA.Decl.AST.Param]): scala.meta.Term.Block = + q"""{ + ..${params. + filter(param => param.ty.isInstanceOf[KindOrType.Type]). // only check types for params with phrase types + map(param => param.ty match { + // special treatment for params with variadic type + case KindOrType.Type(DPIA.Type.AST.VariadicType(n, typeAST)) => + generateVariadicTypeCheck(param, n, typeAST) + // for params with dependent function type we generate something like this, e.g. for iterate: + // f :: ({ + // val l = f.t.x + // DepFunType[NatKind, PhraseType](l, + // FunType(expT(ArrayType(l * n, dt), read), expT(ArrayType(l, dt), write))) + // }) + case KindOrType.Type(typeAST@DPIA.Type.AST.DepFunType(id, _, _)) => + q"""${Term.Name(param.id.name)} :: { + val ${Pat.Var(name = Term.Name(id.name))} = ${Term.Name(param.id.name)}.t.x + ${generateTerm(typeAST)} + }""" + // for all other parameters we generate the simple form: e :: t + case KindOrType.Type(typeAST) => + q"${Term.Name(param.id.name)} :: ${generateTerm(typeAST)}" + case KindOrType.Kind(_) => throw new Exception("Generation of type checks not necessary for kinds") + }).toList} + }""" + + def generateVariadicTypeCheck(param: DPIA.Decl.AST.Param, + n: DPIA.Type.AST.Identifier, + typeAST: DPIA.Type.AST): scala.meta.Term = + // we check if there are unrolled ids (i.e. *id) in the expression + getUnrolledIds(typeAST) match { + case Seq() => // ... if there are none, then we can generate the simple form with an foreach + q"${Term.Name(param.id.name)}.foreach(_ :: ${generateTerm(typeAST)})" + case unrolledIds => // ... if there are, then we have to zip the unrolled ids with the param + // we generate something like this, e.g. for foreign function call: + // args.zip(inTs).foreach({ case (args, inTs) => args :: expT(inTs, read) }) + + // generate a chain of nested zip calls (in case there are multiple unrolled ids) + val zips = unrolledIds.foldLeft[Term](Term.Name(param.id.name)) { + case (term, id) => q"$term.zip(${Term.Name(id.name)})" + } + val p = Term.Name(param.id.name) + // ... generate the case pattern reusing the variable names of the param and unrolled ids + // (this avoids the need to rename the variables in the `typeAST`) + val pattern = unrolledIds.foldRight[Pat](Pat.Var(p)) { + case (id, pattern) => p"($pattern, ${Pat.Var(Term.Name(id.name))})" } - case AST.Access => Type.Name("AccessType") - case AST.VariadicKind(_, kind) => t"Seq[${generateType(kind)}]" + // ... putting everything together with a foreach into the final form + q"""$zips.foreach { + case ($pattern) => $p :: ${generateTerm(typeAST)} + }""" } + + def getUnrolledIds(typeAST: DPIA.Type.AST): Seq[rise.Type.AST.UnrolledIdentifier] = typeAST match { + case DPIA.Type.AST.ExpType(dataType, _) => getUnrolledIds(dataType) + case DPIA.Type.AST.AccType(dataType) => getUnrolledIds(dataType) + case DPIA.Type.AST.CommType => Seq() + case DPIA.Type.AST.PairType(lhs, rhs) => getUnrolledIds(lhs) concat getUnrolledIds(rhs) + case DPIA.Type.AST.FunType(inT, outT) => getUnrolledIds(inT) concat getUnrolledIds(outT) + case DPIA.Type.AST.DepFunType(_, _, t) => getUnrolledIds(t) + case DPIA.Type.AST.Identifier(_) => Seq() + case DPIA.Type.AST.VariadicType(_, _) => throw new Exception("This function should not be called on a variadic type") } - def generateKindType(kindAST: DPIA.Kind.AST): scala.meta.Type = { - import scala.meta._ - import meta.parser.DPIA.Kind.AST - kindAST match { - case AST.RiseKind(riseKind) => - import meta.parser.rise.Kind.AST - riseKind match { - case AST.Data => Type.Name("DataKind") - case AST.Address => Type.Name("AddressSpaceKind") - case AST.Nat2Nat => Type.Name("NatToNatKind") - case AST.Nat2Data => Type.Name("NatToDataKind") - case AST.Nat => Type.Name("NatKind") - case AST.Fragment => throw new Exception("Can not generate Kind for Fragment") - case AST.MatrixLayout => throw new Exception("Can not generate Kind for Matrix Layout") - } - case AST.Access => Type.Name("AccessKind") - case AST.VariadicKind(_, _) => throw new Exception("Can not generate Kind for Variadic Kind") - } + def getUnrolledIds(typeAST: rise.Type.AST): Seq[rise.Type.AST.UnrolledIdentifier] = typeAST match { + case rise.Type.AST.Identifier(_) => Seq() + case id@rise.Type.AST.UnrolledIdentifier(_) => Seq(id) + case rise.Type.AST.FunType(inT, outT) => getUnrolledIds(inT) concat getUnrolledIds(outT) + case rise.Type.AST.DepFunType(_, _, t) => getUnrolledIds(t) + case rise.Type.AST.ImplicitDepFunType(_, _, t) => getUnrolledIds(t) + case rise.Type.AST.ScalarType(_) => Seq() + case rise.Type.AST.NatType => Seq() + case rise.Type.AST.OpaqueType(_) => Seq() + case rise.Type.AST.VectorType(_, elemType) => getUnrolledIds(elemType) + case rise.Type.AST.IndexType(_) => Seq() + case rise.Type.AST.PairType(lhs, rhs) => getUnrolledIds(lhs) concat getUnrolledIds(rhs) + case rise.Type.AST.DepPairType(_, _, t) => getUnrolledIds(t) + case rise.Type.AST.NatToDataApply(_, _) => Seq() + case rise.Type.AST.NatToDataLambda(_, t) => getUnrolledIds(t) + case rise.Type.AST.ArrayType(_, elemType) => getUnrolledIds(elemType) + case rise.Type.AST.DepArrayType(_, _) => Seq() + case rise.Type.AST.FragmentType(_, _, _, dt, _, _) => getUnrolledIds(dt) + case rise.Type.AST.ManagedBufferType(t) => getUnrolledIds(t) + case rise.Type.AST.VariadicFunType(_, _, _) | rise.Type.AST.VariadicDepFunType(_, _, _, _) => + throw new Exception("This function should not be called on a variadic type") } - def generateTerm(typeAST: DPIA.Type.AST): scala.meta.Term = { - import scala.meta._ - import meta.parser.DPIA.Type.AST - typeAST match { - case AST.ExpType(dataType, access) => - q"expT(${RisePrimitives.generateDataType(dataType)}, ${generateTerm(access)})" - case AST.AccType(dataType) => - q"accT(${RisePrimitives.generateDataType(dataType)})" - case AST.CommType => - q"comm" - case AST.PairType(lhs, rhs) => - q"PhrasePairType(${generateTerm(lhs)}, ${generateTerm(rhs)})" - case AST.FunType(inT, outT) => - q"FunType(${generateTerm(inT)}, ${generateTerm(outT)})" - case AST.DepFunType(id, kind, t) => - q"DepFunType[${generateKindType(kind)}, PhraseType](${Term.Name(id.name)}, ${generateTerm(t)})" - case AST.Identifier(name) => Term.Name(name) - case AST.VariadicType(_, _) => throw new Exception("Can not generate Term for Variadic Type") - } + // generate Scala term that represents the given DPIA/rise type + def generateTerm(typeAST: DPIA.Type.AST): scala.meta.Term = typeAST match { + case DPIA.Type.AST.Identifier(name) => Term.Name(name) + case DPIA.Type.AST.ExpType(dataType, access) => + q"expT(${RisePrimitives.generateDataType(dataType)}, ${generateTerm(access)})" + case DPIA.Type.AST.AccType(dataType) => + q"accT(${RisePrimitives.generateDataType(dataType)})" + case DPIA.Type.AST.CommType => + q"comm" + case DPIA.Type.AST.PairType(lhs, rhs) => + q"PhrasePairType(${generateTerm(lhs)}, ${generateTerm(rhs)})" + case DPIA.Type.AST.FunType(inT, outT) => + q"FunType(${generateTerm(inT)}, ${generateTerm(outT)})" + case DPIA.Type.AST.DepFunType(id, kind, t) => + q"DepFunType[${generateKindType(kind)}, PhraseType](${Term.Name(id.name)}, ${generateTerm(t)})" + case DPIA.Type.AST.VariadicType(_, _) => throw new Exception("Can not generate Term for Variadic Type") } - def generateTerm(accessAST: DPIA.Type.Access.AST): scala.meta.Term = { - import scala.meta._ - import meta.parser.DPIA.Type.Access.AST - accessAST match { - case AST.Identifier(name) => Term.Name(name) - case AST.Read => Term.Name("read") - case AST.Write =>Term.Name("write") - } + def generateTerm(accessAST: DPIA.Type.Access.AST): scala.meta.Term = accessAST match { + case DPIA.Type.Access.AST.Identifier(name) => Term.Name(name) + case DPIA.Type.Access.AST.Read => Term.Name("read") + case DPIA.Type.Access.AST.Write =>Term.Name("write") } def generateVisitAndRebuild(name: scala.meta.Type.Name, paramLists: List[List[scala.meta.Term.Param]]): scala.meta.Defn.Def = { - import scala.meta._ - + // little pattern matching helper that ignores if a type name is written with a package prefix object TypeIs { def unapply(ty: Type): Option[String] = ty match { case Type.Name(name) => Some(name) @@ -309,40 +287,42 @@ ${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), param } } - def injectVisitCall(param: Term.Param): Term = { - param.decltpe match { - case Some(ty) => ty match { - case TypeIs("Nat") | TypeIs("NatIdentifier") => - q"v.nat(${Term.Name(param.name.value)})" - case TypeIs("DataType") | TypeIs("ScalarType") | TypeIs("BasicType") => - q"v.data(${Term.Name(param.name.value)})" - case TypeIs("NatToNat") => - q"v.natToNat(${Term.Name(param.name.value)})" - case TypeIs("NatToData") => - q"v.natToData(${Term.Name(param.name.value)})" - case TypeIs("AccessType") => - q"v.access(${Term.Name(param.name.value)})" - case TypeIs("AddressSpace") => - q"v.addressSpace(${Term.Name(param.name.value)})" - case TypeIs("LocalSize") | TypeIs("GlobalSize") => - q"${Term.Name(param.name.value)}.visitAndRebuild(v)" - case t"Phrase[$_]" => q"VisitAndRebuild(${Term.Name(param.name.value)}, v)" - case t"Vector[Phrase[$_]]" => q"${Term.Name(param.name.value)}.map(VisitAndRebuild(_, v))" - case t"Seq[Phrase[$_]]" => q"${Term.Name(param.name.value)}.map(VisitAndRebuild(_, v))" - - case t"Map[Identifier[_ <: PhraseType], $_]" => - q"""${Term.Name(param.name.value)}.map{ case (key, value) => - VisitAndRebuild(key, v).asInstanceOf[Identifier[_ <: PhraseType]] -> value - }""" - - case Type.Apply(Type.Name("Vector"), List(TypeIs("DataType"))) // Vector[DataType] - | Type.Apply(Type.Name("Seq"), List(TypeIs("DataType"))) => // Seq[DataType] - q"${Term.Name(param.name.value)}.map(v.data)" - case _ => - Term.Name(param.name.value) - } - case None => throw new Exception(s"Expected type declaration") + // inject a call to the visitor depending on the param's generated scala type + def injectVisitCall(param: Term.Param): Term = param.decltpe match { + case Some(ty) => ty match { + // the different kinds ... + case TypeIs("Nat") | TypeIs("NatIdentifier") => + q"v.nat(${Term.Name(param.name.value)})" + case TypeIs("DataType") | TypeIs("ScalarType") | TypeIs("BasicType") => + q"v.data(${Term.Name(param.name.value)})" + case TypeIs("NatToNat") => + q"v.natToNat(${Term.Name(param.name.value)})" + case TypeIs("NatToData") => + q"v.natToData(${Term.Name(param.name.value)})" + case TypeIs("AccessType") => + q"v.access(${Term.Name(param.name.value)})" + case TypeIs("AddressSpace") => + q"v.addressSpace(${Term.Name(param.name.value)})" + case Type.Apply(Type.Name("Vector"), List(TypeIs("DataType"))) // Vector[DataType] + | Type.Apply(Type.Name("Seq"), List(TypeIs("DataType"))) => // Seq[DataType] + q"${Term.Name(param.name.value)}.map(v.data)" + + // ... phrases ... + case t"Phrase[$_]" => q"VisitAndRebuild(${Term.Name(param.name.value)}, v)" + case t"Seq[Phrase[$_]]" => q"${Term.Name(param.name.value)}.map(VisitAndRebuild(_, v))" + + // ... and finally special cases, that maybe should be eliminated in the future + case TypeIs("LocalSize") | TypeIs("GlobalSize") => + q"${Term.Name(param.name.value)}.visitAndRebuild(v)" + case t"Map[Identifier[_ <: PhraseType], $_]" => + q"""${Term.Name(param.name.value)}.map{ case (key, value) => + VisitAndRebuild(key, v).asInstanceOf[Identifier[_ <: PhraseType]] -> value + }""" + + case _ => + Term.Name(param.name.value) } + case None => throw new Exception(s"Expected type declaration") } q"""override def visitAndRebuild(v: VisitAndRebuild.Visitor): $name = @@ -351,7 +331,6 @@ ${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), param } def generateUnwrap(paramList: List[scala.meta.Term.Param]): scala.meta.Defn.Def = { - import scala.meta._ val (types, names) = paramList.map({ case Term.Param(_, name, Some(typ), _) => (typ, Term.Name(name.value)) }).unzip diff --git a/meta/src/main/scala/meta/generator/RisePrimitives.scala b/meta/src/main/scala/meta/generator/RisePrimitives.scala index 7523d0046..eeafb6f54 100644 --- a/meta/src/main/scala/meta/generator/RisePrimitives.scala +++ b/meta/src/main/scala/meta/generator/RisePrimitives.scala @@ -5,8 +5,6 @@ import meta.parser._ import meta.parser.rise.Kind import meta.parser.rise.Kind.AST -import scala.meta.Term - object RisePrimitives { def main(args: Array[String]): Unit = { val sourceDir = args.head // excepts one argument that is the source directory of the rise repo (i.e. 'rise/src') @@ -59,7 +57,7 @@ import arithexpr.arithmetic._ } // parse scala parameters (i.e. parameters that are not part of the rise language) into a list of parameters - def toParamList(definition: String, scalaParams: Option[(Int, Int)]): Option[List[Term.Param]] = { + def toParamList(definition: String, scalaParams: Option[(Int, Int)]): Option[List[scala.meta.Term.Param]] = { import scala.meta._ scalaParams.map { case (start, end) => s"def foo(${definition.substring(start, end)})".parse[Stat].get match { @@ -69,7 +67,9 @@ import arithexpr.arithmetic._ } // generate either an object (if there are no scala parameters) or a case class (if there are some) - def generate(name: String, params: Option[List[Term.Param]], typeSignature: rise.Type.AST): scala.meta.Term.Block = + def generate(name: String, + params: Option[List[scala.meta.Term.Param]], + typeSignature: rise.Type.AST): scala.meta.Term.Block = params match { case None => generateObject(name, typeSignature) case Some(params) => generateCaseClass(name, params, typeSignature) @@ -101,7 +101,9 @@ import arithexpr.arithmetic._ generated } - def generateCaseClass(name: String, params: List[Term.Param], typeSignature: rise.Type.AST): scala.meta.Term.Block = { + def generateCaseClass(name: String, + params: List[scala.meta.Term.Param], + typeSignature: rise.Type.AST): scala.meta.Term.Block = { import scala.meta._ val args: List[Term.Name] = params.map(p => Term.Name(p.name.value)) @@ -179,11 +181,16 @@ import arithexpr.arithmetic._ // ids.foldRight(t){ case (id, t) => DepFunType[DataKind](id, t) } // to represent n-many dependent function types: (id0: kind) -> (id1: kind) -> ... -> t val (createIds, typeName) = kind match { - case AST.Data => (q"""DataTypeIdentifier(freshName("dt"), isExplicit = true)""", Type.Name("DataKind")) - case AST.Address => (q"""AddressSpaceIdentifier(freshName("a"), isExplicit = true)""", Type.Name("AddressSpaceKind")) - case AST.Nat2Nat => (q"""NatToNatIdentifier(freshName("n2n"), isExplicit = true)""", Type.Name("NatToNatKind")) - case AST.Nat2Data => (q"""NatToDataIdentifier(freshName("n2d"), isExplicit = true)""", Type.Name("NatToDataKind")) - case AST.Nat => (q"""NatIdentifier(freshName("n"), isExplicit = true)""", Type.Name("NatKind")) + case AST.Data => + (q"""DataTypeIdentifier(freshName("dt"), isExplicit = true)""", Type.Name("DataKind")) + case AST.Address => + (q"""AddressSpaceIdentifier(freshName("a"), isExplicit = true)""", Type.Name("AddressSpaceKind")) + case AST.Nat2Nat => + (q"""NatToNatIdentifier(freshName("n2n"), isExplicit = true)""", Type.Name("NatToNatKind")) + case AST.Nat2Data => + (q"""NatToDataIdentifier(freshName("n2d"), isExplicit = true)""", Type.Name("NatToDataKind")) + case AST.Nat => + (q"""NatIdentifier(freshName("n"), isExplicit = true)""", Type.Name("NatKind")) case AST.Fragment => throw new Exception("No support for Fragment Kind yet") case AST.MatrixLayout => throw new Exception("No support for Matrix Layout Kind yet") } @@ -210,7 +217,7 @@ import arithexpr.arithmetic._ case rise.Type.AST.NatType => q"NatType" case rise.Type.AST.OpaqueType(name) => - q"OpaqueType($name)" + q"OpaqueType(${Lit.String(name)})" case rise.Type.AST.VectorType(size, elemType) => q"VectorType(${generateNat(size)}, ${generateDataType(elemType)})" case rise.Type.AST.IndexType(size) => diff --git a/meta/src/main/scala/meta/parser/DPIA/Decl.scala b/meta/src/main/scala/meta/parser/DPIA/Decl.scala index f5a105e39..b829c6f95 100644 --- a/meta/src/main/scala/meta/parser/DPIA/Decl.scala +++ b/meta/src/main/scala/meta/parser/DPIA/Decl.scala @@ -8,13 +8,19 @@ object Decl { sealed trait AST object AST { case class Identifier(name: String) extends AST - case class Param(id: Identifier, ty: Either[Kind.AST, Type.AST]) extends AST + case class Param(id: Identifier, ty: KindOrType) extends AST case class PrimitiveDeclaration(id: Identifier, scalaParams: Option[(Int, Int)], params: Seq[Param], returnType: Type.AST) extends AST } + sealed trait KindOrType + object KindOrType { + final case class Kind(kind: meta.parser.DPIA.Kind.AST) extends KindOrType + final case class Type(typeAST: meta.parser.DPIA.Type.AST) extends KindOrType + } + def PrimitiveDeclarations[_: P]: P[Seq[AST.PrimitiveDeclaration]] = P(Start ~ PrimitiveDeclaration.rep(1) ~ End) @@ -27,8 +33,10 @@ object Decl { } def Param: P[AST.Param] = ( - (Identifier.map(AST.Identifier) ~ ":" ~ Kind.Kind).map(pair => AST.Param(pair._1, Left(pair._2))) - | (Identifier.map(AST.Identifier) ~ ":" ~ Type.PhraseType).map(pair => AST.Param(pair._1, Right(pair._2))) + (Identifier.map(AST.Identifier) ~ ":" ~ Kind.Kind).map( + pair => AST.Param(pair._1, KindOrType.Kind(pair._2))) + | (Identifier.map(AST.Identifier) ~ ":" ~ Type.PhraseType).map( + pair => AST.Param(pair._1, KindOrType.Type(pair._2))) ) def Params: P[Seq[AST.Param]] = Param.repTC(0) diff --git a/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala b/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala index ca38195bb..bbefc2674 100644 --- a/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala +++ b/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala @@ -1,5 +1,6 @@ package meta.parser.DPIA +import meta.parser.DPIA.Decl.KindOrType import meta.parser._ object isWellKindedDefinition { @@ -17,9 +18,9 @@ object isWellKindedDefinition { } }).getOrElse(Map.empty[String, Kind.AST]) val env = params.foldLeft(incompleteEnv) { - case (env, Param(Identifier(name), Left(kind))) => + case (env, Param(Identifier(name), KindOrType.Kind(kind))) => env.updated(name, kind) - case (env, Param(Identifier(_), Right(typeAST))) => + case (env, Param(Identifier(_), KindOrType.Type(typeAST))) => if (!isWellKinded(typeAST, env, None)) isWellKindedFlag = false env From 09b5a008e111f919346dcdc69eace2b20c349218 Mon Sep 17 00:00:00 2001 From: Michel Steuwer Date: Mon, 10 May 2021 19:03:06 +0100 Subject: [PATCH 4/5] Added more documentation to the well kindness implementations. --- .../parser/DPIA/isWellKindedDefinition.scala | 102 +++--- .../meta/parser/rise/isWellKindedType.scala | 337 +++++++++--------- 2 files changed, 211 insertions(+), 228 deletions(-) diff --git a/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala b/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala index bbefc2674..a8f7ea363 100644 --- a/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala +++ b/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala @@ -2,79 +2,77 @@ package meta.parser.DPIA import meta.parser.DPIA.Decl.KindOrType import meta.parser._ +import meta.parser.DPIA.Kind.{AST => KindAST} +import meta.parser.rise.isWellKindedType object isWellKindedDefinition { - + // check if given parameters and return type are well kinded, i.e. + // that type-level parameters are declared with the correct kind before they are used def apply(scalaParams: Option[List[scala.meta.Term.Param]], params: Seq[Decl.AST.Param], returnType: Type.AST): Boolean = { import Decl.AST._ var isWellKindedFlag = true - val incompleteEnv = scalaParams.map(_.foldLeft(Map.empty[String, Kind.AST]) { + // prepopulate the environment with all integers in the scala parameters + val prepopulatedEnv = scalaParams.map(_.foldLeft(Map.empty[String, KindAST]) { case (env, param) => param.decltpe match { case Some(scala.meta.Type.Name("Int")) => - env.updated(param.name.value,Kind.AST.RiseKind(rise.Kind.AST.Nat)) + env.updated(param.name.value, KindAST.RiseKind(rise.Kind.AST.Nat)) case _ => env } - }).getOrElse(Map.empty[String, Kind.AST]) - val env = params.foldLeft(incompleteEnv) { + }).getOrElse(Map.empty[String, KindAST]) + // go left-to-right over the parameters ... + val env = params.foldLeft(prepopulatedEnv) { + // ... and add identifiers to the environment case (env, Param(Identifier(name), KindOrType.Kind(kind))) => env.updated(name, kind) + // ... or check if a parameter is well-kinded case (env, Param(Identifier(_), KindOrType.Type(typeAST))) => if (!isWellKinded(typeAST, env, None)) - isWellKindedFlag = false + isWellKindedFlag = false // ... set flag if param is not well-kinded env } + // ... finally check return type isWellKindedFlag && isWellKinded(returnType, env, None) } - def isWellKinded(typeAST: Type.AST, - env: Map[String, Kind.AST], - variadicN: Option[String]): Boolean = { - import Type._ - import rise.isWellKindedType._ - typeAST match { - case AST.ExpType(dataType, access) => - val nenv = env.flatMap { - case (string, DPIA.Kind.AST.RiseKind(riseKind)) => - Some((string, rise.isWellKindedType.DataTypeKind(riseKind))) - case (string, DPIA.Kind.AST.VariadicKind(n, DPIA.Kind.AST.RiseKind(riseKind))) => - Some((string, rise.isWellKindedType.VariadicKind(n, riseKind))) - case _ => None - } - kindOf(dataType, nenv, variadicN).isDefined && isWellKinded(access, env) - case AST.AccType(dataType) => - val nenv = env.flatMap { - case (string, DPIA.Kind.AST.RiseKind(riseKind)) => - Some((string, rise.isWellKindedType.DataTypeKind(riseKind))) - case _ => None - } - kindOf(dataType, nenv, variadicN).isDefined - case AST.CommType => true - case AST.PairType(lhs, rhs) => - isWellKinded(lhs, env, variadicN) && isWellKinded(rhs, env, variadicN) - case AST.FunType(inT, outT) => - isWellKinded(inT, env, variadicN) && isWellKinded(outT, env, variadicN) - case AST.DepFunType(id, kind, t) => - if (env.isDefinedAt(id.name)) { - // we forbid shadowing - false - } else { - isWellKinded(t, env.updated(id.name, kind), variadicN) - } - case AST.Identifier(name) => - env.contains(name) - case AST.VariadicType(n, ty) => - isWellKinded(ty, env, Some(n.name)) - } + def isWellKinded(typeAST: Type.AST, // the AST of the type that is checked + env: Map[String, KindAST], // ... the kinding environment + variadicN: Option[String] // ... an optional string for the number of types in a variadic type + ): Boolean = typeAST match { + case Type.AST.ExpType(dataType, access) => + rise.isWellKindedType.kindOf(dataType, convertEnv(env), variadicN).isDefined && isWellKinded(access, env) + case Type.AST.AccType(dataType) => + rise.isWellKindedType.kindOf(dataType, convertEnv(env), variadicN).isDefined + case Type.AST.CommType => true + case Type.AST.PairType(lhs, rhs) => + isWellKinded(lhs, env, variadicN) && isWellKinded(rhs, env, variadicN) + case Type.AST.FunType(inT, outT) => + isWellKinded(inT, env, variadicN) && isWellKinded(outT, env, variadicN) + case Type.AST.DepFunType(id, kind, t) => + if (env.isDefinedAt(id.name)) { + // we forbid shadowing + false + } else { + isWellKinded(t, env.updated(id.name, kind), variadicN) + } + case Type.AST.Identifier(name) => + env.contains(name) + case Type.AST.VariadicType(n, ty) => + isWellKinded(ty, env, Some(n.name)) } - def isWellKinded(accessAST: Type.Access.AST, env: Map[String, Kind.AST]): Boolean = { - import Type.Access._ - accessAST match { - case AST.Identifier(name) => env.isDefinedAt(name) - case AST.Read => true - case AST.Write => true - } + def isWellKinded(accessAST: Type.Access.AST, env: Map[String, KindAST]): Boolean = accessAST match { + case Type.Access.AST.Identifier(name) => env.isDefinedAt(name) + case Type.Access.AST.Read => true + case Type.Access.AST.Write => true + } + + def convertEnv(env: Map[String, KindAST]): Map[String, isWellKindedType.Kind] = env.flatMap { + case (string, KindAST.RiseKind(riseKind)) => + Some((string, rise.isWellKindedType.DataTypeKind(riseKind))) + case (string, KindAST.VariadicKind(n, KindAST.RiseKind(riseKind))) => + Some((string, rise.isWellKindedType.VariadicKind(n, riseKind))) + case _ => None } } diff --git a/meta/src/main/scala/meta/parser/rise/isWellKindedType.scala b/meta/src/main/scala/meta/parser/rise/isWellKindedType.scala index 351a2865b..fec2e6e4a 100644 --- a/meta/src/main/scala/meta/parser/rise/isWellKindedType.scala +++ b/meta/src/main/scala/meta/parser/rise/isWellKindedType.scala @@ -4,10 +4,12 @@ import meta.parser._ import meta.parser.rise.Kind.{AST => KindAST} object isWellKindedType { - + // check if given type AST is well kinded, i.e. + // that type-level parameters are declared with the correct kind before they are used def apply(scalaParams: Option[List[scala.meta.Term.Param]], typeAST: Type.AST): Boolean = { import scala.meta._ + // prepopulate the environment with all integers in the scala parameters val env = scalaParams.getOrElse(List.empty).foldLeft[Map[String, Kind]](Map.empty){ case (env, param) => param.decltpe match { case Some(t"Int") => env.updated(param.name.value, DataTypeKind(KindAST.Nat)) @@ -17,204 +19,187 @@ object isWellKindedType { kindOf(typeAST, env, None).isDefined } + // possible kinds used in well kindness check sealed trait Kind case class DataTypeKind(kind: KindAST) extends Kind case object FunctionKind extends Kind case class VariadicKind(n: String, kind: KindAST) extends Kind - def kindOf(typeAST: Type.AST, - env: Map[String, Kind], - variadicN: Option[String]): Option[Kind] = { - import Type._ - typeAST match { - case AST.Identifier(name) => - env.get(name) - case AST.UnrolledIdentifier(name) => - env.get(name).flatMap { - case VariadicKind(n, kind) => - variadicN.flatMap(m => - if (n == m) { - Some(VariadicKind(n, kind)) - } else { - None - }) - case DataTypeKind(_) => None - case FunctionKind => None - } - case AST.FunType(inT, outT) => + def kindOf(typeAST: Type.AST, // the AST of the type that is checked + env: Map[String, Kind], // ... the kinding environment + variadicN: Option[String] // ... an optional string for the number of types in a variadic type + ): Option[Kind] = typeAST match { + case Type.AST.Identifier(name) => + env.get(name) + case Type.AST.UnrolledIdentifier(name) => + env.get(name).flatMap { // the kind of the identifier must be a ... + case VariadicKind(n, kind) => // ... variadic kind + variadicN.flatMap(m => // Then check that a variadic number is given ... + if (n == m) { // ... and that it is matching + Some(VariadicKind(n, kind)) + } else { + None + }) + case DataTypeKind(_) => None + case FunctionKind => None + } + case Type.AST.FunType(inT, outT) => + for { + _ <- kindOf(inT, env, variadicN) + _ <- kindOf(outT, env, variadicN) + } yield FunctionKind + case Type.AST.DepFunType(id, kind, t) => + if (env.isDefinedAt(id.name)) { + None // we forbid shadowing + } else { + for { + _ <- kindOf(t, env.updated(id.name, DataTypeKind(kind)), variadicN) + } yield FunctionKind + } + case Type.AST.ImplicitDepFunType(id, kind, t) => + if (env.isDefinedAt(id.name)) { + None // we forbid shadowing + } else { for { - _ <- kindOf(inT, env, variadicN) - _ <- kindOf(outT, env, variadicN) + _ <- kindOf(t, env.updated(id.name, DataTypeKind(kind)), variadicN) } yield FunctionKind - case AST.DepFunType(id, kind, t) => - if (env.isDefinedAt(id.name)) { - None // we forbid shadowing - } else { - for { - _ <- kindOf(t, env.updated(id.name, DataTypeKind(kind)), variadicN) - } yield FunctionKind - } - case AST.ImplicitDepFunType(id, kind, t) => - if (env.isDefinedAt(id.name)) { - None // we forbid shadowing - } else { - for { - _ <- kindOf(t, env.updated(id.name, DataTypeKind(kind)), variadicN) - } yield FunctionKind - } - case AST.VariadicFunType(n, inTs, outT) => + } + case Type.AST.VariadicFunType(n, inTs, outT) => + for { + k1 <- kindOf(n, env, None) + _ <- kindOf(inTs, env, Some(n.name)) // pass the variadic number down when checking the kind of inTs + _ <- kindOf(outT, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat) + } yield FunctionKind + case Type.AST.VariadicDepFunType(n, id, kind, t) => + if (env.isDefinedAt(id.name)) { + None // we forbid shadowing + } else { for { k1 <- kindOf(n, env, None) - _ <- kindOf(inTs, env, Some(n.name)) - _ <- kindOf(outT, env, variadicN) + _ <- kindOf(t, env.updated(id.name, VariadicKind(n.name, kind)), variadicN) if k1 == DataTypeKind(KindAST.Nat) } yield FunctionKind - case AST.VariadicDepFunType(n, id, kind, t) => - if (env.isDefinedAt(id.name)) { - None // we forbid shadowing - } else { - for { - k1 <- kindOf(n, env, None) - _ <- kindOf(t, env.updated(id.name, VariadicKind(n.name, kind)), variadicN) - if k1 == DataTypeKind(KindAST.Nat) - } yield FunctionKind - } - case AST.VectorType(size, elemType) => - for { - k1 <- kindOf(size, env, variadicN) - k2 <- kindOf(elemType, env, variadicN) - if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Data) - } yield DataTypeKind(KindAST.Data) - case AST.IndexType(size) => - for { - k <- kindOf(size, env, variadicN) - if k == DataTypeKind(KindAST.Nat) - } yield DataTypeKind(KindAST.Data) - case AST.PairType(lhs, rhs) => - for { - k1 <- kindOf(lhs, env, variadicN) - k2 <- kindOf(rhs, env, variadicN) - if k1 == DataTypeKind(KindAST.Data) && k2 == DataTypeKind(KindAST.Data) - } yield DataTypeKind(KindAST.Data) - case AST.DepPairType(id, kind, t) => - if (env.isDefinedAt(id.name)) { - None // we forbid shadowing - } else { - kindOf(t, env.updated(id.name, DataTypeKind(kind)), variadicN) - } - case AST.NatToDataApply(f, n) => - for { - k1 <- kindOf(f, env, variadicN) - k2 <- kindOf(n, env, variadicN) - if k1 == DataTypeKind(KindAST.Nat2Data) && k2 == DataTypeKind(KindAST.Nat) - } yield DataTypeKind(KindAST.Data) - case AST.NatToDataLambda(id, t) => - if (env.isDefinedAt(id.name)) { - None // we forbid shadowing - } else { - for { - k <- kindOf(t, env.updated(id.name, DataTypeKind(KindAST.Nat)), variadicN) - if k == DataTypeKind(KindAST.Data) - } yield DataTypeKind(KindAST.Nat2Data) - } - case AST.ArrayType(size, elemType) => - for { - k1 <- kindOf(size, env, variadicN) - k2 <- kindOf(elemType, env, variadicN) - if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Data) - } yield DataTypeKind(KindAST.Data) - case AST.DepArrayType(size, fdt) => - for { - k1 <- kindOf(size, env, variadicN) - k2 <- kindOf(fdt, env, variadicN) - if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Nat2Data) - } yield DataTypeKind(KindAST.Data) - case AST.FragmentType(n, m, k, elemType, fKind, mLayout) => - for { - k1 <- kindOf(n, env, variadicN) - k2 <- kindOf(m, env, variadicN) - k3 <- kindOf(k, env, variadicN) - k4 <- kindOf(elemType, env, variadicN) - k5 <- kindOf(fKind, env, variadicN) - k6 <- kindOf(mLayout, env, variadicN) - if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Nat) && - k3 == DataTypeKind(KindAST.Nat) && k4 == DataTypeKind(KindAST.Data) && - k5 == DataTypeKind(KindAST.Fragment) && k6 == DataTypeKind(KindAST.MatrixLayout) - } yield DataTypeKind(KindAST.Data) - case AST.ManagedBufferType(dt) => + } + case Type.AST.VectorType(size, elemType) => + for { + k1 <- kindOf(size, env, variadicN) + k2 <- kindOf(elemType, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Data) + } yield DataTypeKind(KindAST.Data) + case Type.AST.IndexType(size) => + for { + k <- kindOf(size, env, variadicN) + if k == DataTypeKind(KindAST.Nat) + } yield DataTypeKind(KindAST.Data) + case Type.AST.PairType(lhs, rhs) => + for { + k1 <- kindOf(lhs, env, variadicN) + k2 <- kindOf(rhs, env, variadicN) + if k1 == DataTypeKind(KindAST.Data) && k2 == DataTypeKind(KindAST.Data) + } yield DataTypeKind(KindAST.Data) + case Type.AST.DepPairType(id, kind, t) => + if (env.isDefinedAt(id.name)) { + None // we forbid shadowing + } else { + kindOf(t, env.updated(id.name, DataTypeKind(kind)), variadicN) + } + case Type.AST.NatToDataApply(f, n) => + for { + k1 <- kindOf(f, env, variadicN) + k2 <- kindOf(n, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat2Data) && k2 == DataTypeKind(KindAST.Nat) + } yield DataTypeKind(KindAST.Data) + case Type.AST.NatToDataLambda(id, t) => + if (env.isDefinedAt(id.name)) { + None // we forbid shadowing + } else { for { - k1 <- kindOf(dt, env, variadicN) - if k1 == DataTypeKind(KindAST.Data) - } yield DataTypeKind(KindAST.Data) - case _: AST.ScalarType | AST.NatType | _: AST.OpaqueType => - Some(DataTypeKind(KindAST.Data)) - } + k <- kindOf(t, env.updated(id.name, DataTypeKind(KindAST.Nat)), variadicN) + if k == DataTypeKind(KindAST.Data) + } yield DataTypeKind(KindAST.Nat2Data) + } + case Type.AST.ArrayType(size, elemType) => + for { + k1 <- kindOf(size, env, variadicN) + k2 <- kindOf(elemType, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Data) + } yield DataTypeKind(KindAST.Data) + case Type.AST.DepArrayType(size, fdt) => + for { + k1 <- kindOf(size, env, variadicN) + k2 <- kindOf(fdt, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Nat2Data) + } yield DataTypeKind(KindAST.Data) + case Type.AST.FragmentType(n, m, k, elemType, fKind, mLayout) => + for { + k1 <- kindOf(n, env, variadicN) + k2 <- kindOf(m, env, variadicN) + k3 <- kindOf(k, env, variadicN) + k4 <- kindOf(elemType, env, variadicN) + k5 <- kindOf(fKind, env) + k6 <- kindOf(mLayout, env) + if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Nat) && + k3 == DataTypeKind(KindAST.Nat) && k4 == DataTypeKind(KindAST.Data) && + k5 == DataTypeKind(KindAST.Fragment) && k6 == DataTypeKind(KindAST.MatrixLayout) + } yield DataTypeKind(KindAST.Data) + case Type.AST.ManagedBufferType(dt) => + for { + k1 <- kindOf(dt, env, variadicN) + if k1 == DataTypeKind(KindAST.Data) + } yield DataTypeKind(KindAST.Data) + case _: Type.AST.ScalarType | Type.AST.NatType | _: Type.AST.OpaqueType => + Some(DataTypeKind(KindAST.Data)) } def kindOf(natAST: Nat.AST, env: Map[String, Kind], variadicN: Option[String] - ): Option[Kind] = { - natAST match { - case Nat.AST.Identifier(id) => - env.get(id) - case Nat.AST.Number(_) => - Some(DataTypeKind(KindAST.Nat)) - case Nat.AST.BinaryOp(lhs, _, rhs) => - for { - k1 <- kindOf(lhs, env, variadicN) - k2 <- kindOf(rhs, env, variadicN) - if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Nat) - } yield DataTypeKind(KindAST.Nat) + ): Option[Kind] = natAST match { + case Nat.AST.Identifier(id) => + env.get(id) + case Nat.AST.Number(_) => + Some(DataTypeKind(KindAST.Nat)) + case Nat.AST.BinaryOp(lhs, _, rhs) => + for { + k1 <- kindOf(lhs, env, variadicN) + k2 <- kindOf(rhs, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Nat) + } yield DataTypeKind(KindAST.Nat) - case Nat.AST.TernaryOp(_, thenN, elseN) => - for { - k1 <- kindOf(thenN, env, variadicN) - k2 <- kindOf(elseN, env, variadicN) - if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Nat) - } yield DataTypeKind(KindAST.Nat) + case Nat.AST.TernaryOp(_, thenN, elseN) => + for { + k1 <- kindOf(thenN, env, variadicN) + k2 <- kindOf(elseN, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Nat) + } yield DataTypeKind(KindAST.Nat) - case Nat.AST.Nat2NatApply(f, n) => - for { - k1 <- kindOf(f, env, variadicN) - k2 <- kindOf(n, env, variadicN) - if k1 == DataTypeKind(KindAST.Nat2Nat) && k2 == DataTypeKind(KindAST.Nat) - } yield DataTypeKind(KindAST.Nat) + case Nat.AST.Nat2NatApply(f, n) => + for { + k1 <- kindOf(f, env, variadicN) + k2 <- kindOf(n, env, variadicN) + if k1 == DataTypeKind(KindAST.Nat2Nat) && k2 == DataTypeKind(KindAST.Nat) + } yield DataTypeKind(KindAST.Nat) - case Nat.AST.Sum(id, from, upTo, body) => - val nEnv = env.updated(id.name, DataTypeKind(KindAST.Nat)) - for { - k1 <- kindOf(from, nEnv, variadicN) - k2 <- kindOf(upTo, nEnv, variadicN) - k3 <- kindOf(body, nEnv, variadicN) - if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Nat) && k3 == DataTypeKind(KindAST.Nat) - } yield DataTypeKind(KindAST.Nat) - } + case Nat.AST.Sum(id, from, upTo, body) => + val nEnv = env.updated(id.name, DataTypeKind(KindAST.Nat)) + for { + k1 <- kindOf(from, nEnv, variadicN) + k2 <- kindOf(upTo, nEnv, variadicN) + k3 <- kindOf(body, nEnv, variadicN) + if k1 == DataTypeKind(KindAST.Nat) && k2 == DataTypeKind(KindAST.Nat) && k3 == DataTypeKind(KindAST.Nat) + } yield DataTypeKind(KindAST.Nat) } - def kindOf(fragmentAST: Type.Fragment.AST, - env: Map[String, Kind], - variadicN: Option[String] - ): Option[Kind] = { - import Type._ - fragmentAST match { - case Fragment.AST.Identifier(id) => - env.get(id) - case Fragment.AST.ACC | Fragment.AST.A | Fragment.AST.B => Some(DataTypeKind(KindAST.Fragment)) - } + def kindOf(fragmentAST: Type.Fragment.AST, env: Map[String, Kind]): Option[Kind] = fragmentAST match { + case Type.Fragment.AST.Identifier(id) => env.get(id) + case Type.Fragment.AST.ACC | Type.Fragment.AST.A | Type.Fragment.AST.B => Some(DataTypeKind(KindAST.Fragment)) } - def kindOf(matrixLayout: Type.MatrixLayout.AST, - env: Map[String, Kind], - variadicN: Option[String] - ): Option[Kind] = { - import Type._ - matrixLayout match { - case MatrixLayout.AST.Identifier(id) => - env.get(id) - case MatrixLayout.AST.ROW_MAJOR | - MatrixLayout.AST.COL_MAJOR | - MatrixLayout.AST.NONE => Some(DataTypeKind(KindAST.MatrixLayout)) - } + def kindOf(matrixLayout: Type.MatrixLayout.AST, env: Map[String, Kind]): Option[Kind] = matrixLayout match { + case Type.MatrixLayout.AST.Identifier(id) => env.get(id) + case Type.MatrixLayout.AST.ROW_MAJOR | + Type.MatrixLayout.AST.COL_MAJOR | + Type.MatrixLayout.AST.NONE => Some(DataTypeKind(KindAST.MatrixLayout)) } } From 1395b373ff7f1eb036a22493b62f96f65931ab8f Mon Sep 17 00:00:00 2001 From: Michel Steuwer Date: Tue, 11 May 2021 13:09:13 +0100 Subject: [PATCH 5/5] Added type assertions. --- .../scala/meta/generator/DPIAPrimitives.scala | 85 ++++++++++++------- src/main/scala/shine/DPIA/Types/package.scala | 25 ++++-- .../DPIA/primitives/functional/AsScalar.scala | 3 +- .../DPIA/primitives/functional/AsVector.scala | 3 +- .../functional/AsVectorAligned.scala | 3 +- .../DPIA/primitives/functional/Cast.scala | 3 +- .../functional/CircularBuffer.scala | 3 +- .../DPIA/primitives/functional/Cycle.scala | 3 +- .../DPIA/primitives/functional/DepIdx.scala | 3 +- .../DPIA/primitives/functional/DepJoin.scala | 3 +- .../primitives/functional/DepMapSeq.scala | 3 +- .../DPIA/primitives/functional/DepZip.scala | 3 +- .../DPIA/primitives/functional/Drop.scala | 3 +- .../functional/ForeignFunctionCall.scala | 15 ++-- .../DPIA/primitives/functional/Fst.scala | 3 +- .../DPIA/primitives/functional/Gather.scala | 3 +- .../DPIA/primitives/functional/Generate.scala | 3 +- .../DPIA/primitives/functional/Idx.scala | 3 +- .../DPIA/primitives/functional/IdxVec.scala | 3 +- .../primitives/functional/IndexAsNat.scala | 3 +- .../DPIA/primitives/functional/Iterate.scala | 3 +- .../primitives/functional/IterateStream.scala | 3 +- .../DPIA/primitives/functional/Join.scala | 3 +- .../DPIA/primitives/functional/Let.scala | 3 +- .../primitives/functional/MakeArray.scala | 8 +- .../DPIA/primitives/functional/MakePair.scala | 3 +- .../DPIA/primitives/functional/Map.scala | 3 +- .../DPIA/primitives/functional/MapFst.scala | 3 +- .../DPIA/primitives/functional/MapSeq.scala | 3 +- .../DPIA/primitives/functional/MapSnd.scala | 3 +- .../primitives/functional/MapStream.scala | 3 +- .../DPIA/primitives/functional/MapVec.scala | 3 +- .../primitives/functional/NatAsIndex.scala | 3 +- .../DPIA/primitives/functional/PadClamp.scala | 3 +- .../DPIA/primitives/functional/PadCst.scala | 3 +- .../DPIA/primitives/functional/PadEmpty.scala | 3 +- .../primitives/functional/Partition.scala | 3 +- .../primitives/functional/ReduceSeq.scala | 3 +- .../DPIA/primitives/functional/Reorder.scala | 3 +- .../primitives/functional/RotateValues.scala | 3 +- .../DPIA/primitives/functional/ScanSeq.scala | 3 +- .../DPIA/primitives/functional/Scatter.scala | 3 +- .../DPIA/primitives/functional/Slide.scala | 3 +- .../DPIA/primitives/functional/Snd.scala | 3 +- .../DPIA/primitives/functional/Split.scala | 3 +- .../DPIA/primitives/functional/Take.scala | 3 +- .../DPIA/primitives/functional/ToMem.scala | 3 +- .../primitives/functional/Transpose.scala | 3 +- .../functional/TransposeDepArray.scala | 3 +- .../DPIA/primitives/functional/Unzip.scala | 3 +- .../functional/VectorFromScalar.scala | 3 +- .../DPIA/primitives/functional/Zip.scala | 3 +- .../primitives/imperative/AsScalarAcc.scala | 3 +- .../primitives/imperative/AsVectorAcc.scala | 3 +- .../DPIA/primitives/imperative/Assign.scala | 3 +- .../DPIA/primitives/imperative/Comment.scala | 1 - .../DPIA/primitives/imperative/CycleAcc.scala | 3 +- .../primitives/imperative/DepIdxAcc.scala | 3 +- .../primitives/imperative/DepJoinAcc.scala | 3 +- .../DPIA/primitives/imperative/DropAcc.scala | 3 +- .../DPIA/primitives/imperative/For.scala | 3 +- .../DPIA/primitives/imperative/ForNat.scala | 3 +- .../DPIA/primitives/imperative/ForVec.scala | 3 +- .../primitives/imperative/GenerateCont.scala | 3 +- .../DPIA/primitives/imperative/IdxAcc.scala | 3 +- .../primitives/imperative/IdxVecAcc.scala | 3 +- .../DPIA/primitives/imperative/JoinAcc.scala | 3 +- .../DPIA/primitives/imperative/MapAcc.scala | 3 +- .../primitives/imperative/MapFstAcc.scala | 3 +- .../DPIA/primitives/imperative/MapRead.scala | 3 +- .../primitives/imperative/MapSndAcc.scala | 3 +- .../DPIA/primitives/imperative/New.scala | 3 +- .../imperative/NewDoubleBuffer.scala | 3 +- .../DPIA/primitives/imperative/PairAcc.scala | 3 +- .../DPIA/primitives/imperative/PairAcc1.scala | 3 +- .../DPIA/primitives/imperative/PairAcc2.scala | 3 +- .../primitives/imperative/ReorderAcc.scala | 3 +- .../primitives/imperative/ScatterAcc.scala | 3 +- .../DPIA/primitives/imperative/Seq.scala | 3 +- .../DPIA/primitives/imperative/Skip.scala | 1 - .../DPIA/primitives/imperative/SplitAcc.scala | 3 +- .../DPIA/primitives/imperative/TakeAcc.scala | 3 +- .../primitives/imperative/TransposeAcc.scala | 3 +- .../DPIA/primitives/imperative/UnzipAcc.scala | 3 +- .../DPIA/primitives/imperative/ZipAcc1.scala | 3 +- .../DPIA/primitives/imperative/ZipAcc2.scala | 3 +- .../functional/CircularBuffer.scala | 3 +- .../OpenCL/primitives/functional/DepMap.scala | 3 +- .../primitives/functional/Iterate.scala | 3 +- .../primitives/functional/KernelCall.scala | 15 ++-- .../OpenCL/primitives/functional/Map.scala | 3 +- .../functional/OpenCLFunctionCall.scala | 15 ++-- .../primitives/functional/ReduceSeq.scala | 3 +- .../primitives/functional/RotateValues.scala | 3 +- .../OpenCL/primitives/functional/Run.scala | 3 +- .../OpenCL/primitives/functional/ToMem.scala | 3 +- .../primitives/imperative/Barrier.scala | 1 - .../primitives/imperative/HostExecution.scala | 3 +- .../primitives/imperative/IdxDistribute.scala | 3 +- .../imperative/IdxDistributeAcc.scala | 3 +- .../primitives/imperative/KernelCallCmd.scala | 15 ++-- .../OpenCL/primitives/imperative/New.scala | 3 +- .../imperative/NewDoubleBuffer.scala | 3 +- .../imperative/NewManagedBuffer.scala | 3 +- .../OpenCL/primitives/imperative/ParFor.scala | 3 +- .../primitives/imperative/ParForNat.scala | 3 +- .../primitives/functional/DepMapPar.scala | 3 +- .../OpenMP/primitives/functional/MapPar.scala | 3 +- .../primitives/functional/ReducePar.scala | 3 +- .../OpenMP/primitives/imperative/ParFor.scala | 3 +- .../primitives/imperative/ParForNat.scala | 3 +- .../primitives/functional/AsFragment.scala | 3 +- .../cuda/primitives/functional/AsMatrix.scala | 3 +- .../functional/GenerateFragment.scala | 3 +- .../functional/GlobalToShared.scala | 3 +- .../cuda/primitives/functional/Map.scala | 3 +- .../primitives/functional/MapFragment.scala | 3 +- .../functional/TensorMatMultAdd.scala | 3 +- .../primitives/imperative/ForFragment.scala | 3 +- .../imperative/GlobalToSharedAcc.scala | 3 +- .../cuda/primitives/imperative/ParFor.scala | 3 +- .../primitives/imperative/SyncPipeline.scala | 3 +- .../primitives/imperative/SyncThreads.scala | 1 - .../cuda/primitives/imperative/SyncWarp.scala | 1 - .../cuda/primitives/imperative/WmmaFill.scala | 3 +- .../cuda/primitives/imperative/WmmaLoad.scala | 3 +- .../cuda/primitives/imperative/WmmaMMA.scala | 3 +- .../primitives/imperative/WmmaStore.scala | 3 +- 128 files changed, 348 insertions(+), 183 deletions(-) diff --git a/meta/src/main/scala/meta/generator/DPIAPrimitives.scala b/meta/src/main/scala/meta/generator/DPIAPrimitives.scala index b18280f40..8c77dcafb 100644 --- a/meta/src/main/scala/meta/generator/DPIAPrimitives.scala +++ b/meta/src/main/scala/meta/generator/DPIAPrimitives.scala @@ -81,9 +81,7 @@ ${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), param val generatedParams = generateParams(scalaParams, params) q""" final case class $name(...$generatedParams) extends $superClass { - { - ..${generateTypeChecks(params).stats} - } + ..${generateTypeChecks(params).stats} ..${if (scalaReturnType != t"CommType") { List(q"override val t: $scalaReturnType = ${generateTerm(returnType)}") @@ -165,44 +163,64 @@ ${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), param // generate type checks in the body of the generated case classes, e.g. for map: // f :: FunType(expT(dt1, a), expT(dt2, a)) // array :: expT(ArrayType(n, dt1), a) - def generateTypeChecks(params: Seq[DPIA.Decl.AST.Param]): scala.meta.Term.Block = - q"""{ - ..${params. - filter(param => param.ty.isInstanceOf[KindOrType.Type]). // only check types for params with phrase types - map(param => param.ty match { - // special treatment for params with variadic type - case KindOrType.Type(DPIA.Type.AST.VariadicType(n, typeAST)) => - generateVariadicTypeCheck(param, n, typeAST) - // for params with dependent function type we generate something like this, e.g. for iterate: - // f :: ({ - // val l = f.t.x - // DepFunType[NatKind, PhraseType](l, - // FunType(expT(ArrayType(l * n, dt), read), expT(ArrayType(l, dt), write))) - // }) - case KindOrType.Type(typeAST@DPIA.Type.AST.DepFunType(id, _, _)) => - q"""${Term.Name(param.id.name)} :: { + def generateTypeChecks(params: Seq[DPIA.Decl.AST.Param]): scala.meta.Term.Block = { + val checks = params. + filter(param => param.ty.isInstanceOf[KindOrType.Type]). // only check types for params with phrase types + map(param => param.ty match { + // special treatment for params with variadic type + case KindOrType.Type(DPIA.Type.AST.VariadicType(n, typeAST)) => + generateVariadicTypeCheck(param, n, typeAST) + // for params with dependent function type we generate something like this, e.g. for iterate: + // f :: ({ + // val l = f.t.x + // DepFunType[NatKind, PhraseType](l, + // FunType(expT(ArrayType(l * n, dt), read), expT(ArrayType(l, dt), write))) + // }) + case KindOrType.Type(typeAST@DPIA.Type.AST.DepFunType(id, _, _)) => + q"""${Term.Name(param.id.name)} :: { val ${Pat.Var(name = Term.Name(id.name))} = ${Term.Name(param.id.name)}.t.x ${generateTerm(typeAST)} }""" - // for all other parameters we generate the simple form: e :: t - case KindOrType.Type(typeAST) => - q"${Term.Name(param.id.name)} :: ${generateTerm(typeAST)}" - case KindOrType.Kind(_) => throw new Exception("Generation of type checks not necessary for kinds") - }).toList} - }""" + // for all other parameters we generate the simple form: e :: t + case KindOrType.Type(typeAST) => + q"${Term.Name(param.id.name)} :: ${generateTerm(typeAST)}" + case KindOrType.Kind(_) => throw new Exception("Generation of type checks not necessary for kinds") + }).toList + // add assertion around the checks to avoid them in release builds + if (checks.nonEmpty) { q"{ assert {..$checks; true} }" } else { q"{}" } + } def generateVariadicTypeCheck(param: DPIA.Decl.AST.Param, n: DPIA.Type.AST.Identifier, - typeAST: DPIA.Type.AST): scala.meta.Term = + typeAST: DPIA.Type.AST): scala.meta.Term.Block = // we check if there are unrolled ids (i.e. *id) in the expression getUnrolledIds(typeAST) match { - case Seq() => // ... if there are none, then we can generate the simple form with an foreach - q"${Term.Name(param.id.name)}.foreach(_ :: ${generateTerm(typeAST)})" + case Seq() => // ... if there are none, then we generate a length check and a simple checking form with an foreach + q"""{ + typeAssert(${Term.Name(param.id.name)}.length == ${Term.Name(n.name)}, + ${Lit.String(param.id.name)} + ".length == " + ${Lit.String(n.name)} + " is not true") + ${Term.Name(param.id.name)}.foreach(_ :: ${generateTerm(typeAST)}) + }""" case unrolledIds => // ... if there are, then we have to zip the unrolled ids with the param // we generate something like this, e.g. for foreign function call: + // typeAssert(args.length == n) + // typeAssert(inTs.length == n) // args.zip(inTs).foreach({ case (args, inTs) => args :: expT(inTs, read) }) - // generate a chain of nested zip calls (in case there are multiple unrolled ids) + // first generate the assertions ... + val firstAssertion = q"""{ + typeAssert(${Term.Name(param.id.name)}.length == ${Term.Name(n.name)}, + ${Lit.String(param.id.name)} + ".length == " + ${Lit.String(n.name)} + " is not true") + }""" + val allAssertions = unrolledIds.foldLeft[Term.Block](firstAssertion) { + case (block, id) => q"""{ + ..${block.stats} + typeAssert(${Term.Name(id.name)}.length == ${Term.Name(n.name)}, + ${Lit.String(id.name)} + ".length == " + ${Lit.String(n.name)} + " is not true") + }""" + } + + // ... then generate a chain of nested zip calls (in case there are multiple unrolled ids) val zips = unrolledIds.foldLeft[Term](Term.Name(param.id.name)) { case (term, id) => q"$term.zip(${Term.Name(id.name)})" } @@ -212,9 +230,12 @@ ${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), param val pattern = unrolledIds.foldRight[Pat](Pat.Var(p)) { case (id, pattern) => p"($pattern, ${Pat.Var(Term.Name(id.name))})" } - // ... putting everything together with a foreach into the final form - q"""$zips.foreach { - case ($pattern) => $p :: ${generateTerm(typeAST)} + // ... putting everything together with the assertions and the check with a foreach into the final form + q"""{ + ..${allAssertions.stats} + $zips.foreach { + case ($pattern) => $p :: ${generateTerm(typeAST)} + } }""" } diff --git a/src/main/scala/shine/DPIA/Types/package.scala b/src/main/scala/shine/DPIA/Types/package.scala index 5d1a7b630..0d5505364 100644 --- a/src/main/scala/shine/DPIA/Types/package.scala +++ b/src/main/scala/shine/DPIA/Types/package.scala @@ -4,17 +4,26 @@ import arithexpr.arithmetic.RangeAdd import shine.DPIA.Phrases.Phrase import shine.DPIA.Types.TypeCheck._ +import scala.annotation.elidable +import scala.annotation.elidable.ASSERTION + package object Types { + @elidable(ASSERTION) + def typeAssert[T <: PhraseType](p: Phrase[T], pt: PhraseType): Unit = { + if (!(p checkTypeEqOrSubtype pt)) + throw new java.lang.AssertionError(s"Type error: found ${p.t}, expected $pt") + } + + @elidable(ASSERTION) + def typeAssert(boolean: Boolean, msg: => String): Unit = { + if (!boolean) + throw new java.lang.AssertionError(s"Type error: $msg") + } + implicit class ReverseInferenceHelper(pt: PhraseType) { - def ::[T <: PhraseType](p: Phrase[T]): Unit = - if (!(p checkTypeEqOrSubtype pt)) { - throw new Exception(s"Type error: found ${p.t}, expected $pt") - } - def `:`[T <: PhraseType](p: Phrase[T]): Unit = - if (!(p checkTypeEqOrSubtype pt)) { - throw new Exception(s"Type error: found ${p.t}, expected $pt") - } + def ::[T <: PhraseType](p: Phrase[T]): Unit = typeAssert(p, pt) + def `:`[T <: PhraseType](p: Phrase[T]): Unit = typeAssert(p, pt) } type NatDependentFunctionType[T <: PhraseType] = DepFunType[NatKind, T] diff --git a/src/main/scala/shine/DPIA/primitives/functional/AsScalar.scala b/src/main/scala/shine/DPIA/primitives/functional/AsScalar.scala index 9d62a3b9c..8b7216d3e 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/AsScalar.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/AsScalar.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class AsScalar(val n: Nat, val m: Nat, val dt: DataType, val a: AccessType, val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { array :: expT(ArrayType(m, VectorType(n, dt)), a) + true } override val t: ExpType = expT(ArrayType(m * n, dt), a) override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsScalar = new AsScalar(v.nat(n), v.nat(m), v.data(dt), v.access(a), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/AsVector.scala b/src/main/scala/shine/DPIA/primitives/functional/AsVector.scala index dbe82e1ac..782d5fd84 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/AsVector.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/AsVector.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class AsVector(val n: Nat, val m: Nat, val dt: DataType, val a: AccessType, val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { array :: expT(ArrayType(m * n, dt), a) + true } override val t: ExpType = expT(ArrayType(m, VectorType(n, dt)), a) override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsVector = new AsVector(v.nat(n), v.nat(m), v.data(dt), v.access(a), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/AsVectorAligned.scala b/src/main/scala/shine/DPIA/primitives/functional/AsVectorAligned.scala index f84602328..84caf5271 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/AsVectorAligned.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/AsVectorAligned.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class AsVectorAligned(val n: Nat, val m: Nat, val dt: DataType, val a: AccessType, val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { array :: expT(ArrayType(m * n, dt), a) + true } override val t: ExpType = expT(ArrayType(m, VectorType(n, dt)), a) override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsVectorAligned = new AsVectorAligned(v.nat(n), v.nat(m), v.data(dt), v.access(a), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Cast.scala b/src/main/scala/shine/DPIA/primitives/functional/Cast.scala index 52fc46cb2..faffa0801 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Cast.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Cast.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Cast(val dt1: DataType, val dt2: DataType, val e: Phrase[ExpType]) extends ExpPrimitive { - { + assert { e :: expT(dt1, read) + true } override val t: ExpType = expT(dt2, read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Cast = new Cast(v.data(dt1), v.data(dt2), VisitAndRebuild(e, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/CircularBuffer.scala b/src/main/scala/shine/DPIA/primitives/functional/CircularBuffer.scala index 2154283c7..4e90533e0 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/CircularBuffer.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/CircularBuffer.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class CircularBuffer(val n: Nat, val alloc: Nat, val sz: Nat, val dt1: DataType, val dt2: DataType, val load: Phrase[FunType[ExpType, ExpType]], val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { load :: FunType(expT(dt1, read), expT(dt2, write)) input :: expT(ArrayType(n - 1 + sz, dt1), read) + true } override val t: ExpType = expT(ArrayType(n, ArrayType(sz, dt2)), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): CircularBuffer = new CircularBuffer(v.nat(n), v.nat(alloc), v.nat(sz), v.data(dt1), v.data(dt2), VisitAndRebuild(load, v), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Cycle.scala b/src/main/scala/shine/DPIA/primitives/functional/Cycle.scala index 995185f97..6e5242a96 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Cycle.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Cycle.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Cycle(val n: Nat, val m: Nat, val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { input :: expT(ArrayType(m, dt), read) + true } override val t: ExpType = expT(ArrayType(n, dt), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Cycle = new Cycle(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/DepIdx.scala b/src/main/scala/shine/DPIA/primitives/functional/DepIdx.scala index be23be676..1a0dd9fb0 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/DepIdx.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/DepIdx.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class DepIdx(val n: Nat, val ft: NatToData, val index: Nat, val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { array :: expT(DepArrayType(n, ft), read) + true } override val t: ExpType = expT(NatToDataApply(ft, index), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepIdx = new DepIdx(v.nat(n), v.natToData(ft), v.nat(index), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/DepJoin.scala b/src/main/scala/shine/DPIA/primitives/functional/DepJoin.scala index a853edac8..5a80c653d 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/DepJoin.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/DepJoin.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class DepJoin(val n: Nat, val lenF: NatToNat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { array :: expT(DepArrayType(n, n2dtFun { (i: NatIdentifier) => ArrayType(lenF(i), dt) }), read) + true } override val t: ExpType = expT(ArrayType(BigSum(from = 0, upTo = n - 1, (i: Nat) => lenF(i)), dt), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepJoin = new DepJoin(v.nat(n), v.natToNat(lenF), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala b/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala index 2dede1eb0..9e025b2a0 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala @@ -8,12 +8,13 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class DepMapSeq(unroll: Boolean)(val n: Nat, val ft1: NatToData, val ft2: NatToData, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: ({ val k = f.t.x DepFunType[NatKind, PhraseType](k, FunType(expT(NatToDataApply(ft1, k), read), expT(NatToDataApply(ft2, k), write))) }) array :: expT(DepArrayType(n, ft1), read) + true } override val t: ExpType = expT(DepArrayType(n, ft2), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepMapSeq = new DepMapSeq(unroll)(v.nat(n), v.natToData(ft1), v.natToData(ft2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/DepZip.scala b/src/main/scala/shine/DPIA/primitives/functional/DepZip.scala index 8864e226e..850d589ac 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/DepZip.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/DepZip.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class DepZip(val n: Nat, val ft1: NatToData, val ft2: NatToData, val e1: Phrase[ExpType], val e2: Phrase[ExpType]) extends ExpPrimitive { - { + assert { e1 :: expT(DepArrayType(n, ft1), read) e2 :: expT(DepArrayType(n, ft2), read) + true } override val t: ExpType = expT(DepArrayType(n, n2dtFun { (i: NatIdentifier) => PairType(NatToDataApply(ft1, i), NatToDataApply(ft2, i)) }), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepZip = new DepZip(v.nat(n), v.natToData(ft1), v.natToData(ft2), VisitAndRebuild(e1, v), VisitAndRebuild(e2, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Drop.scala b/src/main/scala/shine/DPIA/primitives/functional/Drop.scala index c90942c95..162cc5b0a 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Drop.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Drop.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Drop(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { array :: expT(ArrayType(n + m, dt), read) + true } override val t: ExpType = expT(ArrayType(m, dt), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Drop = new Drop(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala b/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala index 2b1e560cf..14b4845c9 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala @@ -8,11 +8,16 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ForeignFunctionCall(funDecl: rise.core.ForeignFunction.Decl, n: Int)(val inTs: Seq[DataType], val outT: DataType, val args: Seq[Phrase[ExpType]]) extends ExpPrimitive { - { - args.zip(inTs).foreach({ - case (args, inTs) => - args :: expT(inTs, read) - }) + assert { + { + typeAssert(args.length == n, "args" + ".length == " + "n" + " is not true") + typeAssert(inTs.length == n, "inTs" + ".length == " + "n" + " is not true") + args.zip(inTs).foreach({ + case (args, inTs) => + args :: expT(inTs, read) + }) + } + true } override val t: ExpType = expT(outT, read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): ForeignFunctionCall = new ForeignFunctionCall(funDecl, n)(inTs.map(v.data), v.data(outT), args.map(VisitAndRebuild(_, v))) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Fst.scala b/src/main/scala/shine/DPIA/primitives/functional/Fst.scala index 8230c972f..54ff742e4 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Fst.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Fst.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Fst(val dt1: DataType, val dt2: DataType, val pair: Phrase[ExpType]) extends ExpPrimitive { - { + assert { pair :: expT(PairType(dt1, dt2), read) + true } override val t: ExpType = expT(dt1, read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Fst = new Fst(v.data(dt1), v.data(dt2), VisitAndRebuild(pair, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Gather.scala b/src/main/scala/shine/DPIA/primitives/functional/Gather.scala index 343491501..1a2022c41 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Gather.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Gather.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Gather(val n: Nat, val m: Nat, val dt: DataType, val indices: Phrase[ExpType], val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { indices :: expT(ArrayType(m, IndexType(n)), read) input :: expT(ArrayType(n, dt), read) + true } override val t: ExpType = expT(ArrayType(m, dt), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Gather = new Gather(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(indices, v), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Generate.scala b/src/main/scala/shine/DPIA/primitives/functional/Generate.scala index 7cb88033c..5eecb7004 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Generate.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Generate.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Generate(val n: Nat, val dt: DataType, val f: Phrase[FunType[ExpType, ExpType]]) extends ExpPrimitive { - { + assert { f :: FunType(expT(IndexType(n), read), expT(dt, read)) + true } override val t: ExpType = expT(ArrayType(n, dt), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Generate = new Generate(v.nat(n), v.data(dt), VisitAndRebuild(f, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Idx.scala b/src/main/scala/shine/DPIA/primitives/functional/Idx.scala index 8ccfab1ea..dd7b10d36 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Idx.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Idx.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Idx(val n: Nat, val dt: DataType, val index: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { index :: expT(IndexType(n), read) array :: expT(ArrayType(n, dt), read) + true } override val t: ExpType = expT(dt, read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Idx = new Idx(v.nat(n), v.data(dt), VisitAndRebuild(index, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/IdxVec.scala b/src/main/scala/shine/DPIA/primitives/functional/IdxVec.scala index d2e4c6229..330efc2a8 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/IdxVec.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/IdxVec.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class IdxVec(val n: Nat, val dt: DataType, val index: Phrase[ExpType], val vector: Phrase[ExpType]) extends ExpPrimitive { - { + assert { index :: expT(IndexType(n), read) vector :: expT(VectorType(n, dt), read) + true } override val t: ExpType = expT(dt, read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): IdxVec = new IdxVec(v.nat(n), v.data(dt), VisitAndRebuild(index, v), VisitAndRebuild(vector, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/IndexAsNat.scala b/src/main/scala/shine/DPIA/primitives/functional/IndexAsNat.scala index 94528a350..f0eb60959 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/IndexAsNat.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/IndexAsNat.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class IndexAsNat(val n: Nat, val e: Phrase[ExpType]) extends ExpPrimitive { - { + assert { e :: expT(IndexType(n), read) + true } override val t: ExpType = expT(NatType, read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): IndexAsNat = new IndexAsNat(v.nat(n), VisitAndRebuild(e, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala b/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala index 6b20d3e4a..c1a153413 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala @@ -8,12 +8,13 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Iterate(val n: Nat, val m: Nat, val k: Nat, val dt: DataType, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: ({ val l = f.t.x DepFunType[NatKind, PhraseType](l, FunType(expT(ArrayType(l * n, dt), read), expT(ArrayType(l, dt), write))) }) array :: expT(ArrayType(m * n.pow(k), dt), read) + true } override val t: ExpType = expT(ArrayType(m, dt), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Iterate = new Iterate(v.nat(n), v.nat(m), v.nat(k), v.data(dt), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/IterateStream.scala b/src/main/scala/shine/DPIA/primitives/functional/IterateStream.scala index 531bca06f..bf8d58733 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/IterateStream.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/IterateStream.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class IterateStream(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt1, read), expT(dt2, write)) array :: expT(ArrayType(n, dt1), read) + true } override val t: ExpType = expT(ArrayType(n, dt2), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): IterateStream = new IterateStream(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Join.scala b/src/main/scala/shine/DPIA/primitives/functional/Join.scala index c42697619..453cb7dc6 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Join.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Join.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Join(val n: Nat, val m: Nat, val a: AccessType, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { array :: expT(ArrayType(n, ArrayType(m, dt)), a) + true } override val t: ExpType = expT(ArrayType(n * m, dt), a) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Join = new Join(v.nat(n), v.nat(m), v.access(a), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Let.scala b/src/main/scala/shine/DPIA/primitives/functional/Let.scala index cf627e092..820e3977e 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Let.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Let.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Let(val dt1: DataType, val dt2: DataType, val a: AccessType, val value: Phrase[ExpType], val f: Phrase[FunType[ExpType, ExpType]]) extends ExpPrimitive { - { + assert { value :: expT(dt1, read) f :: FunType(expT(dt1, read), expT(dt2, a)) + true } override val t: ExpType = expT(dt2, a) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Let = new Let(v.data(dt1), v.data(dt2), v.access(a), VisitAndRebuild(value, v), VisitAndRebuild(f, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala b/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala index e887fd51a..6ccd084cc 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala @@ -8,8 +8,12 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class MakeArray(n: Int)(val dt: DataType, val elements: Seq[Phrase[ExpType]]) extends ExpPrimitive { - { - elements.foreach(_ :: expT(dt, read)) + assert { + { + typeAssert(elements.length == n, "elements" + ".length == " + "n" + " is not true") + elements.foreach(_ :: expT(dt, read)) + } + true } override val t: ExpType = expT(ArrayType(n, dt), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): MakeArray = new MakeArray(n)(v.data(dt), elements.map(VisitAndRebuild(_, v))) diff --git a/src/main/scala/shine/DPIA/primitives/functional/MakePair.scala b/src/main/scala/shine/DPIA/primitives/functional/MakePair.scala index 38b54d25d..f4b221b5d 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MakePair.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MakePair.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class MakePair(val dt1: DataType, val dt2: DataType, val a: AccessType, val fst: Phrase[ExpType], val snd: Phrase[ExpType]) extends ExpPrimitive { - { + assert { fst :: expT(dt1, a) snd :: expT(dt2, a) + true } override val t: ExpType = expT(PairType(dt1, dt2), a) override def visitAndRebuild(v: VisitAndRebuild.Visitor): MakePair = new MakePair(v.data(dt1), v.data(dt2), v.access(a), VisitAndRebuild(fst, v), VisitAndRebuild(snd, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Map.scala b/src/main/scala/shine/DPIA/primitives/functional/Map.scala index 64896fd85..d91e96c3a 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Map.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Map.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Map(val n: Nat, val dt1: DataType, val dt2: DataType, val a: AccessType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt1, a), expT(dt2, a)) array :: expT(ArrayType(n, dt1), a) + true } override val t: ExpType = expT(ArrayType(n, dt2), a) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Map = new Map(v.nat(n), v.data(dt1), v.data(dt2), v.access(a), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/MapFst.scala b/src/main/scala/shine/DPIA/primitives/functional/MapFst.scala index 5005b6e07..08a5fba8a 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MapFst.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MapFst.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class MapFst(val a: AccessType, val dt1: DataType, val dt2: DataType, val dt3: DataType, val f: Phrase[FunType[ExpType, ExpType]], val pair: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt1, a), expT(dt3, a)) pair :: expT(PairType(dt1, dt2), a) + true } override val t: ExpType = expT(PairType(dt3, dt2), a) override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapFst = new MapFst(v.access(a), v.data(dt1), v.data(dt2), v.data(dt3), VisitAndRebuild(f, v), VisitAndRebuild(pair, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/MapSeq.scala b/src/main/scala/shine/DPIA/primitives/functional/MapSeq.scala index e18bf1a2b..2f9d4894e 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MapSeq.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MapSeq.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class MapSeq(unroll: Boolean)(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt1, read), expT(dt2, write)) array :: expT(ArrayType(n, dt1), read) + true } override val t: ExpType = expT(ArrayType(n, dt2), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapSeq = new MapSeq(unroll)(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/MapSnd.scala b/src/main/scala/shine/DPIA/primitives/functional/MapSnd.scala index 410b3dc86..abbb7f440 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MapSnd.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MapSnd.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class MapSnd(val a: AccessType, val dt1: DataType, val dt2: DataType, val dt3: DataType, val f: Phrase[FunType[ExpType, ExpType]], val pair: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt2, a), expT(dt3, a)) pair :: expT(PairType(dt1, dt2), a) + true } override val t: ExpType = expT(PairType(dt1, dt3), a) override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapSnd = new MapSnd(v.access(a), v.data(dt1), v.data(dt2), v.data(dt3), VisitAndRebuild(f, v), VisitAndRebuild(pair, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/MapStream.scala b/src/main/scala/shine/DPIA/primitives/functional/MapStream.scala index c79d554c6..14e98ee19 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MapStream.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MapStream.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class MapStream(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt1, read), expT(dt2, write)) array :: expT(ArrayType(n, dt1), read) + true } override val t: ExpType = expT(ArrayType(n, dt2), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapStream = new MapStream(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/MapVec.scala b/src/main/scala/shine/DPIA/primitives/functional/MapVec.scala index 8583c637d..1071a8583 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MapVec.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MapVec.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class MapVec(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt1, read), expT(dt2, write)) array :: expT(VectorType(n, dt1), read) + true } override val t: ExpType = expT(VectorType(n, dt2), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapVec = new MapVec(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/NatAsIndex.scala b/src/main/scala/shine/DPIA/primitives/functional/NatAsIndex.scala index a7a69904d..e1ccc91ea 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/NatAsIndex.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/NatAsIndex.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class NatAsIndex(val n: Nat, val e: Phrase[ExpType]) extends ExpPrimitive { - { + assert { e :: expT(NatType, read) + true } override val t: ExpType = expT(IndexType(n), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): NatAsIndex = new NatAsIndex(v.nat(n), VisitAndRebuild(e, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/PadClamp.scala b/src/main/scala/shine/DPIA/primitives/functional/PadClamp.scala index d8f44a1b4..fcf88752f 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/PadClamp.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/PadClamp.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class PadClamp(val n: Nat, val l: Nat, val r: Nat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { array :: expT(ArrayType(n, dt), read) + true } override val t: ExpType = expT(ArrayType(l + n + r, dt), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): PadClamp = new PadClamp(v.nat(n), v.nat(l), v.nat(r), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/PadCst.scala b/src/main/scala/shine/DPIA/primitives/functional/PadCst.scala index 18cf95c23..be4f5d53c 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/PadCst.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/PadCst.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class PadCst(val n: Nat, val l: Nat, val r: Nat, val dt: DataType, val padExp: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { padExp :: expT(dt, read) array :: expT(ArrayType(n, dt), read) + true } override val t: ExpType = expT(ArrayType(l + n + r, dt), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): PadCst = new PadCst(v.nat(n), v.nat(l), v.nat(r), v.data(dt), VisitAndRebuild(padExp, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/PadEmpty.scala b/src/main/scala/shine/DPIA/primitives/functional/PadEmpty.scala index ed4d9d75e..11326cfd4 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/PadEmpty.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/PadEmpty.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class PadEmpty(val n: Nat, val r: Nat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { array :: expT(ArrayType(n, dt), write) + true } override val t: ExpType = expT(ArrayType(n + r, dt), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): PadEmpty = new PadEmpty(v.nat(n), v.nat(r), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Partition.scala b/src/main/scala/shine/DPIA/primitives/functional/Partition.scala index 74bdf5622..594dd8810 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Partition.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Partition.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Partition(val n: Nat, val m: Nat, val lenF: NatToNat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { array :: expT(ArrayType(n, dt), read) + true } override val t: ExpType = expT(DepArrayType(m, n2dtFun { (i: NatIdentifier) => ArrayType(lenF(i), dt) }), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Partition = new Partition(v.nat(n), v.nat(m), v.natToNat(lenF), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/ReduceSeq.scala b/src/main/scala/shine/DPIA/primitives/functional/ReduceSeq.scala index bb07ef410..69c4934dc 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/ReduceSeq.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/ReduceSeq.scala @@ -8,10 +8,11 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ReduceSeq(unroll: Boolean)(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], val init: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt2, read), FunType(expT(dt1, read), expT(dt2, write))) init :: expT(dt2, write) array :: expT(ArrayType(n, dt1), read) + true } override val t: ExpType = expT(dt2, read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): ReduceSeq = new ReduceSeq(unroll)(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(init, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Reorder.scala b/src/main/scala/shine/DPIA/primitives/functional/Reorder.scala index 9c9457bf1..7a64488a0 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Reorder.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Reorder.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Reorder(val n: Nat, val dt: DataType, val a: AccessType, val idxF: NatToNat, val idxFiv: NatToNat, val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { input :: expT(ArrayType(n, dt), a) + true } override val t: ExpType = expT(ArrayType(n, dt), a) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Reorder = new Reorder(v.nat(n), v.data(dt), v.access(a), v.natToNat(idxF), v.natToNat(idxFiv), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/RotateValues.scala b/src/main/scala/shine/DPIA/primitives/functional/RotateValues.scala index 04cbf6714..6ce75ced5 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/RotateValues.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/RotateValues.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class RotateValues(val n: Nat, val sz: Nat, val dt: DataType, val wrt: Phrase[FunType[ExpType, ExpType]], val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { wrt :: FunType(expT(dt, read), expT(dt, write)) input :: expT(ArrayType(n - 1 + sz, dt), read) + true } override val t: ExpType = expT(ArrayType(n, ArrayType(sz, dt)), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): RotateValues = new RotateValues(v.nat(n), v.nat(sz), v.data(dt), VisitAndRebuild(wrt, v), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/ScanSeq.scala b/src/main/scala/shine/DPIA/primitives/functional/ScanSeq.scala index 6be6db44b..b77245581 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/ScanSeq.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/ScanSeq.scala @@ -8,10 +8,11 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ScanSeq(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], val init: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt1, read), FunType(expT(dt2, read), expT(dt2, write))) init :: expT(dt2, write) array :: expT(ArrayType(n, dt1), read) + true } override val t: ExpType = expT(ArrayType(n, dt2), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): ScanSeq = new ScanSeq(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(init, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Scatter.scala b/src/main/scala/shine/DPIA/primitives/functional/Scatter.scala index 85c2bbbf7..9b6f0e411 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Scatter.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Scatter.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Scatter(val n: Nat, val m: Nat, val dt: DataType, val indices: Phrase[ExpType], val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { indices :: expT(ArrayType(n, IndexType(m)), read) input :: expT(ArrayType(n, dt), write) + true } override val t: ExpType = expT(ArrayType(m, dt), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Scatter = new Scatter(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(indices, v), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Slide.scala b/src/main/scala/shine/DPIA/primitives/functional/Slide.scala index 92fdf0bbf..3498d138b 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Slide.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Slide.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Slide(val n: Nat, val sz: Nat, val sp: Nat, val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { input :: expT(ArrayType(sp * n + sz, dt), read) + true } override val t: ExpType = expT(ArrayType(1 + n, ArrayType(sz, dt)), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Slide = new Slide(v.nat(n), v.nat(sz), v.nat(sp), v.data(dt), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Snd.scala b/src/main/scala/shine/DPIA/primitives/functional/Snd.scala index 77a2e9488..03361ef0b 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Snd.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Snd.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Snd(val dt1: DataType, val dt2: DataType, val pair: Phrase[ExpType]) extends ExpPrimitive { - { + assert { pair :: expT(PairType(dt1, dt2), read) + true } override val t: ExpType = expT(dt2, read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Snd = new Snd(v.data(dt1), v.data(dt2), VisitAndRebuild(pair, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Split.scala b/src/main/scala/shine/DPIA/primitives/functional/Split.scala index 426e6af3a..251c20400 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Split.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Split.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Split(val n: Nat, val m: Nat, val a: AccessType, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { array :: expT(ArrayType(m * n, dt), a) + true } override val t: ExpType = expT(ArrayType(m, ArrayType(n, dt)), a) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Split = new Split(v.nat(n), v.nat(m), v.access(a), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Take.scala b/src/main/scala/shine/DPIA/primitives/functional/Take.scala index ffd625320..cc9d314cc 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Take.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Take.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Take(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { array :: expT(ArrayType(n + m, dt), read) + true } override val t: ExpType = expT(ArrayType(n, dt), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Take = new Take(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/ToMem.scala b/src/main/scala/shine/DPIA/primitives/functional/ToMem.scala index afd8d1bdf..168c2053d 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/ToMem.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/ToMem.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ToMem(val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { input :: expT(dt, write) + true } override val t: ExpType = expT(dt, read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): ToMem = new ToMem(v.data(dt), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Transpose.scala b/src/main/scala/shine/DPIA/primitives/functional/Transpose.scala index ecdf22dc1..f343da8d2 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Transpose.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Transpose.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Transpose(val n: Nat, val m: Nat, val dt: DataType, val a: AccessType, val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { array :: expT(ArrayType(n, ArrayType(m, dt)), a) + true } override val t: ExpType = expT(ArrayType(m, ArrayType(n, dt)), a) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Transpose = new Transpose(v.nat(n), v.nat(m), v.data(dt), v.access(a), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/TransposeDepArray.scala b/src/main/scala/shine/DPIA/primitives/functional/TransposeDepArray.scala index 276a3f223..cbc14110f 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/TransposeDepArray.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/TransposeDepArray.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class TransposeDepArray(val n: Nat, val m: Nat, val ft: NatToData, val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { array :: expT(ArrayType(n, DepArrayType(m, ft)), read) + true } override val t: ExpType = expT(DepArrayType(n, n2dtFun { (i: NatIdentifier) => ArrayType(n, NatToDataApply(ft, i)) }), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): TransposeDepArray = new TransposeDepArray(v.nat(n), v.nat(m), v.natToData(ft), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Unzip.scala b/src/main/scala/shine/DPIA/primitives/functional/Unzip.scala index 12d39f4bf..ad47ab8ef 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Unzip.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Unzip.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Unzip(val n: Nat, val dt1: DataType, val dt2: DataType, val a: AccessType, val e: Phrase[ExpType]) extends ExpPrimitive { - { + assert { e :: expT(ArrayType(n, PairType(dt1, dt2)), a) + true } override val t: ExpType = expT(PairType(ArrayType(n, dt1), ArrayType(n, dt2)), a) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Unzip = new Unzip(v.nat(n), v.data(dt1), v.data(dt2), v.access(a), VisitAndRebuild(e, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/VectorFromScalar.scala b/src/main/scala/shine/DPIA/primitives/functional/VectorFromScalar.scala index 91b183b13..315e2d08e 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/VectorFromScalar.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/VectorFromScalar.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class VectorFromScalar(val n: Nat, val dt: DataType, val arg: Phrase[ExpType]) extends ExpPrimitive { - { + assert { arg :: expT(dt, read) + true } override val t: ExpType = expT(VectorType(n, dt), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): VectorFromScalar = new VectorFromScalar(v.nat(n), v.data(dt), VisitAndRebuild(arg, v)) diff --git a/src/main/scala/shine/DPIA/primitives/functional/Zip.scala b/src/main/scala/shine/DPIA/primitives/functional/Zip.scala index 3f0c328df..532af46c8 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Zip.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Zip.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Zip(val n: Nat, val dt1: DataType, val dt2: DataType, val a: AccessType, val e1: Phrase[ExpType], val e2: Phrase[ExpType]) extends ExpPrimitive { - { + assert { e1 :: expT(ArrayType(n, dt1), a) e2 :: expT(ArrayType(n, dt2), a) + true } override val t: ExpType = expT(ArrayType(n, PairType(dt1, dt2)), a) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Zip = new Zip(v.nat(n), v.data(dt1), v.data(dt2), v.access(a), VisitAndRebuild(e1, v), VisitAndRebuild(e2, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/AsScalarAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/AsScalarAcc.scala index 7c064f934..57f0c4fb2 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/AsScalarAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/AsScalarAcc.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class AsScalarAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { - { + assert { array :: accT(ArrayType(m * n, dt)) + true } override val t: AccType = accT(ArrayType(m, VectorType(n, dt))) override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsScalarAcc = new AsScalarAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/AsVectorAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/AsVectorAcc.scala index d68e07768..6e79440fd 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/AsVectorAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/AsVectorAcc.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class AsVectorAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { - { + assert { array :: accT(ArrayType(m, VectorType(n, dt))) + true } override val t: AccType = accT(ArrayType(n * m, dt)) override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsVectorAcc = new AsVectorAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/Assign.scala b/src/main/scala/shine/DPIA/primitives/imperative/Assign.scala index 0451dcffa..76dfd7c3d 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/Assign.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/Assign.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Assign(val dt: DataType, val lhs: Phrase[AccType], val rhs: Phrase[ExpType]) extends CommandPrimitive { - { + assert { lhs :: accT(dt) rhs :: expT(dt, read) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): Assign = new Assign(v.data(dt), VisitAndRebuild(lhs, v), VisitAndRebuild(rhs, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/Comment.scala b/src/main/scala/shine/DPIA/primitives/imperative/Comment.scala index c1416c259..8a6dab26e 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/Comment.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/Comment.scala @@ -8,7 +8,6 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Comment(comment: String)() extends CommandPrimitive { - {} override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): Comment = new Comment(comment)() } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/CycleAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/CycleAcc.scala index 06be5aec8..c762a87b8 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/CycleAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/CycleAcc.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class CycleAcc(val n: Nat, val m: Nat, val dt: DataType, val input: Phrase[AccType]) extends AccPrimitive { - { + assert { input :: accT(ArrayType(m, dt)) + true } override val t: AccType = accT(ArrayType(n, dt)) override def visitAndRebuild(v: VisitAndRebuild.Visitor): CycleAcc = new CycleAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/DepIdxAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/DepIdxAcc.scala index a887fef02..677b2ec02 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/DepIdxAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/DepIdxAcc.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class DepIdxAcc(val n: Nat, val ft: NatToData, val index: Nat, val array: Phrase[AccType]) extends AccPrimitive { - { + assert { array :: accT(DepArrayType(n, ft)) + true } override val t: AccType = accT(NatToDataApply(ft, index)) override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepIdxAcc = new DepIdxAcc(v.nat(n), v.natToData(ft), v.nat(index), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/DepJoinAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/DepJoinAcc.scala index 011ead569..c178ce56c 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/DepJoinAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/DepJoinAcc.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class DepJoinAcc(val n: Nat, val lenF: NatToNat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { - { + assert { array :: accT(ArrayType(BigSum(from = 0, upTo = n - 1, (i: Nat) => lenF(i)), dt)) + true } override val t: AccType = accT(DepArrayType(n, n2dtFun { (i: NatIdentifier) => ArrayType(lenF(i), dt) })) override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepJoinAcc = new DepJoinAcc(v.nat(n), v.natToNat(lenF), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/DropAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/DropAcc.scala index a09617ffc..ea453417d 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/DropAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/DropAcc.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class DropAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { - { + assert { array :: accT(ArrayType(n + m, dt)) + true } override val t: AccType = accT(ArrayType(m - n, dt)) override def visitAndRebuild(v: VisitAndRebuild.Visitor): DropAcc = new DropAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/For.scala b/src/main/scala/shine/DPIA/primitives/imperative/For.scala index 3d5d4ef1c..c9634fd3d 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/For.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/For.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class For(unroll: Boolean)(val n: Nat, val loopBody: Phrase[FunType[ExpType, CommType]]) extends CommandPrimitive { - { + assert { loopBody :: FunType(expT(IndexType(n), read), comm) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): For = new For(unroll)(v.nat(n), VisitAndRebuild(loopBody, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala b/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala index 6bc08c265..b4615241c 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala @@ -8,11 +8,12 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ForNat(unroll: Boolean)(val n: Nat, val loopBody: Phrase[DepFunType[NatKind, CommType]]) extends CommandPrimitive { - { + assert { loopBody :: ({ val i = loopBody.t.x DepFunType[NatKind, PhraseType](i, comm) }) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): ForNat = new ForNat(unroll)(v.nat(n), VisitAndRebuild(loopBody, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ForVec.scala b/src/main/scala/shine/DPIA/primitives/imperative/ForVec.scala index 2d4967f8a..5ae17f755 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/ForVec.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/ForVec.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ForVec(val n: Nat, val dt: DataType, val out: Phrase[AccType], val loopBody: Phrase[FunType[ExpType, FunType[AccType, CommType]]]) extends CommandPrimitive { - { + assert { out :: accT(VectorType(n, dt)) loopBody :: FunType(expT(IndexType(n), read), FunType(accT(dt), comm)) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): ForVec = new ForVec(v.nat(n), v.data(dt), VisitAndRebuild(out, v), VisitAndRebuild(loopBody, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/GenerateCont.scala b/src/main/scala/shine/DPIA/primitives/imperative/GenerateCont.scala index 029fc9810..8a705214f 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/GenerateCont.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/GenerateCont.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class GenerateCont(val n: Nat, val dt: DataType, val f: Phrase[FunType[ExpType, FunType[FunType[ExpType, CommType], CommType]]]) extends ExpPrimitive { - { + assert { f :: FunType(expT(IndexType(n), read), FunType(FunType(expT(dt, read), comm), comm)) + true } override val t: ExpType = expT(ArrayType(n, dt), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): GenerateCont = new GenerateCont(v.nat(n), v.data(dt), VisitAndRebuild(f, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/IdxAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/IdxAcc.scala index 09a1942a5..b3e09836b 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/IdxAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/IdxAcc.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class IdxAcc(val n: Nat, val dt: DataType, val index: Phrase[ExpType], val array: Phrase[AccType]) extends AccPrimitive { - { + assert { index :: expT(IndexType(n), read) array :: accT(ArrayType(n, dt)) + true } override val t: AccType = accT(dt) override def visitAndRebuild(v: VisitAndRebuild.Visitor): IdxAcc = new IdxAcc(v.nat(n), v.data(dt), VisitAndRebuild(index, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/IdxVecAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/IdxVecAcc.scala index 0566caa11..f2097a274 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/IdxVecAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/IdxVecAcc.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class IdxVecAcc(val n: Nat, val dt: DataType, val index: Phrase[ExpType], val vector: Phrase[AccType]) extends AccPrimitive { - { + assert { index :: expT(IndexType(n), read) vector :: accT(VectorType(n, dt)) + true } override val t: AccType = accT(dt) override def visitAndRebuild(v: VisitAndRebuild.Visitor): IdxVecAcc = new IdxVecAcc(v.nat(n), v.data(dt), VisitAndRebuild(index, v), VisitAndRebuild(vector, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/JoinAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/JoinAcc.scala index fe0923988..eac2c66a5 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/JoinAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/JoinAcc.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class JoinAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { - { + assert { array :: accT(ArrayType(n * m, dt)) + true } override val t: AccType = accT(ArrayType(n, ArrayType(m, dt))) override def visitAndRebuild(v: VisitAndRebuild.Visitor): JoinAcc = new JoinAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MapAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/MapAcc.scala index b4cd17cec..43da5f66b 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/MapAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/MapAcc.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class MapAcc(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[AccType, AccType]], val array: Phrase[AccType]) extends AccPrimitive { - { + assert { f :: FunType(accT(dt1), accT(dt2)) array :: accT(ArrayType(n, dt1)) + true } override val t: AccType = accT(ArrayType(n, dt2)) override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapAcc = new MapAcc(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MapFstAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/MapFstAcc.scala index 4e90c0d42..2d71702e3 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/MapFstAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/MapFstAcc.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class MapFstAcc(val dt1: DataType, val dt2: DataType, val dt3: DataType, val f: Phrase[FunType[AccType, AccType]], val record: Phrase[AccType]) extends AccPrimitive { - { + assert { f :: FunType(accT(dt3), accT(dt1)) record :: accT(PairType(dt3, dt2)) + true } override val t: AccType = accT(PairType(dt1, dt2)) override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapFstAcc = new MapFstAcc(v.data(dt1), v.data(dt2), v.data(dt3), VisitAndRebuild(f, v), VisitAndRebuild(record, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MapRead.scala b/src/main/scala/shine/DPIA/primitives/imperative/MapRead.scala index 67fe5ad67..b6a4ba214 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/MapRead.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/MapRead.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class MapRead(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, FunType[FunType[ExpType, CommType], CommType]]], val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt1, read), FunType(FunType(expT(dt2, read), comm), comm)) input :: expT(ArrayType(n, dt1), read) + true } override val t: ExpType = expT(ArrayType(n, dt2), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapRead = new MapRead(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MapSndAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/MapSndAcc.scala index 3c163d1cf..7257504ee 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/MapSndAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/MapSndAcc.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class MapSndAcc(val dt1: DataType, val dt2: DataType, val dt3: DataType, val f: Phrase[FunType[AccType, AccType]], val record: Phrase[AccType]) extends AccPrimitive { - { + assert { f :: FunType(accT(dt3), accT(dt2)) record :: accT(PairType(dt1, dt3)) + true } override val t: AccType = accT(PairType(dt1, dt2)) override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapSndAcc = new MapSndAcc(v.data(dt1), v.data(dt2), v.data(dt3), VisitAndRebuild(f, v), VisitAndRebuild(record, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/New.scala b/src/main/scala/shine/DPIA/primitives/imperative/New.scala index b376a80a9..4184e9933 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/New.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/New.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class New(val dt: DataType, val f: Phrase[FunType[PhrasePairType[ExpType, AccType], CommType]]) extends CommandPrimitive { - { + assert { f :: FunType(PhrasePairType(expT(dt, read), accT(dt)), comm) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): New = new New(v.data(dt), VisitAndRebuild(f, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/NewDoubleBuffer.scala b/src/main/scala/shine/DPIA/primitives/imperative/NewDoubleBuffer.scala index 7001ed75c..2616c91e7 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/NewDoubleBuffer.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/NewDoubleBuffer.scala @@ -8,10 +8,11 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class NewDoubleBuffer(val dt1: DataType, val dt2: DataType, val dt3: DataType, val n: Nat, val in: Phrase[ExpType], val out: Phrase[AccType], val f: Phrase[FunType[PhrasePairType[PhrasePairType[PhrasePairType[ExpType, AccType], CommType], CommType], CommType]]) extends CommandPrimitive { - { + assert { in :: expT(dt1, read) out :: accT(dt2) f :: FunType(PhrasePairType(PhrasePairType(PhrasePairType(expT(ArrayType(n, dt3), read), accT(ArrayType(n, dt3))), comm), comm), comm) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): NewDoubleBuffer = new NewDoubleBuffer(v.data(dt1), v.data(dt2), v.data(dt3), v.nat(n), VisitAndRebuild(in, v), VisitAndRebuild(out, v), VisitAndRebuild(f, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc.scala index b6f03c5c7..c672e5fa4 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class PairAcc(val dt1: DataType, val dt2: DataType, val fst: Phrase[AccType], val snd: Phrase[AccType]) extends AccPrimitive { - { + assert { fst :: accT(dt1) snd :: accT(dt2) + true } override val t: AccType = accT(PairType(dt1, dt2)) override def visitAndRebuild(v: VisitAndRebuild.Visitor): PairAcc = new PairAcc(v.data(dt1), v.data(dt2), VisitAndRebuild(fst, v), VisitAndRebuild(snd, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc1.scala b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc1.scala index 924410406..d21f85322 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc1.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc1.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class PairAcc1(val dt1: DataType, val dt2: DataType, val pair: Phrase[AccType]) extends AccPrimitive { - { + assert { pair :: accT(PairType(dt1, dt2)) + true } override val t: AccType = accT(dt1) override def visitAndRebuild(v: VisitAndRebuild.Visitor): PairAcc1 = new PairAcc1(v.data(dt1), v.data(dt2), VisitAndRebuild(pair, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc2.scala b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc2.scala index f49ccd845..ee1fa5924 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc2.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc2.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class PairAcc2(val dt1: DataType, val dt2: DataType, val pair: Phrase[AccType]) extends AccPrimitive { - { + assert { pair :: accT(PairType(dt1, dt2)) + true } override val t: AccType = accT(dt2) override def visitAndRebuild(v: VisitAndRebuild.Visitor): PairAcc2 = new PairAcc2(v.data(dt1), v.data(dt2), VisitAndRebuild(pair, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ReorderAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/ReorderAcc.scala index 7820bfcf0..0924593d4 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/ReorderAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/ReorderAcc.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ReorderAcc(val n: Nat, val dt: DataType, val idxF: NatToNat, val array: Phrase[AccType]) extends AccPrimitive { - { + assert { array :: accT(ArrayType(n, dt)) + true } override val t: AccType = accT(ArrayType(n, dt)) override def visitAndRebuild(v: VisitAndRebuild.Visitor): ReorderAcc = new ReorderAcc(v.nat(n), v.data(dt), v.natToNat(idxF), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ScatterAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/ScatterAcc.scala index 483b94e19..193a376d2 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/ScatterAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/ScatterAcc.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ScatterAcc(val n: Nat, val m: Nat, val dt: DataType, val indices: Phrase[ExpType], val array: Phrase[AccType]) extends AccPrimitive { - { + assert { indices :: expT(ArrayType(n, IndexType(m)), read) array :: accT(ArrayType(m, dt)) + true } override val t: AccType = accT(ArrayType(n, dt)) override def visitAndRebuild(v: VisitAndRebuild.Visitor): ScatterAcc = new ScatterAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(indices, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/Seq.scala b/src/main/scala/shine/DPIA/primitives/imperative/Seq.scala index 15500208b..52d0af3e6 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/Seq.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/Seq.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Seq(val c1: Phrase[CommType], val c2: Phrase[CommType]) extends CommandPrimitive { - { + assert { c1 :: comm c2 :: comm + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): Seq = new Seq(VisitAndRebuild(c1, v), VisitAndRebuild(c2, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/Skip.scala b/src/main/scala/shine/DPIA/primitives/imperative/Skip.scala index 1d8ac8dec..eb3f6db60 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/Skip.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/Skip.scala @@ -8,7 +8,6 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Skip() extends CommandPrimitive { - {} override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): Skip = new Skip() } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/SplitAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/SplitAcc.scala index 4ad7b97c6..c398e7208 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/SplitAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/SplitAcc.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class SplitAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { - { + assert { array :: accT(ArrayType(m, ArrayType(n, dt))) + true } override val t: AccType = accT(ArrayType(n * m, dt)) override def visitAndRebuild(v: VisitAndRebuild.Visitor): SplitAcc = new SplitAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/TakeAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/TakeAcc.scala index 923cb6ded..33c4d1ee6 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/TakeAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/TakeAcc.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class TakeAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { - { + assert { array :: accT(ArrayType(n + m, dt)) + true } override val t: AccType = accT(ArrayType(n, dt)) override def visitAndRebuild(v: VisitAndRebuild.Visitor): TakeAcc = new TakeAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/TransposeAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/TransposeAcc.scala index 3944e3d8a..b7ecbdd12 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/TransposeAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/TransposeAcc.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class TransposeAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { - { + assert { array :: accT(ArrayType(m, ArrayType(n, dt))) + true } override val t: AccType = accT(ArrayType(n, ArrayType(m, dt))) override def visitAndRebuild(v: VisitAndRebuild.Visitor): TransposeAcc = new TransposeAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/UnzipAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/UnzipAcc.scala index 9e0069f74..3001610e9 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/UnzipAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/UnzipAcc.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class UnzipAcc(val n: Nat, val dt1: DataType, val dt2: DataType, val a: Phrase[AccType]) extends AccPrimitive { - { + assert { a :: accT(PairType(ArrayType(n, dt1), ArrayType(n, dt2))) + true } override val t: AccType = accT(ArrayType(n, PairType(dt1, dt2))) override def visitAndRebuild(v: VisitAndRebuild.Visitor): UnzipAcc = new UnzipAcc(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(a, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc1.scala b/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc1.scala index 28343629d..99151c48b 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc1.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc1.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ZipAcc1(val n: Nat, val dt1: DataType, val dt2: DataType, val array: Phrase[AccType]) extends AccPrimitive { - { + assert { array :: accT(ArrayType(n, PairType(dt1, dt2))) + true } override val t: AccType = accT(ArrayType(n, dt1)) override def visitAndRebuild(v: VisitAndRebuild.Visitor): ZipAcc1 = new ZipAcc1(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc2.scala b/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc2.scala index f84abc66e..39dc674c2 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc2.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc2.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ZipAcc2(val n: Nat, val dt1: DataType, val dt2: DataType, val array: Phrase[AccType]) extends AccPrimitive { - { + assert { array :: accT(ArrayType(n, PairType(dt1, dt2))) + true } override val t: AccType = accT(ArrayType(n, dt2)) override def visitAndRebuild(v: VisitAndRebuild.Visitor): ZipAcc2 = new ZipAcc2(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/functional/CircularBuffer.scala b/src/main/scala/shine/OpenCL/primitives/functional/CircularBuffer.scala index f4b206cbf..c23940e09 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/CircularBuffer.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/CircularBuffer.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class CircularBuffer(val a: AddressSpace, val n: Nat, val alloc: Nat, val sz: Nat, val dt1: DataType, val dt2: DataType, val load: Phrase[FunType[ExpType, ExpType]], val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { load :: FunType(expT(dt1, read), expT(dt2, write)) input :: expT(ArrayType(n - 1 + sz, dt1), read) + true } override val t: ExpType = expT(ArrayType(n, ArrayType(sz, dt2)), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): CircularBuffer = new CircularBuffer(v.addressSpace(a), v.nat(n), v.nat(alloc), v.nat(sz), v.data(dt1), v.data(dt2), VisitAndRebuild(load, v), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala b/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala index adfe41170..8fa77c246 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala @@ -8,12 +8,13 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class DepMap(level: shine.OpenCL.ParallelismLevel, dim: Int)(val n: Nat, val ft1: NatToData, val ft2: NatToData, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: ({ val m = f.t.x DepFunType[NatKind, PhraseType](m, FunType(expT(NatToDataApply(ft1, m), read), expT(NatToDataApply(ft2, m), write))) }) array :: expT(DepArrayType(n, ft1), read) + true } override val t: ExpType = expT(DepArrayType(n, ft2), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepMap = new DepMap(level, dim)(v.nat(n), v.natToData(ft1), v.natToData(ft2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala b/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala index 0387931b5..3fa11ab8b 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala @@ -8,12 +8,13 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Iterate(val a: AddressSpace, val n: Nat, val m: Nat, val k: Nat, val dt: DataType, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: ({ val l = f.t.x DepFunType[NatKind, PhraseType](l, FunType(expT(ArrayType(l * n, dt), read), expT(ArrayType(l, dt), write))) }) array :: expT(ArrayType(m * n.pow(k), dt), read) + true } override val t: ExpType = expT(ArrayType(m, dt), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Iterate = new Iterate(v.addressSpace(a), v.nat(n), v.nat(m), v.nat(k), v.data(dt), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala b/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala index 0f66f31cc..d31bce538 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala @@ -8,11 +8,16 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class KernelCall(name: String, localSize: shine.OpenCL.LocalSize, globalSize: shine.OpenCL.GlobalSize, n: Int)(val inTs: Seq[DataType], val outT: DataType, val args: Seq[Phrase[ExpType]]) extends ExpPrimitive { - { - args.zip(inTs).foreach({ - case (args, inTs) => - args :: expT(inTs, read) - }) + assert { + { + typeAssert(args.length == n, "args" + ".length == " + "n" + " is not true") + typeAssert(inTs.length == n, "inTs" + ".length == " + "n" + " is not true") + args.zip(inTs).foreach({ + case (args, inTs) => + args :: expT(inTs, read) + }) + } + true } override val t: ExpType = expT(outT, write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): KernelCall = new KernelCall(name, localSize.visitAndRebuild(v), globalSize.visitAndRebuild(v), n)(inTs.map(v.data), v.data(outT), args.map(VisitAndRebuild(_, v))) diff --git a/src/main/scala/shine/OpenCL/primitives/functional/Map.scala b/src/main/scala/shine/OpenCL/primitives/functional/Map.scala index a99170788..e5d1c1c24 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/Map.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/Map.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Map(level: shine.OpenCL.ParallelismLevel, dim: Int)(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt1, read), expT(dt2, write)) array :: expT(ArrayType(n, dt1), read) + true } override val t: ExpType = expT(ArrayType(n, dt2), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Map = new Map(level, dim)(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala b/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala index 363c6d792..3f2330c31 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala @@ -8,11 +8,16 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class OpenCLFunctionCall(name: String, n: Int)(val inTs: Seq[DataType], val outT: DataType, val args: Seq[Phrase[ExpType]]) extends ExpPrimitive { - { - args.zip(inTs).foreach({ - case (args, inTs) => - args :: expT(inTs, read) - }) + assert { + { + typeAssert(args.length == n, "args" + ".length == " + "n" + " is not true") + typeAssert(inTs.length == n, "inTs" + ".length == " + "n" + " is not true") + args.zip(inTs).foreach({ + case (args, inTs) => + args :: expT(inTs, read) + }) + } + true } override val t: ExpType = expT(outT, write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): OpenCLFunctionCall = new OpenCLFunctionCall(name, n)(inTs.map(v.data), v.data(outT), args.map(VisitAndRebuild(_, v))) diff --git a/src/main/scala/shine/OpenCL/primitives/functional/ReduceSeq.scala b/src/main/scala/shine/OpenCL/primitives/functional/ReduceSeq.scala index f00544913..b22858864 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/ReduceSeq.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/ReduceSeq.scala @@ -8,10 +8,11 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ReduceSeq(unroll: Boolean)(val n: Nat, val a: AddressSpace, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], val init: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt2, read), FunType(expT(dt1, read), expT(dt2, write))) init :: expT(dt2, write) array :: expT(ArrayType(n, dt1), read) + true } override val t: ExpType = expT(dt2, read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): ReduceSeq = new ReduceSeq(unroll)(v.nat(n), v.addressSpace(a), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(init, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/functional/RotateValues.scala b/src/main/scala/shine/OpenCL/primitives/functional/RotateValues.scala index 093afa594..12d34e141 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/RotateValues.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/RotateValues.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class RotateValues(val a: AddressSpace, val n: Nat, val sz: Nat, val dt: DataType, val wrt: Phrase[FunType[ExpType, ExpType]], val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { wrt :: FunType(expT(dt, read), expT(dt, write)) input :: expT(ArrayType(n - 1 + sz, dt), read) + true } override val t: ExpType = expT(ArrayType(n, ArrayType(sz, dt)), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): RotateValues = new RotateValues(v.addressSpace(a), v.nat(n), v.nat(sz), v.data(dt), VisitAndRebuild(wrt, v), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/functional/Run.scala b/src/main/scala/shine/OpenCL/primitives/functional/Run.scala index 6dc810e2c..7434f65a9 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/Run.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/Run.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Run(localSize: shine.OpenCL.LocalSize, globalSize: shine.OpenCL.GlobalSize)(val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { input :: expT(dt, write) + true } override val t: ExpType = expT(dt, write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Run = new Run(localSize.visitAndRebuild(v), globalSize.visitAndRebuild(v))(v.data(dt), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/functional/ToMem.scala b/src/main/scala/shine/OpenCL/primitives/functional/ToMem.scala index 26cd17bb2..8d5bf9e58 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/ToMem.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/ToMem.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ToMem(val a: AddressSpace, val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { input :: expT(dt, write) + true } override val t: ExpType = expT(dt, read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): ToMem = new ToMem(v.addressSpace(a), v.data(dt), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/Barrier.scala b/src/main/scala/shine/OpenCL/primitives/imperative/Barrier.scala index cf222f8b7..217098eb3 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/Barrier.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/Barrier.scala @@ -8,7 +8,6 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Barrier(local: Boolean, global: Boolean)() extends CommandPrimitive { - {} override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): Barrier = new Barrier(local, global)() } diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/HostExecution.scala b/src/main/scala/shine/OpenCL/primitives/imperative/HostExecution.scala index 54f977635..5e2ed6f0a 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/HostExecution.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/HostExecution.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class HostExecution(params: Map[Identifier[_ <: PhraseType], shine.OpenCL.AccessFlags])(val body: Phrase[CommType]) extends CommandPrimitive { - { + assert { body :: comm + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): HostExecution = new HostExecution(params.map({ diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistribute.scala b/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistribute.scala index c046d1597..8b1835c34 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistribute.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistribute.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class IdxDistribute(parallelismLevel: shine.OpenCL.ParallelismLevel)(val m: Nat, val n: Nat, val stride: Nat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { array :: expT(ArrayType(m, dt), read) + true } override val t: ExpType = expT(ArrayType(n, dt), read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): IdxDistribute = new IdxDistribute(parallelismLevel)(v.nat(m), v.nat(n), v.nat(stride), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistributeAcc.scala b/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistributeAcc.scala index 0de4a9e08..b57908d4f 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistributeAcc.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistributeAcc.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class IdxDistributeAcc(parallelismLevel: shine.OpenCL.ParallelismLevel)(val m: Nat, val n: Nat, val stride: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { - { + assert { array :: accT(ArrayType(m, dt)) + true } override val t: AccType = accT(ArrayType(n, dt)) override def visitAndRebuild(v: VisitAndRebuild.Visitor): IdxDistributeAcc = new IdxDistributeAcc(parallelismLevel)(v.nat(m), v.nat(n), v.nat(stride), v.data(dt), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala b/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala index 24e26716d..247dd63a8 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala @@ -8,12 +8,17 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class KernelCallCmd(name: String, localSize: shine.OpenCL.LocalSize, globalSize: shine.OpenCL.GlobalSize, n: Int)(val inTs: Seq[DataType], val dt: DataType, val args: Seq[Phrase[ExpType]], val output: Phrase[AccType]) extends CommandPrimitive { - { - args.zip(inTs).foreach({ - case (args, inTs) => - args :: expT(inTs, read) - }) + assert { + { + typeAssert(args.length == n, "args" + ".length == " + "n" + " is not true") + typeAssert(inTs.length == n, "inTs" + ".length == " + "n" + " is not true") + args.zip(inTs).foreach({ + case (args, inTs) => + args :: expT(inTs, read) + }) + } output :: accT(dt) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): KernelCallCmd = new KernelCallCmd(name, localSize.visitAndRebuild(v), globalSize.visitAndRebuild(v), n)(inTs.map(v.data), v.data(dt), args.map(VisitAndRebuild(_, v)), VisitAndRebuild(output, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/New.scala b/src/main/scala/shine/OpenCL/primitives/imperative/New.scala index 4b455cb27..0f5b2da08 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/New.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/New.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class New(val a: AddressSpace, val dt: DataType, val f: Phrase[FunType[PhrasePairType[ExpType, AccType], CommType]]) extends CommandPrimitive { - { + assert { f :: FunType(PhrasePairType(expT(dt, read), accT(dt)), comm) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): New = new New(v.addressSpace(a), v.data(dt), VisitAndRebuild(f, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/NewDoubleBuffer.scala b/src/main/scala/shine/OpenCL/primitives/imperative/NewDoubleBuffer.scala index 9e50f0b6a..8e1b64a67 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/NewDoubleBuffer.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/NewDoubleBuffer.scala @@ -8,10 +8,11 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class NewDoubleBuffer(val a: AddressSpace, val dt1: DataType, val dt2: DataType, val dt3: DataType, val n: Nat, val in: Phrase[ExpType], val out: Phrase[AccType], val f: Phrase[FunType[PhrasePairType[PhrasePairType[PhrasePairType[ExpType, AccType], CommType], CommType], CommType]]) extends CommandPrimitive { - { + assert { in :: expT(dt1, read) out :: accT(dt2) f :: FunType(PhrasePairType(PhrasePairType(PhrasePairType(expT(ArrayType(n, dt3), read), accT(ArrayType(n, dt3))), comm), comm), comm) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): NewDoubleBuffer = new NewDoubleBuffer(v.addressSpace(a), v.data(dt1), v.data(dt2), v.data(dt3), v.nat(n), VisitAndRebuild(in, v), VisitAndRebuild(out, v), VisitAndRebuild(f, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/NewManagedBuffer.scala b/src/main/scala/shine/OpenCL/primitives/imperative/NewManagedBuffer.scala index 5c439ae17..0f1261ba8 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/NewManagedBuffer.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/NewManagedBuffer.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class NewManagedBuffer(access: shine.OpenCL.AccessFlags)(val dt: DataType, val k: Phrase[FunType[PhrasePairType[ExpType, AccType], CommType]]) extends CommandPrimitive { - { + assert { k :: FunType(PhrasePairType(expT(ManagedBufferType(dt), read), accT(ManagedBufferType(dt))), comm) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): NewManagedBuffer = new NewManagedBuffer(access)(v.data(dt), VisitAndRebuild(k, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/ParFor.scala b/src/main/scala/shine/OpenCL/primitives/imperative/ParFor.scala index 47efa67bb..a46e0bfc6 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/ParFor.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/ParFor.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ParFor(level: shine.OpenCL.ParallelismLevel, dim: Int, unroll: Boolean, prefix: String)(val init: Nat, val n: Nat, val step: Nat, val dt: DataType, val out: Phrase[AccType], val body: Phrase[FunType[ExpType, FunType[AccType, CommType]]]) extends CommandPrimitive { - { + assert { out :: accT(ArrayType(n, dt)) body :: FunType(expT(IndexType(n), read), FunType(accT(dt), comm)) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParFor = new ParFor(level, dim, unroll, prefix)(v.nat(init), v.nat(n), v.nat(step), v.data(dt), VisitAndRebuild(out, v), VisitAndRebuild(body, v)) diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala b/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala index bd45c3fd0..3f0b066b4 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala @@ -8,12 +8,13 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ParForNat(level: shine.OpenCL.ParallelismLevel, dim: Int, unroll: Boolean, prefix: String)(val init: Nat, val n: Nat, val step: Nat, val ft: NatToData, val out: Phrase[AccType], val body: Phrase[DepFunType[NatKind, FunType[AccType, CommType]]]) extends CommandPrimitive { - { + assert { out :: accT(DepArrayType(n, ft)) body :: ({ val i = body.t.x DepFunType[NatKind, PhraseType](i, FunType(accT(NatToDataApply(ft, i)), comm)) }) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParForNat = new ParForNat(level, dim, unroll, prefix)(v.nat(init), v.nat(n), v.nat(step), v.natToData(ft), VisitAndRebuild(out, v), VisitAndRebuild(body, v)) diff --git a/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala b/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala index 9035e541e..7d5550d82 100644 --- a/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala +++ b/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala @@ -8,12 +8,13 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class DepMapPar(val n: Nat, val ft1: NatToData, val ft2: NatToData, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: ({ val m = f.t.x DepFunType[NatKind, PhraseType](m, FunType(expT(NatToDataApply(ft1, m), read), expT(NatToDataApply(ft2, m), write))) }) array :: expT(DepArrayType(n, ft1), read) + true } override val t: ExpType = expT(DepArrayType(n, ft2), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepMapPar = new DepMapPar(v.nat(n), v.natToData(ft1), v.natToData(ft2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/OpenMP/primitives/functional/MapPar.scala b/src/main/scala/shine/OpenMP/primitives/functional/MapPar.scala index 1b622ffd4..77aa340ea 100644 --- a/src/main/scala/shine/OpenMP/primitives/functional/MapPar.scala +++ b/src/main/scala/shine/OpenMP/primitives/functional/MapPar.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class MapPar(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt1, read), expT(dt2, write)) array :: expT(ArrayType(n, dt1), read) + true } override val t: ExpType = expT(ArrayType(n, dt2), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapPar = new MapPar(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/OpenMP/primitives/functional/ReducePar.scala b/src/main/scala/shine/OpenMP/primitives/functional/ReducePar.scala index 9cc4bf236..c8176ed9b 100644 --- a/src/main/scala/shine/OpenMP/primitives/functional/ReducePar.scala +++ b/src/main/scala/shine/OpenMP/primitives/functional/ReducePar.scala @@ -8,10 +8,11 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ReducePar(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], val init: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt2, read), FunType(expT(dt1, read), expT(dt2, write))) init :: expT(dt2, write) array :: expT(ArrayType(n, dt1), read) + true } override val t: ExpType = expT(dt2, read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): ReducePar = new ReducePar(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(init, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/OpenMP/primitives/imperative/ParFor.scala b/src/main/scala/shine/OpenMP/primitives/imperative/ParFor.scala index d467a4c29..29b147716 100644 --- a/src/main/scala/shine/OpenMP/primitives/imperative/ParFor.scala +++ b/src/main/scala/shine/OpenMP/primitives/imperative/ParFor.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ParFor(val n: Nat, val dt: DataType, val out: Phrase[AccType], val body: Phrase[FunType[ExpType, FunType[AccType, CommType]]]) extends CommandPrimitive { - { + assert { out :: accT(ArrayType(n, dt)) body :: FunType(expT(IndexType(n), read), FunType(accT(dt), comm)) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParFor = new ParFor(v.nat(n), v.data(dt), VisitAndRebuild(out, v), VisitAndRebuild(body, v)) diff --git a/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala b/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala index 92b4f87b3..8925bb322 100644 --- a/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala +++ b/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala @@ -8,12 +8,13 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ParForNat(val n: Nat, val ft: NatToData, val out: Phrase[AccType], val body: Phrase[DepFunType[NatKind, FunType[AccType, CommType]]]) extends CommandPrimitive { - { + assert { out :: accT(DepArrayType(n, ft)) body :: ({ val i = body.t.x DepFunType[NatKind, PhraseType](i, FunType(accT(NatToDataApply(ft, i)), comm)) }) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParForNat = new ParForNat(v.nat(n), v.natToData(ft), VisitAndRebuild(out, v), VisitAndRebuild(body, v)) diff --git a/src/main/scala/shine/cuda/primitives/functional/AsFragment.scala b/src/main/scala/shine/cuda/primitives/functional/AsFragment.scala index 31d151cf2..f4eb2ff1b 100644 --- a/src/main/scala/shine/cuda/primitives/functional/AsFragment.scala +++ b/src/main/scala/shine/cuda/primitives/functional/AsFragment.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class AsFragment(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { input :: expT(ArrayType(rows, ArrayType(columns, dt)), read) + true } override val t: ExpType = expT(FragmentType(rows, columns, layers, dt, frag, layout), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsFragment = new AsFragment(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/cuda/primitives/functional/AsMatrix.scala b/src/main/scala/shine/cuda/primitives/functional/AsMatrix.scala index c3ea35b06..d3ff60e68 100644 --- a/src/main/scala/shine/cuda/primitives/functional/AsMatrix.scala +++ b/src/main/scala/shine/cuda/primitives/functional/AsMatrix.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class AsMatrix(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { input :: expT(FragmentType(rows, columns, layers, dt, FragmentKind.Accumulator, MatrixLayout.None), read) + true } override val t: ExpType = expT(ArrayType(rows, ArrayType(columns, dt)), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsMatrix = new AsMatrix(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/cuda/primitives/functional/GenerateFragment.scala b/src/main/scala/shine/cuda/primitives/functional/GenerateFragment.scala index fa0144660..5c0eea3db 100644 --- a/src/main/scala/shine/cuda/primitives/functional/GenerateFragment.scala +++ b/src/main/scala/shine/cuda/primitives/functional/GenerateFragment.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class GenerateFragment(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val fill: Phrase[ExpType]) extends ExpPrimitive { - { + assert { fill :: expT(dt, read) + true } override val t: ExpType = expT(FragmentType(rows, columns, layers, dt, frag, layout), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): GenerateFragment = new GenerateFragment(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(fill, v)) diff --git a/src/main/scala/shine/cuda/primitives/functional/GlobalToShared.scala b/src/main/scala/shine/cuda/primitives/functional/GlobalToShared.scala index cc4f06188..d5cb3b4e4 100644 --- a/src/main/scala/shine/cuda/primitives/functional/GlobalToShared.scala +++ b/src/main/scala/shine/cuda/primitives/functional/GlobalToShared.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class GlobalToShared(val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { input :: expT(dt, write) + true } override val t: ExpType = expT(dt, read) override def visitAndRebuild(v: VisitAndRebuild.Visitor): GlobalToShared = new GlobalToShared(v.data(dt), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/cuda/primitives/functional/Map.scala b/src/main/scala/shine/cuda/primitives/functional/Map.scala index 7695393cd..342b64121 100644 --- a/src/main/scala/shine/cuda/primitives/functional/Map.scala +++ b/src/main/scala/shine/cuda/primitives/functional/Map.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class Map(level: shine.OpenCL.ParallelismLevel, dim: Int)(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt1, read), expT(dt2, write)) array :: expT(ArrayType(n, dt1), read) + true } override val t: ExpType = expT(ArrayType(n, dt2), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): Map = new Map(level, dim)(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) diff --git a/src/main/scala/shine/cuda/primitives/functional/MapFragment.scala b/src/main/scala/shine/cuda/primitives/functional/MapFragment.scala index 9757cf271..5d4128c67 100644 --- a/src/main/scala/shine/cuda/primitives/functional/MapFragment.scala +++ b/src/main/scala/shine/cuda/primitives/functional/MapFragment.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class MapFragment(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val f: Phrase[FunType[ExpType, ExpType]], val input: Phrase[ExpType]) extends ExpPrimitive { - { + assert { f :: FunType(expT(dt, read), expT(dt, write)) input :: expT(FragmentType(rows, columns, layers, dt, frag, layout), read) + true } override val t: ExpType = expT(FragmentType(rows, columns, layers, dt, frag, layout), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapFragment = new MapFragment(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(f, v), VisitAndRebuild(input, v)) diff --git a/src/main/scala/shine/cuda/primitives/functional/TensorMatMultAdd.scala b/src/main/scala/shine/cuda/primitives/functional/TensorMatMultAdd.scala index c35cf09a7..a5ef39e01 100644 --- a/src/main/scala/shine/cuda/primitives/functional/TensorMatMultAdd.scala +++ b/src/main/scala/shine/cuda/primitives/functional/TensorMatMultAdd.scala @@ -8,10 +8,11 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class TensorMatMultAdd(val m: Nat, val n: Nat, val k: Nat, val layoutA: MatrixLayout, val layoutB: MatrixLayout, val dt1: DataType, val dt2: DataType, val aMatrix: Phrase[ExpType], val bMatrix: Phrase[ExpType], val cMatrix: Phrase[ExpType]) extends ExpPrimitive { - { + assert { aMatrix :: expT(FragmentType(m, k, n, dt1, FragmentKind.AMatrix, layoutA), read) bMatrix :: expT(FragmentType(k, n, m, dt1, FragmentKind.BMatrix, layoutB), read) cMatrix :: expT(FragmentType(m, n, k, dt2, FragmentKind.Accumulator, MatrixLayout.None), read) + true } override val t: ExpType = expT(FragmentType(m, n, k, dt2, FragmentKind.Accumulator, MatrixLayout.None), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): TensorMatMultAdd = new TensorMatMultAdd(v.nat(m), v.nat(n), v.nat(k), layoutA, layoutB, v.data(dt1), v.data(dt2), VisitAndRebuild(aMatrix, v), VisitAndRebuild(bMatrix, v), VisitAndRebuild(cMatrix, v)) diff --git a/src/main/scala/shine/cuda/primitives/imperative/ForFragment.scala b/src/main/scala/shine/cuda/primitives/imperative/ForFragment.scala index 283476433..0f1eb321d 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/ForFragment.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/ForFragment.scala @@ -8,10 +8,11 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ForFragment(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val in: Phrase[ExpType], val out: Phrase[AccType], val fun: Phrase[FunType[ExpType, FunType[AccType, CommType]]]) extends CommandPrimitive { - { + assert { in :: expT(FragmentType(rows, columns, layers, dt, frag, layout), read) out :: accT(FragmentType(rows, columns, layers, dt, frag, layout)) fun :: FunType(expT(dt, read), FunType(accT(dt), comm)) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): ForFragment = new ForFragment(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(in, v), VisitAndRebuild(out, v), VisitAndRebuild(fun, v)) diff --git a/src/main/scala/shine/cuda/primitives/imperative/GlobalToSharedAcc.scala b/src/main/scala/shine/cuda/primitives/imperative/GlobalToSharedAcc.scala index dba030d80..d82ceb63f 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/GlobalToSharedAcc.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/GlobalToSharedAcc.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class GlobalToSharedAcc(val dt: DataType, val pipe: Phrase[ExpType], val outputShared: Phrase[AccType]) extends AccPrimitive { - { + assert { pipe :: expT(OpaqueType("pipeline"), read) outputShared :: accT(dt) + true } override val t: AccType = accT(dt) override def visitAndRebuild(v: VisitAndRebuild.Visitor): GlobalToSharedAcc = new GlobalToSharedAcc(v.data(dt), VisitAndRebuild(pipe, v), VisitAndRebuild(outputShared, v)) diff --git a/src/main/scala/shine/cuda/primitives/imperative/ParFor.scala b/src/main/scala/shine/cuda/primitives/imperative/ParFor.scala index 8a471f68d..52208cd3e 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/ParFor.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/ParFor.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class ParFor(level: shine.OpenCL.ParallelismLevel, dim: Int, unroll: Boolean, prefix: String)(val init: Nat, val n: Nat, val step: Nat, val dt: DataType, val out: Phrase[AccType], val body: Phrase[FunType[ExpType, FunType[AccType, CommType]]]) extends CommandPrimitive { - { + assert { out :: accT(ArrayType(n, dt)) body :: FunType(expT(IndexType(n), read), FunType(accT(dt), comm)) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParFor = new ParFor(level, dim, unroll, prefix)(v.nat(init), v.nat(n), v.nat(step), v.data(dt), VisitAndRebuild(out, v), VisitAndRebuild(body, v)) diff --git a/src/main/scala/shine/cuda/primitives/imperative/SyncPipeline.scala b/src/main/scala/shine/cuda/primitives/imperative/SyncPipeline.scala index c1f06a010..51d65f9b3 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/SyncPipeline.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/SyncPipeline.scala @@ -8,8 +8,9 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class SyncPipeline(val pipe: Phrase[ExpType]) extends CommandPrimitive { - { + assert { pipe :: expT(OpaqueType("pipeline"), read) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): SyncPipeline = new SyncPipeline(VisitAndRebuild(pipe, v)) diff --git a/src/main/scala/shine/cuda/primitives/imperative/SyncThreads.scala b/src/main/scala/shine/cuda/primitives/imperative/SyncThreads.scala index d7d7acafc..1f618726c 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/SyncThreads.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/SyncThreads.scala @@ -8,7 +8,6 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class SyncThreads() extends CommandPrimitive { - {} override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): SyncThreads = new SyncThreads() } diff --git a/src/main/scala/shine/cuda/primitives/imperative/SyncWarp.scala b/src/main/scala/shine/cuda/primitives/imperative/SyncWarp.scala index 31718237a..083e5d292 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/SyncWarp.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/SyncWarp.scala @@ -8,7 +8,6 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class SyncWarp() extends CommandPrimitive { - {} override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): SyncWarp = new SyncWarp() } diff --git a/src/main/scala/shine/cuda/primitives/imperative/WmmaFill.scala b/src/main/scala/shine/cuda/primitives/imperative/WmmaFill.scala index cf250936a..c45e8e894 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/WmmaFill.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/WmmaFill.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class WmmaFill(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val fill: Phrase[ExpType], val target: Phrase[AccType]) extends CommandPrimitive { - { + assert { fill :: expT(dt, read) target :: accT(FragmentType(rows, columns, layers, dt, frag, layout)) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): WmmaFill = new WmmaFill(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(fill, v), VisitAndRebuild(target, v)) diff --git a/src/main/scala/shine/cuda/primitives/imperative/WmmaLoad.scala b/src/main/scala/shine/cuda/primitives/imperative/WmmaLoad.scala index 61b33bcba..da04106be 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/WmmaLoad.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/WmmaLoad.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class WmmaLoad(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val matrixTile: Phrase[ExpType], val target: Phrase[AccType]) extends CommandPrimitive { - { + assert { matrixTile :: expT(ArrayType(rows, ArrayType(columns, dt)), read) target :: accT(FragmentType(rows, columns, layers, dt, frag, layout)) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): WmmaLoad = new WmmaLoad(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(matrixTile, v), VisitAndRebuild(target, v)) diff --git a/src/main/scala/shine/cuda/primitives/imperative/WmmaMMA.scala b/src/main/scala/shine/cuda/primitives/imperative/WmmaMMA.scala index 793a562f5..06316c51f 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/WmmaMMA.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/WmmaMMA.scala @@ -8,11 +8,12 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class WmmaMMA(val m: Nat, val n: Nat, val k: Nat, val layoutA: MatrixLayout, val layoutB: MatrixLayout, val dt1: DataType, val dt2: DataType, val aMatrix: Phrase[ExpType], val bMatrix: Phrase[ExpType], val cMatrix: Phrase[ExpType], val resultMatrix: Phrase[AccType]) extends CommandPrimitive { - { + assert { aMatrix :: expT(FragmentType(m, k, n, dt1, FragmentKind.AMatrix, layoutA), read) bMatrix :: expT(FragmentType(k, n, m, dt1, FragmentKind.BMatrix, layoutB), read) cMatrix :: expT(FragmentType(m, n, k, dt2, FragmentKind.Accumulator, MatrixLayout.None), read) resultMatrix :: accT(FragmentType(m, n, k, dt2, FragmentKind.Accumulator, MatrixLayout.None)) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): WmmaMMA = new WmmaMMA(v.nat(m), v.nat(n), v.nat(k), layoutA, layoutB, v.data(dt1), v.data(dt2), VisitAndRebuild(aMatrix, v), VisitAndRebuild(bMatrix, v), VisitAndRebuild(cMatrix, v), VisitAndRebuild(resultMatrix, v)) diff --git a/src/main/scala/shine/cuda/primitives/imperative/WmmaStore.scala b/src/main/scala/shine/cuda/primitives/imperative/WmmaStore.scala index fb5e7ffd6..26bc6a6e8 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/WmmaStore.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/WmmaStore.scala @@ -8,9 +8,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ final case class WmmaStore(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val value: Phrase[ExpType], val matrixTile: Phrase[AccType]) extends CommandPrimitive { - { + assert { value :: expT(FragmentType(rows, columns, layers, dt, FragmentKind.Accumulator, MatrixLayout.None), read) matrixTile :: accT(ArrayType(rows, ArrayType(columns, dt))) + true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): WmmaStore = new WmmaStore(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), VisitAndRebuild(value, v), VisitAndRebuild(matrixTile, v))