Skip to content

Commit 8d29dcd

Browse files
committed
Add client methods for Models API. (googleapis#494)
* Add client methods for Models API. * Adds hack to workaround milliseconds format for model.trainingRun.startTime. * Adds code samples for Models API, which double as system tests.
1 parent ecc0f18 commit 8d29dcd

File tree

14 files changed

+725
-2
lines changed

14 files changed

+725
-2
lines changed

bigquery/google/cloud/bigquery/client.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from google.cloud.bigquery.dataset import DatasetListItem
4747
from google.cloud.bigquery.dataset import DatasetReference
4848
from google.cloud.bigquery import job
49+
from google.cloud.bigquery.model import Model
50+
from google.cloud.bigquery.model import ModelReference
4951
from google.cloud.bigquery.query import _QueryResults
5052
from google.cloud.bigquery.retry import DEFAULT_RETRY
5153
from google.cloud.bigquery.table import Table
@@ -427,6 +429,33 @@ def get_dataset(self, dataset_ref, retry=DEFAULT_RETRY):
427429
api_response = self._call_api(retry, method="GET", path=dataset_ref.path)
428430
return Dataset.from_api_repr(api_response)
429431

432+
def get_model(self, model_ref, retry=DEFAULT_RETRY):
433+
"""Fetch the model referenced by ``model_ref``.
434+
435+
Args:
436+
model_ref (Union[ \
437+
:class:`~google.cloud.bigquery.model.ModelReference`, \
438+
str, \
439+
]):
440+
A reference to the model to fetch from the BigQuery API.
441+
If a string is passed in, this method attempts to create a
442+
model reference from a string using
443+
:func:`google.cloud.bigquery.model.ModelReference.from_string`.
444+
retry (:class:`google.api_core.retry.Retry`):
445+
(Optional) How to retry the RPC.
446+
447+
Returns:
448+
google.cloud.bigquery.model.Model:
449+
A ``Model`` instance.
450+
"""
451+
if isinstance(model_ref, str):
452+
model_ref = ModelReference.from_string(
453+
model_ref, default_project=self.project
454+
)
455+
456+
api_response = self._call_api(retry, method="GET", path=model_ref.path)
457+
return Model.from_api_repr(api_response)
458+
430459
def get_table(self, table_ref, retry=DEFAULT_RETRY):
431460
"""Fetch the table referenced by ``table_ref``.
432461
@@ -490,6 +519,41 @@ def update_dataset(self, dataset, fields, retry=DEFAULT_RETRY):
490519
)
491520
return Dataset.from_api_repr(api_response)
492521

522+
def update_model(self, model, fields, retry=DEFAULT_RETRY):
523+
"""Change some fields of a model.
524+
525+
Use ``fields`` to specify which fields to update. At least one field
526+
must be provided. If a field is listed in ``fields`` and is ``None``
527+
in ``model``, it will be deleted.
528+
529+
If ``model.etag`` is not ``None``, the update will only succeed if
530+
the model on the server has the same ETag. Thus reading a model with
531+
``get_model``, changing its fields, and then passing it to
532+
``update_model`` will ensure that the changes will only be saved if
533+
no modifications to the model occurred since the read.
534+
535+
Args:
536+
model (google.cloud.bigquery.model.Model): The model to update.
537+
fields (Sequence[str]):
538+
The fields of ``model`` to change, spelled as the Model
539+
properties (e.g. "friendly_name").
540+
retry (google.api_core.retry.Retry):
541+
(Optional) A description of how to retry the API call.
542+
543+
Returns:
544+
google.cloud.bigquery.model.Model:
545+
The model resource returned from the API call.
546+
"""
547+
partial = model._build_resource(fields)
548+
if model.etag:
549+
headers = {"If-Match": model.etag}
550+
else:
551+
headers = None
552+
api_response = self._call_api(
553+
retry, method="PATCH", path=model.path, data=partial, headers=headers
554+
)
555+
return Model.from_api_repr(api_response)
556+
493557
def update_table(self, table, fields, retry=DEFAULT_RETRY):
494558
"""Change some fields of a table.
495559
@@ -525,6 +589,64 @@ def update_table(self, table, fields, retry=DEFAULT_RETRY):
525589
)
526590
return Table.from_api_repr(api_response)
527591

592+
def list_models(
593+
self, dataset, max_results=None, page_token=None, retry=DEFAULT_RETRY
594+
):
595+
"""List models in the dataset.
596+
597+
See
598+
https://cloud.google.com/bigquery/docs/reference/rest/v2/models/list
599+
600+
Args:
601+
dataset (Union[ \
602+
:class:`~google.cloud.bigquery.dataset.Dataset`, \
603+
:class:`~google.cloud.bigquery.dataset.DatasetReference`, \
604+
str, \
605+
]):
606+
A reference to the dataset whose models to list from the
607+
BigQuery API. If a string is passed in, this method attempts
608+
to create a dataset reference from a string using
609+
:func:`google.cloud.bigquery.dataset.DatasetReference.from_string`.
610+
max_results (int):
611+
(Optional) Maximum number of models to return. If not passed,
612+
defaults to a value set by the API.
613+
page_token (str):
614+
(Optional) Token representing a cursor into the models. If
615+
not passed, the API will return the first page of models. The
616+
token marks the beginning of the iterator to be returned and
617+
the value of the ``page_token`` can be accessed at
618+
``next_page_token`` of the
619+
:class:`~google.api_core.page_iterator.HTTPIterator`.
620+
retry (:class:`google.api_core.retry.Retry`):
621+
(Optional) How to retry the RPC.
622+
623+
Returns:
624+
google.api_core.page_iterator.Iterator:
625+
Iterator of
626+
:class:`~google.cloud.bigquery.model.Model` contained
627+
within the requested dataset.
628+
"""
629+
if isinstance(dataset, str):
630+
dataset = DatasetReference.from_string(
631+
dataset, default_project=self.project
632+
)
633+
634+
if not isinstance(dataset, (Dataset, DatasetReference)):
635+
raise TypeError("dataset must be a Dataset, DatasetReference, or string")
636+
637+
path = "%s/models" % dataset.path
638+
result = page_iterator.HTTPIterator(
639+
client=self,
640+
api_request=functools.partial(self._call_api, retry),
641+
path=path,
642+
item_to_value=_item_to_model,
643+
items_key="models",
644+
page_token=page_token,
645+
max_results=max_results,
646+
)
647+
result.dataset = dataset
648+
return result
649+
528650
def list_tables(
529651
self, dataset, max_results=None, page_token=None, retry=DEFAULT_RETRY
530652
):
@@ -631,6 +753,40 @@ def delete_dataset(
631753
if not not_found_ok:
632754
raise
633755

756+
def delete_model(self, model, retry=DEFAULT_RETRY, not_found_ok=False):
757+
"""Delete a model
758+
759+
See
760+
https://cloud.google.com/bigquery/docs/reference/rest/v2/models/delete
761+
762+
Args:
763+
model (Union[ \
764+
:class:`~google.cloud.bigquery.model.Model`, \
765+
:class:`~google.cloud.bigquery.model.ModelReference`, \
766+
str, \
767+
]):
768+
A reference to the model to delete. If a string is passed in,
769+
this method attempts to create a model reference from a
770+
string using
771+
:func:`google.cloud.bigquery.model.ModelReference.from_string`.
772+
retry (:class:`google.api_core.retry.Retry`):
773+
(Optional) How to retry the RPC.
774+
not_found_ok (bool):
775+
Defaults to ``False``. If ``True``, ignore "not found" errors
776+
when deleting the model.
777+
"""
778+
if isinstance(model, str):
779+
model = ModelReference.from_string(model, default_project=self.project)
780+
781+
if not isinstance(model, (Model, ModelReference)):
782+
raise TypeError("model must be a Model or a ModelReference")
783+
784+
try:
785+
self._call_api(retry, method="DELETE", path=model.path)
786+
except google.api_core.exceptions.NotFound:
787+
if not not_found_ok:
788+
raise
789+
634790
def delete_table(self, table, retry=DEFAULT_RETRY, not_found_ok=False):
635791
"""Delete a table
636792
@@ -1810,6 +1966,21 @@ def _item_to_job(iterator, resource):
18101966
return iterator.client.job_from_resource(resource)
18111967

18121968

1969+
def _item_to_model(iterator, resource):
1970+
"""Convert a JSON model to the native object.
1971+
1972+
Args:
1973+
iterator (google.api_core.page_iterator.Iterator):
1974+
The iterator that is currently in use.
1975+
resource (dict):
1976+
An item to be converted to a model.
1977+
1978+
Returns:
1979+
google.cloud.bigquery.model.Model: The next model in the page.
1980+
"""
1981+
return Model.from_api_repr(resource)
1982+
1983+
18131984
def _item_to_table(iterator, resource):
18141985
"""Convert a JSON table to the native object.
18151986

bigquery/google/cloud/bigquery/model.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616

1717
"""Define resources for the BigQuery ML Models API."""
1818

19-
import datetime
19+
import copy
2020

2121
from google.protobuf import json_format
2222
import six
2323

2424
import google.cloud._helpers
25+
from google.api_core import datetime_helpers
2526
from google.cloud.bigquery import _helpers
2627
from google.cloud.bigquery_v2 import types
2728

@@ -83,6 +84,26 @@ def reference(self):
8384
ref._proto = self._proto.model_reference
8485
return ref
8586

87+
@property
88+
def project(self):
89+
"""str: Project bound to the model"""
90+
return self.reference.project
91+
92+
@property
93+
def dataset_id(self):
94+
"""str: ID of dataset containing the model."""
95+
return self.reference.dataset_id
96+
97+
@property
98+
def model_id(self):
99+
"""str: The model ID."""
100+
return self.reference.model_id
101+
102+
@property
103+
def path(self):
104+
"""str: URL path for the model's APIs."""
105+
return self.reference.path
106+
86107
@property
87108
def location(self):
88109
"""str: The geographic location where the model resides. This value
@@ -192,7 +213,7 @@ def expires(self):
192213
@expires.setter
193214
def expires(self, value):
194215
if value is not None:
195-
value = google.cloud._helpers._millis_from_datetime(value)
216+
value = str(google.cloud._helpers._millis_from_datetime(value))
196217
self._properties["expirationTime"] = value
197218

198219
@property
@@ -247,6 +268,17 @@ def from_api_repr(cls, resource):
247268
google.cloud.bigquery.model.Model: Model parsed from ``resource``.
248269
"""
249270
this = cls(None)
271+
272+
# Convert from millis-from-epoch to timestamp well-known type.
273+
# TODO: Remove this hack once CL 238585470 hits prod.
274+
resource = copy.deepcopy(resource)
275+
for training_run in resource.get("trainingRuns", ()):
276+
start_time = training_run.get("startTime")
277+
if not start_time or "-" in start_time: # Already right format?
278+
continue
279+
start_time = datetime_helpers.from_microseconds(1e3 * float(start_time))
280+
training_run["startTime"] = datetime_helpers.to_rfc3339(start_time)
281+
250282
this._proto = json_format.ParseDict(resource, types.Model())
251283
for key in six.itervalues(cls._PROPERTY_TO_API_FIELD):
252284
# Leave missing keys unset. This allows us to use setdefault in the
@@ -288,6 +320,15 @@ def model_id(self):
288320
"""str: The model ID."""
289321
return self._proto.model_id
290322

323+
@property
324+
def path(self):
325+
"""str: URL path for the model's APIs."""
326+
return "/projects/%s/datasets/%s/models/%s" % (
327+
self._proto.project_id,
328+
self._proto.dataset_id,
329+
self._proto.model_id,
330+
)
331+
291332
@classmethod
292333
def from_api_repr(cls, resource):
293334
"""Factory: construct a model reference given its API representation

bigquery/noxfile.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def snippets(session):
129129
# Run py.test against the snippets tests.
130130
session.run(
131131
'py.test', os.path.join('docs', 'snippets.py'), *session.posargs)
132+
session.run(
133+
'py.test', os.path.join('samples'), *session.posargs)
132134

133135

134136
@nox.session(python='3.6')
@@ -178,6 +180,7 @@ def blacken(session):
178180
session.run(
179181
"black",
180182
"google",
183+
"samples",
181184
"tests",
182185
"docs",
183186
)

bigquery/samples/__init__.py

Whitespace-only changes.

bigquery/samples/delete_model.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def delete_model(client, model_id):
17+
"""Sample ID: go/samples-tracker/1534"""
18+
19+
# [START bigquery_delete_model]
20+
from google.cloud import bigquery
21+
22+
# TODO(developer): Construct a BigQuery client object.
23+
# client = bigquery.Client()
24+
25+
# TODO(developer): Set model_id to the ID of the model to fetch.
26+
# model_id = 'your-project.your_dataset.your_model'
27+
28+
client.delete_model(model_id)
29+
# [END bigquery_delete_model]
30+
31+
print("Deleted model '{}'.".format(model_id))

bigquery/samples/get_model.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def get_model(client, model_id):
17+
"""Sample ID: go/samples-tracker/1510"""
18+
19+
# [START bigquery_get_model]
20+
from google.cloud import bigquery
21+
22+
# TODO(developer): Construct a BigQuery client object.
23+
# client = bigquery.Client()
24+
25+
# TODO(developer): Set model_id to the ID of the model to fetch.
26+
# model_id = 'your-project.your_dataset.your_model'
27+
28+
model = client.get_model(model_id)
29+
30+
full_model_id = "{}.{}.{}".format(model.project, model.dataset_id, model.model_id)
31+
friendly_name = model.friendly_name
32+
print(
33+
"Got model '{}' with friendly_name '{}'.".format(full_model_id, friendly_name)
34+
)
35+
# [END bigquery_get_model]

0 commit comments

Comments
 (0)