Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions google/cloud/storage/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -1738,11 +1738,13 @@ def _get_upload_arguments(self, client, content_type, filename=None, command=Non
* The ``content_type`` as a string (according to precedence)
"""
content_type = self._get_content_type(content_type, filename=filename)
# Add any client attached custom headers to the upload headers.
headers = {
**_get_default_headers(
client._connection.user_agent, content_type, command=command
),
**_get_encryption_headers(self._encryption_key),
**client._extra_headers,
}
object_metadata = self._get_writable_metadata()
return headers, object_metadata, content_type
Expand Down Expand Up @@ -4313,9 +4315,11 @@ def _prep_and_do_download(
if_etag_match=if_etag_match,
if_etag_not_match=if_etag_not_match,
)
# Add any client attached custom headers to be sent with the request.
headers = {
**_get_default_headers(client._connection.user_agent, command=command),
**headers,
**client._extra_headers,
}

transport = client._http
Expand Down
12 changes: 11 additions & 1 deletion google/cloud/storage/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ class Client(ClientWithProject):
(Optional) Whether authentication is required under custom endpoints.
If false, uses AnonymousCredentials and bypasses authentication.
Defaults to True. Note this is only used when a custom endpoint is set in conjunction.

:type extra_headers: dict
:param extra_headers:
(Optional) Custom headers to be sent with the requests attached to the client.
For example, you can add custom audit logging headers.
"""

SCOPE = (
Expand All @@ -111,6 +116,7 @@ def __init__(
client_info=None,
client_options=None,
use_auth_w_custom_endpoint=True,
extra_headers={},
):
self._base_connection = None

Expand All @@ -127,6 +133,7 @@ def __init__(
# are passed along, for use in __reduce__ defined elsewhere.
self._initial_client_info = client_info
self._initial_client_options = client_options
self._extra_headers = extra_headers

kw_args = {"client_info": client_info}

Expand Down Expand Up @@ -172,7 +179,10 @@ def __init__(
if no_project:
self.project = None

self._connection = Connection(self, **kw_args)
# Pass extra_headers to Connection
connection = Connection(self, **kw_args)
connection.extra_headers = extra_headers
self._connection = connection
self._batch_stack = _LocalStack()

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/storage/transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,7 @@ def _reduce_client(cl):
_http = None # Can't carry this over
client_info = cl._initial_client_info
client_options = cl._initial_client_options
extra_headers = cl._extra_headers

return _LazyClient, (
client_object_id,
Expand All @@ -1297,6 +1298,7 @@ def _reduce_client(cl):
_http,
client_info,
client_options,
extra_headers,
)


Expand Down
49 changes: 49 additions & 0 deletions tests/unit/test__http.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,55 @@ def test_extra_headers(self):
timeout=_DEFAULT_TIMEOUT,
)

def test_metadata_op_has_client_custom_headers(self):
import requests
import google.auth.credentials
from google.cloud import _http as base_http
from google.cloud.storage import Client
from google.cloud.storage.constants import _DEFAULT_TIMEOUT

custom_headers = {
"x-goog-custom-audit-foo": "bar",
"x-goog-custom-audit-user": "baz",
}
http = mock.create_autospec(requests.Session, instance=True)
response = requests.Response()
response.status_code = 200
data = b"brent-spiner"
response._content = data
http.is_mtls = False
http.request.return_value = response
credentials = mock.Mock(spec=google.auth.credentials.Credentials)
client = Client(
project="project",
credentials=credentials,
_http=http,
extra_headers=custom_headers,
)
req_data = "hey-yoooouuuuu-guuuuuyyssss"
with patch.object(
_helpers, "_get_invocation_id", return_value=GCCL_INVOCATION_TEST_CONST
):
result = client._connection.api_request(
"GET", "/rainbow", data=req_data, expect_json=False
)
self.assertEqual(result, data)

expected_headers = {
**custom_headers,
"Accept-Encoding": "gzip",
base_http.CLIENT_INFO_HEADER: f"{client._connection.user_agent} {GCCL_INVOCATION_TEST_CONST}",
"User-Agent": client._connection.user_agent,
}
expected_uri = client._connection.build_api_url("/rainbow")
http.request.assert_called_once_with(
data=req_data,
headers=expected_headers,
method="GET",
url=expected_uri,
timeout=_DEFAULT_TIMEOUT,
)

def test_build_api_url_no_extra_query_params(self):
from urllib.parse import parse_qsl
from urllib.parse import urlsplit
Expand Down
148 changes: 125 additions & 23 deletions tests/unit/test_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -2246,8 +2246,13 @@ def test__set_metadata_to_none(self):
def test__get_upload_arguments(self):
name = "blob-name"
key = b"[pXw@,p@@AfBfrR3x-2b2SCHR,.?YwRO"
custom_headers = {
"x-goog-custom-audit-foo": "bar",
"x-goog-custom-audit-user": "baz",
}
client = mock.Mock(_connection=_Connection)
client._connection.user_agent = "testing 1.2.3"
client._extra_headers = custom_headers
blob = self._make_one(name, bucket=None, encryption_key=key)
blob.content_disposition = "inline"

Expand All @@ -2271,6 +2276,7 @@ def test__get_upload_arguments(self):
"X-Goog-Encryption-Algorithm": "AES256",
"X-Goog-Encryption-Key": header_key_value,
"X-Goog-Encryption-Key-Sha256": header_key_hash_value,
**custom_headers,
}
self.assertEqual(
headers["X-Goog-API-Client"],
Expand Down Expand Up @@ -2325,6 +2331,7 @@ def _do_multipart_success(

client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = {}

# Mock get_api_base_url_for_mtls function.
mtls_url = "https://foo.mtls"
Expand Down Expand Up @@ -2424,11 +2431,14 @@ def _do_multipart_success(
with patch.object(
_helpers, "_get_invocation_id", return_value=GCCL_INVOCATION_TEST_CONST
):
headers = _get_default_headers(
client._connection.user_agent,
b'multipart/related; boundary="==0=="',
"application/xml",
)
headers = {
**_get_default_headers(
client._connection.user_agent,
b'multipart/related; boundary="==0=="',
"application/xml",
),
**client._extra_headers,
}
client._http.request.assert_called_once_with(
"POST", upload_url, data=payload, headers=headers, timeout=expected_timeout
)
Expand Down Expand Up @@ -2520,6 +2530,19 @@ def test__do_multipart_upload_with_client(self, mock_get_boundary):
transport = self._mock_transport(http.client.OK, {})
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = {}
self._do_multipart_success(mock_get_boundary, client=client)

@mock.patch("google.resumable_media._upload.get_boundary", return_value=b"==0==")
def test__do_multipart_upload_with_client_custom_headers(self, mock_get_boundary):
custom_headers = {
"x-goog-custom-audit-foo": "bar",
"x-goog-custom-audit-user": "baz",
}
transport = self._mock_transport(http.client.OK, {})
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = custom_headers
self._do_multipart_success(mock_get_boundary, client=client)

@mock.patch("google.resumable_media._upload.get_boundary", return_value=b"==0==")
Expand Down Expand Up @@ -2597,6 +2620,7 @@ def _initiate_resumable_helper(
# Create some mock arguments and call the method under test.
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = {}

# Mock get_api_base_url_for_mtls function.
mtls_url = "https://foo.mtls"
Expand Down Expand Up @@ -2677,13 +2701,15 @@ def _initiate_resumable_helper(
_helpers, "_get_invocation_id", return_value=GCCL_INVOCATION_TEST_CONST
):
if extra_headers is None:
self.assertEqual(
upload._headers,
_get_default_headers(client._connection.user_agent, content_type),
)
expected_headers = {
**_get_default_headers(client._connection.user_agent, content_type),
**client._extra_headers,
}
self.assertEqual(upload._headers, expected_headers)
else:
expected_headers = {
**_get_default_headers(client._connection.user_agent, content_type),
**client._extra_headers,
**extra_headers,
}
self.assertEqual(upload._headers, expected_headers)
Expand Down Expand Up @@ -2730,9 +2756,12 @@ def _initiate_resumable_helper(
with patch.object(
_helpers, "_get_invocation_id", return_value=GCCL_INVOCATION_TEST_CONST
):
expected_headers = _get_default_headers(
client._connection.user_agent, x_upload_content_type=content_type
)
expected_headers = {
**_get_default_headers(
client._connection.user_agent, x_upload_content_type=content_type
),
**client._extra_headers,
}
if size is not None:
expected_headers["x-upload-content-length"] = str(size)
if extra_headers is not None:
Expand Down Expand Up @@ -2824,6 +2853,21 @@ def test__initiate_resumable_upload_with_client(self):

client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = {}
self._initiate_resumable_helper(client=client)

def test__initiate_resumable_upload_with_client_custom_headers(self):
custom_headers = {
"x-goog-custom-audit-foo": "bar",
"x-goog-custom-audit-user": "baz",
}
resumable_url = "http://test.invalid?upload_id=hey-you"
response_headers = {"location": resumable_url}
transport = self._mock_transport(http.client.OK, response_headers)

client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = custom_headers
self._initiate_resumable_helper(client=client)

def _make_resumable_transport(
Expand Down Expand Up @@ -3000,6 +3044,7 @@ def _do_resumable_helper(
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._connection.user_agent = USER_AGENT
client._extra_headers = {}
stream = io.BytesIO(data)

bucket = _Bucket(name="yesterday")
Expand Down Expand Up @@ -3612,26 +3657,32 @@ def _create_resumable_upload_session_helper(
if_metageneration_match=None,
if_metageneration_not_match=None,
retry=None,
client=None,
):
bucket = _Bucket(name="alex-trebek")
blob = self._make_one("blob-name", bucket=bucket)
chunk_size = 99 * blob._CHUNK_SIZE_MULTIPLE
blob.chunk_size = chunk_size

# Create mocks to be checked for doing transport.
resumable_url = "http://test.invalid?upload_id=clean-up-everybody"
response_headers = {"location": resumable_url}
transport = self._mock_transport(http.client.OK, response_headers)
if side_effect is not None:
transport.request.side_effect = side_effect

# Create some mock arguments and call the method under test.
content_type = "text/plain"
size = 10000
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._connection.user_agent = "testing 1.2.3"
transport = None

if not client:
# Create mocks to be checked for doing transport.
response_headers = {"location": resumable_url}
transport = self._mock_transport(http.client.OK, response_headers)

# Create some mock arguments and call the method under test.
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._connection.user_agent = "testing 1.2.3"
client._extra_headers = {}

if transport is None:
transport = client._http
if side_effect is not None:
transport.request.side_effect = side_effect
if timeout is None:
expected_timeout = self._get_default_timeout()
timeout_kwarg = {}
Expand Down Expand Up @@ -3689,6 +3740,7 @@ def _create_resumable_upload_session_helper(
**_get_default_headers(
client._connection.user_agent, x_upload_content_type=content_type
),
**client._extra_headers,
"x-upload-content-length": str(size),
"x-upload-content-type": content_type,
}
Expand Down Expand Up @@ -3750,6 +3802,28 @@ def test_create_resumable_upload_session_with_failure(self):
self.assertIn(message, exc_info.exception.message)
self.assertEqual(exc_info.exception.errors, [])

def test_create_resumable_upload_session_with_client(self):
resumable_url = "http://test.invalid?upload_id=clean-up-everybody"
response_headers = {"location": resumable_url}
transport = self._mock_transport(http.client.OK, response_headers)
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = {}
self._create_resumable_upload_session_helper(client=client)

def test_create_resumable_upload_session_with_client_custom_headers(self):
custom_headers = {
"x-goog-custom-audit-foo": "bar",
"x-goog-custom-audit-user": "baz",
}
resumable_url = "http://test.invalid?upload_id=clean-up-everybody"
response_headers = {"location": resumable_url}
transport = self._mock_transport(http.client.OK, response_headers)
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = custom_headers
self._create_resumable_upload_session_helper(client=client)

def test_get_iam_policy_defaults(self):
from google.cloud.storage.iam import STORAGE_OWNER_ROLE
from google.cloud.storage.iam import STORAGE_EDITOR_ROLE
Expand Down Expand Up @@ -5815,6 +5889,34 @@ def test_open(self):
with self.assertRaises(ValueError):
blob.open("w", ignore_flush=False)

def test_downloads_w_client_custom_headers(self):
import google.auth.credentials
from google.cloud.storage import Client

custom_headers = {
"x-goog-custom-audit-foo": "bar",
"x-goog-custom-audit-user": "baz",
}
credentials = mock.Mock(spec=google.auth.credentials.Credentials)
client = Client(
project="project", credentials=credentials, extra_headers=custom_headers
)
blob = self._make_one("blob-name", bucket=_Bucket(client))
file_obj = io.BytesIO()

downloads = {
client.download_blob_to_file: (blob, file_obj),
blob.download_to_file: (file_obj,),
blob.download_as_bytes: (),
}
for method, args in downloads.items():
with mock.patch.object(blob, "_do_download"):
method(*args)
blob._do_download.assert_called()
called_headers = blob._do_download.call_args.args[-4]
self.assertIsInstance(called_headers, dict)
self.assertDictContainsSubset(custom_headers, called_headers)


class Test__quote(unittest.TestCase):
@staticmethod
Expand Down
Loading