From 2092ead1f19ae358b04edaefc8cf8e06444c3143 Mon Sep 17 00:00:00 2001 From: t-bast Date: Fri, 10 Jan 2025 12:03:29 +0100 Subject: [PATCH] Store incoming peers with channels in `PeersDb` Once we have a channel with a peer that connected to us, we store their details in our DB. We don't store the address they're connecting from, because we don't know if we will be able to connect to them using this address, but we store their features. --- .../main/scala/fr/acinq/eclair/io/Peer.scala | 56 +++++++++++++---- .../scala/fr/acinq/eclair/io/PeerSpec.scala | 60 ++++++++++++++++++- .../eclair/io/ReconnectionTaskSpec.scala | 29 ++++----- 3 files changed, 114 insertions(+), 31 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index 7974a86dc3..d4fbd7f1d1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -92,7 +92,9 @@ class Peer(val nodeParams: NodeParams, } else { None } - goto(DISCONNECTED) using DisconnectedData(channels, activeChannels = Set.empty, PeerStorage(peerStorageData, written = true)) // when we restart, we will attempt to reconnect right away, but then we'll wait + // When we restart, we will attempt to reconnect right away, but then we'll wait. + // We don't fetch our peer's features from the DB: if the connection succeeds, we will get them from their init message, which saves a DB call. + goto(DISCONNECTED) using DisconnectedData(channels, activeChannels = Set.empty, PeerStorage(peerStorageData, written = true), remoteFeatures_opt = None) } when(DISCONNECTED) { @@ -150,7 +152,14 @@ class Peer(val nodeParams: NodeParams, if (!d.peerStorage.written && !isTimerActive(WritePeerStorageTimerKey)) { startSingleTimer(WritePeerStorageTimerKey, WritePeerStorage, nodeParams.peerStorageConfig.writeDelay) } - stay() using d.copy(activeChannels = d.activeChannels + e.channelId) + val remoteFeatures_opt = d.remoteFeatures_opt match { + case Some(remoteFeatures) if !remoteFeatures.written => + // We have a channel, so we can write to the DB without any DoS risk. + nodeParams.db.peers.addOrUpdatePeer(remoteNodeId, NodeInfo(remoteFeatures.features, None)) + Some(remoteFeatures.copy(written = true)) + case _ => d.remoteFeatures_opt + } + stay() using d.copy(activeChannels = d.activeChannels + e.channelId, remoteFeatures_opt = remoteFeatures_opt) case Event(e: LocalChannelDown, d: DisconnectedData) => stay() using d.copy(activeChannels = d.activeChannels - e.channelId) @@ -447,7 +456,11 @@ class Peer(val nodeParams: NodeParams, if (!d.peerStorage.written && !isTimerActive(WritePeerStorageTimerKey)) { startSingleTimer(WritePeerStorageTimerKey, WritePeerStorage, nodeParams.peerStorageConfig.writeDelay) } - stay() using d.copy(activeChannels = d.activeChannels + e.channelId) + if (!d.remoteFeaturesWritten) { + // We have a channel, so we can write to the DB without any DoS risk. + nodeParams.db.peers.addOrUpdatePeer(remoteNodeId, NodeInfo(d.remoteFeatures, None)) + } + stay() using d.copy(activeChannels = d.activeChannels + e.channelId, remoteFeaturesWritten = true) case Event(e: LocalChannelDown, d: ConnectedData) => stay() using d.copy(activeChannels = d.activeChannels - e.channelId) @@ -492,7 +505,8 @@ class Peer(val nodeParams: NodeParams, stopPeer(d.peerStorage) } else { d.channels.values.toSet[ActorRef].foreach(_ ! INPUT_DISCONNECTED) // we deduplicate with toSet because there might be two entries per channel (tmp id and final id) - goto(DISCONNECTED) using DisconnectedData(d.channels.collect { case (k: FinalChannelId, v) => (k, v) }, d.activeChannels, d.peerStorage) + val lastRemoteFeatures = LastRemoteFeatures(d.remoteFeatures, d.remoteFeaturesWritten) + goto(DISCONNECTED) using DisconnectedData(d.channels.collect { case (k: FinalChannelId, v) => (k, v) }, d.activeChannels, d.peerStorage, Some(lastRemoteFeatures)) } case Event(Terminated(actor), d: ConnectedData) if d.channels.values.toSet.contains(actor) => @@ -587,12 +601,22 @@ class Peer(val nodeParams: NodeParams, case Event(r: GetPeerInfo, d) => val replyTo = r.replyTo.getOrElse(sender().toTyped) - val peerInfo = d match { - case c: ConnectedData => PeerInfo(self, remoteNodeId, stateName, Some(c.remoteFeatures), Some(c.address), c.channels.values.toSet) - case _ => PeerInfo(self, remoteNodeId, stateName, None, None, d.channels.values.toSet) + d match { + case c: ConnectedData => + replyTo ! PeerInfo(self, remoteNodeId, stateName, Some(c.remoteFeatures), Some(c.address), c.channels.values.toSet) + stay() + case d: DisconnectedData => + // If we haven't reconnected since our last restart, we fetch the latest remote features from our DB. + val remoteFeatures_opt = d.remoteFeatures_opt match { + case Some(remoteFeatures) => Some(remoteFeatures) + case None => nodeParams.db.peers.getPeer(remoteNodeId).map(nodeInfo => LastRemoteFeatures(nodeInfo.features, written = true)) + } + replyTo ! PeerInfo(self, remoteNodeId, stateName, remoteFeatures_opt.map(_.features), None, d.channels.values.toSet) + stay() using d.copy(remoteFeatures_opt = remoteFeatures_opt) + case _ => + replyTo ! PeerInfo(self, remoteNodeId, stateName, None, None, d.channels.values.toSet) + stay() } - replyTo ! peerInfo - stay() case Event(r: GetPeerChannels, d) => if (d.channels.isEmpty) { @@ -804,7 +828,13 @@ class Peer(val nodeParams: NodeParams, // We store the node address and features upon successful outgoing connection, so we can reconnect later. // The previous address is overwritten: we don't need it since the current one works. nodeParams.db.peers.addOrUpdatePeer(remoteNodeId, NodeInfo(connectionReady.remoteInit.features, Some(connectionReady.address))) + } else if (channels.nonEmpty) { + // If this is an incoming connection, we only store the peer details in our DB if we have channels with them. + // Otherwise nodes could DoS by simply connecting to us to force us to store data in our DB. + // We don't update the remote address, we don't know if we would successfully connect using the current one. + nodeParams.db.peers.addOrUpdatePeer(remoteNodeId, NodeInfo(connectionReady.remoteInit.features, None)) } + val remoteFeaturesWritten = connectionReady.outgoing || channels.nonEmpty // If we have some data stored from our peer, we send it to them before doing anything else. peerStorage.data.foreach(connectionReady.peerConnection ! PeerStorageRetrieval(_)) @@ -826,7 +856,7 @@ class Peer(val nodeParams: NodeParams, connectionReady.peerConnection ! CurrentFeeCredit(nodeParams.chainHash, feeCredit.getOrElse(0 msat)) } - goto(CONNECTED) using ConnectedData(connectionReady.address, connectionReady.peerConnection, connectionReady.localInit, connectionReady.remoteInit, channels, activeChannels, feerates, None, peerStorage) + goto(CONNECTED) using ConnectedData(connectionReady.address, connectionReady.peerConnection, connectionReady.localInit, connectionReady.remoteInit, channels, activeChannels, feerates, None, peerStorage, remoteFeaturesWritten) } /** @@ -967,6 +997,8 @@ object Peer { case class PeerStorage(data: Option[ByteVector], written: Boolean) + case class LastRemoteFeatures(features: Features[InitFeature], written: Boolean) + sealed trait Data { def channels: Map[_ <: ChannelId, ActorRef] // will be overridden by Map[FinalChannelId, ActorRef] or Map[ChannelId, ActorRef] def activeChannels: Set[ByteVector32] // channels that are available to process payments @@ -977,8 +1009,8 @@ object Peer { override def activeChannels: Set[ByteVector32] = Set.empty override def peerStorage: PeerStorage = PeerStorage(None, written = true) } - case class DisconnectedData(channels: Map[FinalChannelId, ActorRef], activeChannels: Set[ByteVector32], peerStorage: PeerStorage) extends Data - case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef], activeChannels: Set[ByteVector32], currentFeerates: RecommendedFeerates, previousFeerates_opt: Option[RecommendedFeerates], peerStorage: PeerStorage) extends Data { + case class DisconnectedData(channels: Map[FinalChannelId, ActorRef], activeChannels: Set[ByteVector32], peerStorage: PeerStorage, remoteFeatures_opt: Option[LastRemoteFeatures]) extends Data + case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef], activeChannels: Set[ByteVector32], currentFeerates: RecommendedFeerates, previousFeerates_opt: Option[RecommendedFeerates], peerStorage: PeerStorage, remoteFeaturesWritten: Boolean) extends Data { val connectionInfo: ConnectionInfo = ConnectionInfo(address, peerConnection, localInit, remoteInit) def localFeatures: Features[InitFeature] = localInit.features def remoteFeatures: Features[InitFeature] = remoteInit.features diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala index ac64eb146a..1d9eca3131 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala @@ -755,7 +755,7 @@ class PeerSpec extends FixtureSpec { channel.expectMsg(open) } - test("peer storage") { f => + test("store remote peer storage once we have channels") { f => import f._ // We connect with a previous backup. @@ -768,7 +768,6 @@ class PeerSpec extends FixtureSpec { peerConnection1.send(peer, PeerStorageStore(hex"0123456789")) // We disconnect and reconnect, sending the last backup we received. - peer ! Peer.Disconnect(f.remoteNodeId) val peerConnection2 = TestProbe() connect(remoteNodeId, peer, peerConnection2, switchboard, channels = Set(ChannelCodecsSpec.normal), initializePeer = false, peerStorage = Some(hex"0123456789")) peerConnection2.send(peer, PeerStorageStore(hex"1111")) @@ -788,6 +787,63 @@ class PeerSpec extends FixtureSpec { assert(nodeParams.db.peers.getStorage(remoteNodeId).contains(hex"1111")) } + test("store remote features when channel confirms") { f => + import f._ + + // When we make an outgoing connection, we store the peer details in our DB. + assert(nodeParams.db.peers.getPeer(remoteNodeId).isEmpty) + connect(remoteNodeId, peer, peerConnection, switchboard) + val Some(nodeInfo1) = nodeParams.db.peers.getPeer(remoteNodeId) + assert(nodeInfo1.features == TestConstants.Bob.nodeParams.features.initFeatures()) + assert(nodeInfo1.address_opt.contains(fakeIPAddress)) + + // We disconnect and our peer connects to us: we don't have any channel, so we don't update the DB entry. + val peerConnection2 = TestProbe() + val address2 = Tor3("of7husrflx7sforh3fw6yqlpwstee3wg5imvvmkp4bz6rbjxtg5nljad", 9735) + val remoteFeatures2 = Features(Features.ChannelType -> FeatureSupport.Mandatory).initFeatures() + switchboard.send(peer, PeerConnection.ConnectionReady(peerConnection2.ref, remoteNodeId, address2, outgoing = false, protocol.Init(Features.empty), protocol.Init(remoteFeatures2))) + val probe = TestProbe() + probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped))) + assert(probe.expectMsgType[Peer.PeerInfo].address.contains(address2)) + assert(nodeParams.db.peers.getPeer(remoteNodeId).contains(nodeInfo1)) + + // A channel is created, so we update the remote features in our DB. + // We don't update the address because this was an incoming connection. + peer ! ChannelReadyForPayments(ActorRef.noSender, remoteNodeId, randomBytes32(), 0) + probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped))) + assert(probe.expectMsgType[Peer.PeerInfo].features.contains(remoteFeatures2)) + assert(nodeParams.db.peers.getPeer(remoteNodeId).contains(nodeInfo1.copy(features = remoteFeatures2))) + } + + test("store remote features when channel confirms while disconnected") { f => + import f._ + + // When we receive an incoming connection, we don't store the peer details in our DB. + assert(nodeParams.db.peers.getPeer(remoteNodeId).isEmpty) + switchboard.send(peer, Peer.Init(Set.empty, Map.empty)) + val localInit = protocol.Init(peer.underlyingActor.nodeParams.features.initFeatures()) + val remoteInit = protocol.Init(TestConstants.Bob.nodeParams.features.initFeatures()) + switchboard.send(peer, PeerConnection.ConnectionReady(peerConnection.ref, remoteNodeId, fakeIPAddress, outgoing = false, localInit, remoteInit)) + val probe = TestProbe() + probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped))) + assert(probe.expectMsgType[Peer.PeerInfo].state == Peer.CONNECTED) + assert(nodeParams.db.peers.getPeer(remoteNodeId).isEmpty) + + // Our peer wants to open a channel to us, but we disconnect before we have a confirmed channel. + peer ! SpawnChannelNonInitiator(Left(createOpenChannelMessage()), ChannelConfig.standard, ChannelTypes.Standard(), None, localParams, peerConnection.ref) + peer ! Peer.ConnectionDown(peerConnection.ref) + probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped))) + assert(probe.expectMsgType[Peer.PeerInfo].state == Peer.DISCONNECTED) + assert(nodeParams.db.peers.getPeer(remoteNodeId).isEmpty) + + // The channel confirms, so we store the remote features in our DB. + // We don't store the remote address because this was an incoming connection. + peer ! ChannelReadyForPayments(ActorRef.noSender, remoteNodeId, randomBytes32(), 0) + probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped))) + assert(probe.expectMsgType[Peer.PeerInfo].state == Peer.DISCONNECTED) + assert(nodeParams.db.peers.getPeer(remoteNodeId).contains(NodeInfo(remoteInit.features, None))) + } + } object PeerSpec { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala index d4ba5401c2..06ca2b13e9 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala @@ -38,8 +38,8 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike private val recommendedFeerates = RecommendedFeerates(Block.RegtestGenesisBlock.hash, TestConstants.feeratePerKw, TestConstants.anchorOutputsFeeratePerKw) private val PeerNothingData = Peer.Nothing - private val PeerDisconnectedData = Peer.DisconnectedData(channels, activeChannels = Set.empty, PeerStorage(None, written = true)) - private val PeerConnectedData = Peer.ConnectedData(fakeIPAddress, system.deadLetters, null, null, channels.map { case (k: ChannelId, v) => (k, v) }, activeChannels = Set.empty, recommendedFeerates, None, PeerStorage(None, written = true)) + private val PeerDisconnectedData = Peer.DisconnectedData(channels, activeChannels = Set.empty, PeerStorage(None, written = true), remoteFeatures_opt = None) + private val PeerConnectedData = Peer.ConnectedData(fakeIPAddress, system.deadLetters, null, null, channels.map { case (k: ChannelId, v) => (k, v) }, activeChannels = Set.empty, recommendedFeerates, None, PeerStorage(None, written = true), remoteFeaturesWritten = true) case class FixtureParam(nodeParams: NodeParams, remoteNodeId: PublicKey, reconnectionTask: TestFSMRef[ReconnectionTask.State, ReconnectionTask.Data, ReconnectionTask], monitor: TestProbe) @@ -82,7 +82,7 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike import f._ val peer = TestProbe() - peer.send(reconnectionTask, Peer.Transition(PeerNothingData, Peer.DisconnectedData(Map.empty, activeChannels = Set.empty, PeerStorage(None, written = true)))) + peer.send(reconnectionTask, Peer.Transition(PeerNothingData, Peer.DisconnectedData(Map.empty, activeChannels = Set.empty, PeerStorage(None, written = true), None))) monitor.expectNoMessage() } @@ -205,7 +205,6 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike peer.send(reconnectionTask, Peer.Transition(PeerDisconnectedData, PeerConnectedData)) // we cancel the reconnection and go to idle state val TransitionWithData(ReconnectionTask.WAITING, ReconnectionTask.IDLE, _, _) = monitor.expectMsgType[TransitionWithData] - } test("reconnect using the address from node_announcement") { f => @@ -232,15 +231,13 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val tor = NodeAddress.fromParts("iq7zhmhck54vcax2vlrdcavq2m32wao7ekh6jyeglmnuuvv3js57r4id.onion", 9735).get // NB: we don't test randomization here, but it makes tests unnecessary more complex for little value - { // tor not supported: always return clearnet addresses nodeParams.socksProxy_opt returns None - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)) == Some(clearnet)) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)) == None) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)) == Some(clearnet)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)).contains(clearnet)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)).isEmpty) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)).contains(clearnet)) } - { // tor supported but not enabled for clearnet addresses: return clearnet addresses when available val socksParams = mock[Socks5ProxyParams] @@ -248,11 +245,10 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike socksParams.useForIPv4 returns false socksParams.useForIPv6 returns false nodeParams.socksProxy_opt returns Some(socksParams) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)) == Some(clearnet)) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)) == Some(tor)) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)) == Some(clearnet)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)).contains(clearnet)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)).contains(tor)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)).contains(clearnet)) } - { // tor supported and enabled for clearnet addresses: return tor addresses when available val socksParams = mock[Socks5ProxyParams] @@ -260,11 +256,10 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike socksParams.useForIPv4 returns true socksParams.useForIPv6 returns true nodeParams.socksProxy_opt returns Some(socksParams) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)) == Some(clearnet)) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)) == Some(tor)) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)) == Some(tor)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)).contains(clearnet)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)).contains(tor)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)).contains(tor)) } - } }