Skip to content

Commit

Permalink
Fix issue 527 - Add to StackReference, methods converting to typed Ou…
Browse files Browse the repository at this point in the history
…tputs (#528)

* Fix issue 527 - Add methods to stack reference

* Fix issue 527 - Create TypedStackReference

* Add all changes necessary for nice typed StackRef api - @lbialy 

---------

Co-authored-by: Łukasz Biały <[email protected]>
  • Loading branch information
polkx and lbialy authored Aug 26, 2024
1 parent 4c20a7e commit fffd9c9
Show file tree
Hide file tree
Showing 11 changed files with 288 additions and 35 deletions.
2 changes: 1 addition & 1 deletion besom-json/src/main/scala/besom/json/JsonFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ object JsonReader {
def read(json: JsValue) = f(json)
}

inline def derived[T <: Product](using JsonProtocol): JsonReader[T] = summon[JsonProtocol].jsonFormatN[T]
inline def derived[T <: Product](using JsonProtocol): JsonReader[T] = summon[JsonProtocol].jsonReaderN[T]
}

/** Provides the JSON serialization for type T.
Expand Down
89 changes: 72 additions & 17 deletions besom-json/src/main/scala/besom/json/ProductFormats.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ trait ProductFormats:
def requireNullsForOptions: Boolean = false

inline def jsonFormatN[T <: Product]: RootJsonFormat[T] = ${ ProductFormatsMacro.jsonFormatImpl[T]('self) }
inline def jsonReaderN[T <: Product]: RootJsonReader[T] = ${ ProductFormatsMacro.jsonReaderImpl[T]('self) }

object ProductFormatsMacro:
import scala.deriving.*
Expand Down Expand Up @@ -57,29 +58,45 @@ object ProductFormatsMacro:
'{ $namesExpr.zip($identsExpr).toMap }
catch case cce: ClassCastException => '{ Map.empty[String, Any] } // TODO drop after https://github.com/lampepfl/dotty/issues/19732

private def prepareFormatInstances(elemLabels: Type[?], elemTypes: Type[?])(using Quotes): List[Expr[(String, JsonFormat[?], Boolean)]] =
(elemLabels, elemTypes) match
case ('[EmptyTuple], '[EmptyTuple]) => Nil
case ('[label *: labelsTail], '[tpe *: tpesTail]) =>
val label = Type.valueOfConstant[label].get.asInstanceOf[String]
val isOption = Type.of[tpe] match
case '[Option[?]] => Expr(true)
case _ => Expr(false)

val fieldName = Expr(label)
val fieldFormat = Expr.summon[JsonFormat[tpe]].getOrElse {
quotes.reflect.report.errorAndAbort("Missing given instance of JsonFormat[" ++ Type.show[tpe] ++ "]")
} // TODO: Handle missing instance
val namedInstance = '{ (${ fieldName }, $fieldFormat, ${ isOption }) }
namedInstance :: prepareFormatInstances(Type.of[labelsTail], Type.of[tpesTail])

private def prepareReaderInstances(elemLabels: Type[?], elemTypes: Type[?])(using Quotes): List[Expr[(String, JsonReader[?], Boolean)]] =
(elemLabels, elemTypes) match
case ('[EmptyTuple], '[EmptyTuple]) => Nil
case ('[label *: labelsTail], '[tpe *: tpesTail]) =>
val label = Type.valueOfConstant[label].get.asInstanceOf[String]
val isOption = Type.of[tpe] match
case '[Option[?]] => Expr(true)
case _ => Expr(false)

val fieldName = Expr(label)
val fieldFormat = Expr.summon[JsonReader[tpe]].getOrElse {
quotes.reflect.report.errorAndAbort("Missing given instance of JsonFormat[" ++ Type.show[tpe] ++ "]")
} // TODO: Handle missing instance
val namedInstance = '{ (${ fieldName }, $fieldFormat, ${ isOption }) }
namedInstance :: prepareReaderInstances(Type.of[labelsTail], Type.of[tpesTail])

def jsonFormatImpl[T <: Product: Type](prodFormats: Expr[ProductFormats])(using Quotes): Expr[RootJsonFormat[T]] =
Expr.summon[Mirror.Of[T]].get match
case '{
$m: Mirror.ProductOf[T] { type MirroredElemLabels = elementLabels; type MirroredElemTypes = elementTypes }
} =>
def prepareInstances(elemLabels: Type[?], elemTypes: Type[?]): List[Expr[(String, JsonFormat[?], Boolean)]] =
(elemLabels, elemTypes) match
case ('[EmptyTuple], '[EmptyTuple]) => Nil
case ('[label *: labelsTail], '[tpe *: tpesTail]) =>
val label = Type.valueOfConstant[label].get.asInstanceOf[String]
val isOption = Type.of[tpe] match
case '[Option[?]] => Expr(true)
case _ => Expr(false)

val fieldName = Expr(label)
val fieldFormat = Expr.summon[JsonFormat[tpe]].getOrElse {
quotes.reflect.report.errorAndAbort("Missing given instance of JsonFormat[" ++ Type.show[tpe] ++ "]")
} // TODO: Handle missing instance
val namedInstance = '{ (${ fieldName }, $fieldFormat, ${ isOption }) }
namedInstance :: prepareInstances(Type.of[labelsTail], Type.of[tpesTail])

// instances are in correct order of fields of the product
val allInstancesExpr = Expr.ofList(prepareInstances(Type.of[elementLabels], Type.of[elementTypes]))
val allInstancesExpr = Expr.ofList(prepareFormatInstances(Type.of[elementLabels], Type.of[elementTypes]))
val defaultArguments = findDefaultParams[T]

'{
Expand Down Expand Up @@ -121,6 +138,44 @@ object ProductFormatsMacro:

JsObject(fields.toMap)
}

def jsonReaderImpl[T <: Product: Type](prodFormats: Expr[ProductFormats])(using Quotes): Expr[RootJsonReader[T]] =
Expr.summon[Mirror.Of[T]].get match
case '{
$m: Mirror.ProductOf[T] { type MirroredElemLabels = elementLabels; type MirroredElemTypes = elementTypes }
} =>
// instances are in correct order of fields of the product
val allInstancesExpr = Expr.ofList(prepareReaderInstances(Type.of[elementLabels], Type.of[elementTypes]))
val defaultArguments = findDefaultParams[T]

'{
new RootJsonReader[T]:
private val allInstances = ${ allInstancesExpr }
private val fmts = ${ prodFormats }
private val defaultArgs = ${ defaultArguments }

def read(json: JsValue): T = json match
case JsObject(fields) =>
val values = allInstances.map { case (fieldName, fieldFormat, isOption) =>
try fieldFormat.read(fields(fieldName))
catch
case e: NoSuchElementException =>
// if field has a default value, use it, we didn't find anything in the JSON
if defaultArgs.contains(fieldName) then defaultArgs(fieldName)
// if field is optional and requireNullsForOptions is disabled, return None
// otherwise we require an explicit null value
else if isOption && !fmts.requireNullsForOptions then None
// it's missing so we throw an exception
else throw DeserializationException("Object is missing required member '" ++ fieldName ++ "'", null, fieldName :: Nil)
case DeserializationException(msg, cause, fieldNames) =>
throw DeserializationException(msg, cause, fieldName :: fieldNames)
}
$m.fromProduct(Tuple.fromArray(values.toArray))

case _ => throw DeserializationException("Object expected", null, allInstances.map(_._1))

}

end ProductFormatsMacro

/** This trait supplies an alternative rendering mode for optional case class members. Normally optional members that are undefined (`None`)
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/besom/aliases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ object aliases:
type ComponentResource = besom.internal.ComponentResource
type RegistersOutputs[A <: ComponentResource & Product] = besom.internal.RegistersOutputs[A]
type StackReference = besom.internal.StackReference
type TypedStackReference[A] = besom.internal.TypedStackReference[A]
object StackReference extends besom.internal.StackReferenceFactory
type StackReferenceArgs = besom.internal.StackReferenceArgs
object StackReferenceArgs extends besom.internal.StackReferenceArgsFactory
Expand All @@ -52,4 +53,5 @@ object aliases:
object CustomTimeouts extends besom.internal.CustomTimeoutsFactory

export besom.internal.InvokeOptions
export besom.util.JsonReaderInstances.*
end aliases
51 changes: 49 additions & 2 deletions core/src/main/scala/besom/internal/StackReference.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,44 @@ case class StackReference(
end StackReference

trait StackReferenceFactory:
def apply(using Context)(
sealed trait StackReferenceType[T]:
type Out[T]
def transform(stackReference: StackReference)(using Context): Output[Out[T]]

object StackReferenceType:
given untyped: UntypedStackReferenceType = UntypedStackReferenceType()

given typed[T: JsonReader]: TypedStackReferenceType[T] = TypedStackReferenceType[T]

class TypedStackReferenceType[T](using JsonReader[T]) extends StackReferenceType[T]:
type Out[T] = TypedStackReference[T]
def transform(stackReference: StackReference)(using Context): Output[Out[T]] =
val objectOutput: Output[T] =
requireObject(stackReference.outputs, stackReference.secretOutputNames)

objectOutput.map(t =>
TypedStackReference(
urn = stackReference.urn,
id = stackReference.id,
name = stackReference.name,
outputs = t,
secretOutputNames = stackReference.secretOutputNames
)
)

class UntypedStackReferenceType extends StackReferenceType[Any]:
type Out[T] = StackReference
def transform(stackReference: StackReference)(using Context): Output[StackReference] = Output(stackReference)

def untypedStackReference(using Context): StackReferenceType[Any] = UntypedStackReferenceType()

def typedStackReference[T: JsonReader]: TypedStackReferenceType[T] = TypedStackReferenceType()

def apply[T](using stackRefType: StackReferenceType[T], ctx: Context)(
name: NonEmptyString,
args: Input.Optional[StackReferenceArgs] = None,
opts: StackReferenceResourceOptions = StackReferenceResourceOptions()
): Output[StackReference] =
): Output[stackRefType.Out[T]] =
args
.asOptionOutput(false)
.flatMap {
Expand All @@ -76,3 +109,17 @@ trait StackReferenceFactory:

Context().readOrRegisterResource[StackReference, StackReferenceArgs]("pulumi:pulumi:StackReference", name, stackRefArgs, mergedOpts)
}
.flatMap(stackRefType.transform)

private[internal] def requireObject[T: JsonReader](
outputs: Output[Map[String, JsValue]],
secretOutputNames: Output[Set[String]]
): Output[T] =
outputs
.map(JsObject(_).convertTo[T])
.withIsSecret(
secretOutputNames
.map(_.nonEmpty)
.getValueOrElse(false)
)
end StackReferenceFactory
11 changes: 11 additions & 0 deletions core/src/main/scala/besom/internal/TypedStackReference.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package besom.internal

import besom.types.*

case class TypedStackReference[T](
urn: Output[URN],
id: Output[ResourceId],
name: Output[String],
outputs: T,
secretOutputNames: Output[Set[String]]
) extends CustomResource
2 changes: 1 addition & 1 deletion core/src/main/scala/besom/internal/codecs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ object Decoder extends DecoderInstancesLowPrio1:
.foldLeft[ValidatedResult[DecodingError, Vector[OutputData[A]]]](ValidatedResult.valid(Vector.empty))(
accumulatedOutputDataOrErrors(_, _, "iterable", label)
)
.map(_.toIterable)
.map(_.toVector)
.map(OutputData.sequence)
end if
}
Expand Down
33 changes: 33 additions & 0 deletions core/src/main/scala/besom/util/JsonReaderInstances.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package besom.util

import besom.json.*
import besom.internal.{Output, Context}
import besom.internal.Constants, Constants.SpecialSig

object JsonReaderInstances:
implicit def outputJsonReader[A](using jsonReader: JsonReader[A], ctx: Context): JsonReader[Output[A]] =
new JsonReader[Output[A]]:
def read(json: JsValue): Output[A] = json match
case JsObject(fields) =>
fields.get(SpecialSig.Key) match
case Some(JsString(sig)) if SpecialSig.fromString(sig) == Some(SpecialSig.OutputSig) =>
val maybeInnerValue = fields.get(Constants.ValueName)
maybeInnerValue
.map { innerValue =>
try Output(jsonReader.read(innerValue))
catch case e: Throwable => Output.fail(e)
}
.getOrElse(Output.fail(Exception("Invalid JSON")))

case Some(JsString(sig)) if SpecialSig.fromString(sig) == Some(SpecialSig.SecretSig) =>
val maybeInnerValue = fields.get(Constants.ValueName)
maybeInnerValue
.map { innerValue =>
try Output.secret(jsonReader.read(innerValue))
catch case e: Throwable => Output.fail(e)
}
.getOrElse(Output.fail(Exception("Invalid JSON")))

case _ => Output.fail(Exception("Invalid JSON"))

case _ => Output.fail(Exception("Invalid JSON"))
61 changes: 61 additions & 0 deletions core/src/test/scala/besom/internal/StackReferenceTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package besom.internal

import besom.*
import besom.json.*
import RunResult.{*, given}

class StackReferenceTest extends munit.FunSuite:

test("convert stack reference to case class") {
given Context = DummyContext().unsafeRunSync()
case class Test(s: String, i: Int) derives JsonReader
val expected = Test("value1", 2)
val outputs = Map("s" -> JsString("value1"), "i" -> JsNumber(2))

val requireObject = StackReference.requireObject[Test](Output(outputs), Output(Set.empty))
assertEquals(requireObject.getData.unsafeRunSync(), OutputData(expected))
}

test("fail when convert stack reference to case class with missing data") {
given Context = DummyContext().unsafeRunSync()
case class Test(s: String, i: Int) derives JsonReader
val outputs = Map("s" -> JsString("value1"))

val requireObject = StackReference.requireObject[Test](Output(outputs), Output(Set.empty))
intercept[besom.json.DeserializationException](requireObject.getData.unsafeRunSync())
}

test("convert stack reference to case class with secret field") {
given Context = DummyContext().unsafeRunSync()
case class Test(s: String, i: Int) derives JsonReader
val expected = Test("value1", 2)
val outputs = Map("s" -> JsString("value1"), "i" -> JsNumber(2))
val secretOutputNames = Set("i")

val requireObject = StackReference.requireObject[Test](Output(outputs), Output(secretOutputNames))
assertEquals(requireObject.getData.unsafeRunSync(), OutputData(expected).withIsSecret(true))
}

test("propagate secret field to whole typed stack reference") {
given Context = DummyContext().unsafeRunSync()

case class Test(s: String, i: Int) derives JsonReader
val outputs = Map("s" -> JsString("value1"), "i" -> JsNumber(2))
val secretOutputNames = Set("i")

val typedStackReference =
StackReference
.requireObject[Test](Output(outputs), Output(secretOutputNames))
.map(test =>
TypedStackReference(
urn = Output(URN.empty),
id = Output(ResourceId.empty),
name = Output(""),
outputs = test,
secretOutputNames = Output(secretOutputNames)
)
)

assertEquals(typedStackReference.getData.unsafeRunSync().secret, true)
}
end StackReferenceTest
Binary file removed cs
Binary file not shown.
12 changes: 6 additions & 6 deletions integration-tests/CoreTests.test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,19 @@ class CoreTests extends munit.FunSuite {
case pulumi.FixtureMultiContext(ctx, Vector(ctx1, ctx2)) =>
println(s"Source stack name: ${ctx1.stackName}, pulumi home: ${ctx.home}")
pulumi.up(ctx1.stackName).call(cwd = ctx1.programDir, env = ctx1.env)
val outputs1 = upickle.default.read[Map[String, ujson.Value]](
val expected = upickle.default.read[Map[String, ujson.Value]](
pulumi.outputs(ctx1.stackName, "--show-secrets").call(cwd = ctx1.programDir, env = ctx1.env).out.text()
)

println(s"Target stack name: ${ctx2.stackName}, pulumi home: ${ctx.home}")
pulumi
.up(ctx2.stackName, "--config", s"sourceStack=organization/source-stack-test/${ctx1.stackName}")
.call(cwd = ctx2.programDir, env = ctx2.env)
val outputs2 = upickle.default.read[Map[String, ujson.Value]](
val obtained = upickle.default.read[Map[String, ujson.Value]](
pulumi.outputs(ctx2.stackName, "--show-secrets").call(cwd = ctx2.programDir, env = ctx2.env).out.text()
)

assertEquals(outputs1, outputs2)
assertEquals(obtained, expected)

case _ => throw new Exception("Invalid number of contexts")
}
Expand Down Expand Up @@ -182,19 +182,19 @@ class CoreTests extends munit.FunSuite {
case pulumi.FixtureMultiContext(ctx, Vector(ctx1, ctx2)) =>
println(s"Source stack name: ${ctx1.stackName}, pulumi home: ${ctx.home}")
pulumi.up(ctx1.stackName).call(cwd = ctx1.programDir, env = ctx1.env)
val outputs1 = upickle.default.read[Map[String, ujson.Value]](
val expected = upickle.default.read[Map[String, ujson.Value]](
pulumi.outputs(ctx1.stackName, "--show-secrets").call(cwd = ctx1.programDir, env = ctx1.env).out.text()
)

println(s"Target stack name: ${ctx2.stackName}, pulumi home: ${ctx.home}")
pulumi
.up(ctx2.stackName, "--config", s"sourceStack=organization/source-stack-test/${ctx1.stackName}")
.call(cwd = ctx2.programDir, env = ctx2.env)
val outputs2 = upickle.default.read[Map[String, ujson.Value]](
val obtained = upickle.default.read[Map[String, ujson.Value]](
pulumi.outputs(ctx2.stackName, "--show-secrets").call(cwd = ctx2.programDir, env = ctx2.env).out.text()
)

assertEquals(outputs1, outputs2)
assertEquals(obtained, expected)

case _ => throw new Exception("Invalid number of contexts")
}
Expand Down
Loading

0 comments on commit fffd9c9

Please sign in to comment.