Skip to content

Commit

Permalink
Implement streaming client backpressure with configurable batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
joroKr21 committed Jun 7, 2024
1 parent dc2ce2a commit 6fc4b28
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 71 deletions.
4 changes: 4 additions & 0 deletions core/src/main/scalajs/scalapb/zio_grpc/ZChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ import zio.Task

class ZChannel(
private[zio_grpc] val channel: ManagedChannel,
private[zio_grpc] val prefetch: Option[Int],
interceptors: Seq[ZClientInterceptor]
) {
def this(channel: ManagedChannel, interceptors: Seq[ZClientInterceptor]) =
this(channel, None, interceptors)

def shutdown(): Task[Unit] = ZIO.unit
}
30 changes: 26 additions & 4 deletions core/src/main/scalajvm/scalapb/zio_grpc/ZChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ import zio._

class ZChannel(
private[zio_grpc] val channel: ManagedChannel,
private[zio_grpc] val prefetch: Option[Int],
interceptors: Seq[ZClientInterceptor]
) {
def this(channel: ManagedChannel, interceptors: Seq[ZClientInterceptor]) =
this(channel, None, interceptors)

def newCall[Req, Res](
methodDescriptor: MethodDescriptor[Req, Res],
options: CallOptions
Expand Down Expand Up @@ -44,8 +48,26 @@ object ZChannel {
interceptors: Seq[ZClientInterceptor],
timeout: Duration
): RIO[Scope, ZChannel] =
ZIO
.acquireRelease(
ZIO.attempt(new ZChannel(builder.build(), interceptors))
)(channel => channel.shutdown().ignore *> channel.awaitTermination(timeout).ignore)
scoped(builder, interceptors, timeout, None)

/** Create a scoped channel that will be shutdown when the scope closes.
*
* @param builder
* The channel builder to use to create the channel.
* @param interceptors
* The client call interceptors to use.
* @param timeout
* The maximum amount of time to wait for the channel to shutdown.
* @param prefetch
* Enables backpressure for streaming responses and sets the number of messages to prefetch.
* @return
*/
def scoped(
builder: => ManagedChannelBuilder[_],
interceptors: Seq[ZClientInterceptor],
timeout: Duration,
prefetch: Option[Int]
): RIO[Scope, ZChannel] = ZIO.acquireRelease(
ZIO.attempt(new ZChannel(builder.build(), prefetch.map(_.max(1)), interceptors))
)(channel => channel.shutdown().ignore *> channel.awaitTermination(timeout).ignore)
}
16 changes: 10 additions & 6 deletions core/src/main/scalajvm/scalapb/zio_grpc/ZManagedChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@ import zio.ZIO

object ZManagedChannel {
def apply(
builder: => ManagedChannelBuilder[_],
builder: => ManagedChannelBuilder[?],
prefetch: Option[Int],
interceptors: Seq[ZClientInterceptor]
): ZManagedChannel =
ZIO.acquireRelease(ZIO.attempt(new ZChannel(builder.build(), interceptors)))(
_.shutdown().ignore
)
): ZManagedChannel = ZIO.acquireRelease(
ZIO.attempt(new ZChannel(builder.build(), prefetch.map(_.max(1)), interceptors))
)(_.shutdown().ignore)

def apply(builder: ManagedChannelBuilder[_]): ZManagedChannel = apply(builder, Nil)
def apply(builder: => ManagedChannelBuilder[?], interceptors: Seq[ZClientInterceptor]): ZManagedChannel =
apply(builder, None, interceptors)

def apply(builder: ManagedChannelBuilder[?]): ZManagedChannel =
apply(builder, None, Nil)
}
22 changes: 8 additions & 14 deletions core/src/main/scalajvm/scalapb/zio_grpc/client/ClientCalls.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,17 @@ object ClientCalls {
private def serverStreamingCall[Req, Res](
call: ZClientCall[Req, Res],
headers: SafeMetadata,
prefetch: Option[Int],
req: Req
): ZStream[Any, StatusException, ResponseFrame[Res]] =
ZStream
.acquireReleaseExitWith(
StreamingClientCallListener.make[Res](call)
StreamingClientCallListener.make[Res](call, prefetch)
)(anyExitHandler[Req, Res](call))
.flatMap { (listener: StreamingClientCallListener[Res]) =>
ZStream.unwrap(
(call.start(listener, headers) *>
call.request(1) *>
call.request(prefetch.getOrElse(1)) *>
call.sendMessage(req) *>
call.halfClose()).as(listener.stream)
)
Expand All @@ -57,11 +58,7 @@ object ClientCalls {
headers: SafeMetadata,
req: Req
): ZStream[Any, StatusException, ResponseFrame[Res]] =
ZStream.unwrap(
channel
.newCall(method, options)
.map(serverStreamingCall(_, headers, req))
)
ZStream.unwrap(channel.newCall(method, options).map(serverStreamingCall(_, headers, channel.prefetch, req)))

private def clientStreamingCall[Req, Res](
call: ZClientCall[Req, Res],
Expand Down Expand Up @@ -92,14 +89,15 @@ object ClientCalls {
private def bidiCall[Req, Res](
call: ZClientCall[Req, Res],
headers: SafeMetadata,
prefetch: Option[Int],
req: ZStream[Any, StatusException, Req]
): ZStream[Any, StatusException, ResponseFrame[Res]] =
ZStream
.acquireReleaseExitWith(
StreamingClientCallListener.make[Res](call)
StreamingClientCallListener.make[Res](call, prefetch)
)(anyExitHandler(call))
.flatMap { (listener: StreamingClientCallListener[Res]) =>
val init = call.start(listener, headers) *> call.request(1)
val init = call.start(listener, headers) *> call.request(prefetch.getOrElse(1))
val process = req.runForeach(call.sendMessage)
val finish = call.halfClose()
val sendRequestStream = ZStream.execute(init *> process *> finish)
Expand All @@ -113,11 +111,7 @@ object ClientCalls {
headers: SafeMetadata,
req: ZStream[Any, StatusException, Req]
): ZStream[Any, StatusException, ResponseFrame[Res]] =
ZStream.unwrap(
channel
.newCall(method, options)
.map(bidiCall(_, headers, req))
)
ZStream.unwrap(channel.newCall(method, options).map(bidiCall(_, headers, channel.prefetch, req)))
}

def exitHandler[Req, Res](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,48 +6,54 @@ import zio.stream.ZStream
import zio._

class StreamingClientCallListener[Res](
prefetch: Option[Int],
runtime: Runtime[Any],
call: ZClientCall[_, Res],
queue: Queue[ResponseFrame[Res]]
call: ZClientCall[?, Res],
queue: Queue[ResponseFrame[Res]],
buffered: Ref[Int]
) extends ClientCall.Listener[Res] {
private val increment = if (prefetch.isDefined) buffered.update(_ + 1) else ZIO.unit
private val fetchOne = if (prefetch.isDefined) ZIO.unit else call.request(1)
private val fetchMore = prefetch match {
case None => ZIO.unit
case Some(n) => buffered.get.flatMap(b => call.request(n - b).when(n > b))
}

private def unsafeRun(task: IO[Any, Unit]): Unit =
Unsafe.unsafe(implicit u => runtime.unsafe.run(task).getOrThrowFiberFailure())

private def handle(promise: Promise[StatusException, Unit])(
chunk: Chunk[ResponseFrame[Res]]
) = (chunk.lastOption match {
case Some(ResponseFrame.Trailers(status, trailers)) =>
val exit = if (status.isOk) Exit.unit else Exit.fail(new StatusException(status, trailers))
promise.done(exit) *> queue.shutdown
case _ =>
buffered.update(_ - chunk.size) *> fetchMore
}).as(chunk)

override def onHeaders(headers: Metadata): Unit =
Unsafe.unsafe { implicit u =>
runtime.unsafe
.run(
queue
.offer(ResponseFrame.Headers(headers))
.unit
)
.getOrThrowFiberFailure()
}
unsafeRun(queue.offer(ResponseFrame.Headers(headers)) *> increment)

override def onMessage(message: Res): Unit =
Unsafe.unsafe { implicit u =>
runtime.unsafe.run(queue.offer(ResponseFrame.Message(message)) *> call.request(1)).getOrThrowFiberFailure()
}
unsafeRun(queue.offer(ResponseFrame.Message(message)) *> increment *> fetchOne)

override def onClose(status: Status, trailers: Metadata): Unit =
Unsafe.unsafe { implicit u =>
runtime.unsafe.run(queue.offer(ResponseFrame.Trailers(status, trailers)).unit).getOrThrowFiberFailure()
}
unsafeRun(queue.offer(ResponseFrame.Trailers(status, trailers)).unit)

def stream: ZStream[Any, StatusException, ResponseFrame[Res]] =
ZStream
.fromQueue(queue)
.tap {
case ResponseFrame.Trailers(status, trailers) =>
queue.shutdown *> ZIO.when(!status.isOk)(ZIO.fail(new StatusException(status, trailers)))
case _ => ZIO.unit
}
ZStream.fromZIO(Promise.make[StatusException, Unit]).flatMap { promise =>
ZStream
.fromQueue(queue, prefetch.getOrElse(ZStream.DefaultChunkSize))
.mapChunksZIO(handle(promise))
.concat(ZStream.execute(promise.await))
}
}

object StreamingClientCallListener {
def make[Res](
call: ZClientCall[_, Res]
): UIO[StreamingClientCallListener[Res]] =
for {
runtime <- zio.ZIO.runtime[Any]
queue <- Queue.unbounded[ResponseFrame[Res]]
} yield new StreamingClientCallListener(runtime, call, queue)
def make[Res](call: ZClientCall[?, Res], prefetch: Option[Int]): UIO[StreamingClientCallListener[Res]] = for {
runtime <- ZIO.runtime[Any]
queue <- Queue.unbounded[ResponseFrame[Res]]
buffered <- Ref.make(0)
} yield new StreamingClientCallListener(prefetch, runtime, call, queue, buffered)
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package scalapb.zio_grpc

import scalapb.zio_grpc.testservice.ZioTestservice.TestServiceClient
import scalapb.zio_grpc.ZManagedChannel
import scalapb.grpc.Channels
import zio.test._
import org.scalajs.dom
Expand Down
35 changes: 20 additions & 15 deletions e2e/src/test/scalajvm/scalapb/zio_grpc/TestServiceSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@ object TestServiceSpec extends ZIOSpecDefault with CommonTestServiceSpec {
val serverLayer: ZLayer[TestServiceImpl, Throwable, Server] =
ServerLayer.fromEnvironment[TestServiceImpl.Service](ServerBuilder.forPort(0))

val clientLayer: ZLayer[Server, Nothing, TestServiceClient] =
ZLayer.scoped[Server] {
for {
ss <- ZIO.service[Server]
port <- ss.port.orDie
ch = ManagedChannelBuilder.forAddress("localhost", port).usePlaintext()
client <- TestServiceClient.scoped(ZManagedChannel(ch)).orDie
} yield client
}
def clientLayer(prefetch: Option[Int]): ZLayer[Server, Nothing, TestServiceClient] =
ZLayer.scoped[Server](for {
ss <- ZIO.service[Server]
port <- ss.port.orDie
ch = ManagedChannelBuilder.forAddress("localhost", port).usePlaintext()
client <- TestServiceClient.scoped(ZManagedChannel(ch, prefetch, Nil)).orDie
} yield client)

def unarySuiteJVM =
suite("unary request")(
Expand Down Expand Up @@ -238,17 +236,24 @@ object TestServiceSpec extends ZIOSpecDefault with CommonTestServiceSpec {
}
)

val layers = TestServiceImpl.live >>>
(TestServiceImpl.any ++ serverLayer) >>>
(TestServiceImpl.any ++ clientLayer ++ Annotations.live)
type Deps = Server with TestServiceImpl with Annotations

def spec =
suite("TestServiceSpec")(
def spec = suite("TestServiceSpec")(
suite("without prefetch")(
unarySuite,
unarySuiteJVM,
serverStreamingSuite,
serverStreamingSuiteJVM,
clientStreamingSuite,
bidiStreamingSuite
).provideSomeLayer[Deps](clientLayer(None)),
suite("with prefetch = 2")(
unarySuite,
unarySuiteJVM,
serverStreamingSuite,
serverStreamingSuiteJVM,
clientStreamingSuite,
bidiStreamingSuite
).provideLayer(layers.orDie)
).provideSomeLayer[Deps](clientLayer(Some(2)))
).provide(serverLayer, TestServiceImpl.live, Annotations.live)
}

0 comments on commit 6fc4b28

Please sign in to comment.