diff --git a/mongo/mongo-derivation/src/main/scala/io/sphere/mongo/generic/MongoFormatMacros.scala b/mongo/mongo-derivation/src/main/scala/io/sphere/mongo/generic/MongoFormatMacros.scala index 0390ecd6..571146bb 100644 --- a/mongo/mongo-derivation/src/main/scala/io/sphere/mongo/generic/MongoFormatMacros.scala +++ b/mongo/mongo-derivation/src/main/scala/io/sphere/mongo/generic/MongoFormatMacros.scala @@ -107,6 +107,7 @@ private[generic] object MongoFormatMacros { ) else { val subtypes = collectKnownSubtypes(c)(symbol) + val subtypesWithNoFormatter = subtypes.filterNot(mongoFormatExists(c)) val idents = Ident(symbol.name) :: subtypes.map { s => if (s.isModuleClass) New(TypeTree(s.asClass.toType)) else Ident(s.name) }.toList @@ -114,7 +115,7 @@ private[generic] object MongoFormatMacros { if (idents.size == 1) c.abort(c.enclosingPosition, "Subtypes not found.") else { - val instanceDefs = subtypes.zipWithIndex.collect { + val instanceDefs = subtypesWithNoFormatter.zipWithIndex.collect { case (symbol, i) if symbol.isClass && symbol.asClass.isCaseClass => if (symbol.asClass.typeParams.nonEmpty) { c.abort( @@ -152,4 +153,14 @@ private[generic] object MongoFormatMacros { } } } + + private def mongoFormatExists(c: blackbox.Context)(s: c.universe.Symbol) = { + val typeName = s.asType.toType + try { + c.typecheck(c.parse(s"implicitly[MongoFormat[$typeName]]")) + true + } catch { + case _: Throwable => false + } + } } diff --git a/mongo/mongo-derivation/src/test/scala/io/sphere/mongo/generic/SumTypesDerivingSpec.scala b/mongo/mongo-derivation/src/test/scala/io/sphere/mongo/generic/SumTypesDerivingSpec.scala index 37c8ce9e..c919647a 100644 --- a/mongo/mongo-derivation/src/test/scala/io/sphere/mongo/generic/SumTypesDerivingSpec.scala +++ b/mongo/mongo-derivation/src/test/scala/io/sphere/mongo/generic/SumTypesDerivingSpec.scala @@ -6,6 +6,7 @@ import org.scalatest.wordspec.AnyWordSpec import io.sphere.mongo.MongoUtils.dbObj import io.sphere.mongo.format.DefaultMongoFormats._ import io.sphere.mongo.format.MongoFormat +import org.bson.BSONObject import org.scalatest.Assertion class SumTypesDerivingSpec extends AnyWordSpec with Matchers { @@ -73,6 +74,22 @@ class SumTypesDerivingSpec extends AnyWordSpec with Matchers { check(Color9.format, Color9.Custom("2356"), dbObj("type" -> "Custom", "rgb" -> "2356")) } + + "Formatters could be overridden for objects" in { + check(Color10.format, Color10.Red, dbObj("type" -> "Red", "extraField" -> "panda")) + + check(Color10.format, Color10.Custom("2356"), dbObj("type" -> "Custom", "rgb" -> "2356")) + } + + "Formatters could be overridden for classes" in { + check(Color11.format, Color11.Red, dbObj("type" -> "Red")) + + check( + Color11.format, + Color11.Custom("2356"), + dbObj("type" -> "Custom", "rgb" -> "2356", "extraField" -> "panda")) + } + } } @@ -171,4 +188,30 @@ object SumTypesDerivingSpec { val format = deriveMongoFormat[Color9] } + sealed trait Color10 + object Color10 { + case object Red extends Color10 + case class Custom(rgb: String) extends Color10 + + implicit val redFormatter: MongoFormat[Red.type] = new MongoFormat[Red.type] { + override def toMongoValue(a: Red.type): Any = dbObj("type" -> "Red", "extraField" -> "panda") + override def fromMongoValue(any: Any): Red.type = Red + } + val format = deriveMongoFormat[Color10] + } + + sealed trait Color11 + object Color11 { + case object Red extends Color11 + case class Custom(rgb: String) extends Color11 + + implicit val customFormatter: MongoFormat[Custom] = new MongoFormat[Custom] { + override def toMongoValue(a: Custom): Any = + dbObj("type" -> "Custom", "rgb" -> a.rgb, "extraField" -> "panda") + override def fromMongoValue(any: Any): Custom = + Custom(any.asInstanceOf[BSONObject].get("rgb").asInstanceOf[String]) + } + val format = deriveMongoFormat[Color11] + } + }