Skip to content
This repository was archived by the owner on Sep 8, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions p2p/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,18 @@
to_list,
)

from eth_utils import (
int_to_big_endian,
big_endian_to_int,
)

from eth_keys import keys
from eth_keys import datatypes

from eth_hash.auto import keccak

from p2p.cancel_token import CancelToken
from p2p import kademlia
from evm.utils.numeric import (
big_endian_to_int,
int_to_big_endian,
)

# UDP packet constants.
MAC_SIZE = 256 // 8 # 32
Expand Down
194 changes: 135 additions & 59 deletions p2p/kademlia.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
import asyncio
import bisect
import collections
import contextlib
from functools import total_ordering
import ipaddress
import logging
import bisect
import operator
import random
import struct
import time
from urllib import parse as urlparse
from functools import total_ordering
from typing import (
Any,
Callable,
Dict,
Generator,
Hashable,
List,
Set,
Sized,
Tuple,
TYPE_CHECKING,
)
from urllib import parse as urlparse

from eth_utils import (
big_endian_to_int,
Expand Down Expand Up @@ -237,6 +240,8 @@ def get_random_nodes(self, count: int) -> Generator[Node, None, None]:
# to nodes.
while len(seen) < count:
bucket = random.choice(self.buckets)
if not bucket.nodes:
continue
node = random.choice(bucket.nodes)
if node not in seen:
yield node
Expand Down Expand Up @@ -320,16 +325,64 @@ def binary_get_bucket_for_node(buckets: List[KBucket], node: Node) -> KBucket:
raise ValueError("No bucket found for node with id {}".format(node.id))


class CallbackLock:
def __init__(self,
callback: Callable,
timeout: int=300) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about making the default here something like 2 * k_request_timeout as by then we can be sure nobody should be waiting for the callback anyway?

self.callback = callback
self.timeout = timeout
self.created_at = time.time()

@property
def is_expired(self):
return time.time() - self.created_at > self.timeout


class CallbackManager(collections.UserDict):
@contextlib.contextmanager
def acquire(self,
key: Hashable,
callback: Callable[..., Any]) -> Generator[CallbackLock, None, None]:
if key in self:
if not self.locked(key):
del self[key]
else:
raise AlreadyWaiting("Already waiting on callback for: {0}".format(key))

lock = CallbackLock(callback)
self[key] = lock

try:
yield lock
finally:
del self[key]

def get_callback(self, key: Hashable) -> Callable[..., Any]:
return self[key].callback

def locked(self, key: Hashable) -> bool:
try:
lock = self[key]
except KeyError:
return False
else:
if lock.is_expired:
return False
else:
return True


class KademliaProtocol:
logger = logging.getLogger("p2p.kademlia.KademliaProtocol")

def __init__(self, node: Node, wire: 'DiscoveryProtocol') -> None:
self.this_node = node
self.wire = wire
self.routing = RoutingTable(node)
self.pong_callbacks: Dict[bytes, Callable[[], None]] = {}
self.ping_callbacks: Dict[Node, Callable[[], None]] = {}
self.neighbours_callbacks: Dict[Node, Callable[[List[Node]], None]] = {}

self.pong_callbacks = CallbackManager()
self.ping_callbacks = CallbackManager()
self.neighbours_callbacks = CallbackManager()

def recv_neighbours(self, remote: Node, neighbours: List[Node]) -> None:
"""Process a neighbours response.
Expand All @@ -340,12 +393,13 @@ def recv_neighbours(self, remote: Node, neighbours: List[Node]) -> None:
wait_neighbours().
"""
self.logger.debug('<<< neighbours from %s: %s', remote, neighbours)
callback = self.neighbours_callbacks.get(remote)
if callback is not None:
callback(neighbours)
else:
try:
callback = self.neighbours_callbacks.get_callback(remote)
except KeyError:
self.logger.debug(
'unexpected neighbours from %s, probably came too late', remote)
else:
callback(neighbours)

def recv_pong(self, remote: Node, token: bytes) -> None:
"""Process a pong packet.
Expand All @@ -356,11 +410,13 @@ def recv_pong(self, remote: Node, token: bytes) -> None:
"""
self.logger.debug('<<< pong from %s (token == %s)', remote, encode_hex(token))
pingid = self._mkpingid(token, remote)
callback = self.pong_callbacks.get(pingid)
if callback is not None:
callback()
else:

try:
callback = self.pong_callbacks.get_callback(pingid)
except KeyError:
self.logger.debug('unexpected pong from %s (token == %s)', remote, encode_hex(token))
else:
callback()

def recv_ping(self, remote: Node, hash_: bytes) -> None:
"""Process a received ping packet.
Expand All @@ -372,10 +428,14 @@ def recv_ping(self, remote: Node, hash_: bytes) -> None:
self.logger.debug('<<< ping from %s', remote)
self.update_routing_table(remote)
self.wire.send_pong(remote, hash_)
# Sometimes a ping will be sent to us as part of the bond()ing performed the first time we
# see a node, and it is in those cases that a callback will exist.
callback = self.ping_callbacks.get(remote)
if callback is not None:
# Sometimes a ping will be sent to us as part of the bonding
# performed the first time we see a node, and it is in those cases that
# a callback will exist.
try:
callback = self.ping_callbacks.get_callback(remote)
except KeyError:
pass
else:
callback()

def recv_find_node(self, remote: Node, targetid: int) -> None:
Expand Down Expand Up @@ -411,16 +471,16 @@ async def wait_ping(self, remote: Node, cancel_token: CancelToken) -> bool:
"There's another coroutine waiting for a ping packet from {}".format(remote))

event = asyncio.Event()
self.ping_callbacks[remote] = event.set
got_ping = False
try:
got_ping = await wait_with_token(
event.wait(), token=cancel_token, timeout=k_request_timeout)
self.logger.debug('got expected ping from %s', remote)
except TimeoutError:
self.logger.debug('timed out waiting for ping from %s', remote)
# TODO: Use a contextmanager to ensure we always delete the callback from the list.
del self.ping_callbacks[remote]

with self.ping_callbacks.acquire(remote, event.set):
got_ping = False
try:
got_ping = await wait_with_token(
event.wait(), token=cancel_token, timeout=k_request_timeout)
self.logger.debug('got expected ping from %s', remote)
except TimeoutError:
self.logger.debug('timed out waiting for ping from %s', remote)

return got_ping

async def wait_pong(self, remote: Node, token: bytes, cancel_token: CancelToken) -> bool:
Expand All @@ -433,23 +493,28 @@ async def wait_pong(self, remote: Node, token: bytes, cancel_token: CancelToken)
pingid = self._mkpingid(token, remote)
if pingid in self.pong_callbacks:
raise AlreadyWaiting(
"There's another coroutine waiting for a pong packet with id {}".format(pingid))
"There's another coroutine waiting for a pong packet with id "
"{}".format(pingid)
)

event = asyncio.Event()
self.pong_callbacks[pingid] = event.set
got_pong = False
try:
got_pong = await wait_with_token(
event.wait(), token=cancel_token, timeout=k_request_timeout)
self.logger.debug('got expected pong with token %s', encode_hex(token))
except TimeoutError:
self.logger.debug(
'timed out waiting for pong from %s (token == %s)', remote, encode_hex(token))
# TODO: Use a contextmanager to ensure we always delete the callback from the list.
del self.pong_callbacks[pingid]

with self.pong_callbacks.acquire(pingid, event.set):
got_pong = False
try:
got_pong = await wait_with_token(
event.wait(), token=cancel_token, timeout=k_request_timeout)
self.logger.debug('got expected pong with token %s', encode_hex(token))
except TimeoutError:
self.logger.debug(
'timed out waiting for pong from %s (token == %s)',
remote,
encode_hex(token),
)

return got_pong

async def wait_neighbours(self, remote: Node, cancel_token: CancelToken) -> List[Node]:
async def wait_neighbours(self, remote: Node, cancel_token: CancelToken) -> Tuple[Node, ...]:
"""Wait for a neihgbours packet from the given node.

Returns the list of neighbours received.
Expand All @@ -469,17 +534,15 @@ def process(response):
if len(neighbours) == k_bucket_size:
event.set()

self.neighbours_callbacks[remote] = process
try:
await wait_with_token(
event.wait(), token=cancel_token, timeout=k_request_timeout)
self.logger.debug('got expected neighbours response from %s', remote)
except TimeoutError:
self.logger.debug('timed out waiting for neighbours response from %s', remote)
with self.neighbours_callbacks.acquire(remote, process):
try:
await wait_with_token(
event.wait(), token=cancel_token, timeout=k_request_timeout)
self.logger.debug('got expected neighbours response from %s', remote)
except TimeoutError:
self.logger.debug('timed out waiting for neighbours response from %s', remote)

# TODO: Use a contextmanager to ensure we always delete the callback from the list.
del self.neighbours_callbacks[remote]
return [n for n in neighbours if n != self.this_node]
return tuple(n for n in neighbours if n != self.this_node)

def ping(self, node: Node) -> bytes:
if node == self.this_node:
Expand Down Expand Up @@ -516,7 +579,12 @@ async def bond(self, node: Node, cancel_token: CancelToken) -> bool:
return True

async def bootstrap(self, bootstrap_nodes: List[Node], cancel_token: CancelToken) -> None:
bonded = await asyncio.gather(*[self.bond(n, cancel_token) for n in bootstrap_nodes])
bonded = await asyncio.gather(*(
self.bond(n, cancel_token)
for n
in bootstrap_nodes
if (not self.ping_callbacks.locked(n) and not self.pong_callbacks.locked(n))
))
if not any(bonded):
self.logger.info("Failed to bond with bootstrap nodes %s", bootstrap_nodes)
return
Expand All @@ -539,15 +607,19 @@ async def _find_node(node_id, remote):
candidates = await self.wait_neighbours(remote, cancel_token)
if not candidates:
self.logger.debug("got no candidates from %s, returning", remote)
return candidates
candidates = [c for c in candidates if c not in nodes_seen]
return tuple()
all_candidates = tuple(c for c in candidates if c not in nodes_seen)
candidates = tuple(
c for c in all_candidates
if (not self.ping_callbacks.locked(c) and not self.pong_callbacks.locked(c))
)
self.logger.debug("got %s new candidates", len(candidates))
# Add new candidates to nodes_seen so that we don't attempt to bond with failing ones
# in the future.
nodes_seen.update(candidates)
bonded = await asyncio.gather(*[self.bond(c, cancel_token) for c in candidates])
bonded = await asyncio.gather(*(self.bond(c, cancel_token) for c in candidates))
self.logger.debug("bonded with %s candidates", bonded.count(True))
return [c for c in candidates if bonded[candidates.index(c)]]
return tuple(c for c in candidates if bonded[candidates.index(c)])

def _exclude_if_asked(nodes):
nodes_to_ask = list(set(nodes).difference(nodes_asked))
Expand All @@ -559,8 +631,12 @@ def _exclude_if_asked(nodes):
while nodes_to_ask:
self.logger.debug("node lookup; querying %s", nodes_to_ask)
nodes_asked.update(nodes_to_ask)
results = await asyncio.gather(
*[_find_node(node_id, n) for n in nodes_to_ask])
results = await asyncio.gather(*(
_find_node(node_id, n)
for n
in nodes_to_ask
if not self.neighbours_callbacks.locked(n)
))
for candidates in results:
closest.extend(candidates)
closest = sort_by_distance(closest, node_id)[:k_bucket_size]
Expand Down
4 changes: 2 additions & 2 deletions tests/p2p/test_kademlia.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def test_wait_neighbours(cancel_token):

# Schedule a call to proto.recv_neighbours() simulating a neighbours response from the node we
# expect.
neighbours = [random_node(), random_node(), random_node()]
neighbours = (random_node(), random_node(), random_node())
recv_neighbours_coroutine = asyncio.coroutine(lambda: proto.recv_neighbours(node, neighbours))
asyncio.ensure_future(recv_neighbours_coroutine())

Expand All @@ -125,7 +125,7 @@ async def test_wait_neighbours(cancel_token):
# If wait_neighbours() times out, we get an empty list of neighbours.
received_neighbours = await proto.wait_neighbours(node, cancel_token)

assert received_neighbours == []
assert received_neighbours == tuple()
assert node not in proto.neighbours_callbacks


Expand Down