|  | 
|  | 1 | +import asyncio | 
|  | 2 | + | 
|  | 3 | +from aiobotocore.session import AioSession | 
|  | 4 | + | 
|  | 5 | +from ...mock_server import AIOServer | 
|  | 6 | +from .. import ClientHTTPStubber | 
|  | 7 | + | 
|  | 8 | + | 
|  | 9 | +def get_captured_ua_strings(stubber): | 
|  | 10 | +    """Get captured request-level user agent strings from stubber. | 
|  | 11 | +    :type stubber: tests.BaseHTTPStubber | 
|  | 12 | +    """ | 
|  | 13 | +    return [req.headers['User-Agent'].decode() for req in stubber.requests] | 
|  | 14 | + | 
|  | 15 | + | 
|  | 16 | +def parse_registered_feature_ids(ua_string): | 
|  | 17 | +    """Parse registered feature ids in user agent string. | 
|  | 18 | +    :type ua_string: str | 
|  | 19 | +    :rtype: list[str] | 
|  | 20 | +    """ | 
|  | 21 | +    ua_fields = ua_string.split(' ') | 
|  | 22 | +    feature_field = [field for field in ua_fields if field.startswith('m/')][0] | 
|  | 23 | +    return feature_field[2:].split(',') | 
|  | 24 | + | 
|  | 25 | + | 
|  | 26 | +async def test_user_agent_has_registered_feature_id(): | 
|  | 27 | +    session = AioSession() | 
|  | 28 | + | 
|  | 29 | +    async with ( | 
|  | 30 | +        AIOServer() as server, | 
|  | 31 | +        session.create_client( | 
|  | 32 | +            's3', | 
|  | 33 | +            endpoint_url=server.endpoint_url, | 
|  | 34 | +            aws_secret_access_key='xxx', | 
|  | 35 | +            aws_access_key_id='xxx', | 
|  | 36 | +        ) as s3_client, | 
|  | 37 | +    ): | 
|  | 38 | +        with ClientHTTPStubber(s3_client) as stub_client: | 
|  | 39 | +            stub_client.add_response() | 
|  | 40 | +            paginator = s3_client.get_paginator('list_buckets') | 
|  | 41 | +            # The `paginate()` method registers `'PAGINATOR': 'C'` | 
|  | 42 | +            async for _ in paginator.paginate(): | 
|  | 43 | +                pass | 
|  | 44 | + | 
|  | 45 | +        ua_string = get_captured_ua_strings(stub_client)[0] | 
|  | 46 | +        feature_list = parse_registered_feature_ids(ua_string) | 
|  | 47 | +        assert 'C' in feature_list | 
|  | 48 | + | 
|  | 49 | + | 
|  | 50 | +async def test_registered_feature_ids_dont_bleed_between_requests(): | 
|  | 51 | +    session = AioSession() | 
|  | 52 | + | 
|  | 53 | +    async with ( | 
|  | 54 | +        AIOServer() as server, | 
|  | 55 | +        session.create_client( | 
|  | 56 | +            's3', | 
|  | 57 | +            endpoint_url=server.endpoint_url, | 
|  | 58 | +            aws_secret_access_key='xxx', | 
|  | 59 | +            aws_access_key_id='xxx', | 
|  | 60 | +        ) as s3_client, | 
|  | 61 | +    ): | 
|  | 62 | +        with ClientHTTPStubber(s3_client) as stub_client: | 
|  | 63 | +            stub_client.add_response() | 
|  | 64 | +            waiter = s3_client.get_waiter('bucket_exists') | 
|  | 65 | +            # The `wait()` method registers `'WAITER': 'B'` | 
|  | 66 | +            await waiter.wait(Bucket='mybucket') | 
|  | 67 | + | 
|  | 68 | +            stub_client.add_response() | 
|  | 69 | +            paginator = s3_client.get_paginator('list_buckets') | 
|  | 70 | +            # The `paginate()` method registers `'PAGINATOR': 'C'` | 
|  | 71 | +            async for _ in paginator.paginate(): | 
|  | 72 | +                pass | 
|  | 73 | + | 
|  | 74 | +        ua_strings = get_captured_ua_strings(stub_client) | 
|  | 75 | +        waiter_feature_list = parse_registered_feature_ids(ua_strings[0]) | 
|  | 76 | +        assert 'B' in waiter_feature_list | 
|  | 77 | + | 
|  | 78 | +        paginator_feature_list = parse_registered_feature_ids(ua_strings[1]) | 
|  | 79 | +        assert 'C' in paginator_feature_list | 
|  | 80 | +        assert 'B' not in paginator_feature_list | 
|  | 81 | + | 
|  | 82 | + | 
|  | 83 | +# This tests context's bleeding across tasks instead | 
|  | 84 | +async def test_registered_feature_ids_dont_bleed_across_threads(): | 
|  | 85 | +    session = AioSession() | 
|  | 86 | + | 
|  | 87 | +    async with ( | 
|  | 88 | +        AIOServer() as server, | 
|  | 89 | +        session.create_client( | 
|  | 90 | +            's3', | 
|  | 91 | +            endpoint_url=server.endpoint_url, | 
|  | 92 | +            aws_secret_access_key='xxx', | 
|  | 93 | +            aws_access_key_id='xxx', | 
|  | 94 | +        ) as s3_client, | 
|  | 95 | +    ): | 
|  | 96 | + | 
|  | 97 | +        async def wait(): | 
|  | 98 | +            with ClientHTTPStubber(s3_client) as stub_client: | 
|  | 99 | +                stub_client.add_response() | 
|  | 100 | +                waiter = s3_client.get_waiter('bucket_exists') | 
|  | 101 | +                # The `wait()` method registers `'WAITER': 'B'` | 
|  | 102 | +                await waiter.wait(Bucket='mybucket') | 
|  | 103 | +            ua_string = get_captured_ua_strings(stub_client)[0] | 
|  | 104 | +            return parse_registered_feature_ids(ua_string) | 
|  | 105 | + | 
|  | 106 | +        async def paginate(): | 
|  | 107 | +            with ClientHTTPStubber(s3_client) as stub_client: | 
|  | 108 | +                stub_client.add_response() | 
|  | 109 | +                paginator = s3_client.get_paginator('list_buckets') | 
|  | 110 | +                # The `paginate()` method registers `'PAGINATOR': 'C'` | 
|  | 111 | +                async for _ in paginator.paginate(): | 
|  | 112 | +                    pass | 
|  | 113 | +            ua_string = get_captured_ua_strings(stub_client)[0] | 
|  | 114 | +            return parse_registered_feature_ids(ua_string) | 
|  | 115 | + | 
|  | 116 | +        waiter_features, paginator_features = await asyncio.gather( | 
|  | 117 | +            wait(), paginate() | 
|  | 118 | +        ) | 
|  | 119 | + | 
|  | 120 | +        assert 'B' in waiter_features | 
|  | 121 | +        assert 'C' not in waiter_features | 
|  | 122 | +        assert 'C' in paginator_features | 
|  | 123 | +        assert 'B' not in paginator_features | 
0 commit comments