diff --git a/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqConnection.scala b/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqConnection.scala index 29346efbf..6f63c1dec 100644 --- a/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqConnection.scala +++ b/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqConnection.scala @@ -1,6 +1,6 @@ package almond.channels.zeromq -import java.nio.channels.{ClosedByInterruptException, Selector} +import java.nio.channels.{ClosedByInterruptException, ClosedSelectorException, Selector} import java.nio.charset.StandardCharsets.UTF_8 import almond.channels._ @@ -137,7 +137,8 @@ final class ZeromqConnection( case Channel.Input => stdin0 } - @volatile private var selectorOpt = Option.empty[Selector] + @volatile private var selectorClosing = false + @volatile private var selectorOpt = Option.empty[Selector] private def withSelector[T](f: Selector => T): T = selectorOpt match { @@ -190,18 +191,29 @@ final class ZeromqConnection( (channel, new PollItem(socket.channel, Poller.POLLIN)) } - withSelector { selector => - ZMQ.poll(selector, pollItems.map(_._2).toArray, pollingDelay.toMillis) + val doRead = withSelector { selector => + try { + ZMQ.poll(selector, pollItems.map(_._2).toArray, pollingDelay.toMillis) + true + } + catch { + case _: ClosedSelectorException if selectorClosing || selectorOpt.isEmpty => + // channel was closed + false + } } - pollItems - .collectFirst { - case (channel, pi) if pi.isReadable => - channelSocket0(channel) - .read - .map(_.map((channel, _))) - } - .getOrElse(IO.pure(None)) + if (doRead) + pollItems + .collectFirst { + case (channel, pi) if pi.isReadable => + channelSocket0(channel) + .read + .map(_.map((channel, _))) + } + .getOrElse(IO.pure(None)) + else + IO.pure(None) }.evalOn(threads.pollingEc).flatMap(identity) def close(partial: Boolean, lingerDuration: Duration): IO[Unit] = { @@ -222,8 +234,13 @@ final class ZeromqConnection( if (!partial) heartBeatThreadOpt.foreach(_.interrupt()) - selectorOpt.foreach(_.close()) - selectorOpt = None + try { + selectorClosing = true + selectorOpt.foreach(_.close()) + selectorOpt = None + } + finally + selectorClosing = false log.debug(s"Closed channels for $params" + (if (partial) " (partial)" else "")) }.evalOn(threads.selectorOpenCloseEc)