Skip to content

Commit

Permalink
Fix websocket frame fragmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Konstantin Kolmogortsev committed Oct 27, 2024
1 parent 4d10390 commit 238b2da
Showing 1 changed file with 39 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import scala.util.chaining.given

import cats.{MonadThrow, Parallel}
import cats.data.NonEmptyList
import cats.effect.{Sync, Temporal}
import cats.effect.Async
import cats.syntax.all.given
import fs2.*

Expand All @@ -22,7 +22,7 @@ import muffin.http.*
import muffin.internal.syntax.*
import muffin.model.websocket.domain.*

class SttpClient[F[_]: Temporal: Parallel, To[_], From[_]](
class SttpClient[F[_]: Async: Parallel, To[_], From[_]](
backend: SttpBackend[F, Fs2Streams[F] & WebSockets],
codec: CodecSupport[To, From]
) extends HttpClient[F, To, From] {
Expand Down Expand Up @@ -118,33 +118,42 @@ class SttpClient[F[_]: Temporal: Parallel, To[_], From[_]](
listeners: List[EventListener[F]] = Nil
): F[Unit] = {
val websocketEventProcessing: Pipe[F, WebSocketFrame.Data[?], WebSocketFrame] = { input =>
input.flatMap {
case WebSocketFrame.Text(payload, _, _) =>
Stream.eval(
Decode[Event[RawJson]].apply(payload).liftTo[F] >>= {
event =>
listeners
.parTraverse(
_.onEvent(event)
.attempt
.map(
_.leftMap(err =>
MuffinError.Websockets.ListenerError(err.getMessage, event.eventType, err)
input
.evalMapAccumulate(StringBuilder()) {
case (fragmentedPayload, frame: WebSocketFrame.Text) if !frame.finalFragment =>
Async[F].delay(fragmentedPayload.append(frame.payload)).map(_ -> frame)

case (fragmentedPayload, frame: WebSocketFrame.Text) =>
Decode[Event[RawJson]]
.apply {
val res = fragmentedPayload.append(frame.payload).result()
fragmentedPayload.clear()
res
}
.liftTo[F]
.flatMap {
event =>
listeners
.parTraverse(
_.onEvent(event)
.attempt
.map(
_.leftMap(err =>
MuffinError.Websockets.ListenerError(err.getMessage, event.eventType, err)
)
)
) >>= {
_.collect { case Left(err) => err }
.pipe(NonEmptyList.fromList)
.traverse_(
MuffinError.Websockets.FailedWebsocketProcessing(_).raiseError[F, Unit]
)
) >>= {
_.collect { case Left(err) => err }
.pipe(NonEmptyList.fromList)
.traverse_(
MuffinError.Websockets.FailedWebsocketProcessing(_).raiseError[F, Unit]
)
}
}
) *>
Stream.empty

case _ => Stream.empty
}
}
}
.as((fragmentedPayload, frame))

case otherwise => otherwise.pure[F]
} *> Stream.empty
}

val request = basicRequest
Expand Down Expand Up @@ -173,7 +182,7 @@ class SttpClient[F[_]: Temporal: Parallel, To[_], From[_]](
case _: SttpClientException.ConnectException |
_: SttpClientException.TimeoutException |
_: SttpClientException.ReadException =>
Temporal[F].sleep(
Async[F].sleep(
backoffSettings.initialDelay min backoffSettings.maxDelayThreshold
) *> retryWithBackoff(
f,
Expand All @@ -189,9 +198,9 @@ class SttpClient[F[_]: Temporal: Parallel, To[_], From[_]](

object SttpClient {

def apply[I[_]: Sync, F[_]: Temporal: Parallel, To[_], From[_]](
def apply[I[_]: Async, F[_]: Async: Parallel, To[_], From[_]](
backend: SttpBackend[F, Fs2Streams[F] & WebSockets],
codec: CodecSupport[To, From]
): I[SttpClient[F, To, From]] = Sync[I].delay(new SttpClient[F, To, From](backend, codec))
): I[SttpClient[F, To, From]] = Async[I].delay(new SttpClient[F, To, From](backend, codec))

}

0 comments on commit 238b2da

Please sign in to comment.