Skip to content

Commit 6568184

Browse files
Support the WebSocket Denial Response ASGI extension (#1916)
Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent 7d274ed commit 6568184

File tree

5 files changed

+447
-15
lines changed

5 files changed

+447
-15
lines changed

tests/protocols/test_websocket.py

Lines changed: 345 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@
1010
from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory
1111
from websockets.typing import Subprotocol
1212

13+
from tests.response import Response
1314
from tests.utils import run_server
1415
from uvicorn._types import (
1516
ASGIReceiveCallable,
17+
ASGIReceiveEvent,
1618
ASGISendCallable,
1719
Scope,
1820
WebSocketCloseEvent,
1921
WebSocketDisconnectEvent,
22+
WebSocketResponseStartEvent,
2023
)
2124
from uvicorn.config import Config
2225
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
@@ -55,6 +58,21 @@ async def asgi(self):
5558
break
5659

5760

61+
async def wsresponse(url):
62+
"""
63+
A simple websocket connection request and response helper
64+
"""
65+
url = url.replace("ws:", "http:")
66+
headers = {
67+
"connection": "upgrade",
68+
"upgrade": "websocket",
69+
"Sec-WebSocket-Key": "x3JJHMbDL1EzLkh9GBhXDw==",
70+
"Sec-WebSocket-Version": "13",
71+
}
72+
async with httpx.AsyncClient() as client:
73+
return await client.get(url, headers=headers)
74+
75+
5876
@pytest.mark.anyio
5977
async def test_invalid_upgrade(
6078
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
@@ -942,7 +960,10 @@ async def test_server_reject_connection(
942960
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
943961
unused_tcp_port: int,
944962
):
963+
disconnected_message: ASGIReceiveEvent = {} # type: ignore
964+
945965
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
966+
nonlocal disconnected_message
946967
assert scope["type"] == "websocket"
947968

948969
# Pull up first recv message.
@@ -955,15 +976,241 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
955976

956977
# This doesn't raise `TypeError`:
957978
# See https://github.com/encode/uvicorn/issues/244
979+
disconnected_message = await receive()
980+
981+
async def websocket_session(url):
982+
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
983+
async with websockets.client.connect(url):
984+
pass # pragma: no cover
985+
assert exc_info.value.status_code == 403
986+
987+
config = Config(
988+
app=app,
989+
ws=ws_protocol_cls,
990+
http=http_protocol_cls,
991+
lifespan="off",
992+
port=unused_tcp_port,
993+
)
994+
async with run_server(config):
995+
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
996+
997+
assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}
998+
999+
1000+
@pytest.mark.anyio
1001+
async def test_server_reject_connection_with_response(
1002+
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
1003+
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
1004+
unused_tcp_port: int,
1005+
):
1006+
disconnected_message = {}
1007+
1008+
async def app(scope, receive, send):
1009+
nonlocal disconnected_message
1010+
assert scope["type"] == "websocket"
1011+
assert "websocket.http.response" in scope["extensions"]
1012+
1013+
# Pull up first recv message.
9581014
message = await receive()
959-
assert message["type"] == "websocket.disconnect"
1015+
assert message["type"] == "websocket.connect"
1016+
1017+
# Reject the connection with a response
1018+
response = Response(b"goodbye", status_code=400)
1019+
await response(scope, receive, send)
1020+
disconnected_message = await receive()
1021+
1022+
async def websocket_session(url):
1023+
response = await wsresponse(url)
1024+
assert response.status_code == 400
1025+
assert response.content == b"goodbye"
1026+
1027+
config = Config(
1028+
app=app,
1029+
ws=ws_protocol_cls,
1030+
http=http_protocol_cls,
1031+
lifespan="off",
1032+
port=unused_tcp_port,
1033+
)
1034+
async with run_server(config):
1035+
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
1036+
1037+
assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}
1038+
1039+
1040+
@pytest.mark.anyio
1041+
async def test_server_reject_connection_with_multibody_response(
1042+
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
1043+
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
1044+
unused_tcp_port: int,
1045+
):
1046+
disconnected_message: ASGIReceiveEvent = {} # type: ignore
1047+
1048+
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
1049+
nonlocal disconnected_message
1050+
assert scope["type"] == "websocket"
1051+
assert "extensions" in scope
1052+
assert "websocket.http.response" in scope["extensions"]
1053+
1054+
# Pull up first recv message.
1055+
message = await receive()
1056+
assert message["type"] == "websocket.connect"
1057+
await send(
1058+
{
1059+
"type": "websocket.http.response.start",
1060+
"status": 400,
1061+
"headers": [
1062+
(b"Content-Length", b"20"),
1063+
(b"Content-Type", b"text/plain"),
1064+
],
1065+
}
1066+
)
1067+
await send(
1068+
{
1069+
"type": "websocket.http.response.body",
1070+
"body": b"x" * 10,
1071+
"more_body": True,
1072+
}
1073+
)
1074+
await send({"type": "websocket.http.response.body", "body": b"y" * 10})
1075+
disconnected_message = await receive()
9601076

9611077
async def websocket_session(url: str):
962-
try:
1078+
response = await wsresponse(url)
1079+
assert response.status_code == 400
1080+
assert response.content == (b"x" * 10) + (b"y" * 10)
1081+
1082+
config = Config(
1083+
app=app,
1084+
ws=ws_protocol_cls,
1085+
http=http_protocol_cls,
1086+
lifespan="off",
1087+
port=unused_tcp_port,
1088+
)
1089+
async with run_server(config):
1090+
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
1091+
1092+
assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}
1093+
1094+
1095+
@pytest.mark.anyio
1096+
async def test_server_reject_connection_with_invalid_status(
1097+
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
1098+
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
1099+
unused_tcp_port: int,
1100+
):
1101+
# this test checks that even if there is an error in the response, the server
1102+
# can successfully send a 500 error back to the client
1103+
async def app(scope, receive, send):
1104+
assert scope["type"] == "websocket"
1105+
assert "websocket.http.response" in scope["extensions"]
1106+
1107+
# Pull up first recv message.
1108+
message = await receive()
1109+
assert message["type"] == "websocket.connect"
1110+
1111+
message = {
1112+
"type": "websocket.http.response.start",
1113+
"status": 700, # invalid status code
1114+
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
1115+
}
1116+
await send(message)
1117+
message = {
1118+
"type": "websocket.http.response.body",
1119+
"body": b"",
1120+
}
1121+
await send(message)
1122+
1123+
async def websocket_session(url):
1124+
response = await wsresponse(url)
1125+
assert response.status_code == 500
1126+
assert response.content == b"Internal Server Error"
1127+
1128+
config = Config(
1129+
app=app,
1130+
ws=ws_protocol_cls,
1131+
http=http_protocol_cls,
1132+
lifespan="off",
1133+
port=unused_tcp_port,
1134+
)
1135+
async with run_server(config):
1136+
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
1137+
1138+
1139+
@pytest.mark.anyio
1140+
async def test_server_reject_connection_with_body_nolength(
1141+
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
1142+
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
1143+
unused_tcp_port: int,
1144+
):
1145+
# test that the server can send a response with a body but no content-length
1146+
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
1147+
assert scope["type"] == "websocket"
1148+
assert "extensions" in scope
1149+
assert "websocket.http.response" in scope["extensions"]
1150+
1151+
# Pull up first recv message.
1152+
message = await receive()
1153+
assert message["type"] == "websocket.connect"
1154+
1155+
await send(
1156+
{
1157+
"type": "websocket.http.response.start",
1158+
"status": 403,
1159+
"headers": [],
1160+
}
1161+
)
1162+
await send({"type": "websocket.http.response.body", "body": b"hardbody"})
1163+
1164+
async def websocket_session(url):
1165+
response = await wsresponse(url)
1166+
assert response.status_code == 403
1167+
assert response.content == b"hardbody"
1168+
if ws_protocol_cls == WSProtocol: # pragma: no cover
1169+
# wsproto automatically makes the message chunked
1170+
assert response.headers["transfer-encoding"] == "chunked"
1171+
else: # pragma: no cover
1172+
# websockets automatically adds a content-length
1173+
assert response.headers["content-length"] == "8"
1174+
1175+
config = Config(
1176+
app=app,
1177+
ws=ws_protocol_cls,
1178+
http=http_protocol_cls,
1179+
lifespan="off",
1180+
port=unused_tcp_port,
1181+
)
1182+
async with run_server(config):
1183+
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
1184+
1185+
1186+
@pytest.mark.anyio
1187+
async def test_server_reject_connection_with_invalid_msg(
1188+
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
1189+
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
1190+
unused_tcp_port: int,
1191+
):
1192+
async def app(scope, receive, send):
1193+
assert scope["type"] == "websocket"
1194+
assert "websocket.http.response" in scope["extensions"]
1195+
1196+
# Pull up first recv message.
1197+
message = await receive()
1198+
assert message["type"] == "websocket.connect"
1199+
1200+
message = {
1201+
"type": "websocket.http.response.start",
1202+
"status": 404,
1203+
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
1204+
}
1205+
await send(message)
1206+
# send invalid message. This will raise an exception here
1207+
await send(message)
1208+
1209+
async def websocket_session(url):
1210+
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
9631211
async with websockets.client.connect(url):
9641212
pass # pragma: no cover
965-
except Exception:
966-
pass
1213+
assert exc_info.value.status_code == 404
9671214

9681215
config = Config(
9691216
app=app,
@@ -976,6 +1223,100 @@ async def websocket_session(url: str):
9761223
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
9771224

9781225

1226+
@pytest.mark.anyio
1227+
async def test_server_reject_connection_with_missing_body(
1228+
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
1229+
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
1230+
unused_tcp_port: int,
1231+
):
1232+
async def app(scope, receive, send):
1233+
assert scope["type"] == "websocket"
1234+
assert "websocket.http.response" in scope["extensions"]
1235+
1236+
# Pull up first recv message.
1237+
message = await receive()
1238+
assert message["type"] == "websocket.connect"
1239+
1240+
message = {
1241+
"type": "websocket.http.response.start",
1242+
"status": 404,
1243+
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
1244+
}
1245+
await send(message)
1246+
# no further message
1247+
1248+
async def websocket_session(url):
1249+
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
1250+
async with websockets.client.connect(url):
1251+
pass # pragma: no cover
1252+
assert exc_info.value.status_code == 404
1253+
1254+
config = Config(
1255+
app=app,
1256+
ws=ws_protocol_cls,
1257+
http=http_protocol_cls,
1258+
lifespan="off",
1259+
port=unused_tcp_port,
1260+
)
1261+
async with run_server(config):
1262+
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
1263+
1264+
1265+
@pytest.mark.anyio
1266+
async def test_server_multiple_websocket_http_response_start_events(
1267+
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
1268+
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
1269+
unused_tcp_port: int,
1270+
):
1271+
"""
1272+
The server should raise an exception if it sends multiple
1273+
websocket.http.response.start events.
1274+
"""
1275+
exception_message: typing.Optional[str] = None
1276+
1277+
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
1278+
nonlocal exception_message
1279+
assert scope["type"] == "websocket"
1280+
assert "extensions" in scope
1281+
assert "websocket.http.response" in scope["extensions"]
1282+
1283+
# Pull up first recv message.
1284+
message = await receive()
1285+
assert message["type"] == "websocket.connect"
1286+
1287+
start_event: WebSocketResponseStartEvent = {
1288+
"type": "websocket.http.response.start",
1289+
"status": 404,
1290+
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
1291+
}
1292+
await send(start_event)
1293+
try:
1294+
await send(start_event)
1295+
except Exception as exc:
1296+
exception_message = str(exc)
1297+
1298+
async def websocket_session(url: str):
1299+
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
1300+
async with websockets.client.connect(url):
1301+
pass
1302+
assert exc_info.value.status_code == 404
1303+
1304+
config = Config(
1305+
app=app,
1306+
ws=ws_protocol_cls,
1307+
http=http_protocol_cls,
1308+
lifespan="off",
1309+
port=unused_tcp_port,
1310+
)
1311+
async with run_server(config):
1312+
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
1313+
1314+
assert exception_message == (
1315+
"Expected ASGI message 'websocket.http.response.body' but got "
1316+
"'websocket.http.response.start'."
1317+
)
1318+
1319+
9791320
@pytest.mark.anyio
9801321
async def test_server_can_read_messages_in_buffer_after_close(
9811322
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",

0 commit comments

Comments
 (0)