Skip to content

Commit

Permalink
Merge branch 'master' into delay-headers
Browse files Browse the repository at this point in the history
  • Loading branch information
joroKr21 committed Jul 23, 2024
2 parents d7df9b9 + aa3e67e commit 0b807b6
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 81 deletions.
7 changes: 7 additions & 0 deletions .scala-steward.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
updates.pin = [
# Scala 3.3 is the LTS release
{ groupId = "org.scala-lang", artifactId = "scala3-compiler", version = "3.3." }
{ groupId = "org.scala-lang", artifactId = "scala3-library", version = "3.3." }
{ groupId = "org.scala-lang", artifactId = "scala3-library_sjs1", version = "3.3." }
]

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## 0.6.3
* Delay sending response headers until the first message is ready.
* Add configurable backpressure to streaming clients.

## 0.6.1
* Only buffer stream if queue size is positive (#578, #580)
Expand Down
8 changes: 4 additions & 4 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import Settings.stdSettings
import org.scalajs.linker.interface.ModuleInitializer

val Scala3 = "3.4.1"
val Scala3 = "3.3.3"

val Scala213 = "2.13.14"

Expand Down Expand Up @@ -64,7 +64,7 @@ lazy val core = projectMatrix
.settings(
libraryDependencies ++= Seq(
"com.thesamet.scalapb.grpcweb" %%% "scalapb-grpcweb" % "0.7.0",
"io.github.cquiroz" %%% "scala-java-time" % "2.5.0" % "test"
"io.github.cquiroz" %%% "scala-java-time" % "2.6.0" % "test"
),
Compile / npmDependencies += "grpc-web" -> "1.4.2"
)
Expand Down Expand Up @@ -118,7 +118,7 @@ lazy val e2eProtos =
settings = Seq(
libraryDependencies ++= Seq(
"com.thesamet.scalapb.grpcweb" %%% "scalapb-grpcweb" % "0.7.0",
"io.github.cquiroz" %%% "scala-java-time" % "2.5.0"
"io.github.cquiroz" %%% "scala-java-time" % "2.6.0"
)
)
)
Expand Down Expand Up @@ -187,7 +187,7 @@ lazy val e2eWeb =
scalaVersions = ScalaVersions,
settings = Seq(
libraryDependencies ++= Seq(
"com.microsoft.playwright" % "playwright" % "1.44.0" % Test,
"com.microsoft.playwright" % "playwright" % "1.45.0" % Test,
"dev.zio" %%% "zio-test" % Version.zio % Test,
"dev.zio" %% "zio-test-sbt" % Version.zio % Test
),
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scalajs/scalapb/zio_grpc/ZChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ import zio.Task

class ZChannel(
private[zio_grpc] val channel: ManagedChannel,
private[zio_grpc] val prefetch: Option[Int],
interceptors: Seq[ZClientInterceptor]
) {
def this(channel: ManagedChannel, interceptors: Seq[ZClientInterceptor]) =
this(channel, None, interceptors)

def shutdown(): Task[Unit] = ZIO.unit
}
30 changes: 26 additions & 4 deletions core/src/main/scalajvm/scalapb/zio_grpc/ZChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ import zio._

class ZChannel(
private[zio_grpc] val channel: ManagedChannel,
private[zio_grpc] val prefetch: Option[Int],
interceptors: Seq[ZClientInterceptor]
) {
def this(channel: ManagedChannel, interceptors: Seq[ZClientInterceptor]) =
this(channel, None, interceptors)

def newCall[Req, Res](
methodDescriptor: MethodDescriptor[Req, Res],
options: CallOptions
Expand Down Expand Up @@ -44,8 +48,26 @@ object ZChannel {
interceptors: Seq[ZClientInterceptor],
timeout: Duration
): RIO[Scope, ZChannel] =
ZIO
.acquireRelease(
ZIO.attempt(new ZChannel(builder.build(), interceptors))
)(channel => channel.shutdown().ignore *> channel.awaitTermination(timeout).ignore)
scoped(builder, interceptors, timeout, None)

/** Create a scoped channel that will be shutdown when the scope closes.
*
* @param builder
* The channel builder to use to create the channel.
* @param interceptors
* The client call interceptors to use.
* @param timeout
* The maximum amount of time to wait for the channel to shutdown.
* @param prefetch
* Enables backpressure for streaming responses and sets the number of messages to prefetch.
* @return
*/
def scoped(
builder: => ManagedChannelBuilder[_],
interceptors: Seq[ZClientInterceptor],
timeout: Duration,
prefetch: Option[Int]
): RIO[Scope, ZChannel] = ZIO.acquireRelease(
ZIO.attempt(new ZChannel(builder.build(), prefetch.map(_.max(1)), interceptors))
)(channel => channel.shutdown().ignore *> channel.awaitTermination(timeout).ignore)
}
16 changes: 10 additions & 6 deletions core/src/main/scalajvm/scalapb/zio_grpc/ZManagedChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@ import zio.ZIO

object ZManagedChannel {
def apply(
builder: => ManagedChannelBuilder[_],
builder: => ManagedChannelBuilder[?],
prefetch: Option[Int],
interceptors: Seq[ZClientInterceptor]
): ZManagedChannel =
ZIO.acquireRelease(ZIO.attempt(new ZChannel(builder.build(), interceptors)))(
_.shutdown().ignore
)
): ZManagedChannel = ZIO.acquireRelease(
ZIO.attempt(new ZChannel(builder.build(), prefetch.map(_.max(1)), interceptors))
)(_.shutdown().ignore)

def apply(builder: ManagedChannelBuilder[_]): ZManagedChannel = apply(builder, Nil)
def apply(builder: => ManagedChannelBuilder[?], interceptors: Seq[ZClientInterceptor]): ZManagedChannel =
apply(builder, None, interceptors)

def apply(builder: ManagedChannelBuilder[?]): ZManagedChannel =
apply(builder, None, Nil)
}
32 changes: 13 additions & 19 deletions core/src/main/scalajvm/scalapb/zio_grpc/client/ClientCalls.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,17 @@ object ClientCalls {
private def serverStreamingCall[Req, Res](
call: ZClientCall[Req, Res],
headers: SafeMetadata,
prefetch: Option[Int],
req: Req
): ZStream[Any, StatusException, ResponseFrame[Res]] =
ZStream
.acquireReleaseExitWith(
StreamingClientCallListener.make[Res](call)
StreamingClientCallListener.make[Res](call, prefetch)
)(anyExitHandler[Req, Res](call))
.flatMap { (listener: StreamingClientCallListener[Res]) =>
ZStream.unwrap(
(call.start(listener, headers) *>
call.request(1) *>
call.request(prefetch.getOrElse(1)) *>
call.sendMessage(req) *>
call.halfClose()).as(listener.stream)
)
Expand All @@ -57,11 +58,7 @@ object ClientCalls {
headers: SafeMetadata,
req: Req
): ZStream[Any, StatusException, ResponseFrame[Res]] =
ZStream.unwrap(
channel
.newCall(method, options)
.map(serverStreamingCall(_, headers, req))
)
ZStream.unwrap(channel.newCall(method, options).map(serverStreamingCall(_, headers, channel.prefetch, req)))

private def clientStreamingCall[Req, Res](
call: ZClientCall[Req, Res],
Expand Down Expand Up @@ -92,14 +89,15 @@ object ClientCalls {
private def bidiCall[Req, Res](
call: ZClientCall[Req, Res],
headers: SafeMetadata,
prefetch: Option[Int],
req: ZStream[Any, StatusException, Req]
): ZStream[Any, StatusException, ResponseFrame[Res]] =
ZStream
.acquireReleaseExitWith(
StreamingClientCallListener.make[Res](call)
StreamingClientCallListener.make[Res](call, prefetch)
)(anyExitHandler(call))
.flatMap { (listener: StreamingClientCallListener[Res]) =>
val init = call.start(listener, headers) *> call.request(1)
val init = call.start(listener, headers) *> call.request(prefetch.getOrElse(1))
val process = req.runForeach(call.sendMessage)
val finish = call.halfClose()
val sendRequestStream = ZStream.execute(init *> process *> finish)
Expand All @@ -113,25 +111,21 @@ object ClientCalls {
headers: SafeMetadata,
req: ZStream[Any, StatusException, Req]
): ZStream[Any, StatusException, ResponseFrame[Res]] =
ZStream.unwrap(
channel
.newCall(method, options)
.map(bidiCall(_, headers, req))
)
ZStream.unwrap(channel.newCall(method, options).map(bidiCall(_, headers, channel.prefetch, req)))
}

def exitHandler[Req, Res](
call: ZClientCall[Req, Res]
)(l: Any, ex: Exit[StatusException, Any]) = anyExitHandler(call)(l, ex)
)(l: Any, ex: Exit[StatusException, Any]) =
ZIO.when(!ex.isSuccess) {
anyExitHandler(call)(l, ex)
}

// less type safe
def anyExitHandler[Req, Res](
call: ZClientCall[Req, Res]
) =
(_: Any, ex: Exit[Any, Any]) =>
ZIO.when(!ex.isSuccess) {
call.cancel("Interrupted").ignore
}
(_: Any, ex: Exit[Any, Any]) => call.cancel("Interrupted").ignore

def unaryCall[Req, Res](
channel: ZChannel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,48 +6,54 @@ import zio.stream.ZStream
import zio._

class StreamingClientCallListener[Res](
prefetch: Option[Int],
runtime: Runtime[Any],
call: ZClientCall[_, Res],
queue: Queue[ResponseFrame[Res]]
call: ZClientCall[?, Res],
queue: Queue[ResponseFrame[Res]],
buffered: Ref[Int]
) extends ClientCall.Listener[Res] {
private val increment = if (prefetch.isDefined) buffered.update(_ + 1) else ZIO.unit
private val fetchOne = if (prefetch.isDefined) ZIO.unit else call.request(1)
private val fetchMore = prefetch match {
case None => ZIO.unit
case Some(n) => buffered.get.flatMap(b => call.request(n - b).when(n > b))
}

private def unsafeRun(task: IO[Any, Unit]): Unit =
Unsafe.unsafe(implicit u => runtime.unsafe.run(task).getOrThrowFiberFailure())

private def handle(promise: Promise[StatusException, Unit])(
chunk: Chunk[ResponseFrame[Res]]
) = (chunk.lastOption match {
case Some(ResponseFrame.Trailers(status, trailers)) =>
val exit = if (status.isOk) Exit.unit else Exit.fail(new StatusException(status, trailers))
promise.done(exit) *> queue.shutdown
case _ =>
buffered.update(_ - chunk.size) *> fetchMore
}).as(chunk)

override def onHeaders(headers: Metadata): Unit =
Unsafe.unsafe { implicit u =>
runtime.unsafe
.run(
queue
.offer(ResponseFrame.Headers(headers))
.unit
)
.getOrThrowFiberFailure()
}
unsafeRun(queue.offer(ResponseFrame.Headers(headers)) *> increment)

override def onMessage(message: Res): Unit =
Unsafe.unsafe { implicit u =>
runtime.unsafe.run(queue.offer(ResponseFrame.Message(message)) *> call.request(1)).getOrThrowFiberFailure()
}
unsafeRun(queue.offer(ResponseFrame.Message(message)) *> increment *> fetchOne)

override def onClose(status: Status, trailers: Metadata): Unit =
Unsafe.unsafe { implicit u =>
runtime.unsafe.run(queue.offer(ResponseFrame.Trailers(status, trailers)).unit).getOrThrowFiberFailure()
}
unsafeRun(queue.offer(ResponseFrame.Trailers(status, trailers)).unit)

def stream: ZStream[Any, StatusException, ResponseFrame[Res]] =
ZStream
.fromQueue(queue)
.tap {
case ResponseFrame.Trailers(status, trailers) =>
queue.shutdown *> ZIO.when(!status.isOk)(ZIO.fail(new StatusException(status, trailers)))
case _ => ZIO.unit
}
ZStream.fromZIO(Promise.make[StatusException, Unit]).flatMap { promise =>
ZStream
.fromQueue(queue, prefetch.getOrElse(ZStream.DefaultChunkSize))
.mapChunksZIO(handle(promise))
.concat(ZStream.execute(promise.await))
}
}

object StreamingClientCallListener {
def make[Res](
call: ZClientCall[_, Res]
): UIO[StreamingClientCallListener[Res]] =
for {
runtime <- zio.ZIO.runtime[Any]
queue <- Queue.unbounded[ResponseFrame[Res]]
} yield new StreamingClientCallListener(runtime, call, queue)
def make[Res](call: ZClientCall[?, Res], prefetch: Option[Int]): UIO[StreamingClientCallListener[Res]] = for {
runtime <- ZIO.runtime[Any]
queue <- Queue.unbounded[ResponseFrame[Res]]
buffered <- Ref.make(0)
} yield new StreamingClientCallListener(prefetch, runtime, call, queue, buffered)
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package scalapb.zio_grpc

import scalapb.zio_grpc.testservice.ZioTestservice.TestServiceClient
import scalapb.zio_grpc.ZManagedChannel
import scalapb.grpc.Channels
import zio.test._
import org.scalajs.dom
Expand Down
35 changes: 20 additions & 15 deletions e2e/src/test/scalajvm/scalapb/zio_grpc/TestServiceSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@ object TestServiceSpec extends ZIOSpecDefault with CommonTestServiceSpec {
).asJava
).asJava

val clientLayer: ZLayer[Server, Nothing, TestServiceClient] =
ZLayer.scoped[Server] {
for {
ss <- ZIO.service[Server]
port <- ss.port.orDie
ch = ManagedChannelBuilder.forAddress("localhost", port).defaultServiceConfig(serviceConfig).usePlaintext()
client <- TestServiceClient.scoped(ZManagedChannel(ch)).orDie
} yield client
}
def clientLayer(prefetch: Option[Int]): ZLayer[Server, Nothing, TestServiceClient] =
ZLayer.scoped[Server](for {
ss <- ZIO.service[Server]
port <- ss.port.orDie
ch = ManagedChannelBuilder.forAddress("localhost", port).usePlaintext()
client <- TestServiceClient.scoped(ZManagedChannel(ch, prefetch, Nil)).orDie
} yield client)

def unarySuiteJVM =
suite("unary request")(
Expand Down Expand Up @@ -260,17 +258,24 @@ object TestServiceSpec extends ZIOSpecDefault with CommonTestServiceSpec {
}
)

val layers = TestServiceImpl.live >>>
(TestServiceImpl.any ++ serverLayer) >>>
(TestServiceImpl.any ++ clientLayer ++ Annotations.live)
type Deps = Server with TestServiceImpl with Annotations

def spec =
suite("TestServiceSpec")(
def spec = suite("TestServiceSpec")(
suite("without prefetch")(
unarySuite,
unarySuiteJVM,
serverStreamingSuite,
serverStreamingSuiteJVM,
clientStreamingSuite,
bidiStreamingSuite
).provideSomeLayer[Deps](clientLayer(None)),
suite("with prefetch = 2")(
unarySuite,
unarySuiteJVM,
serverStreamingSuite,
serverStreamingSuiteJVM,
clientStreamingSuite,
bidiStreamingSuite
).provideLayer(layers.orDie)
).provideSomeLayer[Deps](clientLayer(Some(2)))
).provide(serverLayer, TestServiceImpl.live, Annotations.live)
}
2 changes: 1 addition & 1 deletion project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ addSbtPlugin("ch.epfl.scala" % "sbt-scalajs-bundler" % "0.21.1")

addSbtPlugin("com.thesamet" % "sbt-protoc-gen-project" % "0.1.8")

addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.5.2")
addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.5.3")

addSbtPlugin("com.eed3si9n" % "sbt-projectmatrix" % "0.10.0")

0 comments on commit 0b807b6

Please sign in to comment.