Skip to content

Commit 983889a

Browse files
committed
Trying async downloads
1 parent 323cddd commit 983889a

File tree

4 files changed

+275
-0
lines changed

4 files changed

+275
-0
lines changed

google/cloud/storage/_experimental/asyncio/json/__init__.py

Whitespace-only changes.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Async classes for holding the credentials, and connection"""
2+
3+
4+
import google.auth._credentials_async
5+
from google.cloud.client import _ClientProjectMixin
6+
from google.cloud.client import _CREDENTIALS_REFRESH_TIMEOUT
7+
from google.auth.transport import _aiohttp_requests as async_requests
8+
from google.cloud.storage import retry as storage_retry
9+
from google.auth import _default_async
10+
from google.api_core import retry_async
11+
12+
13+
DEFAULT_ASYNC_RETRY = retry_async.AsyncRetry(predicate=storage_retry._should_retry)
14+
15+
class Client:
16+
SCOPE = None
17+
# Would be overridden by child classes.
18+
19+
def __init__(self):
20+
async_creds, _ = _default_async.default_async(scopes=self.SCOPE)
21+
self._async_credentials = google.auth._credentials_async.with_scopes_if_required(
22+
async_creds, scopes=self.SCOPE
23+
)
24+
self._async_http_internal = None
25+
26+
@property
27+
def _async_http(self):
28+
if self._async_http_internal is None:
29+
self._async_http_internal = async_requests.AuthorizedSession(
30+
self._async_credentials,
31+
refresh_timeout=_CREDENTIALS_REFRESH_TIMEOUT,
32+
)
33+
return self._async_http_internal
34+
35+
async def __aenter__(self):
36+
return self
37+
38+
async def __aexit__(self, _exc_type, _exc_val, _exc_tb):
39+
if self._async_http_internal is not None:
40+
await self._async_http_internal.close()
41+
42+
43+
class ClientWithProjectAsync(Client, _ClientProjectMixin):
44+
_SET_PROJECT = True
45+
46+
def __init__(self, project=None):
47+
_ClientProjectMixin.__init__(self, project=project)
48+
Client.__init__(self)
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""Async client for SDK downloads"""
2+
3+
import os
4+
import asyncio
5+
import aiofiles
6+
7+
from google.cloud.storage._experimental.asyncio.json import _helpers
8+
from google.cloud.storage._experimental.asyncio.json import download
9+
from google.cloud.storage._helpers import _DEFAULT_SCHEME
10+
from google.cloud.storage._helpers import _STORAGE_HOST_TEMPLATE
11+
from google.cloud.storage._helpers import _DEFAULT_UNIVERSE_DOMAIN
12+
from google.cloud.storage import blob
13+
14+
15+
_SLICED_DOWNLOAD_THRESHOLD = 1024*1024*1024 # 1GB
16+
_SLICED_DOWNLOAD_PARTS = 5
17+
_USERAGENT = 'test-prototype'
18+
19+
20+
class AsyncClient(_helpers.ClientWithProjectAsync):
21+
22+
SCOPE = (
23+
"https://www.googleapis.com/auth/devstorage.full_control",
24+
"https://www.googleapis.com/auth/devstorage.read_only",
25+
"https://www.googleapis.com/auth/devstorage.read_write",
26+
)
27+
28+
@property
29+
def api_endpoint(self):
30+
return _DEFAULT_SCHEME + _STORAGE_HOST_TEMPLATE.format(
31+
universe_domain=_DEFAULT_UNIVERSE_DOMAIN
32+
)
33+
34+
def _get_download_url(self, blob_obj):
35+
return f'{self.api_endpoint}/download/storage/v1/b/{blob_obj.bucket.name}/o/{blob_obj.name}?alt=media'
36+
37+
async def _perform_download(
38+
self,
39+
transport,
40+
file_obj,
41+
download_url,
42+
headers,
43+
start=None,
44+
end=None,
45+
timeout=None,
46+
checksum="md5",
47+
retry=_helpers.DEFAULT_ASYNC_RETRY,
48+
sequential_read=False,
49+
):
50+
download_obj = download.DownloadAsync(
51+
download_url,
52+
stream=file_obj,
53+
headers=headers,
54+
start=start,
55+
end=end,
56+
checksum=checksum,
57+
retry=retry,
58+
sequential_read=sequential_read,
59+
)
60+
await download_obj.consume(transport, timeout=timeout)
61+
62+
def _check_if_sliced_download_is_eligible(self, obj_size, checksum):
63+
if obj_size < _SLICED_DOWNLOAD_THRESHOLD:
64+
return False
65+
# Need to support checksum validations for parallel downloads.
66+
return checksum==None
67+
68+
async def download_to_file(
69+
self,
70+
blob_obj,
71+
filename,
72+
start=None,
73+
end=None,
74+
timeout=None,
75+
checksum="md5",
76+
retry=_helpers.DEFAULT_ASYNC_RETRY,
77+
sequential_read=False,
78+
):
79+
download_url = self._get_download_url(blob_obj)
80+
headers = blob._get_encryption_headers(blob_obj._encryption_key)
81+
headers["accept-encoding"] = "gzip"
82+
headers = {
83+
**blob._get_default_headers(_USERAGENT),
84+
**headers,
85+
}
86+
87+
transport = self._async_http
88+
if not blob_obj.size:
89+
blob_obj.reload()
90+
obj_size = blob_obj.size
91+
try:
92+
if not sequential_read and self._check_if_sliced_download_is_eligible(obj_size, checksum): # 1GB
93+
print("Sliced Download Preferred, and Starting...")
94+
chunks_offset = [0] + [obj_size//_SLICED_DOWNLOAD_PARTS]*(_SLICED_DOWNLOAD_PARTS-1) + [obj_size - obj_size//_SLICED_DOWNLOAD_PARTS*(_SLICED_DOWNLOAD_PARTS-1)]
95+
for i in range(1, _SLICED_DOWNLOAD_PARTS+1):
96+
chunks_offset[i]+=chunks_offset[i-1]
97+
98+
with open(filename, 'wb') as _: pass # trunacates the file to zero, and keeps the file.
99+
100+
tasks, file_handles = [], []
101+
try:
102+
for idx in range(_SLICED_DOWNLOAD_PARTS):
103+
file_handle = await aiofiles.open(filename, 'r+b')
104+
await file_handle.seek(chunks_offset[idx])
105+
tasks.append(
106+
self._perform_download(
107+
transport,
108+
file_handle,
109+
download_url,
110+
headers,
111+
chunks_offset[idx],
112+
chunks_offset[idx+1]-1,
113+
timeout=timeout,
114+
checksum=checksum,
115+
retry=retry,
116+
sequential_read=sequential_read,
117+
)
118+
)
119+
file_handles.append(file_handle)
120+
await asyncio.gather(*tasks)
121+
finally:
122+
for file_handle in file_handles:
123+
await file_handle.close()
124+
else:
125+
print("Sequential Download Preferred, and Starting...")
126+
async with aiofiles.open(filename, "wb") as file_obj:
127+
await self._perform_download(
128+
transport,
129+
file_obj,
130+
download_url,
131+
headers,
132+
start,
133+
end,
134+
timeout=timeout,
135+
checksum=checksum,
136+
retry=retry,
137+
sequential_read=sequential_read,
138+
)
139+
except (blob.DataCorruption, blob.NotFound):
140+
await aiofiles.os.remove(filename)
141+
raise
142+
except blob.InvalidResponse as exc:
143+
blob._raise_from_invalid_response(exc)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""Async based download code"""
2+
3+
import http
4+
import aiohttp
5+
6+
from google.cloud.storage._experimental.asyncio.json._helpers import DEFAULT_ASYNC_RETRY
7+
from google.cloud.storage._media.requests import _request_helpers
8+
from google.cloud.storage._media import _download
9+
from google.cloud.storage._media import _helpers
10+
from google.cloud.storage._media.requests import download as storage_download
11+
12+
13+
class DownloadAsync(_request_helpers.RequestsMixin, _download.Download):
14+
15+
def __init__(
16+
self,
17+
media_url,
18+
stream=None,
19+
start=None,
20+
end=None,
21+
headers=None,
22+
checksum="md5",
23+
retry=DEFAULT_ASYNC_RETRY,
24+
sequential_read=False,
25+
):
26+
super().__init__(
27+
media_url, stream=stream, start=start, end=end, headers=headers, checksum=checksum, retry=retry
28+
)
29+
self.sequential_read = sequential_read
30+
31+
async def _write_to_stream(self, response):
32+
if not self.sequential_read:
33+
# If we've not set expected checksum, or checksum object yet, and if it is not
34+
# sequential download, API would not return us hash value for each chunk.
35+
# We could ideally compute the crc32c checksum for each chunk, and later combine them
36+
# and check, However for prototype not implementing it.
37+
expected_checksum = None
38+
checksum_object = _helpers._DoNothingHash()
39+
self._expected_checksum = expected_checksum
40+
self._checksum_object = checksum_object
41+
else:
42+
# Sequential read, so fetch the hash from the headers.
43+
expected_checksum, checksum_object = _helpers._get_expected_checksum(
44+
response, self._get_headers, self.media_url, checksum_type=self.checksum
45+
)
46+
self._expected_checksum = expected_checksum
47+
self._checksum_object = checksum_object
48+
49+
async with response:
50+
chunk_size = 4096 * 32
51+
async for chunk in response.content.iter_chunked(chunk_size):
52+
await self._stream.write(chunk)
53+
self._bytes_downloaded += len(chunk)
54+
checksum_object.update(chunk)
55+
56+
if (
57+
expected_checksum is not None
58+
and response.status != http.client.PARTIAL_CONTENT
59+
):
60+
actual_checksum = _helpers.prepare_checksum_digest(checksum_object.digest())
61+
62+
if actual_checksum != expected_checksum:
63+
raise storage_download.DataCorruption('Corrupted download!')
64+
65+
async def consume(
66+
self,
67+
transport,
68+
timeout=aiohttp.ClientTimeout(total=None, sock_read=300),
69+
):
70+
method, _, payload, headers = self._prepare_request()
71+
request_kwargs = {
72+
"data": payload,
73+
"headers": headers,
74+
"timeout": timeout,
75+
}
76+
async def retriable_request():
77+
url = self.media_url
78+
result = await transport.request(method, url, **request_kwargs)
79+
await self._write_to_stream(result)
80+
if result.status not in (http.client.OK, http.client.PARTIAL_CONTENT):
81+
result.raise_for_status()
82+
return result
83+
84+
return await _request_helpers.wait_and_retry(retriable_request, self._retry_strategy)

0 commit comments

Comments
 (0)