Skip to content

Commit c91b4c2

Browse files
committed
Use constant-time comparison for passwords.
1 parent 920207c commit c91b4c2

File tree

3 files changed

+27
-16
lines changed

3 files changed

+27
-16
lines changed

docs/project/changelog.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ They may change at any time.
4545

4646
* Optimized default compression settings to reduce memory usage.
4747

48+
* Protected against timing attacks on HTTP Basic Auth.
49+
4850
* Made it easier to customize authentication with
4951
:meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`.
5052

src/websockets/legacy/auth.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
import functools
9+
import hmac
910
import http
1011
from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Union, cast
1112

@@ -154,24 +155,23 @@ def basic_auth_protocol_factory(
154155

155156
if credentials is not None:
156157
if is_credentials(credentials):
157-
158-
async def check_credentials(username: str, password: str) -> bool:
159-
return (username, password) == credentials
160-
158+
credentials_list = [cast(Credentials, credentials)]
161159
elif isinstance(credentials, Iterable):
162160
credentials_list = list(credentials)
163-
if all(is_credentials(item) for item in credentials_list):
164-
credentials_dict = dict(credentials_list)
165-
166-
async def check_credentials(username: str, password: str) -> bool:
167-
return credentials_dict.get(username) == password
168-
169-
else:
161+
if not all(is_credentials(item) for item in credentials_list):
170162
raise TypeError(f"invalid credentials argument: {credentials}")
171-
172163
else:
173164
raise TypeError(f"invalid credentials argument: {credentials}")
174165

166+
credentials_dict = dict(credentials_list)
167+
168+
async def check_credentials(username: str, password: str) -> bool:
169+
try:
170+
expected_password = credentials_dict[username]
171+
except KeyError:
172+
return False
173+
return hmac.compare_digest(expected_password, password)
174+
175175
if create_protocol is None:
176176
# Not sure why mypy cannot figure this out.
177177
create_protocol = cast(
@@ -180,5 +180,7 @@ async def check_credentials(username: str, password: str) -> bool:
180180
)
181181

182182
return functools.partial(
183-
create_protocol, realm=realm, check_credentials=check_credentials
183+
create_protocol,
184+
realm=realm,
185+
check_credentials=check_credentials,
184186
)

tests/legacy/test_auth.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import hmac
12
import unittest
23
import urllib.error
34

@@ -27,7 +28,7 @@ async def process_request(self, path, request_headers):
2728

2829
class CheckWebSocketServerProtocol(BasicAuthWebSocketServerProtocol):
2930
async def check_credentials(self, username, password):
30-
return password == "letmein"
31+
return hmac.compare_digest(password, "letmein")
3132

3233

3334
class AuthClientServerTests(ClientServerTestsMixin, AsyncioTestCase):
@@ -81,7 +82,7 @@ def test_basic_auth_bad_multiple_credentials(self):
8182
)
8283

8384
async def check_credentials(username, password):
84-
return password == "iloveyou"
85+
return hmac.compare_digest(password, "iloveyou")
8586

8687
create_protocol_check_credentials = basic_auth_protocol_factory(
8788
realm="auth-tests",
@@ -158,7 +159,13 @@ def test_basic_auth_unsupported_credentials_details(self):
158159
self.assertEqual(raised.exception.read().decode(), "Unsupported credentials\n")
159160

160161
@with_server(create_protocol=create_protocol)
161-
def test_basic_auth_invalid_credentials(self):
162+
def test_basic_auth_invalid_username(self):
163+
with self.assertRaises(InvalidStatusCode) as raised:
164+
self.start_client(user_info=("goodbye", "iloveyou"))
165+
self.assertEqual(raised.exception.status_code, 401)
166+
167+
@with_server(create_protocol=create_protocol)
168+
def test_basic_auth_invalid_password(self):
162169
with self.assertRaises(InvalidStatusCode) as raised:
163170
self.start_client(user_info=("hello", "ihateyou"))
164171
self.assertEqual(raised.exception.status_code, 401)

0 commit comments

Comments
 (0)