Skip to content
This repository was archived by the owner on Sep 8, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions p2p/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,18 @@
to_list,
)

from eth_utils import (
int_to_big_endian,
big_endian_to_int,
)

from eth_keys import keys
from eth_keys import datatypes

from eth_hash.auto import keccak

from p2p.cancel_token import CancelToken
from p2p import kademlia
from evm.utils.numeric import (
big_endian_to_int,
int_to_big_endian,
)

# UDP packet constants.
MAC_SIZE = 256 // 8 # 32
Expand Down
224 changes: 149 additions & 75 deletions p2p/kademlia.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
import asyncio
import bisect
import collections
import contextlib
from functools import total_ordering
import ipaddress
import logging
import bisect
import operator
import random
import struct
import time
from urllib import parse as urlparse
from functools import total_ordering
from typing import (
Any,
Callable,
Dict,
Generator,
Hashable,
List,
Set,
Sized,
Tuple,
TYPE_CHECKING,
)
from urllib import parse as urlparse

from eth_utils import (
big_endian_to_int,
Expand Down Expand Up @@ -237,6 +240,8 @@ def get_random_nodes(self, count: int) -> Generator[Node, None, None]:
# to nodes.
while len(seen) < count:
bucket = random.choice(self.buckets)
if not bucket.nodes:
continue
node = random.choice(bucket.nodes)
if node not in seen:
yield node
Expand Down Expand Up @@ -320,16 +325,64 @@ def binary_get_bucket_for_node(buckets: List[KBucket], node: Node) -> KBucket:
raise ValueError("No bucket found for node with id {}".format(node.id))


class CallbackLock:
def __init__(self,
callback: Callable,
timeout: float=2 * k_request_timeout) -> None:
self.callback = callback
self.timeout = timeout
self.created_at = time.time()

@property
def is_expired(self):
return time.time() - self.created_at > self.timeout


class CallbackManager(collections.UserDict):
@contextlib.contextmanager
def acquire(self,
key: Hashable,
callback: Callable[..., Any]) -> Generator[CallbackLock, None, None]:
if key in self:
if not self.locked(key):
del self[key]
else:
raise AlreadyWaiting("Already waiting on callback for: {0}".format(key))

lock = CallbackLock(callback)
self[key] = lock

try:
yield lock
finally:
del self[key]

def get_callback(self, key: Hashable) -> Callable[..., Any]:
return self[key].callback

def locked(self, key: Hashable) -> bool:
try:
lock = self[key]
except KeyError:
return False
else:
if lock.is_expired:
return False
else:
return True


class KademliaProtocol:
logger = logging.getLogger("p2p.kademlia.KademliaProtocol")

def __init__(self, node: Node, wire: 'DiscoveryProtocol') -> None:
self.this_node = node
self.wire = wire
self.routing = RoutingTable(node)
self.pong_callbacks: Dict[bytes, Callable[[], None]] = {}
self.ping_callbacks: Dict[Node, Callable[[], None]] = {}
self.neighbours_callbacks: Dict[Node, Callable[[List[Node]], None]] = {}

self.pong_callbacks = CallbackManager()
self.ping_callbacks = CallbackManager()
self.neighbours_callbacks = CallbackManager()

def recv_neighbours(self, remote: Node, neighbours: List[Node]) -> None:
"""Process a neighbours response.
Expand All @@ -340,12 +393,13 @@ def recv_neighbours(self, remote: Node, neighbours: List[Node]) -> None:
wait_neighbours().
"""
self.logger.debug('<<< neighbours from %s: %s', remote, neighbours)
callback = self.neighbours_callbacks.get(remote)
if callback is not None:
callback(neighbours)
else:
try:
callback = self.neighbours_callbacks.get_callback(remote)
except KeyError:
self.logger.debug(
'unexpected neighbours from %s, probably came too late', remote)
else:
callback(neighbours)

def recv_pong(self, remote: Node, token: bytes) -> None:
"""Process a pong packet.
Expand All @@ -356,11 +410,13 @@ def recv_pong(self, remote: Node, token: bytes) -> None:
"""
self.logger.debug('<<< pong from %s (token == %s)', remote, encode_hex(token))
pingid = self._mkpingid(token, remote)
callback = self.pong_callbacks.get(pingid)
if callback is not None:
callback()
else:

try:
callback = self.pong_callbacks.get_callback(pingid)
except KeyError:
self.logger.debug('unexpected pong from %s (token == %s)', remote, encode_hex(token))
else:
callback()

def recv_ping(self, remote: Node, hash_: bytes) -> None:
"""Process a received ping packet.
Expand All @@ -372,10 +428,14 @@ def recv_ping(self, remote: Node, hash_: bytes) -> None:
self.logger.debug('<<< ping from %s', remote)
self.update_routing_table(remote)
self.wire.send_pong(remote, hash_)
# Sometimes a ping will be sent to us as part of the bond()ing performed the first time we
# see a node, and it is in those cases that a callback will exist.
callback = self.ping_callbacks.get(remote)
if callback is not None:
# Sometimes a ping will be sent to us as part of the bonding
# performed the first time we see a node, and it is in those cases that
# a callback will exist.
try:
callback = self.ping_callbacks.get_callback(remote)
except KeyError:
pass
else:
callback()

def recv_find_node(self, remote: Node, targetid: int) -> None:
Expand Down Expand Up @@ -406,21 +466,17 @@ async def wait_ping(self, remote: Node, cancel_token: CancelToken) -> bool:
called or a timeout (k_request_timeout) occurs. At that point it returns whether or not
a ping was received from the given node.
"""
if remote in self.ping_callbacks:
raise AlreadyWaiting(
"There's another coroutine waiting for a ping packet from {}".format(remote))

event = asyncio.Event()
self.ping_callbacks[remote] = event.set
got_ping = False
try:
got_ping = await wait_with_token(
event.wait(), token=cancel_token, timeout=k_request_timeout)
self.logger.debug('got expected ping from %s', remote)
except TimeoutError:
self.logger.debug('timed out waiting for ping from %s', remote)
# TODO: Use a contextmanager to ensure we always delete the callback from the list.
del self.ping_callbacks[remote]

with self.ping_callbacks.acquire(remote, event.set):
got_ping = False
try:
got_ping = await wait_with_token(
event.wait(), token=cancel_token, timeout=k_request_timeout)
self.logger.debug('got expected ping from %s', remote)
except TimeoutError:
self.logger.debug('timed out waiting for ping from %s', remote)

return got_ping

async def wait_pong(self, remote: Node, token: bytes, cancel_token: CancelToken) -> bool:
Expand All @@ -431,33 +487,28 @@ async def wait_pong(self, remote: Node, token: bytes, cancel_token: CancelToken)
a pong was received with the given pingid.
"""
pingid = self._mkpingid(token, remote)
if pingid in self.pong_callbacks:
raise AlreadyWaiting(
"There's another coroutine waiting for a pong packet with id {}".format(pingid))

event = asyncio.Event()
self.pong_callbacks[pingid] = event.set
got_pong = False
try:
got_pong = await wait_with_token(
event.wait(), token=cancel_token, timeout=k_request_timeout)
self.logger.debug('got expected pong with token %s', encode_hex(token))
except TimeoutError:
self.logger.debug(
'timed out waiting for pong from %s (token == %s)', remote, encode_hex(token))
# TODO: Use a contextmanager to ensure we always delete the callback from the list.
del self.pong_callbacks[pingid]

with self.pong_callbacks.acquire(pingid, event.set):
got_pong = False
try:
got_pong = await wait_with_token(
event.wait(), token=cancel_token, timeout=k_request_timeout)
self.logger.debug('got expected pong with token %s', encode_hex(token))
except TimeoutError:
self.logger.debug(
'timed out waiting for pong from %s (token == %s)',
remote,
encode_hex(token),
)

return got_pong

async def wait_neighbours(self, remote: Node, cancel_token: CancelToken) -> List[Node]:
async def wait_neighbours(self, remote: Node, cancel_token: CancelToken) -> Tuple[Node, ...]:
"""Wait for a neihgbours packet from the given node.

Returns the list of neighbours received.
"""
if remote in self.neighbours_callbacks:
raise AlreadyWaiting(
"There's another coroutine waiting for a neighbours packet from {}".format(remote))

event = asyncio.Event()
neighbours: List[Node] = []

Expand All @@ -469,17 +520,15 @@ def process(response):
if len(neighbours) == k_bucket_size:
event.set()

self.neighbours_callbacks[remote] = process
try:
await wait_with_token(
event.wait(), token=cancel_token, timeout=k_request_timeout)
self.logger.debug('got expected neighbours response from %s', remote)
except TimeoutError:
self.logger.debug('timed out waiting for neighbours response from %s', remote)
with self.neighbours_callbacks.acquire(remote, process):
try:
await wait_with_token(
event.wait(), token=cancel_token, timeout=k_request_timeout)
self.logger.debug('got expected neighbours response from %s', remote)
except TimeoutError:
self.logger.debug('timed out waiting for neighbours response from %s', remote)

# TODO: Use a contextmanager to ensure we always delete the callback from the list.
del self.neighbours_callbacks[remote]
return [n for n in neighbours if n != self.this_node]
return tuple(n for n in neighbours if n != self.this_node)

def ping(self, node: Node) -> bytes:
if node == self.this_node:
Expand All @@ -494,10 +543,17 @@ async def bond(self, node: Node, cancel_token: CancelToken) -> bool:
"""
if node in self.routing:
return True
elif node == self.this_node:
return False

token = self.ping(node)

got_pong = await self.wait_pong(node, token, cancel_token)
try:
got_pong = await self.wait_pong(node, token, cancel_token)
except AlreadyWaiting:
self.logger.debug("binding failed, already waiting for pong")
return False

if not got_pong:
self.logger.debug("bonding failed, didn't receive pong from %s", node)
# Drop the failing node and schedule a populate_not_full_buckets() call to try and
Expand All @@ -506,17 +562,27 @@ async def bond(self, node: Node, cancel_token: CancelToken) -> bool:
asyncio.ensure_future(self.populate_not_full_buckets())
return False

# Give the remote node a chance to ping us before we move on and start sending find_node
# requests. It is ok for wait_ping() to timeout and return false here as that just means
# the remote remembers us.
await self.wait_ping(node, cancel_token)
try:
# Give the remote node a chance to ping us before we move on and
# start sending find_node requests. It is ok for wait_ping() to
# timeout and return false here as that just means the remote
# remembers us.
await self.wait_ping(node, cancel_token)
except AlreadyWaiting:
self.logger.debug("binding failed, already waiting for ping")
return False

self.logger.debug("bonding completed successfully with %s", node)
self.update_routing_table(node)
return True

async def bootstrap(self, bootstrap_nodes: List[Node], cancel_token: CancelToken) -> None:
bonded = await asyncio.gather(*[self.bond(n, cancel_token) for n in bootstrap_nodes])
bonded = await asyncio.gather(*(
self.bond(n, cancel_token)
for n
in bootstrap_nodes
if (not self.ping_callbacks.locked(n) and not self.pong_callbacks.locked(n))
))
if not any(bonded):
self.logger.info("Failed to bond with bootstrap nodes %s", bootstrap_nodes)
return
Expand All @@ -539,15 +605,19 @@ async def _find_node(node_id, remote):
candidates = await self.wait_neighbours(remote, cancel_token)
if not candidates:
self.logger.debug("got no candidates from %s, returning", remote)
return candidates
candidates = [c for c in candidates if c not in nodes_seen]
return tuple()
all_candidates = tuple(c for c in candidates if c not in nodes_seen)
candidates = tuple(
c for c in all_candidates
if (not self.ping_callbacks.locked(c) and not self.pong_callbacks.locked(c))
)
self.logger.debug("got %s new candidates", len(candidates))
# Add new candidates to nodes_seen so that we don't attempt to bond with failing ones
# in the future.
nodes_seen.update(candidates)
bonded = await asyncio.gather(*[self.bond(c, cancel_token) for c in candidates])
bonded = await asyncio.gather(*(self.bond(c, cancel_token) for c in candidates))
self.logger.debug("bonded with %s candidates", bonded.count(True))
return [c for c in candidates if bonded[candidates.index(c)]]
return tuple(c for c in candidates if bonded[candidates.index(c)])

def _exclude_if_asked(nodes):
nodes_to_ask = list(set(nodes).difference(nodes_asked))
Expand All @@ -559,8 +629,12 @@ def _exclude_if_asked(nodes):
while nodes_to_ask:
self.logger.debug("node lookup; querying %s", nodes_to_ask)
nodes_asked.update(nodes_to_ask)
results = await asyncio.gather(
*[_find_node(node_id, n) for n in nodes_to_ask])
results = await asyncio.gather(*(
_find_node(node_id, n)
for n
in nodes_to_ask
if not self.neighbours_callbacks.locked(n)
))
for candidates in results:
closest.extend(candidates)
closest = sort_by_distance(closest, node_id)[:k_bucket_size]
Expand Down
Loading