From a3452180b108ffe6f828c4f122bf46d7e430085a Mon Sep 17 00:00:00 2001 From: Damian Reeves <957246+DamianReeves@users.noreply.github.com> Date: Sun, 14 Feb 2021 18:29:10 -0500 Subject: [PATCH] Morphir Core and IR work (#91) * Further application of recursion schemes for working with the Type AST * Rollback un-needed changes * Change over to using scala.BigDecimal to back the Decimal type * Ensure better ergonomics with the Decimal type --- build.sc | 83 ++++++++++---- morphir/ir/src/morphir/ir/recursions.scala | 99 ++++++++++++++-- .../morphir/sdk/DecimalModuleCompat.scala | 14 +++ .../morphir/sdk/DecimalModuleCompat.scala | 3 + .../morphir/sdk/DecimalModuleCompat.scala | 3 + .../sdk/core/src/morphir/sdk/Decimal.scala | 41 ++++--- .../test/src/morphir/sdk/DecimalSpec.scala | 12 ++ .../morphir/sdk/SdkTypesEncodingSpec.scala | 23 ++++ .../morphir/sdk/spark/testing/SparkSpec.scala | 50 ++++++++ .../sdk/spark/testing/sparkModule.scala | 107 ++++++++++++++++++ 10 files changed, 384 insertions(+), 51 deletions(-) create mode 100644 morphir/sdk/core/src-2.11/morphir/sdk/DecimalModuleCompat.scala create mode 100644 morphir/sdk/core/src-2.12/morphir/sdk/DecimalModuleCompat.scala create mode 100644 morphir/sdk/core/src-2.13/morphir/sdk/DecimalModuleCompat.scala create mode 100644 morphir/sdk/core/test/src/morphir/sdk/DecimalSpec.scala create mode 100644 morphir/sdk/spark/test/src/morphir/sdk/SdkTypesEncodingSpec.scala create mode 100644 morphir/sdk/spark/test/src/morphir/sdk/spark/testing/SparkSpec.scala create mode 100644 morphir/sdk/spark/test/src/morphir/sdk/spark/testing/sparkModule.scala diff --git a/build.sc b/build.sc index e7abf776..6cec7e87 100644 --- a/build.sc +++ b/build.sc @@ -36,7 +36,7 @@ object Deps { val zioLogging = "0.5.6" val zioMagic = "0.1.8" val zioNio = "1.0.0-RC10" - val zioPrelude = "1.0.0-RC1" + val zioPrelude = "1.0.0-RC2" val zioProcess = "0.2.0" val newtype = "0.4.4" def decline(scalaVersion: String) = scalaVersion match { @@ -52,6 +52,7 @@ object Deps { val slf4zio = "1.0.0" val scalactic = "3.1.2" val scalaUri = "2.2.2" + val spark = "2.4.7" val oslib = "0.6.2" val quill = "3.6.0-RC3" } @@ -220,28 +221,28 @@ object morphir extends Module { } } } - - object scala extends Module { - - object jvm extends Cross[JvmMorphirScalaModule](Versions.scala213) - class JvmMorphirScalaModule(val crossScalaVersion: String) - extends CrossScalaModule - with CommonJvmModule - with ScalaMacroModule - with MorphirPublishModule { self => - def artifactName = "morphir-scala" - def moduleDeps = Seq(morphir.ir.jvm(crossScalaVersion)) - - def ivyDeps = Agg( - ivy"org.scalameta::scalameta:${Versions.scalameta}" - ) - - object test extends Tests { - def platformSegment: String = self.platformSegment - def crossScalaVersion = JvmMorphirScalaModule.this.crossScalaVersion - } - } - } +// +// object scala extends Module { +// +// object jvm extends Cross[JvmMorphirScalaModule](Versions.scala213) +// class JvmMorphirScalaModule(val crossScalaVersion: String) +// extends CrossScalaModule +// with CommonJvmModule +// with ScalaMacroModule +// with MorphirPublishModule { self => +// def artifactName = "morphir-scala" +// def moduleDeps = Seq(morphir.ir.jvm(crossScalaVersion)) +// +// def ivyDeps = Agg( +// ivy"org.scalameta::scalameta:${Versions.scalameta}" +// ) +// +// object test extends Tests { +// def platformSegment: String = self.platformSegment +// def crossScalaVersion = JvmMorphirScalaModule.this.crossScalaVersion +// } +// } +// } object sdk extends Module { object core extends Module { @@ -263,6 +264,42 @@ object morphir extends Module { } } } + + object spark extends Module { + object jvm + extends Cross[JvmMorphirSdkSpark]( + Versions.scala212, + Versions.scala211 + ) + class JvmMorphirSdkSpark(val crossScalaVersion: String) + extends CrossScalaModule + with CommonJvmModule + with MorphirPublishModule { self => + + def artifactName = "morphir-sdk-spark" + def compileIvyDeps = Agg( + ivy"org.apache.spark::spark-sql:2.4.7", + ivy"com.github.ghik:::silencer-lib:${Versions.silencer}" + ) + def ivyDeps = Agg( + ivy"dev.zio::zio-prelude:${Versions.zioPrelude}" + ) + def scalacPluginIvyDeps = Agg(ivy"com.github.ghik:::silencer-plugin:${Versions.silencer}") + def moduleDeps = Seq(morphir.sdk.core.jvm(crossScalaVersion)) + + object test extends Tests { + def platformSegment: String = self.platformSegment + def crossScalaVersion = JvmMorphirSdkSpark.this.crossScalaVersion + + override def ivyDeps = super.ivyDeps() ++ + Agg( + ivy"dev.zio::zio-logging:${Versions.zioLogging}", + ivy"dev.zio::zio-logging-slf4j:${Versions.zioLogging}", + ivy"org.apache.spark::spark-sql:2.4.7" + ) + } + } + } } object flowz extends Module { diff --git a/morphir/ir/src/morphir/ir/recursions.scala b/morphir/ir/src/morphir/ir/recursions.scala index 80c70b30..296c356f 100644 --- a/morphir/ir/src/morphir/ir/recursions.scala +++ b/morphir/ir/src/morphir/ir/recursions.scala @@ -1,5 +1,7 @@ package morphir.ir +import zio.ZIO + /** * In look at how we can potentially model the Morphir IR, research has shown that recursIt woulkd */ @@ -24,27 +26,88 @@ object recursions { type Name = String type FQNAme = String - final case class Field[+Self](name: Name, value: Self) + final case class Field[+Self](name: Name, value: Self) { + def map[Self2](fn: Self => Self2): Field[Self2] = Field(name, fn(value)) + def mapM[R, E, Self2](fn: Self => ZIO[R, E, Self2]): ZIO[R, E, Field[Self2]] = fn(value).map(Field(name, _)) + } sealed trait TypeCase[+Self, +Attrib] { - def map[Self2](fn: Self => Self2): TypeCase[Self2, Attrib] = ??? + def map[Self2](fn: Self => Self2): TypeCase[Self2, Attrib] + def mapAttrib[Attrib2](f: Attrib => Attrib2): TypeCase[Self, Attrib2] + def mapM[R, E, Self2](fn: Self => ZIO[R, E, Self2]): ZIO[R, E, TypeCase[Self2, Attrib]] } object TypeCase { - final case class Variable[Attrib](a: Attrib) extends TypeCase[Nothing, Attrib] + final case class Variable[Attrib](a: Attrib) extends TypeCase[Nothing, Attrib] { self => + def map[Self2](fn: Nothing => Self2): TypeCase[Self2, Attrib] = self + + def mapAttrib[Attrib2](f: Attrib => Attrib2): TypeCase[Nothing, Attrib2] = copy(a = f(a)) + + def mapM[R, E, Self2](fn: Nothing => ZIO[R, E, Self2]): ZIO[R, E, TypeCase[Self2, Attrib]] = ZIO.succeed(self) + } final case class Reference[Self, Attrib](attribute: Attrib, name: FQName, types: List[Self]) - extends TypeCase[Self, Attrib] - final case class Tuple[Self, Attrib](attribute: Attrib, types: List[Self]) extends TypeCase[Self, Attrib] - final case class Record[Self, Attrib](attributes: Attrib, fields: List[Field[Self]]) extends TypeCase[Self, Attrib] - final case class ExtensibleRecord[Self, Attrib](attributes: Attrib, name: Name, fields: List[Field[Self]]) - extends TypeCase[Self, Attrib] - final case class Function[Self, Attrib](attribute: Attrib, input: Self, output: Self) extends TypeCase[Self, Attrib] - final case class Unit[+Attrib](attribute: Attrib) extends TypeCase[Nothing, Attrib] + extends TypeCase[Self, Attrib] { + def map[Self2](fn: Self => Self2): TypeCase[Self2, Attrib] = Reference(attribute, name, types.map(fn)) + + def mapAttrib[Attrib2](f: Attrib => Attrib2): TypeCase[Self, Attrib2] = copy(attribute = f(attribute)) + + def mapM[R, E, Self2](fn: Self => ZIO[R, E, Self2]): ZIO[R, E, TypeCase[Self2, Attrib]] = + ZIO.foreach(types)(fn).map(Reference(attribute, name, _)) + } + final case class Tuple[Self, Attrib](attribute: Attrib, types: List[Self]) extends TypeCase[Self, Attrib] { + def map[Self2](fn: Self => Self2): TypeCase[Self2, Attrib] = Tuple(attribute, types.map(fn)) + + def mapAttrib[Attrib2](f: Attrib => Attrib2): TypeCase[Self, Attrib2] = copy(attribute = f(attribute)) + + def mapM[R, E, Self2](fn: Self => ZIO[R, E, Self2]): ZIO[R, E, TypeCase[Self2, Attrib]] = + ZIO.foreach(types)(fn).map(Tuple(attribute, _)) + } + final case class Record[Self, Attrib](attribute: Attrib, fields: List[Field[Self]]) extends TypeCase[Self, Attrib] { + def map[Self2](fn: Self => Self2): TypeCase[Self2, Attrib] = Record(attribute, fields.map(_.map(fn))) + def mapAttrib[Attrib2](f: Attrib => Attrib2): TypeCase[Self, Attrib2] = copy(attribute = f(attribute)) + + def mapM[R, E, Self2](fn: Self => ZIO[R, E, Self2]): ZIO[R, E, TypeCase[Self2, Attrib]] = + ZIO.foreach(fields)(_.mapM(fn)).map(Record(attribute, _)) + } + final case class ExtensibleRecord[Self, Attrib](attribute: Attrib, name: Name, fields: List[Field[Self]]) + extends TypeCase[Self, Attrib] { + def map[Self2](fn: Self => Self2): TypeCase[Self2, Attrib] = + ExtensibleRecord(attribute, name, fields.map(_.map(fn))) + + def mapAttrib[Attrib2](f: Attrib => Attrib2): TypeCase[Self, Attrib2] = copy(attribute = f(attribute)) + + def mapM[R, E, Self2](fn: Self => ZIO[R, E, Self2]): ZIO[R, E, TypeCase[Self2, Attrib]] = + ZIO.foreach(fields)(_.mapM(fn)).map(ExtensibleRecord(attribute, name, _)) + } + final case class Function[Self, Attrib](attribute: Attrib, input: Self, output: Self) + extends TypeCase[Self, Attrib] { + def map[Self2](fn: Self => Self2): TypeCase[Self2, Attrib] = Function(attribute, fn(input), fn(output)) + def mapAttrib[Attrib2](f: Attrib => Attrib2): TypeCase[Self, Attrib2] = copy(attribute = f(attribute)) + + def mapM[R, E, Self2](fn: Self => ZIO[R, E, Self2]): ZIO[R, E, TypeCase[Self2, Attrib]] = + ZIO.mapN(fn(input), fn(output)) { case (input1, output1) => + Function(attribute, input1, output1) + } + } + final case class Unit[+Attrib](attribute: Attrib) extends TypeCase[Nothing, Attrib] { self => + def map[Self2](fn: Nothing => Self2): TypeCase[Self2, Attrib] = self + def mapAttrib[Attrib2](f: Attrib => Attrib2): TypeCase[Nothing, Attrib2] = copy(attribute = f(attribute)) + + def mapM[R, E, Self2](fn: Nothing => ZIO[R, E, Self2]): ZIO[R, E, TypeCase[Self2, Attrib]] = ZIO.succeed(self) + } } final case class Type[+Attrib](value: TypeCase[Type[Attrib], Attrib]) { self => + + //NOTE: Useful for bottom up type inference for example + def annotate[Attrib2](f: TypeCase[Type[Attrib2], Attrib] => Attrib2): Type[Attrib2] = + transformUpRecursive[Attrib2](value => value.mapAttrib(_ => f(value))) + def fold[Z](f: TypeCase[Z, Attrib] => Z): Z = f(value.map(_.fold(f))) + def foldM[R, E, Z](f: TypeCase[Z, Attrib] => ZIO[R, E, Z]): ZIO[R, E, Z] = + value.mapM(_.foldM(f)).flatMap(f) + /** * Transform the whole tree. * Top down @@ -58,6 +121,16 @@ object recursions { def transformUp[Attrib2](f: TypeMapper[Attrib, Attrib2]): Type[Attrib2] = Type(f(value.map(_.transformUp(f)))) + def transformDownRecursive[Attrib1 >: Attrib, Attrib2]( + f: TypeCase[Type[Attrib1], Attrib1] => TypeCase[Type[Attrib1], Attrib2] + ): Type[Attrib2] = + Type(f(value).map(_.transformDownRecursive(f))) + + def transformUpRecursive[Attrib2]( + f: TypeCase[Type[Attrib2], Attrib] => TypeCase[Type[Attrib2], Attrib2] + ): Type[Attrib2] = + Type(f(value.map(_.transformUpRecursive(f)))) + } object Type { @@ -65,6 +138,12 @@ object recursions { Type(f(initial).map(unfold(_)(f))) } + final case class TypeMapperRec[A, B, C, D](fn: TypeCase[Type[A], B] => TypeCase[Type[C], D]) + extends Function[TypeCase[Type[A], B], TypeCase[Type[C], D]] { + override def apply(v1: TypeCase[Type[A], B]): TypeCase[Type[C], D] = fn(v1) + + } + trait TypeMapper[-AttribIn, +AttribOut] { def apply[Self](value: TypeCase[Self, AttribIn]): TypeCase[Self, AttribOut] } diff --git a/morphir/sdk/core/src-2.11/morphir/sdk/DecimalModuleCompat.scala b/morphir/sdk/core/src-2.11/morphir/sdk/DecimalModuleCompat.scala new file mode 100644 index 00000000..92bdc043 --- /dev/null +++ b/morphir/sdk/core/src-2.11/morphir/sdk/DecimalModuleCompat.scala @@ -0,0 +1,14 @@ +package morphir.sdk + +trait DecimalModuleCompat { + import DecimalModuleCompat._ + implicit def toBigDecimalOps(value: BigDecimal): BigDecimalOps = + new BigDecimalOps(value) +} + +object DecimalModuleCompat { + class BigDecimalOps(private val self: BigDecimal) extends AnyVal { + def compareTo(that: BigDecimal): Int = + self.compare(that) + } +} diff --git a/morphir/sdk/core/src-2.12/morphir/sdk/DecimalModuleCompat.scala b/morphir/sdk/core/src-2.12/morphir/sdk/DecimalModuleCompat.scala new file mode 100644 index 00000000..b13c37c4 --- /dev/null +++ b/morphir/sdk/core/src-2.12/morphir/sdk/DecimalModuleCompat.scala @@ -0,0 +1,3 @@ +package morphir.sdk + +trait DecimalModuleCompat {} diff --git a/morphir/sdk/core/src-2.13/morphir/sdk/DecimalModuleCompat.scala b/morphir/sdk/core/src-2.13/morphir/sdk/DecimalModuleCompat.scala new file mode 100644 index 00000000..b13c37c4 --- /dev/null +++ b/morphir/sdk/core/src-2.13/morphir/sdk/DecimalModuleCompat.scala @@ -0,0 +1,3 @@ +package morphir.sdk + +trait DecimalModuleCompat {} diff --git a/morphir/sdk/core/src/morphir/sdk/Decimal.scala b/morphir/sdk/core/src/morphir/sdk/Decimal.scala index ff32e3d1..e1735502 100644 --- a/morphir/sdk/core/src/morphir/sdk/Decimal.scala +++ b/morphir/sdk/core/src/morphir/sdk/Decimal.scala @@ -2,27 +2,32 @@ package morphir.sdk import morphir.sdk.Maybe.Maybe import morphir.sdk.Basics.Order -import java.math.{ BigDecimal => BigDec, RoundingMode } +import java.math.{ BigDecimal => BigDec } import scala.util.control.NonFatal -object Decimal { +object Decimal extends DecimalModuleCompat { - type Decimal = BigDec + type Decimal = BigDecimal object Decimal { - def apply(value: BigDec): Decimal = value - def apply(value: scala.BigDecimal): Decimal = value.bigDecimal + def apply(value: BigDec): Decimal = BigDecimal(value) + def apply(value: scala.BigDecimal): Decimal = value def apply(value: morphir.sdk.Float.Float): Decimal = BigDecimal.exact(value).bigDecimal def apply(value: morphir.sdk.Int.Int): Decimal = BigDecimal(value).bigDecimal } + @inline def apply(value: BigDec): Decimal = Decimal.apply(value) + @inline def apply(value: scala.BigDecimal): Decimal = Decimal.apply(value) + @inline def apply(value: morphir.sdk.Float.Float): Decimal = Decimal.apply(value) + @inline def apply(value: morphir.sdk.Int.Int): Decimal = Decimal.apply(value) + /** * Absolute value (sets the sign as positive) */ - def abs(value: Decimal): Decimal = value.abs() + def abs(value: Decimal): Decimal = value.abs - def add(a: Decimal)(b: Decimal): Decimal = a.add(b) + def add(a: Decimal)(b: Decimal): Decimal = a + b def bps(n: morphir.sdk.Int.Int): Decimal = Decimal(n * 0.0001) @@ -37,7 +42,7 @@ object Decimal { if (b.compareTo(zero) == 0) Maybe.nothing else try { - Maybe.just(a.divide(b)) + Maybe.just(a / b) } catch { case NonFatal(_) => Maybe.nothing } @@ -83,39 +88,39 @@ object Decimal { def millionth(n: morphir.sdk.Int.Int): Decimal = Decimal(n * 0.000001) - def mul(a: Decimal)(b: Decimal): Decimal = a.multiply(b) + def mul(a: Decimal)(b: Decimal): Decimal = a * b @inline def ne(a: Decimal)(b: Decimal): morphir.sdk.Bool.Bool = neq(a)(b) def neq(a: Decimal)(b: Decimal): morphir.sdk.Bool.Bool = a.compareTo(b) != 0 - def negate(value: Decimal): Decimal = value.negate() + def negate(value: Decimal): Decimal = -value def round(decimal: Decimal): Decimal = { - val scale = decimal.scale() - decimal.setScale(scale, RoundingMode.HALF_EVEN) + val scale = decimal.scale + decimal.setScale(scale, BigDecimal.RoundingMode.HALF_EVEN) } def shiftDecimalLeft(n: morphir.sdk.Int.Int)(value: Decimal): Decimal = - value.scaleByPowerOfTen(-n.intValue()) //TODO: When we align Int to Int this should settle in correctly + value.bigDecimal.scaleByPowerOfTen(-n.intValue()) //TODO: When we align Int to Int this should settle in correctly def shiftDecimalRight(n: morphir.sdk.Int.Int)(value: Decimal): Decimal = - value.scaleByPowerOfTen(n.intValue()) //TODO: When we align Int to Int this should settle in correctly + value.bigDecimal.scaleByPowerOfTen(n.intValue()) //TODO: When we align Int to Int this should settle in correctly - def sub(a: Decimal)(b: Decimal): Decimal = a.subtract(b) + def sub(a: Decimal)(b: Decimal): Decimal = a - b def thousand(n: morphir.sdk.Int.Int): Decimal = Decimal(n * 1000) def toFloat(value: Decimal): morphir.sdk.Float.Float = - morphir.sdk.Float.Float(value.doubleValue()) + morphir.sdk.Float.Float(value.toDouble) //TODO: Make sure the Elm call and this call return the same value def toString(value: Decimal): morphir.sdk.String.String = value.toString def truncate(decimal: Decimal): Decimal = { // Since morphir's Int is actually a Long this isn't really safe - val scale = decimal.scale() - decimal.setScale(scale, RoundingMode.DOWN) + val scale = decimal.scale + decimal.setScale(scale, BigDecimal.RoundingMode.DOWN) } /** diff --git a/morphir/sdk/core/test/src/morphir/sdk/DecimalSpec.scala b/morphir/sdk/core/test/src/morphir/sdk/DecimalSpec.scala new file mode 100644 index 00000000..6511302a --- /dev/null +++ b/morphir/sdk/core/test/src/morphir/sdk/DecimalSpec.scala @@ -0,0 +1,12 @@ +package morphir.sdk + +import zio.test.Assertion._ +import zio.test.{ assert, DefaultRunnableSpec } + +object DecimalSpec extends DefaultRunnableSpec { + def spec = suite("Decimal Spec")( + test("It should be possible to assign an int value to the Decimal") { + assert(Decimal(42))(equalTo(Decimal.fromInt(42))) + } + ) +} diff --git a/morphir/sdk/spark/test/src/morphir/sdk/SdkTypesEncodingSpec.scala b/morphir/sdk/spark/test/src/morphir/sdk/SdkTypesEncodingSpec.scala new file mode 100644 index 00000000..ba62faee --- /dev/null +++ b/morphir/sdk/spark/test/src/morphir/sdk/SdkTypesEncodingSpec.scala @@ -0,0 +1,23 @@ +package morphir.sdk + +import zio.test._ +import zio.test.Assertion._ +import zio.{ console, ZIO } +import morphir.sdk.spark.testing.SparkSpec +import morphir.sdk.spark.testing.sparkModule +object SdkTypesEncodingSpec extends SparkSpec { + def spec = suite("SdkTypesEncoding Spec")( + testM("Encoding should work for a row with a Decimal")( + for { + data <- ZIO.succeed(List(MyRow(1.1), MyRow(3.14))) + dataset <- sparkModule.createDataset(data) + actual <- sparkModule { _ => + dataset.collect().toList + } + _ <- console.putStrLn(s"Actual: $actual") + } yield assert(actual)(equalTo(data)) + ) + ) + + final case class MyRow(amount: Decimal.Decimal) +} diff --git a/morphir/sdk/spark/test/src/morphir/sdk/spark/testing/SparkSpec.scala b/morphir/sdk/spark/test/src/morphir/sdk/spark/testing/SparkSpec.scala new file mode 100644 index 00000000..30ed3deb --- /dev/null +++ b/morphir/sdk/spark/test/src/morphir/sdk/spark/testing/SparkSpec.scala @@ -0,0 +1,50 @@ +package morphir.sdk.spark.testing + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import zio.duration.durationInt +import zio.logging._ +import zio.logging.slf4j.Slf4jLogger +import zio.test.{ RunnableSpec, TestAspect, TestExecutor, TestRunner } +import zio.test.environment.{ TestEnvironment, testEnvironment } +import zio.{ ULayer, ZLayer } +import morphir.sdk.spark.testing.sparkModule.SparkModule +import SparkSpec.SparkTestingEnv + +abstract class SparkSpec extends RunnableSpec[SparkTestingEnv, Any] { + val sparkSessionBuilder: SparkSession.Builder = SparkSpec.sparkSessionBuilder + + val sparkTestingLayer: ULayer[SparkTestingEnv] = { + val logFormat = "[correlation-id = %s] %s" + val logging = Slf4jLogger.make { (context, message) => + val correlationId = LogAnnotation.CorrelationId.render( + context.get(LogAnnotation.CorrelationId) + ) + logFormat.format(correlationId, message) + } + + val spark = + SparkModule.buildLayer(sparkSessionBuilder, newSession = true) + ZLayer.succeed(sparkSessionBuilder) >>> SparkModule.fromSparkSessionBuilder + + testEnvironment ++ logging ++ spark + }.orDie + + override def aspects: List[TestAspect[Nothing, SparkTestingEnv, Nothing, Any]] = + List(TestAspect.timeout(60.seconds)) + + override def runner: TestRunner[SparkTestingEnv, Any] = + TestRunner(TestExecutor.default(sparkTestingLayer)) +} +object SparkSpec { + val sparkSessionBuilder: SparkSession.Builder = { + val sparkConf = new SparkConf() + .setMaster("local[*]") + .setAppName("spark-spec") + .set("spark.ui.enabled", "false") + .set("spark.driver.host", "localhost") + SparkSession.builder().config(sparkConf) + } + type SparkTestingEnv = TestEnvironment with Logging with SparkModule + +} diff --git a/morphir/sdk/spark/test/src/morphir/sdk/spark/testing/sparkModule.scala b/morphir/sdk/spark/test/src/morphir/sdk/spark/testing/sparkModule.scala new file mode 100644 index 00000000..aa483a66 --- /dev/null +++ b/morphir/sdk/spark/test/src/morphir/sdk/spark/testing/sparkModule.scala @@ -0,0 +1,107 @@ +package morphir.sdk.spark.testing + +import zio._ +import org.apache.spark.sql._ + +import scala.annotation.nowarn +import scala.reflect.runtime.universe.TypeTag + +object sparkModule { + type SparkModule = Has[SparkModule.Service] + + def apply[A](f: SparkSession => A): ZIO[SparkModule, Throwable, A] = withSpark(f) + + def createDataset[A <: Product: TypeTag](data: Seq[A]): ZIO[SparkModule, Throwable, Dataset[A]] = + ZIO.accessM[SparkModule](_.get.createDataset(data)) + + def createDatasetOf[A <: Product: TypeTag](data: A*): ZIO[SparkModule, Throwable, Dataset[A]] = + ZIO.accessM[SparkModule](_.get.createDatasetOf(data: _*)) + + def getConfigValue(key: String): ZIO[SparkModule, NoSuchElementException, String] = + ZIO.accessM[SparkModule](_.get.getConfigValue(key)) + + def makeDataset[A](func: SparkSession => Dataset[A]): RIO[SparkModule, Dataset[A]] = + ZIO.accessM(_.get.makeDataset(func)) + + def printSchema(dataFrame: DataFrame): URIO[SparkModule, Unit] = + ZIO.accessM(_.get.printSchema(dataFrame)) + + val sparkSession: URIO[SparkModule, SparkSession] = + ZIO.access(_.get.sparkSession) + + def withSpark[A](func: SparkSession => A): ZIO[SparkModule, Throwable, A] = + ZIO.accessM[SparkModule](_.get.withSpark(func)) + + object SparkModule { + trait Service extends Serializable { + + /** + * Get access to an instance of the `SparkSession`. + */ + def sparkSession: SparkSession + + def createDataset[A <: Product: TypeTag](data: Seq[A]): Task[Dataset[A]] + def createDatasetOf[A <: Product: TypeTag](items: A*): Task[Dataset[A]] + def getConfigValue(key: String): IO[NoSuchElementException, String] + def makeDataset[A](func: SparkSession => Dataset[A]): Task[Dataset[A]] + + def printSchema(dataFrame: DataFrame): UIO[Unit] + def withSpark[A](func: SparkSession => A): Task[A] + } + + object Service { + def live(sparkSession: SparkSession): Service = + Live(sparkSession) + + def makeLive(builder: SparkSession.Builder): Task[Service] = Task.effect { + live(builder.getOrCreate().newSession()) + } + + final case class Live(sparkSession: SparkSession) extends Service { + def createDataset[A <: Product: TypeTag](data: Seq[A]): Task[Dataset[A]] = Task.effect { + import sparkSession.implicits._ + sparkSession.createDataset(data) + } + + def createDatasetOf[A <: Product: TypeTag](items: A*): Task[Dataset[A]] = Task.effect { + import sparkSession.implicits._ + sparkSession.createDataset(items) + } + + def getConfigValue(key: String): IO[NoSuchElementException, String] = Task.effect { + sparkSession.conf.get(key) + }.refineToOrDie[NoSuchElementException] + + def makeDataset[A](func: SparkSession => Dataset[A]): Task[Dataset[A]] = Task.effect { + func(sparkSession) + } + + override def printSchema(dataFrame: DataFrame): UIO[Unit] = UIO { + dataFrame.printSchema() + } + + def withSpark[A](func: SparkSession => A): Task[A] = Task.effect(func(sparkSession)) + } + } + + def buildLayer(builder: SparkSession.Builder, newSession: Boolean = false): ZLayer[Any, Throwable, SparkModule] = + ZLayer.suspend { + val sparkSessionLayer = Task.effect { + if (newSession) builder.getOrCreate().newSession() + else builder.getOrCreate() + }.toLayer + sparkSessionLayer >>> fromSession + } + + @nowarn + val fromSparkSessionBuilder: ZLayer[Has[SparkSession.Builder], Throwable, SparkModule] = + ZLayer.fromServiceM { builder: SparkSession.Builder => + Service.makeLive(builder) + } + + val fromSession: ZLayer[Has[SparkSession], Nothing, SparkModule] = + ZLayer.fromService { spark: SparkSession => + Service.live(spark) + } + } +}