Skip to content

Commit

Permalink
Add deriveSingletonJSON
Browse files Browse the repository at this point in the history
  • Loading branch information
benko-ct committed Jun 18, 2024
1 parent 76e0073 commit 62a6b76
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,15 @@ class AnnotationReader(using q: Quotes):
'{ Map($subtypes*) }

end AnnotationReader

object AnnotationReader:
inline def readCaseClassMetaData[T]: CaseClassMetaData = ${ readCaseClassMetaDataImpl[T] }

inline def readTraitMetaData[T]: TraitMetaData = ${ readTraitMetaDataImpl[T] }

private def readCaseClassMetaDataImpl[T: Type](using Quotes): Expr[CaseClassMetaData] =
AnnotationReader().readCaseClassMetaData[T]

private def readTraitMetaDataImpl[T: Type](using Quotes): Expr[TraitMetaData] =
AnnotationReader().readTraitMetaData[T]
end AnnotationReader
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ import cats.implicits.*
import io.sphere.json.{JSON, JSONParseError, JValidation}
import org.json4s.DefaultJsonFormats.given
import org.json4s.JsonAST.JValue
import org.json4s.{DefaultJsonFormats, JNull, JObject, JString, jvalue2monadic, jvalue2readerSyntax}
import org.json4s.{DefaultJsonFormats, JObject, JString, jvalue2monadic, jvalue2readerSyntax}

import scala.deriving.Mirror
import scala.quoted.*

inline def deriveJSON[A](using Mirror.Of[A]): JSON[A] = JSON.derived

Expand All @@ -17,14 +16,6 @@ object JSON:

inline def apply[A: JSON]: JSON[A] = summon[JSON[A]]
inline given derived[A](using Mirror.Of[A]): JSON[A] = Derivation.derived[A]
private inline def readCaseClassMetaData[T]: CaseClassMetaData = ${ readCaseClassMetaDataImpl[T] }
private inline def readTraitMetaData[T]: TraitMetaData = ${ readTraitMetaDataImpl[T] }

private def readCaseClassMetaDataImpl[T: Type](using Quotes): Expr[CaseClassMetaData] =
AnnotationReader().readCaseClassMetaData[T]

private def readTraitMetaDataImpl[T: Type](using Quotes): Expr[TraitMetaData] =
AnnotationReader().readTraitMetaData[T]

private def addField(jObject: JObject, field: Field, jValue: JValue): JValue =
jValue match
Expand All @@ -44,7 +35,7 @@ object JSON:

inline private def deriveTrait[A](mirrorOfSum: Mirror.SumOf[A]): JSON[A] =
new JSON[A]:
private val traitMetaData: TraitMetaData = readTraitMetaData[A]
private val traitMetaData: TraitMetaData = AnnotationReader.readTraitMetaData[A]
private val typeHintMap: Map[String, String] = traitMetaData.subtypes.collect {
case (name, classMeta) if classMeta.typeHint.isDefined =>
name -> classMeta.typeHint.get
Expand All @@ -61,11 +52,9 @@ object JSON:
case jObject: JObject =>
val typeName = (jObject \ traitMetaData.typeDiscriminator).as[String]
val originalTypeName = reverseTypeHintMap.getOrElse(typeName, typeName)
val parsed = jsonsByNames(originalTypeName).read(jObject)
parsed.map(_.asInstanceOf[A])
jsonsByNames(originalTypeName).read(jObject).map(_.asInstanceOf[A])
case x =>
Validated.invalidNel(
JSONParseError(s"JSON object expected. >>> trait >>> $jValue >>> ${traitMetaData}"))
Validated.invalidNel(JSONParseError(s"JSON object expected. Got: '$jValue'"))

override def write(value: A): JValue =
// we never get a trait here, only classes, it's safe to assume Product
Expand All @@ -79,7 +68,7 @@ object JSON:

inline private def deriveCaseClass[A](mirrorOfProduct: Mirror.ProductOf[A]): JSON[A] =
new JSON[A]:
private val caseClassMetaData: CaseClassMetaData = readCaseClassMetaData[A]
private val caseClassMetaData: CaseClassMetaData = AnnotationReader.readCaseClassMetaData[A]
private val jsons: Vector[JSON[Any]] = summonFormatters[mirrorOfProduct.MirroredElemTypes]
private val fieldsAndJsons: Vector[(Field, JSON[Any])] = caseClassMetaData.fields.zip(jsons)

Expand Down Expand Up @@ -112,7 +101,7 @@ object JSON:
fieldsAsTuple.asInstanceOf[mirrorOfProduct.MirroredElemTypes])

case x =>
Validated.invalidNel(JSONParseError(s"JSON object expected. ${x}"))
Validated.invalidNel(JSONParseError(s"JSON object expected. $x"))

private def readField(field: Field, json: JSON[Any], jObject: JObject): JValidation[Any] =
if (field.embedded) json.read(jObject)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package io.sphere.json.generic

import cats.data.Validated
import io.sphere.json.{JSON, JSONParseError, JValidation}
import org.json4s.{JNull, JString, JValue}

import scala.deriving.Mirror

inline def deriveSingletonJSON[A](using Mirror.Of[A]): JSON[A] = DeriveSingleton.derived

object DeriveSingleton {

inline given derived[A](using Mirror.Of[A]): JSON[A] = Derivation.derived[A]

private object Derivation {

import scala.compiletime.{constValue, constValueTuple, erasedValue, summonInline}

inline def derived[A](using m: Mirror.Of[A]): JSON[A] =
inline m match
case s: Mirror.SumOf[A] => deriveTrait(s)
case p: Mirror.ProductOf[A] => deriveObject(p)

inline private def deriveTrait[A](mirrorOfSum: Mirror.SumOf[A]): JSON[A] =
new JSON[A]:
private val traitMetaData: TraitMetaData = AnnotationReader.readTraitMetaData[A]
private val typeHintMap: Map[String, String] = traitMetaData.subtypes.collect {
case (name, classMeta) if classMeta.typeHint.isDefined =>
name -> classMeta.typeHint.get
}
private val reverseTypeHintMap: Map[String, String] = typeHintMap.map((on, n) => (n, on))
private val jsons: Seq[JSON[Any]] = summonFormatters[mirrorOfSum.MirroredElemTypes]
private val names: Seq[String] =
constValueTuple[mirrorOfSum.MirroredElemLabels].productIterator.toVector
.asInstanceOf[Vector[String]]
private val jsonsByNames: Map[String, JSON[Any]] = names.zip(jsons).toMap

override def read(jValue: JValue): JValidation[A] =
jValue match
case JString(typeName) =>
val originalTypeName = reverseTypeHintMap.getOrElse(typeName, typeName)
jsonsByNames.get(originalTypeName) match
case Some(json) =>
json.read(JNull).map(_.asInstanceOf[A])
case None =>
Validated.invalidNel(JSONParseError(s"'$typeName' is not a valid value"))

case x =>
Validated.invalidNel(JSONParseError(s"JSON string expected. Got >>> $jValue"))

override def write(value: A): JValue =
val originalTypeName = value.asInstanceOf[Product].productPrefix
val typeName = typeHintMap.getOrElse(originalTypeName, originalTypeName)
JString(typeName)

end deriveTrait

inline private def deriveObject[A](mirrorOfProduct: Mirror.ProductOf[A]): JSON[A] =
new JSON[A]:
override def write(value: A): JValue = ??? // This is already taken care of in `deriveTrait`
override def read(jValue: JValue): JValidation[A] =
// Just create the object instance, no need to do anything else
val tuple = Tuple.fromArray(Array.empty[Any])
val obj = mirrorOfProduct.fromTuple(tuple.asInstanceOf[mirrorOfProduct.MirroredElemTypes])
Validated.Valid(obj)
end deriveObject

inline private def summonFormatters[T <: Tuple]: Vector[JSON[Any]] =
inline erasedValue[T] match
case _: EmptyTuple => Vector.empty
case _: (t *: ts) =>
summonInline[JSON[t]]
.asInstanceOf[JSON[Any]] +: summonFormatters[ts]

}

}
Loading

0 comments on commit 62a6b76

Please sign in to comment.