Skip to content

Commit 4459fff

Browse files
morganduIvan Cheung
and
Ivan Cheung
authored
feat: specialized dataset classes, fix: datasets refactor (#153)
* feat: Refactored Dataset by removing intermediate layers * Added image_dataset and tabular_dataset subclass * Moved metadata_schema_uri responsibility to subclass to enable forecasting * Moved validation logic for tabular into Dataset._create_tabular * Added validation in image_dataset and fixed bounding_box schema error * Removed import_config * Fixed metadata_schema_uri * Fixed import and subclasses * Added EmptyNontabularDatasource * change import_metadata to ioformat * added datasources.py * added support of multiple gcs_sources * fix: default (empty) dataset_metadata need to be set to {}, not None * 1) imported datasources 2) added _support_metadata_schema_uris and _support_import_schema_classes 3) added getter and setter/validation for resource_metadata_schema_uri, metadata_schema_uri, and import_schema_uri 4) fixed request_metadata, data_item_labels 5) encapsulated dataset_metadata, and import_data_configs 6) added datasource configuration logic * added image_dataset.py and tabular_dataset.py * fix: refactor - create datasets modeule * fix: cleanup __init__.py * fix: data_item_labels * fix: docstring * fix: - changed NonTabularDatasource.dataset_metadata default to None - updated NonTabularDatasource docstring - changed gcs_source type hint with Union - changed _create_and_import to _create_encapsulated with datasource - removed subclass.__init__ and irrelevant parameters in create * fix: import the module instead of the classes for datasources * fix: removed all validation for import_schema_uri * fix: set parameter default to immutable * fix: replaced Datasource / DatasourceImportable abstract class instead of a concrete type * fix: added examples for gcs_source * fix: - remove Sequence from utils.py - refactor datasources.py to _datasources.py - change docstring format to arg_name (arg_type): convention - change and include the type signature _supported_metadata_schema_uris - change _validate_metadata_schema_uri - refactor _create_encapsulated to _create_and_import - refactor to module level imports - add tests for ImageDataset and TabularDataset * fix: remove all labels * fix: remove Optional in docstring, add example for bq_source * test: add import_data raise for tabular dataset test * fix: refactor datasource creation with create_datasource * fix: lint Co-authored-by: Ivan Cheung <[email protected]>
1 parent ee6e275 commit 4459fff

File tree

9 files changed

+949
-260
lines changed

9 files changed

+949
-260
lines changed

google/cloud/aiplatform/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
from google.cloud.aiplatform import gapic
1919

2020
from google.cloud.aiplatform import initializer
21-
from google.cloud.aiplatform.datasets import Dataset
21+
from google.cloud.aiplatform.datasets import (
22+
Dataset,
23+
TabularDataset,
24+
ImageDataset,
25+
)
2226
from google.cloud.aiplatform.models import Endpoint
2327
from google.cloud.aiplatform.models import Model
2428
from google.cloud.aiplatform.jobs import BatchPredictionJob
@@ -42,5 +46,7 @@
4246
"AutoMLTabularTrainingJob",
4347
"Model",
4448
"Dataset",
49+
"TabularDataset",
50+
"ImageDataset",
4551
"Endpoint",
4652
)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2020 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from google.cloud.aiplatform.datasets.dataset import Dataset
19+
from google.cloud.aiplatform.datasets.tabular_dataset import TabularDataset
20+
from google.cloud.aiplatform.datasets.image_dataset import ImageDataset
21+
22+
__all__ = (
23+
"Dataset",
24+
"TabularDataset",
25+
"ImageDataset",
26+
)
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
import abc
2+
from typing import Optional, Dict, Sequence, Union
3+
from google.cloud.aiplatform_v1beta1.types import io as gca_io
4+
from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset
5+
6+
from google.cloud.aiplatform import schema
7+
8+
9+
class Datasource(abc.ABC):
10+
"""An abstract class that sets dataset_metadata"""
11+
12+
@property
13+
@abc.abstractmethod
14+
def dataset_metadata(self):
15+
"""Dataset Metadata."""
16+
pass
17+
18+
19+
class DatasourceImportable(abc.ABC):
20+
"""An abstract class that sets import_data_config"""
21+
22+
@property
23+
@abc.abstractmethod
24+
def import_data_config(self):
25+
"""Import Data Config."""
26+
pass
27+
28+
29+
class TabularDatasource(Datasource):
30+
"""Datasource for creating a tabular dataset for AI Platform"""
31+
32+
def __init__(
33+
self,
34+
gcs_source: Optional[Union[str, Sequence[str]]] = None,
35+
bq_source: Optional[str] = None,
36+
):
37+
"""Creates a tabular datasource
38+
39+
Args:
40+
gcs_source (Union[str, Sequence[str]]):
41+
Cloud Storage URI of one or more files. Only CSV files are supported.
42+
The first line of the CSV file is used as the header.
43+
If there are multiple files, the header is the first line of
44+
the lexicographically first file, the other files must either
45+
contain the exact same header or omit the header.
46+
examples:
47+
str: "gs://bucket/file.csv"
48+
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
49+
bq_source (str):
50+
The URI of a BigQuery table.
51+
example:
52+
"bq://project.dataset.table_name"
53+
54+
Raises:
55+
ValueError if source configuration is not valid.
56+
"""
57+
58+
dataset_metadata = None
59+
60+
if gcs_source and isinstance(gcs_source, str):
61+
gcs_source = [gcs_source]
62+
63+
if gcs_source and bq_source:
64+
raise ValueError("Only one of gcs_source or bq_source can be set.")
65+
66+
if not any([gcs_source, bq_source]):
67+
raise ValueError("One of gcs_source or bq_source must be set.")
68+
69+
if gcs_source:
70+
dataset_metadata = {"input_config": {"gcs_source": {"uri": gcs_source}}}
71+
elif bq_source:
72+
dataset_metadata = {"input_config": {"bigquery_source": {"uri": bq_source}}}
73+
74+
self._dataset_metadata = dataset_metadata
75+
76+
@property
77+
def dataset_metadata(self) -> Optional[Dict]:
78+
"""Dataset Metadata."""
79+
return self._dataset_metadata
80+
81+
82+
class NonTabularDatasource(Datasource):
83+
"""Datasource for creating an empty non-tabular dataset for AI Platform"""
84+
85+
@property
86+
def dataset_metadata(self) -> Optional[Dict]:
87+
return None
88+
89+
90+
class NonTabularDatasourceImportable(NonTabularDatasource, DatasourceImportable):
91+
"""Datasource for creating a non-tabular dataset for AI Platform and importing data to the dataset"""
92+
93+
def __init__(
94+
self,
95+
gcs_source: Union[str, Sequence[str]],
96+
import_schema_uri: str,
97+
data_item_labels: Optional[Dict] = None,
98+
):
99+
"""Creates a non-tabular datasource
100+
101+
Args:
102+
gcs_source (Union[str, Sequence[str]]):
103+
Required. The Google Cloud Storage location for the input content.
104+
Google Cloud Storage URI(-s) to the input file(s). May contain
105+
wildcards. For more information on wildcards, see
106+
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
107+
examples:
108+
str: "gs://bucket/file.csv"
109+
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
110+
import_schema_uri (str):
111+
Required. Points to a YAML file stored on Google Cloud
112+
Storage describing the import format. Validation will be
113+
done against the schema. The schema is defined as an
114+
`OpenAPI 3.0.2 Schema
115+
data_item_labels (Dict):
116+
Labels that will be applied to newly imported DataItems. If
117+
an identical DataItem as one being imported already exists
118+
in the Dataset, then these labels will be appended to these
119+
of the already existing one, and if labels with identical
120+
key is imported before, the old label value will be
121+
overwritten. If two DataItems are identical in the same
122+
import data operation, the labels will be combined and if
123+
key collision happens in this case, one of the values will
124+
be picked randomly. Two DataItems are considered identical
125+
if their content bytes are identical (e.g. image bytes or
126+
pdf bytes). These labels will be overridden by Annotation
127+
labels specified inside index file refenced by
128+
``import_schema_uri``,
129+
e.g. jsonl file.
130+
"""
131+
super().__init__()
132+
self._gcs_source = [gcs_source] if isinstance(gcs_source, str) else gcs_source
133+
self._import_schema_uri = import_schema_uri
134+
self._data_item_labels = data_item_labels
135+
136+
@property
137+
def import_data_config(self) -> gca_dataset.ImportDataConfig:
138+
"""Import Data Config."""
139+
return gca_dataset.ImportDataConfig(
140+
gcs_source=gca_io.GcsSource(uris=self._gcs_source),
141+
import_schema_uri=self._import_schema_uri,
142+
data_item_labels=self._data_item_labels,
143+
)
144+
145+
146+
def create_datasource(
147+
metadata_schema_uri: str,
148+
import_schema_uri: Optional[str] = None,
149+
gcs_source: Optional[Union[str, Sequence[str]]] = None,
150+
bq_source: Optional[str] = None,
151+
data_item_labels: Optional[Dict] = None,
152+
) -> Datasource:
153+
"""Creates a datasource
154+
Args:
155+
metadata_schema_uri (str):
156+
Required. Points to a YAML file stored on Google Cloud Storage
157+
describing additional information about the Dataset. The schema
158+
is defined as an OpenAPI 3.0.2 Schema Object. The schema files
159+
that can be used here are found in gs://google-cloud-
160+
aiplatform/schema/dataset/metadata/.
161+
import_schema_uri (str):
162+
Points to a YAML file stored on Google Cloud
163+
Storage describing the import format. Validation will be
164+
done against the schema. The schema is defined as an
165+
`OpenAPI 3.0.2 Schema
166+
gcs_source (Union[str, Sequence[str]]):
167+
The Google Cloud Storage location for the input content.
168+
Google Cloud Storage URI(-s) to the input file(s). May contain
169+
wildcards. For more information on wildcards, see
170+
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
171+
examples:
172+
str: "gs://bucket/file.csv"
173+
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
174+
bq_source (str):
175+
BigQuery URI to the input table.
176+
example:
177+
"bq://project.dataset.table_name"
178+
data_item_labels (Dict):
179+
Labels that will be applied to newly imported DataItems. If
180+
an identical DataItem as one being imported already exists
181+
in the Dataset, then these labels will be appended to these
182+
of the already existing one, and if labels with identical
183+
key is imported before, the old label value will be
184+
overwritten. If two DataItems are identical in the same
185+
import data operation, the labels will be combined and if
186+
key collision happens in this case, one of the values will
187+
be picked randomly. Two DataItems are considered identical
188+
if their content bytes are identical (e.g. image bytes or
189+
pdf bytes). These labels will be overridden by Annotation
190+
labels specified inside index file refenced by
191+
``import_schema_uri``,
192+
e.g. jsonl file.
193+
194+
Returns:
195+
datasource (Datasource)
196+
197+
Raises:
198+
ValueError when below scenarios happen
199+
- import_schema_uri is identified for creating TabularDatasource
200+
- either import_schema_uri or gcs_source is missing for creating NonTabularDatasourceImportable
201+
"""
202+
203+
if metadata_schema_uri == schema.dataset.metadata.tabular:
204+
if import_schema_uri:
205+
raise ValueError("tabular dataset does not support data import.")
206+
return TabularDatasource(gcs_source, bq_source)
207+
208+
if not import_schema_uri and not gcs_source:
209+
return NonTabularDatasource()
210+
elif import_schema_uri and gcs_source:
211+
return NonTabularDatasourceImportable(
212+
gcs_source, import_schema_uri, data_item_labels
213+
)
214+
else:
215+
raise ValueError(
216+
"nontabular dataset requires both import_schema_uri and gcs_source for data import."
217+
)

0 commit comments

Comments
 (0)