diff --git a/mongo/mongo-3/src/main/scala/io/sphere/mongo/format/MongoFormat.scala b/mongo/mongo-3/src/main/scala/io/sphere/mongo/format/MongoFormat.scala index a1ee657..ffcf985 100644 --- a/mongo/mongo-3/src/main/scala/io/sphere/mongo/format/MongoFormat.scala +++ b/mongo/mongo-3/src/main/scala/io/sphere/mongo/format/MongoFormat.scala @@ -1,7 +1,7 @@ package io.sphere.mongo.format import com.mongodb.BasicDBObject -import io.sphere.mongo.generic.{AnnotationReader, Field, TraitMetaData} +import io.sphere.mongo.generic.{AnnotationReader, Field} import org.bson.types.ObjectId import java.util.UUID @@ -56,17 +56,20 @@ object MongoFormat { case p: Mirror.ProductOf[A] => deriveCaseClass(p) } - inline private def deriveTrait[A](mirrorOfSum: Mirror.SumOf[A]): MongoFormat[A] = { - val traitMetaData = AnnotationReader.readTraitMetaData[A] - val typeHintMap = traitMetaData.subTypeTypeHints - val reverseTypeHintMap = typeHintMap.map((on, n) => (n, on)) - val formatters = summonFormatters[mirrorOfSum.MirroredElemTypes] - val names = constValueTuple[mirrorOfSum.MirroredElemLabels].productIterator.toVector - .asInstanceOf[Vector[String]] - val formattersByTypeName = names.zip(formatters).toMap - - MongoFormat.create[A]( - toMongo = { a => + inline private def deriveTrait[A](mirrorOfSum: Mirror.SumOf[A]): MongoFormat[A] = + new MongoFormat[A] { + private val traitMetaData = AnnotationReader.readTraitMetaData[A] + private val typeHintMap = traitMetaData.subtypes.collect { + case (name, classMeta) if classMeta.typeHint.isDefined => + name -> classMeta.typeHint.get + } + private val reverseTypeHintMap = typeHintMap.map((on, n) => (n, on)) + private val formatters = summonFormatters[mirrorOfSum.MirroredElemTypes] + private val names = constValueTuple[mirrorOfSum.MirroredElemLabels].productIterator.toVector + .asInstanceOf[Vector[String]] + private val formattersByTypeName = names.zip(formatters).toMap + + override def toMongoValue(a: A): Any = { // we never get a trait here, only classes, it's safe to assume Product val originalTypeName = a.asInstanceOf[Product].productPrefix val typeName = typeHintMap.getOrElse(originalTypeName, originalTypeName) @@ -74,28 +77,30 @@ object MongoFormat { formattersByTypeName(originalTypeName).toMongoValue(a).asInstanceOf[BasicDBObject] bson.put(traitMetaData.typeDiscriminator, typeName) bson - }, - fromMongo = { - case bson: BasicDBObject => - val typeName = bson.get(traitMetaData.typeDiscriminator).asInstanceOf[String] - val originalTypeName = reverseTypeHintMap.getOrElse(typeName, typeName) - formattersByTypeName(originalTypeName).fromMongoValue(bson).asInstanceOf[A] - case x => - throw new Exception(s"BsonObject is expected for a Trait subtype, instead got $x") } - ) - } - inline private def deriveCaseClass[A](mirrorOfProduct: Mirror.ProductOf[A]): MongoFormat[A] = { - val caseClassMetaData = AnnotationReader.readCaseClassMetaData[A] - val formatters = summonFormatters[mirrorOfProduct.MirroredElemTypes] - val fieldsAndFormatters = caseClassMetaData.fields.zip(formatters) + override def fromMongoValue(bson: Any): A = + bson match { + case bson: BasicDBObject => + val typeName = bson.get(traitMetaData.typeDiscriminator).asInstanceOf[String] + val originalTypeName = reverseTypeHintMap.getOrElse(typeName, typeName) + formattersByTypeName(originalTypeName).fromMongoValue(bson).asInstanceOf[A] + case x => + throw new Exception(s"BsonObject is expected for a Trait subtype, instead got $x") + } + } + + inline private def deriveCaseClass[A](mirrorOfProduct: Mirror.ProductOf[A]): MongoFormat[A] = + new MongoFormat[A] { + private val caseClassMetaData = AnnotationReader.readCaseClassMetaData[A] + private val formatters = summonFormatters[mirrorOfProduct.MirroredElemTypes] + private val fieldsAndFormatters = caseClassMetaData.fields.zip(formatters) - MongoFormat.create( - fields = fieldsAndFormatters.flatMap((field, formatter) => + override val fieldNames: Vector[String] = fieldsAndFormatters.flatMap((field, formatter) => if (field.embedded) formatter.fieldNames :+ field.rawName - else Vector(field.rawName)), - toMongo = { a => + else Vector(field.rawName)) + + override def toMongoValue(a: A): Any = { val bson = new BasicDBObject() val values = a.asInstanceOf[Product].productIterator formatters.zip(values).zip(caseClassMetaData.fields).foreach { @@ -103,36 +108,37 @@ object MongoFormat { addField(bson, field, format.toMongoValue(value)) } bson - }, - fromMongo = { - case bson: BasicDBObject => - val fields = fieldsAndFormatters - .map { (field, format) => - def defaultValue = field.defaultArgument.orElse(format.default) - - if (field.ignored) - defaultValue.getOrElse { - throw new Exception( - s"Missing default parameter value for ignored field `${field.name}` on deserialization.") - } - else if (field.embedded) format.fromMongoValue(bson) - else { - val value = bson.get(field.name) - if (value ne null) format.fromMongoValue(value.asInstanceOf[Any]) - else + } + + override def fromMongoValue(mongoType: Any): A = + mongoType match { + case bson: BasicDBObject => + val fields = fieldsAndFormatters + .map { case (field, format) => + def defaultValue = field.defaultArgument.orElse(format.default) + + if (field.ignored) defaultValue.getOrElse { throw new Exception( - s"Missing required field '${field.name}' on deserialization.") + s"Missing default parameter value for ignored field `${field.name}` on deserialization.") } + else if (field.embedded) format.fromMongoValue(bson) + else { + val value = bson.get(field.name) + if (value ne null) format.fromMongoValue(value.asInstanceOf[Any]) + else + defaultValue.getOrElse { + throw new Exception( + s"Missing required field '${field.name}' on deserialization.") + } + } } - } - val tuple = Tuple.fromArray(fields.toArray) - mirrorOfProduct.fromTuple(tuple.asInstanceOf[mirrorOfProduct.MirroredElemTypes]) + val tuple = Tuple.fromArray(fields.toArray) + mirrorOfProduct.fromTuple(tuple.asInstanceOf[mirrorOfProduct.MirroredElemTypes]) - case x => throw new Exception(s"BasicDBObject is expected for a class, instead got: $x") - } - ) - } + case x => throw new Exception(s"BasicDBObject is expected for a class, instead got: $x") + } + } inline private def summonFormatters[T <: Tuple]: Vector[MongoFormat[Any]] = inline erasedValue[T] match { @@ -143,17 +149,4 @@ object MongoFormat { } } - - // This is needed to remove the "New anonymous class definition will be duplicated at each inline site" warnings - private def create[A]( - toMongo: A => Any, - fromMongo: Any => A, - fields: Vector[String] = MongoFormat.emptyFields): MongoFormat[A] = - new MongoFormat[A] { - override def toMongoValue(a: A): Any = toMongo(a) - - override def fromMongoValue(mongoType: Any): A = fromMongo(mongoType) - - override val fieldNames: Vector[String] = fields - } } diff --git a/mongo/mongo-3/src/main/scala/io/sphere/mongo/generic/AnnotationReader.scala b/mongo/mongo-3/src/main/scala/io/sphere/mongo/generic/AnnotationReader.scala index cc991d9..7b5c5a9 100644 --- a/mongo/mongo-3/src/main/scala/io/sphere/mongo/generic/AnnotationReader.scala +++ b/mongo/mongo-3/src/main/scala/io/sphere/mongo/generic/AnnotationReader.scala @@ -28,10 +28,6 @@ case class TraitMetaData( subtypes: Map[String, CaseClassMetaData] ) { val typeDiscriminator: String = typeHintFieldRaw.map(_.value).getOrElse("type") - - val subTypeTypeHints: Map[String, String] = subtypes.collect { - case (name, classMeta) if classMeta.typeHint.isDefined => name -> classMeta.typeHint.get - } } object AnnotationReader {