diff --git a/p2p/discovery.py b/p2p/discovery.py index 02fb7ceae0..5b4eab55b9 100644 --- a/p2p/discovery.py +++ b/p2p/discovery.py @@ -24,6 +24,11 @@ to_list, ) +from eth_utils import ( + int_to_big_endian, + big_endian_to_int, +) + from eth_keys import keys from eth_keys import datatypes @@ -31,10 +36,6 @@ from p2p.cancel_token import CancelToken from p2p import kademlia -from evm.utils.numeric import ( - big_endian_to_int, - int_to_big_endian, -) # UDP packet constants. MAC_SIZE = 256 // 8 # 32 diff --git a/p2p/kademlia.py b/p2p/kademlia.py index 104c19883f..a4c894ab5f 100644 --- a/p2p/kademlia.py +++ b/p2p/kademlia.py @@ -1,23 +1,26 @@ import asyncio +import bisect +import collections +import contextlib +from functools import total_ordering import ipaddress import logging -import bisect import operator import random import struct import time -from urllib import parse as urlparse -from functools import total_ordering from typing import ( + Any, Callable, - Dict, Generator, + Hashable, List, Set, Sized, Tuple, TYPE_CHECKING, ) +from urllib import parse as urlparse from eth_utils import ( big_endian_to_int, @@ -237,6 +240,8 @@ def get_random_nodes(self, count: int) -> Generator[Node, None, None]: # to nodes. while len(seen) < count: bucket = random.choice(self.buckets) + if not bucket.nodes: + continue node = random.choice(bucket.nodes) if node not in seen: yield node @@ -320,6 +325,53 @@ def binary_get_bucket_for_node(buckets: List[KBucket], node: Node) -> KBucket: raise ValueError("No bucket found for node with id {}".format(node.id)) +class CallbackLock: + def __init__(self, + callback: Callable, + timeout: float=2 * k_request_timeout) -> None: + self.callback = callback + self.timeout = timeout + self.created_at = time.time() + + @property + def is_expired(self): + return time.time() - self.created_at > self.timeout + + +class CallbackManager(collections.UserDict): + @contextlib.contextmanager + def acquire(self, + key: Hashable, + callback: Callable[..., Any]) -> Generator[CallbackLock, None, None]: + if key in self: + if not self.locked(key): + del self[key] + else: + raise AlreadyWaiting("Already waiting on callback for: {0}".format(key)) + + lock = CallbackLock(callback) + self[key] = lock + + try: + yield lock + finally: + del self[key] + + def get_callback(self, key: Hashable) -> Callable[..., Any]: + return self[key].callback + + def locked(self, key: Hashable) -> bool: + try: + lock = self[key] + except KeyError: + return False + else: + if lock.is_expired: + return False + else: + return True + + class KademliaProtocol: logger = logging.getLogger("p2p.kademlia.KademliaProtocol") @@ -327,9 +379,10 @@ def __init__(self, node: Node, wire: 'DiscoveryProtocol') -> None: self.this_node = node self.wire = wire self.routing = RoutingTable(node) - self.pong_callbacks: Dict[bytes, Callable[[], None]] = {} - self.ping_callbacks: Dict[Node, Callable[[], None]] = {} - self.neighbours_callbacks: Dict[Node, Callable[[List[Node]], None]] = {} + + self.pong_callbacks = CallbackManager() + self.ping_callbacks = CallbackManager() + self.neighbours_callbacks = CallbackManager() def recv_neighbours(self, remote: Node, neighbours: List[Node]) -> None: """Process a neighbours response. @@ -340,12 +393,13 @@ def recv_neighbours(self, remote: Node, neighbours: List[Node]) -> None: wait_neighbours(). """ self.logger.debug('<<< neighbours from %s: %s', remote, neighbours) - callback = self.neighbours_callbacks.get(remote) - if callback is not None: - callback(neighbours) - else: + try: + callback = self.neighbours_callbacks.get_callback(remote) + except KeyError: self.logger.debug( 'unexpected neighbours from %s, probably came too late', remote) + else: + callback(neighbours) def recv_pong(self, remote: Node, token: bytes) -> None: """Process a pong packet. @@ -356,11 +410,13 @@ def recv_pong(self, remote: Node, token: bytes) -> None: """ self.logger.debug('<<< pong from %s (token == %s)', remote, encode_hex(token)) pingid = self._mkpingid(token, remote) - callback = self.pong_callbacks.get(pingid) - if callback is not None: - callback() - else: + + try: + callback = self.pong_callbacks.get_callback(pingid) + except KeyError: self.logger.debug('unexpected pong from %s (token == %s)', remote, encode_hex(token)) + else: + callback() def recv_ping(self, remote: Node, hash_: bytes) -> None: """Process a received ping packet. @@ -372,10 +428,14 @@ def recv_ping(self, remote: Node, hash_: bytes) -> None: self.logger.debug('<<< ping from %s', remote) self.update_routing_table(remote) self.wire.send_pong(remote, hash_) - # Sometimes a ping will be sent to us as part of the bond()ing performed the first time we - # see a node, and it is in those cases that a callback will exist. - callback = self.ping_callbacks.get(remote) - if callback is not None: + # Sometimes a ping will be sent to us as part of the bonding + # performed the first time we see a node, and it is in those cases that + # a callback will exist. + try: + callback = self.ping_callbacks.get_callback(remote) + except KeyError: + pass + else: callback() def recv_find_node(self, remote: Node, targetid: int) -> None: @@ -406,21 +466,17 @@ async def wait_ping(self, remote: Node, cancel_token: CancelToken) -> bool: called or a timeout (k_request_timeout) occurs. At that point it returns whether or not a ping was received from the given node. """ - if remote in self.ping_callbacks: - raise AlreadyWaiting( - "There's another coroutine waiting for a ping packet from {}".format(remote)) - event = asyncio.Event() - self.ping_callbacks[remote] = event.set - got_ping = False - try: - got_ping = await wait_with_token( - event.wait(), token=cancel_token, timeout=k_request_timeout) - self.logger.debug('got expected ping from %s', remote) - except TimeoutError: - self.logger.debug('timed out waiting for ping from %s', remote) - # TODO: Use a contextmanager to ensure we always delete the callback from the list. - del self.ping_callbacks[remote] + + with self.ping_callbacks.acquire(remote, event.set): + got_ping = False + try: + got_ping = await wait_with_token( + event.wait(), token=cancel_token, timeout=k_request_timeout) + self.logger.debug('got expected ping from %s', remote) + except TimeoutError: + self.logger.debug('timed out waiting for ping from %s', remote) + return got_ping async def wait_pong(self, remote: Node, token: bytes, cancel_token: CancelToken) -> bool: @@ -431,33 +487,28 @@ async def wait_pong(self, remote: Node, token: bytes, cancel_token: CancelToken) a pong was received with the given pingid. """ pingid = self._mkpingid(token, remote) - if pingid in self.pong_callbacks: - raise AlreadyWaiting( - "There's another coroutine waiting for a pong packet with id {}".format(pingid)) - event = asyncio.Event() - self.pong_callbacks[pingid] = event.set - got_pong = False - try: - got_pong = await wait_with_token( - event.wait(), token=cancel_token, timeout=k_request_timeout) - self.logger.debug('got expected pong with token %s', encode_hex(token)) - except TimeoutError: - self.logger.debug( - 'timed out waiting for pong from %s (token == %s)', remote, encode_hex(token)) - # TODO: Use a contextmanager to ensure we always delete the callback from the list. - del self.pong_callbacks[pingid] + + with self.pong_callbacks.acquire(pingid, event.set): + got_pong = False + try: + got_pong = await wait_with_token( + event.wait(), token=cancel_token, timeout=k_request_timeout) + self.logger.debug('got expected pong with token %s', encode_hex(token)) + except TimeoutError: + self.logger.debug( + 'timed out waiting for pong from %s (token == %s)', + remote, + encode_hex(token), + ) + return got_pong - async def wait_neighbours(self, remote: Node, cancel_token: CancelToken) -> List[Node]: + async def wait_neighbours(self, remote: Node, cancel_token: CancelToken) -> Tuple[Node, ...]: """Wait for a neihgbours packet from the given node. Returns the list of neighbours received. """ - if remote in self.neighbours_callbacks: - raise AlreadyWaiting( - "There's another coroutine waiting for a neighbours packet from {}".format(remote)) - event = asyncio.Event() neighbours: List[Node] = [] @@ -469,17 +520,15 @@ def process(response): if len(neighbours) == k_bucket_size: event.set() - self.neighbours_callbacks[remote] = process - try: - await wait_with_token( - event.wait(), token=cancel_token, timeout=k_request_timeout) - self.logger.debug('got expected neighbours response from %s', remote) - except TimeoutError: - self.logger.debug('timed out waiting for neighbours response from %s', remote) + with self.neighbours_callbacks.acquire(remote, process): + try: + await wait_with_token( + event.wait(), token=cancel_token, timeout=k_request_timeout) + self.logger.debug('got expected neighbours response from %s', remote) + except TimeoutError: + self.logger.debug('timed out waiting for neighbours response from %s', remote) - # TODO: Use a contextmanager to ensure we always delete the callback from the list. - del self.neighbours_callbacks[remote] - return [n for n in neighbours if n != self.this_node] + return tuple(n for n in neighbours if n != self.this_node) def ping(self, node: Node) -> bytes: if node == self.this_node: @@ -494,10 +543,17 @@ async def bond(self, node: Node, cancel_token: CancelToken) -> bool: """ if node in self.routing: return True + elif node == self.this_node: + return False token = self.ping(node) - got_pong = await self.wait_pong(node, token, cancel_token) + try: + got_pong = await self.wait_pong(node, token, cancel_token) + except AlreadyWaiting: + self.logger.debug("binding failed, already waiting for pong") + return False + if not got_pong: self.logger.debug("bonding failed, didn't receive pong from %s", node) # Drop the failing node and schedule a populate_not_full_buckets() call to try and @@ -506,17 +562,27 @@ async def bond(self, node: Node, cancel_token: CancelToken) -> bool: asyncio.ensure_future(self.populate_not_full_buckets()) return False - # Give the remote node a chance to ping us before we move on and start sending find_node - # requests. It is ok for wait_ping() to timeout and return false here as that just means - # the remote remembers us. - await self.wait_ping(node, cancel_token) + try: + # Give the remote node a chance to ping us before we move on and + # start sending find_node requests. It is ok for wait_ping() to + # timeout and return false here as that just means the remote + # remembers us. + await self.wait_ping(node, cancel_token) + except AlreadyWaiting: + self.logger.debug("binding failed, already waiting for ping") + return False self.logger.debug("bonding completed successfully with %s", node) self.update_routing_table(node) return True async def bootstrap(self, bootstrap_nodes: List[Node], cancel_token: CancelToken) -> None: - bonded = await asyncio.gather(*[self.bond(n, cancel_token) for n in bootstrap_nodes]) + bonded = await asyncio.gather(*( + self.bond(n, cancel_token) + for n + in bootstrap_nodes + if (not self.ping_callbacks.locked(n) and not self.pong_callbacks.locked(n)) + )) if not any(bonded): self.logger.info("Failed to bond with bootstrap nodes %s", bootstrap_nodes) return @@ -539,15 +605,19 @@ async def _find_node(node_id, remote): candidates = await self.wait_neighbours(remote, cancel_token) if not candidates: self.logger.debug("got no candidates from %s, returning", remote) - return candidates - candidates = [c for c in candidates if c not in nodes_seen] + return tuple() + all_candidates = tuple(c for c in candidates if c not in nodes_seen) + candidates = tuple( + c for c in all_candidates + if (not self.ping_callbacks.locked(c) and not self.pong_callbacks.locked(c)) + ) self.logger.debug("got %s new candidates", len(candidates)) # Add new candidates to nodes_seen so that we don't attempt to bond with failing ones # in the future. nodes_seen.update(candidates) - bonded = await asyncio.gather(*[self.bond(c, cancel_token) for c in candidates]) + bonded = await asyncio.gather(*(self.bond(c, cancel_token) for c in candidates)) self.logger.debug("bonded with %s candidates", bonded.count(True)) - return [c for c in candidates if bonded[candidates.index(c)]] + return tuple(c for c in candidates if bonded[candidates.index(c)]) def _exclude_if_asked(nodes): nodes_to_ask = list(set(nodes).difference(nodes_asked)) @@ -559,8 +629,12 @@ def _exclude_if_asked(nodes): while nodes_to_ask: self.logger.debug("node lookup; querying %s", nodes_to_ask) nodes_asked.update(nodes_to_ask) - results = await asyncio.gather( - *[_find_node(node_id, n) for n in nodes_to_ask]) + results = await asyncio.gather(*( + _find_node(node_id, n) + for n + in nodes_to_ask + if not self.neighbours_callbacks.locked(n) + )) for candidates in results: closest.extend(candidates) closest = sort_by_distance(closest, node_id)[:k_bucket_size] diff --git a/tests/p2p/test_kademlia.py b/tests/p2p/test_kademlia.py index 7e4fe67fef..bef1ae701d 100644 --- a/tests/p2p/test_kademlia.py +++ b/tests/p2p/test_kademlia.py @@ -112,7 +112,7 @@ async def test_wait_neighbours(cancel_token): # Schedule a call to proto.recv_neighbours() simulating a neighbours response from the node we # expect. - neighbours = [random_node(), random_node(), random_node()] + neighbours = (random_node(), random_node(), random_node()) recv_neighbours_coroutine = asyncio.coroutine(lambda: proto.recv_neighbours(node, neighbours)) asyncio.ensure_future(recv_neighbours_coroutine()) @@ -125,7 +125,7 @@ async def test_wait_neighbours(cancel_token): # If wait_neighbours() times out, we get an empty list of neighbours. received_neighbours = await proto.wait_neighbours(node, cancel_token) - assert received_neighbours == [] + assert received_neighbours == tuple() assert node not in proto.neighbours_callbacks