diff --git a/.kokoro/build.sh b/.kokoro/build.sh
index 048d80ac8..e594535ac 100755
--- a/.kokoro/build.sh
+++ b/.kokoro/build.sh
@@ -77,6 +77,11 @@ samples)
source "${KOKORO_GFILE_DIR}/secret_manager/ucaip_samples_secrets"
fi
+ if [ -f "${KOKORO_GFILE_DIR}/secret_manager/java-aiplatform-samples-secrets" ]
+ then
+ source "${KOKORO_GFILE_DIR}/secret_manager/java-aiplatform-samples-secrets"
+ fi
+
pushd samples
mvn -B \
-Penable-samples \
diff --git a/.kokoro/presubmit/samples.cfg b/.kokoro/presubmit/samples.cfg
index e1c63d5c8..50f8f3fa0 100644
--- a/.kokoro/presubmit/samples.cfg
+++ b/.kokoro/presubmit/samples.cfg
@@ -29,5 +29,5 @@ env_vars: {
env_vars: {
key: "SECRET_MANAGER_KEYS"
- value: "java-docs-samples-service-account,ucaip_samples_secrets"
-}
\ No newline at end of file
+ value: "java-docs-samples-service-account,ucaip_samples_secrets,java-aiplatform-samples-secrets"
+}
diff --git a/samples/snippets/pom.xml b/samples/snippets/pom.xml
index 2b6f82157..1fc37373f 100644
--- a/samples/snippets/pom.xml
+++ b/samples/snippets/pom.xml
@@ -40,6 +40,11 @@
protobuf-java-util
4.0.0-rc-2
+
+ com.google.code.gson
+ gson
+ 2.8.6
+
junit
junit
diff --git a/samples/snippets/src/main/java/aiplatform/CreateBatchPredictionJobBigquerySample.java b/samples/snippets/src/main/java/aiplatform/CreateBatchPredictionJobBigquerySample.java
new file mode 100644
index 000000000..5ccad051a
--- /dev/null
+++ b/samples/snippets/src/main/java/aiplatform/CreateBatchPredictionJobBigquerySample.java
@@ -0,0 +1,109 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_batch_prediction_job_bigquery_sample]
+import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob;
+import com.google.cloud.aiplatform.v1beta1.BigQueryDestination;
+import com.google.cloud.aiplatform.v1beta1.BigQuerySource;
+import com.google.cloud.aiplatform.v1beta1.JobServiceClient;
+import com.google.cloud.aiplatform.v1beta1.JobServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.ModelName;
+import com.google.gson.JsonObject;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+
+public class CreateBatchPredictionJobBigquerySample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "PROJECT";
+ String displayName = "DISPLAY_NAME";
+ String modelName = "MODEL_NAME";
+ String instancesFormat = "INSTANCES_FORMAT";
+ String bigquerySourceInputUri = "BIGQUERY_SOURCE_INPUT_URI";
+ String predictionsFormat = "PREDICTIONS_FORMAT";
+ String bigqueryDestinationOutputUri = "BIGQUERY_DESTINATION_OUTPUT_URI";
+ createBatchPredictionJobBigquerySample(
+ project,
+ displayName,
+ modelName,
+ instancesFormat,
+ bigquerySourceInputUri,
+ predictionsFormat,
+ bigqueryDestinationOutputUri);
+ }
+
+ static void createBatchPredictionJobBigquerySample(
+ String project,
+ String displayName,
+ String model,
+ String instancesFormat,
+ String bigquerySourceInputUri,
+ String predictionsFormat,
+ String bigqueryDestinationOutputUri)
+ throws IOException {
+ JobServiceSettings settings =
+ JobServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ String location = "us-central1";
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (JobServiceClient client = JobServiceClient.create(settings)) {
+ JsonObject jsonModelParameters = new JsonObject();
+ Value.Builder modelParametersBuilder = Value.newBuilder();
+ JsonFormat.parser().merge(jsonModelParameters.toString(), modelParametersBuilder);
+ Value modelParameters = modelParametersBuilder.build();
+ BigQuerySource bigquerySource =
+ BigQuerySource.newBuilder().setInputUri(bigquerySourceInputUri).build();
+ BatchPredictionJob.InputConfig inputConfig =
+ BatchPredictionJob.InputConfig.newBuilder()
+ .setInstancesFormat(instancesFormat)
+ .setBigquerySource(bigquerySource)
+ .build();
+ BigQueryDestination bigqueryDestination =
+ BigQueryDestination.newBuilder().setOutputUri(bigqueryDestinationOutputUri).build();
+ BatchPredictionJob.OutputConfig outputConfig =
+ BatchPredictionJob.OutputConfig.newBuilder()
+ .setPredictionsFormat(predictionsFormat)
+ .setBigqueryDestination(bigqueryDestination)
+ .build();
+ String modelName = ModelName.of(project, location, model).toString();
+ BatchPredictionJob batchPredictionJob =
+ BatchPredictionJob.newBuilder()
+ .setDisplayName(displayName)
+ .setModel(modelName)
+ .setModelParameters(modelParameters)
+ .setInputConfig(inputConfig)
+ .setOutputConfig(outputConfig)
+ // optional
+ .setGenerateExplanation(true)
+ .build();
+ LocationName parent = LocationName.of(project, location);
+ BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob);
+ System.out.format("response: %s\n", response);
+ System.out.format("\tName: %s\n", response.getName());
+ }
+ }
+}
+
+// [END aiplatform_create_batch_prediction_job_bigquery_sample]
diff --git a/samples/snippets/src/main/java/aiplatform/CreateBatchPredictionJobSample.java b/samples/snippets/src/main/java/aiplatform/CreateBatchPredictionJobSample.java
new file mode 100644
index 000000000..cdac97ba4
--- /dev/null
+++ b/samples/snippets/src/main/java/aiplatform/CreateBatchPredictionJobSample.java
@@ -0,0 +1,125 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_batch_prediction_job_sample]
+import com.google.cloud.aiplatform.v1beta1.AcceleratorType;
+import com.google.cloud.aiplatform.v1beta1.BatchDedicatedResources;
+import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob;
+import com.google.cloud.aiplatform.v1beta1.GcsDestination;
+import com.google.cloud.aiplatform.v1beta1.GcsSource;
+import com.google.cloud.aiplatform.v1beta1.JobServiceClient;
+import com.google.cloud.aiplatform.v1beta1.JobServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.MachineSpec;
+import com.google.cloud.aiplatform.v1beta1.ModelName;
+import com.google.gson.JsonObject;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+
+public class CreateBatchPredictionJobSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "PROJECT";
+ String displayName = "DISPLAY_NAME";
+ String modelName = "MODEL_NAME";
+ String instancesFormat = "INSTANCES_FORMAT";
+ String gcsSourceUri = "GCS_SOURCE_URI";
+ String predictionsFormat = "PREDICTIONS_FORMAT";
+ String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX";
+ createBatchPredictionJobSample(
+ project,
+ displayName,
+ modelName,
+ instancesFormat,
+ gcsSourceUri,
+ predictionsFormat,
+ gcsDestinationOutputUriPrefix);
+ }
+
+ static void createBatchPredictionJobSample(
+ String project,
+ String displayName,
+ String model,
+ String instancesFormat,
+ String gcsSourceUri,
+ String predictionsFormat,
+ String gcsDestinationOutputUriPrefix)
+ throws IOException {
+ JobServiceSettings settings =
+ JobServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ String location = "us-central1";
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (JobServiceClient client = JobServiceClient.create(settings)) {
+
+ // Passing in an empty Value object for model parameters
+ JsonObject jsonModelParameters = new JsonObject();
+ Value.Builder modelParametersBuilder = Value.newBuilder();
+ JsonFormat.parser().merge(jsonModelParameters.toString(), modelParametersBuilder);
+ Value modelParameters = modelParametersBuilder.build();
+
+ GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build();
+ BatchPredictionJob.InputConfig inputConfig =
+ BatchPredictionJob.InputConfig.newBuilder()
+ .setInstancesFormat(instancesFormat)
+ .setGcsSource(gcsSource)
+ .build();
+ GcsDestination gcsDestination =
+ GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build();
+ BatchPredictionJob.OutputConfig outputConfig =
+ BatchPredictionJob.OutputConfig.newBuilder()
+ .setPredictionsFormat(predictionsFormat)
+ .setGcsDestination(gcsDestination)
+ .build();
+ MachineSpec machineSpec =
+ MachineSpec.newBuilder()
+ .setMachineType("n1-standard-2")
+ .setAcceleratorType(AcceleratorType.NVIDIA_TESLA_K80)
+ .setAcceleratorCount(1)
+ .build();
+ BatchDedicatedResources dedicatedResources =
+ BatchDedicatedResources.newBuilder()
+ .setMachineSpec(machineSpec)
+ .setStartingReplicaCount(1)
+ .setMaxReplicaCount(1)
+ .build();
+ String modelName = ModelName.of(project, location, model).toString();
+ BatchPredictionJob batchPredictionJob =
+ BatchPredictionJob.newBuilder()
+ .setDisplayName(displayName)
+ .setModel(modelName)
+ .setModelParameters(modelParameters)
+ .setInputConfig(inputConfig)
+ .setOutputConfig(outputConfig)
+ .setDedicatedResources(dedicatedResources)
+ .build();
+ LocationName parent = LocationName.of(project, location);
+ BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob);
+ System.out.format("response: %s\n", response);
+ System.out.format("\tName: %s\n", response.getName());
+ }
+ }
+}
+
+// [END aiplatform_create_batch_prediction_job_sample]
diff --git a/samples/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSample.java b/samples/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSample.java
new file mode 100644
index 000000000..b255b625c
--- /dev/null
+++ b/samples/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSample.java
@@ -0,0 +1,99 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_batch_prediction_job_video_action_recognition_sample]
+import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob;
+import com.google.cloud.aiplatform.v1beta1.GcsDestination;
+import com.google.cloud.aiplatform.v1beta1.GcsSource;
+import com.google.cloud.aiplatform.v1beta1.JobServiceClient;
+import com.google.cloud.aiplatform.v1beta1.JobServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.ModelName;
+import com.google.gson.JsonObject;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+
+public class CreateBatchPredictionJobVideoActionRecognitionSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "PROJECT";
+ String displayName = "DISPLAY_NAME";
+ String model = "MODEL";
+ String gcsSourceUri = "GCS_SOURCE_URI";
+ String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX";
+ createBatchPredictionJobVideoActionRecognitionSample(
+ project, displayName, model, gcsSourceUri, gcsDestinationOutputUriPrefix);
+ }
+
+ static void createBatchPredictionJobVideoActionRecognitionSample(
+ String project,
+ String displayName,
+ String model,
+ String gcsSourceUri,
+ String gcsDestinationOutputUriPrefix)
+ throws IOException {
+ JobServiceSettings settings =
+ JobServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ String location = "us-central1";
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (JobServiceClient client = JobServiceClient.create(settings)) {
+ JsonObject jsonModelParameters = new JsonObject();
+ jsonModelParameters.addProperty("confidenceThreshold", 0.5);
+ Value.Builder modelParametersBuilder = Value.newBuilder();
+ JsonFormat.parser().merge(jsonModelParameters.toString(), modelParametersBuilder);
+ Value modelParameters = modelParametersBuilder.build();
+ GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build();
+ BatchPredictionJob.InputConfig inputConfig =
+ BatchPredictionJob.InputConfig.newBuilder()
+ .setInstancesFormat("jsonl")
+ .setGcsSource(gcsSource)
+ .build();
+ GcsDestination gcsDestination =
+ GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build();
+ BatchPredictionJob.OutputConfig outputConfig =
+ BatchPredictionJob.OutputConfig.newBuilder()
+ .setPredictionsFormat("jsonl")
+ .setGcsDestination(gcsDestination)
+ .build();
+
+ String modelName = ModelName.of(project, location, model).toString();
+
+ BatchPredictionJob batchPredictionJob =
+ BatchPredictionJob.newBuilder()
+ .setDisplayName(displayName)
+ .setModel(modelName)
+ .setModelParameters(modelParameters)
+ .setInputConfig(inputConfig)
+ .setOutputConfig(outputConfig)
+ .build();
+ LocationName parent = LocationName.of(project, location);
+ BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob);
+ System.out.format("response: %s\n", response);
+ System.out.format("\tName: %s\n", response.getName());
+ }
+ }
+}
+
+// [END aiplatform_create_batch_prediction_job_video_action_recognition_sample]
diff --git a/samples/snippets/src/main/java/aiplatform/CreateDataLabelingJobActiveLearningSample.java b/samples/snippets/src/main/java/aiplatform/CreateDataLabelingJobActiveLearningSample.java
new file mode 100644
index 000000000..d9f069e40
--- /dev/null
+++ b/samples/snippets/src/main/java/aiplatform/CreateDataLabelingJobActiveLearningSample.java
@@ -0,0 +1,97 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_data_labeling_job_active_learning_sample]
+import com.google.cloud.aiplatform.v1beta1.ActiveLearningConfig;
+import com.google.cloud.aiplatform.v1beta1.DataLabelingJob;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.JobServiceClient;
+import com.google.cloud.aiplatform.v1beta1.JobServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.gson.JsonArray;
+import com.google.gson.JsonObject;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+
+public class CreateDataLabelingJobActiveLearningSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "PROJECT";
+ String displayName = "DISPLAY_NAME";
+ String dataset = "DATASET";
+ String instructionUri = "INSTRUCTION_URI";
+ String inputsSchemaUri = "INPUTS_SCHEMA_URI";
+ String annotationSpec = "ANNOTATION_SPEC";
+ createDataLabelingJobActiveLearningSample(
+ project, displayName, dataset, instructionUri, inputsSchemaUri, annotationSpec);
+ }
+
+ static void createDataLabelingJobActiveLearningSample(
+ String project,
+ String displayName,
+ String dataset,
+ String instructionUri,
+ String inputsSchemaUri,
+ String annotationSpec)
+ throws IOException {
+ JobServiceSettings settings =
+ JobServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ String location = "us-central1";
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (JobServiceClient client = JobServiceClient.create(settings)) {
+ JsonArray jsonAnnotationSpecs = new JsonArray();
+ jsonAnnotationSpecs.add(annotationSpec);
+ JsonObject jsonInputs = new JsonObject();
+ jsonInputs.add("annotation_specs", jsonAnnotationSpecs);
+ Value.Builder inputsBuilder = Value.newBuilder();
+ JsonFormat.parser().merge(jsonInputs.toString(), inputsBuilder);
+ Value inputs = inputsBuilder.build();
+ ActiveLearningConfig activeLearningConfig =
+ ActiveLearningConfig.newBuilder().setMaxDataItemCount(1).build();
+
+ String datasetName = DatasetName.of(project, location, dataset).toString();
+
+ DataLabelingJob dataLabelingJob =
+ DataLabelingJob.newBuilder()
+ .setDisplayName(displayName)
+ .addDatasets(datasetName)
+ .setLabelerCount(1)
+ .setInstructionUri(instructionUri)
+ .setInputsSchemaUri(inputsSchemaUri)
+ .setInputs(inputs)
+ .putAnnotationLabels(
+ "aiplatform.googleapis.com/annotation_set_name",
+ "data_labeling_job_active_learning")
+ .setActiveLearningConfig(activeLearningConfig)
+ .build();
+ LocationName parent = LocationName.of(project, location);
+ DataLabelingJob response = client.createDataLabelingJob(parent, dataLabelingJob);
+ System.out.format("response: %s\n", response);
+ System.out.format("Name: %s\n", response.getName());
+ }
+ }
+}
+
+// [END aiplatform_create_data_labeling_job_active_learning_sample]
diff --git a/samples/snippets/src/main/java/aiplatform/CreateDataLabelingJobImageSample.java b/samples/snippets/src/main/java/aiplatform/CreateDataLabelingJobImageSample.java
index ce81ae6b4..5ea70a42f 100644
--- a/samples/snippets/src/main/java/aiplatform/CreateDataLabelingJobImageSample.java
+++ b/samples/snippets/src/main/java/aiplatform/CreateDataLabelingJobImageSample.java
@@ -72,7 +72,7 @@ static void createDataLabelingJobImage(
.setInstructionUri(instructionUri)
.setInputsSchemaUri(
"gs://google-cloud-aiplatform/schema/datalabelingjob/inputs/"
- + "image_classification.yaml")
+ + "image_classification.yaml")
.addDatasets(datasetName.toString())
.setInputs(annotationSpecValue)
.putAnnotationLabels(
diff --git a/samples/snippets/src/main/java/aiplatform/CreateDataLabelingJobSpecialistPoolSample.java b/samples/snippets/src/main/java/aiplatform/CreateDataLabelingJobSpecialistPoolSample.java
new file mode 100644
index 000000000..04a3c4216
--- /dev/null
+++ b/samples/snippets/src/main/java/aiplatform/CreateDataLabelingJobSpecialistPoolSample.java
@@ -0,0 +1,104 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_data_labeling_job_specialist_pool_sample]
+import com.google.cloud.aiplatform.v1beta1.DataLabelingJob;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.JobServiceClient;
+import com.google.cloud.aiplatform.v1beta1.JobServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.SpecialistPoolName;
+import com.google.gson.JsonArray;
+import com.google.gson.JsonObject;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+
+public class CreateDataLabelingJobSpecialistPoolSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "PROJECT";
+ String displayName = "DISPLAY_NAME";
+ String dataset = "DATASET";
+ String specialistPool = "SPECIALIST_POOL";
+ String instructionUri = "INSTRUCTION_URI";
+ String inputsSchemaUri = "INPUTS_SCHEMA_URI";
+ String annotationSpec = "ANNOTATION_SPEC";
+ createDataLabelingJobSpecialistPoolSample(
+ project,
+ displayName,
+ dataset,
+ specialistPool,
+ instructionUri,
+ inputsSchemaUri,
+ annotationSpec);
+ }
+
+ static void createDataLabelingJobSpecialistPoolSample(
+ String project,
+ String displayName,
+ String dataset,
+ String specialistPool,
+ String instructionUri,
+ String inputsSchemaUri,
+ String annotationSpec)
+ throws IOException {
+ JobServiceSettings settings =
+ JobServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ String location = "us-central1";
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (JobServiceClient client = JobServiceClient.create(settings)) {
+ JsonArray jsonAnnotationSpecs = new JsonArray();
+ jsonAnnotationSpecs.add(annotationSpec);
+ JsonObject jsonInputs = new JsonObject();
+ jsonInputs.add("annotation_specs", jsonAnnotationSpecs);
+ Value.Builder inputsBuilder = Value.newBuilder();
+ JsonFormat.parser().merge(jsonInputs.toString(), inputsBuilder);
+ Value inputs = inputsBuilder.build();
+
+ String datasetName = DatasetName.of(project, location, dataset).toString();
+ String specialistPoolName = SpecialistPoolName.of(project, location, specialistPool)
+ .toString();
+
+ DataLabelingJob dataLabelingJob =
+ DataLabelingJob.newBuilder()
+ .setDisplayName(displayName)
+ .addDatasets(datasetName)
+ .setLabelerCount(1)
+ .setInstructionUri(instructionUri)
+ .setInputsSchemaUri(inputsSchemaUri)
+ .setInputs(inputs)
+ .putAnnotationLabels(
+ "aiplatform.googleapis.com/annotation_set_name",
+ "data_labeling_job_specialist_pool")
+ .addSpecialistPools(specialistPoolName)
+ .build();
+ LocationName parent = LocationName.of(project, location);
+ DataLabelingJob response = client.createDataLabelingJob(parent, dataLabelingJob);
+ System.out.format("response: %s\n", response);
+ }
+ }
+}
+
+// [END aiplatform_create_data_labeling_job_specialist_pool_sample]
diff --git a/samples/snippets/src/main/java/aiplatform/CreateDataLabelingJobVideoSample.java b/samples/snippets/src/main/java/aiplatform/CreateDataLabelingJobVideoSample.java
index 5722c335e..ae0e451ba 100644
--- a/samples/snippets/src/main/java/aiplatform/CreateDataLabelingJobVideoSample.java
+++ b/samples/snippets/src/main/java/aiplatform/CreateDataLabelingJobVideoSample.java
@@ -72,7 +72,7 @@ static void createDataLabelingJobVideo(
.setInstructionUri(instructionUri)
.setInputsSchemaUri(
"gs://google-cloud-aiplatform/schema/datalabelingjob/inputs/"
- + "video_classification.yaml")
+ + "video_classification.yaml")
.addDatasets(datasetName.toString())
.setInputs(annotationSpecValue)
.putAnnotationLabels(
diff --git a/samples/snippets/src/main/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSample.java b/samples/snippets/src/main/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSample.java
new file mode 100644
index 000000000..9d1937f3d
--- /dev/null
+++ b/samples/snippets/src/main/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSample.java
@@ -0,0 +1,174 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_hyperparameter_tuning_job_python_package_sample]
+import com.google.cloud.aiplatform.v1beta1.AcceleratorType;
+import com.google.cloud.aiplatform.v1beta1.CustomJobSpec;
+import com.google.cloud.aiplatform.v1beta1.HyperparameterTuningJob;
+import com.google.cloud.aiplatform.v1beta1.JobServiceClient;
+import com.google.cloud.aiplatform.v1beta1.JobServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.MachineSpec;
+import com.google.cloud.aiplatform.v1beta1.PythonPackageSpec;
+import com.google.cloud.aiplatform.v1beta1.StudySpec;
+import com.google.cloud.aiplatform.v1beta1.StudySpec.MetricSpec;
+import com.google.cloud.aiplatform.v1beta1.StudySpec.MetricSpec.GoalType;
+import com.google.cloud.aiplatform.v1beta1.StudySpec.ParameterSpec;
+import com.google.cloud.aiplatform.v1beta1.StudySpec.ParameterSpec.ConditionalParameterSpec;
+import com.google.cloud.aiplatform.v1beta1.StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition;
+import com.google.cloud.aiplatform.v1beta1.StudySpec.ParameterSpec.DiscreteValueSpec;
+import com.google.cloud.aiplatform.v1beta1.StudySpec.ParameterSpec.DoubleValueSpec;
+import com.google.cloud.aiplatform.v1beta1.StudySpec.ParameterSpec.ScaleType;
+import com.google.cloud.aiplatform.v1beta1.WorkerPoolSpec;
+import java.io.IOException;
+import java.util.Arrays;
+
+public class CreateHyperparameterTuningJobPythonPackageSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "PROJECT";
+ String displayName = "DISPLAY_NAME";
+ String executorImageUri = "EXECUTOR_IMAGE_URI";
+ String packageUri = "PACKAGE_URI";
+ String pythonModule = "PYTHON_MODULE";
+ createHyperparameterTuningJobPythonPackageSample(
+ project, displayName, executorImageUri, packageUri, pythonModule);
+ }
+
+ static void createHyperparameterTuningJobPythonPackageSample(
+ String project,
+ String displayName,
+ String executorImageUri,
+ String packageUri,
+ String pythonModule)
+ throws IOException {
+ JobServiceSettings settings =
+ JobServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ String location = "us-central1";
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (JobServiceClient client = JobServiceClient.create(settings)) {
+ // study spec
+ MetricSpec metric =
+ MetricSpec.newBuilder().setMetricId("val_rmse").setGoal(GoalType.MINIMIZE).build();
+
+ // decay
+ DoubleValueSpec doubleValueSpec =
+ DoubleValueSpec.newBuilder().setMinValue(1e-07).setMaxValue(1).build();
+ ParameterSpec parameterDecaySpec =
+ ParameterSpec.newBuilder()
+ .setParameterId("decay")
+ .setDoubleValueSpec(doubleValueSpec)
+ .setScaleType(ScaleType.UNIT_LINEAR_SCALE)
+ .build();
+ Double[] decayValues = {32.0, 64.0};
+ DiscreteValueCondition discreteValueDecay =
+ DiscreteValueCondition.newBuilder().addAllValues(Arrays.asList(decayValues)).build();
+ ConditionalParameterSpec conditionalParameterDecay =
+ ConditionalParameterSpec.newBuilder()
+ .setParameterSpec(parameterDecaySpec)
+ .setParentDiscreteValues(discreteValueDecay)
+ .build();
+
+ // learning rate
+ ParameterSpec parameterLearningSpec =
+ ParameterSpec.newBuilder()
+ .setParameterId("learning_rate")
+ .setDoubleValueSpec(doubleValueSpec) // Use the same min/max as for decay
+ .setScaleType(ScaleType.UNIT_LINEAR_SCALE)
+ .build();
+
+ Double[] learningRateValues = {4.0, 8.0, 16.0};
+ DiscreteValueCondition discreteValueLearning =
+ DiscreteValueCondition.newBuilder()
+ .addAllValues(Arrays.asList(learningRateValues))
+ .build();
+ ConditionalParameterSpec conditionalParameterLearning =
+ ConditionalParameterSpec.newBuilder()
+ .setParameterSpec(parameterLearningSpec)
+ .setParentDiscreteValues(discreteValueLearning)
+ .build();
+
+ // batch size
+ Double[] batchSizeValues = {4.0, 8.0, 16.0, 32.0, 64.0, 128.0};
+
+ DiscreteValueSpec discreteValueSpec =
+ DiscreteValueSpec.newBuilder().addAllValues(Arrays.asList(batchSizeValues)).build();
+ ParameterSpec parameter =
+ ParameterSpec.newBuilder()
+ .setParameterId("batch_size")
+ .setDiscreteValueSpec(discreteValueSpec)
+ .setScaleType(ScaleType.UNIT_LINEAR_SCALE)
+ .addConditionalParameterSpecs(conditionalParameterDecay)
+ .addConditionalParameterSpecs(conditionalParameterLearning)
+ .build();
+
+ // trial_job_spec
+ MachineSpec machineSpec =
+ MachineSpec.newBuilder()
+ .setMachineType("n1-standard-4")
+ .setAcceleratorType(AcceleratorType.NVIDIA_TESLA_K80)
+ .setAcceleratorCount(1)
+ .build();
+
+ PythonPackageSpec pythonPackageSpec =
+ PythonPackageSpec.newBuilder()
+ .setExecutorImageUri(executorImageUri)
+ .addPackageUris(packageUri)
+ .setPythonModule(pythonModule)
+ .build();
+
+ WorkerPoolSpec workerPoolSpec =
+ WorkerPoolSpec.newBuilder()
+ .setMachineSpec(machineSpec)
+ .setReplicaCount(1)
+ .setPythonPackageSpec(pythonPackageSpec)
+ .build();
+
+ StudySpec studySpec =
+ StudySpec.newBuilder()
+ .addMetrics(metric)
+ .addParameters(parameter)
+ .setAlgorithm(StudySpec.Algorithm.RANDOM_SEARCH)
+ .build();
+ CustomJobSpec trialJobSpec =
+ CustomJobSpec.newBuilder().addWorkerPoolSpecs(workerPoolSpec).build();
+ // hyperparameter_tuning_job
+ HyperparameterTuningJob hyperparameterTuningJob =
+ HyperparameterTuningJob.newBuilder()
+ .setDisplayName(displayName)
+ .setMaxTrialCount(4)
+ .setParallelTrialCount(2)
+ .setStudySpec(studySpec)
+ .setTrialJobSpec(trialJobSpec)
+ .build();
+ LocationName parent = LocationName.of(project, location);
+ HyperparameterTuningJob response =
+ client.createHyperparameterTuningJob(parent, hyperparameterTuningJob);
+ System.out.format("response: %s\n", response);
+ System.out.format("Name: %s\n", response.getName());
+ }
+ }
+}
+
+// [END aiplatform_create_hyperparameter_tuning_job_python_package_sample]
diff --git a/samples/snippets/src/main/java/aiplatform/CreateHyperparameterTuningJobSample.java b/samples/snippets/src/main/java/aiplatform/CreateHyperparameterTuningJobSample.java
new file mode 100644
index 000000000..37e66d512
--- /dev/null
+++ b/samples/snippets/src/main/java/aiplatform/CreateHyperparameterTuningJobSample.java
@@ -0,0 +1,106 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_hyperparameter_tuning_job_sample]
+import com.google.cloud.aiplatform.v1beta1.AcceleratorType;
+import com.google.cloud.aiplatform.v1beta1.ContainerSpec;
+import com.google.cloud.aiplatform.v1beta1.CustomJobSpec;
+import com.google.cloud.aiplatform.v1beta1.HyperparameterTuningJob;
+import com.google.cloud.aiplatform.v1beta1.JobServiceClient;
+import com.google.cloud.aiplatform.v1beta1.JobServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.MachineSpec;
+import com.google.cloud.aiplatform.v1beta1.StudySpec;
+import com.google.cloud.aiplatform.v1beta1.WorkerPoolSpec;
+import java.io.IOException;
+
+public class CreateHyperparameterTuningJobSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "PROJECT";
+ String displayName = "DISPLAY_NAME";
+ String containerImageUri = "CONTAINER_IMAGE_URI";
+ createHyperparameterTuningJobSample(project, displayName, containerImageUri);
+ }
+
+ static void createHyperparameterTuningJobSample(
+ String project, String displayName, String containerImageUri) throws IOException {
+ JobServiceSettings settings =
+ JobServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ String location = "us-central1";
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (JobServiceClient client = JobServiceClient.create(settings)) {
+ StudySpec.MetricSpec metric0 =
+ StudySpec.MetricSpec.newBuilder()
+ .setMetricId("accuracy")
+ .setGoal(StudySpec.MetricSpec.GoalType.MAXIMIZE)
+ .build();
+ StudySpec.ParameterSpec.DoubleValueSpec doubleValueSpec =
+ StudySpec.ParameterSpec.DoubleValueSpec.newBuilder()
+ .setMinValue(0.001)
+ .setMaxValue(0.1)
+ .build();
+ StudySpec.ParameterSpec parameter0 =
+ StudySpec.ParameterSpec.newBuilder()
+ // Learning rate.
+ .setParameterId("lr")
+ .setDoubleValueSpec(doubleValueSpec)
+ .build();
+ StudySpec studySpec =
+ StudySpec.newBuilder().addMetrics(metric0).addParameters(parameter0).build();
+ MachineSpec machineSpec =
+ MachineSpec.newBuilder()
+ .setMachineType("n1-standard-4")
+ .setAcceleratorType(AcceleratorType.NVIDIA_TESLA_K80)
+ .setAcceleratorCount(1)
+ .build();
+ ContainerSpec containerSpec =
+ ContainerSpec.newBuilder().setImageUri(containerImageUri).build();
+ WorkerPoolSpec workerPoolSpec0 =
+ WorkerPoolSpec.newBuilder()
+ .setMachineSpec(machineSpec)
+ .setReplicaCount(1)
+ .setContainerSpec(containerSpec)
+ .build();
+ CustomJobSpec trialJobSpec =
+ CustomJobSpec.newBuilder().addWorkerPoolSpecs(workerPoolSpec0).build();
+ HyperparameterTuningJob hyperparameterTuningJob =
+ HyperparameterTuningJob.newBuilder()
+ .setDisplayName(displayName)
+ .setMaxTrialCount(2)
+ .setParallelTrialCount(1)
+ .setMaxFailedTrialCount(1)
+ .setStudySpec(studySpec)
+ .setTrialJobSpec(trialJobSpec)
+ .build();
+ LocationName parent = LocationName.of(project, location);
+ HyperparameterTuningJob response =
+ client.createHyperparameterTuningJob(parent, hyperparameterTuningJob);
+ System.out.format("response: %s\n", response);
+ System.out.format("Name: %s\n", response.getName());
+ }
+ }
+}
+
+// [END aiplatform_create_hyperparameter_tuning_job_sample]
diff --git a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomJobSample.java b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomJobSample.java
new file mode 100644
index 000000000..7b40d0e8d
--- /dev/null
+++ b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomJobSample.java
@@ -0,0 +1,119 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_training_pipeline_custom_job_sample]
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.Model;
+import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
+import com.google.gson.JsonArray;
+import com.google.gson.JsonObject;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+
+public class CreateTrainingPipelineCustomJobSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "PROJECT";
+ String displayName = "DISPLAY_NAME";
+ String modelDisplayName = "MODEL_DISPLAY_NAME";
+ String containerImageUri = "CONTAINER_IMAGE_URI";
+ String baseOutputDirectoryPrefix = "BASE_OUTPUT_DIRECTORY_PREFIX";
+ createTrainingPipelineCustomJobSample(
+ project, displayName, modelDisplayName, containerImageUri, baseOutputDirectoryPrefix);
+ }
+
+ static void createTrainingPipelineCustomJobSample(
+ String project,
+ String displayName,
+ String modelDisplayName,
+ String containerImageUri,
+ String baseOutputDirectoryPrefix)
+ throws IOException {
+ PipelineServiceSettings settings =
+ PipelineServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ String location = "us-central1";
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PipelineServiceClient client = PipelineServiceClient.create(settings)) {
+ JsonObject jsonMachineSpec = new JsonObject();
+ jsonMachineSpec.addProperty("machineType", "n1-standard-4");
+
+ JsonArray jsonArgs = new JsonArray();
+ jsonArgs.add("--model_dir=$(AIP_MODEL_DIR)");
+
+ // A working docker image can be found at
+ // gs://cloud-samples-data/ai-platform/mnist_tfrecord/custom_job
+ JsonObject jsonContainerSpec = new JsonObject();
+ jsonContainerSpec.addProperty("imageUri", containerImageUri);
+ jsonContainerSpec.add("args", jsonArgs);
+
+ JsonObject jsonJsonWorkerPoolSpec0 = new JsonObject();
+ jsonJsonWorkerPoolSpec0.addProperty("replicaCount", 1);
+ jsonJsonWorkerPoolSpec0.add("machineSpec", jsonMachineSpec);
+ jsonJsonWorkerPoolSpec0.add("containerSpec", jsonContainerSpec);
+
+ JsonArray jsonWorkerPoolSpecs = new JsonArray();
+ jsonWorkerPoolSpecs.add(jsonJsonWorkerPoolSpec0);
+
+ JsonObject jsonBaseOutputDirectory = new JsonObject();
+ // The GCS location for outputs must be accessible by the project's AI Platform
+ // service account.
+ jsonBaseOutputDirectory.addProperty("output_uri_prefix", baseOutputDirectoryPrefix);
+
+ JsonObject jsonTrainingTaskInputs = new JsonObject();
+ jsonTrainingTaskInputs.add("workerPoolSpecs", jsonWorkerPoolSpecs);
+ jsonTrainingTaskInputs.add("baseOutputDirectory", jsonBaseOutputDirectory);
+
+ Value.Builder trainingTaskInputsBuilder = Value.newBuilder();
+ JsonFormat.parser().merge(jsonTrainingTaskInputs.toString(), trainingTaskInputsBuilder);
+ Value trainingTaskInputs = trainingTaskInputsBuilder.build();
+ String trainingTaskDefinition =
+ "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml";
+ String imageUri = "gcr.io/cloud-aiplatform/prediction/tf-cpu.1-15:latest";
+ ModelContainerSpec containerSpec =
+ ModelContainerSpec.newBuilder().setImageUri(imageUri).build();
+ Model modelToUpload =
+ Model.newBuilder()
+ .setDisplayName(modelDisplayName)
+ .setContainerSpec(containerSpec)
+ .build();
+ TrainingPipeline trainingPipeline =
+ TrainingPipeline.newBuilder()
+ .setDisplayName(displayName)
+ .setTrainingTaskDefinition(trainingTaskDefinition)
+ .setTrainingTaskInputs(trainingTaskInputs)
+ .setModelToUpload(modelToUpload)
+ .build();
+ LocationName parent = LocationName.of(project, location);
+ TrainingPipeline response = client.createTrainingPipeline(parent, trainingPipeline);
+ System.out.format("response: %s\n", response);
+ System.out.format("Name: %s\n", response.getName());
+ }
+ }
+}
+
+// [END aiplatform_create_training_pipeline_custom_job_sample]
diff --git a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java
new file mode 100644
index 000000000..739d15cf8
--- /dev/null
+++ b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java
@@ -0,0 +1,145 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_training_pipeline_custom_training_managed_dataset_sample]
+import com.google.cloud.aiplatform.v1beta1.GcsDestination;
+import com.google.cloud.aiplatform.v1beta1.InputDataConfig;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.Model;
+import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
+import com.google.gson.JsonArray;
+import com.google.gson.JsonObject;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+
+public class CreateTrainingPipelineCustomTrainingManagedDatasetSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "PROJECT";
+ String displayName = "DISPLAY_NAME";
+ String modelDisplayName = "MODEL_DISPLAY_NAME";
+ String datasetId = "DATASET_ID";
+ String annotationSchemaUri = "ANNOTATION_SCHEMA_URI";
+ String trainingContainerSpecImageUri = "TRAINING_CONTAINER_SPEC_IMAGE_URI";
+ String modelContainerSpecImageUri = "MODEL_CONTAINER_SPEC_IMAGE_URI";
+ String baseOutputUriPrefix = "BASE_OUTPUT_URI_PREFIX";
+ createTrainingPipelineCustomTrainingManagedDatasetSample(
+ project,
+ displayName,
+ modelDisplayName,
+ datasetId,
+ annotationSchemaUri,
+ trainingContainerSpecImageUri,
+ modelContainerSpecImageUri,
+ baseOutputUriPrefix);
+ }
+
+ static void createTrainingPipelineCustomTrainingManagedDatasetSample(
+ String project,
+ String displayName,
+ String modelDisplayName,
+ String datasetId,
+ String annotationSchemaUri,
+ String trainingContainerSpecImageUri,
+ String modelContainerSpecImageUri,
+ String baseOutputUriPrefix)
+ throws IOException {
+ PipelineServiceSettings settings =
+ PipelineServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ String location = "us-central1";
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PipelineServiceClient client = PipelineServiceClient.create(settings)) {
+ JsonArray jsonArgs = new JsonArray();
+ jsonArgs.add("--model-dir=$(AIP_MODEL_DIR)");
+ // training_task_inputs
+ JsonObject jsonTrainingContainerSpec = new JsonObject();
+ jsonTrainingContainerSpec.addProperty("imageUri", trainingContainerSpecImageUri);
+ // AIP_MODEL_DIR is set by the service according to baseOutputDirectory.
+ jsonTrainingContainerSpec.add("args", jsonArgs);
+
+ JsonObject jsonMachineSpec = new JsonObject();
+ jsonMachineSpec.addProperty("machineType", "n1-standard-8");
+
+ JsonObject jsonTrainingWorkerPoolSpec = new JsonObject();
+ jsonTrainingWorkerPoolSpec.addProperty("replicaCount", 1);
+ jsonTrainingWorkerPoolSpec.add("machineSpec", jsonMachineSpec);
+ jsonTrainingWorkerPoolSpec.add("containerSpec", jsonTrainingContainerSpec);
+
+ JsonArray jsonWorkerPoolSpecs = new JsonArray();
+ jsonWorkerPoolSpecs.add(jsonTrainingWorkerPoolSpec);
+
+ JsonObject jsonBaseOutputDirectory = new JsonObject();
+ jsonBaseOutputDirectory.addProperty("outputUriPrefix", baseOutputUriPrefix);
+
+ JsonObject jsonTrainingTaskInputs = new JsonObject();
+ jsonTrainingTaskInputs.add("workerPoolSpecs", jsonWorkerPoolSpecs);
+ jsonTrainingTaskInputs.add("baseOutputDirectory", jsonBaseOutputDirectory);
+
+ Value.Builder trainingTaskInputsBuilder = Value.newBuilder();
+ JsonFormat.parser().merge(jsonTrainingTaskInputs.toString(), trainingTaskInputsBuilder);
+ Value trainingTaskInputs = trainingTaskInputsBuilder.build();
+ // model_to_upload
+ ModelContainerSpec modelContainerSpec =
+ ModelContainerSpec.newBuilder().setImageUri(modelContainerSpecImageUri).build();
+ Model model =
+ Model.newBuilder()
+ .setDisplayName(modelDisplayName)
+ .setContainerSpec(modelContainerSpec)
+ .build();
+ GcsDestination gcsDestination =
+ GcsDestination.newBuilder().setOutputUriPrefix(baseOutputUriPrefix).build();
+
+ // input_data_config
+ InputDataConfig inputDataConfig =
+ InputDataConfig.newBuilder()
+ .setDatasetId(datasetId)
+ .setAnnotationSchemaUri(annotationSchemaUri)
+ .setGcsDestination(gcsDestination)
+ .build();
+
+ // training_task_definition
+ String customTaskDefinition =
+ "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml";
+
+ TrainingPipeline trainingPipeline =
+ TrainingPipeline.newBuilder()
+ .setDisplayName(displayName)
+ .setInputDataConfig(inputDataConfig)
+ .setTrainingTaskDefinition(customTaskDefinition)
+ .setTrainingTaskInputs(trainingTaskInputs)
+ .setModelToUpload(model)
+ .build();
+ LocationName parent = LocationName.of(project, location);
+ TrainingPipeline response = client.createTrainingPipeline(parent, trainingPipeline);
+ System.out.format("response: %s\n", response);
+ System.out.format("Name: %s\n", response.getName());
+ }
+ }
+}
+
+// [END aiplatform_create_training_pipeline_custom_training_managed_dataset_sample]
diff --git a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java
index 7327cba9b..041ec32f0 100644
--- a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java
+++ b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java
@@ -38,12 +38,10 @@
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
-import com.google.protobuf.Any;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import com.google.rpc.Status;
import java.io.IOException;
-import java.util.List;
public class CreateTrainingPipelineImageClassificationSample {
@@ -73,7 +71,7 @@ static void createTrainingPipelineImageClassificationSample(
String location = "us-central1";
String trainingTaskDefinition =
"gs://google-cloud-aiplatform/schema/trainingjob/definition/"
- + "automl_image_classification_1.0.0.yaml";
+ + "automl_image_classification_1.0.0.yaml";
LocationName locationName = LocationName.of(project, location);
String jsonString =
diff --git a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java
index 636ab0224..d6b2af9ab 100644
--- a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java
+++ b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java
@@ -38,12 +38,10 @@
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
-import com.google.protobuf.Any;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import com.google.rpc.Status;
import java.io.IOException;
-import java.util.List;
public class CreateTrainingPipelineImageObjectDetectionSample {
@@ -73,7 +71,7 @@ static void createTrainingPipelineImageObjectDetectionSample(
String location = "us-central1";
String trainingTaskDefinition =
"gs://google-cloud-aiplatform/schema/trainingjob/definition/"
- + "automl_image_object_detection_1.0.0.yaml";
+ + "automl_image_object_detection_1.0.0.yaml";
LocationName locationName = LocationName.of(project, location);
String jsonString =
diff --git a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineSample.java b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineSample.java
index 5c310bfea..2dcb6e88c 100644
--- a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineSample.java
+++ b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineSample.java
@@ -38,12 +38,10 @@
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
-import com.google.protobuf.Any;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import com.google.rpc.Status;
import java.io.IOException;
-import java.util.List;
public class CreateTrainingPipelineSample {
diff --git a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java
index de54c0a0a..92ebb6ada 100644
--- a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java
+++ b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java
@@ -37,12 +37,10 @@
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
-import com.google.protobuf.Any;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import com.google.rpc.Status;
import java.io.IOException;
-import java.util.List;
public class CreateTrainingPipelineTabularClassificationSample {
diff --git a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java
index ca24862ca..0043fdf52 100644
--- a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java
+++ b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java
@@ -37,12 +37,10 @@
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
-import com.google.protobuf.Any;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import com.google.rpc.Status;
import java.io.IOException;
-import java.util.List;
public class CreateTrainingPipelineTabularRegressionSample {
diff --git a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java
index c67e49a05..b5335a4c7 100644
--- a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java
+++ b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java
@@ -38,12 +38,10 @@
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
-import com.google.protobuf.Any;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import com.google.rpc.Status;
import java.io.IOException;
-import java.util.List;
public class CreateTrainingPipelineTextClassificationSample {
@@ -74,7 +72,7 @@ static void createTrainingPipelineTextClassificationSample(
String location = "us-central1";
String trainingTaskDefinition =
"gs://google-cloud-aiplatform/schema/trainingjob/definition/"
- + "automl_text_classification_1.0.0.yaml";
+ + "automl_text_classification_1.0.0.yaml";
String jsonString = "{\"multiLabel\": false}";
LocationName locationName = LocationName.of(project, location);
diff --git a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java
index 3aef27086..9d83a4d1e 100644
--- a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java
+++ b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java
@@ -38,12 +38,10 @@
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
-import com.google.protobuf.Any;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import com.google.rpc.Status;
import java.io.IOException;
-import java.util.List;
public class CreateTrainingPipelineTextEntityExtractionSample {
@@ -74,7 +72,7 @@ static void createTrainingPipelineTextEntityExtractionSample(
String location = "us-central1";
String trainingTaskDefinition =
"gs://google-cloud-aiplatform/schema/trainingjob/definition/"
- + "automl_text_extraction_1.0.0.yaml";
+ + "automl_text_extraction_1.0.0.yaml";
String jsonString = "{}";
LocationName locationName = LocationName.of(project, location);
diff --git a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java
index a6139405d..f9c656ad0 100644
--- a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java
+++ b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java
@@ -38,12 +38,10 @@
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
-import com.google.protobuf.Any;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import com.google.rpc.Status;
import java.io.IOException;
-import java.util.List;
public class CreateTrainingPipelineTextSentimentAnalysisSample {
@@ -74,7 +72,7 @@ static void createTrainingPipelineTextSentimentAnalysisSample(
String location = "us-central1";
String trainingTaskDefinition =
"gs://google-cloud-aiplatform/schema/trainingjob/definition/"
- + "automl_text_sentiment_1.0.0.yaml";
+ + "automl_text_sentiment_1.0.0.yaml";
// Sentiment max must be between 1 and 10 inclusive.
// Higher value means positive sentiment.
diff --git a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java
new file mode 100644
index 000000000..70d03788e
--- /dev/null
+++ b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java
@@ -0,0 +1,87 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_training_pipeline_video_action_recognition_sample]
+import com.google.cloud.aiplatform.v1beta1.InputDataConfig;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.Model;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
+import com.google.gson.JsonObject;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+
+public class CreateTrainingPipelineVideoActionRecognitionSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "PROJECT";
+ String displayName = "DISPLAY_NAME";
+ String datasetId = "DATASET_ID";
+ String modelDisplayName = "MODEL_DISPLAY_NAME";
+ String modelType = "MODEL_TYPE";
+ createTrainingPipelineVideoActionRecognitionSample(
+ project, displayName, datasetId, modelDisplayName, modelType);
+ }
+
+ static void createTrainingPipelineVideoActionRecognitionSample(
+ String project,
+ String displayName,
+ String datasetId,
+ String modelDisplayName,
+ String modelType)
+ throws IOException {
+ PipelineServiceSettings settings =
+ PipelineServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ String location = "us-central1";
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PipelineServiceClient client = PipelineServiceClient.create(settings)) {
+ JsonObject jsonTrainingTaskInputs = new JsonObject();
+ jsonTrainingTaskInputs.addProperty("modelType", modelType);
+ Value.Builder trainingTaskInputsBuilder = Value.newBuilder();
+ JsonFormat.parser().merge(jsonTrainingTaskInputs.toString(), trainingTaskInputsBuilder);
+ Value trainingTaskInputs = trainingTaskInputsBuilder.build();
+ InputDataConfig inputDataConfig =
+ InputDataConfig.newBuilder().setDatasetId(datasetId).build();
+ Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();
+ TrainingPipeline trainingPipeline =
+ TrainingPipeline.newBuilder()
+ .setDisplayName(displayName)
+ .setTrainingTaskDefinition(
+ "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
+ + "automl_video_action_recognition_1.0.0.yaml")
+ .setTrainingTaskInputs(trainingTaskInputs)
+ .setInputDataConfig(inputDataConfig)
+ .setModelToUpload(modelToUpload)
+ .build();
+ LocationName parent = LocationName.of(project, location);
+ TrainingPipeline response = client.createTrainingPipeline(parent, trainingPipeline);
+ System.out.format("response: %s\n", response);
+ System.out.format("Name: %s\n", response.getName());
+ }
+ }
+}
+
+// [END aiplatform_create_training_pipeline_video_action_recognition_sample]
diff --git a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java
index 383e56954..eacbc5123 100644
--- a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java
+++ b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java
@@ -28,11 +28,9 @@
import com.google.cloud.aiplatform.v1beta1.PredefinedSplit;
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
-import com.google.protobuf.Any;
import com.google.protobuf.Value;
import com.google.rpc.Status;
import java.io.IOException;
-import java.util.List;
public class CreateTrainingPipelineVideoClassificationSample {
@@ -67,7 +65,7 @@ static void createTrainingPipelineVideoClassification(
LocationName locationName = LocationName.of(project, location);
String trainingTaskDefinition =
"gs://google-cloud-aiplatform/schema/trainingjob/definition/"
- + "automl_video_classification_1.0.0.yaml";
+ + "automl_video_classification_1.0.0.yaml";
InputDataConfig inputDataConfig =
InputDataConfig.newBuilder().setDatasetId(datasetId).build();
diff --git a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java
index d49fcff96..a8e5ee149 100644
--- a/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java
+++ b/samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java
@@ -28,12 +28,10 @@
import com.google.cloud.aiplatform.v1beta1.PredefinedSplit;
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
-import com.google.protobuf.Any;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import com.google.rpc.Status;
import java.io.IOException;
-import java.util.List;
public class CreateTrainingPipelineVideoObjectTrackingSample {
@@ -66,7 +64,7 @@ static void createTrainingPipelineVideoObjectTracking(
String location = "us-central1";
String trainingTaskDefinition =
"gs://google-cloud-aiplatform/schema/trainingjob/definition/"
- + "automl_video_object_tracking_1.0.0.yaml";
+ + "automl_video_object_tracking_1.0.0.yaml";
LocationName locationName = LocationName.of(project, location);
String jsonString = "{\"modelType\": \"CLOUD\"}";
diff --git a/samples/snippets/src/main/java/aiplatform/DeleteDatasetSample.java b/samples/snippets/src/main/java/aiplatform/DeleteDatasetSample.java
index a9989b564..39ad52d0f 100644
--- a/samples/snippets/src/main/java/aiplatform/DeleteDatasetSample.java
+++ b/samples/snippets/src/main/java/aiplatform/DeleteDatasetSample.java
@@ -59,7 +59,7 @@ static void deleteDatasetSample(String project, String datasetId)
System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName());
System.out.println("Waiting for operation to finish...");
operationFuture.get(300, TimeUnit.SECONDS);
-
+
System.out.format("Deleted Dataset.");
}
}
diff --git a/samples/snippets/src/main/java/aiplatform/DeployModelCustomTrainedModelSample.java b/samples/snippets/src/main/java/aiplatform/DeployModelCustomTrainedModelSample.java
new file mode 100644
index 000000000..471961b69
--- /dev/null
+++ b/samples/snippets/src/main/java/aiplatform/DeployModelCustomTrainedModelSample.java
@@ -0,0 +1,92 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_deploy_model_custom_trained_model_sample]
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.DedicatedResources;
+import com.google.cloud.aiplatform.v1beta1.DeployModelOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.DeployModelResponse;
+import com.google.cloud.aiplatform.v1beta1.DeployedModel;
+import com.google.cloud.aiplatform.v1beta1.EndpointName;
+import com.google.cloud.aiplatform.v1beta1.EndpointServiceClient;
+import com.google.cloud.aiplatform.v1beta1.EndpointServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.MachineSpec;
+import com.google.cloud.aiplatform.v1beta1.ModelName;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+
+public class DeployModelCustomTrainedModelSample {
+
+ public static void main(String[] args)
+ throws IOException, ExecutionException, InterruptedException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "PROJECT";
+ String endpointId = "ENDPOINT_ID";
+ String modelName = "MODEL_NAME";
+ String deployedModelDisplayName = "DEPLOYED_MODEL_DISPLAY_NAME";
+ deployModelCustomTrainedModelSample(project, endpointId, modelName, deployedModelDisplayName);
+ }
+
+ static void deployModelCustomTrainedModelSample(
+ String project, String endpointId, String model, String deployedModelDisplayName)
+ throws IOException, ExecutionException, InterruptedException {
+ EndpointServiceSettings settings =
+ EndpointServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ String location = "us-central1";
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (EndpointServiceClient client = EndpointServiceClient.create(settings)) {
+ MachineSpec machineSpec = MachineSpec.newBuilder().setMachineType("n1-standard-2").build();
+ DedicatedResources dedicatedResources =
+ DedicatedResources.newBuilder().setMinReplicaCount(1).setMachineSpec(machineSpec).build();
+
+ String modelName = ModelName.of(project, location, model).toString();
+ DeployedModel deployedModel =
+ DeployedModel.newBuilder()
+ .setModel(modelName)
+ .setDisplayName(deployedModelDisplayName)
+ // `dedicated_resources` must be used for non-AutoML models
+ .setDedicatedResources(dedicatedResources)
+ .build();
+ // key '0' assigns traffic for the newly deployed model
+ // Traffic percentage values must add up to 100
+ // Leave dictionary empty if endpoint should not accept any traffic
+ Map trafficSplit = new HashMap<>();
+ trafficSplit.put("0", 100);
+ EndpointName endpoint = EndpointName.of(project, location, endpointId);
+ OperationFuture response =
+ client.deployModelAsync(endpoint, deployedModel, trafficSplit);
+
+ // You can use OperationFuture.getInitialFuture to get a future representing the initial
+ // response to the request, which contains information while the operation is in progress.
+ System.out.format("Operation name: %s\n", response.getInitialFuture().get().getName());
+
+ // OperationFuture.get() will block until the operation is finished.
+ DeployModelResponse deployModelResponse = response.get();
+ System.out.format("deployModelResponse: %s\n", deployModelResponse);
+ }
+ }
+}
+
+// [END aiplatform_deploy_model_custom_trained_model_sample]
diff --git a/samples/snippets/src/main/java/aiplatform/ExportModelVideoActionRecognitionSample.java b/samples/snippets/src/main/java/aiplatform/ExportModelVideoActionRecognitionSample.java
new file mode 100644
index 000000000..1d70cd0d8
--- /dev/null
+++ b/samples/snippets/src/main/java/aiplatform/ExportModelVideoActionRecognitionSample.java
@@ -0,0 +1,79 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_export_model_video_action_recognition_sample]
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.ExportModelOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ExportModelRequest;
+import com.google.cloud.aiplatform.v1beta1.ExportModelResponse;
+import com.google.cloud.aiplatform.v1beta1.GcsDestination;
+import com.google.cloud.aiplatform.v1beta1.ModelName;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceClient;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings;
+import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+
+public class ExportModelVideoActionRecognitionSample {
+
+ public static void main(String[] args)
+ throws IOException, ExecutionException, InterruptedException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "PROJECT";
+ String modelId = "MODEL_ID";
+ String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX";
+ String exportFormat = "EXPORT_FORMAT";
+ exportModelVideoActionRecognitionSample(
+ project, modelId, gcsDestinationOutputUriPrefix, exportFormat);
+ }
+
+ static void exportModelVideoActionRecognitionSample(
+ String project, String modelId, String gcsDestinationOutputUriPrefix, String exportFormat)
+ throws IOException, ExecutionException, InterruptedException {
+ ModelServiceSettings settings =
+ ModelServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ String location = "us-central1";
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (ModelServiceClient client = ModelServiceClient.create(settings)) {
+ GcsDestination gcsDestination =
+ GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build();
+ ExportModelRequest.OutputConfig outputConfig =
+ ExportModelRequest.OutputConfig.newBuilder()
+ .setArtifactDestination(gcsDestination)
+ .setExportFormatId(exportFormat)
+ .build();
+ ModelName name = ModelName.of(project, location, modelId);
+ OperationFuture response =
+ client.exportModelAsync(name, outputConfig);
+
+ // You can use OperationFuture.getInitialFuture to get a future representing the initial
+ // response to the request, which contains information while the operation is in progress.
+ System.out.format("Operation name: %s\n", response.getInitialFuture().get().getName());
+
+ // OperationFuture.get() will block until the operation is finished.
+ ExportModelResponse exportModelResponse = response.get();
+ System.out.format("exportModelResponse: %s\n", exportModelResponse);
+ }
+ }
+}
+
+// [END aiplatform_export_model_video_action_recognition_sample]
diff --git a/samples/snippets/src/main/java/aiplatform/GetHyperparameterTuningJobSample.java b/samples/snippets/src/main/java/aiplatform/GetHyperparameterTuningJobSample.java
new file mode 100644
index 000000000..b4d100c3c
--- /dev/null
+++ b/samples/snippets/src/main/java/aiplatform/GetHyperparameterTuningJobSample.java
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_get_hyperparameter_tuning_job_sample]
+import com.google.cloud.aiplatform.v1beta1.HyperparameterTuningJob;
+import com.google.cloud.aiplatform.v1beta1.HyperparameterTuningJobName;
+import com.google.cloud.aiplatform.v1beta1.JobServiceClient;
+import com.google.cloud.aiplatform.v1beta1.JobServiceSettings;
+import java.io.IOException;
+
+public class GetHyperparameterTuningJobSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "PROJECT";
+ String hyperparameterTuningJobId = "HYPERPARAMETER_TUNING_JOB_ID";
+ getHyperparameterTuningJobSample(project, hyperparameterTuningJobId);
+ }
+
+ static void getHyperparameterTuningJobSample(String project, String hyperparameterTuningJobId)
+ throws IOException {
+ JobServiceSettings settings =
+ JobServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ String location = "us-central1";
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (JobServiceClient client = JobServiceClient.create(settings)) {
+ HyperparameterTuningJobName name =
+ HyperparameterTuningJobName.of(project, location, hyperparameterTuningJobId);
+ HyperparameterTuningJob response = client.getHyperparameterTuningJob(name);
+ System.out.format("response: %s\n", response);
+ }
+ }
+}
+
+// [END aiplatform_get_hyperparameter_tuning_job_sample]
diff --git a/samples/snippets/src/main/java/aiplatform/GetModelEvaluationVideoActionRecognitionSample.java b/samples/snippets/src/main/java/aiplatform/GetModelEvaluationVideoActionRecognitionSample.java
new file mode 100644
index 000000000..2685e23f1
--- /dev/null
+++ b/samples/snippets/src/main/java/aiplatform/GetModelEvaluationVideoActionRecognitionSample.java
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_get_model_evaluation_video_action_recognition_sample]
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluation;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceClient;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings;
+import java.io.IOException;
+
+public class GetModelEvaluationVideoActionRecognitionSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "PROJECT";
+ String modelId = "MODEL_ID";
+ String evaluationId = "EVALUATION_ID";
+ getModelEvaluationVideoActionRecognitionSample(project, modelId, evaluationId);
+ }
+
+ static void getModelEvaluationVideoActionRecognitionSample(
+ String project, String modelId, String evaluationId) throws IOException {
+ ModelServiceSettings settings =
+ ModelServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ String location = "us-central1";
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (ModelServiceClient client = ModelServiceClient.create(settings)) {
+ ModelEvaluationName name = ModelEvaluationName.of(project, location, modelId, evaluationId);
+ ModelEvaluation response = client.getModelEvaluation(name);
+ System.out.format("response: %s\n", response);
+ }
+ }
+}
+
+// [END aiplatform_get_model_evaluation_video_action_recognition_sample]
diff --git a/samples/snippets/src/main/java/aiplatform/ImportDataImageClassificationSample.java b/samples/snippets/src/main/java/aiplatform/ImportDataImageClassificationSample.java
index 04beb8c54..7bc46c447 100644
--- a/samples/snippets/src/main/java/aiplatform/ImportDataImageClassificationSample.java
+++ b/samples/snippets/src/main/java/aiplatform/ImportDataImageClassificationSample.java
@@ -61,7 +61,7 @@ static void importDataImageClassificationSample(
String location = "us-central1";
String importSchemaUri =
"gs://google-cloud-aiplatform/schema/dataset/ioformat/"
- + "image_classification_single_label_io_format_1.0.0.yaml";
+ + "image_classification_single_label_io_format_1.0.0.yaml";
GcsSource.Builder gcsSource = GcsSource.newBuilder();
gcsSource.addUris(gcsSourceUri);
@@ -81,8 +81,8 @@ static void importDataImageClassificationSample(
System.out.println("Waiting for operation to finish...");
ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS);
- System.out.format("Import Data Image Classification Response: %s\n",
- importDataResponse.toString());
+ System.out.format(
+ "Import Data Image Classification Response: %s\n", importDataResponse.toString());
}
}
}
diff --git a/samples/snippets/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java b/samples/snippets/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java
index ae17cfd3a..947b3f232 100644
--- a/samples/snippets/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java
+++ b/samples/snippets/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java
@@ -61,7 +61,7 @@ static void importDataImageObjectDetectionSample(
String location = "us-central1";
String importSchemaUri =
"gs://google-cloud-aiplatform/schema/dataset/ioformat/"
- + "image_bounding_box_io_format_1.0.0.yaml";
+ + "image_bounding_box_io_format_1.0.0.yaml";
GcsSource.Builder gcsSource = GcsSource.newBuilder();
gcsSource.addUris(gcsSourceUri);
DatasetName datasetName = DatasetName.of(project, location, datasetId);
@@ -80,8 +80,8 @@ static void importDataImageObjectDetectionSample(
System.out.println("Waiting for operation to finish...");
ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS);
- System.out.format("Import Data Image Object Detection Response: %s\n",
- importDataResponse.toString());
+ System.out.format(
+ "Import Data Image Object Detection Response: %s\n", importDataResponse.toString());
}
}
}
diff --git a/samples/snippets/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java b/samples/snippets/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java
index cf08bd38a..8c285a1b6 100644
--- a/samples/snippets/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java
+++ b/samples/snippets/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java
@@ -62,7 +62,7 @@ static void importDataTextClassificationSingleLabelSample(
String location = "us-central1";
String importSchemaUri =
"gs://google-cloud-aiplatform/schema/dataset/ioformat/"
- + "text_classification_single_label_io_format_1.0.0.yaml";
+ + "text_classification_single_label_io_format_1.0.0.yaml";
GcsSource.Builder gcsSource = GcsSource.newBuilder();
gcsSource.addUris(gcsSourceUri);
@@ -82,8 +82,8 @@ static void importDataTextClassificationSingleLabelSample(
System.out.println("Waiting for operation to finish...");
ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS);
- System.out.format("Import Data Text Classification Response: %s\n",
- importDataResponse.toString());
+ System.out.format(
+ "Import Data Text Classification Response: %s\n", importDataResponse.toString());
}
}
}
diff --git a/samples/snippets/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java b/samples/snippets/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java
index 6bd5c4f02..be53db261 100644
--- a/samples/snippets/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java
+++ b/samples/snippets/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java
@@ -61,7 +61,7 @@ static void importDataTextEntityExtractionSample(
String location = "us-central1";
String importSchemaUri =
"gs://google-cloud-aiplatform/schema/dataset/ioformat/"
- + "text_extraction_io_format_1.0.0.yaml";
+ + "text_extraction_io_format_1.0.0.yaml";
GcsSource.Builder gcsSource = GcsSource.newBuilder();
gcsSource.addUris(gcsSourceUri);
@@ -81,8 +81,8 @@ static void importDataTextEntityExtractionSample(
System.out.println("Waiting for operation to finish...");
ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS);
- System.out.format("Import Data Text Entity Extraction Response: %s\n",
- importDataResponse.toString());
+ System.out.format(
+ "Import Data Text Entity Extraction Response: %s\n", importDataResponse.toString());
}
}
}
diff --git a/samples/snippets/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java b/samples/snippets/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java
index e17506783..4634d6f64 100644
--- a/samples/snippets/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java
+++ b/samples/snippets/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java
@@ -62,7 +62,7 @@ static void importDataTextSentimentAnalysisSample(
String location = "us-central1";
String importSchemaUri =
"gs://google-cloud-aiplatform/schema/dataset/ioformat/"
- + "text_sentiment_io_format_1.0.0.yaml";
+ + "text_sentiment_io_format_1.0.0.yaml";
GcsSource.Builder gcsSource = GcsSource.newBuilder();
gcsSource.addUris(gcsSourceUri);
@@ -82,8 +82,8 @@ static void importDataTextSentimentAnalysisSample(
System.out.println("Waiting for operation to finish...");
ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS);
- System.out.format("Import Data Text Sentiment Analysis Response: %s\n",
- importDataResponse.toString());
+ System.out.format(
+ "Import Data Text Sentiment Analysis Response: %s\n", importDataResponse.toString());
}
}
}
diff --git a/samples/snippets/src/main/java/aiplatform/ImportDataVideoActionRecognitionSample.java b/samples/snippets/src/main/java/aiplatform/ImportDataVideoActionRecognitionSample.java
new file mode 100644
index 000000000..ac182ecbc
--- /dev/null
+++ b/samples/snippets/src/main/java/aiplatform/ImportDataVideoActionRecognitionSample.java
@@ -0,0 +1,82 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_import_data_video_action_recognition_sample]
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.GcsSource;
+import com.google.cloud.aiplatform.v1beta1.ImportDataConfig;
+import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ImportDataResponse;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+
+public class ImportDataVideoActionRecognitionSample {
+
+ public static void main(String[] args)
+ throws IOException, ExecutionException, InterruptedException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "PROJECT";
+ String datasetId = "DATASET_ID";
+ String gcsSourceUri = "GCS_SOURCE_URI";
+ importDataVideoActionRecognitionSample(project, datasetId, gcsSourceUri);
+ }
+
+ static void importDataVideoActionRecognitionSample(
+ String project, String datasetId, String gcsSourceUri)
+ throws IOException, ExecutionException, InterruptedException {
+ DatasetServiceSettings settings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ String location = "us-central1";
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (DatasetServiceClient client = DatasetServiceClient.create(settings)) {
+ GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build();
+ ImportDataConfig importConfig0 =
+ ImportDataConfig.newBuilder()
+ .setGcsSource(gcsSource)
+ .setImportSchemaUri(
+ "gs://google-cloud-aiplatform/schema/dataset/ioformat/"
+ + "video_action_recognition_io_format_1.0.0.yaml")
+ .build();
+ List importConfigs = new ArrayList<>();
+ importConfigs.add(importConfig0);
+ DatasetName name = DatasetName.of(project, location, datasetId);
+ OperationFuture response =
+ client.importDataAsync(name, importConfigs);
+
+ // You can use OperationFuture.getInitialFuture to get a future representing the initial
+ // response to the request, which contains information while the operation is in progress.
+ System.out.format("Operation name: %s\n", response.getInitialFuture().get().getName());
+
+ // OperationFuture.get() will block until the operation is finished.
+ ImportDataResponse importDataResponse = response.get();
+ System.out.format("importDataResponse: %s\n", importDataResponse);
+ }
+ }
+}
+
+// [END aiplatform_import_data_video_action_recognition_sample]
diff --git a/samples/snippets/src/main/java/aiplatform/ImportDataVideoClassificationSample.java b/samples/snippets/src/main/java/aiplatform/ImportDataVideoClassificationSample.java
index 4bf2c37f3..10a7b673e 100644
--- a/samples/snippets/src/main/java/aiplatform/ImportDataVideoClassificationSample.java
+++ b/samples/snippets/src/main/java/aiplatform/ImportDataVideoClassificationSample.java
@@ -60,7 +60,7 @@ static void importDataVideoClassification(String gcsSourceUri, String project, S
String location = "us-central1";
String importSchemaUri =
"gs://google-cloud-aiplatform/schema/dataset/ioformat/"
- + "video_classification_io_format_1.0.0.yaml";
+ + "video_classification_io_format_1.0.0.yaml";
GcsSource.Builder gcsSource = GcsSource.newBuilder();
gcsSource.addUris(gcsSourceUri);
@@ -81,8 +81,7 @@ static void importDataVideoClassification(String gcsSourceUri, String project, S
ImportDataResponse importDataResponse = importDataResponseFuture.get(1800, TimeUnit.SECONDS);
System.out.format(
- "Import Data Video Classification Response: %s\n",
- importDataResponse.toString());
+ "Import Data Video Classification Response: %s\n", importDataResponse.toString());
}
}
}
diff --git a/samples/snippets/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java b/samples/snippets/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java
index f8a07d914..87807c193 100644
--- a/samples/snippets/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java
+++ b/samples/snippets/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java
@@ -59,7 +59,7 @@ static void importDataVideObjectTracking(String gcsSourceUri, String project, St
String location = "us-central1";
String importSchemaUri =
"gs://google-cloud-aiplatform/schema/dataset/ioformat/"
- + "video_object_tracking_io_format_1.0.0.yaml";
+ + "video_object_tracking_io_format_1.0.0.yaml";
GcsSource.Builder gcsSource = GcsSource.newBuilder();
gcsSource.addUris(gcsSourceUri);
@@ -78,8 +78,8 @@ static void importDataVideObjectTracking(String gcsSourceUri, String project, St
System.out.println("Waiting for operation to finish...");
ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS);
- System.out.format("Import Data Video Object Tracking Response: %s\n",
- importDataResponse.toString());
+ System.out.format(
+ "Import Data Video Object Tracking Response: %s\n", importDataResponse.toString());
}
}
}
diff --git a/samples/snippets/src/test/java/aiplatform/CancelTrainingPipelineSampleTest.java b/samples/snippets/src/test/java/aiplatform/CancelTrainingPipelineSampleTest.java
index 1f5066679..f3daec201 100644
--- a/samples/snippets/src/test/java/aiplatform/CancelTrainingPipelineSampleTest.java
+++ b/samples/snippets/src/test/java/aiplatform/CancelTrainingPipelineSampleTest.java
@@ -37,7 +37,7 @@ public class CancelTrainingPipelineSampleTest {
private static final String DATASET_ID = System.getenv("TRAINING_PIPELINE_DATASET_ID");
private static final String TRAINING_TASK_DEFINITION =
"gs://google-cloud-aiplatform/schema/trainingjob/definition/"
- + "automl_image_classification_1.0.0.yaml";
+ + "automl_image_classification_1.0.0.yaml";
private static String TRAINING_PIPELINE_ID = null;
private ByteArrayOutputStream bout;
private PrintStream out;
diff --git a/samples/snippets/src/test/java/aiplatform/CreateBatchPredictionJobBigquerySampleTest.java b/samples/snippets/src/test/java/aiplatform/CreateBatchPredictionJobBigquerySampleTest.java
new file mode 100644
index 000000000..25114e607
--- /dev/null
+++ b/samples/snippets/src/test/java/aiplatform/CreateBatchPredictionJobBigquerySampleTest.java
@@ -0,0 +1,109 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateBatchPredictionJobBigquerySampleTest {
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID = System.getenv("BATCH_PREDICTION_TABULAR_BQ_MODEL_ID");
+ private static final String BIGQUERY_SOURCE_URI =
+ "bq://ucaip-sample-tests.table_test.all_bq_types";
+ private static final String BIGQUERY_DESTINATION_OUTPUT_URI_PREFIX = "bq://ucaip-sample-tests";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String batchPredictionJobId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("BATCH_PREDICTION_TABULAR_BQ_MODEL_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Batch Prediction Job
+ DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Batch");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateBatchPredictionJobBigquerySample() throws IOException {
+ // Act
+ String batchPredictionDisplayName =
+ String.format(
+ "batch_prediction_bigquery_display_name_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateBatchPredictionJobBigquerySample.createBatchPredictionJobBigquerySample(
+ PROJECT,
+ batchPredictionDisplayName,
+ MODEL_ID,
+ "bigquery",
+ BIGQUERY_SOURCE_URI,
+ "bigquery",
+ BIGQUERY_DESTINATION_OUTPUT_URI_PREFIX);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(batchPredictionDisplayName);
+ assertThat(got).contains("response:");
+ batchPredictionJobId = got.split("Name: ")[1].split("batchPredictionJobs/")[1].split("\n")[0];
+ }
+}
diff --git a/samples/snippets/src/test/java/aiplatform/CreateBatchPredictionJobSampleTest.java b/samples/snippets/src/test/java/aiplatform/CreateBatchPredictionJobSampleTest.java
new file mode 100644
index 000000000..1def01b3d
--- /dev/null
+++ b/samples/snippets/src/test/java/aiplatform/CreateBatchPredictionJobSampleTest.java
@@ -0,0 +1,109 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateBatchPredictionJobSampleTest {
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID = System.getenv("BATCH_PREDICTION_MODEL_ID");
+ private static final String GCS_SOURCE_URI =
+ "gs://ucaip-samples-test-output/inputs/icn_batch_prediction_input.jsonl";
+ private static final String GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String batchPredictionJobId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("BATCH_PREDICTION_MODEL_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Batch Prediction Job
+ DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Batch");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateBatchPredictionJobSample() throws IOException {
+ // Act
+ String batchPredictionDisplayName =
+ String.format(
+ "batch_prediction_bigquery_display_name_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateBatchPredictionJobSample.createBatchPredictionJobSample(
+ PROJECT,
+ batchPredictionDisplayName,
+ MODEL_ID,
+ "jsonl",
+ GCS_SOURCE_URI,
+ "jsonl",
+ GCS_OUTPUT_URI);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(batchPredictionDisplayName);
+ assertThat(got).contains("response:");
+ batchPredictionJobId = got.split("Name: ")[1].split("batchPredictionJobs/")[1].split("\n")[0];
+ }
+}
diff --git a/samples/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSampleTest.java b/samples/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSampleTest.java
new file mode 100644
index 000000000..5ca1e5f0f
--- /dev/null
+++ b/samples/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSampleTest.java
@@ -0,0 +1,105 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateBatchPredictionJobVideoActionRecognitionSampleTest {
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID =
+ System.getenv("BATCH_PREDICTION_VIDEO_ACTION_RECOGNITION_MODEL_ID");
+ private static final String GCS_SOURCE_URI =
+ "gs://ucaip-samples-test-output/inputs/icn_batch_prediction_input.jsonl";
+ private static final String GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String batchPredictionJobId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("BATCH_PREDICTION_VIDEO_ACTION_RECOGNITION_MODEL_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Batch Prediction Job
+ DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Batch");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateBatchPredictionJobVideoActionRecognitionSample() throws IOException {
+ // Act
+ String batchPredictionDisplayName =
+ String.format(
+ "batch_prediction_video_action_recognition_display_name_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateBatchPredictionJobVideoActionRecognitionSample
+ .createBatchPredictionJobVideoActionRecognitionSample(
+ PROJECT, batchPredictionDisplayName, MODEL_ID, GCS_SOURCE_URI, GCS_OUTPUT_URI);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(batchPredictionDisplayName);
+ assertThat(got).contains("response:");
+ batchPredictionJobId = got.split("Name: ")[1].split("batchPredictionJobs/")[1].split("\n")[0];
+ }
+}
diff --git a/samples/snippets/src/test/java/aiplatform/CreateDataLabelingJobActiveLearningSampleTest.java b/samples/snippets/src/test/java/aiplatform/CreateDataLabelingJobActiveLearningSampleTest.java
new file mode 100644
index 000000000..5280476f3
--- /dev/null
+++ b/samples/snippets/src/test/java/aiplatform/CreateDataLabelingJobActiveLearningSampleTest.java
@@ -0,0 +1,115 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Ignore;
+import org.junit.Test;
+
+public class CreateDataLabelingJobActiveLearningSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String DATASET_ID =
+ System.getenv("DATA_LABELING_ACTIVE_LEARNING_DATASET_ID");
+ private static final String INSTRUCTION_URI =
+ "gs://ucaip-sample-resources/images/datalabeling_instructions.pdf";
+ private static final String INPUTS_SCHEMA_URI =
+ "gs://google-cloud-aiplatform/schema/datalabelingjob/inputs/image_classification_1.0.0.yaml";
+ private static final String ANNOTATION_SPEC = "roses";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String dataLabelingJobId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("DATA_LABELING_ACTIVE_LEARNING_DATASET_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel data labeling job
+ CancelDataLabelingJobSample.cancelDataLabelingJob(PROJECT, dataLabelingJobId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled Data labeling job");
+ TimeUnit.MINUTES.sleep(1);
+
+ // Delete the created dataset
+ DeleteDataLabelingJobSample.deleteDataLabelingJob(PROJECT, dataLabelingJobId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Data Labeling Job.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ @Ignore("Avoid creating actual data labeling job for humans")
+ public void testCreateDataLabelingJobActiveLearningSample() throws IOException {
+ // Act
+ String dataLabelingDisplayName =
+ String.format(
+ "temp_data_labeling_job_active_learning_display_name_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateDataLabelingJobActiveLearningSample.createDataLabelingJobActiveLearningSample(
+ PROJECT,
+ dataLabelingDisplayName,
+ DATASET_ID,
+ INSTRUCTION_URI,
+ INPUTS_SCHEMA_URI,
+ ANNOTATION_SPEC);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(dataLabelingDisplayName);
+ assertThat(got).contains("Create Data Labeling Job Image Response");
+ dataLabelingJobId = got.split("Name: ")[1].split("dataLabelingJobs/")[1].split("\n")[0];
+ }
+}
diff --git a/samples/snippets/src/test/java/aiplatform/CreateDataLabelingJobSpecialistPoolSampleTest.java b/samples/snippets/src/test/java/aiplatform/CreateDataLabelingJobSpecialistPoolSampleTest.java
new file mode 100644
index 000000000..7c41c5d84
--- /dev/null
+++ b/samples/snippets/src/test/java/aiplatform/CreateDataLabelingJobSpecialistPoolSampleTest.java
@@ -0,0 +1,118 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Ignore;
+import org.junit.Test;
+
+public class CreateDataLabelingJobSpecialistPoolSampleTest {
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String DATASET_ID =
+ System.getenv("DATA_LABELING_ACTIVE_LEARNING_DATASET_ID");
+ private static final String SPECIALIST_POOL_ID =
+ System.getenv("DATA_LABELING_SPECIALIST_POOL_ID");
+ private static final String INSTRUCTION_URI =
+ "gs://ucaip-sample-resources/images/datalabeling_instructions.pdf";
+ private static final String INPUTS_SCHEMA_URI =
+ "gs://google-cloud-aiplatform/schema/datalabelingjob/inputs/image_classification_1.0.0.yaml";
+ private static final String ANNOTATION_SPEC = "roses";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String dataLabelingJobId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("DATA_LABELING_ACTIVE_LEARNING_DATASET_ID");
+ requireEnvVar("DATA_LABELING_SPECIALIST_POOL_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel data labeling job
+ CancelDataLabelingJobSample.cancelDataLabelingJob(PROJECT, dataLabelingJobId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled Data labeling job");
+ TimeUnit.MINUTES.sleep(1);
+
+ // Delete the created dataset
+ DeleteDataLabelingJobSample.deleteDataLabelingJob(PROJECT, dataLabelingJobId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Data Labeling Job.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ @Ignore("Avoid creating actual data labeling job for humans")
+ public void testCreateDataLabelingJobSpecialistPoolSample() throws IOException {
+ // Act
+ String dataLabelingDisplayName =
+ String.format(
+ "temp_data_labeling_job_specialist_pool_display_name_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateDataLabelingJobSpecialistPoolSample.createDataLabelingJobSpecialistPoolSample(
+ PROJECT,
+ dataLabelingDisplayName,
+ DATASET_ID,
+ SPECIALIST_POOL_ID,
+ INSTRUCTION_URI,
+ INPUTS_SCHEMA_URI,
+ ANNOTATION_SPEC);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(dataLabelingDisplayName);
+ assertThat(got).contains("Create Data Labeling Job Image Response");
+ dataLabelingJobId = got.split("Name: ")[1].split("dataLabelingJobs/")[1].split("\n")[0];
+ }
+}
diff --git a/samples/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java b/samples/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java
index 10a26a5e1..7592b9aae 100644
--- a/samples/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java
+++ b/samples/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java
@@ -81,8 +81,8 @@ public void testCreateDatasetTabularGcsSample()
"temp_create_dataset_table_gcs_test_%s",
UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
- CreateDatasetTabularGcsSample.createDatasetTableGcs(PROJECT,
- datasetDisplayName, GCS_SOURCE_URI);
+ CreateDatasetTabularGcsSample.createDatasetTableGcs(
+ PROJECT, datasetDisplayName, GCS_SOURCE_URI);
// Assert
String got = bout.toString();
diff --git a/samples/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSampleTest.java b/samples/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSampleTest.java
new file mode 100644
index 000000000..486322ce4
--- /dev/null
+++ b/samples/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSampleTest.java
@@ -0,0 +1,115 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import com.google.cloud.aiplatform.v1beta1.JobServiceClient;
+import com.google.cloud.aiplatform.v1beta1.JobServiceSettings;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateHyperparameterTuningJobPythonPackageSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String EXECUTOR_IMAGE_URI =
+ "us.gcr.io/cloud-aiplatform/training/tf-gpu.2-1:latest";
+ private static final String PACKAGE_URI =
+ "gs://ucaip-test-us-central1/training/pythonpackages/trainer.tar.bz2";
+ private static final String PYTHON_MODULE = "trainer.hptuning_trainer";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String hyperparameterJobId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ JobServiceSettings settings =
+ JobServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ try (JobServiceClient client = JobServiceClient.create(settings)) {
+ // Cancel hyper parameter job
+ String hyperparameterJobName =
+ String.format(
+ "projects/%s/locations/us-central1/hyperparameterTuningJobs/%s",
+ PROJECT, hyperparameterJobId);
+ client.cancelHyperparameterTuningJob(hyperparameterJobName);
+
+ TimeUnit.MINUTES.sleep(1);
+
+ // Delete the created job
+ client.deleteHyperparameterTuningJobAsync(hyperparameterJobName);
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+ }
+
+ @Test
+ public void testCreateHyperparameterTuningJobPythonPackageSample() throws IOException {
+ String hyperparameterTuningJobDisplayName =
+ String.format(
+ "temp_hyperparameter_tuning_job_display_name_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+ CreateHyperparameterTuningJobPythonPackageSample
+ .createHyperparameterTuningJobPythonPackageSample(
+ PROJECT,
+ hyperparameterTuningJobDisplayName,
+ EXECUTOR_IMAGE_URI,
+ PACKAGE_URI,
+ PYTHON_MODULE);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(hyperparameterTuningJobDisplayName);
+ assertThat(got).contains("response:");
+ hyperparameterJobId =
+ got.split("Name: ")[1].split("hyperparameterTuningJobs/")[1].split("\n")[0];
+ }
+}
diff --git a/samples/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobSampleTest.java b/samples/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobSampleTest.java
new file mode 100644
index 000000000..be9273c6e
--- /dev/null
+++ b/samples/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobSampleTest.java
@@ -0,0 +1,105 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import com.google.cloud.aiplatform.v1beta1.JobServiceClient;
+import com.google.cloud.aiplatform.v1beta1.JobServiceSettings;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateHyperparameterTuningJobSampleTest {
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String CONTAINER_IMAGE_URI = "gcr.io/ucaip-test/ucaip-training-test:latest";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String hyperparameterJobId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ JobServiceSettings settings =
+ JobServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ try (JobServiceClient client = JobServiceClient.create(settings)) {
+ // Cancel hyper parameter job
+ String hyperparameterJobName =
+ String.format(
+ "projects/%s/locations/us-central1/hyperparameterTuningJobs/%s",
+ PROJECT, hyperparameterJobId);
+ client.cancelHyperparameterTuningJob(hyperparameterJobName);
+
+ TimeUnit.MINUTES.sleep(1);
+
+ // Delete the created job
+ client.deleteHyperparameterTuningJobAsync(hyperparameterJobName);
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+ }
+
+ @Test
+ public void testCreateHyperparameterTuningJobSample() throws IOException {
+ String hyperparameterTuningJobDisplayName =
+ String.format(
+ "temp_hyperparameter_tuning_job_display_name_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateHyperparameterTuningJobSample.createHyperparameterTuningJobSample(
+ PROJECT, hyperparameterTuningJobDisplayName, CONTAINER_IMAGE_URI);
+
+ String got = bout.toString();
+ assertThat(got).contains(hyperparameterTuningJobDisplayName);
+ assertThat(got).contains("response:");
+ hyperparameterJobId =
+ got.split("Name: ")[1].split("hyperparameterTuningJobs/")[1].split("\n")[0];
+ }
+}
diff --git a/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineCustomJobSampleTest.java b/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineCustomJobSampleTest.java
new file mode 100644
index 000000000..ac4a97bd9
--- /dev/null
+++ b/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineCustomJobSampleTest.java
@@ -0,0 +1,112 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateTrainingPipelineCustomJobSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String CONTAINER_IMAGE_URI =
+ "gcr.io/ucaip-sample-tests/mnist-custom-job:latest";
+ private static final String GCS_OUTPUT_DIRECTORY =
+ "gs://ucaip-samples-us-central1/training_pipeline_output";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String trainingPipelineId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel the Training Pipeline
+ CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Training Pipeline");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Training Pipeline
+ DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Training Pipeline.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateTrainingPipelineCustomJobSample() throws IOException {
+ // Act
+ String trainingPipelineDisplayName =
+ String.format(
+ "temp_create_training_pipeline_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ String modelDisplayName =
+ String.format(
+ "temp_create_training_pipeline_model_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateTrainingPipelineCustomJobSample.createTrainingPipelineCustomJobSample(
+ PROJECT,
+ trainingPipelineDisplayName,
+ modelDisplayName,
+ CONTAINER_IMAGE_URI,
+ GCS_OUTPUT_DIRECTORY);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(trainingPipelineDisplayName);
+ trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0];
+ }
+}
diff --git a/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSampleTest.java b/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSampleTest.java
new file mode 100644
index 000000000..9065f8191
--- /dev/null
+++ b/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSampleTest.java
@@ -0,0 +1,120 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateTrainingPipelineCustomTrainingManagedDatasetSampleTest {
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String DATASET_ID = System.getenv("CUSTOM_MANAGED_DATASET");
+ private static final String ANNOTATION_SCHEMA_URI =
+ "gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml";
+ private static final String TRAINING_CONTAINER_IMAGE_URI =
+ "gcr.io/ucaip-test/custom-container-managed-dataset:latest";
+ private static final String MODEL_CONTAIN_SPEC_IMAGE_URI =
+ "gcr.io/cloud-aiplatform/prediction/tf-gpu.1-15:latest";
+ private static final String GCS_OUTPUT_DIRECTORY =
+ "gs://ucaip-samples-us-central1/training_pipeline_output/custom_training_managed_dataset";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String trainingPipelineId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel the Training Pipeline
+ CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Training Pipeline");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Training Pipeline
+ DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Training Pipeline.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateTrainingPipelineCustomTrainingManagedDatasetSample() throws IOException {
+ // Act
+ String trainingPipelineDisplayName =
+ String.format(
+ "temp_create_training_pipeline_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ String modelDisplayName =
+ String.format(
+ "temp_create_training_pipeline_model_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateTrainingPipelineCustomTrainingManagedDatasetSample
+ .createTrainingPipelineCustomTrainingManagedDatasetSample(
+ PROJECT,
+ trainingPipelineDisplayName,
+ modelDisplayName,
+ DATASET_ID,
+ ANNOTATION_SCHEMA_URI,
+ TRAINING_CONTAINER_IMAGE_URI,
+ MODEL_CONTAIN_SPEC_IMAGE_URI,
+ GCS_OUTPUT_DIRECTORY);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(trainingPipelineDisplayName);
+ trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0];
+ }
+}
diff --git a/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java b/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java
index 65f7d041b..f8bc4769a 100644
--- a/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java
+++ b/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java
@@ -39,21 +39,21 @@ public class CreateTrainingPipelineTabularClassificationSampleTest {
private static final String TARGET_COLUMN = "TripType";
private static final String TRANSFORMATION =
"[{\"numeric\":{\"columnName\":\"Age\",\"invalidValuesAllowed\":false}},"
- + "{\"categorical\":{\"columnName\":\"Job\"}},"
- + "{\"categorical\":{\"columnName\":\"MaritalStatus\"}},"
- + "{\"categorical\":{\"columnName\":\"Default\"}},"
- + "{\"numeric\":{\"columnName\":\"Balance\",\"invalidValuesAllowed\":false}},"
- + "{\"categorical\":{\"columnName\":\"Housing\"}},"
- + "{\"categorical\":{\"columnName\":\"Loan\"}},"
- + "{\"categorical\":{\"columnName\":\"Contact\"}},"
- + "{\"numeric\":{\"columnName\":\"Day\",\"invalidValuesAllowed\":false}},"
- + "{\"categorical\":{\"columnName\":\"Month\"}},"
- + "{\"numeric\":{\"columnName\":\"Duration\",\"invalidValuesAllowed\":false}},"
- + "{\"numeric\":{\"columnName\":\"Campaign\",\"invalidValuesAllowed\":false}},"
- + "{\"numeric\":{\"columnName\":\"PDays\",\"invalidValuesAllowed\":false}},"
- + "{\"numeric\":{\"columnName\":\"Previous\",\"invalidValuesAllowed\":false}},"
- + "{\"categorical\":{\"columnName\":\"POutcome\"}},"
- + "{\"categorical\":{\"columnName\":\"Deposit\"}}]";
+ + "{\"categorical\":{\"columnName\":\"Job\"}},"
+ + "{\"categorical\":{\"columnName\":\"MaritalStatus\"}},"
+ + "{\"categorical\":{\"columnName\":\"Default\"}},"
+ + "{\"numeric\":{\"columnName\":\"Balance\",\"invalidValuesAllowed\":false}},"
+ + "{\"categorical\":{\"columnName\":\"Housing\"}},"
+ + "{\"categorical\":{\"columnName\":\"Loan\"}},"
+ + "{\"categorical\":{\"columnName\":\"Contact\"}},"
+ + "{\"numeric\":{\"columnName\":\"Day\",\"invalidValuesAllowed\":false}},"
+ + "{\"categorical\":{\"columnName\":\"Month\"}},"
+ + "{\"numeric\":{\"columnName\":\"Duration\",\"invalidValuesAllowed\":false}},"
+ + "{\"numeric\":{\"columnName\":\"Campaign\",\"invalidValuesAllowed\":false}},"
+ + "{\"numeric\":{\"columnName\":\"PDays\",\"invalidValuesAllowed\":false}},"
+ + "{\"numeric\":{\"columnName\":\"Previous\",\"invalidValuesAllowed\":false}},"
+ + "{\"categorical\":{\"columnName\":\"POutcome\"}},"
+ + "{\"categorical\":{\"columnName\":\"Deposit\"}}]";
private ByteArrayOutputStream bout;
private PrintStream out;
private PrintStream originalPrintStream;
diff --git a/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java b/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java
index 3933106f3..3298e0eba 100644
--- a/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java
+++ b/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java
@@ -35,23 +35,23 @@ public class CreateTrainingPipelineTabularRegressionSampleTest {
private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
private static final String DATASET_ID =
- System.getenv("TRAINING_PIPELINE_TABLES_REGRESSION_DATASET_ID");
+ System.getenv("TRAINING_PIPELINE_TABLES_REGRESSION_DATASET_ID");
private static final String TARGET_COLUMN = "Amount";
private static final String TRANSFORMATION =
"[{\"categorical\":{\"columnName\":\"SC_Group_Desc\"}},"
- + "{\"categorical\":{\"columnName\":\"SC_GroupCommod_ID\"}},"
- + "{\"categorical\":{\"columnName\":\"SC_GroupCommod_Desc\"}},"
- + "{\"numeric\":{\"columnName\":\"SortOrder\",\"invalidValuesAllowed\":false}},"
- + "{\"text\":{\"columnName\":\"SC_GeographyIndented_Desc\"}},"
- + "{\"numeric\":{\"columnName\":\"SC_Commodity_ID\",\"invalidValuesAllowed\":false}},"
- + "{\"text\":{\"columnName\":\"SC_Commodity_Desc\"}},"
- + "{\"numeric\":{\"columnName\":\"SC_Attribute_ID\",\"invalidValuesAllowed\":false}},"
- + "{\"text\":{\"columnName\":\"SC_Attribute_Desc\"}},"
- + "{\"numeric\":{\"columnName\":\"SC_Unit_ID\",\"invalidValuesAllowed\":false}},"
- + "{\"numeric\":{\"columnName\":\"Year_ID\",\"invalidValuesAllowed\":false}},"
- + "{\"categorical\":{\"columnName\":\"SC_Frequency_Desc\"}},"
- + "{\"numeric\":{\"columnName\":\"Timeperiod_ID\",\"invalidValuesAllowed\":false}},"
- + "{\"text\":{\"columnName\":\"Timeperiod_Desc\"}}]";
+ + "{\"categorical\":{\"columnName\":\"SC_GroupCommod_ID\"}},"
+ + "{\"categorical\":{\"columnName\":\"SC_GroupCommod_Desc\"}},"
+ + "{\"numeric\":{\"columnName\":\"SortOrder\",\"invalidValuesAllowed\":false}},"
+ + "{\"text\":{\"columnName\":\"SC_GeographyIndented_Desc\"}},"
+ + "{\"numeric\":{\"columnName\":\"SC_Commodity_ID\",\"invalidValuesAllowed\":false}},"
+ + "{\"text\":{\"columnName\":\"SC_Commodity_Desc\"}},"
+ + "{\"numeric\":{\"columnName\":\"SC_Attribute_ID\",\"invalidValuesAllowed\":false}},"
+ + "{\"text\":{\"columnName\":\"SC_Attribute_Desc\"}},"
+ + "{\"numeric\":{\"columnName\":\"SC_Unit_ID\",\"invalidValuesAllowed\":false}},"
+ + "{\"numeric\":{\"columnName\":\"Year_ID\",\"invalidValuesAllowed\":false}},"
+ + "{\"categorical\":{\"columnName\":\"SC_Frequency_Desc\"}},"
+ + "{\"numeric\":{\"columnName\":\"Timeperiod_ID\",\"invalidValuesAllowed\":false}},"
+ + "{\"text\":{\"columnName\":\"Timeperiod_Desc\"}}]";
private ByteArrayOutputStream bout;
private PrintStream out;
private PrintStream originalPrintStream;
diff --git a/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java b/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java
index bd0f29461..f4e99b2a5 100644
--- a/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java
+++ b/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java
@@ -90,14 +90,10 @@ public void testCreateTrainingPipelineTextSentimentAnalysisSample() throws IOExc
String tempUuid = UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26);
// Act
String trainingPipelineDisplayName =
- String.format(
- "temp_create_training_pipeline_test_%s",
- tempUuid);
+ String.format("temp_create_training_pipeline_test_%s", tempUuid);
String modelDisplayName =
- String.format(
- "temp_create_training_pipeline_model_test_%s",
- tempUuid);
+ String.format("temp_create_training_pipeline_model_test_%s", tempUuid);
CreateTrainingPipelineTextSentimentAnalysisSample
.createTrainingPipelineTextSentimentAnalysisSample(
diff --git a/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSampleTest.java b/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSampleTest.java
new file mode 100644
index 000000000..3e39be5f8
--- /dev/null
+++ b/samples/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSampleTest.java
@@ -0,0 +1,109 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateTrainingPipelineVideoActionRecognitionSampleTest {
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String DATASET_ID =
+ System.getenv("TRAINING_PIPELINE_VIDEO_ACTION_DATASET_ID");
+ private static final String MODEL = "CLOUD";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String trainingPipelineId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TRAINING_PIPELINE_VIDEO_ACTION_DATASET_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel the Training Pipeline
+ CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Training Pipeline");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Training Pipeline
+ DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Training Pipeline.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateTrainingPipelineVideoActionRecognitionSample() throws IOException {
+ // Act
+ String trainingPipelineDisplayName =
+ String.format(
+ "temp_create_training_pipeline_video_action_recognition_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ String modelDisplayName =
+ String.format(
+ "temp_create_training_pipeline_video_action_recognition_model_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateTrainingPipelineVideoActionRecognitionSample
+ .createTrainingPipelineVideoActionRecognitionSample(
+ PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName, MODEL);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(DATASET_ID);
+ assertThat(got).contains("response");
+ trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0];
+ }
+}
diff --git a/samples/snippets/src/test/java/aiplatform/DeployModelCustomTrainedModelSampleTest.java b/samples/snippets/src/test/java/aiplatform/DeployModelCustomTrainedModelSampleTest.java
new file mode 100644
index 000000000..798e9115a
--- /dev/null
+++ b/samples/snippets/src/test/java/aiplatform/DeployModelCustomTrainedModelSampleTest.java
@@ -0,0 +1,96 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import io.grpc.StatusRuntimeException;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Ignore;
+import org.junit.Test;
+
+public class DeployModelCustomTrainedModelSampleTest {
+
+ private static final String PROJECT_ID = "ucaip-sample-tests";
+ private static final String MODEL_ID = "4992732768149438464";
+ private static final String ENDPOINT_ID = "4366591682456584192";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+
+ // Undeploy the model
+ try {
+ UndeployModelSample.undeployModelSample(PROJECT_ID, ENDPOINT_ID, MODEL_ID);
+ } catch (IOException | InterruptedException | ExecutionException | TimeoutException e) {
+ e.printStackTrace();
+ }
+ }
+
+ @Ignore("Issues with undeploy")
+ @Test
+ public void testDeployModelCustomTrainedModelSample() throws TimeoutException {
+ // As model deployment can take a long time, instead try to deploy a
+ // nonexistent model and confirm that the model was not found, but other
+ // elements of the request were valid.
+ String deployedModelDisplayName =
+ String.format(
+ "temp_deploy_model_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+ try {
+ DeployModelCustomTrainedModelSample.deployModelCustomTrainedModelSample(
+ PROJECT_ID, ENDPOINT_ID, MODEL_ID, deployedModelDisplayName);
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("deployModelResponse");
+ } catch (StatusRuntimeException | ExecutionException | InterruptedException | IOException e) {
+ assertThat(e.getMessage()).contains("is not found.");
+ }
+ }
+}
diff --git a/samples/snippets/src/test/java/aiplatform/DeployModelSampleTest.java b/samples/snippets/src/test/java/aiplatform/DeployModelSampleTest.java
index 85c258319..766944fb7 100644
--- a/samples/snippets/src/test/java/aiplatform/DeployModelSampleTest.java
+++ b/samples/snippets/src/test/java/aiplatform/DeployModelSampleTest.java
@@ -25,7 +25,6 @@
import java.io.PrintStream;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
-import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.junit.After;
import org.junit.Before;
@@ -66,8 +65,7 @@ public void tearDown() {
}
@Test
- public void testDeployModelSample()
- throws TimeoutException {
+ public void testDeployModelSample() throws TimeoutException {
// As model deployment can take a long time, instead try to deploy a
// nonexistent model and confirm that the model was not found, but other
// elements of the request were valid.
@@ -76,8 +74,8 @@ public void testDeployModelSample()
"temp_deploy_model_test_%s",
UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
try {
- DeployModelSample.deployModelSample(PROJECT_ID, deployedModelDisplayName,
- "4366591682456584192", MODEL_ID);
+ DeployModelSample.deployModelSample(
+ PROJECT_ID, deployedModelDisplayName, "4366591682456584192", MODEL_ID);
// Assert
String got = bout.toString();
assertThat(got).contains("is not found.");
diff --git a/samples/snippets/src/test/java/aiplatform/ExportModelVideoActionRecognitionSampleTest.java b/samples/snippets/src/test/java/aiplatform/ExportModelVideoActionRecognitionSampleTest.java
new file mode 100644
index 000000000..57cfdb38b
--- /dev/null
+++ b/samples/snippets/src/test/java/aiplatform/ExportModelVideoActionRecognitionSampleTest.java
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class ExportModelVideoActionRecognitionSampleTest {
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID =
+ System.getenv("EXPORT_MODEL_VIDEO_ACTION_RECOGNITION_MODEL_ID");
+ private static final String GCS_DESTINATION_URI_PREFIX =
+ "gs://ucaip-samples-test-output/tmp/export_model_video_action_recognition_sample";
+ private static final String EXPORT_FORMAT = "tf-saved-model";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("EXPORT_MODEL_VIDEO_ACTION_RECOGNITION_MODEL_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ // Delete the export model
+ String bucketName = GCS_DESTINATION_URI_PREFIX.split("/", 4)[2];
+ String objectName = (GCS_DESTINATION_URI_PREFIX.split("/", 4)[3]).concat("model-" + MODEL_ID);
+ DeleteExportModelSample.deleteExportModelSample(PROJECT, bucketName, objectName);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Export Model Deleted");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testExportModelVideoActionRecognitionSample()
+ throws InterruptedException, ExecutionException, TimeoutException, IOException {
+ // Act
+ ExportModelVideoActionRecognitionSample.exportModelVideoActionRecognitionSample(
+ PROJECT, MODEL_ID, GCS_DESTINATION_URI_PREFIX, EXPORT_FORMAT);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("exportModelResponse: ");
+ }
+}
diff --git a/samples/snippets/src/test/java/aiplatform/GetHyperparameterTuningJobSampleTest.java b/samples/snippets/src/test/java/aiplatform/GetHyperparameterTuningJobSampleTest.java
new file mode 100644
index 000000000..685768000
--- /dev/null
+++ b/samples/snippets/src/test/java/aiplatform/GetHyperparameterTuningJobSampleTest.java
@@ -0,0 +1,73 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class GetHyperparameterTuningJobSampleTest {
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String HYPERPARAMETER_TUNING_JOB_ID = System.getenv("GET_HP_TUNING_JOB_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("GET_HP_TUNING_JOB_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testGetHyperparameterTuningJobSample() throws IOException {
+ GetHyperparameterTuningJobSample.getHyperparameterTuningJobSample(
+ PROJECT, HYPERPARAMETER_TUNING_JOB_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(HYPERPARAMETER_TUNING_JOB_ID);
+ }
+}
diff --git a/samples/snippets/src/test/java/aiplatform/GetModelEvaluationVideoActionRecognitionSampleTest.java b/samples/snippets/src/test/java/aiplatform/GetModelEvaluationVideoActionRecognitionSampleTest.java
new file mode 100644
index 000000000..549f7172c
--- /dev/null
+++ b/samples/snippets/src/test/java/aiplatform/GetModelEvaluationVideoActionRecognitionSampleTest.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class GetModelEvaluationVideoActionRecognitionSampleTest {
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID = System.getenv("VIDEO_ACTION_MODEL_ID");
+ private static final String EVALUATION_ID = System.getenv("VIDEO_ACTION_EVALUATION_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("VIDEO_ACTION_MODEL_ID");
+ requireEnvVar("VIDEO_ACTION_EVALUATION_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testGetModelEvaluationVideoActionRecognitionSample() throws IOException {
+ // Act
+ GetModelEvaluationVideoActionRecognitionSample.getModelEvaluationVideoActionRecognitionSample(
+ PROJECT, MODEL_ID, EVALUATION_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(MODEL_ID);
+ assertThat(got).contains("response:");
+ }
+}
diff --git a/samples/snippets/src/test/java/aiplatform/ImportDataSampleTest.java b/samples/snippets/src/test/java/aiplatform/ImportDataSampleTest.java
index 72d5ad90f..dcfaceb9f 100644
--- a/samples/snippets/src/test/java/aiplatform/ImportDataSampleTest.java
+++ b/samples/snippets/src/test/java/aiplatform/ImportDataSampleTest.java
@@ -72,14 +72,13 @@ public void tearDown() {
}
@Test
- public void testImportDataSample()
- throws TimeoutException {
+ public void testImportDataSample() throws TimeoutException {
// As import data into dataset can take a long time, instead try to import data into a
// nonexistent dataset and confirm that the model was not found, but other
// elements of the request were valid.
try {
ImportDataTextClassificationSingleLabelSample.importDataTextClassificationSingleLabelSample(
- PROJECT, DATASET_ID, GCS_SOURCE_URI);
+ PROJECT, DATASET_ID, GCS_SOURCE_URI);
// Assert
String got = bout.toString();
assertThat(got).contains("The Dataset does not exist.");
diff --git a/samples/snippets/src/test/java/aiplatform/ImportDataVideoActionRecognitionSampleTest.java b/samples/snippets/src/test/java/aiplatform/ImportDataVideoActionRecognitionSampleTest.java
new file mode 100644
index 000000000..dd77e61b7
--- /dev/null
+++ b/samples/snippets/src/test/java/aiplatform/ImportDataVideoActionRecognitionSampleTest.java
@@ -0,0 +1,128 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.Dataset;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.protobuf.Empty;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class ImportDataVideoActionRecognitionSampleTest {
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String LOCATION = "us-central1";
+ private static final String GCS_SOURCE_URI =
+ "gs://automl-video-demo-data/ucaip-var/swimrun.jsonl";
+
+ private String datasetId;
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ }
+
+ @Before
+ public void setUp()
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+
+ // create a temp dataset for importing data
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ String metadataSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml";
+ LocationName locationName = LocationName.of(PROJECT, LOCATION);
+ Dataset dataset =
+ Dataset.newBuilder()
+ .setDisplayName("test_dataset_display_name")
+ .setMetadataSchemaUri(metadataSchemaUri)
+ .build();
+
+ OperationFuture datasetFuture =
+ datasetServiceClient.createDatasetAsync(locationName, dataset);
+ Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS);
+ String[] datasetValues = datasetResponse.getName().split("/");
+ datasetId = datasetValues[datasetValues.length - 1];
+ }
+ }
+
+ @After
+ public void tearDown() throws InterruptedException, ExecutionException, IOException {
+ // delete the temp dataset
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ DatasetName datasetName = DatasetName.of(PROJECT, LOCATION, datasetId);
+
+ OperationFuture operationFuture =
+ datasetServiceClient.deleteDatasetAsync(datasetName);
+ operationFuture.get();
+ }
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testImportDataVideoActionRecognitionSample()
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // Act
+ ImportDataVideoActionRecognitionSample.importDataVideoActionRecognitionSample(
+ PROJECT, datasetId, GCS_SOURCE_URI);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("importDataResponse:");
+ }
+}
diff --git a/samples/snippets/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java b/samples/snippets/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java
index dde83fa06..71fd8e8ba 100644
--- a/samples/snippets/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java
+++ b/samples/snippets/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java
@@ -86,4 +86,3 @@ public void testPredictTextEntityExtraction() throws IOException {
assertThat(got).contains("Predict Text Entity Extraction Response");
}
}
-
diff --git a/samples/snippets/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java b/samples/snippets/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java
index 1189e3917..d452dc945 100644
--- a/samples/snippets/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java
+++ b/samples/snippets/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java
@@ -87,4 +87,3 @@ public void testPredictTextSentimentAnalysis() throws IOException {
assertThat(got).contains("Predict Text Sentiment Analysis Response");
}
}
-
diff --git a/samples/snippets/src/test/java/aiplatform/UploadModelSampleTest.java b/samples/snippets/src/test/java/aiplatform/UploadModelSampleTest.java
index 3ebae3b0d..c085f8c67 100644
--- a/samples/snippets/src/test/java/aiplatform/UploadModelSampleTest.java
+++ b/samples/snippets/src/test/java/aiplatform/UploadModelSampleTest.java
@@ -37,8 +37,8 @@ public class UploadModelSampleTest {
private static final String METADATASCHEMA_URI = "";
private static final String IMAGE_URI =
"gcr.io/cloud-ml-service-public/"
- + "cloud-ml-online-prediction-model-server-cpu:"
- + "v1_15py3cmle_op_images_20200229_0210_RC00";
+ + "cloud-ml-online-prediction-model-server-cpu:"
+ + "v1_15py3cmle_op_images_20200229_0210_RC00";
private static final String ARTIFACT_URI = "gs://ucaip-samples-us-central1/model/explain/";
private ByteArrayOutputStream bout;
private PrintStream out;