Skip to content

Commit 068e507

Browse files
sasha-gitgivanmkcIvan Cheungmorganduvinnysenthil
authored
chore: merge dev into mb-experimental-release (#185)
* fix: unblock builds (#132) * chore: Update README with Experimental verbiage (#131) * fix: Fixed comments (#116) Co-authored-by: Ivan Cheung <[email protected]> * feat: Implements a wrapped client that instantiates the client at every API invocation (#139) * feat: Added optional model args for custom training (#129) * Added optional model args * fix: Removed etag * fix: Added predict schemata and fixed type error * fix: Added description and fixed predict_schemata * Added _model_serving_container_command, _model_serving_container_args, env=self._model_serving_container_environment_variables and _model_serving_container_ports * fix: Ran linter * fix: Added tests for model_instance_schema_uri, model_parameters_schema_uri and model_prediction_schema_uri * fix: Fixed env and ports and added tests * fix: Removed model_labels * fix: Moved container spec creation into init function * fix: Fixed docstrings * fix: Moved import to be alphabetical * fix: Moved model creation to init function * fix: Fixed predict_schemata * fix: simplified predict schemata * fix: added linter * fix: Fixed trailing comma * fix: Removed CustomTrainingJob private fields * fix: Fixed model tests * fix: Set managed_model to None Co-authored-by: Ivan Cheung <[email protected]> * Fix: refactor class constructor for retrieving resource (#125) * Added property and abstract method _getter_method and _resource_noun, implemented method _get_gca_resource to class AiPlatformResourceNoun; Added _resource_noun, _getter_method, to Dataset, Model, Endpoint, subclasses of _Job, _TrainingJob, refactored (_)get_* and utils.full_resource_name in class constructor to self._get_gca_resource to Dataset, Model, Endpoint, _Job * Added return value in _get_gca_resource, added method _sync_gca_resource in AiPlatformResourceNoun class; removed job_type, updated status method with _sync_gca_resource in _Job class * fix: added return type and lint issues * fix: merge conflict issue with models.py * fix: F401 'abc' imported but unused * chore: merge main into dev (#154) * test: Dataset integration tests (#126) * Add dataset.metadata.text to schemas * Add first integation tests, Dataset class * Make teardown work if test fails, update asserts * Change test folder name, enable system tests * Hide test_base, test_end_to_end for Kokoro CI bug * Add GCP Project env var to Kokoro presubmit cfg * Restore presubmit cfg, drop --quiet in unit tests * Restore test_base, test_end_to_end to find timeout * Skip tests depending on persistent resources * Use auth default creds for system tests * Drop unused import os * feat: specialized dataset classes, fix: datasets refactor (#153) * feat: Refactored Dataset by removing intermediate layers * Added image_dataset and tabular_dataset subclass * Moved metadata_schema_uri responsibility to subclass to enable forecasting * Moved validation logic for tabular into Dataset._create_tabular * Added validation in image_dataset and fixed bounding_box schema error * Removed import_config * Fixed metadata_schema_uri * Fixed import and subclasses * Added EmptyNontabularDatasource * change import_metadata to ioformat * added datasources.py * added support of multiple gcs_sources * fix: default (empty) dataset_metadata need to be set to {}, not None * 1) imported datasources 2) added _support_metadata_schema_uris and _support_import_schema_classes 3) added getter and setter/validation for resource_metadata_schema_uri, metadata_schema_uri, and import_schema_uri 4) fixed request_metadata, data_item_labels 5) encapsulated dataset_metadata, and import_data_configs 6) added datasource configuration logic * added image_dataset.py and tabular_dataset.py * fix: refactor - create datasets modeule * fix: cleanup __init__.py * fix: data_item_labels * fix: docstring * fix: - changed NonTabularDatasource.dataset_metadata default to None - updated NonTabularDatasource docstring - changed gcs_source type hint with Union - changed _create_and_import to _create_encapsulated with datasource - removed subclass.__init__ and irrelevant parameters in create * fix: import the module instead of the classes for datasources * fix: removed all validation for import_schema_uri * fix: set parameter default to immutable * fix: replaced Datasource / DatasourceImportable abstract class instead of a concrete type * fix: added examples for gcs_source * fix: - remove Sequence from utils.py - refactor datasources.py to _datasources.py - change docstring format to arg_name (arg_type): convention - change and include the type signature _supported_metadata_schema_uris - change _validate_metadata_schema_uri - refactor _create_encapsulated to _create_and_import - refactor to module level imports - add tests for ImageDataset and TabularDataset * fix: remove all labels * fix: remove Optional in docstring, add example for bq_source * test: add import_data raise for tabular dataset test * fix: refactor datasource creation with create_datasource * fix: lint Co-authored-by: Ivan Cheung <[email protected]> * feat: Add AutoML Image Training Job class (#152) * Add AutoMLImageTrainingJob, tests, constants * Address reviewer comments * feat: Add custom container support (#164) * chore: merge main into dev (#162) * fix: suppress no project id warning (#160) * fix: suppress no project id warning * fix: temporary suppress logging.WARNING and set credentials as google.auth.default credentials * fix: move default credentials config to credentials property * fix: add property setter for credentials to avoid everytime reset * fix: Fixed wrong key value for multilabel (#168) Co-authored-by: Ivan Cheung <[email protected]> * feat: Add delete methods, add list_models and undeploy_all for Endpoint class (#165) * Endpoint list_models, delete, undeploy_all WIP * Finish delete + undeploy methods, tests * Add global pool teardowns for test timeout issue * Address reviewer comments, add async support * fix: Fixed bug causing training failure for object detection (#171) Co-authored-by: Ivan Cheung <[email protected]> * fix: Support intermediary BQ Table for Custom Training (#166) * chore: add AutoMLImageTrainingJob to aiplatform namespace (#173) * fix: Unblock build (#174) * fix: default credentials config related test failures (#167) * fix: suppress no project id warning * fix: temporary suppress logging.WARNING and set credentials as google.auth.default credentials * fix: move default credentials config to credentials property * fix: add property setter for credentials to avoid everytime reset * fix: tests for set credentials to default when default not provided * fix: change credentials with initializer default when not provided in AiPlatformResourceNoun * fix: use credential mock in tests * fix: lint Co-authored-by: sasha-gitg <[email protected]> * Fix: pass bq_destination to input data config when using training script (#181) * fix: pass bigquery destination * fix: add tests and formatting Co-authored-by: Ivan Cheung <[email protected]> Co-authored-by: Ivan Cheung <[email protected]> Co-authored-by: Morgan Du <[email protected]> Co-authored-by: Vinny Senthil <[email protected]>
1 parent 4ea20ae commit 068e507

File tree

2 files changed

+162
-0
lines changed

2 files changed

+162
-0
lines changed

google/cloud/aiplatform/training_jobs.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,7 @@ def run(
15801580
managed_model=managed_model,
15811581
args=args,
15821582
base_output_dir=base_output_dir,
1583+
bigquery_destination=bigquery_destination,
15831584
training_fraction_split=training_fraction_split,
15841585
validation_fraction_split=validation_fraction_split,
15851586
test_fraction_split=test_fraction_split,
@@ -1596,6 +1597,7 @@ def _run(
15961597
managed_model: Optional[gca_model.Model] = None,
15971598
args: Optional[List[Union[str, float, int]]] = None,
15981599
base_output_dir: Optional[str] = None,
1600+
bigquery_destination: Optional[str] = None,
15991601
training_fraction_split: float = 0.8,
16001602
validation_fraction_split: float = 0.1,
16011603
test_fraction_split: float = 0.1,
@@ -1618,6 +1620,21 @@ def _run(
16181620
base_output_dir (str):
16191621
GCS output directory of job. If not provided a
16201622
timestamped directory in the staging directory will be used.
1623+
bigquery_destination (str):
1624+
Provide this field if `dataset` is a BiqQuery dataset.
1625+
The BigQuery project location where the training data is to
1626+
be written to. In the given project a new dataset is created
1627+
with name
1628+
``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
1629+
where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
1630+
training input data will be written into that dataset. In
1631+
the dataset three tables will be created, ``training``,
1632+
``validation`` and ``test``.
1633+
1634+
- AIP_DATA_FORMAT = "bigquery".
1635+
- AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
1636+
- AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
1637+
- AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
16211638
training_fraction_split (float):
16221639
The fraction of the input data that is to be
16231640
used to train the Model.
@@ -1679,6 +1696,7 @@ def _run(
16791696
predefined_split_column_name=predefined_split_column_name,
16801697
model=managed_model,
16811698
gcs_destination_uri_prefix=base_output_dir,
1699+
bigquery_destination=bigquery_destination,
16821700
)
16831701

16841702
return model

tests/unit/aiplatform/test_training_jobs.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,150 @@ def test_run_call_pipeline_service_create(
581581

582582
assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
583583

584+
@pytest.mark.parametrize("sync", [True, False])
585+
def test_run_call_pipeline_service_create_with_bigquery_destination(
586+
self,
587+
mock_pipeline_service_create,
588+
mock_python_package_to_gcs,
589+
mock_dataset,
590+
mock_model_service_get,
591+
sync,
592+
):
593+
aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME)
594+
595+
job = training_jobs.CustomTrainingJob(
596+
display_name=_TEST_DISPLAY_NAME,
597+
script_path=_TEST_LOCAL_SCRIPT_FILE_NAME,
598+
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
599+
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
600+
model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
601+
model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
602+
model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
603+
model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
604+
model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
605+
model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
606+
model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
607+
model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
608+
model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS,
609+
model_description=_TEST_MODEL_DESCRIPTION,
610+
)
611+
612+
model_from_job = job.run(
613+
dataset=mock_dataset,
614+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
615+
bigquery_destination=_TEST_BIGQUERY_DESTINATION,
616+
args=_TEST_RUN_ARGS,
617+
replica_count=1,
618+
machine_type=_TEST_MACHINE_TYPE,
619+
accelerator_type=_TEST_ACCELERATOR_TYPE,
620+
accelerator_count=_TEST_ACCELERATOR_COUNT,
621+
model_display_name=_TEST_MODEL_DISPLAY_NAME,
622+
training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT,
623+
validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT,
624+
test_fraction_split=_TEST_TEST_FRACTION_SPLIT,
625+
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
626+
sync=sync,
627+
)
628+
629+
if not sync:
630+
model_from_job.wait()
631+
632+
true_args = _TEST_RUN_ARGS
633+
634+
true_worker_pool_spec = {
635+
"replicaCount": _TEST_REPLICA_COUNT,
636+
"machineSpec": {
637+
"machineType": _TEST_MACHINE_TYPE,
638+
"acceleratorType": _TEST_ACCELERATOR_TYPE,
639+
"acceleratorCount": _TEST_ACCELERATOR_COUNT,
640+
},
641+
"pythonPackageSpec": {
642+
"executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE,
643+
"pythonModule": training_jobs._TrainingScriptPythonPackager.module_name,
644+
"packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
645+
"args": true_args,
646+
},
647+
}
648+
649+
true_fraction_split = gca_training_pipeline.FractionSplit(
650+
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
651+
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
652+
test_fraction=_TEST_TEST_FRACTION_SPLIT,
653+
)
654+
655+
env = [
656+
env_var.EnvVar(name=str(key), value=str(value))
657+
for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items()
658+
]
659+
660+
ports = [
661+
gca_model.Port(container_port=port)
662+
for port in _TEST_MODEL_SERVING_CONTAINER_PORTS
663+
]
664+
665+
true_container_spec = gca_model.ModelContainerSpec(
666+
image_uri=_TEST_SERVING_CONTAINER_IMAGE,
667+
predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
668+
health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
669+
command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
670+
args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
671+
env=env,
672+
ports=ports,
673+
)
674+
675+
true_managed_model = gca_model.Model(
676+
display_name=_TEST_MODEL_DISPLAY_NAME,
677+
description=_TEST_MODEL_DESCRIPTION,
678+
container_spec=true_container_spec,
679+
predict_schemata=gca_model.PredictSchemata(
680+
instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
681+
parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
682+
prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
683+
),
684+
)
685+
686+
true_input_data_config = gca_training_pipeline.InputDataConfig(
687+
fraction_split=true_fraction_split,
688+
predefined_split=gca_training_pipeline.PredefinedSplit(
689+
key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME
690+
),
691+
dataset_id=mock_dataset.name,
692+
bigquery_destination=gca_io.BigQueryDestination(
693+
output_uri=_TEST_BIGQUERY_DESTINATION
694+
),
695+
)
696+
697+
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
698+
display_name=_TEST_DISPLAY_NAME,
699+
training_task_definition=schema.training_job.definition.custom_task,
700+
training_task_inputs=json_format.ParseDict(
701+
{
702+
"workerPoolSpecs": [true_worker_pool_spec],
703+
"baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR},
704+
},
705+
struct_pb2.Value(),
706+
),
707+
model_to_upload=true_managed_model,
708+
input_data_config=true_input_data_config,
709+
)
710+
711+
mock_pipeline_service_create.assert_called_once_with(
712+
parent=initializer.global_config.common_location_path(),
713+
training_pipeline=true_training_pipeline,
714+
)
715+
716+
assert job._gca_resource is mock_pipeline_service_create.return_value
717+
718+
mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
719+
720+
assert model_from_job._gca_resource is mock_model_service_get.return_value
721+
722+
assert job.get_model()._gca_resource is mock_model_service_get.return_value
723+
724+
assert not job.has_failed
725+
726+
assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
727+
584728
@pytest.mark.parametrize("sync", [True, False])
585729
def test_run_called_twice_raises(
586730
self,

0 commit comments

Comments
 (0)