diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/EncodedNodeId.scala b/eclair-core/src/main/scala/fr/acinq/eclair/EncodedNodeId.scala new file mode 100644 index 0000000000..9cfe9311c3 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/EncodedNodeId.scala @@ -0,0 +1,19 @@ +package fr.acinq.eclair + +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey + +sealed trait EncodedNodeId + +object EncodedNodeId { + /** Nodes are usually identified by their public key. */ + case class Plain(publicKey: PublicKey) extends EncodedNodeId { + override def toString: String = publicKey.toString + } + + /** For compactness, nodes may be identified by the shortChannelId of one of their public channels. */ + case class ShortChannelIdDir(isNode1: Boolean, scid: RealShortChannelId) extends EncodedNodeId { + override def toString: String = if (isNode1) s"<-$scid" else s"$scid->" + } + + def apply(publicKey: PublicKey): EncodedNodeId = Plain(publicKey) +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala index ab04b5f70b..0e7767014a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala @@ -366,7 +366,7 @@ class Setup(val datadir: File, txPublisherFactory = Channel.SimpleTxPublisherFactory(nodeParams, watcher, bitcoinClient) channelFactory = Peer.SimpleChannelFactory(nodeParams, watcher, relayer, bitcoinClient, txPublisherFactory) pendingChannelsRateLimiter = system.spawn(Behaviors.supervise(PendingChannelsRateLimiter(nodeParams, router.toTyped, channels)).onFailure(typed.SupervisorStrategy.resume), name = "pending-channels-rate-limiter") - peerFactory = Switchboard.SimplePeerFactory(nodeParams, bitcoinClient, channelFactory, pendingChannelsRateLimiter, register) + peerFactory = Switchboard.SimplePeerFactory(nodeParams, bitcoinClient, channelFactory, pendingChannelsRateLimiter, register, router.toTyped) switchboard = system.actorOf(SimpleSupervisor.props(Switchboard.props(nodeParams, peerFactory), "switchboard", SupervisorStrategy.Resume)) _ = switchboard ! Switchboard.Init(channels) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala index 29e3322938..52fe8f2821 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala @@ -17,25 +17,27 @@ package fr.acinq.eclair.io import akka.actor.typed.Behavior +import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.{ActorContext, Behaviors} import akka.actor.{ActorRef, typed} import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey -import fr.acinq.eclair.ShortChannelId import fr.acinq.eclair.channel.Register import fr.acinq.eclair.io.Peer.{PeerInfo, PeerInfoResponse} import fr.acinq.eclair.io.Switchboard.GetPeerInfo +import fr.acinq.eclair.message.OnionMessages +import fr.acinq.eclair.message.OnionMessages.DropReason +import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OnionMessage +import fr.acinq.eclair.{EncodedNodeId, NodeParams, ShortChannelId} object MessageRelay { // @formatter:off sealed trait Command case class RelayMessage(messageId: ByteVector32, - switchboard: ActorRef, - register: ActorRef, prevNodeId: PublicKey, - nextNode: Either[ShortChannelId, PublicKey], + nextNode: Either[ShortChannelId, EncodedNodeId], msg: OnionMessage, policy: RelayPolicy, replyTo_opt: Option[typed.ActorRef[Status]]) extends Command @@ -60,66 +62,101 @@ object MessageRelay { case class UnknownOutgoingChannel(messageId: ByteVector32, outgoingChannelId: ShortChannelId) extends Failure { override def toString: String = s"Unknown outgoing channel: $outgoingChannelId" } + case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure { + override def toString: String = s"Message dropped: $reason" + } sealed trait RelayPolicy case object RelayChannelsOnly extends RelayPolicy case object RelayAll extends RelayPolicy // @formatter:on - def apply(): Behavior[Command] = { - Behaviors.receivePartial { - case (context, RelayMessage(messageId, switchboard, register, prevNodeId, Left(outgoingChannelId), msg, policy, replyTo_opt)) => + def apply(nodeParams: NodeParams, + switchboard: ActorRef, + register: ActorRef, + router: typed.ActorRef[Router.GetNodeId]): Behavior[Command] = { + Behaviors.setup { context => + Behaviors.receiveMessagePartial { + case RelayMessage(messageId, prevNodeId, nextNode, msg, policy, replyTo_opt) => + val relay = new MessageRelay(nodeParams, messageId, prevNodeId, policy, switchboard, register, router, replyTo_opt, context) + relay.queryNextNodeId(msg, nextNode) + } + } + } +} + +private class MessageRelay(nodeParams: NodeParams, + messageId: ByteVector32, + prevNodeId: PublicKey, + policy: MessageRelay.RelayPolicy, + switchboard: ActorRef, + register: ActorRef, + router: typed.ActorRef[Router.GetNodeId], + replyTo_opt: Option[typed.ActorRef[MessageRelay.Status]], + context: ActorContext[MessageRelay.Command]) { + + import MessageRelay._ + + def queryNextNodeId(msg: OnionMessage, nextNode: Either[ShortChannelId, EncodedNodeId]): Behavior[Command] = { + nextNode match { + case Left(outgoingChannelId) => register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId) - waitForNextNodeId(messageId, switchboard, prevNodeId, outgoingChannelId, msg, policy, replyTo_opt) - case (context, RelayMessage(messageId, switchboard, _, prevNodeId, Right(nextNodeId), msg, policy, replyTo_opt)) => - withNextNodeId(context, messageId, switchboard, prevNodeId, nextNodeId, msg, policy, replyTo_opt) + waitForNextNodeId(msg, outgoingChannelId) + case Right(EncodedNodeId.ShortChannelIdDir(isNode1, scid)) => + router ! Router.GetNodeId(context.messageAdapter(WrappedOptionalNodeId), scid, isNode1) + waitForNextNodeId(msg, scid) + case Right(EncodedNodeId.Plain(nextNodeId)) => + withNextNodeId(msg, nextNodeId) } } - def waitForNextNodeId(messageId: ByteVector32, - switchboard: ActorRef, - prevNodeId: PublicKey, - outgoingChannelId: ShortChannelId, - msg: OnionMessage, - policy: RelayPolicy, - replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] = - Behaviors.receivePartial { - case (_, WrappedOptionalNodeId(None)) => + private def waitForNextNodeId(msg: OnionMessage, outgoingChannelId: ShortChannelId): Behavior[Command] = + Behaviors.receiveMessagePartial { + case WrappedOptionalNodeId(None) => replyTo_opt.foreach(_ ! UnknownOutgoingChannel(messageId, outgoingChannelId)) Behaviors.stopped - case (context, WrappedOptionalNodeId(Some(nextNodeId))) => - withNextNodeId(context, messageId, switchboard, prevNodeId, nextNodeId, msg, policy, replyTo_opt) + case WrappedOptionalNodeId(Some(nextNodeId)) => + withNextNodeId(msg, nextNodeId) } - def withNextNodeId(context: ActorContext[Command], - messageId: ByteVector32, - switchboard: ActorRef, - prevNodeId: PublicKey, - nextNodeId: PublicKey, - msg: OnionMessage, - policy: RelayPolicy, - replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] = - policy match { - case RelayChannelsOnly => - switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId) - waitForPreviousPeer(messageId, switchboard, nextNodeId, msg, replyTo_opt) - case RelayAll => - switchboard ! Peer.Connect(nextNodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false) - waitForConnection(messageId, msg, replyTo_opt) - } + private def withNextNodeId(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = { + if (nextNodeId == nodeParams.nodeId) { + OnionMessages.process(nodeParams.privateKey, msg) match { + case OnionMessages.DropMessage(reason) => + replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason)) + Behaviors.stopped + case OnionMessages.SendMessage(nextNode, nextMessage) => + // We need to repeat the process until we identify the (real) next node, or find out that we're the recipient. + queryNextNodeId(nextMessage, nextNode) + case received: OnionMessages.ReceiveMessage => + context.system.eventStream ! EventStream.Publish(received) + replyTo_opt.foreach(_ ! Sent(messageId)) + Behaviors.stopped + } + } else { + policy match { + case RelayChannelsOnly => + switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId) + waitForPreviousPeerForPolicyCheck(msg, nextNodeId) + case RelayAll => + switchboard ! Peer.Connect(nextNodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false) + waitForConnection(msg) + } + } + } - def waitForPreviousPeer(messageId: ByteVector32, switchboard: ActorRef, nextNodeId: PublicKey, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] = { - Behaviors.receivePartial { - case (context, WrappedPeerInfo(PeerInfo(_, _, _, _, channels))) if channels.nonEmpty => + private def waitForPreviousPeerForPolicyCheck(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = { + Behaviors.receiveMessagePartial { + case WrappedPeerInfo(PeerInfo(_, _, _, _, channels)) if channels.nonEmpty => switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), nextNodeId) - waitForNextPeer(messageId, msg, replyTo_opt) + waitForNextPeerForPolicyCheck(msg) case _ => replyTo_opt.foreach(_ ! AgainstPolicy(messageId, RelayChannelsOnly)) Behaviors.stopped } } - def waitForNextPeer(messageId: ByteVector32, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] = { + private def waitForNextPeerForPolicyCheck(msg: OnionMessage): Behavior[Command] = { Behaviors.receiveMessagePartial { case WrappedPeerInfo(PeerInfo(peer, _, _, _, channels)) if channels.nonEmpty => peer ! Peer.RelayOnionMessage(messageId, msg, replyTo_opt) @@ -130,7 +167,7 @@ object MessageRelay { } } - def waitForConnection(messageId: ByteVector32, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] = { + private def waitForConnection(msg: OnionMessage): Behavior[Command] = { Behaviors.receiveMessagePartial { case WrappedConnectionResult(r: PeerConnection.ConnectionResult.HasConnection) => r.peer ! Peer.RelayOnionMessage(messageId, msg, replyTo_opt) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index 4a885b1ddc..5cb90f4fd0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -38,6 +38,7 @@ import fr.acinq.eclair.io.OpenChannelInterceptor.{OpenChannelInitiator, OpenChan import fr.acinq.eclair.io.PeerConnection.KillReason import fr.acinq.eclair.message.OnionMessages import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes +import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId, LightningMessage, NodeAddress, OnionMessage, RoutingMessage, UnknownMessage, Warning} @@ -51,7 +52,14 @@ import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId * * Created by PM on 26/08/2016. */ -class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, switchboard: ActorRef, register: ActorRef, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] { +class Peer(val nodeParams: NodeParams, + remoteNodeId: PublicKey, + wallet: OnchainPubkeyCache, + channelFactory: Peer.ChannelFactory, + switchboard: ActorRef, + register: ActorRef, + router: typed.ActorRef[Router.GetNodeId], + pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] { import Peer._ @@ -279,8 +287,8 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainP log.debug("dropping message from {}: {}", remoteNodeId.value.toHex, reason.toString) case OnionMessages.SendMessage(nextNode, message) if nodeParams.features.hasFeature(Features.OnionMessages) => val messageId = randomBytes32() - val relay = context.spawn(Behaviors.supervise(MessageRelay()).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId") - relay ! MessageRelay.RelayMessage(messageId, switchboard, register, remoteNodeId, nextNode, message, nodeParams.onionMessageConfig.relayPolicy, None) + val relay = context.spawn(Behaviors.supervise(MessageRelay(nodeParams, switchboard, register, router)).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId") + relay ! MessageRelay.RelayMessage(messageId, remoteNodeId, nextNode, message, nodeParams.onionMessageConfig.relayPolicy, None) case OnionMessages.SendMessage(_, _) => log.debug("dropping message from {}: relaying onion messages is disabled", remoteNodeId.value.toHex) case received: OnionMessages.ReceiveMessage => @@ -458,7 +466,8 @@ object Peer { context.actorOf(Channel.props(nodeParams, wallet, remoteNodeId, watcher, relayer, txPublisherFactory)) } - def props(nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: ChannelFactory, switchboard: ActorRef, register: ActorRef, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]): Props = Props(new Peer(nodeParams, remoteNodeId, wallet, channelFactory, switchboard, register, pendingChannelsRateLimiter)) + def props(nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: ChannelFactory, switchboard: ActorRef, register: ActorRef, router: typed.ActorRef[Router.GetNodeId], pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]): Props = + Props(new Peer(nodeParams, remoteNodeId, wallet, channelFactory, switchboard, register, router, pendingChannelsRateLimiter)) // @formatter:off diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala index ca49947034..0c5d4bb3d0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala @@ -27,6 +27,7 @@ import fr.acinq.eclair.channel._ import fr.acinq.eclair.io.IncomingConnectionsTracker.TrackIncomingConnection import fr.acinq.eclair.io.Peer.{PeerInfoResponse, PeerNotFound} import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes +import fr.acinq.eclair.router.Router import fr.acinq.eclair.router.Router.RouterConf import fr.acinq.eclair.{NodeParams, SubscriptionsComplete} @@ -159,9 +160,9 @@ object Switchboard { def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef } - case class SimplePeerFactory(nodeParams: NodeParams, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command], register: ActorRef) extends PeerFactory { + case class SimplePeerFactory(nodeParams: NodeParams, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command], register: ActorRef, router: typed.ActorRef[Router.GetNodeId]) extends PeerFactory { override def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef = - context.actorOf(Peer.props(nodeParams, remoteNodeId, wallet, channelFactory, context.self, register, pendingChannelsRateLimiter), name = peerActorName(remoteNodeId)) + context.actorOf(Peer.props(nodeParams, remoteNodeId, wallet, channelFactory, context.self, register, router, pendingChannelsRateLimiter), name = peerActorName(remoteNodeId)) } def props(nodeParams: NodeParams, peerFactory: PeerFactory) = Props(new Switchboard(nodeParams, peerFactory)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala b/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala index 6db3d2eba7..05f599c13b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala @@ -17,7 +17,7 @@ package fr.acinq.eclair.message import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} -import fr.acinq.eclair.ShortChannelId +import fr.acinq.eclair.{EncodedNodeId, ShortChannelId} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.io.MessageRelay.RelayPolicy import fr.acinq.eclair.wire.protocol.MessageOnion.{FinalPayload, IntermediatePayload} @@ -105,9 +105,9 @@ object OnionMessages { case Left(_) => None case Right(decoded) => decoded.tlvs.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId] match { - case None => None - case Some(RouteBlindingEncryptedDataTlv.OutgoingNodeId(nextNodeId)) => + case Some(RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.Plain(nextNodeId))) => Some(Sphinx.RouteBlinding.BlindedRoute(nextNodeId, decoded.nextBlinding, route.blindedNodes.tail)) + case _ => None // TODO: allow compact node id and OutgoingChannelId } } case BlindedPath(route) if intermediateNodes.isEmpty => Some(route) @@ -165,7 +165,7 @@ object OnionMessages { // @formatter:off sealed trait Action case class DropMessage(reason: DropReason) extends Action - case class SendMessage(nextNode: Either[ShortChannelId, PublicKey], message: OnionMessage) extends Action + case class SendMessage(nextNode: Either[ShortChannelId, EncodedNodeId], message: OnionMessage) extends Action case class ReceiveMessage(finalPayload: FinalPayload) extends Action sealed trait DropReason @@ -211,8 +211,8 @@ object OnionMessages { case Left(f) => DropMessage(f) case Right(DecodedEncryptedData(blindedPayload, nextBlinding)) => nextPacket_opt match { case Some(nextPacket) => validateRelayPayload(payload, blindedPayload, nextBlinding, nextPacket) match { - case SendMessage(Right(nextNodeId), nextMsg) if nextNodeId == privateKey.publicKey => process(privateKey, nextMsg) - case SendMessage(Left(outgoingChannelId), nextMsg) if outgoingChannelId == ShortChannelId.toSelf => process(privateKey, nextMsg) + case SendMessage(Right(EncodedNodeId.Plain(publicKey)), nextMsg) if publicKey == privateKey.publicKey => process(privateKey, nextMsg) // TODO: remove and rely on MessageRelay + case SendMessage(Left(outgoingChannelId), nextMsg) if outgoingChannelId == ShortChannelId.toSelf => process(privateKey, nextMsg) // TODO: remove and rely on MessageRelay case action => action } case None => validateFinalPayload(payload, blindedPayload) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala b/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala index 2f805cebd0..e25f73fa79 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala @@ -31,7 +31,7 @@ import fr.acinq.eclair.router.Router.{MessageRoute, MessageRouteNotFound, Messag import fr.acinq.eclair.wire.protocol.MessageOnion.{FinalPayload, InvoiceRequestPayload} import fr.acinq.eclair.wire.protocol.OfferTypes.{CompactBlindedPath, ContactInfo} import fr.acinq.eclair.wire.protocol.{OfferTypes, OnionMessagePayloadTlv, TlvStream} -import fr.acinq.eclair.{NodeParams, randomBytes32, randomKey} +import fr.acinq.eclair.{EncodedNodeId, NodeParams, randomBytes32, randomKey} import scala.collection.mutable @@ -214,8 +214,8 @@ private class SendingMessage(nodeParams: NodeParams, replyTo ! Postman.MessageFailed(failure.toString) Behaviors.stopped case Right((nextNodeId, message)) => - val relay = context.spawn(Behaviors.supervise(MessageRelay()).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId") - relay ! MessageRelay.RelayMessage(messageId, switchboard, register, nodeParams.nodeId, Right(nextNodeId), message, MessageRelay.RelayAll, Some(context.messageAdapter[MessageRelay.Status](SendingStatus))) + val relay = context.spawn(Behaviors.supervise(MessageRelay(nodeParams, switchboard, register, router)).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId") + relay ! MessageRelay.RelayMessage(messageId, nodeParams.nodeId, Right(EncodedNodeId(nextNodeId)), message, MessageRelay.RelayAll, Some(context.messageAdapter[MessageRelay.Status](SendingStatus))) waitForSent() } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index 7764f11128..9a0f730f66 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -27,6 +27,7 @@ import akka.util.Timeout import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto} import fr.acinq.eclair.Logs.LogCategory +import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, RES_SUCCESS} import fr.acinq.eclair.db._ import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop @@ -36,7 +37,7 @@ import fr.acinq.eclair.payment.offer.OfferManager import fr.acinq.eclair.router.BlindedRouteCreation.{aggregatePaymentInfo, createBlindedRouteFromHops, createBlindedRouteWithoutHops} import fr.acinq.eclair.router.Router import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams} -import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, InvoiceTlv, ShortChannelIdDir} +import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, InvoiceTlv} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{Bolt11Feature, CltvExpiryDelta, FeatureSupport, Features, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TimestampMilli, randomBytes32} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala index 618583120c..021ae6c3b4 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala @@ -17,7 +17,7 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey -import fr.acinq.eclair.{ShortChannelId, UInt64} +import fr.acinq.eclair.{EncodedNodeId, ShortChannelId, UInt64} import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedNode, BlindedRoute} import fr.acinq.eclair.payment.Bolt12Invoice import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} @@ -73,7 +73,7 @@ object MessageOnion { /** Per-hop payload for an intermediate node. */ case class IntermediatePayload(records: TlvStream[OnionMessagePayloadTlv], blindedRecords: TlvStream[RouteBlindingEncryptedDataTlv], nextBlinding: PublicKey) extends PerHopPayload { - val nextNode: Either[ShortChannelId, PublicKey] = + val nextNode: Either[ShortChannelId, EncodedNodeId] = blindedRecords.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].map(outgoingNodeId => Right(outgoingNodeId.nodeId)) .getOrElse(Left(blindedRecords.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId].get.shortChannelId)) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala index 5bee31ab1e..0c1a42ede3 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala @@ -17,11 +17,13 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.scalacompat.BlockHash +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedNode, BlindedRoute} import fr.acinq.eclair.wire.protocol.CommonCodecs._ import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequestChain, InvoiceRequestPayerNote, InvoiceRequestQuantity, _} import fr.acinq.eclair.wire.protocol.TlvCodecs.{tlvField, tmillisatoshi, tu32, tu64overflow} -import fr.acinq.eclair.{TimestampSecond, UInt64} +import fr.acinq.eclair.{EncodedNodeId, TimestampSecond, UInt64} import scodec.{Attempt, Codec, Err} import scodec.codecs._ @@ -60,6 +62,8 @@ object OfferCodecs { (("isNode1" | isNode1) :: ("scid" | realshortchannelid)).as[ShortChannelIdDir] + val encodedNodeIdCodec: Codec[EncodedNodeId] = choice(shortChannelIdDirCodec.upcast[EncodedNodeId], publicKey.as[EncodedNodeId.Plain].upcast[EncodedNodeId]) + private val compactBlindedPathCodec: Codec[CompactBlindedPath] = (("introductionNode" | shortChannelIdDirCodec) :: ("blinding" | publicKey) :: diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala index 3d7fa51479..489aba3f01 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala @@ -19,6 +19,7 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.Bech32 import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey, XonlyPublicKey} import fr.acinq.bitcoin.scalacompat.{Block, BlockHash, ByteVector32, ByteVector64, Crypto, LexicographicalOrdering} +import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedNode, BlindedRoute} import fr.acinq.eclair.wire.protocol.CommonCodecs.varint import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} @@ -35,8 +36,6 @@ import scala.util.{Failure, Try} * see https://github.com/lightning/bolts/blob/master/12-offer-encoding.md */ object OfferTypes { - case class ShortChannelIdDir(isNode1: Boolean, scid: RealShortChannelId) - // @formatter:off /** Data provided to reach the issuer of an offer or invoice. */ sealed trait ContactInfo diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala index 178901f3e6..2d07944500 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala @@ -21,7 +21,7 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.wire.protocol.CommonCodecs.{cltvExpiry, cltvExpiryDelta, featuresCodec} import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs.{fixedLengthTlvField, tlvField, tmillisatoshi, tmillisatoshi32} -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, UInt64} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, EncodedNodeId, ShortChannelId, UInt64} import scodec.bits.ByteVector import scala.util.{Failure, Success} @@ -40,8 +40,14 @@ object RouteBlindingEncryptedDataTlv { /** Id of the outgoing channel, used to identify the next node. */ case class OutgoingChannelId(shortChannelId: ShortChannelId) extends RouteBlindingEncryptedDataTlv - /** Id of the next node. */ - case class OutgoingNodeId(nodeId: PublicKey) extends RouteBlindingEncryptedDataTlv + /** + * Id of the next node. + * Warning: the spec only allows a public key here. We allow reading a ShortChannelIdDir for phoenix but we should never write one. + */ + case class OutgoingNodeId(nodeId: EncodedNodeId) extends RouteBlindingEncryptedDataTlv + object OutgoingNodeId { + def apply(publicKey: PublicKey): OutgoingNodeId = OutgoingNodeId(EncodedNodeId(publicKey)) + } /** * The final recipient may store some data in the encrypted payload for itself to avoid storing it locally. @@ -109,12 +115,13 @@ object RouteBlindingEncryptedDataCodecs { import RouteBlindingEncryptedDataTlv._ import fr.acinq.eclair.wire.protocol.CommonCodecs.{publicKey, shortchannelid, varint} + import fr.acinq.eclair.wire.protocol.OfferCodecs.encodedNodeIdCodec import scodec.codecs._ import scodec.{Attempt, Codec, DecodeResult} private val padding: Codec[Padding] = tlvField(bytes) private val outgoingChannelId: Codec[OutgoingChannelId] = tlvField(shortchannelid) - private val outgoingNodeId: Codec[OutgoingNodeId] = fixedLengthTlvField(33, publicKey) + private val outgoingNodeId: Codec[OutgoingNodeId] = tlvField(encodedNodeIdCodec) private val pathId: Codec[PathId] = tlvField(bytes) private val nextBlinding: Codec[NextBlinding] = fixedLengthTlvField(33, publicKey) private val paymentRelay: Codec[PaymentRelay] = tlvField(("cltv_expiry_delta" | cltvExpiryDelta) :: ("fee_proportional_millionths" | uint32) :: ("fee_base_msat" | tmillisatoshi32)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala index 93cb1785e7..e5159d9356 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala @@ -23,6 +23,7 @@ import akka.testkit.TestProbe import com.typesafe.config.ConfigFactory import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto, SatoshiLong} +import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.blockchain.bitcoind.BitcoindService.BitcoinReq import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher.{Watch, WatchFundingConfirmed} @@ -43,7 +44,7 @@ import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentToNode, SendTra import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.router.Router.{GossipDecision, PublicChannel} import fr.acinq.eclair.router.{Announcements, AnnouncementsBatchValidationSpec, Router} -import fr.acinq.eclair.wire.protocol.OfferTypes.{CompactBlindedPath, Offer, OfferPaths, ShortChannelIdDir} +import fr.acinq.eclair.wire.protocol.OfferTypes.{CompactBlindedPath, Offer, OfferPaths} import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, IncorrectOrUnknownPaymentDetails, OfferTypes} import fr.acinq.eclair.{CltvExpiryDelta, EclairImpl, Features, Kit, MilliSatoshiLong, ShortChannelId, TimestampMilli, randomBytes32, randomKey} import org.json4s.JsonAST.{JString, JValue} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala index 9dc831eb6d..df0c6961ad 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala @@ -100,7 +100,7 @@ object MinimalNodeFixture extends Assertions with Eventually with IntegrationPat val txPublisherFactory = Channel.SimpleTxPublisherFactory(nodeParams, watcherTyped, bitcoinClient) val channelFactory = Peer.SimpleChannelFactory(nodeParams, watcherTyped, relayer, wallet, txPublisherFactory) val pendingChannelsRateLimiter = system.spawnAnonymous(Behaviors.supervise(PendingChannelsRateLimiter(nodeParams, router.toTyped, Seq())).onFailure(typed.SupervisorStrategy.resume)) - val peerFactory = Switchboard.SimplePeerFactory(nodeParams, wallet, channelFactory, pendingChannelsRateLimiter, register) + val peerFactory = Switchboard.SimplePeerFactory(nodeParams, wallet, channelFactory, pendingChannelsRateLimiter, register, router.toTyped) val switchboard = system.actorOf(Switchboard.props(nodeParams, peerFactory), "switchboard") val paymentFactory = PaymentInitiator.SimplePaymentFactory(nodeParams, router, register) val paymentInitiator = system.actorOf(PaymentInitiator.props(nodeParams, paymentFactory), "payment-initiator") diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala index d85ffc782a..1f9a150db1 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala @@ -29,29 +29,31 @@ import fr.acinq.eclair.io.Peer.{PeerInfo, PeerNotFound} import fr.acinq.eclair.io.Switchboard.GetPeerInfo import fr.acinq.eclair.message.OnionMessages import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient} -import fr.acinq.eclair.wire.protocol.TlvStream -import fr.acinq.eclair.{ShortChannelId, randomBytes32, randomKey} +import fr.acinq.eclair.router.Router +import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionMessagePayloadTlv, TlvStream} +import fr.acinq.eclair.{EncodedNodeId, RealShortChannelId, ShortChannelId, UInt64, randomBytes32, randomKey} import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike +import scodec.bits.HexStringSyntax import scala.concurrent.duration.DurationInt -import scala.util.Success class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike { val aliceId: PublicKey = Alice.nodeParams.nodeId val bobId: PublicKey = Bob.nodeParams.nodeId - case class FixtureParam(relay: ActorRef[Command], switchboard: TestProbe, register: TestProbe, peerConnection: TypedProbe[Nothing], peer: TypedProbe[Peer.RelayOnionMessage], probe: TypedProbe[Status]) + case class FixtureParam(relay: ActorRef[Command], switchboard: TestProbe, register: TestProbe, router: TypedProbe[Router.GetNodeId], peerConnection: TypedProbe[Nothing], peer: TypedProbe[Peer.RelayOnionMessage], probe: TypedProbe[Status]) override def withFixture(test: OneArgTest): Outcome = { val switchboard = TestProbe("switchboard")(system.classicSystem) val register = TestProbe("register")(system.classicSystem) + val router = TypedProbe[Router.GetNodeId]("router") val peerConnection = TypedProbe[Nothing]("peerConnection") val peer = TypedProbe[Peer.RelayOnionMessage]("peer") val probe = TypedProbe[Status]("probe") - val relay = testKit.spawn(MessageRelay()) + val relay = testKit.spawn(MessageRelay(Alice.nodeParams, switchboard.ref, register.ref, router.ref)) try { - withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, peerConnection, peer, probe))) + withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, router, peerConnection, peer, probe))) } finally { testKit.stop(relay) } @@ -60,9 +62,10 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("relay with new connection") { f => import f._ - val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty) + val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) + assert(nextNode == bobId) val messageId = randomBytes32() - relay ! RelayMessage(messageId, switchboard.ref, register.ref, randomKey().publicKey, Right(bobId), message, RelayAll, None) + relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId(bobId)), message, RelayAll, None) val connectToNextPeer = switchboard.expectMsgType[Peer.Connect] assert(connectToNextPeer.nodeId == bobId) @@ -73,9 +76,10 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("relay with existing peer") { f => import f._ - val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty) + val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) + assert(nextNode == bobId) val messageId = randomBytes32() - relay ! RelayMessage(messageId, switchboard.ref, register.ref, randomKey().publicKey, Right(bobId), message, RelayAll, None) + relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId(bobId)), message, RelayAll, None) val connectToNextPeer = switchboard.expectMsgType[Peer.Connect] assert(connectToNextPeer.nodeId == bobId) @@ -86,9 +90,10 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("can't open new connection") { f => import f._ - val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty) + val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) + assert(nextNode == bobId) val messageId = randomBytes32() - relay ! RelayMessage(messageId, switchboard.ref, register.ref, randomKey().publicKey, Right(bobId), message, RelayAll, Some(probe.ref)) + relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId(bobId)), message, RelayAll, Some(probe.ref)) val connectToNextPeer = switchboard.expectMsgType[Peer.Connect] assert(connectToNextPeer.nodeId == bobId) @@ -99,10 +104,11 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("no channel with previous node") { f => import f._ - val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty) + val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) + assert(nextNode == bobId) val messageId = randomBytes32() val previousNodeId = randomKey().publicKey - relay ! RelayMessage(messageId, switchboard.ref, register.ref, previousNodeId, Right(bobId), message, RelayChannelsOnly, Some(probe.ref)) + relay ! RelayMessage(messageId, previousNodeId, Right(EncodedNodeId(bobId)), message, RelayChannelsOnly, Some(probe.ref)) val getPeerInfo = switchboard.expectMsgType[GetPeerInfo] assert(getPeerInfo.remoteNodeId == previousNodeId) @@ -115,10 +121,11 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("no channel with next node") { f => import f._ - val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty) + val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) + assert(nextNode == bobId) val messageId = randomBytes32() val previousNodeId = randomKey().publicKey - relay ! RelayMessage(messageId, switchboard.ref, register.ref, previousNodeId, Right(bobId), message, RelayChannelsOnly, Some(probe.ref)) + relay ! RelayMessage(messageId, previousNodeId, Right(EncodedNodeId(bobId)), message, RelayChannelsOnly, Some(probe.ref)) val getPeerInfo1 = switchboard.expectMsgType[GetPeerInfo] assert(getPeerInfo1.remoteNodeId == previousNodeId) @@ -135,10 +142,11 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("channels on both ends") { f => import f._ - val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty) + val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) + assert(nextNode == bobId) val messageId = randomBytes32() val previousNodeId = randomKey().publicKey - relay ! RelayMessage(messageId, switchboard.ref, register.ref, previousNodeId, Right(bobId), message, RelayChannelsOnly, None) + relay ! RelayMessage(messageId, previousNodeId, Right(EncodedNodeId(bobId)), message, RelayChannelsOnly, None) val getPeerInfo1 = switchboard.expectMsgType[GetPeerInfo] assert(getPeerInfo1.remoteNodeId == previousNodeId) @@ -154,10 +162,11 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("next node specified with channel id") { f => import f._ - val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty) + val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) + assert(nextNode == bobId) val messageId = randomBytes32() val scid = ShortChannelId(123456L) - relay ! RelayMessage(messageId, switchboard.ref, register.ref, randomKey().publicKey, Left(scid), message, RelayAll, None) + relay ! RelayMessage(messageId, randomKey().publicKey, Left(scid), message, RelayAll, None) val getNextNodeId = register.expectMsgType[Register.GetNextNodeId] assert(getNextNodeId.shortChannelId == scid) @@ -168,4 +177,46 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app connectToNextPeer.replyTo ! PeerConnection.ConnectionResult.AlreadyConnected(peerConnection.ref.toClassic, peer.ref.toClassic) assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message) } + + test("next node is compact node id") { f => + import f._ + + val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) + assert(nextNode == bobId) + val messageId = randomBytes32() + val scid = RealShortChannelId(234567L) + relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId.ShortChannelIdDir(isNode1 = false, scid)), message, RelayAll, None) + + val getNodeId = router.expectMessageType[Router.GetNodeId] + assert(getNodeId.isNode1 == false) + assert(getNodeId.shortChannelId == scid) + getNodeId.replyTo ! Some(bobId) + + val connectToNextPeer = switchboard.expectMsgType[Peer.Connect] + assert(connectToNextPeer.nodeId == bobId) + connectToNextPeer.replyTo ! PeerConnection.ConnectionResult.AlreadyConnected(peerConnection.ref.toClassic, peer.ref.toClassic) + assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message) + } + + test("next node is us as compact node id") { f => + import f._ + + val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(31), hex"f3ed")))) + assert(nextNode == aliceId) + val messageId = randomBytes32() + val scid = RealShortChannelId(345678L) + relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId.ShortChannelIdDir(isNode1 = true, scid)), message, RelayAll, None) + + val getNodeId = router.expectMessageType[Router.GetNodeId] + assert(getNodeId.isNode1 == true) + assert(getNodeId.shortChannelId == scid) + getNodeId.replyTo ! Some(aliceId) + + val connectToNextPeer = switchboard.expectMsgType[Peer.Connect] + assert(connectToNextPeer.nodeId == bobId) + connectToNextPeer.replyTo ! PeerConnection.ConnectionResult.AlreadyConnected(peerConnection.ref.toClassic, peer.ref.toClassic) + val messageToBob = peer.expectMessageType[Peer.RelayOnionMessage].msg + val OnionMessages.ReceiveMessage(payload) = OnionMessages.process(Bob.nodeParams.privateKey, messageToBob) + assert(payload.records.unknown == Set(GenericTlv(UInt64(31), hex"f3ed"))) + } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala index 722f141d74..622f952309 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala @@ -64,6 +64,7 @@ class PeerSpec extends FixtureSpec { val channel = TestProbe() val switchboard = TestProbe() val register = TestProbe() + val router = TestProbe() import com.softwaremill.quicklens._ val aliceParams = TestConstants.Alice.nodeParams @@ -99,7 +100,7 @@ class PeerSpec extends FixtureSpec { case _ => KeepRunning }) - val peer: TestFSMRef[Peer.State, Peer.Data, Peer] = TestFSMRef(new Peer(aliceParams, remoteNodeId, wallet, FakeChannelFactory(channel), switchboard.ref, register.ref, mockLimiter.ref)) + val peer: TestFSMRef[Peer.State, Peer.Data, Peer] = TestFSMRef(new Peer(aliceParams, remoteNodeId, wallet, FakeChannelFactory(channel), switchboard.ref, register.ref, router.ref, mockLimiter.ref)) FixtureParam(aliceParams, remoteNodeId, system, peer, peerConnection, channel, switchboard, register, mockLimiter.ref) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala index 7f2101ab31..73f356ccee 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.wire.protocol.OnionMessagePayloadTlv.EncryptedData import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv._ import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionMessage, OnionMessagePayloadTlv, OnionRoutingCodecs, RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv, TlvStream} -import fr.acinq.eclair.{ShortChannelId, UInt64, randomBytes, randomKey} +import fr.acinq.eclair.{EncodedNodeId, ShortChannelId, UInt64, randomBytes, randomKey} import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.scalatest.funsuite.AnyFunSuite @@ -116,13 +116,13 @@ class OnionMessagesSpec extends AnyFunSuite { // Checking that the onion is relayed properly process(alice, onionForAlice) match { case SendMessage(Right(nextNodeId), onionForBob) => - assert(nextNodeId == bob.publicKey) + assert(nextNodeId == EncodedNodeId(bob.publicKey)) process(bob, onionForBob) match { case SendMessage(Right(nextNodeId), onionForCarol) => - assert(nextNodeId == carol.publicKey) + assert(nextNodeId == EncodedNodeId(carol.publicKey)) process(carol, onionForCarol) match { case SendMessage(Right(nextNodeId), onionForDave) => - assert(nextNodeId == dave.publicKey) + assert(nextNodeId == EncodedNodeId(dave.publicKey)) process(dave, onionForDave) match { case ReceiveMessage(finalPayload) => assert(finalPayload.pathId_opt.contains(hex"01234567")) case x => fail(x.toString) @@ -235,10 +235,10 @@ class OnionMessagesSpec extends AnyFunSuite { // Checking that the onion is relayed properly process(alice, messageForAlice) match { case SendMessage(Right(nextNodeId), onionForBob) => - assert(nextNodeId == bob.publicKey) + assert(nextNodeId == EncodedNodeId(bob.publicKey)) process(bob, onionForBob) match { case SendMessage(Right(nextNodeId), onionForCarol) => - assert(nextNodeId == carol.publicKey) + assert(nextNodeId == EncodedNodeId(carol.publicKey)) process(carol, onionForCarol) match { case ReceiveMessage(finalPayload) => assert(finalPayload.pathId_opt.contains(pathId)) case x => fail(x.toString) @@ -329,13 +329,13 @@ class OnionMessagesSpec extends AnyFunSuite { // Checking that the onion is relayed properly process(alice, message) match { case SendMessage(Right(nextNodeId), onionForBob) => - assert(nextNodeId == bob.publicKey) + assert(nextNodeId == EncodedNodeId(bob.publicKey)) process(bob, onionForBob) match { case SendMessage(Right(nextNodeId), onionForCarol) => - assert(nextNodeId == carol.publicKey) + assert(nextNodeId == EncodedNodeId(carol.publicKey)) process(carol, onionForCarol) match { case SendMessage(Right(nextNodeId), onionForDave) => - assert(nextNodeId == dave.publicKey) + assert(nextNodeId == EncodedNodeId(dave.publicKey)) process(dave, onionForDave) match { case ReceiveMessage(finalPayload) => assert(finalPayload.pathId_opt.contains(pathId)) case x => fail(x.toString) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala index 57c946d124..e68c52d9fa 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala @@ -36,7 +36,7 @@ import fr.acinq.eclair.router.Router.{MessageRoute, MessageRouteRequest} import fr.acinq.eclair.wire.protocol.OnionMessagePayloadTlv.{InvoiceRequest, ReplyPath} import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.PathId import fr.acinq.eclair.wire.protocol.{GenericTlv, MessageOnion, OfferTypes, OnionMessagePayloadTlv, TlvStream} -import fr.acinq.eclair.{Features, MilliSatoshiLong, NodeParams, RealShortChannelId, TestConstants, UInt64, randomKey} +import fr.acinq.eclair.{Features, MilliSatoshiLong, EncodedNodeId, NodeParams, RealShortChannelId, TestConstants, UInt64, randomKey} import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike import scodec.bits.HexStringSyntax @@ -213,11 +213,11 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val Peer.RelayOnionMessage(messageId, message1, Some(replyTo)) = expectRelayToConnected(switchboard, a.publicKey) replyTo ! Sent(messageId) val OnionMessages.SendMessage(Right(next2), message2) = OnionMessages.process(a, message1) - assert(next2 == b.publicKey) + assert(next2 == EncodedNodeId(b.publicKey)) val OnionMessages.SendMessage(Right(next3), message3) = OnionMessages.process(b, message2) - assert(next3 == c.publicKey) + assert(next3 == EncodedNodeId(c.publicKey)) val OnionMessages.SendMessage(Right(next4), message4) = OnionMessages.process(c, message3) - assert(next4 == d.publicKey) + assert(next4 == EncodedNodeId(d.publicKey)) val OnionMessages.ReceiveMessage(payload) = OnionMessages.process(d, message4) assert(payload.records.unknown == Set(GenericTlv(UInt64(11), hex"012345"))) assert(payload.records.get[ReplyPath].nonEmpty) @@ -229,11 +229,11 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val Right((next5, reply)) = OnionMessages.buildMessage(d, randomKey(), randomKey(), Nil, OnionMessages.BlindedPath(replyPath), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(13), hex"6789")))) assert(next5 == c.publicKey) val OnionMessages.SendMessage(Right(next6), message6) = OnionMessages.process(c, reply) - assert(next6 == b.publicKey) + assert(next6 == EncodedNodeId(b.publicKey)) val OnionMessages.SendMessage(Right(next7), message7) = OnionMessages.process(b, message6) - assert(next7 == a.publicKey) + assert(next7 == EncodedNodeId(a.publicKey)) val OnionMessages.SendMessage(Right(next8), message8) = OnionMessages.process(a, message7) - assert(next8 == nodeParams.nodeId) + assert(next8 == EncodedNodeId(nodeParams.nodeId)) val OnionMessages.ReceiveMessage(replyPayload) = OnionMessages.process(nodeParams.privateKey, message8) postman ! WrappedMessage(replyPayload) @@ -246,7 +246,7 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val recipientKey = randomKey() val route = buildRoute(randomKey(), Seq(), Recipient(recipientKey.publicKey, None)) - val compactRoute = OfferTypes.CompactBlindedPath(OfferTypes.ShortChannelIdDir(isNode1 = false, RealShortChannelId(1234)), route.blindingKey, route.blindedNodes) + val compactRoute = OfferTypes.CompactBlindedPath(EncodedNodeId.ShortChannelIdDir(isNode1 = false, RealShortChannelId(1234)), route.blindingKey, route.blindedNodes) postman ! SendMessage(compactRoute, FindRoute, TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd"))), expectsReply = false, messageSender.ref) val getNodeId = router.expectMessageType[Router.GetNodeId] @@ -280,7 +280,7 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val recipientKey = randomKey() val route = buildRoute(randomKey(), Seq(IntermediateNode(nodeParams.nodeId)), Recipient(recipientKey.publicKey, None)) - val compactRoute = OfferTypes.CompactBlindedPath(OfferTypes.ShortChannelIdDir(isNode1 = true, RealShortChannelId(1234)), route.blindingKey, route.blindedNodes) + val compactRoute = OfferTypes.CompactBlindedPath(EncodedNodeId.ShortChannelIdDir(isNode1 = true, RealShortChannelId(1234)), route.blindingKey, route.blindedNodes) postman ! SendMessage(compactRoute, FindRoute, TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd"))), expectsReply = false, messageSender.ref) val getNodeId = router.expectMessageType[Router.GetNodeId] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala index 18e1a9556e..00bd63127f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala @@ -28,6 +28,7 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{Block, BlockHash, ByteVector32, Crypto} import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features.{AsyncPaymentPrototype, BasicMultiPartPayment, PaymentSecret, VariableLengthOnion} +import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/OfferPaymentSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/OfferPaymentSpec.scala index b7b0f5a2f3..4f76d7a5d9 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/OfferPaymentSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/OfferPaymentSpec.scala @@ -21,6 +21,7 @@ import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe = import akka.actor.typed.ActorRef import akka.testkit.TestProbe import com.typesafe.config.ConfigFactory +import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.crypto.Sphinx.RouteBlinding import fr.acinq.eclair.message.OnionMessages.RoutingStrategy.FindRoute import fr.acinq.eclair.message.Postman @@ -30,7 +31,7 @@ import fr.acinq.eclair.payment.{Bolt12Invoice, PaymentBlindedContactInfo} import fr.acinq.eclair.router.Router import fr.acinq.eclair.router.Router.RouteParams import fr.acinq.eclair.wire.protocol.MessageOnion.InvoicePayload -import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer, PaymentInfo, ShortChannelIdDir} +import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer, PaymentInfo} import fr.acinq.eclair.wire.protocol.{OfferTypes, OnionMessagePayloadTlv, TlvStream} import fr.acinq.eclair.{CltvExpiryDelta, Features, MilliSatoshiLong, NodeParams, RealShortChannelId, TestConstants, randomBytes, randomBytes32, randomKey} import org.scalatest.Outcome diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala index 84c186298b..009dec657d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala @@ -12,7 +12,7 @@ import fr.acinq.eclair.wire.protocol.OnionMessagePayloadTlv._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{AllowedFeatures, OutgoingNodeId, PathId, PaymentConstraints, PaymentRelay} -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshiLong, UInt64, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshiLong, EncodedNodeId, UInt64, randomBytes32, randomKey} import org.scalatest.funsuite.AnyFunSuiteLike import scodec.bits.{ByteVector, HexStringSyntax} @@ -30,7 +30,7 @@ class MessageOnionCodecsSpec extends AnyFunSuiteLike { assert(decoded == expected) val nextNodeId = randomKey().publicKey val Right(payload) = IntermediatePayload.validate(decoded, TlvStream(RouteBlindingEncryptedDataTlv.OutgoingNodeId(nextNodeId)), randomKey().publicKey) - assert(payload.nextNode == Right(nextNodeId)) + assert(payload.nextNode == Right(EncodedNodeId(nextNodeId))) val encoded = perHopPayloadCodec.encode(expected).require.bytes assert(encoded == bin) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/OfferTypesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/OfferTypesSpec.scala index 623a54c384..0f21ba07e2 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/OfferTypesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/OfferTypesSpec.scala @@ -19,12 +19,13 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.Bech32 import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{Block, BlockHash, ByteVector32} +import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features.BasicMultiPartPayment import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedNode, BlindedRoute} import fr.acinq.eclair.wire.protocol.OfferCodecs.{invoiceRequestTlvCodec, offerTlvCodec} import fr.acinq.eclair.wire.protocol.OfferTypes._ -import fr.acinq.eclair.{Features, MilliSatoshiLong, RealShortChannelId, randomBytes32, randomKey} +import fr.acinq.eclair.{BlockHeight, EncodedNodeId, Features, MilliSatoshiLong, RealShortChannelId, randomBytes32, randomKey} import org.scalatest.funsuite.AnyFunSuite import scodec.bits.{ByteVector, HexStringSyntax} @@ -282,4 +283,22 @@ class OfferTypesSpec extends AnyFunSuite { assert(OfferCodecs.pathCodec.decode(encoded.bits).require.value == decoded) } } + + test("encoded node id") { + val testCases = Map( + hex"00 0d950b0001c80000" -> + EncodedNodeId.ShortChannelIdDir(isNode1 = true, RealShortChannelId(BlockHeight(890123), 456, 0)), + hex"01 0c0a14000d800005" -> + EncodedNodeId.ShortChannelIdDir(isNode1 = false, RealShortChannelId(BlockHeight(789012), 3456, 5)), + hex"022d3b15cea00ee4a8e710b082bef18f0f3409cc4e7aff41c26eb0a4d3ab20dd73" -> + EncodedNodeId.Plain(PublicKey(hex"022d3b15cea00ee4a8e710b082bef18f0f3409cc4e7aff41c26eb0a4d3ab20dd73")), + hex"03ba3c458e3299eb19d2e07ae86453f4290bcdf8689707f0862f35194397c45922" -> + EncodedNodeId.Plain(PublicKey(hex"03ba3c458e3299eb19d2e07ae86453f4290bcdf8689707f0862f35194397c45922")), + ) + + for ((encoded, decoded) <- testCases) { + assert(OfferCodecs.encodedNodeIdCodec.encode(decoded).require.bytes == encoded) + assert(OfferCodecs.encodedNodeIdCodec.decode(encoded.bits).require.value == decoded) + } + } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala index 5707cc6a1f..c3a32d30dc 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala @@ -26,7 +26,7 @@ import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.PaymentOnion._ import fr.acinq.eclair.wire.protocol.PaymentOnionCodecs._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, MilliSatoshiLong, RealShortChannelId, ShortChannelId, UInt64, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, MilliSatoshiLong, EncodedNodeId, RealShortChannelId, ShortChannelId, UInt64, randomKey} import org.scalatest.funsuite.AnyFunSuite import scodec.bits.{ByteVector, HexStringSyntax} @@ -167,7 +167,7 @@ class PaymentOnionSpec extends AnyFunSuite { test("encode/decode node relay to blinded paths per-hop payload") { val features = Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional).toByteVector val blindedRoute = OfferTypes.CompactBlindedPath( - OfferTypes.ShortChannelIdDir(isNode1 = false, RealShortChannelId(468)), + EncodedNodeId.ShortChannelIdDir(isNode1 = false, RealShortChannelId(468)), PublicKey(hex"0232882c4982576e00f0d6bd4998f5b3e92d47ecc8fbad5b6a5e7521819d891d9e"), Seq(RouteBlinding.BlindedNode(PublicKey(hex"03823aa560d631e9d7b686be4a9227e577009afb5173023b458a6a6aff056ac980"), hex"")) )