Skip to content

Commit

Permalink
Initial version of the DHT
Browse files Browse the repository at this point in the history
  • Loading branch information
jensenbox committed Oct 3, 2017
1 parent d6501ea commit 33bf3ab
Show file tree
Hide file tree
Showing 12 changed files with 618 additions and 28 deletions.
59 changes: 59 additions & 0 deletions agent/sn_agent/network/dht/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import logging

from sn_agent.network.dht.dht import DHT

from sn_agent.agent.base import AgentABC
from sn_agent.network.base import NetworkABC
from sn_agent.network.enum import NetworkStatus
from sn_agent.ontology.service_descriptor import ServiceDescriptor

logger = logging.getLogger(__name__)


class DHTNetwork(NetworkABC):
def __init__(self, app):
super().__init__(app)
import nacl.signing



host1, port1 = 'localhost', 3000
dht1 = DHT(key1)

key2 = nacl.signing.SigningKey.generate()

host2, port2 = 'localhost', 3001
dht2 = DHT(host2, port2, key2, boot_host=host1, boot_port=port1)

dht1["test2"] = ["My", "json-serializable", "Object"]
print(dht2["test2"])

def update_ontology(self):
super().update_ontology()

def is_agent_a_member(self, agent: AgentABC) -> bool:
return super().is_agent_a_member(agent)

def leave_network(self) -> bool:
return super().leave_network()

def join_network(self) -> bool:
return super().join_network()

def get_network_status(self) -> NetworkStatus:
return super().get_network_status()

def advertise_service(self, service: ServiceDescriptor):
super().advertise_service(service)

def logon_network(self) -> bool:
return super().logon_network()

def find_service_providers(self, service: ServiceDescriptor) -> list:
return super().find_service_providers(service)

def remove_service_advertisement(self, service: ServiceDescriptor):
super().remove_service_advertisement(service)

def logoff_network(self) -> bool:
return super().logoff_network()
43 changes: 43 additions & 0 deletions agent/sn_agent/network/dht/bucketset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import heapq
import threading

from .peer import Peer


def largest_differing_bit(value1, value2):
distance = value1 ^ value2
length = -1
while distance:
distance >>= 1
length += 1
return max(0, length)


class BucketSet(object):
def __init__(self, bucket_size, buckets, id):
self.id = id
self.bucket_size = bucket_size
self.buckets = [list() for _ in range(buckets)]
self.lock = threading.Lock()

def insert(self, peer):
if peer.id != self.id:
bucket_number = largest_differing_bit(self.id, peer.id)
peer_triple = peer.astriple()
with self.lock:
bucket = self.buckets[bucket_number]
if peer_triple in bucket:
bucket.pop(bucket.index(peer_triple))
elif len(bucket) >= self.bucket_size:
bucket.pop(0)
bucket.append(peer_triple)

def nearest_nodes(self, key):

with self.lock:
def keyfunction(peer):
return key ^ peer[2] # ideally there would be a better way with names? Instead of storing triples it would be nice to have a dict

peers = (peer for bucket in self.buckets for peer in bucket)
best_peers = heapq.nsmallest(self.bucket_size, peers, keyfunction)
return [Peer(*peer) for peer in best_peers]
124 changes: 124 additions & 0 deletions agent/sn_agent/network/dht/dht.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import logging
import random
import threading
import time

import nacl.encoding
import nacl.signing

from .handler import DHTRequestHandler
from .server import DHTServer
from .settings import DHTSettings
from .bucketset import BucketSet
from .hashing import hash_function, random_id
from .peer import Peer
from .shortlist import Shortlist

logger = logging.getLogger(__name__)


class DHT(object):
def __init__(self, host=None, port=None, key=None, my_id=None):
self.settings = DHTSettings()

if not my_id:
my_id = random_id()

if not host:
host = "0.0.0.0"

if not port:
port = random.randint(5000, 10000)

if not key:
key = nacl.signing.SigningKey.generate()
self.my_key = key

self.peer = Peer(host, port, my_id)
self.data = {}
self.buckets = BucketSet(self.settings.K, self.settings.ID_BITS, self.peer.id)
self.rpc_ids = {} # should probably have a lock for this

self.server = DHTServer(self.peer.address(), DHTRequestHandler)
self.server.dht = self
self.server_thread = threading.Thread(target=self.server.serve_forever)
self.server_thread.daemon = True
self.server_thread.start()

if self.settings.USE_UPNP:
self.server.try_upnp_portmap(port)

if self.settings.NEEDS_BOOTING:
self.is_boot_node = False
self.bootstrap(self.settings.BOOT_HOST, self.settings.BOOT_PORT)
else:
self.is_boot_node = True

logger.debug('DHT Server started')

def iterative_find_nodes(self, key, boot_peer=None):
logger.debug('Finding nearest nodes...')
shortlist = Shortlist(self.settings.K, key)
shortlist.update(self.buckets.nearest_nodes(key))

if boot_peer:
logger.debug('This node a boot node: %s', boot_peer)
rpc_id = random.getrandbits(self.settings.ID_BITS)
self.rpc_ids[rpc_id] = shortlist
boot_peer.find_node(key, rpc_id, socket=self.server.socket, peer_id=self.peer.id)

while (not shortlist.complete()) or boot_peer:
nearest_nodes = shortlist.get_next_iteration(self.settings.ALPHA)
for peer in nearest_nodes:
logger.debug('Nearest Node: %s', peer)
shortlist.mark(peer)
rpc_id = random.getrandbits(self.settings.ID_BITS)
self.rpc_ids[rpc_id] = shortlist
peer.find_node(key, rpc_id, socket=self.server.socket, peer_id=self.peer.id)
time.sleep(self.settings.ITERATION_SLEEP)
boot_peer = None

return shortlist.results()

def iterative_find_value(self, key):
shortlist = Shortlist(self.settings.K, key)
shortlist.update(self.buckets.nearest_nodes(key))
while not shortlist.complete():
nearest_nodes = shortlist.get_next_iteration(self.settings.ALPHA)
for peer in nearest_nodes:
shortlist.mark(peer)
rpc_id = random.getrandbits(self.settings.ID_BITS)
self.rpc_ids[rpc_id] = shortlist
peer.find_value(key, rpc_id, socket=self.server.socket, peer_id=self.peer.id)
time.sleep(self.settings.ITERATION_SLEEP)
return shortlist.completion_result()

def bootstrap(self, boot_host, boot_port):
boot_peer = Peer(boot_host, boot_port, 0)
self.iterative_find_nodes(self.peer.id, boot_peer=boot_peer)

def __getitem__(self, key, bypass=0):
hashed_key = hash_function(key.encode("ascii"))
if hashed_key in self.data:
return self.data[hashed_key]["content"]
result = self.iterative_find_value(hashed_key)
if result:
return result["content"]

raise KeyError

def __setitem__(self, key, content):
content = str(content)
hashed_key = hash_function(key.encode("ascii"))
nearest_nodes = self.iterative_find_nodes(hashed_key)
value = {
"content": content,
"key": self.my_key.verify_key.encode(encoder=nacl.encoding.Base64Encoder).decode("utf-8"),
"signature": nacl.encoding.Base64Encoder.encode(self.my_key.sign(content.encode("ascii"))).decode("utf-8")
}

if not nearest_nodes:
self.data[hashed_key] = value

for node in nearest_nodes:
node.store(hashed_key, value, socket=self.server.socket, peer_id=self.peer.id)
102 changes: 102 additions & 0 deletions agent/sn_agent/network/dht/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import json
import logging
import socketserver

import nacl.encoding
import nacl.signing

from .peer import Peer

logger = logging.getLogger(__name__)


class DHTRequestHandler(socketserver.BaseRequestHandler):
def handle(self):
try:
message = json.loads(self.request[0].decode("utf-8").strip())
message_type = message["message_type"]

logger.debug('Handling message type" %s', message_type)

if message_type == "ping":
self.handle_ping(message)
elif message_type == "pong":
self.handle_pong(message)
elif message_type == "find_node":
self.handle_find(message)
elif message_type == "find_value":
self.handle_find(message, find_value=True)
elif message_type == "found_nodes":
self.handle_found_nodes(message)
elif message_type == "found_value":
self.handle_found_value(message)
elif message_type == "store":
self.handle_store(message)
elif message_type == "push":
self.handle_push(message)
except:
return

client_host, client_port = self.client_address
peer_id = message["peer_id"]
new_peer = Peer(client_host, client_port, peer_id)
self.server.dht.buckets.insert(new_peer)

def handle_ping(self, message):
client_host, client_port = self.client_address
id = message["peer_id"]
peer = Peer(client_host, client_port, id)
peer.pong(socket=self.server.socket, peer_id=self.server.dht.peer.id, lock=self.server.send_lock)

def handle_pong(self, message):
pass

def handle_find(self, message, find_value=False):
key = message["id"]
id = message["peer_id"]
client_host, client_port = self.client_address
peer = Peer(client_host, client_port, id)
response_socket = self.request[1]
if find_value and (key in self.server.dht.data):
value = self.server.dht.data[key]
peer.found_value(id, value, message["rpc_id"], socket=response_socket, peer_id=self.server.dht.peer.id, lock=self.server.send_lock)
else:
nearest_nodes = self.server.dht.buckets.nearest_nodes(id)
if not nearest_nodes:
nearest_nodes.append(self.server.dht.peer)
nearest_nodes = [nearest_peer.astriple() for nearest_peer in nearest_nodes]
peer.found_nodes(id, nearest_nodes, message["rpc_id"], socket=response_socket, peer_id=self.server.dht.peer.id, lock=self.server.send_lock)

def handle_found_nodes(self, message):
rpc_id = message["rpc_id"]
shortlist = self.server.dht.rpc_ids[rpc_id]
del self.server.dht.rpc_ids[rpc_id]
nearest_nodes = [Peer(*peer) for peer in message["nearest_nodes"]]
shortlist.update(nearest_nodes)

def handle_found_value(self, message):
rpc_id = message["rpc_id"]
shortlist = self.server.dht.rpc_ids[rpc_id]
del self.server.dht.rpc_ids[rpc_id]
shortlist.set_complete(message["value"])

def handle_store(self, message):
key = message["id"]

# Verify updated message is signed with same key.
if key in self.server.dht.data:
# Signature is valid.
# (Raises exception if not.)
ret = nacl.signing.VerifyKey(self.server.dht.data[key]["key"], encoder=nacl.encoding.Base64Encoder).verify(nacl.encoding.Base64Encoder.decode(message["value"]["signature"]))
if type(ret) == bytes:
ret = ret.decode("utf-8")

# Check that the signature corresponds to this message.
message_content = message["value"]["content"]
if ret != message_content:
return

self.server.dht.data[key] = message["value"]

def handle_push(self, message):
pass
14 changes: 14 additions & 0 deletions agent/sn_agent/network/dht/hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import hashlib
import random

id_bits = 128


def hash_function(data):
return int(hashlib.md5(data).hexdigest(), 16)


def random_id(seed=None):
if seed:
random.seed(seed)
return random.randint(0, (2 ** id_bits) - 1)
Empty file.
Loading

0 comments on commit 33bf3ab

Please sign in to comment.