diff --git a/hyper/common/connection.py b/hyper/common/connection.py index e225852e..855994f8 100644 --- a/hyper/common/connection.py +++ b/hyper/common/connection.py @@ -58,6 +58,7 @@ def __init__(self, proxy_host=None, proxy_port=None, proxy_headers=None, + timeout=None, **kwargs): self._host = host @@ -65,13 +66,15 @@ def __init__(self, self._h1_kwargs = { 'secure': secure, 'ssl_context': ssl_context, 'proxy_host': proxy_host, 'proxy_port': proxy_port, - 'proxy_headers': proxy_headers, 'enable_push': enable_push + 'proxy_headers': proxy_headers, 'enable_push': enable_push, + 'timeout': timeout } self._h2_kwargs = { 'window_manager': window_manager, 'enable_push': enable_push, 'secure': secure, 'ssl_context': ssl_context, 'proxy_host': proxy_host, 'proxy_port': proxy_port, - 'proxy_headers': proxy_headers + 'proxy_headers': proxy_headers, + 'timeout': timeout } # Add any unexpected kwargs to both dictionaries. diff --git a/hyper/contrib.py b/hyper/contrib.py index c269ab90..5a580f29 100644 --- a/hyper/contrib.py +++ b/hyper/contrib.py @@ -33,7 +33,7 @@ def __init__(self, *args, **kwargs): self.connections = {} def get_connection(self, host, port, scheme, cert=None, verify=True, - proxy=None): + proxy=None, timeout=None): """ Gets an appropriate HTTP/2 connection object based on host/port/scheme/cert tuples. @@ -77,13 +77,14 @@ def get_connection(self, host, port, scheme, cert=None, verify=True, secure=secure, ssl_context=ssl_context, proxy_host=proxy_netloc, - proxy_headers=proxy_headers) + proxy_headers=proxy_headers, + timeout=timeout) self.connections[connection_key] = conn return conn def send(self, request, stream=False, cert=None, verify=True, proxies=None, - **kwargs): + timeout=None, **kwargs): """ Sends a HTTP message to the server. """ @@ -98,7 +99,8 @@ def send(self, request, stream=False, cert=None, verify=True, proxies=None, parsed.scheme, cert=cert, verify=verify, - proxy=proxy) + proxy=proxy, + timeout=timeout) # Build the selector. selector = parsed.path diff --git a/hyper/http11/connection.py b/hyper/http11/connection.py index 225ef98e..4311d307 100644 --- a/hyper/http11/connection.py +++ b/hyper/http11/connection.py @@ -39,14 +39,14 @@ def _create_tunnel(proxy_host, proxy_port, target_host, target_port, - proxy_headers=None): + proxy_headers=None, timeout=None): """ Sends CONNECT method to a proxy and returns a socket with established connection to the target. :returns: socket """ - conn = HTTP11Connection(proxy_host, proxy_port) + conn = HTTP11Connection(proxy_host, proxy_port, timeout=timeout) conn.request('CONNECT', '%s:%d' % (target_host, target_port), headers=proxy_headers) @@ -101,7 +101,7 @@ class HTTP11Connection(object): def __init__(self, host, port=None, secure=None, ssl_context=None, proxy_host=None, proxy_port=None, proxy_headers=None, - **kwargs): + timeout=None, **kwargs): if port is None: self.host, self.port = to_host_port_tuple(host, default_port=80) else: @@ -150,6 +150,9 @@ def __init__(self, host, port=None, secure=None, ssl_context=None, #: the standard hyper parsing interface. self.parser = Parser() + # timeout + self._timeout = timeout + def connect(self): """ Connect to the server specified when the object was created. This is a @@ -159,6 +162,13 @@ def connect(self): """ if self._sock is None: + if isinstance(self._timeout, tuple): + connect_timeout = self._timeout[0] + read_timeout = self._timeout[1] + else: + connect_timeout = self._timeout + read_timeout = self._timeout + if self.proxy_host and self.secure: # Send http CONNECT method to a proxy and acquire the socket sock = _create_tunnel( @@ -166,16 +176,18 @@ def connect(self): self.proxy_port, self.host, self.port, - proxy_headers=self.proxy_headers + proxy_headers=self.proxy_headers, + timeout=self._timeout ) elif self.proxy_host: # Simple http proxy sock = socket.create_connection( (self.proxy_host, self.proxy_port), - 5 + timeout=connect_timeout ) else: - sock = socket.create_connection((self.host, self.port), 5) + sock = socket.create_connection((self.host, self.port), + timeout=connect_timeout) proto = None if self.secure: @@ -184,6 +196,9 @@ def connect(self): log.debug("Selected protocol: %s", proto) sock = BufferedSocket(sock, self.network_buffer_size) + # Set read timeout + sock.settimeout(read_timeout) + if proto not in ('http/1.1', None): raise TLSUpgrade(proto, sock) diff --git a/hyper/http20/connection.py b/hyper/http20/connection.py index 2451c3fe..b8be292b 100644 --- a/hyper/http20/connection.py +++ b/hyper/http20/connection.py @@ -102,7 +102,7 @@ class HTTP20Connection(object): def __init__(self, host, port=None, secure=None, window_manager=None, enable_push=False, ssl_context=None, proxy_host=None, proxy_port=None, force_proto=None, proxy_headers=None, - **kwargs): + timeout=None, **kwargs): """ Creates an HTTP/2 connection to a specific server. """ @@ -151,6 +151,9 @@ def __init__(self, host, port=None, secure=None, window_manager=None, self.__wm_class = window_manager or FlowControlManager self.__init_state() + # timeout + self._timeout = timeout + return def __init_state(self): @@ -343,6 +346,13 @@ def connect(self): if self._sock is not None: return + if isinstance(self._timeout, tuple): + connect_timeout = self._timeout[0] + read_timeout = self._timeout[1] + else: + connect_timeout = self._timeout + read_timeout = self._timeout + if self.proxy_host and self.secure: # Send http CONNECT method to a proxy and acquire the socket sock = _create_tunnel( @@ -350,15 +360,18 @@ def connect(self): self.proxy_port, self.host, self.port, - proxy_headers=self.proxy_headers + proxy_headers=self.proxy_headers, + timeout=self._timeout ) elif self.proxy_host: # Simple http proxy sock = socket.create_connection( - (self.proxy_host, self.proxy_port) + (self.proxy_host, self.proxy_port), + timeout=connect_timeout ) else: - sock = socket.create_connection((self.host, self.port)) + sock = socket.create_connection((self.host, self.port), + timeout=connect_timeout) if self.secure: sock, proto = wrap_socket(sock, self.host, self.ssl_context, @@ -374,6 +387,9 @@ def connect(self): self._sock = BufferedSocket(sock, self.network_buffer_size) + # Set read timeout + self._sock.settimeout(read_timeout) + self._send_preamble() def _connect_upgrade(self, sock): diff --git a/test/server.py b/test/server.py index 482bf734..edc28755 100644 --- a/test/server.py +++ b/test/server.py @@ -108,12 +108,13 @@ class SocketLevelTest(object): A test-class that defines a few helper methods for running socket-level tests. """ - def set_up(self, secure=True, proxy=False): + def set_up(self, secure=True, proxy=False, timeout=None): self.host = None self.port = None self.socket_security = SocketSecuritySetting(secure) self.proxy = proxy self.server_thread = None + self.timeout = timeout def _start_server(self, socket_handler): """ @@ -146,18 +147,22 @@ def secure(self, value): def get_connection(self): if self.h2: if not self.proxy: - return HTTP20Connection(self.host, self.port, self.secure) + return HTTP20Connection(self.host, self.port, self.secure, + timeout=self.timeout) else: return HTTP20Connection('http2bin.org', secure=self.secure, proxy_host=self.host, - proxy_port=self.port) + proxy_port=self.port, + timeout=self.timeout) else: if not self.proxy: - return HTTP11Connection(self.host, self.port, self.secure) + return HTTP11Connection(self.host, self.port, self.secure, + timeout=self.timeout) else: return HTTP11Connection('httpbin.org', secure=self.secure, proxy_host=self.host, - proxy_port=self.port) + proxy_port=self.port, + timeout=self.timeout) def get_encoder(self): """ diff --git a/test/test_abstraction.py b/test/test_abstraction.py index d48b3954..00ee16ec 100644 --- a/test/test_abstraction.py +++ b/test/test_abstraction.py @@ -10,7 +10,7 @@ def test_h1_kwargs(self): c = HTTPConnection( 'test', 443, secure=False, window_manager=True, enable_push=True, ssl_context=False, proxy_host=False, proxy_port=False, - proxy_headers=False, other_kwarg=True + proxy_headers=False, other_kwarg=True, timeout=5 ) assert c._h1_kwargs == { @@ -21,13 +21,14 @@ def test_h1_kwargs(self): 'proxy_headers': False, 'other_kwarg': True, 'enable_push': True, + 'timeout': 5, } def test_h2_kwargs(self): c = HTTPConnection( 'test', 443, secure=False, window_manager=True, enable_push=True, ssl_context=True, proxy_host=False, proxy_port=False, - proxy_headers=False, other_kwarg=True + proxy_headers=False, other_kwarg=True, timeout=(10, 30) ) assert c._h2_kwargs == { @@ -39,6 +40,7 @@ def test_h2_kwargs(self): 'proxy_port': False, 'proxy_headers': False, 'other_kwarg': True, + 'timeout': (10, 30), } def test_tls_upgrade(self, monkeypatch): diff --git a/test/test_http11.py b/test/test_http11.py index 40fea8a9..21dd7f70 100644 --- a/test/test_http11.py +++ b/test/test_http11.py @@ -110,6 +110,16 @@ def test_initialization_with_ipv6_addresses_proxy_inline_port(self): assert c.proxy_host == 'ffff:aaaa::1' assert c.proxy_port == 8443 + def test_initialization_timeout(self): + c = HTTP11Connection('httpbin.org', timeout=30) + + assert c._timeout == 30 + + def test_initialization_tuple_timeout(self): + c = HTTP11Connection('httpbin.org', timeout=(5, 60)) + + assert c._timeout == (5, 60) + def test_basic_request(self): c = HTTP11Connection('httpbin.org') c._sock = sock = DummySocket() diff --git a/test/test_hyper.py b/test/test_hyper.py index 76a68cfe..f4a5994d 100644 --- a/test/test_hyper.py +++ b/test/test_hyper.py @@ -98,6 +98,16 @@ def test_connection_version(self): c = HTTP20Connection('www.google.com') assert c.version is HTTPVersion.http20 + def test_connection_timeout(self): + c = HTTP20Connection('httpbin.org', timeout=30) + + assert c._timeout == 30 + + def test_connection_tuple_timeout(self): + c = HTTP20Connection('httpbin.org', timeout=(5, 60)) + + assert c._timeout == (5, 60) + def test_ping(self, frame_buffer): def data_callback(chunk, **kwargs): frame_buffer.add_data(chunk) diff --git a/test/test_integration.py b/test/test_integration.py index e1c87673..bde7d393 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -13,6 +13,7 @@ import hyper import hyper.http11.connection import pytest +from socket import timeout as SocketTimeout from contextlib import contextmanager from mock import patch from concurrent.futures import ThreadPoolExecutor, TimeoutError @@ -1230,6 +1231,110 @@ def do_connect(conn): self.tear_down() + def test_connection_timeout(self): + self.set_up(timeout=0.5) + + def socket_handler(listener): + time.sleep(1) + + self._start_server(socket_handler) + conn = self.get_connection() + + with pytest.raises((SocketTimeout, ssl.SSLError)): + # Py2 raises this as a BaseSSLError, + # Py3 raises it as socket timeout. + conn.connect() + + self.tear_down() + + def test_hyper_connection_timeout(self): + self.set_up(timeout=0.5) + + def socket_handler(listener): + time.sleep(1) + + self._start_server(socket_handler) + conn = hyper.HTTPConnection(self.host, self.port, self.secure, + timeout=self.timeout) + + with pytest.raises((SocketTimeout, ssl.SSLError)): + # Py2 raises this as a BaseSSLError, + # Py3 raises it as socket timeout. + conn.request('GET', '/') + + self.tear_down() + + def test_read_timeout(self): + self.set_up(timeout=(10, 0.5)) + + req_event = threading.Event() + + def socket_handler(listener): + sock = listener.accept()[0] + + # We get two messages for the connection open and then a HEADERS + # frame. + receive_preamble(sock) + sock.recv(65535) + + # Wait for request + req_event.wait(5) + + # Sleep wait for read timeout + time.sleep(1) + + sock.close() + + self._start_server(socket_handler) + conn = self.get_connection() + conn.request('GET', '/') + req_event.set() + + with pytest.raises((SocketTimeout, ssl.SSLError)): + # Py2 raises this as a BaseSSLError, + # Py3 raises it as socket timeout. + conn.get_response() + + self.tear_down() + + def test_default_connection_timeout(self): + self.set_up(timeout=None) + + # Confirm that we send the connection upgrade string and the initial + # SettingsFrame. + data = [] + send_event = threading.Event() + + def socket_handler(listener): + time.sleep(1) + sock = listener.accept()[0] + + # We should get one big chunk. + first = sock.recv(65535) + data.append(first) + + # We need to send back a SettingsFrame. + f = SettingsFrame(0) + sock.send(f.serialize()) + + send_event.set() + sock.close() + + self._start_server(socket_handler) + conn = self.get_connection() + try: + conn.connect() + except (SocketTimeout, ssl.SSLError): + # Py2 raises this as a BaseSSLError, + # Py3 raises it as socket timeout. + pytest.fail() + + send_event.wait(5) + + assert data[0].startswith(b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n') + + self.tear_down() + @patch('hyper.http20.connection.H2_NPN_PROTOCOLS', PROTOCOLS) class TestRequestsAdapter(SocketLevelTest): @@ -1537,3 +1642,78 @@ def socket_handler(listener): assert r.content == b'' self.tear_down() + + def test_adapter_connection_timeout(self, monkeypatch, frame_buffer): + self.set_up() + + # We need to patch the ssl_wrap_socket method to ensure that we + # forcefully upgrade. + old_wrap_socket = hyper.http11.connection.wrap_socket + + def wrap(*args): + sock, _ = old_wrap_socket(*args) + return sock, 'h2' + + monkeypatch.setattr(hyper.http11.connection, 'wrap_socket', wrap) + + def socket_handler(listener): + time.sleep(1) + + self._start_server(socket_handler) + + s = requests.Session() + s.mount('https://%s' % self.host, HTTP20Adapter()) + + with pytest.raises((SocketTimeout, ssl.SSLError)): + # Py2 raises this as a BaseSSLError, + # Py3 raises it as socket timeout. + s.get('https://%s:%s/some/path' % (self.host, self.port), + timeout=0.5) + + self.tear_down() + + def test_adapter_read_timeout(self, monkeypatch, frame_buffer): + self.set_up() + + # We need to patch the ssl_wrap_socket method to ensure that we + # forcefully upgrade. + old_wrap_socket = hyper.http11.connection.wrap_socket + + def wrap(*args): + sock, _ = old_wrap_socket(*args) + return sock, 'h2' + + monkeypatch.setattr(hyper.http11.connection, 'wrap_socket', wrap) + + def socket_handler(listener): + sock = listener.accept()[0] + + # Do the handshake: conn header, settings, send settings, recv ack. + frame_buffer.add_data(receive_preamble(sock)) + + # Now expect some data. One headers frame. + req_wait = True + while req_wait: + frame_buffer.add_data(sock.recv(65535)) + with reusable_frame_buffer(frame_buffer) as fr: + for f in fr: + if isinstance(f, HeadersFrame): + req_wait = False + + # Sleep wait for read timeout + time.sleep(1) + + sock.close() + + self._start_server(socket_handler) + + s = requests.Session() + s.mount('https://%s' % self.host, HTTP20Adapter()) + + with pytest.raises((SocketTimeout, ssl.SSLError)): + # Py2 raises this as a BaseSSLError, + # Py3 raises it as socket timeout. + s.get('https://%s:%s/some/path' % (self.host, self.port), + timeout=(10, 0.5)) + + self.tear_down() diff --git a/test/test_integration_http11.py b/test/test_integration_http11.py index ee318797..7ec3846a 100644 --- a/test/test_integration_http11.py +++ b/test/test_integration_http11.py @@ -9,6 +9,8 @@ import hyper import threading import pytest +import time +from socket import timeout as SocketTimeout from hyper.compat import ssl from server import SocketLevelTest, SocketSecuritySetting @@ -442,3 +444,68 @@ def socket_handler(listener): with pytest.raises(HTTPUpgrade): c.get_response() + + def test_connection_timeout(self): + self.set_up(timeout=0.5) + + def socket_handler(listener): + time.sleep(1) + + self._start_server(socket_handler) + conn = self.get_connection() + + with pytest.raises((SocketTimeout, ssl.SSLError)): + # Py2 raises this as a BaseSSLError, + # Py3 raises it as socket timeout. + conn.connect() + + self.tear_down() + + def test_hyper_connection_timeout(self): + self.set_up(timeout=0.5) + + def socket_handler(listener): + time.sleep(1) + + self._start_server(socket_handler) + conn = hyper.HTTPConnection(self.host, self.port, self.secure, + timeout=self.timeout) + + with pytest.raises((SocketTimeout, ssl.SSLError)): + # Py2 raises this as a BaseSSLError, + # Py3 raises it as socket timeout. + conn.request('GET', '/') + + self.tear_down() + + def test_read_timeout(self): + self.set_up(timeout=(10, 0.5)) + + send_event = threading.Event() + + def socket_handler(listener): + sock = listener.accept()[0] + + # We should get the initial request. + data = b'' + while not data.endswith(b'\r\n\r\n'): + data += sock.recv(65535) + + send_event.wait() + + # Sleep wait for read timeout + time.sleep(1) + + sock.close() + + self._start_server(socket_handler) + conn = self.get_connection() + conn.request('GET', '/') + send_event.set() + + with pytest.raises((SocketTimeout, ssl.SSLError)): + # Py2 raises this as a BaseSSLError, + # Py3 raises it as socket timeout. + conn.get_response() + + self.tear_down()