Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added handling for default parameters #326

Merged
merged 4 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.virtuslab.yaml
import org.virtuslab.yaml.Node.*

import scala.compiletime.*
import scala.quoted.*
import scala.deriving.Mirror

private[yaml] trait YamlDecoderCompanionCrossCompat extends DecoderMacros {
Expand All @@ -12,70 +13,11 @@ private[yaml] trait YamlDecoderCompanionCrossCompat extends DecoderMacros {
}

private[yaml] trait DecoderMacros {
protected def extractKeyValues(
mappings: Map[Node, Node]
): Either[ConstructError, Map[String, Node]] = {
val keyValueMap = mappings
.map { (k, v) =>
k match {
case ScalarNode(scalarKey, _) => Right((scalarKey, v))
case node =>
Left(ConstructError.from(s"Parameter of a class must be a scalar value", node))
}
}
val (error, valuesSeq) = keyValueMap.partitionMap(identity)

if (error.nonEmpty) Left(error.head)
else Right(valuesSeq.toMap)
}

protected def constructValues[T](
elemLabels: List[String],
instances: List[YamlDecoder[_]],
optionalTypes: List[Boolean],
valuesMap: Map[String, Node],
p: Mirror.ProductOf[T],
parentNode: Node
) = {
val values = elemLabels.zip(instances).zip(optionalTypes).map { case ((label, c), isOptional) =>
valuesMap.get(label) match
case Some(value) => c.construct(value)
case None =>
if (isOptional) Right(None)
else Left(ConstructError.from(s"Key $label doesn't exist in parsed document", parentNode))
}
val (left, right) = values.partitionMap(identity)
if left.nonEmpty then Left(left.head)
else Right(p.fromProduct(Tuple.fromArray(right.toArray)))
protected inline def deriveProduct[T](p: Mirror.ProductOf[T]) = ${
DecoderMacros.deriveProductImpl[T]('p)
}

protected inline def deriveProduct[T](p: Mirror.ProductOf[T]) =
val instances = summonAll[p.MirroredElemTypes]
val elemLabels = getElemLabels[p.MirroredElemLabels]
val optionalTypes = getOptionalTypes[p.MirroredElemTypes]
new YamlDecoder[T] {
override def construct(node: Node)(using
constructor: LoadSettings = LoadSettings.empty
): Either[ConstructError, T] =
node match
case Node.MappingNode(mappings, _) =>
for {
valuesMap <- extractKeyValues(mappings)
constructedValues <- constructValues(
elemLabels,
instances,
optionalTypes,
valuesMap,
p,
node
)
} yield (constructedValues)
case _ =>
Left(
ConstructError.from(s"Expected MappingNode, got ${node.getClass.getSimpleName}", node)
)
}

protected inline def sumOf[T](s: Mirror.SumOf[T]) =
val instances = summonSumOf[s.MirroredElemTypes].asInstanceOf[List[YamlDecoder[T]]]
new YamlDecoder[T]:
Expand All @@ -94,17 +36,164 @@ private[yaml] trait DecoderMacros {
}
case _: EmptyTuple => Nil

protected inline def summonAll[T <: Tuple]: List[YamlDecoder[_]] = inline erasedValue[T] match
case _: EmptyTuple => Nil
case _: (t *: ts) => summonInline[YamlDecoder[t]] :: summonAll[ts]
}

object DecoderMacros {

protected def constructValues[T](
instances: List[(String, YamlDecoder[?], Boolean)],
valuesMap: Map[String, Node],
defaultParams: Map[String, () => Any],
p: Mirror.ProductOf[T],
parentNode: Node
): Either[ConstructError, T] = {
val values = instances.map { case (label, c, isOptional) =>
valuesMap.get(label) match
case Some(value) => c.construct(value)
case None =>
if (isOptional) Right(None)
else if (defaultParams.contains(label))
val defaultParamCreator = defaultParams(label)
val defaultParamValue = defaultParamCreator()
Right(defaultParamValue)
else Left(ConstructError.from(s"Key $label doesn't exist in parsed document", parentNode))
}
val (left, right) = values.partitionMap(identity)
if left.nonEmpty then Left(left.head)
else Right(p.fromProduct(Tuple.fromArray(right.toArray)))
}

private def extractKeyValues(
mappings: Map[Node, Node]
): Either[ConstructError, Map[String, Node]] = {
val keyValueMap = mappings
.map { (k, v) =>
k match {
case ScalarNode(scalarKey, _) => Right((scalarKey, v))
case node =>
Left(ConstructError.from(s"Parameter of a class must be a scalar value", node))
}
}
val (error, valuesSeq) = keyValueMap.partitionMap(identity)

if (error.nonEmpty) Left(error.head)
else Right(valuesSeq.toMap)
}

def deriveProductImpl[T: Type](p: Expr[Mirror.ProductOf[T]])(using
Quotes
): Expr[YamlDecoder[T]] =

// returns a list of tuples of label, instance, isOptional
def prepareInstances(
elemLabels: Type[?],
elemTypes: Type[?]
): List[Expr[(String, YamlDecoder[?], Boolean)]] =
(elemLabels, elemTypes) match
case ('[EmptyTuple], '[EmptyTuple]) => Nil
case ('[label *: labelsTail], '[tpe *: tpesTail]) =>
val label = Type.valueOfConstant[label].get.asInstanceOf[String]
val isOption = Type.of[tpe] match
case '[Option[?]] => Expr(true)
case _ => Expr(false)

val fieldName = Expr(label)
val fieldFormat = Expr.summon[YamlDecoder[tpe]].getOrElse {
quotes.reflect.report
.errorAndAbort("Missing given instance of YamlDecoder[" ++ Type.show[tpe] ++ "]")
}
val namedInstance = '{ (${ fieldName }, $fieldFormat, ${ isOption }) }
namedInstance :: prepareInstances(Type.of[labelsTail], Type.of[tpesTail])

p match
case '{
$m: Mirror.ProductOf[T] {
type MirroredElemLabels = elementLabels; type MirroredElemTypes = elementTypes
}
} =>
val allInstancesExpr =
Expr.ofList(prepareInstances(Type.of[elementLabels], Type.of[elementTypes]))
val defaultParamsExpr = findDefaultParams[T]

protected inline def getElemLabels[T <: Tuple]: List[String] = inline erasedValue[T] match
case _: EmptyTuple => Nil
case _: (head *: tail) => constValue[head].toString :: getElemLabels[tail]
'{
new YamlDecoder[T] {
private val allInstances = $allInstancesExpr
private val defaultParams = $defaultParamsExpr
private val mirror = $p

protected inline def getOptionalTypes[T <: Tuple]: List[Boolean] = inline erasedValue[T] match
case _: EmptyTuple => Nil
case _: (Option[_] *: tail) => true :: getOptionalTypes[tail]
case _: (_ *: tail) => false :: getOptionalTypes[tail]
override def construct(node: Node)(using
constructor: LoadSettings = LoadSettings.empty
): Either[ConstructError, T] =
node match
case Node.MappingNode(mappings, _) =>
for {
valuesMap <- extractKeyValues(mappings)
constructedValues <- constructValues(
allInstances,
valuesMap,
defaultParams,
mirror,
node
)
} yield (constructedValues)
case _ =>
Left(
ConstructError.from(
s"Expected MappingNode, got ${node.getClass.getSimpleName}",
node
)
)
}
}

private val DefaultParamPrefix = "$lessinit$greater$default$"

protected def findDefaultParams[T](using
quotes: Quotes,
tpe: Type[T]
): Expr[Map[String, () => Any]] =
import quotes.reflect.*

TypeRepr.of[T].classSymbol match
case None => '{ Map.empty[String, () => Any] }
case Some(sym: Symbol) =>
try
val comp = sym.companionClass
val mod = Ref(sym.companionModule)
val paramWithDefaultMeta =
for (p, idx) <- sym.caseFields.zipWithIndex if p.flags.is(Flags.HasDefault)
// +1 because the names are generated starting from 1
yield (p.name, idx + 1)

val idents: List[(String, Ref)] =
for (paramName, idx) <- paramWithDefaultMeta
yield paramName -> mod.select(
// head is safe here because we know there has to be a getter for the default value
// because we checked for HasDefault flag
comp.methodMember(DefaultParamPrefix + idx.toString).head
)

val typeArgs = TypeRepr.of[T].typeArgs

// we create an expression of a list of tuples of name and thunks that return the default value for a given parameter
val defaultsThunksExpr: Expr[List[(String, () => Any)]] =
if typeArgs.isEmpty then
Expr.ofList(
idents.map { case (name, ref) => name -> ref.asExpr }.map { case (name, '{ $x }) =>
'{ (${ Expr(name) }, () => $x) }
}
)
else // if there are type parameters, we need to apply the type parameters to accessors
Expr.ofList(
idents.map { case (name, ref) => name -> ref.appliedToTypes(typeArgs).asExpr }.map {
case (name, '{ $x }) => '{ (${ Expr(name) }, () => $x) }
}
)

'{ $defaultsThunksExpr.toMap }
catch // TODO drop after https://github.com/lampepfl/dotty/issues/19732 (after bump to 3.3.4)
case cce: ClassCastException =>
'{
Map.empty[String, () => Any]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,66 @@ class DecoderSuite extends munit.FunSuite:
assert(error.msg.contains("Could't construct int from null (tag:yaml.org,2002:null)"))
case Right(data) => fail(s"expected failure, but got: $data")
}

test("default parameters for case classes can be used when decoding") {
case class Foo(a: Int = 1, b: String = "test", c: Option[Int] = None, d: Double)
derives YamlCodec

val yaml = """d: 1.0""".stripMargin

yaml.as[Foo] match
case Left(error: YamlError) =>
fail(s"failed with YamlError: $error")
case Right(foo) =>
assertEquals(foo.a, 1)
assertEquals(foo.b, "test")
assertEquals(foo.c, None)
assertEquals(foo.d, 1.0)
}

test("default parameters for case classes are evaluated lazily") {
var times = 0
def createB = {
times += 1
s"test-${times}"
}
case class Foo(a: Int, b: String = createB) derives YamlCodec

val yaml = """a: 1""".stripMargin

yaml.as[Foo] match
case Left(error: YamlError) =>
fail(s"failed with YamlError: $error")
case Right(foo) =>
assertEquals(foo.a, 1)
assertEquals(foo.b, "test-1")

yaml.as[Foo] // skip test-2

yaml.as[Foo] match
case Left(error: YamlError) =>
fail(s"failed with YamlError: $error")
case Right(foo) =>
assertEquals(foo.a, 1)
assertEquals(foo.b, "test-3")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about one more case with b specified to assert that the default is not invoked (so times is not incremented) when an argument is given explicitly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

test("default parameters are not evaluated when they are provided in yaml") {
var evaluated = false
def createB = {
evaluated = true
"default"
}
case class Foo(a: Int, b: String = createB) derives YamlCodec

val yaml = """a: 1
|b: from yaml""".stripMargin

yaml.as[Foo] match
case Left(error: YamlError) =>
fail(s"failed with YamlError: $error")
case Right(foo) =>
assertEquals(foo.a, 1)
assertEquals(foo.b, "from yaml")
assert(!evaluated)
}