diff --git a/core/shared/src/main/scala-3/org/virtuslab/yaml/YamlDecoderCrossCompat.scala b/core/shared/src/main/scala-3/org/virtuslab/yaml/YamlDecoderCrossCompat.scala index 2fea94c2e..40af17bc7 100644 --- a/core/shared/src/main/scala-3/org/virtuslab/yaml/YamlDecoderCrossCompat.scala +++ b/core/shared/src/main/scala-3/org/virtuslab/yaml/YamlDecoderCrossCompat.scala @@ -3,6 +3,7 @@ package org.virtuslab.yaml import org.virtuslab.yaml.Node.* import scala.compiletime.* +import scala.quoted.* import scala.deriving.Mirror private[yaml] trait YamlDecoderCompanionCrossCompat extends DecoderMacros { @@ -12,70 +13,11 @@ private[yaml] trait YamlDecoderCompanionCrossCompat extends DecoderMacros { } private[yaml] trait DecoderMacros { - protected def extractKeyValues( - mappings: Map[Node, Node] - ): Either[ConstructError, Map[String, Node]] = { - val keyValueMap = mappings - .map { (k, v) => - k match { - case ScalarNode(scalarKey, _) => Right((scalarKey, v)) - case node => - Left(ConstructError.from(s"Parameter of a class must be a scalar value", node)) - } - } - val (error, valuesSeq) = keyValueMap.partitionMap(identity) - - if (error.nonEmpty) Left(error.head) - else Right(valuesSeq.toMap) - } - protected def constructValues[T]( - elemLabels: List[String], - instances: List[YamlDecoder[_]], - optionalTypes: List[Boolean], - valuesMap: Map[String, Node], - p: Mirror.ProductOf[T], - parentNode: Node - ) = { - val values = elemLabels.zip(instances).zip(optionalTypes).map { case ((label, c), isOptional) => - valuesMap.get(label) match - case Some(value) => c.construct(value) - case None => - if (isOptional) Right(None) - else Left(ConstructError.from(s"Key $label doesn't exist in parsed document", parentNode)) - } - val (left, right) = values.partitionMap(identity) - if left.nonEmpty then Left(left.head) - else Right(p.fromProduct(Tuple.fromArray(right.toArray))) + protected inline def deriveProduct[T](p: Mirror.ProductOf[T]) = ${ + DecoderMacros.deriveProductImpl[T]('p) } - protected inline def deriveProduct[T](p: Mirror.ProductOf[T]) = - val instances = summonAll[p.MirroredElemTypes] - val elemLabels = getElemLabels[p.MirroredElemLabels] - val optionalTypes = getOptionalTypes[p.MirroredElemTypes] - new YamlDecoder[T] { - override def construct(node: Node)(using - constructor: LoadSettings = LoadSettings.empty - ): Either[ConstructError, T] = - node match - case Node.MappingNode(mappings, _) => - for { - valuesMap <- extractKeyValues(mappings) - constructedValues <- constructValues( - elemLabels, - instances, - optionalTypes, - valuesMap, - p, - node - ) - } yield (constructedValues) - case _ => - Left( - ConstructError.from(s"Expected MappingNode, got ${node.getClass.getSimpleName}", node) - ) - } - protected inline def sumOf[T](s: Mirror.SumOf[T]) = val instances = summonSumOf[s.MirroredElemTypes].asInstanceOf[List[YamlDecoder[T]]] new YamlDecoder[T]: @@ -94,17 +36,164 @@ private[yaml] trait DecoderMacros { } case _: EmptyTuple => Nil - protected inline def summonAll[T <: Tuple]: List[YamlDecoder[_]] = inline erasedValue[T] match - case _: EmptyTuple => Nil - case _: (t *: ts) => summonInline[YamlDecoder[t]] :: summonAll[ts] +} + +object DecoderMacros { + + protected def constructValues[T]( + instances: List[(String, YamlDecoder[?], Boolean)], + valuesMap: Map[String, Node], + defaultParams: Map[String, () => Any], + p: Mirror.ProductOf[T], + parentNode: Node + ): Either[ConstructError, T] = { + val values = instances.map { case (label, c, isOptional) => + valuesMap.get(label) match + case Some(value) => c.construct(value) + case None => + if (isOptional) Right(None) + else if (defaultParams.contains(label)) + val defaultParamCreator = defaultParams(label) + val defaultParamValue = defaultParamCreator() + Right(defaultParamValue) + else Left(ConstructError.from(s"Key $label doesn't exist in parsed document", parentNode)) + } + val (left, right) = values.partitionMap(identity) + if left.nonEmpty then Left(left.head) + else Right(p.fromProduct(Tuple.fromArray(right.toArray))) + } + + private def extractKeyValues( + mappings: Map[Node, Node] + ): Either[ConstructError, Map[String, Node]] = { + val keyValueMap = mappings + .map { (k, v) => + k match { + case ScalarNode(scalarKey, _) => Right((scalarKey, v)) + case node => + Left(ConstructError.from(s"Parameter of a class must be a scalar value", node)) + } + } + val (error, valuesSeq) = keyValueMap.partitionMap(identity) + + if (error.nonEmpty) Left(error.head) + else Right(valuesSeq.toMap) + } + + def deriveProductImpl[T: Type](p: Expr[Mirror.ProductOf[T]])(using + Quotes + ): Expr[YamlDecoder[T]] = + + // returns a list of tuples of label, instance, isOptional + def prepareInstances( + elemLabels: Type[?], + elemTypes: Type[?] + ): List[Expr[(String, YamlDecoder[?], Boolean)]] = + (elemLabels, elemTypes) match + case ('[EmptyTuple], '[EmptyTuple]) => Nil + case ('[label *: labelsTail], '[tpe *: tpesTail]) => + val label = Type.valueOfConstant[label].get.asInstanceOf[String] + val isOption = Type.of[tpe] match + case '[Option[?]] => Expr(true) + case _ => Expr(false) + + val fieldName = Expr(label) + val fieldFormat = Expr.summon[YamlDecoder[tpe]].getOrElse { + quotes.reflect.report + .errorAndAbort("Missing given instance of YamlDecoder[" ++ Type.show[tpe] ++ "]") + } + val namedInstance = '{ (${ fieldName }, $fieldFormat, ${ isOption }) } + namedInstance :: prepareInstances(Type.of[labelsTail], Type.of[tpesTail]) + + p match + case '{ + $m: Mirror.ProductOf[T] { + type MirroredElemLabels = elementLabels; type MirroredElemTypes = elementTypes + } + } => + val allInstancesExpr = + Expr.ofList(prepareInstances(Type.of[elementLabels], Type.of[elementTypes])) + val defaultParamsExpr = findDefaultParams[T] - protected inline def getElemLabels[T <: Tuple]: List[String] = inline erasedValue[T] match - case _: EmptyTuple => Nil - case _: (head *: tail) => constValue[head].toString :: getElemLabels[tail] + '{ + new YamlDecoder[T] { + private val allInstances = $allInstancesExpr + private val defaultParams = $defaultParamsExpr + private val mirror = $p - protected inline def getOptionalTypes[T <: Tuple]: List[Boolean] = inline erasedValue[T] match - case _: EmptyTuple => Nil - case _: (Option[_] *: tail) => true :: getOptionalTypes[tail] - case _: (_ *: tail) => false :: getOptionalTypes[tail] + override def construct(node: Node)(using + constructor: LoadSettings = LoadSettings.empty + ): Either[ConstructError, T] = + node match + case Node.MappingNode(mappings, _) => + for { + valuesMap <- extractKeyValues(mappings) + constructedValues <- constructValues( + allInstances, + valuesMap, + defaultParams, + mirror, + node + ) + } yield (constructedValues) + case _ => + Left( + ConstructError.from( + s"Expected MappingNode, got ${node.getClass.getSimpleName}", + node + ) + ) + } + } + + private val DefaultParamPrefix = "$lessinit$greater$default$" + + protected def findDefaultParams[T](using + quotes: Quotes, + tpe: Type[T] + ): Expr[Map[String, () => Any]] = + import quotes.reflect.* + + TypeRepr.of[T].classSymbol match + case None => '{ Map.empty[String, () => Any] } + case Some(sym: Symbol) => + try + val comp = sym.companionClass + val mod = Ref(sym.companionModule) + val paramWithDefaultMeta = + for (p, idx) <- sym.caseFields.zipWithIndex if p.flags.is(Flags.HasDefault) + // +1 because the names are generated starting from 1 + yield (p.name, idx + 1) + + val idents: List[(String, Ref)] = + for (paramName, idx) <- paramWithDefaultMeta + yield paramName -> mod.select( + // head is safe here because we know there has to be a getter for the default value + // because we checked for HasDefault flag + comp.methodMember(DefaultParamPrefix + idx.toString).head + ) + + val typeArgs = TypeRepr.of[T].typeArgs + + // we create an expression of a list of tuples of name and thunks that return the default value for a given parameter + val defaultsThunksExpr: Expr[List[(String, () => Any)]] = + if typeArgs.isEmpty then + Expr.ofList( + idents.map { case (name, ref) => name -> ref.asExpr }.map { case (name, '{ $x }) => + '{ (${ Expr(name) }, () => $x) } + } + ) + else // if there are type parameters, we need to apply the type parameters to accessors + Expr.ofList( + idents.map { case (name, ref) => name -> ref.appliedToTypes(typeArgs).asExpr }.map { + case (name, '{ $x }) => '{ (${ Expr(name) }, () => $x) } + } + ) + '{ $defaultsThunksExpr.toMap } + catch // TODO drop after https://github.com/lampepfl/dotty/issues/19732 (after bump to 3.3.4) + case cce: ClassCastException => + '{ + Map.empty[String, () => Any] + } } diff --git a/core/shared/src/test/scala-3/org/virtuslab/yaml/decoder/DecoderSuite.scala b/core/shared/src/test/scala-3/org/virtuslab/yaml/decoder/DecoderSuite.scala index 59a883f7f..f7b9f6a85 100644 --- a/core/shared/src/test/scala-3/org/virtuslab/yaml/decoder/DecoderSuite.scala +++ b/core/shared/src/test/scala-3/org/virtuslab/yaml/decoder/DecoderSuite.scala @@ -502,3 +502,66 @@ class DecoderSuite extends munit.FunSuite: assert(error.msg.contains("Could't construct int from null (tag:yaml.org,2002:null)")) case Right(data) => fail(s"expected failure, but got: $data") } + + test("default parameters for case classes can be used when decoding") { + case class Foo(a: Int = 1, b: String = "test", c: Option[Int] = None, d: Double) + derives YamlCodec + + val yaml = """d: 1.0""".stripMargin + + yaml.as[Foo] match + case Left(error: YamlError) => + fail(s"failed with YamlError: $error") + case Right(foo) => + assertEquals(foo.a, 1) + assertEquals(foo.b, "test") + assertEquals(foo.c, None) + assertEquals(foo.d, 1.0) + } + + test("default parameters for case classes are evaluated lazily") { + var times = 0 + def createB = { + times += 1 + s"test-${times}" + } + case class Foo(a: Int, b: String = createB) derives YamlCodec + + val yaml = """a: 1""".stripMargin + + yaml.as[Foo] match + case Left(error: YamlError) => + fail(s"failed with YamlError: $error") + case Right(foo) => + assertEquals(foo.a, 1) + assertEquals(foo.b, "test-1") + + yaml.as[Foo] // skip test-2 + + yaml.as[Foo] match + case Left(error: YamlError) => + fail(s"failed with YamlError: $error") + case Right(foo) => + assertEquals(foo.a, 1) + assertEquals(foo.b, "test-3") + } + + test("default parameters are not evaluated when they are provided in yaml") { + var evaluated = false + def createB = { + evaluated = true + "default" + } + case class Foo(a: Int, b: String = createB) derives YamlCodec + + val yaml = """a: 1 + |b: from yaml""".stripMargin + + yaml.as[Foo] match + case Left(error: YamlError) => + fail(s"failed with YamlError: $error") + case Right(foo) => + assertEquals(foo.a, 1) + assertEquals(foo.b, "from yaml") + assert(!evaluated) + }