diff --git a/LICENSE b/LICENSE index d24c351e..965d0b55 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2014 Cory Benfield +Copyright (c) 2014 Cory Benfield, Google Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/hyper/compat.py b/hyper/compat.py index 2b41357e..150bf3e8 100644 --- a/hyper/compat.py +++ b/hyper/compat.py @@ -15,12 +15,19 @@ # TODO log? ssl_compat = None +try: + import google.appengine + is_appengine = True +except ImportError: + is_appengine = False + _ver = sys.version_info is_py2 = _ver[0] == 2 is_py2_7_9_or_later = _ver[0] >= 2 and _ver[1] >= 7 and _ver[2] >= 9 is_py3 = _ver[0] == 3 is_py3_3 = is_py3 and _ver[1] == 3 + @contextmanager def ignore_missing(): try: @@ -29,7 +36,10 @@ def ignore_missing(): pass if is_py2: - if is_py2_7_9_or_later: + if is_appengine: + from . import ssl_compat_appengine + ssl = ssl_compat_appengine + elif is_py2_7_9_or_later: import ssl else: ssl = ssl_compat diff --git a/hyper/ssl_compat_appengine.py b/hyper/ssl_compat_appengine.py new file mode 100644 index 00000000..61bef518 --- /dev/null +++ b/hyper/ssl_compat_appengine.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +""" +hyper/ssl_compat_appengine +~~~~~~~~~ + +Provides the ssl module interface which hyper assumes, based on AppEngine ssl. + +This module complements some constants and classes which don't exist on +AppEngine's SSL module, to be used by hyper smoothly. +See: https://cloud.google.com/appengine/docs/python/sockets/ssl_support +""" + +import ssl + +from ssl import PROTOCOL_TLSv1, PROTOCOL_SSLv23, PROTOCOL_SSLv3 +from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED +from ssl import match_hostname + +OP_NO_COMPRESSION = 0 + + +class SSLContext(object): + """A SSL context which implements the methods used by hyper.""" + + def __init__(self, proto): + self.protocol_version = proto + self.verify_mode = CERT_REQUIRED + self.check_hostname = True + self.options = 0 + self.custom_ca_cert = None + + def set_default_verify_paths(self): + # Intentionally do nothing. + pass + + def load_verify_locations(self, cafile): + self.custom_ca_cert = cafile + + def wrap_socket(self, sock, server_side=False, do_handshake_on_connect=True, + suppress_ragged_eofs=True, server_hostname=None): + sock = ssl.wrap_socket( + sock, server_side=server_side, cert_reqs=self.verify_mode, + ssl_version=self.protocol_version, ca_certs=self.custom_ca_cert, + suppress_ragged_eofs=suppress_ragged_eofs, + do_handshake_on_connect=do_handshake_on_connect) + # AppEngine SSLSocket does not have selected_npn_protocol, a hacky + # solution to return a dummy data. + sock.selected_npn_protocol = lambda: 'h2' + return sock diff --git a/test/test_appengine.py b/test/test_appengine.py new file mode 100644 index 00000000..fc1ac188 --- /dev/null +++ b/test/test_appengine.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +""" +Tests the ssl compatibility module for appengine. +""" +import hyper +from hyper import ssl_compat_appengine +from server import SocketLevelTest +import socket +import pytest + + +class TestAppengine(object): + """ + Test cases for ssl_compat_appengine module. + """ + + def test_field_existences(self): + assert ssl_compat_appengine.PROTOCOL_TLSv1 is not None + assert ssl_compat_appengine.PROTOCOL_SSLv23 is not None + assert ssl_compat_appengine.PROTOCOL_SSLv3 is not None + assert ssl_compat_appengine.CERT_NONE is not None + assert ssl_compat_appengine.CERT_OPTIONAL is not None + assert ssl_compat_appengine.CERT_REQUIRED is not None + assert ssl_compat_appengine.OP_NO_COMPRESSION is not None + assert ssl_compat_appengine.match_hostname is not None + assert ssl_compat_appengine.SSLContext is not None + + def test_SSLContext(self): + context = ssl_compat_appengine.SSLContext( + ssl_compat_appengine.PROTOCOL_SSLv23) + context.set_default_verify_paths() + assert context.protocol_version == ssl_compat_appengine.PROTOCOL_SSLv23 + + +class TestAppengineSocket(SocketLevelTest): + """ + Test case for wrap_socket. + """ + h2 = False + + def socket_handler(self, listener): + sock = listener.accept()[0] + sock.do_handshake() + sock.close() + + def test_wrap_socket(self): + self.set_up() + self._start_server(self.socket_handler) + context = ssl_compat_appengine.SSLContext( + ssl_compat_appengine.PROTOCOL_SSLv23) + context.set_default_verify_paths() + context.verify_mode = ssl_compat_appengine.CERT_NONE + context.verify_hostname = False + # This invocation makes sure it does not fail. + context.load_verify_locations('test/certs/server.crt') + sock = socket.create_connection( + (self.server_thread.host, self.server_thread.port)) + ssl_sock = context.wrap_socket(sock) + assert ssl_sock.selected_npn_protocol() == 'h2' + ssl_sock.close() + self.tear_down()