diff --git a/eth/p2p/discoveryv5/routing_table.nim b/eth/p2p/discoveryv5/routing_table.nim index 7385c089..8d10621b 100644 --- a/eth/p2p/discoveryv5/routing_table.nim +++ b/eth/p2p/discoveryv5/routing_table.nim @@ -14,7 +14,7 @@ import ../../net/utils, "."/[node, random2, enr] -export results +export results, chronos.timer declareGauge routing_table_nodes, "Discovery routing table nodes", labels = ["state"] @@ -47,6 +47,10 @@ type ## replacement caches. distanceCalculator: DistanceCalculator rng: ref HmacDrbgContext + bannedNodes: Table[NodeId, chronos.Moment] ## Nodes can be banned from the + ## routing table for a period until the timeout is reached. Banned nodes + ## are removed from the routing table and not allowed to be included again + ## until the timeout expires. KBucket = ref object istart, iend: NodeId ## Range of NodeIds this KBucket covers. This is not a @@ -95,6 +99,7 @@ type ReplacementAdded ReplacementExisting NoAddress + Banned # xor distance functions func distance*(a, b: NodeId): UInt256 = @@ -189,6 +194,51 @@ func ipLimitDec(r: var RoutingTable, b: KBucket, n: Node) = b.ipLimits.dec(ip) r.ipLimits.dec(ip) +func getNode*(r: RoutingTable, id: NodeId): Opt[Node] +proc replaceNode*(r: var RoutingTable, n: Node) + +proc banNode*(r: var RoutingTable, nodeId: NodeId, period: chronos.Duration) = + ## Ban a node from the routing table for the given period. The node is removed + ## from the routing table and replaced using a node from the replacement cache. + let banTimeout = now(chronos.Moment) + period + + if r.bannedNodes.contains(nodeId): + let existingTimeout = r.bannedNodes.getOrDefault(nodeId) + if existingTimeout < banTimeout: + r.bannedNodes[nodeId] = banTimeout + return # node is already banned so we don't need to try replacing it because + # it should have already been replaced when it was initially banned + + # NodeId doesn't yet exist in the banned nodes table + r.bannedNodes[nodeId] = banTimeout + + # Remove the node from the routing table + let node = r.getNode(nodeId) + if node.isSome(): + r.replaceNode(node.get()) + +proc isBanned*(r: RoutingTable, nodeId: NodeId): bool = + if not r.bannedNodes.contains(nodeId): + return false + + let + currentTime = now(chronos.Moment) + banTimeout = r.bannedNodes.getOrDefault(nodeId) + + currentTime < banTimeout + +proc cleanupExpiredBans*(r: var RoutingTable) = + ## Remove all expired bans from the banned nodes table + let currentTime = now(chronos.Moment) + + var expiredIds = newSeq[NodeId]() + for id, timeout in r.bannedNodes: + if currentTime >= timeout: + expiredIds.add(id) + + for id in expiredIds: + r.bannedNodes.del(id) + proc add(k: KBucket, n: Node) = k.nodes.add(n) routing_table_nodes.inc() @@ -274,7 +324,8 @@ func init*(T: type RoutingTable, localNode: Node, bitsPerHop = DefaultBitsPerHop bitsPerHop: bitsPerHop, ipLimits: IpLimits(limit: ipLimits.tableIpLimit), distanceCalculator: distanceCalculator, - rng: rng) + rng: rng, + bannedNodes: initTable[NodeId, chronos.Moment]()) func splitBucket(r: var RoutingTable, index: int) = let bucket = r.buckets[index] @@ -343,6 +394,9 @@ proc addNode*(r: var RoutingTable, n: Node): NodeStatus = if n == r.localNode: return LocalNode + if r.isBanned(n.id): + return Banned + let bucket = r.bucketForNode(n.id) ## Check if the node is already present. If so, check if the record requires diff --git a/tests/p2p/test_routing_table.nim b/tests/p2p/test_routing_table.nim index 54aa99d4..652262ce 100644 --- a/tests/p2p/test_routing_table.nim +++ b/tests/p2p/test_routing_table.nim @@ -1,6 +1,7 @@ {.used.} import + std/os, unittest2, ../../eth/common/keys, ../../eth/p2p/discoveryv5/[routing_table, node, enr], ./discv5_test_helper @@ -561,3 +562,127 @@ suite "Routing Table Tests": # there may be more than one node at provided distance check len(neighboursAtLogDist) >= 1 check neighboursAtLogDist.contains(n) + + test "Banned nodes: banned node cannot be added": + let + localNode = generateNode(PrivateKey.random(rng[])) + node1 = generateNode(PrivateKey.random(rng[])) + node2 = generateNode(PrivateKey.random(rng[])) + + var table = RoutingTable.init(localNode, 1, DefaultTableIpLimits, rng = rng) + + # Can add a node that is not banned + check: + table.contains(node1) == false + table.isBanned(node1.id) == false + table.addNode(node1) == Added + table.contains(node1) == true + table.isBanned(node1.id) == false + + # Can ban a node that exists in the routing table + table.banNode(node1.id, 1.minutes) + check: + table.contains(node1) == false # the node is removed when banned + table.isBanned(node1.id) == true + table.addNode(node1) == Banned # the node cannot be added while banned + table.contains(node1) == false + table.getNode(node1.id).isNone() + table.isBanned(node1.id) == true + + # Can ban a node that doesn't yet exist in the routing table + check: + table.contains(node2) == false + table.isBanned(node2.id) == false + + table.banNode(node2.id, 1.minutes) + check: + table.contains(node2) == false + table.isBanned(node2.id) == true + table.addNode(node2) == Banned # the node cannot be added while banned + table.contains(node2) == false + table.getNode(node2.id).isNone() + table.isBanned(node2.id) == true + + test "Banned nodes: nodes with expired bans can be added": + let + localNode = generateNode(PrivateKey.random(rng[])) + node1 = generateNode(PrivateKey.random(rng[])) + node2 = generateNode(PrivateKey.random(rng[])) + + var table = RoutingTable.init(localNode, 1, DefaultTableIpLimits, rng = rng) + + check table.addNode(node1) == Added + table.banNode(node1.id, 1.nanoseconds) + table.banNode(node2.id, 1.nanoseconds) + + sleep(1) + + # Can add nodes for which the ban has expired + check: + table.contains(node1) == false + table.isBanned(node1.id) == false + table.addNode(node1) == Added + table.contains(node1) == true + table.isBanned(node1.id) == false + + table.contains(node2) == false + table.isBanned(node2.id) == false + table.addNode(node2) == Added + table.contains(node2) == true + table.isBanned(node2.id) == false + + test "Banned nodes: ban nodes with existing bans": + let + localNode = generateNode(PrivateKey.random(rng[])) + node1 = generateNode(PrivateKey.random(rng[])) + node2 = generateNode(PrivateKey.random(rng[])) + + var table = RoutingTable.init(localNode, 1, DefaultTableIpLimits, rng = rng) + + check: + table.addNode(node1) == Added + table.addNode(node2) == Added + + table.banNode(node1.id, 1.nanoseconds) + sleep(1) # node1's ban is expired + table.banNode(node2.id, 1.minutes) + + check: + table.isBanned(node1.id) == false + table.isBanned(node2.id) == true + + # Can ban nodes which were previously banned + table.banNode(node1.id, 1.minutes) + table.banNode(node2.id, 1.minutes) + + check: + table.contains(node1) == false + table.isBanned(node1.id) == true + table.contains(node2) == false + table.isBanned(node2.id) == true + + test "Banned nodes: cleanup expired bans": + let + localNode = generateNode(PrivateKey.random(rng[])) + node1 = generateNode(PrivateKey.random(rng[])) + node2 = generateNode(PrivateKey.random(rng[])) + + var table = RoutingTable.init(localNode, 1, DefaultTableIpLimits, rng = rng) + + table.banNode(node1.id, 1.nanoseconds) + sleep(1) # node1's ban is expired + table.banNode(node2.id, 1.minutes) + + check: + table.isBanned(node1.id) == false + table.isBanned(node2.id) == true + table.addNode(node1) == Added + table.addNode(node2) == Banned + + table.cleanupExpiredBans() + + check: + table.isBanned(node1.id) == false + table.isBanned(node2.id) == true + table.addNode(node1) == Existing + table.addNode(node2) == Banned