diff --git a/README.md b/README.md index 9a490ae..c521fd9 100644 --- a/README.md +++ b/README.md @@ -168,34 +168,33 @@ Use `.checkErrorThrowIO()` on the result of a libuv function to throw an `IOExce uv_listen(serverTcpHandle, 128, onNewConnection).checkErrorThrowIO() ``` -When using a callback-based library like libuv, it is common that when everything works, cleanup like freeing memory must be done in a different callback function. However if something fails, we need to immediately cleanup anything we've allocated already. We can use `.checkErrorThrowIO()` with `try`/`catch` to do this, but it's a little verbose as everything that might need to be cleaned up needs to be declared as a `var` outside the `try` block: +When using a callback-based library like libuv, it is common that when everything works, cleanup like freeing memory must be done in a different callback function. However if something fails, we need to immediately cleanup anything we've allocated already. We can use `.checkErrorThrowIO()` with `try`/`catch` to do this, but we need to mainain some `var`s to keep track of how far we got: ```scala +def onClose: CloseCallback = (_: Handle).free() + def onNewConnection: ConnectionCallback = { (handle: StreamHandle, status: ErrorCode) => val loop = uv_handle_get_loop(handle) var clientTcpHandle: TcpHandle = null + var initialized = false try { status.checkErrorThrowIO() clientTcpHandle = TcpHandle.malloc() - println("New connection") uv_tcp_init(loop, clientTcpHandle).checkErrorThrowIO() + initialized = true uv_handle_set_data(clientTcpHandle, handle.toPtr) uv_accept(handle, clientTcpHandle).checkErrorThrowIO() - try { - uv_read_start(clientTcpHandle, allocBuffer, onRead) - .checkErrorThrowIO() - } catch { - case e: IOException => - uv_close(clientTcpHandle, onClose) - throw e - } + uv_read_start(clientTcpHandle, allocBuffer, onRead) + .checkErrorThrowIO() () } catch { case e: IOException => - if (clientTcpHandle != null) { + if (initialized) + // note the onClose callback will free the handle + uv_close(clientTcpHandle, onClose) + else if (clientTcpHandle != null) clientTcpHandle.free() - } setFailed(exception.getMessage()) } } @@ -205,26 +204,26 @@ As an alternative, scala-uv provides `UvUtils.attemptCatch` to make scenarios su ```scala +def onClose: CloseCallback = (_: Handle).free() + def onNewConnection: ConnectionCallback = { (handle: StreamHandle, status: ErrorCode) => val loop = uv_handle_get_loop(handle) UvUtils.attemptCatch { status.checkErrorThrowIO() val clientTcpHandle = TcpHandle.malloc() - UvUtils.onFail(clientTcpHandle.free()) - println("New connection") - uv_tcp_init(loop, clientTcpHandle).checkErrorThrowIO() + uv_tcp_init(loop, clientTcpHandle) + .onFail(clientTcpHandle.free()) + .checkErrorThrowIO() + UvUtils.onFail(uv_close(clientTcpHandle, onClose)) uv_handle_set_data(clientTcpHandle, handle.toPtr) uv_accept(handle, clientTcpHandle).checkErrorThrowIO() - UvUtils.onFail(uv_close(clientTcpHandle, onClose)) uv_read_start(clientTcpHandle, allocBuffer, onRead) .checkErrorThrowIO() () } { exception => setFailed(exception.getMessage()) } - // if `uv_read_start` failed, then `uv_close` followed by `clientTcpHandle.free()` - // have been run in that order by this point } ``` diff --git a/src/main/scala/scalauv/UvUtils.scala b/src/main/scala/scalauv/UvUtils.scala index f603a94..30d608a 100644 --- a/src/main/scala/scalauv/UvUtils.scala +++ b/src/main/scala/scalauv/UvUtils.scala @@ -121,6 +121,11 @@ object UvUtils { ): Unit = onCompleteActions.popAll().foreach(_(maybeFailure)) + private[UvUtils] inline def dropLast(): Unit = { + onCompleteActions.pop() + () + } + } /** Attempts to run a block of code, allowing cleanup operations to be @@ -202,6 +207,8 @@ object UvUtils { inline def onComplete(f: => Unit)(using cleanup: Cleanup): Unit = cleanup.addOnCompleteAction(_ => f) + inline def dropLast()(using cleanup: Cleanup): Unit = + cleanup.dropLast() } extension (uvResult: CInt) { diff --git a/src/test/scala/scalauv/TcpSpec.scala b/src/test/scala/scalauv/TcpSpec.scala index dddda93..4be97ab 100644 --- a/src/test/scala/scalauv/TcpSpec.scala +++ b/src/test/scala/scalauv/TcpSpec.scala @@ -32,12 +32,13 @@ final class TcpSpec { UvUtils.attemptCatch { status.checkErrorThrowIO() val clientTcpHandle = TcpHandle.malloc() - UvUtils.onFail(clientTcpHandle.free()) println("New connection") - uv_tcp_init(loop, clientTcpHandle).checkErrorThrowIO() + uv_tcp_init(loop, clientTcpHandle) + .onFail(clientTcpHandle.free()) + .checkErrorThrowIO() + UvUtils.onFail(uv_close(clientTcpHandle, onClose)) uv_handle_set_data(clientTcpHandle, handle.toPtr) uv_accept(handle, clientTcpHandle).checkErrorThrowIO() - UvUtils.onFail(uv_close(clientTcpHandle, onClose)) uv_read_start(clientTcpHandle, allocBuffer, onRead) .checkErrorThrowIO() () @@ -125,7 +126,7 @@ object TcpSpec { failed = Some(msg) } - def onClose: CloseCallback = (h: Handle) => stdlib.free(h.toPtr) + def onClose: CloseCallback = (_: Handle).free() def onRead: StreamReadCallback = { (handle: StreamHandle, numRead: CSSize, buf: Buffer) =>