Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,7 @@ def run(
managed_model=managed_model,
args=args,
base_output_dir=base_output_dir,
bigquery_destination=bigquery_destination,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
Expand All @@ -1596,6 +1597,7 @@ def _run(
managed_model: Optional[gca_model.Model] = None,
args: Optional[List[Union[str, float, int]]] = None,
base_output_dir: Optional[str] = None,
bigquery_destination: Optional[str] = None,
training_fraction_split: float = 0.8,
validation_fraction_split: float = 0.1,
test_fraction_split: float = 0.1,
Expand All @@ -1618,6 +1620,21 @@ def _run(
base_output_dir (str):
GCS output directory of job. If not provided a
timestamped directory in the staging directory will be used.
bigquery_destination (str):
Provide this field if `dataset` is a BiqQuery dataset.
The BigQuery project location where the training data is to
be written to. In the given project a new dataset is created
with name
``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
training input data will be written into that dataset. In
the dataset three tables will be created, ``training``,
``validation`` and ``test``.

- AIP_DATA_FORMAT = "bigquery".
- AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
- AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
- AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
training_fraction_split (float):
The fraction of the input data that is to be
used to train the Model.
Expand Down Expand Up @@ -1679,6 +1696,7 @@ def _run(
predefined_split_column_name=predefined_split_column_name,
model=managed_model,
gcs_destination_uri_prefix=base_output_dir,
bigquery_destination=bigquery_destination,
)

return model
Expand Down
144 changes: 144 additions & 0 deletions tests/unit/aiplatform/test_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,150 @@ def test_run_call_pipeline_service_create(

assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED

@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_with_bigquery_destination(
self,
mock_pipeline_service_create,
mock_python_package_to_gcs,
mock_dataset,
mock_model_service_get,
sync,
):
aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME)

job = training_jobs.CustomTrainingJob(
display_name=_TEST_DISPLAY_NAME,
script_path=_TEST_LOCAL_SCRIPT_FILE_NAME,
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS,
model_description=_TEST_MODEL_DESCRIPTION,
)

model_from_job = job.run(
dataset=mock_dataset,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
bigquery_destination=_TEST_BIGQUERY_DESTINATION,
args=_TEST_RUN_ARGS,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction_split=_TEST_TEST_FRACTION_SPLIT,
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
sync=sync,
)

if not sync:
model_from_job.wait()

true_args = _TEST_RUN_ARGS

true_worker_pool_spec = {
"replicaCount": _TEST_REPLICA_COUNT,
"machineSpec": {
"machineType": _TEST_MACHINE_TYPE,
"acceleratorType": _TEST_ACCELERATOR_TYPE,
"acceleratorCount": _TEST_ACCELERATOR_COUNT,
},
"pythonPackageSpec": {
"executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE,
"pythonModule": training_jobs._TrainingScriptPythonPackager.module_name,
"packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
"args": true_args,
},
}

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

env = [
env_var.EnvVar(name=str(key), value=str(value))
for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items()
]

ports = [
gca_model.Port(container_port=port)
for port in _TEST_MODEL_SERVING_CONTAINER_PORTS
]

true_container_spec = gca_model.ModelContainerSpec(
image_uri=_TEST_SERVING_CONTAINER_IMAGE,
predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
env=env,
ports=ports,
)

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
description=_TEST_MODEL_DESCRIPTION,
container_spec=true_container_spec,
predict_schemata=gca_model.PredictSchemata(
instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
),
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
predefined_split=gca_training_pipeline.PredefinedSplit(
key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME
),
dataset_id=mock_dataset.name,
bigquery_destination=gca_io.BigQueryDestination(
output_uri=_TEST_BIGQUERY_DESTINATION
),
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
display_name=_TEST_DISPLAY_NAME,
training_task_definition=schema.training_job.definition.custom_task,
training_task_inputs=json_format.ParseDict(
{
"workerPoolSpecs": [true_worker_pool_spec],
"baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR},
},
struct_pb2.Value(),
),
model_to_upload=true_managed_model,
input_data_config=true_input_data_config,
)

mock_pipeline_service_create.assert_called_once_with(
parent=initializer.global_config.common_location_path(),
training_pipeline=true_training_pipeline,
)

assert job._gca_resource is mock_pipeline_service_create.return_value

mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)

assert model_from_job._gca_resource is mock_model_service_get.return_value

assert job.get_model()._gca_resource is mock_model_service_get.return_value

assert not job.has_failed

assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED

@pytest.mark.parametrize("sync", [True, False])
def test_run_called_twice_raises(
self,
Expand Down