diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index 3b4026cb73869a..2eac64c8165fcc 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -145,6 +145,27 @@ def test_ipaddr_info_no_inet_pton(self, m_socket): socket.SOCK_STREAM, socket.IPPROTO_TCP)) + def test_interleave_addrinfos(self): + SIX_A = (socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)) + SIX_B = (socket.AF_INET6, 0, 0, '', ('2001:db8::2', 2)) + SIX_C = (socket.AF_INET6, 0, 0, '', ('2001:db8::3', 3)) + SIX_D = (socket.AF_INET6, 0, 0, '', ('2001:db8::4', 4)) + FOUR_A = (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)) + FOUR_B = (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6)) + FOUR_C = (socket.AF_INET, 0, 0, '', ('192.0.2.3', 7)) + FOUR_D = (socket.AF_INET, 0, 0, '', ('192.0.2.4', 8)) + + addrinfos = [SIX_A, SIX_B, SIX_C, SIX_D, FOUR_A, FOUR_B, FOUR_C, FOUR_D] + expected = [SIX_A, FOUR_A, SIX_B, FOUR_B, SIX_C, FOUR_C, SIX_D, FOUR_D] + + self.assertEqual(expected, base_events._interleave_addrinfos(addrinfos)) + + expected_fafc_2 = [SIX_A, SIX_B, FOUR_A, SIX_C, FOUR_B, SIX_D, FOUR_C, FOUR_D] + self.assertEqual( + expected_fafc_2, + base_events._interleave_addrinfos(addrinfos, first_address_family_count=2), + ) + class BaseEventLoopTests(test_utils.TestCase): @@ -1426,6 +1447,65 @@ def getaddrinfo_task(*args, **kwds): self.assertRaises( OSError, self.loop.run_until_complete, coro) + @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'no IPv6 support') + @patch_socket + def test_create_connection_happy_eyeballs(self, m_socket): + + class MyProto(asyncio.Protocol): + pass + + async def getaddrinfo(*args, **kw): + return [(socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)), + (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5))] + + async def sock_connect(sock, address): + if address[0] == '2001:db8::1': + await asyncio.sleep(1) + sock.connect(address) + + self.loop._add_reader = mock.Mock() + self.loop._add_writer = mock.Mock() + self.loop.getaddrinfo = getaddrinfo + self.loop.sock_connect = sock_connect + + coro = self.loop.create_connection(MyProto, 'example.com', 80, happy_eyeballs_delay=0.3) + transport, protocol = self.loop.run_until_complete(coro) + try: + sock = transport._sock + sock.connect.assert_called_with(('192.0.2.1', 5)) + finally: + transport.close() + test_utils.run_briefly(self.loop) # allow transport to close + + @patch_socket + def test_create_connection_happy_eyeballs_ipv4_only(self, m_socket): + + class MyProto(asyncio.Protocol): + pass + + async def getaddrinfo(*args, **kw): + return [(socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)), + (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6))] + + async def sock_connect(sock, address): + if address[0] == '192.0.2.1': + await asyncio.sleep(1) + sock.connect(address) + + self.loop._add_reader = mock.Mock() + self.loop._add_writer = mock.Mock() + self.loop.getaddrinfo = getaddrinfo + self.loop.sock_connect = sock_connect + + coro = self.loop.create_connection(MyProto, 'example.com', 80, happy_eyeballs_delay=0.3) + transport, protocol = self.loop.run_until_complete(coro) + try: + sock = transport._sock + sock.connect.assert_called_with(('192.0.2.2', 6)) + finally: + transport.close() + test_utils.run_briefly(self.loop) # allow transport to close + @patch_socket def test_create_connection_bluetooth(self, m_socket): # See http://bugs.python.org/issue27136, fallback to getaddrinfo when diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py new file mode 100644 index 00000000000000..775f6f0901fa59 --- /dev/null +++ b/Lib/test/test_asyncio/test_staggered.py @@ -0,0 +1,115 @@ +import asyncio +import functools +import unittest +from asyncio.staggered import staggered_race + + +# To prevent a warning "test altered the execution environment" +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class TestStaggered(unittest.IsolatedAsyncioTestCase): + @staticmethod + async def waiting_coroutine(return_value, wait_seconds, success): + await asyncio.sleep(wait_seconds) + if success: + return return_value + raise RuntimeError(str(return_value)) + + def get_waiting_coroutine_factory(self, return_value, wait_seconds, success): + return functools.partial(self.waiting_coroutine, return_value, wait_seconds, success) + + async def test_single_success(self): + winner_result, winner_idx, exceptions = await staggered_race( + (self.get_waiting_coroutine_factory(0, 0.1, True),), + 0.1, + ) + self.assertEqual(winner_result, 0) + self.assertEqual(winner_idx, 0) + self.assertEqual(len(exceptions), 1) + self.assertIsNone(exceptions[0]) + + async def test_single_fail(self): + winner_result, winner_idx, exceptions = await staggered_race( + (self.get_waiting_coroutine_factory(0, 0.1, False),), + 0.1, + ) + self.assertEqual(winner_result, None) + self.assertEqual(winner_idx, None) + self.assertEqual(len(exceptions), 1) + self.assertIsInstance(exceptions[0], RuntimeError) + + async def test_first_win(self): + winner_result, winner_idx, exceptions = await staggered_race( + ( + self.get_waiting_coroutine_factory(0, 0.2, True), + self.get_waiting_coroutine_factory(1, 0.2, True), + ), + 0.1, + ) + self.assertEqual(winner_result, 0) + self.assertEqual(winner_idx, 0) + self.assertEqual(len(exceptions), 2) + self.assertIsNone(exceptions[0]) + self.assertIsInstance(exceptions[1], asyncio.CancelledError) + + async def test_second_win(self): + winner_result, winner_idx, exceptions = await staggered_race( + ( + self.get_waiting_coroutine_factory(0, 0.3, True), + self.get_waiting_coroutine_factory(1, 0.1, True), + ), + 0.1, + ) + self.assertEqual(winner_result, 1) + self.assertEqual(winner_idx, 1) + self.assertEqual(len(exceptions), 2) + self.assertIsInstance(exceptions[0], asyncio.CancelledError) + self.assertIsNone(exceptions[1]) + + async def test_first_fail(self): + winner_result, winner_idx, exceptions = await staggered_race( + ( + self.get_waiting_coroutine_factory(0, 0.2, False), + self.get_waiting_coroutine_factory(1, 0.2, True), + ), + 0.1, + ) + self.assertEqual(winner_result, 1) + self.assertEqual(winner_idx, 1) + self.assertEqual(len(exceptions), 2) + self.assertIsInstance(exceptions[0], RuntimeError) + self.assertIsNone(exceptions[1]) + + async def test_second_fail(self): + winner_result, winner_idx, exceptions = await staggered_race( + ( + self.get_waiting_coroutine_factory(0, 0.2, True), + self.get_waiting_coroutine_factory(1, 0, False), + ), + 0.1, + ) + self.assertEqual(winner_result, 0) + self.assertEqual(winner_idx, 0) + self.assertEqual(len(exceptions), 2) + self.assertIsNone(exceptions[0]) + self.assertIsInstance(exceptions[1], RuntimeError) + + async def test_simultaneous_success_fail(self): + # There's a potential race condition here: + # https://github.com/python/cpython/issues/86296 + # As with any race condition, it can be difficult to reproduce. + # This test may not fail every time. + for i in range(201): + time_unit = 0.0001 * i + winner_result, winner_idx, exceptions = await staggered_race( + ( + self.get_waiting_coroutine_factory(0, time_unit*2, True), + self.get_waiting_coroutine_factory(1, time_unit, False), + self.get_waiting_coroutine_factory(2, 0.05, True) + ), + time_unit, + ) + self.assertEqual(winner_result, 0) + self.assertEqual(winner_idx, 0)