diff --git a/automl/cloud-client/README.md b/automl/cloud-client/README.md index f9482342b51..3702992e591 100644 --- a/automl/cloud-client/README.md +++ b/automl/cloud-client/README.md @@ -63,6 +63,8 @@ small section of code to print out the `metadata` field. * [Deploy Model](src/main/java/com/example/automl/DeployModel.java) - Not supported by Translation * [Uneploy Model](src/main/java/com/example/automl/UndeployModel.java) - Not supported by Translation +### Batch Prediction +* [Batch Predict](src/main/java/com/example/automl/BatchPredict.java) - Supported by: Natural Language Entity Extraction, Vision Classification, and Vision Object Detection. ### Operation Management * [List Operation Statuses](src/main/java/com/example/automl/ListOperationStatus.java) diff --git a/automl/cloud-client/src/main/java/com/example/automl/BatchPredict.java b/automl/cloud-client/src/main/java/com/example/automl/BatchPredict.java new file mode 100644 index 00000000000..3eb4d9beaa8 --- /dev/null +++ b/automl/cloud-client/src/main/java/com/example/automl/BatchPredict.java @@ -0,0 +1,78 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.automl; + +// [START automl_batch_predict] +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.automl.v1.BatchPredictInputConfig; +import com.google.cloud.automl.v1.BatchPredictOutputConfig; +import com.google.cloud.automl.v1.BatchPredictRequest; +import com.google.cloud.automl.v1.BatchPredictResult; +import com.google.cloud.automl.v1.GcsDestination; +import com.google.cloud.automl.v1.GcsSource; +import com.google.cloud.automl.v1.ModelName; +import com.google.cloud.automl.v1.OperationMetadata; +import com.google.cloud.automl.v1.PredictionServiceClient; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; + +class BatchPredict { + + static void batchPredict() throws IOException, ExecutionException, InterruptedException { + // TODO(developer): Replace these variables before running the sample. + String projectId = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String inputUri = "gs://YOUR_BUCKET_ID/path_to_your_input_csv_or_jsonl"; + String outputUri = "gs://YOUR_BUCKET_ID/path_to_save_results/"; + batchPredict(projectId, modelId, inputUri, outputUri); + } + + static void batchPredict(String projectId, String modelId, String inputUri, String outputUri) + throws IOException, ExecutionException, InterruptedException { + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient client = PredictionServiceClient.create()) { + // Get the full path of the model. + ModelName name = ModelName.of(projectId, "us-central1", modelId); + GcsSource gcsSource = GcsSource.newBuilder().addInputUris(inputUri).build(); + BatchPredictInputConfig inputConfig = + BatchPredictInputConfig.newBuilder().setGcsSource(gcsSource).build(); + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(outputUri).build(); + BatchPredictOutputConfig outputConfig = + BatchPredictOutputConfig.newBuilder().setGcsDestination(gcsDestination).build(); + BatchPredictRequest request = + BatchPredictRequest.newBuilder() + .setName(name.toString()) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + // [0.0-1.0] Only produce results higher than this value + .putParams("score_threshold", "0.8") + .build(); + + OperationFuture future = + client.batchPredictAsync(request); + + System.out.println("Waiting for operation to complete..."); + BatchPredictResult response = future.get(); + System.out.println("Batch Prediction results saved to specified Cloud Storage bucket."); + } + } +} +// [END automl_batch_predict] diff --git a/automl/cloud-client/src/test/java/com/example/automl/LanguageEntityExtractionPredictIT.java b/automl/cloud-client/src/test/java/com/example/automl/LanguageEntityExtractionPredictIT.java index 8b5400d1a2c..7c9af56b6a6 100644 --- a/automl/cloud-client/src/test/java/com/example/automl/LanguageEntityExtractionPredictIT.java +++ b/automl/cloud-client/src/test/java/com/example/automl/LanguageEntityExtractionPredictIT.java @@ -86,7 +86,7 @@ public void testBatchPredict() throws IOException, ExecutionException, Interrupt String inputUri = String.format("gs://%s/entity_extraction/input.jsonl", BUCKET_ID); String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID); // Act - LanguageBatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri); + BatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri); // Assert String got = bout.toString(); diff --git a/automl/cloud-client/src/test/java/com/example/automl/VisionClassificationPredictIT.java b/automl/cloud-client/src/test/java/com/example/automl/VisionClassificationPredictIT.java index 532fe2ae258..ff60b569dd1 100644 --- a/automl/cloud-client/src/test/java/com/example/automl/VisionClassificationPredictIT.java +++ b/automl/cloud-client/src/test/java/com/example/automl/VisionClassificationPredictIT.java @@ -86,7 +86,7 @@ public void testBatchPredict() throws IOException, ExecutionException, Interrupt String inputUri = String.format("gs://%s/batch_predict_test.csv", BUCKET_ID); String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID); // Act - VisionBatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri); + BatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri); // Assert String got = bout.toString(); diff --git a/automl/cloud-client/src/test/java/com/example/automl/VisionObjectDetectionPredictIT.java b/automl/cloud-client/src/test/java/com/example/automl/VisionObjectDetectionPredictIT.java index 394947580ae..a9f5029033d 100644 --- a/automl/cloud-client/src/test/java/com/example/automl/VisionObjectDetectionPredictIT.java +++ b/automl/cloud-client/src/test/java/com/example/automl/VisionObjectDetectionPredictIT.java @@ -88,7 +88,7 @@ public void testBatchPredict() throws IOException, ExecutionException, Interrupt String.format("gs://%s/vision_object_detection_batch_predict_test.csv", BUCKET_ID); String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID); // Act - VisionBatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri); + BatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri); // Assert String got = bout.toString();