Skip to content

Commit cb0f75d

Browse files
authored
Merge branch 'dev' into mor--dataset-refactor-datasource
2 parents 4c83478 + ee6e275 commit cb0f75d

File tree

4 files changed

+294
-3
lines changed

4 files changed

+294
-3
lines changed

.kokoro/samples/python3.8/common.cfg

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ env_vars: {
1919
value: "py-3.8"
2020
}
2121

22+
# Run tests located under tests/system
23+
env_vars: {
24+
key: "RUN_SYSTEM_TESTS"
25+
value: "true"
26+
}
27+
2228
env_vars: {
2329
key: "TRAMPOLINE_BUILD_FILE"
2430
value: "github/python-aiplatform/.kokoro/test-samples.sh"

google/cloud/aiplatform/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class metadata:
3030
"gs://google-cloud-aiplatform/schema/dataset/metadata/tabular_1.0.0.yaml"
3131
)
3232
image = "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml"
33+
text = "gs://google-cloud-aiplatform/schema/dataset/metadata/text_1.0.0.yaml"
3334

3435
class ioformat:
3536
class image:

noxfile.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def default(session):
8080
# Run py.test against the unit tests.
8181
session.run(
8282
"py.test",
83-
"--quiet",
84-
"--cov=google/cloud",
85-
"--cov=tests/unit",
83+
"--cov=google.cloud.aiplatform",
84+
"--cov=google.cloud",
85+
"--cov=tests.unit",
8686
"--cov-append",
8787
"--cov-config=.coveragerc",
8888
"--cov-report=",
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2020 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import uuid
19+
import pytest
20+
import importlib
21+
22+
from google import auth as google_auth
23+
from google.protobuf import json_format
24+
from google.api_core import exceptions
25+
from google.api_core import client_options
26+
27+
from google.cloud import storage
28+
from google.cloud import aiplatform
29+
from google.cloud.aiplatform import utils
30+
from google.cloud.aiplatform import initializer
31+
from google.cloud.aiplatform_v1beta1.types import dataset
32+
from google.cloud.aiplatform_v1beta1.services import dataset_service
33+
34+
# TODO(vinnys): Replace with env var `BUILD_SPECIFIC_GCP_PROJECT` once supported
35+
_, _TEST_PROJECT = google_auth.default()
36+
37+
_TEST_LOCATION = "us-central1"
38+
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
39+
_TEST_API_ENDPOINT = f"{_TEST_LOCATION}-aiplatform.googleapis.com"
40+
_TEST_IMAGE_DATASET_ID = "1084241610289446912" # permanent_50_flowers_dataset
41+
_TEST_TEXT_DATASET_ID = (
42+
"6203215905493614592" # permanent_text_entity_extraction_dataset
43+
)
44+
_TEST_DATASET_DISPLAY_NAME = "permanent_50_flowers_dataset"
45+
_TEST_TABULAR_CLASSIFICATION_GCS_SOURCE = "gs://ucaip-sample-resources/iris_1000.csv"
46+
_TEST_TEXT_ENTITY_EXTRACTION_GCS_SOURCE = (
47+
"gs://ucaip-test-us-central1/dataset/ucaip_ten_dataset.jsonl"
48+
)
49+
_TEST_IMAGE_OBJECT_DETECTION_GCS_SOURCE = (
50+
"gs://ucaip-test-us-central1/dataset/salads_oid_ml_use_public_unassigned.jsonl"
51+
)
52+
_TEST_TEXT_ENTITY_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_extraction_io_format_1.0.0.yaml"
53+
_TEST_IMAGE_OBJ_DET_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml"
54+
55+
56+
class TestDataset:
57+
def setup_method(self):
58+
importlib.reload(initializer)
59+
importlib.reload(aiplatform)
60+
61+
@pytest.fixture()
62+
def shared_state(self):
63+
shared_state = {}
64+
yield shared_state
65+
66+
@pytest.fixture()
67+
def create_staging_bucket(self, shared_state):
68+
new_staging_bucket = f"temp-sdk-integration-{uuid.uuid4()}"
69+
70+
storage_client = storage.Client()
71+
storage_client.create_bucket(new_staging_bucket)
72+
shared_state["storage_client"] = storage_client
73+
shared_state["staging_bucket"] = new_staging_bucket
74+
yield
75+
76+
@pytest.fixture()
77+
def delete_staging_bucket(self, shared_state):
78+
yield
79+
storage_client = shared_state["storage_client"]
80+
81+
# Delete temp staging bucket
82+
bucket_to_delete = storage_client.get_bucket(shared_state["staging_bucket"])
83+
bucket_to_delete.delete(force=True)
84+
85+
# Close Storage Client
86+
storage_client._http._auth_request.session.close()
87+
storage_client._http.close()
88+
89+
@pytest.fixture()
90+
def dataset_gapic_client(self):
91+
gapic_client = dataset_service.DatasetServiceClient(
92+
client_options=client_options.ClientOptions(api_endpoint=_TEST_API_ENDPOINT)
93+
)
94+
95+
yield gapic_client
96+
97+
@pytest.fixture()
98+
def create_text_dataset(self, dataset_gapic_client, shared_state):
99+
100+
gapic_dataset = dataset.Dataset(
101+
display_name=f"temp_sdk_integration_test_create_text_dataset_{uuid.uuid4()}",
102+
metadata_schema_uri=aiplatform.schema.dataset.metadata.text,
103+
)
104+
105+
create_lro = dataset_gapic_client.create_dataset(
106+
parent=_TEST_PARENT, dataset=gapic_dataset
107+
)
108+
new_dataset = create_lro.result()
109+
shared_state["dataset_name"] = new_dataset.name
110+
yield
111+
112+
@pytest.fixture()
113+
def create_tabular_dataset(self, dataset_gapic_client, shared_state):
114+
115+
gapic_dataset = dataset.Dataset(
116+
display_name=f"temp_sdk_integration_test_create_tabular_dataset_{uuid.uuid4()}",
117+
metadata_schema_uri=aiplatform.schema.dataset.metadata.tabular,
118+
)
119+
120+
create_lro = dataset_gapic_client.create_dataset(
121+
parent=_TEST_PARENT, dataset=gapic_dataset
122+
)
123+
new_dataset = create_lro.result()
124+
shared_state["dataset_name"] = new_dataset.name
125+
yield
126+
127+
@pytest.fixture()
128+
def create_image_dataset(self, dataset_gapic_client, shared_state):
129+
130+
gapic_dataset = dataset.Dataset(
131+
display_name=f"temp_sdk_integration_test_create_image_dataset_{uuid.uuid4()}",
132+
metadata_schema_uri=aiplatform.schema.dataset.metadata.image,
133+
)
134+
135+
create_lro = dataset_gapic_client.create_dataset(
136+
parent=_TEST_PARENT, dataset=gapic_dataset
137+
)
138+
new_dataset = create_lro.result()
139+
shared_state["dataset_name"] = new_dataset.name
140+
yield
141+
142+
@pytest.fixture()
143+
def delete_new_dataset(self, dataset_gapic_client, shared_state):
144+
yield
145+
assert shared_state["dataset_name"]
146+
147+
deletion_lro = dataset_gapic_client.delete_dataset(
148+
name=shared_state["dataset_name"]
149+
)
150+
deletion_lro.result()
151+
152+
shared_state["dataset_name"] = None
153+
154+
# TODO(vinnys): Remove pytest skip once persistent resources are accessible
155+
@pytest.mark.skip(reason="System tests cannot access persistent test resources")
156+
def test_get_existing_dataset(self):
157+
"""Retrieve a known existing dataset, ensure SDK successfully gets the
158+
dataset resource."""
159+
160+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
161+
162+
flowers_dataset = aiplatform.Dataset(dataset_name=_TEST_IMAGE_DATASET_ID)
163+
assert flowers_dataset.name == _TEST_IMAGE_DATASET_ID
164+
assert flowers_dataset.display_name == _TEST_DATASET_DISPLAY_NAME
165+
166+
def test_get_nonexistent_dataset(self):
167+
"""Ensure attempting to retrieve a dataset that doesn't exist raises
168+
a Google API core 404 exception."""
169+
170+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
171+
172+
# AI Platform service returns 404
173+
with pytest.raises(exceptions.NotFound):
174+
aiplatform.Dataset(dataset_name="0")
175+
176+
@pytest.mark.usefixtures("create_text_dataset", "delete_new_dataset")
177+
def test_get_new_dataset_and_import(self, dataset_gapic_client, shared_state):
178+
"""Retrieve new, empty dataset and import a text dataset using import().
179+
Then verify data items were successfully imported."""
180+
181+
assert shared_state["dataset_name"]
182+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
183+
184+
my_dataset = aiplatform.Dataset(dataset_name=shared_state["dataset_name"])
185+
186+
data_items_pre_import = dataset_gapic_client.list_data_items(
187+
parent=my_dataset.resource_name
188+
)
189+
190+
assert len(list(data_items_pre_import)) == 0
191+
192+
# Blocking call to import
193+
my_dataset.import_data(
194+
gcs_source=_TEST_TEXT_ENTITY_EXTRACTION_GCS_SOURCE,
195+
import_schema_uri=_TEST_TEXT_ENTITY_IMPORT_SCHEMA,
196+
)
197+
198+
data_items_post_import = dataset_gapic_client.list_data_items(
199+
parent=my_dataset.resource_name
200+
)
201+
202+
assert len(list(data_items_post_import)) == 469
203+
204+
@pytest.mark.usefixtures("delete_new_dataset")
205+
def test_create_and_import_image_dataset(self, dataset_gapic_client, shared_state):
206+
"""Use the Dataset.create() method to create a new image obj detection
207+
dataset and import images. Then confirm images were successfully imported."""
208+
209+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
210+
211+
img_dataset = aiplatform.Dataset.create(
212+
display_name=f"temp_sdk_integration_create_and_import_dataset_{uuid.uuid4()}",
213+
metadata_schema_uri=aiplatform.schema.dataset.metadata.image,
214+
gcs_source=_TEST_IMAGE_OBJECT_DETECTION_GCS_SOURCE,
215+
import_schema_uri=_TEST_IMAGE_OBJ_DET_IMPORT_SCHEMA,
216+
)
217+
218+
shared_state["dataset_name"] = img_dataset.resource_name
219+
220+
data_items_iterator = dataset_gapic_client.list_data_items(
221+
parent=img_dataset.resource_name
222+
)
223+
224+
assert len(list(data_items_iterator)) == 14
225+
226+
@pytest.mark.usefixtures("delete_new_dataset")
227+
def test_create_tabular_dataset(self, dataset_gapic_client, shared_state):
228+
"""Use the Dataset.create() method to create a new tabular dataset.
229+
Then confirm the dataset was successfully created and references GCS source."""
230+
231+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
232+
233+
tabular_dataset = aiplatform.Dataset.create(
234+
display_name=f"temp_sdk_integration_create_and_import_dataset_{uuid.uuid4()}",
235+
metadata_schema_uri=aiplatform.schema.dataset.metadata.tabular,
236+
gcs_source=[_TEST_TABULAR_CLASSIFICATION_GCS_SOURCE],
237+
)
238+
239+
gapic_dataset = tabular_dataset._gca_resource
240+
shared_state["dataset_name"] = tabular_dataset.resource_name
241+
242+
gapic_metadata = json_format.MessageToDict(gapic_dataset._pb.metadata)
243+
gcs_source_uris = gapic_metadata["inputConfig"]["gcsSource"]["uri"]
244+
245+
assert len(gcs_source_uris) == 1
246+
assert _TEST_TABULAR_CLASSIFICATION_GCS_SOURCE == gcs_source_uris[0]
247+
assert (
248+
gapic_dataset.metadata_schema_uri
249+
== aiplatform.schema.dataset.metadata.tabular
250+
)
251+
252+
# TODO(vinnys): Remove pytest skip once persistent resources are accessible
253+
@pytest.mark.skip(reason="System tests cannot access persistent test resources")
254+
@pytest.mark.usefixtures("create_staging_bucket", "delete_staging_bucket")
255+
def test_export_data(self, shared_state):
256+
"""Get an existing dataset, export data to a newly created folder in
257+
Google Cloud Storage, then verify data was successfully exported."""
258+
259+
assert shared_state["staging_bucket"]
260+
assert shared_state["storage_client"]
261+
262+
aiplatform.init(
263+
project=_TEST_PROJECT,
264+
location=_TEST_LOCATION,
265+
staging_bucket=f"gs://{shared_state['staging_bucket']}",
266+
)
267+
268+
text_dataset = aiplatform.Dataset(dataset_name=_TEST_TEXT_DATASET_ID)
269+
270+
exported_files = text_dataset.export_data(
271+
output_dir=f"gs://{shared_state['staging_bucket']}"
272+
)
273+
274+
assert len(exported_files) # Ensure at least one GCS path was returned
275+
276+
exported_file = exported_files[0]
277+
bucket, prefix = utils.extract_bucket_and_prefix_from_gcs_path(exported_file)
278+
279+
storage_client = shared_state["storage_client"]
280+
281+
bucket = storage_client.get_bucket(bucket)
282+
blob = bucket.get_blob(prefix)
283+
284+
assert blob # Verify the returned GCS export path exists

0 commit comments

Comments
 (0)