From 9d8bdc0feff81dea09585f23d0f2307d4d664890 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Thu, 11 Apr 2024 21:54:20 +0900 Subject: [PATCH] Optimize client calls (#618) * Optimize client calls * Fix * Use ZStream.execute --- .../scalapb/zio_grpc/client/ClientCalls.scala | 62 ++++++++----------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/core/src/main/scalajvm/scalapb/zio_grpc/client/ClientCalls.scala b/core/src/main/scalajvm/scalapb/zio_grpc/client/ClientCalls.scala index d785a9d80..cbdd1fde9 100644 --- a/core/src/main/scalajvm/scalapb/zio_grpc/client/ClientCalls.scala +++ b/core/src/main/scalajvm/scalapb/zio_grpc/client/ClientCalls.scala @@ -42,14 +42,12 @@ object ClientCalls { StreamingClientCallListener.make[Res](call) )(anyExitHandler[Req, Res](call)) .flatMap { (listener: StreamingClientCallListener[Res]) => - ZStream - .fromZIO( - call.start(listener, headers) *> - call.request(1) *> - call.sendMessage(req) *> - call.halfClose() - ) - .drain ++ listener.stream + ZStream.unwrap( + (call.start(listener, headers) *> + call.request(1) *> + call.sendMessage(req) *> + call.halfClose()).as(listener.stream) + ) } def serverStreamingCall[Req, Res]( @@ -59,9 +57,11 @@ object ClientCalls { headers: SafeMetadata, req: Req ): ZStream[Any, StatusException, ResponseFrame[Res]] = - ZStream - .fromZIO(channel.newCall(method, options)) - .flatMap(serverStreamingCall(_, headers, req)) + ZStream.unwrap( + channel + .newCall(method, options) + .map(serverStreamingCall(_, headers, req)) + ) private def clientStreamingCall[Req, Res]( call: ZClientCall[Req, Res], @@ -69,15 +69,13 @@ object ClientCalls { req: ZStream[Any, StatusException, Req] ): IO[StatusException, ResponseContext[Res]] = ZIO.acquireReleaseExitWith(UnaryClientCallListener.make[Res])(exitHandler(call)) { listener => - val callStream = req.tap(call.sendMessage).drain ++ ZStream.fromZIO(call.halfClose()).drain - val resultStream = ZStream.fromZIO(listener.getValue) + val processRequestStream = req.runForeach(call.sendMessage) *> call.halfClose() + val getResult = listener.getValue call.start(listener, headers) *> call.request(1) *> - callStream - .merge(resultStream) - .runCollect - .map(res => res.last) + processRequestStream &> + getResult } def clientStreamingCall[Req, Res]( @@ -89,13 +87,7 @@ object ClientCalls { ): IO[StatusException, ResponseContext[Res]] = channel .newCall(method, options) - .flatMap( - clientStreamingCall( - _, - headers, - req - ) - ) + .flatMap(clientStreamingCall(_, headers, req)) private def bidiCall[Req, Res]( call: ZClientCall[Req, Res], @@ -107,14 +99,10 @@ object ClientCalls { StreamingClientCallListener.make[Res](call) )(anyExitHandler(call)) .flatMap { (listener: StreamingClientCallListener[Res]) => - val init = - ZStream - .fromZIO( - call.start(listener, headers) *> - call.request(1) - ) - val finish = ZStream.fromZIO(call.halfClose()) - val sendRequestStream = (init ++ req.tap(call.sendMessage) ++ finish).drain + val init = call.start(listener, headers) *> call.request(1) + val process = req.runForeach(call.sendMessage) + val finish = call.halfClose() + val sendRequestStream = ZStream.execute(init *> process *> finish) sendRequestStream.merge(listener.stream, ZStream.HaltStrategy.Right) } @@ -125,11 +113,11 @@ object ClientCalls { headers: SafeMetadata, req: ZStream[Any, StatusException, Req] ): ZStream[Any, StatusException, ResponseFrame[Res]] = - ZStream - .fromZIO( - channel.newCall(method, options) - ) - .flatMap(bidiCall(_, headers, req)) + ZStream.unwrap( + channel + .newCall(method, options) + .map(bidiCall(_, headers, req)) + ) } def exitHandler[Req, Res](