Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 9e06e22

Browse files
authored
Add type hints to more tests files. (#12240)
1 parent 3f7cfbc commit 9e06e22

File tree

6 files changed

+66
-47
lines changed

6 files changed

+66
-47
lines changed

changelog.d/12240.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add type hints to tests files.

mypy.ini

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,6 @@ exclude = (?x)
6666
|tests/federation/test_federation_server.py
6767
|tests/federation/transport/test_knocking.py
6868
|tests/federation/transport/test_server.py
69-
|tests/handlers/test_cas.py
70-
|tests/handlers/test_federation.py
71-
|tests/handlers/test_presence.py
7269
|tests/handlers/test_typing.py
7370
|tests/http/federation/test_matrix_federation_agent.py
7471
|tests/http/federation/test_srv_resolver.py
@@ -80,7 +77,6 @@ exclude = (?x)
8077
|tests/logging/test_terse_json.py
8178
|tests/module_api/test_api.py
8279
|tests/push/test_email.py
83-
|tests/push/test_http.py
8480
|tests/push/test_presentable_names.py
8581
|tests/push/test_push_rule_evaluator.py
8682
|tests/rest/client/test_transactions.py

tests/handlers/test_cas.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Any, Dict
1415
from unittest.mock import Mock
1516

17+
from twisted.test.proto_helpers import MemoryReactor
18+
1619
from synapse.handlers.cas import CasResponse
20+
from synapse.server import HomeServer
21+
from synapse.util import Clock
1722

1823
from tests.test_utils import simple_async_mock
1924
from tests.unittest import HomeserverTestCase, override_config
@@ -24,7 +29,7 @@
2429

2530

2631
class CasHandlerTestCase(HomeserverTestCase):
27-
def default_config(self):
32+
def default_config(self) -> Dict[str, Any]:
2833
config = super().default_config()
2934
config["public_baseurl"] = BASE_URL
3035
cas_config = {
@@ -40,7 +45,7 @@ def default_config(self):
4045

4146
return config
4247

43-
def make_homeserver(self, reactor, clock):
48+
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
4449
hs = self.setup_test_homeserver()
4550

4651
self.handler = hs.get_cas_handler()
@@ -51,7 +56,7 @@ def make_homeserver(self, reactor, clock):
5156

5257
return hs
5358

54-
def test_map_cas_user_to_user(self):
59+
def test_map_cas_user_to_user(self) -> None:
5560
"""Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
5661

5762
# stub out the auth handler
@@ -75,7 +80,7 @@ def test_map_cas_user_to_user(self):
7580
auth_provider_session_id=None,
7681
)
7782

78-
def test_map_cas_user_to_existing_user(self):
83+
def test_map_cas_user_to_existing_user(self) -> None:
7984
"""Existing users can log in with CAS account."""
8085
store = self.hs.get_datastores().main
8186
self.get_success(
@@ -119,7 +124,7 @@ def test_map_cas_user_to_existing_user(self):
119124
auth_provider_session_id=None,
120125
)
121126

122-
def test_map_cas_user_to_invalid_localpart(self):
127+
def test_map_cas_user_to_invalid_localpart(self) -> None:
123128
"""CAS automaps invalid characters to base-64 encoding."""
124129

125130
# stub out the auth handler
@@ -150,7 +155,7 @@ def test_map_cas_user_to_invalid_localpart(self):
150155
}
151156
}
152157
)
153-
def test_required_attributes(self):
158+
def test_required_attributes(self) -> None:
154159
"""The required attributes must be met from the CAS response."""
155160

156161
# stub out the auth handler
@@ -166,7 +171,7 @@ def test_required_attributes(self):
166171
auth_handler.complete_sso_login.assert_not_called()
167172

168173
# The response doesn't have any department.
169-
cas_response = CasResponse("test_user", {"userGroup": "staff"})
174+
cas_response = CasResponse("test_user", {"userGroup": ["staff"]})
170175
request.reset_mock()
171176
self.get_success(
172177
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")

tests/handlers/test_federation.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
from typing import List
15+
from typing import List, cast
1616
from unittest import TestCase
1717

18+
from twisted.test.proto_helpers import MemoryReactor
19+
1820
from synapse.api.constants import EventTypes
1921
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
2022
from synapse.api.room_versions import RoomVersions
@@ -23,7 +25,9 @@
2325
from synapse.logging.context import LoggingContext, run_in_background
2426
from synapse.rest import admin
2527
from synapse.rest.client import login, room
28+
from synapse.server import HomeServer
2629
from synapse.types import create_requester
30+
from synapse.util import Clock
2731
from synapse.util.stringutils import random_string
2832

2933
from tests import unittest
@@ -42,15 +46,15 @@ class FederationTestCase(unittest.HomeserverTestCase):
4246
room.register_servlets,
4347
]
4448

45-
def make_homeserver(self, reactor, clock):
49+
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
4650
hs = self.setup_test_homeserver(federation_http_client=None)
4751
self.handler = hs.get_federation_handler()
4852
self.store = hs.get_datastores().main
4953
self.state_store = hs.get_storage().state
5054
self._event_auth_handler = hs.get_event_auth_handler()
5155
return hs
5256

53-
def test_exchange_revoked_invite(self):
57+
def test_exchange_revoked_invite(self) -> None:
5458
user_id = self.register_user("kermit", "test")
5559
tok = self.login("kermit", "test")
5660

@@ -96,7 +100,7 @@ def test_exchange_revoked_invite(self):
96100
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
97101
self.assertEqual(failure.msg, "You are not invited to this room.")
98102

99-
def test_rejected_message_event_state(self):
103+
def test_rejected_message_event_state(self) -> None:
100104
"""
101105
Check that we store the state group correctly for rejected non-state events.
102106
@@ -126,7 +130,7 @@ def test_rejected_message_event_state(self):
126130
"content": {},
127131
"room_id": room_id,
128132
"sender": "@yetanotheruser:" + OTHER_SERVER,
129-
"depth": join_event["depth"] + 1,
133+
"depth": cast(int, join_event["depth"]) + 1,
130134
"prev_events": [join_event.event_id],
131135
"auth_events": [],
132136
"origin_server_ts": self.clock.time_msec(),
@@ -149,7 +153,7 @@ def test_rejected_message_event_state(self):
149153

150154
self.assertEqual(sg, sg2)
151155

152-
def test_rejected_state_event_state(self):
156+
def test_rejected_state_event_state(self) -> None:
153157
"""
154158
Check that we store the state group correctly for rejected state events.
155159
@@ -180,7 +184,7 @@ def test_rejected_state_event_state(self):
180184
"content": {},
181185
"room_id": room_id,
182186
"sender": "@yetanotheruser:" + OTHER_SERVER,
183-
"depth": join_event["depth"] + 1,
187+
"depth": cast(int, join_event["depth"]) + 1,
184188
"prev_events": [join_event.event_id],
185189
"auth_events": [],
186190
"origin_server_ts": self.clock.time_msec(),
@@ -203,7 +207,7 @@ def test_rejected_state_event_state(self):
203207

204208
self.assertEqual(sg, sg2)
205209

206-
def test_backfill_with_many_backward_extremities(self):
210+
def test_backfill_with_many_backward_extremities(self) -> None:
207211
"""
208212
Check that we can backfill with many backward extremities.
209213
The goal is to make sure that when we only use a portion
@@ -262,7 +266,7 @@ def test_backfill_with_many_backward_extremities(self):
262266
)
263267
self.get_success(d)
264268

265-
def test_backfill_floating_outlier_membership_auth(self):
269+
def test_backfill_floating_outlier_membership_auth(self) -> None:
266270
"""
267271
As the local homeserver, check that we can properly process a federated
268272
event from the OTHER_SERVER with auth_events that include a floating
@@ -377,7 +381,7 @@ async def get_event_auth(
377381
for ae in auth_events
378382
]
379383

380-
self.handler.federation_client.get_event_auth = get_event_auth
384+
self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment]
381385

382386
with LoggingContext("receive_pdu"):
383387
# Fake the OTHER_SERVER federating the message event over to our local homeserver
@@ -397,7 +401,7 @@ async def get_event_auth(
397401
@unittest.override_config(
398402
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
399403
)
400-
def test_invite_by_user_ratelimit(self):
404+
def test_invite_by_user_ratelimit(self) -> None:
401405
"""Tests that invites from federation to a particular user are
402406
actually rate-limited.
403407
"""
@@ -446,7 +450,9 @@ def create_invite():
446450
exc=LimitExceededError,
447451
)
448452

449-
def _build_and_send_join_event(self, other_server, other_user, room_id):
453+
def _build_and_send_join_event(
454+
self, other_server: str, other_user: str, room_id: str
455+
) -> EventBase:
450456
join_event = self.get_success(
451457
self.handler.on_make_join_request(other_server, room_id, other_user)
452458
)
@@ -469,7 +475,7 @@ def _build_and_send_join_event(self, other_server, other_user, room_id):
469475

470476

471477
class EventFromPduTestCase(TestCase):
472-
def test_valid_json(self):
478+
def test_valid_json(self) -> None:
473479
"""Valid JSON should be turned into an event."""
474480
ev = event_from_pdu_json(
475481
{
@@ -487,7 +493,7 @@ def test_valid_json(self):
487493

488494
self.assertIsInstance(ev, EventBase)
489495

490-
def test_invalid_numbers(self):
496+
def test_invalid_numbers(self) -> None:
491497
"""Invalid values for an integer should be rejected, all floats should be rejected."""
492498
for value in [
493499
-(2 ** 53),
@@ -512,7 +518,7 @@ def test_invalid_numbers(self):
512518
RoomVersions.V6,
513519
)
514520

515-
def test_invalid_nested(self):
521+
def test_invalid_nested(self) -> None:
516522
"""List and dictionaries are recursively searched."""
517523
with self.assertRaises(SynapseError):
518524
event_from_pdu_json(

tests/handlers/test_presence.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,11 @@ def test_persisting_presence_updates(self):
331331

332332
# Extract presence update user ID and state information into lists of tuples
333333
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
334-
presence_states = [(ps.user_id, ps.state) for ps in presence_states]
334+
presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
335335

336336
# Compare what we put into the storage with what we got out.
337337
# They should be identical.
338-
self.assertEqual(presence_states, db_presence_states)
338+
self.assertEqual(presence_states_compare, db_presence_states)
339339

340340

341341
class PresenceTimeoutTestCase(unittest.TestCase):
@@ -357,6 +357,7 @@ def test_idle_timer(self):
357357
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
358358

359359
self.assertIsNotNone(new_state)
360+
assert new_state is not None
360361
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
361362
self.assertEqual(new_state.status_msg, status_msg)
362363

@@ -380,6 +381,7 @@ def test_busy_no_idle(self):
380381
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
381382

382383
self.assertIsNotNone(new_state)
384+
assert new_state is not None
383385
self.assertEqual(new_state.state, PresenceState.BUSY)
384386
self.assertEqual(new_state.status_msg, status_msg)
385387

@@ -399,6 +401,7 @@ def test_sync_timeout(self):
399401
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
400402

401403
self.assertIsNotNone(new_state)
404+
assert new_state is not None
402405
self.assertEqual(new_state.state, PresenceState.OFFLINE)
403406
self.assertEqual(new_state.status_msg, status_msg)
404407

@@ -420,6 +423,7 @@ def test_sync_online(self):
420423
)
421424

422425
self.assertIsNotNone(new_state)
426+
assert new_state is not None
423427
self.assertEqual(new_state.state, PresenceState.ONLINE)
424428
self.assertEqual(new_state.status_msg, status_msg)
425429

@@ -477,6 +481,7 @@ def test_federation_timeout(self):
477481
)
478482

479483
self.assertIsNotNone(new_state)
484+
assert new_state is not None
480485
self.assertEqual(new_state.state, PresenceState.OFFLINE)
481486
self.assertEqual(new_state.status_msg, status_msg)
482487

@@ -653,13 +658,13 @@ def test_set_presence_with_status_msg_none(self):
653658
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
654659

655660
def _set_presencestate_with_status_msg(
656-
self, user_id: str, state: PresenceState, status_msg: Optional[str]
661+
self, user_id: str, state: str, status_msg: Optional[str]
657662
):
658663
"""Set a PresenceState and status_msg and check the result.
659664
660665
Args:
661666
user_id: User for that the status is to be set.
662-
PresenceState: The new PresenceState.
667+
state: The new PresenceState.
663668
status_msg: Status message that is to be set.
664669
"""
665670
self.get_success(

0 commit comments

Comments
 (0)