@@ -77,8 +77,7 @@ def pytest_collection_modifyitems(config, items):
77
77
item .add_marker (skip_transport )
78
78
79
79
80
- @pytest .fixture
81
- async def aiohttp_server ():
80
+ async def aiohttp_server_base (with_ssl = False ):
82
81
"""Factory to create a TestServer instance, given an app.
83
82
84
83
aiohttp_server(app, **kwargs)
@@ -89,7 +88,13 @@ async def aiohttp_server():
89
88
90
89
async def go (app , * , port = None , ** kwargs ): # type: ignore
91
90
server = AIOHTTPTestServer (app , port = port )
92
- await server .start_server (** kwargs )
91
+
92
+ start_server_args = {** kwargs }
93
+ if with_ssl :
94
+ testcert , ssl_context = get_localhost_ssl_context ()
95
+ start_server_args ["ssl" ] = ssl_context
96
+
97
+ await server .start_server (** start_server_args )
93
98
servers .append (server )
94
99
return server
95
100
@@ -99,6 +104,18 @@ async def go(app, *, port=None, **kwargs): # type: ignore
99
104
await servers .pop ().close ()
100
105
101
106
107
+ @pytest .fixture
108
+ async def aiohttp_server ():
109
+ async for server in aiohttp_server_base ():
110
+ yield server
111
+
112
+
113
+ @pytest .fixture
114
+ async def ssl_aiohttp_server ():
115
+ async for server in aiohttp_server_base (with_ssl = True ):
116
+ yield server
117
+
118
+
102
119
# Adding debug logs to websocket tests
103
120
for name in [
104
121
"websockets.legacy.server" ,
@@ -121,6 +138,22 @@ async def go(app, *, port=None, **kwargs): # type: ignore
121
138
MS = 0.001 * int (os .environ .get ("GQL_TESTS_TIMEOUT_FACTOR" , 1 ))
122
139
123
140
141
+ def get_localhost_ssl_context ():
142
+ # This is a copy of certificate from websockets tests folder
143
+ #
144
+ # Generate TLS certificate with:
145
+ # $ openssl req -x509 -config test_localhost.cnf \
146
+ # -days 15340 -newkey rsa:2048 \
147
+ # -out test_localhost.crt -keyout test_localhost.key
148
+ # $ cat test_localhost.key test_localhost.crt > test_localhost.pem
149
+ # $ rm test_localhost.key test_localhost.crt
150
+ testcert = bytes (pathlib .Path (__file__ ).with_name ("test_localhost.pem" ))
151
+ ssl_context = ssl .SSLContext (ssl .PROTOCOL_TLS_SERVER )
152
+ ssl_context .load_cert_chain (testcert )
153
+
154
+ return (testcert , ssl_context )
155
+
156
+
124
157
class WebSocketServer :
125
158
"""Websocket server on localhost on a free port.
126
159
@@ -141,20 +174,7 @@ async def start(self, handler, extra_serve_args=None):
141
174
extra_serve_args = {}
142
175
143
176
if self .with_ssl :
144
- # This is a copy of certificate from websockets tests folder
145
- #
146
- # Generate TLS certificate with:
147
- # $ openssl req -x509 -config test_localhost.cnf \
148
- # -days 15340 -newkey rsa:2048 \
149
- # -out test_localhost.crt -keyout test_localhost.key
150
- # $ cat test_localhost.key test_localhost.crt > test_localhost.pem
151
- # $ rm test_localhost.key test_localhost.crt
152
- self .testcert = bytes (
153
- pathlib .Path (__file__ ).with_name ("test_localhost.pem" )
154
- )
155
- ssl_context = ssl .SSLContext (ssl .PROTOCOL_TLS_SERVER )
156
- ssl_context .load_cert_chain (self .testcert )
157
-
177
+ self .testcert , ssl_context = get_localhost_ssl_context ()
158
178
extra_serve_args ["ssl" ] = ssl_context
159
179
160
180
# Start a server with a random open port
0 commit comments