diff --git a/zarr/storage.py b/zarr/storage.py index 7a5273c044..a1578ef0aa 100644 --- a/zarr/storage.py +++ b/zarr/storage.py @@ -1886,3 +1886,181 @@ def __delitem__(self, key): with self._mutex: self._invalidate_keys() self._invalidate_value(key) + + +# utility functions for object stores + + +def _strip_prefix_from_path(path, prefix): + # normalized things will not have any leading or trailing slashes + path_norm = normalize_storage_path(path) + prefix_norm = normalize_storage_path(prefix) + if path_norm.startswith(prefix_norm): + return path_norm[(len(prefix_norm)+1):] + else: + return path + + +def _append_path_to_prefix(path, prefix): + return '/'.join([normalize_storage_path(prefix), + normalize_storage_path(path)]) + + +def atexit_rmgcspath(bucket, path): + from google.cloud import storage + client = storage.Client() + bucket = client.get_bucket(bucket) + bucket.delete_blobs(bucket.list_blobs(prefix=path)) + + +class GCSStore(MutableMapping): + """Storage class using a Google Cloud Storage (GCS) + + Parameters + ---------- + bucket_name : string + The name of the GCS bucket + prefix : string, optional + The prefix within the bucket (i.e. subdirectory) + client_kwargs : dict, optional + Extra options passed to ``google.cloud.storage.Client`` when connecting + to GCS + + Notes + ----- + In order to use this store, you must install the Google Cloud Storage + `Python Client Library `_. + You must also provide valid application credentials, either by setting the + ``GOOGLE_APPLICATION_CREDENTIALS`` environment variable or via + `default credentials `_. + """ + + def __init__(self, bucket_name, prefix=None, client_kwargs={}): + + self.bucket_name = bucket_name + self.prefix = normalize_storage_path(prefix) + self.client_kwargs = client_kwargs + self.initialize_bucket() + + def initialize_bucket(self): + from google.cloud import storage + # run `gcloud auth application-default login` from shell + client = storage.Client(**self.client_kwargs) + self.bucket = client.get_bucket(self.bucket_name) + # need to properly handle excpetions + import google.api_core.exceptions as exceptions + self.exceptions = exceptions + + # needed for pickling + def __getstate__(self): + state = self.__dict__.copy() + del state['bucket'] + del state['exceptions'] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.initialize_bucket() + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def full_path(self, path=None): + return _append_path_to_prefix(path, self.prefix) + + def list_gcs_directory_blobs(self, path): + """Return list of all blobs *directly* under a gcs prefix.""" + prefix = normalize_storage_path(path) + '/' + return [blob.name for blob in + self.bucket.list_blobs(prefix=prefix, delimiter='/')] + + # from https://github.com/GoogleCloudPlatform/google-cloud-python/issues/920 + def list_gcs_subdirectories(self, path): + """Return set of all "subdirectories" from a gcs prefix.""" + prefix = normalize_storage_path(path) + '/' + iterator = self.bucket.list_blobs(prefix=prefix, delimiter='/') + prefixes = set() + for page in iterator.pages: + prefixes.update(page.prefixes) + # need to strip trailing slash to be consistent with os.listdir + return [path[:-1] for path in prefixes] + + def list_gcs_directory(self, prefix, strip_prefix=True): + """Return a list of all blobs and subdirectories from a gcs prefix.""" + items = set() + items.update(self.list_gcs_directory_blobs(prefix)) + items.update(self.list_gcs_subdirectories(prefix)) + items = list(items) + if strip_prefix: + items = [_strip_prefix_from_path(path, prefix) for path in items] + return items + + def listdir(self, path=None): + dir_path = self.full_path(path) + return sorted(self.list_gcs_directory(dir_path, strip_prefix=True)) + + def rmdir(self, path=None): + # make sure it's a directory + dir_path = normalize_storage_path(self.full_path(path)) + '/' + self.bucket.delete_blobs(self.bucket.list_blobs(prefix=dir_path)) + + def getsize(self, path=None): + # this function should *not* be recursive + # a lot of slash trickery is required to make this work right + full_path = self.full_path(path) + blob = self.bucket.get_blob(full_path) + if blob is not None: + return blob.size + else: + dir_path = normalize_storage_path(full_path) + '/' + blobs = self.bucket.list_blobs(prefix=dir_path, delimiter='/') + size = 0 + for blob in blobs: + size += blob.size + return size + + def clear(self): + self.rmdir() + + def __getitem__(self, key): + blob_name = self.full_path(key) + blob = self.bucket.get_blob(blob_name) + if blob: + return blob.download_as_string() + else: + raise KeyError('Blob %s not found' % blob_name) + + def __setitem__(self, key, value): + blob_name = self.full_path(key) + blob = self.bucket.blob(blob_name) + blob.upload_from_string(value) + + def __delitem__(self, key): + blob_name = self.full_path(key) + try: + self.bucket.delete_blob(blob_name) + except self.exceptions.NotFound as er: + raise KeyError(er.message) + + def __contains__(self, key): + blob_name = self.full_path(key) + return self.bucket.get_blob(blob_name) is not None + + def __eq__(self, other): + return ( + isinstance(other, GCSStore) and + self.bucket_name == other.bucket_name and + self.prefix == other.prefix + ) + + def __iter__(self): + blobs = self.bucket.list_blobs(prefix=self.prefix) + for blob in blobs: + yield _strip_prefix_from_path(blob.name, self.prefix) + + def __len__(self): + iterator = self.bucket.list_blobs(prefix=self.prefix) + return len(list(iterator)) diff --git a/zarr/tests/test_core.py b/zarr/tests/test_core.py index 390f888287..23d64d1f2c 100644 --- a/zarr/tests/test_core.py +++ b/zarr/tests/test_core.py @@ -7,7 +7,7 @@ import pickle import os import warnings - +import uuid import numpy as np from numpy.testing import assert_array_equal, assert_array_almost_equal @@ -16,7 +16,7 @@ from zarr.storage import (DirectoryStore, init_array, init_group, NestedDirectoryStore, DBMStore, LMDBStore, atexit_rmtree, atexit_rmglob, - LRUStoreCache) + LRUStoreCache, GCSStore, atexit_rmgcspath) from zarr.core import Array from zarr.errors import PermissionError from zarr.compat import PY2, text_type, binary_type @@ -1698,3 +1698,25 @@ def create_array(read_only=False, **kwargs): init_array(store, **kwargs) return Array(store, read_only=read_only, cache_metadata=cache_metadata, cache_attrs=cache_attrs) + + +try: + from google.cloud import storage as gcstorage +except ImportError: # pragma: no cover + gcstorage = None + + +@unittest.skipIf(gcstorage is None, 'google-cloud-storage is not installed') +class TestGCSArray(TestArray): + + def create_array(self, read_only=False, **kwargs): + bucket = 'zarr-test' + prefix = uuid.uuid4() + atexit.register(atexit_rmgcspath, bucket, prefix) + store = GCSStore(bucket, prefix) + cache_metadata = kwargs.pop('cache_metadata', True) + cache_attrs = kwargs.pop('cache_attrs', True) + kwargs.setdefault('compressor', Zlib(1)) + init_array(store, **kwargs) + return Array(store, read_only=read_only, cache_metadata=cache_metadata, + cache_attrs=cache_attrs) diff --git a/zarr/tests/test_storage.py b/zarr/tests/test_storage.py index f68f8a6ed6..ffc19822a5 100644 --- a/zarr/tests/test_storage.py +++ b/zarr/tests/test_storage.py @@ -8,6 +8,7 @@ import array import shutil import os +import uuid import numpy as np @@ -19,7 +20,8 @@ DirectoryStore, ZipStore, init_group, group_meta_key, getsize, migrate_1to2, TempStore, atexit_rmtree, NestedDirectoryStore, default_compressor, DBMStore, - LMDBStore, atexit_rmglob, LRUStoreCache) + LMDBStore, atexit_rmglob, LRUStoreCache, GCSStore, + atexit_rmgcspath) from zarr.meta import (decode_array_metadata, encode_array_metadata, ZARR_FORMAT, decode_group_metadata, encode_group_metadata) from zarr.compat import PY2 @@ -1235,3 +1237,29 @@ def test_format_compatibility(): else: assert compressor.codec_id == z.compressor.codec_id assert compressor.get_config() == z.compressor.get_config() + + +try: + from google.cloud import storage as gcstorage + # cleanup function + +except ImportError: # pragma: no cover + gcstorage = None + + +@unittest.skipIf(gcstorage is None, 'google-cloud-storage is not installed') +class TestGCSStore(StoreTests, unittest.TestCase): + + def create_store(self): + # would need to be replaced with a dedicated test bucket + bucket = 'zarr-test' + prefix = uuid.uuid4() + atexit.register(atexit_rmgcspath, bucket, prefix) + store = GCSStore(bucket, prefix) + return store + + def test_context_manager(self): + with self.create_store() as store: + store['foo'] = b'bar' + store['baz'] = b'qux' + assert 2 == len(store)