|
13 | 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | 14 | # See the License for the specific language governing permissions and |
15 | 15 | # limitations under the License. |
| 16 | +from typing import Iterable |
16 | 17 | from unittest import mock |
17 | 18 |
|
| 19 | +from parameterized import parameterized |
18 | 20 | from signedjson import key as key, sign as sign |
19 | 21 |
|
20 | 22 | from twisted.internet import defer |
|
23 | 25 | from synapse.api.errors import Codes, SynapseError |
24 | 26 |
|
25 | 27 | from tests import unittest |
| 28 | +from tests.test_utils import make_awaitable |
26 | 29 |
|
27 | 30 |
|
28 | 31 | class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): |
@@ -765,6 +768,8 @@ def test_query_devices_remote_sync(self): |
765 | 768 | remote_user_id = "@test:other" |
766 | 769 | local_user_id = "@test:test" |
767 | 770 |
|
| 771 | + # Pretend we're sharing a room with the user we're querying. If not, |
| 772 | + # `_query_devices_for_destination` will return early. |
768 | 773 | self.store.get_rooms_for_user = mock.Mock( |
769 | 774 | return_value=defer.succeed({"some_room_id"}) |
770 | 775 | ) |
@@ -831,3 +836,94 @@ def test_query_devices_remote_sync(self): |
831 | 836 | } |
832 | 837 | }, |
833 | 838 | ) |
| 839 | + |
| 840 | + @parameterized.expand( |
| 841 | + [ |
| 842 | + # The remote homeserver's response indicates that this user has 0/1/2 devices. |
| 843 | + ([],), |
| 844 | + (["device_1"],), |
| 845 | + (["device_1", "device_2"],), |
| 846 | + ] |
| 847 | + ) |
| 848 | + def test_query_all_devices_caches_result(self, device_ids: Iterable[str]): |
| 849 | + """Test that requests for all of a remote user's devices are cached. |
| 850 | +
|
| 851 | + We do this by asserting that only one call over federation was made, and that |
| 852 | + the two queries to the local homeserver produce the same response. |
| 853 | + """ |
| 854 | + local_user_id = "@test:test" |
| 855 | + remote_user_id = "@test:other" |
| 856 | + request_body = {"device_keys": {remote_user_id: []}} |
| 857 | + |
| 858 | + response_devices = [ |
| 859 | + { |
| 860 | + "device_id": device_id, |
| 861 | + "keys": { |
| 862 | + "algorithms": ["dummy"], |
| 863 | + "device_id": device_id, |
| 864 | + "keys": {f"dummy:{device_id}": "dummy"}, |
| 865 | + "signatures": {device_id: {f"dummy:{device_id}": "dummy"}}, |
| 866 | + "unsigned": {}, |
| 867 | + "user_id": "@test:other", |
| 868 | + }, |
| 869 | + } |
| 870 | + for device_id in device_ids |
| 871 | + ] |
| 872 | + |
| 873 | + response_body = { |
| 874 | + "devices": response_devices, |
| 875 | + "user_id": remote_user_id, |
| 876 | + "stream_id": 12345, # an integer, according to the spec |
| 877 | + } |
| 878 | + |
| 879 | + e2e_handler = self.hs.get_e2e_keys_handler() |
| 880 | + |
| 881 | + # Pretend we're sharing a room with the user we're querying. If not, |
| 882 | + # `_query_devices_for_destination` will return early. |
| 883 | + mock_get_rooms = mock.patch.object( |
| 884 | + self.store, |
| 885 | + "get_rooms_for_user", |
| 886 | + new_callable=mock.MagicMock, |
| 887 | + return_value=make_awaitable(["some_room_id"]), |
| 888 | + ) |
| 889 | + mock_request = mock.patch.object( |
| 890 | + self.hs.get_federation_client(), |
| 891 | + "query_user_devices", |
| 892 | + new_callable=mock.MagicMock, |
| 893 | + return_value=make_awaitable(response_body), |
| 894 | + ) |
| 895 | + |
| 896 | + with mock_get_rooms, mock_request as mocked_federation_request: |
| 897 | + # Make the first query and sanity check it succeeds. |
| 898 | + response_1 = self.get_success( |
| 899 | + e2e_handler.query_devices( |
| 900 | + request_body, |
| 901 | + timeout=10, |
| 902 | + from_user_id=local_user_id, |
| 903 | + from_device_id="some_device_id", |
| 904 | + ) |
| 905 | + ) |
| 906 | + self.assertEqual(response_1["failures"], {}) |
| 907 | + |
| 908 | + # We should have made a federation request to do so. |
| 909 | + mocked_federation_request.assert_called_once() |
| 910 | + |
| 911 | + # Reset the mock so we can prove we don't make a second federation request. |
| 912 | + mocked_federation_request.reset_mock() |
| 913 | + |
| 914 | + # Repeat the query. |
| 915 | + response_2 = self.get_success( |
| 916 | + e2e_handler.query_devices( |
| 917 | + request_body, |
| 918 | + timeout=10, |
| 919 | + from_user_id=local_user_id, |
| 920 | + from_device_id="some_device_id", |
| 921 | + ) |
| 922 | + ) |
| 923 | + self.assertEqual(response_2["failures"], {}) |
| 924 | + |
| 925 | + # We should not have made a second federation request. |
| 926 | + mocked_federation_request.assert_not_called() |
| 927 | + |
| 928 | + # The two requests to the local homeserver should be identical. |
| 929 | + self.assertEqual(response_1, response_2) |
0 commit comments