|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +# Copyright 2020 Google LLC |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +import uuid |
| 19 | +import pytest |
| 20 | +import importlib |
| 21 | + |
| 22 | +from google import auth as google_auth |
| 23 | +from google.protobuf import json_format |
| 24 | +from google.api_core import exceptions |
| 25 | +from google.api_core import client_options |
| 26 | + |
| 27 | +from google.cloud import storage |
| 28 | +from google.cloud import aiplatform |
| 29 | +from google.cloud.aiplatform import utils |
| 30 | +from google.cloud.aiplatform import initializer |
| 31 | +from google.cloud.aiplatform_v1beta1.types import dataset |
| 32 | +from google.cloud.aiplatform_v1beta1.services import dataset_service |
| 33 | + |
| 34 | +# TODO(vinnys): Replace with env var `BUILD_SPECIFIC_GCP_PROJECT` once supported |
| 35 | +_, _TEST_PROJECT = google_auth.default() |
| 36 | + |
| 37 | +_TEST_LOCATION = "us-central1" |
| 38 | +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" |
| 39 | +_TEST_API_ENDPOINT = f"{_TEST_LOCATION}-aiplatform.googleapis.com" |
| 40 | +_TEST_IMAGE_DATASET_ID = "1084241610289446912" # permanent_50_flowers_dataset |
| 41 | +_TEST_TEXT_DATASET_ID = ( |
| 42 | + "6203215905493614592" # permanent_text_entity_extraction_dataset |
| 43 | +) |
| 44 | +_TEST_DATASET_DISPLAY_NAME = "permanent_50_flowers_dataset" |
| 45 | +_TEST_TABULAR_CLASSIFICATION_GCS_SOURCE = "gs://ucaip-sample-resources/iris_1000.csv" |
| 46 | +_TEST_TEXT_ENTITY_EXTRACTION_GCS_SOURCE = ( |
| 47 | + "gs://ucaip-test-us-central1/dataset/ucaip_ten_dataset.jsonl" |
| 48 | +) |
| 49 | +_TEST_IMAGE_OBJECT_DETECTION_GCS_SOURCE = ( |
| 50 | + "gs://ucaip-test-us-central1/dataset/salads_oid_ml_use_public_unassigned.jsonl" |
| 51 | +) |
| 52 | +_TEST_TEXT_ENTITY_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_extraction_io_format_1.0.0.yaml" |
| 53 | +_TEST_IMAGE_OBJ_DET_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml" |
| 54 | + |
| 55 | + |
| 56 | +class TestDataset: |
| 57 | + def setup_method(self): |
| 58 | + importlib.reload(initializer) |
| 59 | + importlib.reload(aiplatform) |
| 60 | + |
| 61 | + @pytest.fixture() |
| 62 | + def shared_state(self): |
| 63 | + shared_state = {} |
| 64 | + yield shared_state |
| 65 | + |
| 66 | + @pytest.fixture() |
| 67 | + def create_staging_bucket(self, shared_state): |
| 68 | + new_staging_bucket = f"temp-sdk-integration-{uuid.uuid4()}" |
| 69 | + |
| 70 | + storage_client = storage.Client() |
| 71 | + storage_client.create_bucket(new_staging_bucket) |
| 72 | + shared_state["storage_client"] = storage_client |
| 73 | + shared_state["staging_bucket"] = new_staging_bucket |
| 74 | + yield |
| 75 | + |
| 76 | + @pytest.fixture() |
| 77 | + def delete_staging_bucket(self, shared_state): |
| 78 | + yield |
| 79 | + storage_client = shared_state["storage_client"] |
| 80 | + |
| 81 | + # Delete temp staging bucket |
| 82 | + bucket_to_delete = storage_client.get_bucket(shared_state["staging_bucket"]) |
| 83 | + bucket_to_delete.delete(force=True) |
| 84 | + |
| 85 | + # Close Storage Client |
| 86 | + storage_client._http._auth_request.session.close() |
| 87 | + storage_client._http.close() |
| 88 | + |
| 89 | + @pytest.fixture() |
| 90 | + def dataset_gapic_client(self): |
| 91 | + gapic_client = dataset_service.DatasetServiceClient( |
| 92 | + client_options=client_options.ClientOptions(api_endpoint=_TEST_API_ENDPOINT) |
| 93 | + ) |
| 94 | + |
| 95 | + yield gapic_client |
| 96 | + |
| 97 | + @pytest.fixture() |
| 98 | + def create_text_dataset(self, dataset_gapic_client, shared_state): |
| 99 | + |
| 100 | + gapic_dataset = dataset.Dataset( |
| 101 | + display_name=f"temp_sdk_integration_test_create_text_dataset_{uuid.uuid4()}", |
| 102 | + metadata_schema_uri=aiplatform.schema.dataset.metadata.text, |
| 103 | + ) |
| 104 | + |
| 105 | + create_lro = dataset_gapic_client.create_dataset( |
| 106 | + parent=_TEST_PARENT, dataset=gapic_dataset |
| 107 | + ) |
| 108 | + new_dataset = create_lro.result() |
| 109 | + shared_state["dataset_name"] = new_dataset.name |
| 110 | + yield |
| 111 | + |
| 112 | + @pytest.fixture() |
| 113 | + def create_tabular_dataset(self, dataset_gapic_client, shared_state): |
| 114 | + |
| 115 | + gapic_dataset = dataset.Dataset( |
| 116 | + display_name=f"temp_sdk_integration_test_create_tabular_dataset_{uuid.uuid4()}", |
| 117 | + metadata_schema_uri=aiplatform.schema.dataset.metadata.tabular, |
| 118 | + ) |
| 119 | + |
| 120 | + create_lro = dataset_gapic_client.create_dataset( |
| 121 | + parent=_TEST_PARENT, dataset=gapic_dataset |
| 122 | + ) |
| 123 | + new_dataset = create_lro.result() |
| 124 | + shared_state["dataset_name"] = new_dataset.name |
| 125 | + yield |
| 126 | + |
| 127 | + @pytest.fixture() |
| 128 | + def create_image_dataset(self, dataset_gapic_client, shared_state): |
| 129 | + |
| 130 | + gapic_dataset = dataset.Dataset( |
| 131 | + display_name=f"temp_sdk_integration_test_create_image_dataset_{uuid.uuid4()}", |
| 132 | + metadata_schema_uri=aiplatform.schema.dataset.metadata.image, |
| 133 | + ) |
| 134 | + |
| 135 | + create_lro = dataset_gapic_client.create_dataset( |
| 136 | + parent=_TEST_PARENT, dataset=gapic_dataset |
| 137 | + ) |
| 138 | + new_dataset = create_lro.result() |
| 139 | + shared_state["dataset_name"] = new_dataset.name |
| 140 | + yield |
| 141 | + |
| 142 | + @pytest.fixture() |
| 143 | + def delete_new_dataset(self, dataset_gapic_client, shared_state): |
| 144 | + yield |
| 145 | + assert shared_state["dataset_name"] |
| 146 | + |
| 147 | + deletion_lro = dataset_gapic_client.delete_dataset( |
| 148 | + name=shared_state["dataset_name"] |
| 149 | + ) |
| 150 | + deletion_lro.result() |
| 151 | + |
| 152 | + shared_state["dataset_name"] = None |
| 153 | + |
| 154 | + # TODO(vinnys): Remove pytest skip once persistent resources are accessible |
| 155 | + @pytest.mark.skip(reason="System tests cannot access persistent test resources") |
| 156 | + def test_get_existing_dataset(self): |
| 157 | + """Retrieve a known existing dataset, ensure SDK successfully gets the |
| 158 | + dataset resource.""" |
| 159 | + |
| 160 | + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) |
| 161 | + |
| 162 | + flowers_dataset = aiplatform.Dataset(dataset_name=_TEST_IMAGE_DATASET_ID) |
| 163 | + assert flowers_dataset.name == _TEST_IMAGE_DATASET_ID |
| 164 | + assert flowers_dataset.display_name == _TEST_DATASET_DISPLAY_NAME |
| 165 | + |
| 166 | + def test_get_nonexistent_dataset(self): |
| 167 | + """Ensure attempting to retrieve a dataset that doesn't exist raises |
| 168 | + a Google API core 404 exception.""" |
| 169 | + |
| 170 | + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) |
| 171 | + |
| 172 | + # AI Platform service returns 404 |
| 173 | + with pytest.raises(exceptions.NotFound): |
| 174 | + aiplatform.Dataset(dataset_name="0") |
| 175 | + |
| 176 | + @pytest.mark.usefixtures("create_text_dataset", "delete_new_dataset") |
| 177 | + def test_get_new_dataset_and_import(self, dataset_gapic_client, shared_state): |
| 178 | + """Retrieve new, empty dataset and import a text dataset using import(). |
| 179 | + Then verify data items were successfully imported.""" |
| 180 | + |
| 181 | + assert shared_state["dataset_name"] |
| 182 | + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) |
| 183 | + |
| 184 | + my_dataset = aiplatform.Dataset(dataset_name=shared_state["dataset_name"]) |
| 185 | + |
| 186 | + data_items_pre_import = dataset_gapic_client.list_data_items( |
| 187 | + parent=my_dataset.resource_name |
| 188 | + ) |
| 189 | + |
| 190 | + assert len(list(data_items_pre_import)) == 0 |
| 191 | + |
| 192 | + # Blocking call to import |
| 193 | + my_dataset.import_data( |
| 194 | + gcs_source=_TEST_TEXT_ENTITY_EXTRACTION_GCS_SOURCE, |
| 195 | + import_schema_uri=_TEST_TEXT_ENTITY_IMPORT_SCHEMA, |
| 196 | + ) |
| 197 | + |
| 198 | + data_items_post_import = dataset_gapic_client.list_data_items( |
| 199 | + parent=my_dataset.resource_name |
| 200 | + ) |
| 201 | + |
| 202 | + assert len(list(data_items_post_import)) == 469 |
| 203 | + |
| 204 | + @pytest.mark.usefixtures("delete_new_dataset") |
| 205 | + def test_create_and_import_image_dataset(self, dataset_gapic_client, shared_state): |
| 206 | + """Use the Dataset.create() method to create a new image obj detection |
| 207 | + dataset and import images. Then confirm images were successfully imported.""" |
| 208 | + |
| 209 | + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) |
| 210 | + |
| 211 | + img_dataset = aiplatform.Dataset.create( |
| 212 | + display_name=f"temp_sdk_integration_create_and_import_dataset_{uuid.uuid4()}", |
| 213 | + metadata_schema_uri=aiplatform.schema.dataset.metadata.image, |
| 214 | + gcs_source=_TEST_IMAGE_OBJECT_DETECTION_GCS_SOURCE, |
| 215 | + import_schema_uri=_TEST_IMAGE_OBJ_DET_IMPORT_SCHEMA, |
| 216 | + ) |
| 217 | + |
| 218 | + shared_state["dataset_name"] = img_dataset.resource_name |
| 219 | + |
| 220 | + data_items_iterator = dataset_gapic_client.list_data_items( |
| 221 | + parent=img_dataset.resource_name |
| 222 | + ) |
| 223 | + |
| 224 | + assert len(list(data_items_iterator)) == 14 |
| 225 | + |
| 226 | + @pytest.mark.usefixtures("delete_new_dataset") |
| 227 | + def test_create_tabular_dataset(self, dataset_gapic_client, shared_state): |
| 228 | + """Use the Dataset.create() method to create a new tabular dataset. |
| 229 | + Then confirm the dataset was successfully created and references GCS source.""" |
| 230 | + |
| 231 | + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) |
| 232 | + |
| 233 | + tabular_dataset = aiplatform.Dataset.create( |
| 234 | + display_name=f"temp_sdk_integration_create_and_import_dataset_{uuid.uuid4()}", |
| 235 | + metadata_schema_uri=aiplatform.schema.dataset.metadata.tabular, |
| 236 | + gcs_source=[_TEST_TABULAR_CLASSIFICATION_GCS_SOURCE], |
| 237 | + ) |
| 238 | + |
| 239 | + gapic_dataset = tabular_dataset._gca_resource |
| 240 | + shared_state["dataset_name"] = tabular_dataset.resource_name |
| 241 | + |
| 242 | + gapic_metadata = json_format.MessageToDict(gapic_dataset._pb.metadata) |
| 243 | + gcs_source_uris = gapic_metadata["inputConfig"]["gcsSource"]["uri"] |
| 244 | + |
| 245 | + assert len(gcs_source_uris) == 1 |
| 246 | + assert _TEST_TABULAR_CLASSIFICATION_GCS_SOURCE == gcs_source_uris[0] |
| 247 | + assert ( |
| 248 | + gapic_dataset.metadata_schema_uri |
| 249 | + == aiplatform.schema.dataset.metadata.tabular |
| 250 | + ) |
| 251 | + |
| 252 | + # TODO(vinnys): Remove pytest skip once persistent resources are accessible |
| 253 | + @pytest.mark.skip(reason="System tests cannot access persistent test resources") |
| 254 | + @pytest.mark.usefixtures("create_staging_bucket", "delete_staging_bucket") |
| 255 | + def test_export_data(self, shared_state): |
| 256 | + """Get an existing dataset, export data to a newly created folder in |
| 257 | + Google Cloud Storage, then verify data was successfully exported.""" |
| 258 | + |
| 259 | + assert shared_state["staging_bucket"] |
| 260 | + assert shared_state["storage_client"] |
| 261 | + |
| 262 | + aiplatform.init( |
| 263 | + project=_TEST_PROJECT, |
| 264 | + location=_TEST_LOCATION, |
| 265 | + staging_bucket=f"gs://{shared_state['staging_bucket']}", |
| 266 | + ) |
| 267 | + |
| 268 | + text_dataset = aiplatform.Dataset(dataset_name=_TEST_TEXT_DATASET_ID) |
| 269 | + |
| 270 | + exported_files = text_dataset.export_data( |
| 271 | + output_dir=f"gs://{shared_state['staging_bucket']}" |
| 272 | + ) |
| 273 | + |
| 274 | + assert len(exported_files) # Ensure at least one GCS path was returned |
| 275 | + |
| 276 | + exported_file = exported_files[0] |
| 277 | + bucket, prefix = utils.extract_bucket_and_prefix_from_gcs_path(exported_file) |
| 278 | + |
| 279 | + storage_client = shared_state["storage_client"] |
| 280 | + |
| 281 | + bucket = storage_client.get_bucket(bucket) |
| 282 | + blob = bucket.get_blob(prefix) |
| 283 | + |
| 284 | + assert blob # Verify the returned GCS export path exists |
0 commit comments