Skip to content

Commit

Permalink
unwrap exceptions from CompletableFuture
Browse files Browse the repository at this point in the history
  • Loading branch information
rssh committed Nov 18, 2024
1 parent 876990d commit 3fb01cb
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 14 deletions.
54 changes: 40 additions & 14 deletions jvm/src/main/scala/cps/monads/CompletableFutureCpsMonad.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package cps.monads

import cps._
import java.util.concurrent.CompletableFuture
import scala.util.Try
import scala.util.Failure
import scala.util.Success
import cps.*

import java.util.concurrent.{CompletableFuture, CompletionException}
import scala.concurrent.Future
import scala.util.{Failure, NotGiven, Success, Try}
import scala.util.control.NonFatal


Expand All @@ -30,7 +30,7 @@ given CompletableFutureCpsMonad: CpsSchedulingMonad[CompletableFuture] with CpsT
if (e == null) then
f(Success(v.nn))
else
f(Failure(e.nn))
f(Failure(unwrapCompletableException(e.nn)))
}.nn.toCompletableFuture.nn

override def flatMapTry[A,B](fa:CompletableFuture[A])(f: Try[A]=>CompletableFuture[B]):CompletableFuture[B] =
Expand All @@ -42,18 +42,18 @@ given CompletableFutureCpsMonad: CpsSchedulingMonad[CompletableFuture] with CpsT
if (e1 == null) then
retval.complete(v1.nn)
else
retval.completeExceptionally(e1.nn)
retval.completeExceptionally(unwrapCompletableException(e1))
}
catch
case NonFatal(ex) =>
retval.completeExceptionally(ex)
retval.completeExceptionally(unwrapCompletableException(ex))
else
try
f(Failure(e.nn)).handle{ (v1,e1) =>
f(Failure(unwrapCompletableException(e.nn))).handle{ (v1,e1) =>
if (e1 == null) then
retval.complete(v1.nn)
else
retval.completeExceptionally(e1.nn)
retval.completeExceptionally(unwrapCompletableException(e1.nn))
}
catch
case NonFatal(ex) =>
Expand All @@ -69,11 +69,11 @@ given CompletableFutureCpsMonad: CpsSchedulingMonad[CompletableFuture] with CpsT
retval.complete(v.nn)
else
try
fx(e).handle{ (v1,e1) =>
fx(unwrapCompletableException(e)).handle{ (v1,e1) =>
if (e1 == null) then
retval.complete(v1.nn)
else
retval.completeExceptionally(e1.nn)
retval.completeExceptionally(unwrapCompletableException(e1.nn))
}
catch
case NonFatal(ex) =>
Expand All @@ -98,20 +98,46 @@ given CompletableFutureCpsMonad: CpsSchedulingMonad[CompletableFuture] with CpsT
if (e == null)
r.complete(v)
else
r.completeExceptionally(e)
r.completeExceptionally(unwrapCompletableException(e))
}
catch
case NonFatal(e) =>
r.completeExceptionally(e)
}
r

def tryCancel[A](op: CompletableFuture[A]): CompletableFuture[Unit] =
def tryCancel[A](op: CompletableFuture[A]): CompletableFuture[Unit] =
if (op.cancel(true)) then
CompletableFuture.completedFuture(()).nn
else
CompletableFuture.failedFuture(new IllegalStateException("CompletableFuture is not cancelled")).nn


private def unwrapCompletableException(ex: Throwable): Throwable =
if (ex.isInstanceOf[CompletionException] && ex.getCause() != null) then
ex.getCause().nn
else
ex.nn


given fromCompletableFutureConversion[G[_], T](using CpsAsyncMonad[G], CpsMonadContext[G]): CpsMonadConversion[CompletableFuture, G] with

def apply[T](ft: CompletableFuture[T]): G[T] =
summon[CpsAsyncMonad[G]].adoptCallbackStyle(listener =>
val _unused = ft.whenComplete(
(v, e) =>
if (e == null) then
listener(Success(v))
else
if (e.isInstanceOf[CompletionException] && e.getCause() != null) then
listener(Failure(e.getCause()))
else
listener(Failure(e))
)
)


}



Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package cpstest

import org.junit.{Ignore, Test}
import org.junit.Assert.*

import java.util.concurrent.CompletableFuture
import scala.concurrent.*
import scala.concurrent.duration.*
import scala.util.*
import scala.util.control.*
import scala.concurrent.ExecutionContext.Implicits.global

import cps.*
import cps.monads.{*, given}

class TestAsyncExceptionInCompletableFuture {

object X {

def completableFutureMethod(): CompletableFuture[Int] = {
val cf = new CompletableFuture[Int]()
cf.completeExceptionally(new IllegalStateException("test exception"))
cf
}

}

@Test
def testExceptinInCompletableFuture(): Unit = {
val f = async[Future] {
try {
val x = X.completableFutureMethod().await
x
} catch {
case ex: IllegalStateException =>
-1
case NonFatal(ex) =>
println(s"unexpected exception: $ex")
throw ex
}
}
val x = Await.ready(f, 1.second)
assert(f.value.get.get == -1)
}

}

0 comments on commit 3fb01cb

Please sign in to comment.