@@ -2246,8 +2246,13 @@ def test__set_metadata_to_none(self):
22462246 def test__get_upload_arguments (self ):
22472247 name = "blob-name"
22482248 key = b"[pXw@,p@@AfBfrR3x-2b2SCHR,.?YwRO"
2249+ custom_headers = {
2250+ "x-goog-custom-audit-foo" : "bar" ,
2251+ "x-goog-custom-audit-user" : "baz" ,
2252+ }
22492253 client = mock .Mock (_connection = _Connection )
22502254 client ._connection .user_agent = "testing 1.2.3"
2255+ client ._extra_headers = custom_headers
22512256 blob = self ._make_one (name , bucket = None , encryption_key = key )
22522257 blob .content_disposition = "inline"
22532258
@@ -2271,6 +2276,7 @@ def test__get_upload_arguments(self):
22712276 "X-Goog-Encryption-Algorithm" : "AES256" ,
22722277 "X-Goog-Encryption-Key" : header_key_value ,
22732278 "X-Goog-Encryption-Key-Sha256" : header_key_hash_value ,
2279+ ** custom_headers ,
22742280 }
22752281 self .assertEqual (
22762282 headers ["X-Goog-API-Client" ],
@@ -2325,6 +2331,7 @@ def _do_multipart_success(
23252331
23262332 client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
23272333 client ._connection .API_BASE_URL = "https://storage.googleapis.com"
2334+ client ._extra_headers = {}
23282335
23292336 # Mock get_api_base_url_for_mtls function.
23302337 mtls_url = "https://foo.mtls"
@@ -2424,11 +2431,14 @@ def _do_multipart_success(
24242431 with patch .object (
24252432 _helpers , "_get_invocation_id" , return_value = GCCL_INVOCATION_TEST_CONST
24262433 ):
2427- headers = _get_default_headers (
2428- client ._connection .user_agent ,
2429- b'multipart/related; boundary="==0=="' ,
2430- "application/xml" ,
2431- )
2434+ headers = {
2435+ ** _get_default_headers (
2436+ client ._connection .user_agent ,
2437+ b'multipart/related; boundary="==0=="' ,
2438+ "application/xml" ,
2439+ ),
2440+ ** client ._extra_headers ,
2441+ }
24322442 client ._http .request .assert_called_once_with (
24332443 "POST" , upload_url , data = payload , headers = headers , timeout = expected_timeout
24342444 )
@@ -2520,6 +2530,19 @@ def test__do_multipart_upload_with_client(self, mock_get_boundary):
25202530 transport = self ._mock_transport (http .client .OK , {})
25212531 client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
25222532 client ._connection .API_BASE_URL = "https://storage.googleapis.com"
2533+ client ._extra_headers = {}
2534+ self ._do_multipart_success (mock_get_boundary , client = client )
2535+
2536+ @mock .patch ("google.resumable_media._upload.get_boundary" , return_value = b"==0==" )
2537+ def test__do_multipart_upload_with_client_custom_headers (self , mock_get_boundary ):
2538+ custom_headers = {
2539+ "x-goog-custom-audit-foo" : "bar" ,
2540+ "x-goog-custom-audit-user" : "baz" ,
2541+ }
2542+ transport = self ._mock_transport (http .client .OK , {})
2543+ client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
2544+ client ._connection .API_BASE_URL = "https://storage.googleapis.com"
2545+ client ._extra_headers = custom_headers
25232546 self ._do_multipart_success (mock_get_boundary , client = client )
25242547
25252548 @mock .patch ("google.resumable_media._upload.get_boundary" , return_value = b"==0==" )
@@ -2597,6 +2620,7 @@ def _initiate_resumable_helper(
25972620 # Create some mock arguments and call the method under test.
25982621 client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
25992622 client ._connection .API_BASE_URL = "https://storage.googleapis.com"
2623+ client ._extra_headers = {}
26002624
26012625 # Mock get_api_base_url_for_mtls function.
26022626 mtls_url = "https://foo.mtls"
@@ -2677,13 +2701,15 @@ def _initiate_resumable_helper(
26772701 _helpers , "_get_invocation_id" , return_value = GCCL_INVOCATION_TEST_CONST
26782702 ):
26792703 if extra_headers is None :
2680- self .assertEqual (
2681- upload ._headers ,
2682- _get_default_headers (client ._connection .user_agent , content_type ),
2683- )
2704+ expected_headers = {
2705+ ** _get_default_headers (client ._connection .user_agent , content_type ),
2706+ ** client ._extra_headers ,
2707+ }
2708+ self .assertEqual (upload ._headers , expected_headers )
26842709 else :
26852710 expected_headers = {
26862711 ** _get_default_headers (client ._connection .user_agent , content_type ),
2712+ ** client ._extra_headers ,
26872713 ** extra_headers ,
26882714 }
26892715 self .assertEqual (upload ._headers , expected_headers )
@@ -2730,9 +2756,12 @@ def _initiate_resumable_helper(
27302756 with patch .object (
27312757 _helpers , "_get_invocation_id" , return_value = GCCL_INVOCATION_TEST_CONST
27322758 ):
2733- expected_headers = _get_default_headers (
2734- client ._connection .user_agent , x_upload_content_type = content_type
2735- )
2759+ expected_headers = {
2760+ ** _get_default_headers (
2761+ client ._connection .user_agent , x_upload_content_type = content_type
2762+ ),
2763+ ** client ._extra_headers ,
2764+ }
27362765 if size is not None :
27372766 expected_headers ["x-upload-content-length" ] = str (size )
27382767 if extra_headers is not None :
@@ -2824,6 +2853,21 @@ def test__initiate_resumable_upload_with_client(self):
28242853
28252854 client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
28262855 client ._connection .API_BASE_URL = "https://storage.googleapis.com"
2856+ client ._extra_headers = {}
2857+ self ._initiate_resumable_helper (client = client )
2858+
2859+ def test__initiate_resumable_upload_with_client_custom_headers (self ):
2860+ custom_headers = {
2861+ "x-goog-custom-audit-foo" : "bar" ,
2862+ "x-goog-custom-audit-user" : "baz" ,
2863+ }
2864+ resumable_url = "http://test.invalid?upload_id=hey-you"
2865+ response_headers = {"location" : resumable_url }
2866+ transport = self ._mock_transport (http .client .OK , response_headers )
2867+
2868+ client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
2869+ client ._connection .API_BASE_URL = "https://storage.googleapis.com"
2870+ client ._extra_headers = custom_headers
28272871 self ._initiate_resumable_helper (client = client )
28282872
28292873 def _make_resumable_transport (
@@ -3000,6 +3044,7 @@ def _do_resumable_helper(
30003044 client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
30013045 client ._connection .API_BASE_URL = "https://storage.googleapis.com"
30023046 client ._connection .user_agent = USER_AGENT
3047+ client ._extra_headers = {}
30033048 stream = io .BytesIO (data )
30043049
30053050 bucket = _Bucket (name = "yesterday" )
@@ -3612,26 +3657,32 @@ def _create_resumable_upload_session_helper(
36123657 if_metageneration_match = None ,
36133658 if_metageneration_not_match = None ,
36143659 retry = None ,
3660+ client = None ,
36153661 ):
36163662 bucket = _Bucket (name = "alex-trebek" )
36173663 blob = self ._make_one ("blob-name" , bucket = bucket )
36183664 chunk_size = 99 * blob ._CHUNK_SIZE_MULTIPLE
36193665 blob .chunk_size = chunk_size
3620-
3621- # Create mocks to be checked for doing transport.
36223666 resumable_url = "http://test.invalid?upload_id=clean-up-everybody"
3623- response_headers = {"location" : resumable_url }
3624- transport = self ._mock_transport (http .client .OK , response_headers )
3625- if side_effect is not None :
3626- transport .request .side_effect = side_effect
3627-
3628- # Create some mock arguments and call the method under test.
36293667 content_type = "text/plain"
36303668 size = 10000
3631- client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
3632- client ._connection .API_BASE_URL = "https://storage.googleapis.com"
3633- client ._connection .user_agent = "testing 1.2.3"
3669+ transport = None
36343670
3671+ if not client :
3672+ # Create mocks to be checked for doing transport.
3673+ response_headers = {"location" : resumable_url }
3674+ transport = self ._mock_transport (http .client .OK , response_headers )
3675+
3676+ # Create some mock arguments and call the method under test.
3677+ client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
3678+ client ._connection .API_BASE_URL = "https://storage.googleapis.com"
3679+ client ._connection .user_agent = "testing 1.2.3"
3680+ client ._extra_headers = {}
3681+
3682+ if transport is None :
3683+ transport = client ._http
3684+ if side_effect is not None :
3685+ transport .request .side_effect = side_effect
36353686 if timeout is None :
36363687 expected_timeout = self ._get_default_timeout ()
36373688 timeout_kwarg = {}
@@ -3689,6 +3740,7 @@ def _create_resumable_upload_session_helper(
36893740 ** _get_default_headers (
36903741 client ._connection .user_agent , x_upload_content_type = content_type
36913742 ),
3743+ ** client ._extra_headers ,
36923744 "x-upload-content-length" : str (size ),
36933745 "x-upload-content-type" : content_type ,
36943746 }
@@ -3750,6 +3802,28 @@ def test_create_resumable_upload_session_with_failure(self):
37503802 self .assertIn (message , exc_info .exception .message )
37513803 self .assertEqual (exc_info .exception .errors , [])
37523804
3805+ def test_create_resumable_upload_session_with_client (self ):
3806+ resumable_url = "http://test.invalid?upload_id=clean-up-everybody"
3807+ response_headers = {"location" : resumable_url }
3808+ transport = self ._mock_transport (http .client .OK , response_headers )
3809+ client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
3810+ client ._connection .API_BASE_URL = "https://storage.googleapis.com"
3811+ client ._extra_headers = {}
3812+ self ._create_resumable_upload_session_helper (client = client )
3813+
3814+ def test_create_resumable_upload_session_with_client_custom_headers (self ):
3815+ custom_headers = {
3816+ "x-goog-custom-audit-foo" : "bar" ,
3817+ "x-goog-custom-audit-user" : "baz" ,
3818+ }
3819+ resumable_url = "http://test.invalid?upload_id=clean-up-everybody"
3820+ response_headers = {"location" : resumable_url }
3821+ transport = self ._mock_transport (http .client .OK , response_headers )
3822+ client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
3823+ client ._connection .API_BASE_URL = "https://storage.googleapis.com"
3824+ client ._extra_headers = custom_headers
3825+ self ._create_resumable_upload_session_helper (client = client )
3826+
37533827 def test_get_iam_policy_defaults (self ):
37543828 from google .cloud .storage .iam import STORAGE_OWNER_ROLE
37553829 from google .cloud .storage .iam import STORAGE_EDITOR_ROLE
@@ -5815,6 +5889,34 @@ def test_open(self):
58155889 with self .assertRaises (ValueError ):
58165890 blob .open ("w" , ignore_flush = False )
58175891
5892+ def test_downloads_w_client_custom_headers (self ):
5893+ import google .auth .credentials
5894+ from google .cloud .storage import Client
5895+
5896+ custom_headers = {
5897+ "x-goog-custom-audit-foo" : "bar" ,
5898+ "x-goog-custom-audit-user" : "baz" ,
5899+ }
5900+ credentials = mock .Mock (spec = google .auth .credentials .Credentials )
5901+ client = Client (
5902+ project = "project" , credentials = credentials , extra_headers = custom_headers
5903+ )
5904+ blob = self ._make_one ("blob-name" , bucket = _Bucket (client ))
5905+ file_obj = io .BytesIO ()
5906+
5907+ downloads = {
5908+ client .download_blob_to_file : (blob , file_obj ),
5909+ blob .download_to_file : (file_obj ,),
5910+ blob .download_as_bytes : (),
5911+ }
5912+ for method , args in downloads .items ():
5913+ with mock .patch .object (blob , "_do_download" ):
5914+ method (* args )
5915+ blob ._do_download .assert_called ()
5916+ called_headers = blob ._do_download .call_args .args [- 4 ]
5917+ self .assertIsInstance (called_headers , dict )
5918+ self .assertDictContainsSubset (custom_headers , called_headers )
5919+
58185920
58195921class Test__quote (unittest .TestCase ):
58205922 @staticmethod
0 commit comments