10
10
from websockets .extensions .permessage_deflate import ClientPerMessageDeflateFactory
11
11
from websockets .typing import Subprotocol
12
12
13
+ from tests .response import Response
13
14
from tests .utils import run_server
14
15
from uvicorn ._types import (
15
16
ASGIReceiveCallable ,
17
+ ASGIReceiveEvent ,
16
18
ASGISendCallable ,
17
19
Scope ,
18
20
WebSocketCloseEvent ,
19
21
WebSocketDisconnectEvent ,
22
+ WebSocketResponseStartEvent ,
20
23
)
21
24
from uvicorn .config import Config
22
25
from uvicorn .protocols .websockets .websockets_impl import WebSocketProtocol
@@ -55,6 +58,21 @@ async def asgi(self):
55
58
break
56
59
57
60
61
+ async def wsresponse (url ):
62
+ """
63
+ A simple websocket connection request and response helper
64
+ """
65
+ url = url .replace ("ws:" , "http:" )
66
+ headers = {
67
+ "connection" : "upgrade" ,
68
+ "upgrade" : "websocket" ,
69
+ "Sec-WebSocket-Key" : "x3JJHMbDL1EzLkh9GBhXDw==" ,
70
+ "Sec-WebSocket-Version" : "13" ,
71
+ }
72
+ async with httpx .AsyncClient () as client :
73
+ return await client .get (url , headers = headers )
74
+
75
+
58
76
@pytest .mark .anyio
59
77
async def test_invalid_upgrade (
60
78
ws_protocol_cls : "typing.Type[WSProtocol | WebSocketProtocol]" ,
@@ -942,7 +960,10 @@ async def test_server_reject_connection(
942
960
http_protocol_cls : "typing.Type[H11Protocol | HttpToolsProtocol]" ,
943
961
unused_tcp_port : int ,
944
962
):
963
+ disconnected_message : ASGIReceiveEvent = {} # type: ignore
964
+
945
965
async def app (scope : Scope , receive : ASGIReceiveCallable , send : ASGISendCallable ):
966
+ nonlocal disconnected_message
946
967
assert scope ["type" ] == "websocket"
947
968
948
969
# Pull up first recv message.
@@ -955,15 +976,241 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
955
976
956
977
# This doesn't raise `TypeError`:
957
978
# See https://github.com/encode/uvicorn/issues/244
979
+ disconnected_message = await receive ()
980
+
981
+ async def websocket_session (url ):
982
+ with pytest .raises (websockets .exceptions .InvalidStatusCode ) as exc_info :
983
+ async with websockets .client .connect (url ):
984
+ pass # pragma: no cover
985
+ assert exc_info .value .status_code == 403
986
+
987
+ config = Config (
988
+ app = app ,
989
+ ws = ws_protocol_cls ,
990
+ http = http_protocol_cls ,
991
+ lifespan = "off" ,
992
+ port = unused_tcp_port ,
993
+ )
994
+ async with run_server (config ):
995
+ await websocket_session (f"ws://127.0.0.1:{ unused_tcp_port } " )
996
+
997
+ assert disconnected_message == {"type" : "websocket.disconnect" , "code" : 1006 }
998
+
999
+
1000
+ @pytest .mark .anyio
1001
+ async def test_server_reject_connection_with_response (
1002
+ ws_protocol_cls : "typing.Type[WSProtocol | WebSocketProtocol]" ,
1003
+ http_protocol_cls : "typing.Type[H11Protocol | HttpToolsProtocol]" ,
1004
+ unused_tcp_port : int ,
1005
+ ):
1006
+ disconnected_message = {}
1007
+
1008
+ async def app (scope , receive , send ):
1009
+ nonlocal disconnected_message
1010
+ assert scope ["type" ] == "websocket"
1011
+ assert "websocket.http.response" in scope ["extensions" ]
1012
+
1013
+ # Pull up first recv message.
958
1014
message = await receive ()
959
- assert message ["type" ] == "websocket.disconnect"
1015
+ assert message ["type" ] == "websocket.connect"
1016
+
1017
+ # Reject the connection with a response
1018
+ response = Response (b"goodbye" , status_code = 400 )
1019
+ await response (scope , receive , send )
1020
+ disconnected_message = await receive ()
1021
+
1022
+ async def websocket_session (url ):
1023
+ response = await wsresponse (url )
1024
+ assert response .status_code == 400
1025
+ assert response .content == b"goodbye"
1026
+
1027
+ config = Config (
1028
+ app = app ,
1029
+ ws = ws_protocol_cls ,
1030
+ http = http_protocol_cls ,
1031
+ lifespan = "off" ,
1032
+ port = unused_tcp_port ,
1033
+ )
1034
+ async with run_server (config ):
1035
+ await websocket_session (f"ws://127.0.0.1:{ unused_tcp_port } " )
1036
+
1037
+ assert disconnected_message == {"type" : "websocket.disconnect" , "code" : 1006 }
1038
+
1039
+
1040
+ @pytest .mark .anyio
1041
+ async def test_server_reject_connection_with_multibody_response (
1042
+ ws_protocol_cls : "typing.Type[WSProtocol | WebSocketProtocol]" ,
1043
+ http_protocol_cls : "typing.Type[H11Protocol | HttpToolsProtocol]" ,
1044
+ unused_tcp_port : int ,
1045
+ ):
1046
+ disconnected_message : ASGIReceiveEvent = {} # type: ignore
1047
+
1048
+ async def app (scope : Scope , receive : ASGIReceiveCallable , send : ASGISendCallable ):
1049
+ nonlocal disconnected_message
1050
+ assert scope ["type" ] == "websocket"
1051
+ assert "extensions" in scope
1052
+ assert "websocket.http.response" in scope ["extensions" ]
1053
+
1054
+ # Pull up first recv message.
1055
+ message = await receive ()
1056
+ assert message ["type" ] == "websocket.connect"
1057
+ await send (
1058
+ {
1059
+ "type" : "websocket.http.response.start" ,
1060
+ "status" : 400 ,
1061
+ "headers" : [
1062
+ (b"Content-Length" , b"20" ),
1063
+ (b"Content-Type" , b"text/plain" ),
1064
+ ],
1065
+ }
1066
+ )
1067
+ await send (
1068
+ {
1069
+ "type" : "websocket.http.response.body" ,
1070
+ "body" : b"x" * 10 ,
1071
+ "more_body" : True ,
1072
+ }
1073
+ )
1074
+ await send ({"type" : "websocket.http.response.body" , "body" : b"y" * 10 })
1075
+ disconnected_message = await receive ()
960
1076
961
1077
async def websocket_session (url : str ):
962
- try :
1078
+ response = await wsresponse (url )
1079
+ assert response .status_code == 400
1080
+ assert response .content == (b"x" * 10 ) + (b"y" * 10 )
1081
+
1082
+ config = Config (
1083
+ app = app ,
1084
+ ws = ws_protocol_cls ,
1085
+ http = http_protocol_cls ,
1086
+ lifespan = "off" ,
1087
+ port = unused_tcp_port ,
1088
+ )
1089
+ async with run_server (config ):
1090
+ await websocket_session (f"ws://127.0.0.1:{ unused_tcp_port } " )
1091
+
1092
+ assert disconnected_message == {"type" : "websocket.disconnect" , "code" : 1006 }
1093
+
1094
+
1095
+ @pytest .mark .anyio
1096
+ async def test_server_reject_connection_with_invalid_status (
1097
+ ws_protocol_cls : "typing.Type[WSProtocol | WebSocketProtocol]" ,
1098
+ http_protocol_cls : "typing.Type[H11Protocol | HttpToolsProtocol]" ,
1099
+ unused_tcp_port : int ,
1100
+ ):
1101
+ # this test checks that even if there is an error in the response, the server
1102
+ # can successfully send a 500 error back to the client
1103
+ async def app (scope , receive , send ):
1104
+ assert scope ["type" ] == "websocket"
1105
+ assert "websocket.http.response" in scope ["extensions" ]
1106
+
1107
+ # Pull up first recv message.
1108
+ message = await receive ()
1109
+ assert message ["type" ] == "websocket.connect"
1110
+
1111
+ message = {
1112
+ "type" : "websocket.http.response.start" ,
1113
+ "status" : 700 , # invalid status code
1114
+ "headers" : [(b"Content-Length" , b"0" ), (b"Content-Type" , b"text/plain" )],
1115
+ }
1116
+ await send (message )
1117
+ message = {
1118
+ "type" : "websocket.http.response.body" ,
1119
+ "body" : b"" ,
1120
+ }
1121
+ await send (message )
1122
+
1123
+ async def websocket_session (url ):
1124
+ response = await wsresponse (url )
1125
+ assert response .status_code == 500
1126
+ assert response .content == b"Internal Server Error"
1127
+
1128
+ config = Config (
1129
+ app = app ,
1130
+ ws = ws_protocol_cls ,
1131
+ http = http_protocol_cls ,
1132
+ lifespan = "off" ,
1133
+ port = unused_tcp_port ,
1134
+ )
1135
+ async with run_server (config ):
1136
+ await websocket_session (f"ws://127.0.0.1:{ unused_tcp_port } " )
1137
+
1138
+
1139
+ @pytest .mark .anyio
1140
+ async def test_server_reject_connection_with_body_nolength (
1141
+ ws_protocol_cls : "typing.Type[WSProtocol | WebSocketProtocol]" ,
1142
+ http_protocol_cls : "typing.Type[H11Protocol | HttpToolsProtocol]" ,
1143
+ unused_tcp_port : int ,
1144
+ ):
1145
+ # test that the server can send a response with a body but no content-length
1146
+ async def app (scope : Scope , receive : ASGIReceiveCallable , send : ASGISendCallable ):
1147
+ assert scope ["type" ] == "websocket"
1148
+ assert "extensions" in scope
1149
+ assert "websocket.http.response" in scope ["extensions" ]
1150
+
1151
+ # Pull up first recv message.
1152
+ message = await receive ()
1153
+ assert message ["type" ] == "websocket.connect"
1154
+
1155
+ await send (
1156
+ {
1157
+ "type" : "websocket.http.response.start" ,
1158
+ "status" : 403 ,
1159
+ "headers" : [],
1160
+ }
1161
+ )
1162
+ await send ({"type" : "websocket.http.response.body" , "body" : b"hardbody" })
1163
+
1164
+ async def websocket_session (url ):
1165
+ response = await wsresponse (url )
1166
+ assert response .status_code == 403
1167
+ assert response .content == b"hardbody"
1168
+ if ws_protocol_cls == WSProtocol : # pragma: no cover
1169
+ # wsproto automatically makes the message chunked
1170
+ assert response .headers ["transfer-encoding" ] == "chunked"
1171
+ else : # pragma: no cover
1172
+ # websockets automatically adds a content-length
1173
+ assert response .headers ["content-length" ] == "8"
1174
+
1175
+ config = Config (
1176
+ app = app ,
1177
+ ws = ws_protocol_cls ,
1178
+ http = http_protocol_cls ,
1179
+ lifespan = "off" ,
1180
+ port = unused_tcp_port ,
1181
+ )
1182
+ async with run_server (config ):
1183
+ await websocket_session (f"ws://127.0.0.1:{ unused_tcp_port } " )
1184
+
1185
+
1186
+ @pytest .mark .anyio
1187
+ async def test_server_reject_connection_with_invalid_msg (
1188
+ ws_protocol_cls : "typing.Type[WSProtocol | WebSocketProtocol]" ,
1189
+ http_protocol_cls : "typing.Type[H11Protocol | HttpToolsProtocol]" ,
1190
+ unused_tcp_port : int ,
1191
+ ):
1192
+ async def app (scope , receive , send ):
1193
+ assert scope ["type" ] == "websocket"
1194
+ assert "websocket.http.response" in scope ["extensions" ]
1195
+
1196
+ # Pull up first recv message.
1197
+ message = await receive ()
1198
+ assert message ["type" ] == "websocket.connect"
1199
+
1200
+ message = {
1201
+ "type" : "websocket.http.response.start" ,
1202
+ "status" : 404 ,
1203
+ "headers" : [(b"Content-Length" , b"0" ), (b"Content-Type" , b"text/plain" )],
1204
+ }
1205
+ await send (message )
1206
+ # send invalid message. This will raise an exception here
1207
+ await send (message )
1208
+
1209
+ async def websocket_session (url ):
1210
+ with pytest .raises (websockets .exceptions .InvalidStatusCode ) as exc_info :
963
1211
async with websockets .client .connect (url ):
964
1212
pass # pragma: no cover
965
- except Exception :
966
- pass
1213
+ assert exc_info .value .status_code == 404
967
1214
968
1215
config = Config (
969
1216
app = app ,
@@ -976,6 +1223,100 @@ async def websocket_session(url: str):
976
1223
await websocket_session (f"ws://127.0.0.1:{ unused_tcp_port } " )
977
1224
978
1225
1226
+ @pytest .mark .anyio
1227
+ async def test_server_reject_connection_with_missing_body (
1228
+ ws_protocol_cls : "typing.Type[WSProtocol | WebSocketProtocol]" ,
1229
+ http_protocol_cls : "typing.Type[H11Protocol | HttpToolsProtocol]" ,
1230
+ unused_tcp_port : int ,
1231
+ ):
1232
+ async def app (scope , receive , send ):
1233
+ assert scope ["type" ] == "websocket"
1234
+ assert "websocket.http.response" in scope ["extensions" ]
1235
+
1236
+ # Pull up first recv message.
1237
+ message = await receive ()
1238
+ assert message ["type" ] == "websocket.connect"
1239
+
1240
+ message = {
1241
+ "type" : "websocket.http.response.start" ,
1242
+ "status" : 404 ,
1243
+ "headers" : [(b"Content-Length" , b"0" ), (b"Content-Type" , b"text/plain" )],
1244
+ }
1245
+ await send (message )
1246
+ # no further message
1247
+
1248
+ async def websocket_session (url ):
1249
+ with pytest .raises (websockets .exceptions .InvalidStatusCode ) as exc_info :
1250
+ async with websockets .client .connect (url ):
1251
+ pass # pragma: no cover
1252
+ assert exc_info .value .status_code == 404
1253
+
1254
+ config = Config (
1255
+ app = app ,
1256
+ ws = ws_protocol_cls ,
1257
+ http = http_protocol_cls ,
1258
+ lifespan = "off" ,
1259
+ port = unused_tcp_port ,
1260
+ )
1261
+ async with run_server (config ):
1262
+ await websocket_session (f"ws://127.0.0.1:{ unused_tcp_port } " )
1263
+
1264
+
1265
+ @pytest .mark .anyio
1266
+ async def test_server_multiple_websocket_http_response_start_events (
1267
+ ws_protocol_cls : "typing.Type[WSProtocol | WebSocketProtocol]" ,
1268
+ http_protocol_cls : "typing.Type[H11Protocol | HttpToolsProtocol]" ,
1269
+ unused_tcp_port : int ,
1270
+ ):
1271
+ """
1272
+ The server should raise an exception if it sends multiple
1273
+ websocket.http.response.start events.
1274
+ """
1275
+ exception_message : typing .Optional [str ] = None
1276
+
1277
+ async def app (scope : Scope , receive : ASGIReceiveCallable , send : ASGISendCallable ):
1278
+ nonlocal exception_message
1279
+ assert scope ["type" ] == "websocket"
1280
+ assert "extensions" in scope
1281
+ assert "websocket.http.response" in scope ["extensions" ]
1282
+
1283
+ # Pull up first recv message.
1284
+ message = await receive ()
1285
+ assert message ["type" ] == "websocket.connect"
1286
+
1287
+ start_event : WebSocketResponseStartEvent = {
1288
+ "type" : "websocket.http.response.start" ,
1289
+ "status" : 404 ,
1290
+ "headers" : [(b"Content-Length" , b"0" ), (b"Content-Type" , b"text/plain" )],
1291
+ }
1292
+ await send (start_event )
1293
+ try :
1294
+ await send (start_event )
1295
+ except Exception as exc :
1296
+ exception_message = str (exc )
1297
+
1298
+ async def websocket_session (url : str ):
1299
+ with pytest .raises (websockets .exceptions .InvalidStatusCode ) as exc_info :
1300
+ async with websockets .client .connect (url ):
1301
+ pass
1302
+ assert exc_info .value .status_code == 404
1303
+
1304
+ config = Config (
1305
+ app = app ,
1306
+ ws = ws_protocol_cls ,
1307
+ http = http_protocol_cls ,
1308
+ lifespan = "off" ,
1309
+ port = unused_tcp_port ,
1310
+ )
1311
+ async with run_server (config ):
1312
+ await websocket_session (f"ws://127.0.0.1:{ unused_tcp_port } " )
1313
+
1314
+ assert exception_message == (
1315
+ "Expected ASGI message 'websocket.http.response.body' but got "
1316
+ "'websocket.http.response.start'."
1317
+ )
1318
+
1319
+
979
1320
@pytest .mark .anyio
980
1321
async def test_server_can_read_messages_in_buffer_after_close (
981
1322
ws_protocol_cls : "typing.Type[WSProtocol | WebSocketProtocol]" ,
0 commit comments