diff --git a/p2p/chain.py b/p2p/chain.py index d9f8d97474..4d4094d306 100644 --- a/p2p/chain.py +++ b/p2p/chain.py @@ -43,9 +43,9 @@ from p2p import les from p2p.cancellable import CancellableMixin from p2p.constants import MAX_REORG_DEPTH, SEAL_CHECK_RANDOM_SAMPLE_RATE -from p2p.exceptions import NoEligiblePeers +from p2p.exceptions import NoEligiblePeers, ValidationError from p2p.p2p_proto import DisconnectReason -from p2p.peer import BasePeer, ETHPeer, LESPeer, HeaderRequest, PeerPool, PeerSubscriber +from p2p.peer import BasePeer, ETHPeer, LESPeer, PeerPool, PeerSubscriber from p2p.rlp import BlockBody from p2p.service import BaseService from p2p.utils import ( @@ -91,7 +91,6 @@ def __init__(self, self._syncing = False self._sync_complete = asyncio.Event() self._sync_requests: asyncio.Queue[HeaderRequestingPeer] = asyncio.Queue() - self._new_headers: asyncio.Queue[Tuple[BlockHeader, ...]] = asyncio.Queue() self._executor = get_asyncio_executor() @property @@ -207,7 +206,7 @@ async def _sync(self, peer: HeaderRequestingPeer) -> None: self.logger.warn("Timeout waiting for header batch from %s, aborting sync", peer) await peer.disconnect(DisconnectReason.timeout) break - except ValueError as err: + except ValidationError as err: self.logger.warn( "Invalid header response sent by peer %s disconnecting: %s", peer, err, @@ -253,47 +252,37 @@ async def _fetch_missing_headers( self, peer: HeaderRequestingPeer, start_at: int) -> Tuple[BlockHeader, ...]: """Fetch a batch of headers starting at start_at and return the ones we're missing.""" self.logger.debug("Fetching chain segment starting at #%d", start_at) - request = peer.request_block_headers( + + headers = await peer.get_block_headers( start_at, peer.max_headers_fetch, skip=0, reverse=False, ) - # Pass the peer's token to self.wait() because we want to abort if either we - # or the peer terminates. - headers = tuple(await self.wait( - self._new_headers.get(), - token=peer.cancel_token, - timeout=self._reply_timeout)) - - # check that the response headers are a valid match for our - # requested headers. - request.validate_headers(headers) - - # the inner list comprehension is required to get python to evaluate - # the asynchronous comprehension - missing_headers = tuple([ - header - for header - in headers - if not (await self.wait(self.db.coro_header_exists(header.hash))) - ]) - if len(missing_headers) != len(headers): - self.logger.debug( - "Discarding %d / %d headers that we already have", - len(headers) - len(missing_headers), - len(headers), - ) - return headers - - def _handle_block_headers(self, headers: Tuple[BlockHeader, ...]) -> None: - if not headers: - self.logger.warn("Got an empty BlockHeaders msg") - return - self.logger.debug( - "Got BlockHeaders from %d to %d", headers[0].block_number, headers[-1].block_number) - self._new_headers.put_nowait(headers) + # We only want headers that are missing, so we iterate over the list + # until we find the first missing header, after which we return all of + # the remaining headers. + async def get_missing_tail(self: 'BaseHeaderChainSyncer', + headers: Tuple[BlockHeader, ...] + ) -> AsyncGenerator[BlockHeader, None]: + iter_headers = iter(headers) + for header in iter_headers: + is_missing = not await self.wait(self.db.coro_header_exists(header.hash)) + if is_missing: + yield header + break + else: + self.logger.debug("Discarding header that we already have: %s", header) + + for header in iter_headers: + yield header + + # The inner list comprehension is needed because async_generators + # cannot be cast to a tuple. + tail_headers = tuple([header async for header in get_missing_tail(self, headers)]) + + return tail_headers @abstractmethod async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command, @@ -313,26 +302,27 @@ async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: if isinstance(cmd, les.Announce): self._sync_requests.put_nowait(peer) - elif isinstance(cmd, les.BlockHeaders): - msg = cast(Dict[str, Any], msg) - self._handle_block_headers(tuple(cast(Tuple[BlockHeader, ...], msg['headers']))) elif isinstance(cmd, les.GetBlockHeaders): msg = cast(Dict[str, Any], msg) await self._handle_get_block_headers(cast(LESPeer, peer), msg) + elif isinstance(cmd, les.BlockHeaders): + # `BlockHeaders` messages are handled at the peer level. + pass else: self.logger.debug("Ignoring %s message from %s", cmd, peer) async def _handle_get_block_headers(self, peer: LESPeer, msg: Dict[str, Any]) -> None: self.logger.debug("Peer %s made header request: %s", peer, msg) - request = HeaderRequest( + request = les.HeaderRequest( msg['query'].block_number_or_hash, msg['query'].max_headers, msg['query'].skip, msg['query'].reverse, + msg['request_id'], ) headers = await self._handler.lookup_headers(request) self.logger.trace("Replying to %s with %d headers", peer, len(headers)) - peer.sub_proto.send_block_headers(headers, buffer_value=0, request_id=msg['request_id']) + peer.sub_proto.send_block_headers(headers, buffer_value=0, request_id=request.request_id) async def _process_headers( self, peer: HeaderRequestingPeer, headers: Tuple[BlockHeader, ...]) -> int: @@ -538,9 +528,7 @@ def request_receipts(self, target_td: int, headers: List[BlockHeader]) -> int: async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: peer = cast(ETHPeer, peer) - if isinstance(cmd, eth.BlockHeaders): - self._handle_block_headers(tuple(cast(Tuple[BlockHeader, ...], msg))) - elif isinstance(cmd, eth.BlockBodies): + if isinstance(cmd, eth.BlockBodies): await self._handle_block_bodies(peer, list(cast(Tuple[BlockBody], msg))) elif isinstance(cmd, eth.Receipts): await self._handle_block_receipts(peer, cast(List[List[Receipt]], msg)) @@ -548,6 +536,9 @@ async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command, await self._handle_new_block(peer, cast(Dict[str, Any], msg)) elif isinstance(cmd, eth.GetBlockHeaders): await self._handle_get_block_headers(peer, cast(Dict[str, Any], msg)) + elif isinstance(cmd, eth.BlockHeaders): + # `BlockHeaders` messages are handled at the peer level. + pass elif isinstance(cmd, eth.GetBlockBodies): # Only serve up to eth.MAX_BODIES_FETCH items in every request. block_hashes = cast(List[Hash32], msg)[:eth.MAX_BODIES_FETCH] @@ -613,7 +604,7 @@ async def _handle_get_block_headers( peer: ETHPeer, query: Dict[str, Any]) -> None: self.logger.debug("Peer %s made header request: %s", peer, query) - request = HeaderRequest( + request = eth.HeaderRequest( query['block_number_or_hash'], query['max_headers'], query['skip'], @@ -732,7 +723,7 @@ async def handle_get_node_data(self, peer: ETHPeer, node_hashes: List[Hash32]) - peer.sub_proto.send_node_data(nodes) async def lookup_headers(self, - request: HeaderRequest) -> Tuple[BlockHeader, ...]: + request: protocol.BaseHeaderRequest) -> Tuple[BlockHeader, ...]: """ Lookup :max_headers: headers starting at :block_number_or_hash:, skipping :skip: items between each, in reverse order if :reverse: is True. @@ -753,7 +744,8 @@ async def lookup_headers(self, return headers async def _get_block_numbers_for_request(self, - request: HeaderRequest) -> Tuple[BlockNumber, ...]: + request: protocol.BaseHeaderRequest + ) -> Tuple[BlockNumber, ...]: """ Generate the block numbers for a given `HeaderRequest`. """ diff --git a/p2p/eth.py b/p2p/eth.py index c60380915d..69e02a851e 100644 --- a/p2p/eth.py +++ b/p2p/eth.py @@ -1,5 +1,6 @@ import logging from typing import ( + Any, cast, List, Tuple, @@ -16,8 +17,10 @@ from eth.rlp.receipts import Receipt from eth.rlp.transactions import BaseTransactionFields +from p2p.exceptions import ValidationError from p2p.protocol import ( BaseBlockHeaders, + BaseHeaderRequest, Command, Protocol, _DecodedMsgType, @@ -70,6 +73,31 @@ class GetBlockHeaders(Command): ] +class HeaderRequest(BaseHeaderRequest): + max_size = MAX_HEADERS_FETCH + + def __init__(self, + block_number_or_hash: BlockIdentifier, + max_headers: int, + skip: int, + reverse: bool) -> None: + self.block_number_or_hash = block_number_or_hash + self.max_headers = max_headers + self.skip = skip + self.reverse = reverse + + def validate_response(self, response: Any) -> None: + """ + Core `Request` API used for validation. + """ + if not isinstance(response, tuple): + raise ValidationError("Response to `HeaderRequest` must be a tuple") + elif not all(isinstance(item, BlockHeader) for item in response): + raise ValidationError("Response must be a tuple of `BlockHeader` objects") + + return self.validate_headers(cast(Tuple[BlockHeader, ...], response)) + + class BlockHeaders(BaseBlockHeaders): _cmd_id = 4 structure = sedes.CountableList(BlockHeader) diff --git a/p2p/exceptions.py b/p2p/exceptions.py index e5b5db61f3..05a3ff1c0a 100644 --- a/p2p/exceptions.py +++ b/p2p/exceptions.py @@ -159,3 +159,10 @@ class NoInternalAddressMatchesDevice(BaseP2PError): def __init__(self, *args: Any, device_hostname: str=None, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.device_hostname = device_hostname + + +class ValidationError(BaseP2PError): + """ + Raised when something does not pass a validation check. + """ + pass diff --git a/p2p/les.py b/p2p/les.py index 1a903f8346..7748ec57d6 100644 --- a/p2p/les.py +++ b/p2p/les.py @@ -16,8 +16,12 @@ from eth.rlp.headers import BlockHeader from eth.rlp.receipts import Receipt +from p2p.exceptions import ( + ValidationError, +) from p2p.protocol import ( BaseBlockHeaders, + BaseHeaderRequest, Command, Protocol, _DecodedMsgType, @@ -163,6 +167,44 @@ class GetBlockHeadersQuery(rlp.Serializable): ] +class HeaderRequest(BaseHeaderRequest): + request_id: int + + max_size = MAX_HEADERS_FETCH + + def __init__(self, + block_number_or_hash: BlockIdentifier, + max_headers: int, + skip: int, + reverse: bool, + request_id: int) -> None: + self.block_number_or_hash = block_number_or_hash + self.max_headers = max_headers + self.skip = skip + self.reverse = reverse + self.request_id = request_id + + def validate_response(self, response: Any) -> None: + """ + Core `Request` API used for validation. + """ + if not isinstance(response, dict): + raise ValidationError("Response to `HeaderRequest` must be a dict") + + request_id = response['request_id'] + if request_id != self.request_id: + raise ValidationError( + "Response `request_id` does not match. expected: %s | got: %s".format( + self.request_id, + request_id, + ) + ) + elif not all(isinstance(item, BlockHeader) for item in response['headers']): + raise ValidationError("Response must be a tuple of `BlockHeader` objects") + + return self.validate_headers(cast(Tuple[BlockHeader, ...], response['headers'])) + + class GetBlockHeaders(Command): _cmd_id = 2 structure = [ diff --git a/p2p/peer.py b/p2p/peer.py index 135be20ba7..aeae5e12c3 100644 --- a/p2p/peer.py +++ b/p2p/peer.py @@ -20,7 +20,6 @@ Dict, Iterator, List, - NamedTuple, TYPE_CHECKING, Tuple, Type, @@ -60,8 +59,10 @@ from p2p import auth from p2p import ecies -from p2p.kademlia import Address, Node +from p2p import eth +from p2p import les from p2p import protocol +from p2p.kademlia import Address, Node from p2p.exceptions import ( BadAckMessage, DAOForkCheckFailure, @@ -75,17 +76,16 @@ UnexpectedMessage, UnknownProtocolCommand, UnreachablePeer, + ValidationError, ) from p2p.service import BaseService from p2p.utils import ( - gen_request_id, + gen_request_id as _gen_request_id, get_devp2p_cmd_id, roundup_16, sxor, time_since, ) -from p2p import eth -from p2p import les from p2p.p2p_proto import ( Disconnect, DisconnectReason, @@ -143,105 +143,6 @@ async def handshake(remote: Node, return peer -class HeaderRequest(NamedTuple): - block_number_or_hash: BlockIdentifier - max_headers: int - skip: int - reverse: bool - - def generate_block_numbers(self, - block_number: BlockNumber=None) -> Tuple[BlockNumber, ...]: - if block_number is None and not self.is_numbered: - raise TypeError( - "A `block_number` must be supplied to generate block numbers " - "for hash based header requests" - ) - elif block_number is not None and self.is_numbered: - raise TypeError( - "The `block_number` parameter may not be used for number based " - "header requests" - ) - elif block_number is None: - block_number = cast(BlockNumber, self.block_number_or_hash) - - max_headers = min(eth.MAX_HEADERS_FETCH, self.max_headers) - - # inline import until this module is moved to `trinity` - from trinity.utils.headers import sequence_builder - return sequence_builder( - block_number, - max_headers, - self.skip, - self.reverse, - ) - - @property - def is_numbered(self) -> bool: - return isinstance(self.block_number_or_hash, int) - - def validate_headers(self, - headers: Tuple[BlockHeader, ...]) -> None: - if not headers: - # An empty response is always valid - return - elif not self.is_numbered: - first_header = headers[0] - if first_header.hash != self.block_number_or_hash: - raise ValueError( - "Returned headers cannot be matched to header request. " - "Expected first header to have hash of {0} but instead got " - "{1}.".format( - encode_hex(self.block_number_or_hash), - encode_hex(first_header.hash), - ) - ) - - block_numbers: Tuple[BlockNumber, ...] = tuple( - header.block_number for header in headers - ) - return self.validate_sequence(block_numbers) - - def validate_sequence(self, block_numbers: Tuple[BlockNumber, ...]) -> None: - if not block_numbers: - return - elif self.is_numbered: - expected_numbers = self.generate_block_numbers() - else: - expected_numbers = self.generate_block_numbers(block_numbers[0]) - - # check for numbers that should not be present. - unexpected_numbers = set(block_numbers).difference(expected_numbers) - if unexpected_numbers: - raise ValueError( - 'Unexpected numbers: {0}'.format(unexpected_numbers)) - - # check that the numbers are correctly ordered. - expected_order = tuple(sorted( - block_numbers, - reverse=self.reverse, - )) - if block_numbers != expected_order: - raise ValueError( - 'Returned headers are not correctly ordered.\n' - 'Expected: {0}\n' - 'Got : {1}\n'.format(expected_order, block_numbers) - ) - - # check that all provided numbers are an ordered subset of the master - # sequence. - iter_expected = iter(expected_numbers) - for number in block_numbers: - for value in iter_expected: - if value == number: - break - else: - raise ValueError( - 'Returned headers contain an unexpected block number.\n' - 'Unexpected Number: {0}\n' - 'Expected Numbers : {1}'.format(number, expected_numbers) - ) - - class BasePeer(BaseService): conn_idle_timeout = CONN_IDLE_TIMEOUT # Must be defined in subclasses. All items here must be Protocol classes representing @@ -254,6 +155,14 @@ class BasePeer(BaseService): head_td: int = None head_hash: Hash32 = None + # TODO: Instead of a fixed timeout, we should instead monitor response + # times for the peer and adjust our timeout accordingly + _response_timeout = 60 + pending_requests: Dict[ + Type[protocol.Command], + Tuple[protocol.BaseRequest, 'asyncio.Future[protocol._DecodedMsgType]'], + ] + def __init__(self, remote: Node, privkey: datatypes.PrivateKey, @@ -280,6 +189,8 @@ def __init__(self, self.start_time = datetime.datetime.now() self.received_msgs: Dict[protocol.Command, int] = collections.defaultdict(int) + self.pending_requests = {} + self.egress_mac = egress_mac self.ingress_mac = ingress_mac # FIXME: Yes, the encryption is insecure, see: https://github.com/ethereum/devp2p/issues/32 @@ -470,6 +381,24 @@ def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgT else: self.logger.warn("Peer %s has no subscribers, discarding %s msg", self, cmd) + cmd_type = type(cmd) + + if cmd_type in self.pending_requests: + request, future = self.pending_requests[cmd_type] + try: + request.validate_response(msg) + except ValidationError as err: + self.logger.debug( + "Response validation failure for pending %s request from peer %s: %s", + cmd_type.__name__, + self, + err, + ) + pass + else: + future.set_result(msg) + self.pending_requests.pop(cmd_type) + def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: if cmd.is_base_protocol: self.handle_p2p_msg(cmd, msg) @@ -624,6 +553,7 @@ def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgT self.head_info = cmd.as_head_info(msg) self.head_td = self.head_info.total_difficulty self.head_hash = self.head_info.block_hash + super().handle_sub_proto_msg(cmd, msg) async def send_sub_proto_handshake(self) -> None: @@ -652,19 +582,23 @@ async def process_sub_proto_handshake( self.head_td = self.head_info.total_difficulty self.head_hash = self.head_info.block_hash + def gen_request_id(self) -> int: + return _gen_request_id() + def request_block_headers(self, block_number_or_hash: BlockIdentifier, max_headers: int = None, skip: int = 0, - reverse: bool = False) -> HeaderRequest: + reverse: bool = False) -> les.HeaderRequest: if max_headers is None: max_headers = self.max_headers_fetch - request_id = gen_request_id() - request = HeaderRequest( + request_id = self.gen_request_id() + request = les.HeaderRequest( block_number_or_hash, max_headers, skip, reverse, + request_id, ) self.sub_proto.send_get_block_headers( request.block_number_or_hash, @@ -675,6 +609,38 @@ def request_block_headers(self, ) return request + async def wait_for_block_headers(self, request: les.HeaderRequest) -> Tuple[BlockHeader, ...]: + future: 'asyncio.Future[protocol._DecodedMsgType]' = asyncio.Future() + if les.BlockHeaders in self.pending_requests: + # the `finally` block below should prevent this from happening, but + # were two requests to the same peer to be fired off at the same + # time, this will prevent us from overwriting the first one. + raise ValueError( + "There is already a pending `BlockHeaders` request for peer {0}".format(self) + ) + self.pending_requests[les.BlockHeaders] = cast( + Tuple[protocol.BaseRequest, 'asyncio.Future[protocol._DecodedMsgType]'], + (request, future), + ) + try: + response = cast( + Dict[str, Any], + await self.wait(future, timeout=self._response_timeout), + ) + finally: + # We always want to be sure that this method cleans up the + # `pending_requests` so that we don't end up in a situation. + self.pending_requests.pop(les.BlockHeaders, None) + return cast(Tuple[BlockHeader, ...], response['headers']) + + async def get_block_headers(self, + block_number_or_hash: BlockIdentifier, + max_headers: int = None, + skip: int = 0, + reverse: bool = True) -> Tuple[BlockHeader, ...]: + request = self.request_block_headers(block_number_or_hash, max_headers, skip, reverse) + return await self.wait_for_block_headers(request) + class ETHPeer(BasePeer): _supported_sub_protocols = [eth.ETHProtocol] @@ -693,6 +659,7 @@ def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgT if actual_td > self.head_td: self.head_hash = actual_head self.head_td = actual_td + super().handle_sub_proto_msg(cmd, msg) async def send_sub_proto_handshake(self) -> None: @@ -723,10 +690,10 @@ def request_block_headers(self, block_number_or_hash: BlockIdentifier, max_headers: int = None, skip: int = 0, - reverse: bool = True) -> HeaderRequest: + reverse: bool = True) -> eth.HeaderRequest: if max_headers is None: max_headers = self.max_headers_fetch - request = HeaderRequest( + request = eth.HeaderRequest( block_number_or_hash, max_headers, skip, @@ -740,6 +707,23 @@ def request_block_headers(self, ) return request + async def wait_for_block_headers(self, request: eth.HeaderRequest) -> Tuple[BlockHeader, ...]: + future: 'asyncio.Future[Tuple[BlockHeader, ...]]' = asyncio.Future() + self.pending_requests[eth.BlockHeaders] = cast( + Tuple[protocol.BaseRequest, 'asyncio.Future[protocol._DecodedMsgType]'], + (request, future), + ) + response = await self.wait(future, timeout=self._response_timeout) + return response + + async def get_block_headers(self, + block_number_or_hash: BlockIdentifier, + max_headers: int = None, + skip: int = 0, + reverse: bool = True) -> Tuple[BlockHeader, ...]: + request = self.request_block_headers(block_number_or_hash, max_headers, skip, reverse) + return await self.wait_for_block_headers(request) + class PeerSubscriber(ABC): _msg_queue: 'asyncio.Queue[PEER_MSG_TYPE]' = None @@ -984,7 +968,7 @@ async def ensure_same_side_on_dao_fork( try: request.validate_headers(headers) - except ValueError as err: + except ValidationError as err: raise DAOForkCheckFailure( "Invalid header response during DAO fork check: {}".format(err) ) diff --git a/p2p/protocol.py b/p2p/protocol.py index ccdac454b2..6928a3f468 100644 --- a/p2p/protocol.py +++ b/p2p/protocol.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import ( Any, + cast, Dict, List, Tuple, @@ -14,9 +15,14 @@ import rlp from rlp import sedes +from eth_utils import encode_hex + +from eth_typing import BlockIdentifier, BlockNumber + from eth.constants import NULL_BYTE from eth.rlp.headers import BlockHeader +from p2p.exceptions import ValidationError from p2p.utils import get_devp2p_cmd_id @@ -116,6 +122,120 @@ def encode(self, data: _DecodedMsgType) -> Tuple[bytes, bytes]: return header, body +class BaseRequest(ABC): + """ + Base representation of a *request* to a connected peer which has a matching + *response*. + """ + @abstractmethod + def validate_response(self, response: Any) -> None: + pass + + +class BaseHeaderRequest(BaseRequest): + block_number_or_hash: BlockIdentifier + max_headers: int + skip: int + reverse: bool + + @property + @abstractmethod + def max_size(self) -> int: + pass + + def generate_block_numbers(self, + block_number: BlockNumber=None) -> Tuple[BlockNumber, ...]: + if block_number is None and not self.is_numbered: + raise TypeError( + "A `block_number` must be supplied to generate block numbers " + "for hash based header requests" + ) + elif block_number is not None and self.is_numbered: + raise TypeError( + "The `block_number` parameter may not be used for number based " + "header requests" + ) + elif block_number is None: + block_number = cast(BlockNumber, self.block_number_or_hash) + + max_headers = min(self.max_size, self.max_headers) + + # inline import until this module is moved to `trinity` + from trinity.utils.headers import sequence_builder + return sequence_builder( + block_number, + max_headers, + self.skip, + self.reverse, + ) + + @property + def is_numbered(self) -> bool: + return isinstance(self.block_number_or_hash, int) + + def validate_headers(self, + headers: Tuple[BlockHeader, ...]) -> None: + if not headers: + # An empty response is always valid + return + elif not self.is_numbered: + first_header = headers[0] + if first_header.hash != self.block_number_or_hash: + raise ValidationError( + "Returned headers cannot be matched to header request. " + "Expected first header to have hash of {0} but instead got " + "{1}.".format( + encode_hex(self.block_number_or_hash), + encode_hex(first_header.hash), + ) + ) + + block_numbers: Tuple[BlockNumber, ...] = tuple( + header.block_number for header in headers + ) + return self.validate_sequence(block_numbers) + + def validate_sequence(self, block_numbers: Tuple[BlockNumber, ...]) -> None: + if not block_numbers: + return + elif self.is_numbered: + expected_numbers = self.generate_block_numbers() + else: + expected_numbers = self.generate_block_numbers(block_numbers[0]) + + # check for numbers that should not be present. + unexpected_numbers = set(block_numbers).difference(expected_numbers) + if unexpected_numbers: + raise ValidationError( + 'Unexpected numbers: {0}'.format(unexpected_numbers)) + + # check that the numbers are correctly ordered. + expected_order = tuple(sorted( + block_numbers, + reverse=self.reverse, + )) + if block_numbers != expected_order: + raise ValidationError( + 'Returned headers are not correctly ordered.\n' + 'Expected: {0}\n' + 'Got : {1}\n'.format(expected_order, block_numbers) + ) + + # check that all provided numbers are an ordered subset of the master + # sequence. + iter_expected = iter(expected_numbers) + for number in block_numbers: + for value in iter_expected: + if value == number: + break + else: + raise ValidationError( + 'Returned headers contain an unexpected block number.\n' + 'Unexpected Number: {0}\n' + 'Expected Numbers : {1}'.format(number, expected_numbers) + ) + + class Protocol: logger = logging.getLogger("p2p.protocol.Protocol") name: str = None diff --git a/p2p/state.py b/p2p/state.py index 3b7531667f..adb3f25c31 100644 --- a/p2p/state.py +++ b/p2p/state.py @@ -49,7 +49,7 @@ from p2p import protocol from p2p.chain import PeerRequestHandler from p2p.exceptions import NoEligiblePeers, NoIdlePeers -from p2p.peer import BasePeer, ETHPeer, HeaderRequest, PeerPool, PeerSubscriber +from p2p.peer import BasePeer, ETHPeer, PeerPool, PeerSubscriber from p2p.service import BaseService from p2p.utils import get_asyncio_executor, Timer @@ -177,7 +177,7 @@ async def _handle_msg( await self._process_nodes(zip(node_keys, msg)) elif isinstance(cmd, eth.GetBlockHeaders): query = cast(Dict[Any, Union[bool, int]], msg) - request = HeaderRequest( + request = eth.HeaderRequest( query['block_number_or_hash'], query['max_headers'], query['skip'], @@ -199,7 +199,7 @@ async def _handle_msg( else: self.logger.warn("%s not handled during StateSync, must be implemented", cmd) - async def _handle_get_block_headers(self, peer: ETHPeer, request: HeaderRequest) -> None: + async def _handle_get_block_headers(self, peer: ETHPeer, request: eth.HeaderRequest) -> None: headers = await self._handler.lookup_headers(request) peer.sub_proto.send_block_headers(headers) diff --git a/tests/p2p/test_header_request_object.py b/tests/p2p/test_header_request_object.py index 372bb5aa38..13d54434bd 100644 --- a/tests/p2p/test_header_request_object.py +++ b/tests/p2p/test_header_request_object.py @@ -1,6 +1,7 @@ import pytest -from p2p.chain import HeaderRequest +from p2p.exceptions import ValidationError +from p2p.protocol import BaseHeaderRequest FORWARD_0_to_5 = (0, 6, 0, False) @@ -13,6 +14,23 @@ BLOCK_HASH = b'\x01' * 32 +class HeaderRequest(BaseHeaderRequest): + max_size = 192 + + def __init__(self, + block_number_or_hash, + max_headers, + skip, + reverse): + self.block_number_or_hash = block_number_or_hash + self.max_headers = max_headers + self.skip = skip + self.reverse = reverse + + def validate_response(self, response): + pass + + @pytest.mark.parametrize( 'block_number_or_hash,expected', ( @@ -111,5 +129,5 @@ def test_header_request_sequence_matching( if is_match: request.validate_sequence(sequence) else: - with pytest.raises(ValueError): + with pytest.raises(ValidationError): request.validate_sequence(sequence) diff --git a/tests/p2p/test_peer_block_header_request_and_response_api.py b/tests/p2p/test_peer_block_header_request_and_response_api.py new file mode 100644 index 0000000000..890107acaa --- /dev/null +++ b/tests/p2p/test_peer_block_header_request_and_response_api.py @@ -0,0 +1,153 @@ +import asyncio + +import pytest + +from eth_utils import to_tuple + +from eth.rlp.headers import BlockHeader + +from p2p.peer import ETHPeer, LESPeer +from peer_helpers import ( + get_directly_linked_peers, +) + + +@to_tuple +def mk_header_chain(length): + assert length >= 1 + genesis = BlockHeader(difficulty=100, block_number=0, gas_limit=3000000) + yield genesis + parent = genesis + if length == 1: + return + + for i in range(length - 1): + header = BlockHeader( + difficulty=100, + block_number=parent.block_number + 1, + parent_hash=parent.hash, + gas_limit=3000000, + ) + yield header + parent = header + + +@pytest.fixture +async def eth_peer_and_remote(request, event_loop): + peer, remote = await get_directly_linked_peers( + request, + event_loop, + peer1_class=ETHPeer, + peer2_class=ETHPeer, + ) + return peer, remote + + +@pytest.fixture +async def les_peer_and_remote(request, event_loop): + peer, remote = await get_directly_linked_peers( + request, + event_loop, + peer1_class=LESPeer, + peer2_class=LESPeer, + ) + return peer, remote + + +@pytest.mark.parametrize( + 'params,headers', + ( + ((0, 1, 0, False), mk_header_chain(1)), + ((0, 10, 0, False), mk_header_chain(10)), + ((3, 5, 0, False), mk_header_chain(10)[3:8]), + ) +) +@pytest.mark.asyncio +async def test_eth_peer_get_headers_round_trip(eth_peer_and_remote, + params, + headers): + peer, remote = eth_peer_and_remote + + async def send_headers(): + remote.sub_proto.send_block_headers(headers) + await asyncio.sleep(0) + + asyncio.ensure_future(send_headers()) + response = await peer.get_block_headers(*params) + + assert len(response) == len(headers) + for expected, actual in zip(headers, response): + assert expected == actual + + +@pytest.mark.parametrize( + 'params,headers', + ( + ((0, 1, 0, False), mk_header_chain(1)), + ((0, 10, 0, False), mk_header_chain(10)), + ((3, 5, 0, False), mk_header_chain(10)[3:8]), + ) +) +@pytest.mark.asyncio +async def test_les_peer_get_headers_round_trip(les_peer_and_remote, + params, + headers): + peer, remote = les_peer_and_remote + request_id = 1234 + + peer.gen_request_id = lambda: request_id + + async def send_headers(): + remote.sub_proto.send_block_headers(headers, 0, request_id) + await asyncio.sleep(0) + + asyncio.ensure_future(send_headers()) + response = await peer.get_block_headers(*params) + + assert len(response) == len(headers) + for expected, actual in zip(headers, response): + assert expected == actual + + +@pytest.mark.asyncio +async def test_eth_peer_get_headers_round_trip_with_noise(eth_peer_and_remote): + peer, remote = eth_peer_and_remote + + headers = mk_header_chain(10) + + async def send_responses(): + remote.sub_proto.send_node_data([b'arst', b'tsra']) + await asyncio.sleep(0) + remote.sub_proto.send_block_headers(headers) + await asyncio.sleep(0) + + asyncio.ensure_future(send_responses()) + response = await peer.get_block_headers(0, 10, 0, False) + + assert len(response) == len(headers) + for expected, actual in zip(headers, response): + assert expected == actual + + +@pytest.mark.asyncio +async def test_eth_peer_get_headers_round_trip_does_not_match_invalid_response(eth_peer_and_remote): + peer, remote = eth_peer_and_remote + + headers = mk_header_chain(5) + + wrong_headers = mk_header_chain(10)[3:8] + + async def send_responses(): + remote.sub_proto.send_node_data([b'arst', b'tsra']) + await asyncio.sleep(0) + remote.sub_proto.send_block_headers(wrong_headers) + await asyncio.sleep(0) + remote.sub_proto.send_block_headers(headers) + await asyncio.sleep(0) + + asyncio.ensure_future(send_responses()) + response = await peer.get_block_headers(0, 5, 0, False) + + assert len(response) == len(headers) + for expected, actual in zip(headers, response): + assert expected == actual diff --git a/trinity/utils/headers.py b/trinity/utils/headers.py index ea662fc780..255855cf56 100644 --- a/trinity/utils/headers.py +++ b/trinity/utils/headers.py @@ -1,15 +1,11 @@ from typing import ( Iterator, - TYPE_CHECKING, ) from eth_utils import to_tuple from eth.constants import UINT_256_MAX -if TYPE_CHECKING: - from p2p.peer import HeaderRequest # noqa: F401 - @to_tuple def sequence_builder(start_number: int,