From d6602ce7afb8595f8a8bd7fd8d8ffa48dcb29bad Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Wed, 16 Dec 2020 15:15:58 -0800 Subject: [PATCH] samples: adds custom model, action recognition samples and tests (#111) * samples: adds custom mode, action recognition samples and tests * fix: refactor to use Gson --- aiplatform/snippets/pom.xml | 5 + ...reateBatchPredictionJobBigquerySample.java | 109 +++++++++++ .../CreateBatchPredictionJobSample.java | 125 +++++++++++++ ...ictionJobVideoActionRecognitionSample.java | 99 ++++++++++ ...teDataLabelingJobActiveLearningSample.java | 97 ++++++++++ .../CreateDataLabelingJobImageSample.java | 2 +- ...teDataLabelingJobSpecialistPoolSample.java | 104 +++++++++++ .../CreateDataLabelingJobVideoSample.java | 2 +- ...parameterTuningJobPythonPackageSample.java | 174 ++++++++++++++++++ .../CreateHyperparameterTuningJobSample.java | 106 +++++++++++ ...CreateTrainingPipelineCustomJobSample.java | 119 ++++++++++++ ...ineCustomTrainingManagedDatasetSample.java | 145 +++++++++++++++ ...ningPipelineImageClassificationSample.java | 4 +- ...ingPipelineImageObjectDetectionSample.java | 4 +- .../CreateTrainingPipelineSample.java | 2 - ...ngPipelineTabularClassificationSample.java | 2 - ...ainingPipelineTabularRegressionSample.java | 2 - ...iningPipelineTextClassificationSample.java | 4 +- ...ingPipelineTextEntityExtractionSample.java | 4 +- ...ngPipelineTextSentimentAnalysisSample.java | 4 +- ...gPipelineVideoActionRecognitionSample.java | 87 +++++++++ ...ningPipelineVideoClassificationSample.java | 4 +- ...ningPipelineVideoObjectTrackingSample.java | 4 +- .../java/aiplatform/DeleteDatasetSample.java | 2 +- .../DeployModelCustomTrainedModelSample.java | 92 +++++++++ ...portModelVideoActionRecognitionSample.java | 79 ++++++++ .../GetHyperparameterTuningJobSample.java | 55 ++++++ ...valuationVideoActionRecognitionSample.java | 55 ++++++ .../ImportDataImageClassificationSample.java | 6 +- .../ImportDataImageObjectDetectionSample.java | 6 +- ...taTextClassificationSingleLabelSample.java | 6 +- .../ImportDataTextEntityExtractionSample.java | 6 +- ...ImportDataTextSentimentAnalysisSample.java | 6 +- ...mportDataVideoActionRecognitionSample.java | 82 +++++++++ .../ImportDataVideoClassificationSample.java | 5 +- .../ImportDataVideoObjectTrackingSample.java | 6 +- .../CancelTrainingPipelineSampleTest.java | 2 +- ...eBatchPredictionJobBigquerySampleTest.java | 109 +++++++++++ .../CreateBatchPredictionJobSampleTest.java | 109 +++++++++++ ...onJobVideoActionRecognitionSampleTest.java | 105 +++++++++++ ...taLabelingJobActiveLearningSampleTest.java | 115 ++++++++++++ ...taLabelingJobSpecialistPoolSampleTest.java | 118 ++++++++++++ .../CreateDatasetTabularGcsSampleTest.java | 4 +- ...meterTuningJobPythonPackageSampleTest.java | 115 ++++++++++++ ...eateHyperparameterTuningJobSampleTest.java | 105 +++++++++++ ...teTrainingPipelineCustomJobSampleTest.java | 112 +++++++++++ ...ustomTrainingManagedDatasetSampleTest.java | 120 ++++++++++++ ...pelineTabularClassificationSampleTest.java | 30 +-- ...ngPipelineTabularRegressionSampleTest.java | 28 +-- ...pelineTextSentimentAnalysisSampleTest.java | 8 +- ...elineVideoActionRecognitionSampleTest.java | 109 +++++++++++ ...ployModelCustomTrainedModelSampleTest.java | 96 ++++++++++ .../aiplatform/DeployModelSampleTest.java | 8 +- ...ModelVideoActionRecognitionSampleTest.java | 89 +++++++++ .../GetHyperparameterTuningJobSampleTest.java | 73 ++++++++ ...ationVideoActionRecognitionSampleTest.java | 77 ++++++++ .../java/aiplatform/ImportDataSampleTest.java | 5 +- ...tDataVideoActionRecognitionSampleTest.java | 128 +++++++++++++ ...PredictTextEntityExtractionSampleTest.java | 1 - ...redictTextSentimentAnalysisSampleTest.java | 1 - .../aiplatform/UploadModelSampleTest.java | 4 +- 61 files changed, 3184 insertions(+), 101 deletions(-) create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobBigquerySample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobActiveLearningSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobSpecialistPoolSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateHyperparameterTuningJobSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomJobSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/DeployModelCustomTrainedModelSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/ExportModelVideoActionRecognitionSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/GetHyperparameterTuningJobSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoActionRecognitionSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoActionRecognitionSample.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobBigquerySampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateDataLabelingJobActiveLearningSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateDataLabelingJobSpecialistPoolSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineCustomJobSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/DeployModelCustomTrainedModelSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/ExportModelVideoActionRecognitionSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/GetHyperparameterTuningJobSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoActionRecognitionSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoActionRecognitionSampleTest.java diff --git a/aiplatform/snippets/pom.xml b/aiplatform/snippets/pom.xml index 2b6f8215728..1fc37373fdf 100644 --- a/aiplatform/snippets/pom.xml +++ b/aiplatform/snippets/pom.xml @@ -40,6 +40,11 @@ protobuf-java-util 4.0.0-rc-2 + + com.google.code.gson + gson + 2.8.6 + junit junit diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobBigquerySample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobBigquerySample.java new file mode 100644 index 00000000000..5ccad051aaa --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobBigquerySample.java @@ -0,0 +1,109 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_bigquery_sample] +import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1beta1.BigQueryDestination; +import com.google.cloud.aiplatform.v1beta1.BigQuerySource; +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.ModelName; +import com.google.gson.JsonObject; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; + +public class CreateBatchPredictionJobBigquerySample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String modelName = "MODEL_NAME"; + String instancesFormat = "INSTANCES_FORMAT"; + String bigquerySourceInputUri = "BIGQUERY_SOURCE_INPUT_URI"; + String predictionsFormat = "PREDICTIONS_FORMAT"; + String bigqueryDestinationOutputUri = "BIGQUERY_DESTINATION_OUTPUT_URI"; + createBatchPredictionJobBigquerySample( + project, + displayName, + modelName, + instancesFormat, + bigquerySourceInputUri, + predictionsFormat, + bigqueryDestinationOutputUri); + } + + static void createBatchPredictionJobBigquerySample( + String project, + String displayName, + String model, + String instancesFormat, + String bigquerySourceInputUri, + String predictionsFormat, + String bigqueryDestinationOutputUri) + throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + JsonObject jsonModelParameters = new JsonObject(); + Value.Builder modelParametersBuilder = Value.newBuilder(); + JsonFormat.parser().merge(jsonModelParameters.toString(), modelParametersBuilder); + Value modelParameters = modelParametersBuilder.build(); + BigQuerySource bigquerySource = + BigQuerySource.newBuilder().setInputUri(bigquerySourceInputUri).build(); + BatchPredictionJob.InputConfig inputConfig = + BatchPredictionJob.InputConfig.newBuilder() + .setInstancesFormat(instancesFormat) + .setBigquerySource(bigquerySource) + .build(); + BigQueryDestination bigqueryDestination = + BigQueryDestination.newBuilder().setOutputUri(bigqueryDestinationOutputUri).build(); + BatchPredictionJob.OutputConfig outputConfig = + BatchPredictionJob.OutputConfig.newBuilder() + .setPredictionsFormat(predictionsFormat) + .setBigqueryDestination(bigqueryDestination) + .build(); + String modelName = ModelName.of(project, location, model).toString(); + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(displayName) + .setModel(modelName) + .setModelParameters(modelParameters) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + // optional + .setGenerateExplanation(true) + .build(); + LocationName parent = LocationName.of(project, location); + BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob); + System.out.format("response: %s\n", response); + System.out.format("\tName: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_batch_prediction_job_bigquery_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobSample.java new file mode 100644 index 00000000000..cdac97ba47e --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobSample.java @@ -0,0 +1,125 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_sample] +import com.google.cloud.aiplatform.v1beta1.AcceleratorType; +import com.google.cloud.aiplatform.v1beta1.BatchDedicatedResources; +import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1beta1.GcsDestination; +import com.google.cloud.aiplatform.v1beta1.GcsSource; +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.MachineSpec; +import com.google.cloud.aiplatform.v1beta1.ModelName; +import com.google.gson.JsonObject; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; + +public class CreateBatchPredictionJobSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String modelName = "MODEL_NAME"; + String instancesFormat = "INSTANCES_FORMAT"; + String gcsSourceUri = "GCS_SOURCE_URI"; + String predictionsFormat = "PREDICTIONS_FORMAT"; + String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX"; + createBatchPredictionJobSample( + project, + displayName, + modelName, + instancesFormat, + gcsSourceUri, + predictionsFormat, + gcsDestinationOutputUriPrefix); + } + + static void createBatchPredictionJobSample( + String project, + String displayName, + String model, + String instancesFormat, + String gcsSourceUri, + String predictionsFormat, + String gcsDestinationOutputUriPrefix) + throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + + // Passing in an empty Value object for model parameters + JsonObject jsonModelParameters = new JsonObject(); + Value.Builder modelParametersBuilder = Value.newBuilder(); + JsonFormat.parser().merge(jsonModelParameters.toString(), modelParametersBuilder); + Value modelParameters = modelParametersBuilder.build(); + + GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build(); + BatchPredictionJob.InputConfig inputConfig = + BatchPredictionJob.InputConfig.newBuilder() + .setInstancesFormat(instancesFormat) + .setGcsSource(gcsSource) + .build(); + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + BatchPredictionJob.OutputConfig outputConfig = + BatchPredictionJob.OutputConfig.newBuilder() + .setPredictionsFormat(predictionsFormat) + .setGcsDestination(gcsDestination) + .build(); + MachineSpec machineSpec = + MachineSpec.newBuilder() + .setMachineType("n1-standard-2") + .setAcceleratorType(AcceleratorType.NVIDIA_TESLA_K80) + .setAcceleratorCount(1) + .build(); + BatchDedicatedResources dedicatedResources = + BatchDedicatedResources.newBuilder() + .setMachineSpec(machineSpec) + .setStartingReplicaCount(1) + .setMaxReplicaCount(1) + .build(); + String modelName = ModelName.of(project, location, model).toString(); + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(displayName) + .setModel(modelName) + .setModelParameters(modelParameters) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + .setDedicatedResources(dedicatedResources) + .build(); + LocationName parent = LocationName.of(project, location); + BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob); + System.out.format("response: %s\n", response); + System.out.format("\tName: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_batch_prediction_job_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSample.java new file mode 100644 index 00000000000..b255b625ccd --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSample.java @@ -0,0 +1,99 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_video_action_recognition_sample] +import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1beta1.GcsDestination; +import com.google.cloud.aiplatform.v1beta1.GcsSource; +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.ModelName; +import com.google.gson.JsonObject; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; + +public class CreateBatchPredictionJobVideoActionRecognitionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String model = "MODEL"; + String gcsSourceUri = "GCS_SOURCE_URI"; + String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX"; + createBatchPredictionJobVideoActionRecognitionSample( + project, displayName, model, gcsSourceUri, gcsDestinationOutputUriPrefix); + } + + static void createBatchPredictionJobVideoActionRecognitionSample( + String project, + String displayName, + String model, + String gcsSourceUri, + String gcsDestinationOutputUriPrefix) + throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + JsonObject jsonModelParameters = new JsonObject(); + jsonModelParameters.addProperty("confidenceThreshold", 0.5); + Value.Builder modelParametersBuilder = Value.newBuilder(); + JsonFormat.parser().merge(jsonModelParameters.toString(), modelParametersBuilder); + Value modelParameters = modelParametersBuilder.build(); + GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build(); + BatchPredictionJob.InputConfig inputConfig = + BatchPredictionJob.InputConfig.newBuilder() + .setInstancesFormat("jsonl") + .setGcsSource(gcsSource) + .build(); + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + BatchPredictionJob.OutputConfig outputConfig = + BatchPredictionJob.OutputConfig.newBuilder() + .setPredictionsFormat("jsonl") + .setGcsDestination(gcsDestination) + .build(); + + String modelName = ModelName.of(project, location, model).toString(); + + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(displayName) + .setModel(modelName) + .setModelParameters(modelParameters) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + .build(); + LocationName parent = LocationName.of(project, location); + BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob); + System.out.format("response: %s\n", response); + System.out.format("\tName: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_batch_prediction_job_video_action_recognition_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobActiveLearningSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobActiveLearningSample.java new file mode 100644 index 00000000000..d9f069e408f --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobActiveLearningSample.java @@ -0,0 +1,97 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_data_labeling_job_active_learning_sample] +import com.google.cloud.aiplatform.v1beta1.ActiveLearningConfig; +import com.google.cloud.aiplatform.v1beta1.DataLabelingJob; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; + +public class CreateDataLabelingJobActiveLearningSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String dataset = "DATASET"; + String instructionUri = "INSTRUCTION_URI"; + String inputsSchemaUri = "INPUTS_SCHEMA_URI"; + String annotationSpec = "ANNOTATION_SPEC"; + createDataLabelingJobActiveLearningSample( + project, displayName, dataset, instructionUri, inputsSchemaUri, annotationSpec); + } + + static void createDataLabelingJobActiveLearningSample( + String project, + String displayName, + String dataset, + String instructionUri, + String inputsSchemaUri, + String annotationSpec) + throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + JsonArray jsonAnnotationSpecs = new JsonArray(); + jsonAnnotationSpecs.add(annotationSpec); + JsonObject jsonInputs = new JsonObject(); + jsonInputs.add("annotation_specs", jsonAnnotationSpecs); + Value.Builder inputsBuilder = Value.newBuilder(); + JsonFormat.parser().merge(jsonInputs.toString(), inputsBuilder); + Value inputs = inputsBuilder.build(); + ActiveLearningConfig activeLearningConfig = + ActiveLearningConfig.newBuilder().setMaxDataItemCount(1).build(); + + String datasetName = DatasetName.of(project, location, dataset).toString(); + + DataLabelingJob dataLabelingJob = + DataLabelingJob.newBuilder() + .setDisplayName(displayName) + .addDatasets(datasetName) + .setLabelerCount(1) + .setInstructionUri(instructionUri) + .setInputsSchemaUri(inputsSchemaUri) + .setInputs(inputs) + .putAnnotationLabels( + "aiplatform.googleapis.com/annotation_set_name", + "data_labeling_job_active_learning") + .setActiveLearningConfig(activeLearningConfig) + .build(); + LocationName parent = LocationName.of(project, location); + DataLabelingJob response = client.createDataLabelingJob(parent, dataLabelingJob); + System.out.format("response: %s\n", response); + System.out.format("Name: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_data_labeling_job_active_learning_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobImageSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobImageSample.java index ce81ae6b47f..5ea70a42fef 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobImageSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobImageSample.java @@ -72,7 +72,7 @@ static void createDataLabelingJobImage( .setInstructionUri(instructionUri) .setInputsSchemaUri( "gs://google-cloud-aiplatform/schema/datalabelingjob/inputs/" - + "image_classification.yaml") + + "image_classification.yaml") .addDatasets(datasetName.toString()) .setInputs(annotationSpecValue) .putAnnotationLabels( diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobSpecialistPoolSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobSpecialistPoolSample.java new file mode 100644 index 00000000000..04a3c421634 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobSpecialistPoolSample.java @@ -0,0 +1,104 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_data_labeling_job_specialist_pool_sample] +import com.google.cloud.aiplatform.v1beta1.DataLabelingJob; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.SpecialistPoolName; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; + +public class CreateDataLabelingJobSpecialistPoolSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String dataset = "DATASET"; + String specialistPool = "SPECIALIST_POOL"; + String instructionUri = "INSTRUCTION_URI"; + String inputsSchemaUri = "INPUTS_SCHEMA_URI"; + String annotationSpec = "ANNOTATION_SPEC"; + createDataLabelingJobSpecialistPoolSample( + project, + displayName, + dataset, + specialistPool, + instructionUri, + inputsSchemaUri, + annotationSpec); + } + + static void createDataLabelingJobSpecialistPoolSample( + String project, + String displayName, + String dataset, + String specialistPool, + String instructionUri, + String inputsSchemaUri, + String annotationSpec) + throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + JsonArray jsonAnnotationSpecs = new JsonArray(); + jsonAnnotationSpecs.add(annotationSpec); + JsonObject jsonInputs = new JsonObject(); + jsonInputs.add("annotation_specs", jsonAnnotationSpecs); + Value.Builder inputsBuilder = Value.newBuilder(); + JsonFormat.parser().merge(jsonInputs.toString(), inputsBuilder); + Value inputs = inputsBuilder.build(); + + String datasetName = DatasetName.of(project, location, dataset).toString(); + String specialistPoolName = SpecialistPoolName.of(project, location, specialistPool) + .toString(); + + DataLabelingJob dataLabelingJob = + DataLabelingJob.newBuilder() + .setDisplayName(displayName) + .addDatasets(datasetName) + .setLabelerCount(1) + .setInstructionUri(instructionUri) + .setInputsSchemaUri(inputsSchemaUri) + .setInputs(inputs) + .putAnnotationLabels( + "aiplatform.googleapis.com/annotation_set_name", + "data_labeling_job_specialist_pool") + .addSpecialistPools(specialistPoolName) + .build(); + LocationName parent = LocationName.of(project, location); + DataLabelingJob response = client.createDataLabelingJob(parent, dataLabelingJob); + System.out.format("response: %s\n", response); + } + } +} + +// [END aiplatform_create_data_labeling_job_specialist_pool_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobVideoSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobVideoSample.java index 5722c335e55..ae0e451ba52 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobVideoSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobVideoSample.java @@ -72,7 +72,7 @@ static void createDataLabelingJobVideo( .setInstructionUri(instructionUri) .setInputsSchemaUri( "gs://google-cloud-aiplatform/schema/datalabelingjob/inputs/" - + "video_classification.yaml") + + "video_classification.yaml") .addDatasets(datasetName.toString()) .setInputs(annotationSpecValue) .putAnnotationLabels( diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSample.java new file mode 100644 index 00000000000..9d1937f3da7 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSample.java @@ -0,0 +1,174 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_hyperparameter_tuning_job_python_package_sample] +import com.google.cloud.aiplatform.v1beta1.AcceleratorType; +import com.google.cloud.aiplatform.v1beta1.CustomJobSpec; +import com.google.cloud.aiplatform.v1beta1.HyperparameterTuningJob; +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.MachineSpec; +import com.google.cloud.aiplatform.v1beta1.PythonPackageSpec; +import com.google.cloud.aiplatform.v1beta1.StudySpec; +import com.google.cloud.aiplatform.v1beta1.StudySpec.MetricSpec; +import com.google.cloud.aiplatform.v1beta1.StudySpec.MetricSpec.GoalType; +import com.google.cloud.aiplatform.v1beta1.StudySpec.ParameterSpec; +import com.google.cloud.aiplatform.v1beta1.StudySpec.ParameterSpec.ConditionalParameterSpec; +import com.google.cloud.aiplatform.v1beta1.StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition; +import com.google.cloud.aiplatform.v1beta1.StudySpec.ParameterSpec.DiscreteValueSpec; +import com.google.cloud.aiplatform.v1beta1.StudySpec.ParameterSpec.DoubleValueSpec; +import com.google.cloud.aiplatform.v1beta1.StudySpec.ParameterSpec.ScaleType; +import com.google.cloud.aiplatform.v1beta1.WorkerPoolSpec; +import java.io.IOException; +import java.util.Arrays; + +public class CreateHyperparameterTuningJobPythonPackageSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String executorImageUri = "EXECUTOR_IMAGE_URI"; + String packageUri = "PACKAGE_URI"; + String pythonModule = "PYTHON_MODULE"; + createHyperparameterTuningJobPythonPackageSample( + project, displayName, executorImageUri, packageUri, pythonModule); + } + + static void createHyperparameterTuningJobPythonPackageSample( + String project, + String displayName, + String executorImageUri, + String packageUri, + String pythonModule) + throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + // study spec + MetricSpec metric = + MetricSpec.newBuilder().setMetricId("val_rmse").setGoal(GoalType.MINIMIZE).build(); + + // decay + DoubleValueSpec doubleValueSpec = + DoubleValueSpec.newBuilder().setMinValue(1e-07).setMaxValue(1).build(); + ParameterSpec parameterDecaySpec = + ParameterSpec.newBuilder() + .setParameterId("decay") + .setDoubleValueSpec(doubleValueSpec) + .setScaleType(ScaleType.UNIT_LINEAR_SCALE) + .build(); + Double[] decayValues = {32.0, 64.0}; + DiscreteValueCondition discreteValueDecay = + DiscreteValueCondition.newBuilder().addAllValues(Arrays.asList(decayValues)).build(); + ConditionalParameterSpec conditionalParameterDecay = + ConditionalParameterSpec.newBuilder() + .setParameterSpec(parameterDecaySpec) + .setParentDiscreteValues(discreteValueDecay) + .build(); + + // learning rate + ParameterSpec parameterLearningSpec = + ParameterSpec.newBuilder() + .setParameterId("learning_rate") + .setDoubleValueSpec(doubleValueSpec) // Use the same min/max as for decay + .setScaleType(ScaleType.UNIT_LINEAR_SCALE) + .build(); + + Double[] learningRateValues = {4.0, 8.0, 16.0}; + DiscreteValueCondition discreteValueLearning = + DiscreteValueCondition.newBuilder() + .addAllValues(Arrays.asList(learningRateValues)) + .build(); + ConditionalParameterSpec conditionalParameterLearning = + ConditionalParameterSpec.newBuilder() + .setParameterSpec(parameterLearningSpec) + .setParentDiscreteValues(discreteValueLearning) + .build(); + + // batch size + Double[] batchSizeValues = {4.0, 8.0, 16.0, 32.0, 64.0, 128.0}; + + DiscreteValueSpec discreteValueSpec = + DiscreteValueSpec.newBuilder().addAllValues(Arrays.asList(batchSizeValues)).build(); + ParameterSpec parameter = + ParameterSpec.newBuilder() + .setParameterId("batch_size") + .setDiscreteValueSpec(discreteValueSpec) + .setScaleType(ScaleType.UNIT_LINEAR_SCALE) + .addConditionalParameterSpecs(conditionalParameterDecay) + .addConditionalParameterSpecs(conditionalParameterLearning) + .build(); + + // trial_job_spec + MachineSpec machineSpec = + MachineSpec.newBuilder() + .setMachineType("n1-standard-4") + .setAcceleratorType(AcceleratorType.NVIDIA_TESLA_K80) + .setAcceleratorCount(1) + .build(); + + PythonPackageSpec pythonPackageSpec = + PythonPackageSpec.newBuilder() + .setExecutorImageUri(executorImageUri) + .addPackageUris(packageUri) + .setPythonModule(pythonModule) + .build(); + + WorkerPoolSpec workerPoolSpec = + WorkerPoolSpec.newBuilder() + .setMachineSpec(machineSpec) + .setReplicaCount(1) + .setPythonPackageSpec(pythonPackageSpec) + .build(); + + StudySpec studySpec = + StudySpec.newBuilder() + .addMetrics(metric) + .addParameters(parameter) + .setAlgorithm(StudySpec.Algorithm.RANDOM_SEARCH) + .build(); + CustomJobSpec trialJobSpec = + CustomJobSpec.newBuilder().addWorkerPoolSpecs(workerPoolSpec).build(); + // hyperparameter_tuning_job + HyperparameterTuningJob hyperparameterTuningJob = + HyperparameterTuningJob.newBuilder() + .setDisplayName(displayName) + .setMaxTrialCount(4) + .setParallelTrialCount(2) + .setStudySpec(studySpec) + .setTrialJobSpec(trialJobSpec) + .build(); + LocationName parent = LocationName.of(project, location); + HyperparameterTuningJob response = + client.createHyperparameterTuningJob(parent, hyperparameterTuningJob); + System.out.format("response: %s\n", response); + System.out.format("Name: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_hyperparameter_tuning_job_python_package_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateHyperparameterTuningJobSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateHyperparameterTuningJobSample.java new file mode 100644 index 00000000000..37e66d512a5 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateHyperparameterTuningJobSample.java @@ -0,0 +1,106 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_hyperparameter_tuning_job_sample] +import com.google.cloud.aiplatform.v1beta1.AcceleratorType; +import com.google.cloud.aiplatform.v1beta1.ContainerSpec; +import com.google.cloud.aiplatform.v1beta1.CustomJobSpec; +import com.google.cloud.aiplatform.v1beta1.HyperparameterTuningJob; +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.MachineSpec; +import com.google.cloud.aiplatform.v1beta1.StudySpec; +import com.google.cloud.aiplatform.v1beta1.WorkerPoolSpec; +import java.io.IOException; + +public class CreateHyperparameterTuningJobSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String containerImageUri = "CONTAINER_IMAGE_URI"; + createHyperparameterTuningJobSample(project, displayName, containerImageUri); + } + + static void createHyperparameterTuningJobSample( + String project, String displayName, String containerImageUri) throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + StudySpec.MetricSpec metric0 = + StudySpec.MetricSpec.newBuilder() + .setMetricId("accuracy") + .setGoal(StudySpec.MetricSpec.GoalType.MAXIMIZE) + .build(); + StudySpec.ParameterSpec.DoubleValueSpec doubleValueSpec = + StudySpec.ParameterSpec.DoubleValueSpec.newBuilder() + .setMinValue(0.001) + .setMaxValue(0.1) + .build(); + StudySpec.ParameterSpec parameter0 = + StudySpec.ParameterSpec.newBuilder() + // Learning rate. + .setParameterId("lr") + .setDoubleValueSpec(doubleValueSpec) + .build(); + StudySpec studySpec = + StudySpec.newBuilder().addMetrics(metric0).addParameters(parameter0).build(); + MachineSpec machineSpec = + MachineSpec.newBuilder() + .setMachineType("n1-standard-4") + .setAcceleratorType(AcceleratorType.NVIDIA_TESLA_K80) + .setAcceleratorCount(1) + .build(); + ContainerSpec containerSpec = + ContainerSpec.newBuilder().setImageUri(containerImageUri).build(); + WorkerPoolSpec workerPoolSpec0 = + WorkerPoolSpec.newBuilder() + .setMachineSpec(machineSpec) + .setReplicaCount(1) + .setContainerSpec(containerSpec) + .build(); + CustomJobSpec trialJobSpec = + CustomJobSpec.newBuilder().addWorkerPoolSpecs(workerPoolSpec0).build(); + HyperparameterTuningJob hyperparameterTuningJob = + HyperparameterTuningJob.newBuilder() + .setDisplayName(displayName) + .setMaxTrialCount(2) + .setParallelTrialCount(1) + .setMaxFailedTrialCount(1) + .setStudySpec(studySpec) + .setTrialJobSpec(trialJobSpec) + .build(); + LocationName parent = LocationName.of(project, location); + HyperparameterTuningJob response = + client.createHyperparameterTuningJob(parent, hyperparameterTuningJob); + System.out.format("response: %s\n", response); + System.out.format("Name: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_hyperparameter_tuning_job_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomJobSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomJobSample.java new file mode 100644 index 00000000000..7b40d0e8d05 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomJobSample.java @@ -0,0 +1,119 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_custom_job_sample] +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.Model; +import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; + +public class CreateTrainingPipelineCustomJobSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String modelDisplayName = "MODEL_DISPLAY_NAME"; + String containerImageUri = "CONTAINER_IMAGE_URI"; + String baseOutputDirectoryPrefix = "BASE_OUTPUT_DIRECTORY_PREFIX"; + createTrainingPipelineCustomJobSample( + project, displayName, modelDisplayName, containerImageUri, baseOutputDirectoryPrefix); + } + + static void createTrainingPipelineCustomJobSample( + String project, + String displayName, + String modelDisplayName, + String containerImageUri, + String baseOutputDirectoryPrefix) + throws IOException { + PipelineServiceSettings settings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient client = PipelineServiceClient.create(settings)) { + JsonObject jsonMachineSpec = new JsonObject(); + jsonMachineSpec.addProperty("machineType", "n1-standard-4"); + + JsonArray jsonArgs = new JsonArray(); + jsonArgs.add("--model_dir=$(AIP_MODEL_DIR)"); + + // A working docker image can be found at + // gs://cloud-samples-data/ai-platform/mnist_tfrecord/custom_job + JsonObject jsonContainerSpec = new JsonObject(); + jsonContainerSpec.addProperty("imageUri", containerImageUri); + jsonContainerSpec.add("args", jsonArgs); + + JsonObject jsonJsonWorkerPoolSpec0 = new JsonObject(); + jsonJsonWorkerPoolSpec0.addProperty("replicaCount", 1); + jsonJsonWorkerPoolSpec0.add("machineSpec", jsonMachineSpec); + jsonJsonWorkerPoolSpec0.add("containerSpec", jsonContainerSpec); + + JsonArray jsonWorkerPoolSpecs = new JsonArray(); + jsonWorkerPoolSpecs.add(jsonJsonWorkerPoolSpec0); + + JsonObject jsonBaseOutputDirectory = new JsonObject(); + // The GCS location for outputs must be accessible by the project's AI Platform + // service account. + jsonBaseOutputDirectory.addProperty("output_uri_prefix", baseOutputDirectoryPrefix); + + JsonObject jsonTrainingTaskInputs = new JsonObject(); + jsonTrainingTaskInputs.add("workerPoolSpecs", jsonWorkerPoolSpecs); + jsonTrainingTaskInputs.add("baseOutputDirectory", jsonBaseOutputDirectory); + + Value.Builder trainingTaskInputsBuilder = Value.newBuilder(); + JsonFormat.parser().merge(jsonTrainingTaskInputs.toString(), trainingTaskInputsBuilder); + Value trainingTaskInputs = trainingTaskInputsBuilder.build(); + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml"; + String imageUri = "gcr.io/cloud-aiplatform/prediction/tf-cpu.1-15:latest"; + ModelContainerSpec containerSpec = + ModelContainerSpec.newBuilder().setImageUri(imageUri).build(); + Model modelToUpload = + Model.newBuilder() + .setDisplayName(modelDisplayName) + .setContainerSpec(containerSpec) + .build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(displayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(trainingTaskInputs) + .setModelToUpload(modelToUpload) + .build(); + LocationName parent = LocationName.of(project, location); + TrainingPipeline response = client.createTrainingPipeline(parent, trainingPipeline); + System.out.format("response: %s\n", response); + System.out.format("Name: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_training_pipeline_custom_job_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java new file mode 100644 index 00000000000..739d15cf8ee --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java @@ -0,0 +1,145 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_custom_training_managed_dataset_sample] +import com.google.cloud.aiplatform.v1beta1.GcsDestination; +import com.google.cloud.aiplatform.v1beta1.InputDataConfig; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.Model; +import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; + +public class CreateTrainingPipelineCustomTrainingManagedDatasetSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String modelDisplayName = "MODEL_DISPLAY_NAME"; + String datasetId = "DATASET_ID"; + String annotationSchemaUri = "ANNOTATION_SCHEMA_URI"; + String trainingContainerSpecImageUri = "TRAINING_CONTAINER_SPEC_IMAGE_URI"; + String modelContainerSpecImageUri = "MODEL_CONTAINER_SPEC_IMAGE_URI"; + String baseOutputUriPrefix = "BASE_OUTPUT_URI_PREFIX"; + createTrainingPipelineCustomTrainingManagedDatasetSample( + project, + displayName, + modelDisplayName, + datasetId, + annotationSchemaUri, + trainingContainerSpecImageUri, + modelContainerSpecImageUri, + baseOutputUriPrefix); + } + + static void createTrainingPipelineCustomTrainingManagedDatasetSample( + String project, + String displayName, + String modelDisplayName, + String datasetId, + String annotationSchemaUri, + String trainingContainerSpecImageUri, + String modelContainerSpecImageUri, + String baseOutputUriPrefix) + throws IOException { + PipelineServiceSettings settings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient client = PipelineServiceClient.create(settings)) { + JsonArray jsonArgs = new JsonArray(); + jsonArgs.add("--model-dir=$(AIP_MODEL_DIR)"); + // training_task_inputs + JsonObject jsonTrainingContainerSpec = new JsonObject(); + jsonTrainingContainerSpec.addProperty("imageUri", trainingContainerSpecImageUri); + // AIP_MODEL_DIR is set by the service according to baseOutputDirectory. + jsonTrainingContainerSpec.add("args", jsonArgs); + + JsonObject jsonMachineSpec = new JsonObject(); + jsonMachineSpec.addProperty("machineType", "n1-standard-8"); + + JsonObject jsonTrainingWorkerPoolSpec = new JsonObject(); + jsonTrainingWorkerPoolSpec.addProperty("replicaCount", 1); + jsonTrainingWorkerPoolSpec.add("machineSpec", jsonMachineSpec); + jsonTrainingWorkerPoolSpec.add("containerSpec", jsonTrainingContainerSpec); + + JsonArray jsonWorkerPoolSpecs = new JsonArray(); + jsonWorkerPoolSpecs.add(jsonTrainingWorkerPoolSpec); + + JsonObject jsonBaseOutputDirectory = new JsonObject(); + jsonBaseOutputDirectory.addProperty("outputUriPrefix", baseOutputUriPrefix); + + JsonObject jsonTrainingTaskInputs = new JsonObject(); + jsonTrainingTaskInputs.add("workerPoolSpecs", jsonWorkerPoolSpecs); + jsonTrainingTaskInputs.add("baseOutputDirectory", jsonBaseOutputDirectory); + + Value.Builder trainingTaskInputsBuilder = Value.newBuilder(); + JsonFormat.parser().merge(jsonTrainingTaskInputs.toString(), trainingTaskInputsBuilder); + Value trainingTaskInputs = trainingTaskInputsBuilder.build(); + // model_to_upload + ModelContainerSpec modelContainerSpec = + ModelContainerSpec.newBuilder().setImageUri(modelContainerSpecImageUri).build(); + Model model = + Model.newBuilder() + .setDisplayName(modelDisplayName) + .setContainerSpec(modelContainerSpec) + .build(); + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(baseOutputUriPrefix).build(); + + // input_data_config + InputDataConfig inputDataConfig = + InputDataConfig.newBuilder() + .setDatasetId(datasetId) + .setAnnotationSchemaUri(annotationSchemaUri) + .setGcsDestination(gcsDestination) + .build(); + + // training_task_definition + String customTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml"; + + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(displayName) + .setInputDataConfig(inputDataConfig) + .setTrainingTaskDefinition(customTaskDefinition) + .setTrainingTaskInputs(trainingTaskInputs) + .setModelToUpload(model) + .build(); + LocationName parent = LocationName.of(project, location); + TrainingPipeline response = client.createTrainingPipeline(parent, trainingPipeline); + System.out.format("response: %s\n", response); + System.out.format("Name: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_training_pipeline_custom_training_managed_dataset_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java index 7327cba9b20..041ec32f080 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java @@ -38,12 +38,10 @@ import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; import com.google.cloud.aiplatform.v1beta1.TimestampSplit; import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.protobuf.Any; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; import java.io.IOException; -import java.util.List; public class CreateTrainingPipelineImageClassificationSample { @@ -73,7 +71,7 @@ static void createTrainingPipelineImageClassificationSample( String location = "us-central1"; String trainingTaskDefinition = "gs://google-cloud-aiplatform/schema/trainingjob/definition/" - + "automl_image_classification_1.0.0.yaml"; + + "automl_image_classification_1.0.0.yaml"; LocationName locationName = LocationName.of(project, location); String jsonString = diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java index 636ab022418..d6b2af9ab23 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java @@ -38,12 +38,10 @@ import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; import com.google.cloud.aiplatform.v1beta1.TimestampSplit; import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.protobuf.Any; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; import java.io.IOException; -import java.util.List; public class CreateTrainingPipelineImageObjectDetectionSample { @@ -73,7 +71,7 @@ static void createTrainingPipelineImageObjectDetectionSample( String location = "us-central1"; String trainingTaskDefinition = "gs://google-cloud-aiplatform/schema/trainingjob/definition/" - + "automl_image_object_detection_1.0.0.yaml"; + + "automl_image_object_detection_1.0.0.yaml"; LocationName locationName = LocationName.of(project, location); String jsonString = diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineSample.java index 5c310bfea7f..2dcb6e88cd2 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineSample.java @@ -38,12 +38,10 @@ import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; import com.google.cloud.aiplatform.v1beta1.TimestampSplit; import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.protobuf.Any; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; import java.io.IOException; -import java.util.List; public class CreateTrainingPipelineSample { diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java index de54c0a0a07..92ebb6adab1 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java @@ -37,12 +37,10 @@ import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; import com.google.cloud.aiplatform.v1beta1.TimestampSplit; import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.protobuf.Any; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; import java.io.IOException; -import java.util.List; public class CreateTrainingPipelineTabularClassificationSample { diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java index ca24862cad2..0043fdf52a7 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java @@ -37,12 +37,10 @@ import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; import com.google.cloud.aiplatform.v1beta1.TimestampSplit; import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.protobuf.Any; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; import java.io.IOException; -import java.util.List; public class CreateTrainingPipelineTabularRegressionSample { diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java index c67e49a058e..b5335a4c734 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java @@ -38,12 +38,10 @@ import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; import com.google.cloud.aiplatform.v1beta1.TimestampSplit; import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.protobuf.Any; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; import java.io.IOException; -import java.util.List; public class CreateTrainingPipelineTextClassificationSample { @@ -74,7 +72,7 @@ static void createTrainingPipelineTextClassificationSample( String location = "us-central1"; String trainingTaskDefinition = "gs://google-cloud-aiplatform/schema/trainingjob/definition/" - + "automl_text_classification_1.0.0.yaml"; + + "automl_text_classification_1.0.0.yaml"; String jsonString = "{\"multiLabel\": false}"; LocationName locationName = LocationName.of(project, location); diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java index 3aef27086ba..9d83a4d1e0f 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java @@ -38,12 +38,10 @@ import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; import com.google.cloud.aiplatform.v1beta1.TimestampSplit; import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.protobuf.Any; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; import java.io.IOException; -import java.util.List; public class CreateTrainingPipelineTextEntityExtractionSample { @@ -74,7 +72,7 @@ static void createTrainingPipelineTextEntityExtractionSample( String location = "us-central1"; String trainingTaskDefinition = "gs://google-cloud-aiplatform/schema/trainingjob/definition/" - + "automl_text_extraction_1.0.0.yaml"; + + "automl_text_extraction_1.0.0.yaml"; String jsonString = "{}"; LocationName locationName = LocationName.of(project, location); diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java index a6139405d31..f9c656ad0b4 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java @@ -38,12 +38,10 @@ import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; import com.google.cloud.aiplatform.v1beta1.TimestampSplit; import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.protobuf.Any; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; import java.io.IOException; -import java.util.List; public class CreateTrainingPipelineTextSentimentAnalysisSample { @@ -74,7 +72,7 @@ static void createTrainingPipelineTextSentimentAnalysisSample( String location = "us-central1"; String trainingTaskDefinition = "gs://google-cloud-aiplatform/schema/trainingjob/definition/" - + "automl_text_sentiment_1.0.0.yaml"; + + "automl_text_sentiment_1.0.0.yaml"; // Sentiment max must be between 1 and 10 inclusive. // Higher value means positive sentiment. diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java new file mode 100644 index 00000000000..70d03788ea4 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java @@ -0,0 +1,87 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_video_action_recognition_sample] +import com.google.cloud.aiplatform.v1beta1.InputDataConfig; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.Model; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.gson.JsonObject; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; + +public class CreateTrainingPipelineVideoActionRecognitionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String datasetId = "DATASET_ID"; + String modelDisplayName = "MODEL_DISPLAY_NAME"; + String modelType = "MODEL_TYPE"; + createTrainingPipelineVideoActionRecognitionSample( + project, displayName, datasetId, modelDisplayName, modelType); + } + + static void createTrainingPipelineVideoActionRecognitionSample( + String project, + String displayName, + String datasetId, + String modelDisplayName, + String modelType) + throws IOException { + PipelineServiceSettings settings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient client = PipelineServiceClient.create(settings)) { + JsonObject jsonTrainingTaskInputs = new JsonObject(); + jsonTrainingTaskInputs.addProperty("modelType", modelType); + Value.Builder trainingTaskInputsBuilder = Value.newBuilder(); + JsonFormat.parser().merge(jsonTrainingTaskInputs.toString(), trainingTaskInputsBuilder); + Value trainingTaskInputs = trainingTaskInputsBuilder.build(); + InputDataConfig inputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(displayName) + .setTrainingTaskDefinition( + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_video_action_recognition_1.0.0.yaml") + .setTrainingTaskInputs(trainingTaskInputs) + .setInputDataConfig(inputDataConfig) + .setModelToUpload(modelToUpload) + .build(); + LocationName parent = LocationName.of(project, location); + TrainingPipeline response = client.createTrainingPipeline(parent, trainingPipeline); + System.out.format("response: %s\n", response); + System.out.format("Name: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_training_pipeline_video_action_recognition_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java index 383e56954af..eacbc512383 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java @@ -28,11 +28,9 @@ import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; import com.google.cloud.aiplatform.v1beta1.TimestampSplit; import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.protobuf.Any; import com.google.protobuf.Value; import com.google.rpc.Status; import java.io.IOException; -import java.util.List; public class CreateTrainingPipelineVideoClassificationSample { @@ -67,7 +65,7 @@ static void createTrainingPipelineVideoClassification( LocationName locationName = LocationName.of(project, location); String trainingTaskDefinition = "gs://google-cloud-aiplatform/schema/trainingjob/definition/" - + "automl_video_classification_1.0.0.yaml"; + + "automl_video_classification_1.0.0.yaml"; InputDataConfig inputDataConfig = InputDataConfig.newBuilder().setDatasetId(datasetId).build(); diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java index d49fcff96dd..a8e5ee1490d 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java @@ -28,12 +28,10 @@ import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; import com.google.cloud.aiplatform.v1beta1.TimestampSplit; import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.protobuf.Any; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; import java.io.IOException; -import java.util.List; public class CreateTrainingPipelineVideoObjectTrackingSample { @@ -66,7 +64,7 @@ static void createTrainingPipelineVideoObjectTracking( String location = "us-central1"; String trainingTaskDefinition = "gs://google-cloud-aiplatform/schema/trainingjob/definition/" - + "automl_video_object_tracking_1.0.0.yaml"; + + "automl_video_object_tracking_1.0.0.yaml"; LocationName locationName = LocationName.of(project, location); String jsonString = "{\"modelType\": \"CLOUD\"}"; diff --git a/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java b/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java index a9989b564e1..39ad52d0fdf 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java @@ -59,7 +59,7 @@ static void deleteDatasetSample(String project, String datasetId) System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName()); System.out.println("Waiting for operation to finish..."); operationFuture.get(300, TimeUnit.SECONDS); - + System.out.format("Deleted Dataset."); } } diff --git a/aiplatform/snippets/src/main/java/aiplatform/DeployModelCustomTrainedModelSample.java b/aiplatform/snippets/src/main/java/aiplatform/DeployModelCustomTrainedModelSample.java new file mode 100644 index 00000000000..471961b69c1 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/DeployModelCustomTrainedModelSample.java @@ -0,0 +1,92 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_deploy_model_custom_trained_model_sample] +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.DedicatedResources; +import com.google.cloud.aiplatform.v1beta1.DeployModelOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.DeployModelResponse; +import com.google.cloud.aiplatform.v1beta1.DeployedModel; +import com.google.cloud.aiplatform.v1beta1.EndpointName; +import com.google.cloud.aiplatform.v1beta1.EndpointServiceClient; +import com.google.cloud.aiplatform.v1beta1.EndpointServiceSettings; +import com.google.cloud.aiplatform.v1beta1.MachineSpec; +import com.google.cloud.aiplatform.v1beta1.ModelName; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +public class DeployModelCustomTrainedModelSample { + + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String endpointId = "ENDPOINT_ID"; + String modelName = "MODEL_NAME"; + String deployedModelDisplayName = "DEPLOYED_MODEL_DISPLAY_NAME"; + deployModelCustomTrainedModelSample(project, endpointId, modelName, deployedModelDisplayName); + } + + static void deployModelCustomTrainedModelSample( + String project, String endpointId, String model, String deployedModelDisplayName) + throws IOException, ExecutionException, InterruptedException { + EndpointServiceSettings settings = + EndpointServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (EndpointServiceClient client = EndpointServiceClient.create(settings)) { + MachineSpec machineSpec = MachineSpec.newBuilder().setMachineType("n1-standard-2").build(); + DedicatedResources dedicatedResources = + DedicatedResources.newBuilder().setMinReplicaCount(1).setMachineSpec(machineSpec).build(); + + String modelName = ModelName.of(project, location, model).toString(); + DeployedModel deployedModel = + DeployedModel.newBuilder() + .setModel(modelName) + .setDisplayName(deployedModelDisplayName) + // `dedicated_resources` must be used for non-AutoML models + .setDedicatedResources(dedicatedResources) + .build(); + // key '0' assigns traffic for the newly deployed model + // Traffic percentage values must add up to 100 + // Leave dictionary empty if endpoint should not accept any traffic + Map trafficSplit = new HashMap<>(); + trafficSplit.put("0", 100); + EndpointName endpoint = EndpointName.of(project, location, endpointId); + OperationFuture response = + client.deployModelAsync(endpoint, deployedModel, trafficSplit); + + // You can use OperationFuture.getInitialFuture to get a future representing the initial + // response to the request, which contains information while the operation is in progress. + System.out.format("Operation name: %s\n", response.getInitialFuture().get().getName()); + + // OperationFuture.get() will block until the operation is finished. + DeployModelResponse deployModelResponse = response.get(); + System.out.format("deployModelResponse: %s\n", deployModelResponse); + } + } +} + +// [END aiplatform_deploy_model_custom_trained_model_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/ExportModelVideoActionRecognitionSample.java b/aiplatform/snippets/src/main/java/aiplatform/ExportModelVideoActionRecognitionSample.java new file mode 100644 index 00000000000..1d70cd0d82d --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/ExportModelVideoActionRecognitionSample.java @@ -0,0 +1,79 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_export_model_video_action_recognition_sample] +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.ExportModelOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.ExportModelRequest; +import com.google.cloud.aiplatform.v1beta1.ExportModelResponse; +import com.google.cloud.aiplatform.v1beta1.GcsDestination; +import com.google.cloud.aiplatform.v1beta1.ModelName; +import com.google.cloud.aiplatform.v1beta1.ModelServiceClient; +import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings; +import java.io.IOException; +import java.util.concurrent.ExecutionException; + +public class ExportModelVideoActionRecognitionSample { + + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String modelId = "MODEL_ID"; + String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX"; + String exportFormat = "EXPORT_FORMAT"; + exportModelVideoActionRecognitionSample( + project, modelId, gcsDestinationOutputUriPrefix, exportFormat); + } + + static void exportModelVideoActionRecognitionSample( + String project, String modelId, String gcsDestinationOutputUriPrefix, String exportFormat) + throws IOException, ExecutionException, InterruptedException { + ModelServiceSettings settings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient client = ModelServiceClient.create(settings)) { + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + ExportModelRequest.OutputConfig outputConfig = + ExportModelRequest.OutputConfig.newBuilder() + .setArtifactDestination(gcsDestination) + .setExportFormatId(exportFormat) + .build(); + ModelName name = ModelName.of(project, location, modelId); + OperationFuture response = + client.exportModelAsync(name, outputConfig); + + // You can use OperationFuture.getInitialFuture to get a future representing the initial + // response to the request, which contains information while the operation is in progress. + System.out.format("Operation name: %s\n", response.getInitialFuture().get().getName()); + + // OperationFuture.get() will block until the operation is finished. + ExportModelResponse exportModelResponse = response.get(); + System.out.format("exportModelResponse: %s\n", exportModelResponse); + } + } +} + +// [END aiplatform_export_model_video_action_recognition_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetHyperparameterTuningJobSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetHyperparameterTuningJobSample.java new file mode 100644 index 00000000000..b4d100c3c79 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/GetHyperparameterTuningJobSample.java @@ -0,0 +1,55 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_hyperparameter_tuning_job_sample] +import com.google.cloud.aiplatform.v1beta1.HyperparameterTuningJob; +import com.google.cloud.aiplatform.v1beta1.HyperparameterTuningJobName; +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import java.io.IOException; + +public class GetHyperparameterTuningJobSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String hyperparameterTuningJobId = "HYPERPARAMETER_TUNING_JOB_ID"; + getHyperparameterTuningJobSample(project, hyperparameterTuningJobId); + } + + static void getHyperparameterTuningJobSample(String project, String hyperparameterTuningJobId) + throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + HyperparameterTuningJobName name = + HyperparameterTuningJobName.of(project, location, hyperparameterTuningJobId); + HyperparameterTuningJob response = client.getHyperparameterTuningJob(name); + System.out.format("response: %s\n", response); + } + } +} + +// [END aiplatform_get_hyperparameter_tuning_job_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoActionRecognitionSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoActionRecognitionSample.java new file mode 100644 index 00000000000..2685e23f1a0 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoActionRecognitionSample.java @@ -0,0 +1,55 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_video_action_recognition_sample] +import com.google.cloud.aiplatform.v1beta1.ModelEvaluation; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1beta1.ModelServiceClient; +import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationVideoActionRecognitionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String modelId = "MODEL_ID"; + String evaluationId = "EVALUATION_ID"; + getModelEvaluationVideoActionRecognitionSample(project, modelId, evaluationId); + } + + static void getModelEvaluationVideoActionRecognitionSample( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings settings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient client = ModelServiceClient.create(settings)) { + ModelEvaluationName name = ModelEvaluationName.of(project, location, modelId, evaluationId); + ModelEvaluation response = client.getModelEvaluation(name); + System.out.format("response: %s\n", response); + } + } +} + +// [END aiplatform_get_model_evaluation_video_action_recognition_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageClassificationSample.java index 04beb8c54b2..7bc46c4477d 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageClassificationSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageClassificationSample.java @@ -61,7 +61,7 @@ static void importDataImageClassificationSample( String location = "us-central1"; String importSchemaUri = "gs://google-cloud-aiplatform/schema/dataset/ioformat/" - + "image_classification_single_label_io_format_1.0.0.yaml"; + + "image_classification_single_label_io_format_1.0.0.yaml"; GcsSource.Builder gcsSource = GcsSource.newBuilder(); gcsSource.addUris(gcsSourceUri); @@ -81,8 +81,8 @@ static void importDataImageClassificationSample( System.out.println("Waiting for operation to finish..."); ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); - System.out.format("Import Data Image Classification Response: %s\n", - importDataResponse.toString()); + System.out.format( + "Import Data Image Classification Response: %s\n", importDataResponse.toString()); } } } diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java index ae17cfd3a49..947b3f23258 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java @@ -61,7 +61,7 @@ static void importDataImageObjectDetectionSample( String location = "us-central1"; String importSchemaUri = "gs://google-cloud-aiplatform/schema/dataset/ioformat/" - + "image_bounding_box_io_format_1.0.0.yaml"; + + "image_bounding_box_io_format_1.0.0.yaml"; GcsSource.Builder gcsSource = GcsSource.newBuilder(); gcsSource.addUris(gcsSourceUri); DatasetName datasetName = DatasetName.of(project, location, datasetId); @@ -80,8 +80,8 @@ static void importDataImageObjectDetectionSample( System.out.println("Waiting for operation to finish..."); ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); - System.out.format("Import Data Image Object Detection Response: %s\n", - importDataResponse.toString()); + System.out.format( + "Import Data Image Object Detection Response: %s\n", importDataResponse.toString()); } } } diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java index cf08bd38ae9..8c285a1b6c6 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java @@ -62,7 +62,7 @@ static void importDataTextClassificationSingleLabelSample( String location = "us-central1"; String importSchemaUri = "gs://google-cloud-aiplatform/schema/dataset/ioformat/" - + "text_classification_single_label_io_format_1.0.0.yaml"; + + "text_classification_single_label_io_format_1.0.0.yaml"; GcsSource.Builder gcsSource = GcsSource.newBuilder(); gcsSource.addUris(gcsSourceUri); @@ -82,8 +82,8 @@ static void importDataTextClassificationSingleLabelSample( System.out.println("Waiting for operation to finish..."); ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); - System.out.format("Import Data Text Classification Response: %s\n", - importDataResponse.toString()); + System.out.format( + "Import Data Text Classification Response: %s\n", importDataResponse.toString()); } } } diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java index 6bd5c4f0297..be53db2616a 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java @@ -61,7 +61,7 @@ static void importDataTextEntityExtractionSample( String location = "us-central1"; String importSchemaUri = "gs://google-cloud-aiplatform/schema/dataset/ioformat/" - + "text_extraction_io_format_1.0.0.yaml"; + + "text_extraction_io_format_1.0.0.yaml"; GcsSource.Builder gcsSource = GcsSource.newBuilder(); gcsSource.addUris(gcsSourceUri); @@ -81,8 +81,8 @@ static void importDataTextEntityExtractionSample( System.out.println("Waiting for operation to finish..."); ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); - System.out.format("Import Data Text Entity Extraction Response: %s\n", - importDataResponse.toString()); + System.out.format( + "Import Data Text Entity Extraction Response: %s\n", importDataResponse.toString()); } } } diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java index e1750678392..4634d6f64b9 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java @@ -62,7 +62,7 @@ static void importDataTextSentimentAnalysisSample( String location = "us-central1"; String importSchemaUri = "gs://google-cloud-aiplatform/schema/dataset/ioformat/" - + "text_sentiment_io_format_1.0.0.yaml"; + + "text_sentiment_io_format_1.0.0.yaml"; GcsSource.Builder gcsSource = GcsSource.newBuilder(); gcsSource.addUris(gcsSourceUri); @@ -82,8 +82,8 @@ static void importDataTextSentimentAnalysisSample( System.out.println("Waiting for operation to finish..."); ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); - System.out.format("Import Data Text Sentiment Analysis Response: %s\n", - importDataResponse.toString()); + System.out.format( + "Import Data Text Sentiment Analysis Response: %s\n", importDataResponse.toString()); } } } diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoActionRecognitionSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoActionRecognitionSample.java new file mode 100644 index 00000000000..ac182ecbc8a --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoActionRecognitionSample.java @@ -0,0 +1,82 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_video_action_recognition_sample] +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.GcsSource; +import com.google.cloud.aiplatform.v1beta1.ImportDataConfig; +import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.ImportDataResponse; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; + +public class ImportDataVideoActionRecognitionSample { + + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String datasetId = "DATASET_ID"; + String gcsSourceUri = "GCS_SOURCE_URI"; + importDataVideoActionRecognitionSample(project, datasetId, gcsSourceUri); + } + + static void importDataVideoActionRecognitionSample( + String project, String datasetId, String gcsSourceUri) + throws IOException, ExecutionException, InterruptedException { + DatasetServiceSettings settings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient client = DatasetServiceClient.create(settings)) { + GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build(); + ImportDataConfig importConfig0 = + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri( + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "video_action_recognition_io_format_1.0.0.yaml") + .build(); + List importConfigs = new ArrayList<>(); + importConfigs.add(importConfig0); + DatasetName name = DatasetName.of(project, location, datasetId); + OperationFuture response = + client.importDataAsync(name, importConfigs); + + // You can use OperationFuture.getInitialFuture to get a future representing the initial + // response to the request, which contains information while the operation is in progress. + System.out.format("Operation name: %s\n", response.getInitialFuture().get().getName()); + + // OperationFuture.get() will block until the operation is finished. + ImportDataResponse importDataResponse = response.get(); + System.out.format("importDataResponse: %s\n", importDataResponse); + } + } +} + +// [END aiplatform_import_data_video_action_recognition_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoClassificationSample.java index 4bf2c37f303..10a7b673eca 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoClassificationSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoClassificationSample.java @@ -60,7 +60,7 @@ static void importDataVideoClassification(String gcsSourceUri, String project, S String location = "us-central1"; String importSchemaUri = "gs://google-cloud-aiplatform/schema/dataset/ioformat/" - + "video_classification_io_format_1.0.0.yaml"; + + "video_classification_io_format_1.0.0.yaml"; GcsSource.Builder gcsSource = GcsSource.newBuilder(); gcsSource.addUris(gcsSourceUri); @@ -81,8 +81,7 @@ static void importDataVideoClassification(String gcsSourceUri, String project, S ImportDataResponse importDataResponse = importDataResponseFuture.get(1800, TimeUnit.SECONDS); System.out.format( - "Import Data Video Classification Response: %s\n", - importDataResponse.toString()); + "Import Data Video Classification Response: %s\n", importDataResponse.toString()); } } } diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java index f8a07d91485..87807c19382 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java @@ -59,7 +59,7 @@ static void importDataVideObjectTracking(String gcsSourceUri, String project, St String location = "us-central1"; String importSchemaUri = "gs://google-cloud-aiplatform/schema/dataset/ioformat/" - + "video_object_tracking_io_format_1.0.0.yaml"; + + "video_object_tracking_io_format_1.0.0.yaml"; GcsSource.Builder gcsSource = GcsSource.newBuilder(); gcsSource.addUris(gcsSourceUri); @@ -78,8 +78,8 @@ static void importDataVideObjectTracking(String gcsSourceUri, String project, St System.out.println("Waiting for operation to finish..."); ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); - System.out.format("Import Data Video Object Tracking Response: %s\n", - importDataResponse.toString()); + System.out.format( + "Import Data Video Object Tracking Response: %s\n", importDataResponse.toString()); } } } diff --git a/aiplatform/snippets/src/test/java/aiplatform/CancelTrainingPipelineSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CancelTrainingPipelineSampleTest.java index 1f5066679ff..f3daec20155 100644 --- a/aiplatform/snippets/src/test/java/aiplatform/CancelTrainingPipelineSampleTest.java +++ b/aiplatform/snippets/src/test/java/aiplatform/CancelTrainingPipelineSampleTest.java @@ -37,7 +37,7 @@ public class CancelTrainingPipelineSampleTest { private static final String DATASET_ID = System.getenv("TRAINING_PIPELINE_DATASET_ID"); private static final String TRAINING_TASK_DEFINITION = "gs://google-cloud-aiplatform/schema/trainingjob/definition/" - + "automl_image_classification_1.0.0.yaml"; + + "automl_image_classification_1.0.0.yaml"; private static String TRAINING_PIPELINE_ID = null; private ByteArrayOutputStream bout; private PrintStream out; diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobBigquerySampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobBigquerySampleTest.java new file mode 100644 index 00000000000..25114e60731 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobBigquerySampleTest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateBatchPredictionJobBigquerySampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("BATCH_PREDICTION_TABULAR_BQ_MODEL_ID"); + private static final String BIGQUERY_SOURCE_URI = + "bq://ucaip-sample-tests.table_test.all_bq_types"; + private static final String BIGQUERY_DESTINATION_OUTPUT_URI_PREFIX = "bq://ucaip-sample-tests"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String batchPredictionJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("BATCH_PREDICTION_TABULAR_BQ_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateBatchPredictionJobBigquerySample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "batch_prediction_bigquery_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobBigquerySample.createBatchPredictionJobBigquerySample( + PROJECT, + batchPredictionDisplayName, + MODEL_ID, + "bigquery", + BIGQUERY_SOURCE_URI, + "bigquery", + BIGQUERY_DESTINATION_OUTPUT_URI_PREFIX); + + // Assert + String got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("response:"); + batchPredictionJobId = got.split("Name: ")[1].split("batchPredictionJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobSampleTest.java new file mode 100644 index 00000000000..1def01b3ddc --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobSampleTest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateBatchPredictionJobSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("BATCH_PREDICTION_MODEL_ID"); + private static final String GCS_SOURCE_URI = + "gs://ucaip-samples-test-output/inputs/icn_batch_prediction_input.jsonl"; + private static final String GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String batchPredictionJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("BATCH_PREDICTION_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateBatchPredictionJobSample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "batch_prediction_bigquery_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobSample.createBatchPredictionJobSample( + PROJECT, + batchPredictionDisplayName, + MODEL_ID, + "jsonl", + GCS_SOURCE_URI, + "jsonl", + GCS_OUTPUT_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("response:"); + batchPredictionJobId = got.split("Name: ")[1].split("batchPredictionJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSampleTest.java new file mode 100644 index 00000000000..5ca1e5f0f87 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSampleTest.java @@ -0,0 +1,105 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateBatchPredictionJobVideoActionRecognitionSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = + System.getenv("BATCH_PREDICTION_VIDEO_ACTION_RECOGNITION_MODEL_ID"); + private static final String GCS_SOURCE_URI = + "gs://ucaip-samples-test-output/inputs/icn_batch_prediction_input.jsonl"; + private static final String GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String batchPredictionJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("BATCH_PREDICTION_VIDEO_ACTION_RECOGNITION_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateBatchPredictionJobVideoActionRecognitionSample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "batch_prediction_video_action_recognition_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobVideoActionRecognitionSample + .createBatchPredictionJobVideoActionRecognitionSample( + PROJECT, batchPredictionDisplayName, MODEL_ID, GCS_SOURCE_URI, GCS_OUTPUT_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("response:"); + batchPredictionJobId = got.split("Name: ")[1].split("batchPredictionJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateDataLabelingJobActiveLearningSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateDataLabelingJobActiveLearningSampleTest.java new file mode 100644 index 00000000000..5280476f333 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateDataLabelingJobActiveLearningSampleTest.java @@ -0,0 +1,115 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateDataLabelingJobActiveLearningSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("DATA_LABELING_ACTIVE_LEARNING_DATASET_ID"); + private static final String INSTRUCTION_URI = + "gs://ucaip-sample-resources/images/datalabeling_instructions.pdf"; + private static final String INPUTS_SCHEMA_URI = + "gs://google-cloud-aiplatform/schema/datalabelingjob/inputs/image_classification_1.0.0.yaml"; + private static final String ANNOTATION_SPEC = "roses"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String dataLabelingJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("DATA_LABELING_ACTIVE_LEARNING_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel data labeling job + CancelDataLabelingJobSample.cancelDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled Data labeling job"); + TimeUnit.MINUTES.sleep(1); + + // Delete the created dataset + DeleteDataLabelingJobSample.deleteDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Data Labeling Job."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("Avoid creating actual data labeling job for humans") + public void testCreateDataLabelingJobActiveLearningSample() throws IOException { + // Act + String dataLabelingDisplayName = + String.format( + "temp_data_labeling_job_active_learning_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDataLabelingJobActiveLearningSample.createDataLabelingJobActiveLearningSample( + PROJECT, + dataLabelingDisplayName, + DATASET_ID, + INSTRUCTION_URI, + INPUTS_SCHEMA_URI, + ANNOTATION_SPEC); + + // Assert + String got = bout.toString(); + assertThat(got).contains(dataLabelingDisplayName); + assertThat(got).contains("Create Data Labeling Job Image Response"); + dataLabelingJobId = got.split("Name: ")[1].split("dataLabelingJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateDataLabelingJobSpecialistPoolSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateDataLabelingJobSpecialistPoolSampleTest.java new file mode 100644 index 00000000000..7c41c5d844a --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateDataLabelingJobSpecialistPoolSampleTest.java @@ -0,0 +1,118 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateDataLabelingJobSpecialistPoolSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("DATA_LABELING_ACTIVE_LEARNING_DATASET_ID"); + private static final String SPECIALIST_POOL_ID = + System.getenv("DATA_LABELING_SPECIALIST_POOL_ID"); + private static final String INSTRUCTION_URI = + "gs://ucaip-sample-resources/images/datalabeling_instructions.pdf"; + private static final String INPUTS_SCHEMA_URI = + "gs://google-cloud-aiplatform/schema/datalabelingjob/inputs/image_classification_1.0.0.yaml"; + private static final String ANNOTATION_SPEC = "roses"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String dataLabelingJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("DATA_LABELING_ACTIVE_LEARNING_DATASET_ID"); + requireEnvVar("DATA_LABELING_SPECIALIST_POOL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel data labeling job + CancelDataLabelingJobSample.cancelDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled Data labeling job"); + TimeUnit.MINUTES.sleep(1); + + // Delete the created dataset + DeleteDataLabelingJobSample.deleteDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Data Labeling Job."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("Avoid creating actual data labeling job for humans") + public void testCreateDataLabelingJobSpecialistPoolSample() throws IOException { + // Act + String dataLabelingDisplayName = + String.format( + "temp_data_labeling_job_specialist_pool_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDataLabelingJobSpecialistPoolSample.createDataLabelingJobSpecialistPoolSample( + PROJECT, + dataLabelingDisplayName, + DATASET_ID, + SPECIALIST_POOL_ID, + INSTRUCTION_URI, + INPUTS_SCHEMA_URI, + ANNOTATION_SPEC); + + // Assert + String got = bout.toString(); + assertThat(got).contains(dataLabelingDisplayName); + assertThat(got).contains("Create Data Labeling Job Image Response"); + dataLabelingJobId = got.split("Name: ")[1].split("dataLabelingJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java index 10a26a5e144..7592b9aae1e 100644 --- a/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java @@ -81,8 +81,8 @@ public void testCreateDatasetTabularGcsSample() "temp_create_dataset_table_gcs_test_%s", UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); - CreateDatasetTabularGcsSample.createDatasetTableGcs(PROJECT, - datasetDisplayName, GCS_SOURCE_URI); + CreateDatasetTabularGcsSample.createDatasetTableGcs( + PROJECT, datasetDisplayName, GCS_SOURCE_URI); // Assert String got = bout.toString(); diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSampleTest.java new file mode 100644 index 00000000000..486322ce46e --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSampleTest.java @@ -0,0 +1,115 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateHyperparameterTuningJobPythonPackageSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String EXECUTOR_IMAGE_URI = + "us.gcr.io/cloud-aiplatform/training/tf-gpu.2-1:latest"; + private static final String PACKAGE_URI = + "gs://ucaip-test-us-central1/training/pythonpackages/trainer.tar.bz2"; + private static final String PYTHON_MODULE = "trainer.hptuning_trainer"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String hyperparameterJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + try (JobServiceClient client = JobServiceClient.create(settings)) { + // Cancel hyper parameter job + String hyperparameterJobName = + String.format( + "projects/%s/locations/us-central1/hyperparameterTuningJobs/%s", + PROJECT, hyperparameterJobId); + client.cancelHyperparameterTuningJob(hyperparameterJobName); + + TimeUnit.MINUTES.sleep(1); + + // Delete the created job + client.deleteHyperparameterTuningJobAsync(hyperparameterJobName); + System.out.flush(); + System.setOut(originalPrintStream); + } + } + + @Test + public void testCreateHyperparameterTuningJobPythonPackageSample() throws IOException { + String hyperparameterTuningJobDisplayName = + String.format( + "temp_hyperparameter_tuning_job_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + CreateHyperparameterTuningJobPythonPackageSample + .createHyperparameterTuningJobPythonPackageSample( + PROJECT, + hyperparameterTuningJobDisplayName, + EXECUTOR_IMAGE_URI, + PACKAGE_URI, + PYTHON_MODULE); + + // Assert + String got = bout.toString(); + assertThat(got).contains(hyperparameterTuningJobDisplayName); + assertThat(got).contains("response:"); + hyperparameterJobId = + got.split("Name: ")[1].split("hyperparameterTuningJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobSampleTest.java new file mode 100644 index 00000000000..be9273c6efb --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobSampleTest.java @@ -0,0 +1,105 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateHyperparameterTuningJobSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String CONTAINER_IMAGE_URI = "gcr.io/ucaip-test/ucaip-training-test:latest"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String hyperparameterJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + try (JobServiceClient client = JobServiceClient.create(settings)) { + // Cancel hyper parameter job + String hyperparameterJobName = + String.format( + "projects/%s/locations/us-central1/hyperparameterTuningJobs/%s", + PROJECT, hyperparameterJobId); + client.cancelHyperparameterTuningJob(hyperparameterJobName); + + TimeUnit.MINUTES.sleep(1); + + // Delete the created job + client.deleteHyperparameterTuningJobAsync(hyperparameterJobName); + System.out.flush(); + System.setOut(originalPrintStream); + } + } + + @Test + public void testCreateHyperparameterTuningJobSample() throws IOException { + String hyperparameterTuningJobDisplayName = + String.format( + "temp_hyperparameter_tuning_job_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateHyperparameterTuningJobSample.createHyperparameterTuningJobSample( + PROJECT, hyperparameterTuningJobDisplayName, CONTAINER_IMAGE_URI); + + String got = bout.toString(); + assertThat(got).contains(hyperparameterTuningJobDisplayName); + assertThat(got).contains("response:"); + hyperparameterJobId = + got.split("Name: ")[1].split("hyperparameterTuningJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineCustomJobSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineCustomJobSampleTest.java new file mode 100644 index 00000000000..ac4a97bd97f --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineCustomJobSampleTest.java @@ -0,0 +1,112 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateTrainingPipelineCustomJobSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String CONTAINER_IMAGE_URI = + "gcr.io/ucaip-sample-tests/mnist-custom-job:latest"; + private static final String GCS_OUTPUT_DIRECTORY = + "gs://ucaip-samples-us-central1/training_pipeline_output"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineCustomJobSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineCustomJobSample.createTrainingPipelineCustomJobSample( + PROJECT, + trainingPipelineDisplayName, + modelDisplayName, + CONTAINER_IMAGE_URI, + GCS_OUTPUT_DIRECTORY); + + // Assert + String got = bout.toString(); + assertThat(got).contains(trainingPipelineDisplayName); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSampleTest.java new file mode 100644 index 00000000000..9065f8191b8 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSampleTest.java @@ -0,0 +1,120 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateTrainingPipelineCustomTrainingManagedDatasetSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("CUSTOM_MANAGED_DATASET"); + private static final String ANNOTATION_SCHEMA_URI = + "gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml"; + private static final String TRAINING_CONTAINER_IMAGE_URI = + "gcr.io/ucaip-test/custom-container-managed-dataset:latest"; + private static final String MODEL_CONTAIN_SPEC_IMAGE_URI = + "gcr.io/cloud-aiplatform/prediction/tf-gpu.1-15:latest"; + private static final String GCS_OUTPUT_DIRECTORY = + "gs://ucaip-samples-us-central1/training_pipeline_output/custom_training_managed_dataset"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineCustomTrainingManagedDatasetSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineCustomTrainingManagedDatasetSample + .createTrainingPipelineCustomTrainingManagedDatasetSample( + PROJECT, + trainingPipelineDisplayName, + modelDisplayName, + DATASET_ID, + ANNOTATION_SCHEMA_URI, + TRAINING_CONTAINER_IMAGE_URI, + MODEL_CONTAIN_SPEC_IMAGE_URI, + GCS_OUTPUT_DIRECTORY); + + // Assert + String got = bout.toString(); + assertThat(got).contains(trainingPipelineDisplayName); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java index 65f7d041bf2..f8bc4769a84 100644 --- a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java @@ -39,21 +39,21 @@ public class CreateTrainingPipelineTabularClassificationSampleTest { private static final String TARGET_COLUMN = "TripType"; private static final String TRANSFORMATION = "[{\"numeric\":{\"columnName\":\"Age\",\"invalidValuesAllowed\":false}}," - + "{\"categorical\":{\"columnName\":\"Job\"}}," - + "{\"categorical\":{\"columnName\":\"MaritalStatus\"}}," - + "{\"categorical\":{\"columnName\":\"Default\"}}," - + "{\"numeric\":{\"columnName\":\"Balance\",\"invalidValuesAllowed\":false}}," - + "{\"categorical\":{\"columnName\":\"Housing\"}}," - + "{\"categorical\":{\"columnName\":\"Loan\"}}," - + "{\"categorical\":{\"columnName\":\"Contact\"}}," - + "{\"numeric\":{\"columnName\":\"Day\",\"invalidValuesAllowed\":false}}," - + "{\"categorical\":{\"columnName\":\"Month\"}}," - + "{\"numeric\":{\"columnName\":\"Duration\",\"invalidValuesAllowed\":false}}," - + "{\"numeric\":{\"columnName\":\"Campaign\",\"invalidValuesAllowed\":false}}," - + "{\"numeric\":{\"columnName\":\"PDays\",\"invalidValuesAllowed\":false}}," - + "{\"numeric\":{\"columnName\":\"Previous\",\"invalidValuesAllowed\":false}}," - + "{\"categorical\":{\"columnName\":\"POutcome\"}}," - + "{\"categorical\":{\"columnName\":\"Deposit\"}}]"; + + "{\"categorical\":{\"columnName\":\"Job\"}}," + + "{\"categorical\":{\"columnName\":\"MaritalStatus\"}}," + + "{\"categorical\":{\"columnName\":\"Default\"}}," + + "{\"numeric\":{\"columnName\":\"Balance\",\"invalidValuesAllowed\":false}}," + + "{\"categorical\":{\"columnName\":\"Housing\"}}," + + "{\"categorical\":{\"columnName\":\"Loan\"}}," + + "{\"categorical\":{\"columnName\":\"Contact\"}}," + + "{\"numeric\":{\"columnName\":\"Day\",\"invalidValuesAllowed\":false}}," + + "{\"categorical\":{\"columnName\":\"Month\"}}," + + "{\"numeric\":{\"columnName\":\"Duration\",\"invalidValuesAllowed\":false}}," + + "{\"numeric\":{\"columnName\":\"Campaign\",\"invalidValuesAllowed\":false}}," + + "{\"numeric\":{\"columnName\":\"PDays\",\"invalidValuesAllowed\":false}}," + + "{\"numeric\":{\"columnName\":\"Previous\",\"invalidValuesAllowed\":false}}," + + "{\"categorical\":{\"columnName\":\"POutcome\"}}," + + "{\"categorical\":{\"columnName\":\"Deposit\"}}]"; private ByteArrayOutputStream bout; private PrintStream out; private PrintStream originalPrintStream; diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java index 3933106f3e6..3298e0eba1d 100644 --- a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java @@ -35,23 +35,23 @@ public class CreateTrainingPipelineTabularRegressionSampleTest { private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); private static final String DATASET_ID = - System.getenv("TRAINING_PIPELINE_TABLES_REGRESSION_DATASET_ID"); + System.getenv("TRAINING_PIPELINE_TABLES_REGRESSION_DATASET_ID"); private static final String TARGET_COLUMN = "Amount"; private static final String TRANSFORMATION = "[{\"categorical\":{\"columnName\":\"SC_Group_Desc\"}}," - + "{\"categorical\":{\"columnName\":\"SC_GroupCommod_ID\"}}," - + "{\"categorical\":{\"columnName\":\"SC_GroupCommod_Desc\"}}," - + "{\"numeric\":{\"columnName\":\"SortOrder\",\"invalidValuesAllowed\":false}}," - + "{\"text\":{\"columnName\":\"SC_GeographyIndented_Desc\"}}," - + "{\"numeric\":{\"columnName\":\"SC_Commodity_ID\",\"invalidValuesAllowed\":false}}," - + "{\"text\":{\"columnName\":\"SC_Commodity_Desc\"}}," - + "{\"numeric\":{\"columnName\":\"SC_Attribute_ID\",\"invalidValuesAllowed\":false}}," - + "{\"text\":{\"columnName\":\"SC_Attribute_Desc\"}}," - + "{\"numeric\":{\"columnName\":\"SC_Unit_ID\",\"invalidValuesAllowed\":false}}," - + "{\"numeric\":{\"columnName\":\"Year_ID\",\"invalidValuesAllowed\":false}}," - + "{\"categorical\":{\"columnName\":\"SC_Frequency_Desc\"}}," - + "{\"numeric\":{\"columnName\":\"Timeperiod_ID\",\"invalidValuesAllowed\":false}}," - + "{\"text\":{\"columnName\":\"Timeperiod_Desc\"}}]"; + + "{\"categorical\":{\"columnName\":\"SC_GroupCommod_ID\"}}," + + "{\"categorical\":{\"columnName\":\"SC_GroupCommod_Desc\"}}," + + "{\"numeric\":{\"columnName\":\"SortOrder\",\"invalidValuesAllowed\":false}}," + + "{\"text\":{\"columnName\":\"SC_GeographyIndented_Desc\"}}," + + "{\"numeric\":{\"columnName\":\"SC_Commodity_ID\",\"invalidValuesAllowed\":false}}," + + "{\"text\":{\"columnName\":\"SC_Commodity_Desc\"}}," + + "{\"numeric\":{\"columnName\":\"SC_Attribute_ID\",\"invalidValuesAllowed\":false}}," + + "{\"text\":{\"columnName\":\"SC_Attribute_Desc\"}}," + + "{\"numeric\":{\"columnName\":\"SC_Unit_ID\",\"invalidValuesAllowed\":false}}," + + "{\"numeric\":{\"columnName\":\"Year_ID\",\"invalidValuesAllowed\":false}}," + + "{\"categorical\":{\"columnName\":\"SC_Frequency_Desc\"}}," + + "{\"numeric\":{\"columnName\":\"Timeperiod_ID\",\"invalidValuesAllowed\":false}}," + + "{\"text\":{\"columnName\":\"Timeperiod_Desc\"}}]"; private ByteArrayOutputStream bout; private PrintStream out; private PrintStream originalPrintStream; diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java index bd0f29461bf..f4e99b2a567 100644 --- a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java @@ -90,14 +90,10 @@ public void testCreateTrainingPipelineTextSentimentAnalysisSample() throws IOExc String tempUuid = UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26); // Act String trainingPipelineDisplayName = - String.format( - "temp_create_training_pipeline_test_%s", - tempUuid); + String.format("temp_create_training_pipeline_test_%s", tempUuid); String modelDisplayName = - String.format( - "temp_create_training_pipeline_model_test_%s", - tempUuid); + String.format("temp_create_training_pipeline_model_test_%s", tempUuid); CreateTrainingPipelineTextSentimentAnalysisSample .createTrainingPipelineTextSentimentAnalysisSample( diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSampleTest.java new file mode 100644 index 00000000000..3e39be5f84e --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSampleTest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateTrainingPipelineVideoActionRecognitionSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_VIDEO_ACTION_DATASET_ID"); + private static final String MODEL = "CLOUD"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_VIDEO_ACTION_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineVideoActionRecognitionSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_video_action_recognition_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_video_action_recognition_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineVideoActionRecognitionSample + .createTrainingPipelineVideoActionRecognitionSample( + PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName, MODEL); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/DeployModelCustomTrainedModelSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/DeployModelCustomTrainedModelSampleTest.java new file mode 100644 index 00000000000..798e9115a59 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/DeployModelCustomTrainedModelSampleTest.java @@ -0,0 +1,96 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import io.grpc.StatusRuntimeException; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class DeployModelCustomTrainedModelSampleTest { + + private static final String PROJECT_ID = "ucaip-sample-tests"; + private static final String MODEL_ID = "4992732768149438464"; + private static final String ENDPOINT_ID = "4366591682456584192"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + + // Undeploy the model + try { + UndeployModelSample.undeployModelSample(PROJECT_ID, ENDPOINT_ID, MODEL_ID); + } catch (IOException | InterruptedException | ExecutionException | TimeoutException e) { + e.printStackTrace(); + } + } + + @Ignore("Issues with undeploy") + @Test + public void testDeployModelCustomTrainedModelSample() throws TimeoutException { + // As model deployment can take a long time, instead try to deploy a + // nonexistent model and confirm that the model was not found, but other + // elements of the request were valid. + String deployedModelDisplayName = + String.format( + "temp_deploy_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + try { + DeployModelCustomTrainedModelSample.deployModelCustomTrainedModelSample( + PROJECT_ID, ENDPOINT_ID, MODEL_ID, deployedModelDisplayName); + // Assert + String got = bout.toString(); + assertThat(got).contains("deployModelResponse"); + } catch (StatusRuntimeException | ExecutionException | InterruptedException | IOException e) { + assertThat(e.getMessage()).contains("is not found."); + } + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/DeployModelSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/DeployModelSampleTest.java index 85c2583195e..766944fb73a 100644 --- a/aiplatform/snippets/src/test/java/aiplatform/DeployModelSampleTest.java +++ b/aiplatform/snippets/src/test/java/aiplatform/DeployModelSampleTest.java @@ -25,7 +25,6 @@ import java.io.PrintStream; import java.util.UUID; import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import org.junit.After; import org.junit.Before; @@ -66,8 +65,7 @@ public void tearDown() { } @Test - public void testDeployModelSample() - throws TimeoutException { + public void testDeployModelSample() throws TimeoutException { // As model deployment can take a long time, instead try to deploy a // nonexistent model and confirm that the model was not found, but other // elements of the request were valid. @@ -76,8 +74,8 @@ public void testDeployModelSample() "temp_deploy_model_test_%s", UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); try { - DeployModelSample.deployModelSample(PROJECT_ID, deployedModelDisplayName, - "4366591682456584192", MODEL_ID); + DeployModelSample.deployModelSample( + PROJECT_ID, deployedModelDisplayName, "4366591682456584192", MODEL_ID); // Assert String got = bout.toString(); assertThat(got).contains("is not found."); diff --git a/aiplatform/snippets/src/test/java/aiplatform/ExportModelVideoActionRecognitionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/ExportModelVideoActionRecognitionSampleTest.java new file mode 100644 index 00000000000..57cfdb38ba7 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/ExportModelVideoActionRecognitionSampleTest.java @@ -0,0 +1,89 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class ExportModelVideoActionRecognitionSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = + System.getenv("EXPORT_MODEL_VIDEO_ACTION_RECOGNITION_MODEL_ID"); + private static final String GCS_DESTINATION_URI_PREFIX = + "gs://ucaip-samples-test-output/tmp/export_model_video_action_recognition_sample"; + private static final String EXPORT_FORMAT = "tf-saved-model"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("EXPORT_MODEL_VIDEO_ACTION_RECOGNITION_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + // Delete the export model + String bucketName = GCS_DESTINATION_URI_PREFIX.split("/", 4)[2]; + String objectName = (GCS_DESTINATION_URI_PREFIX.split("/", 4)[3]).concat("model-" + MODEL_ID); + DeleteExportModelSample.deleteExportModelSample(PROJECT, bucketName, objectName); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Export Model Deleted"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testExportModelVideoActionRecognitionSample() + throws InterruptedException, ExecutionException, TimeoutException, IOException { + // Act + ExportModelVideoActionRecognitionSample.exportModelVideoActionRecognitionSample( + PROJECT, MODEL_ID, GCS_DESTINATION_URI_PREFIX, EXPORT_FORMAT); + + // Assert + String got = bout.toString(); + assertThat(got).contains("exportModelResponse: "); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetHyperparameterTuningJobSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetHyperparameterTuningJobSampleTest.java new file mode 100644 index 00000000000..685768000d1 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/GetHyperparameterTuningJobSampleTest.java @@ -0,0 +1,73 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class GetHyperparameterTuningJobSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String HYPERPARAMETER_TUNING_JOB_ID = System.getenv("GET_HP_TUNING_JOB_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("GET_HP_TUNING_JOB_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetHyperparameterTuningJobSample() throws IOException { + GetHyperparameterTuningJobSample.getHyperparameterTuningJobSample( + PROJECT, HYPERPARAMETER_TUNING_JOB_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(HYPERPARAMETER_TUNING_JOB_ID); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoActionRecognitionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoActionRecognitionSampleTest.java new file mode 100644 index 00000000000..549f7172c9d --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoActionRecognitionSampleTest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class GetModelEvaluationVideoActionRecognitionSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("VIDEO_ACTION_MODEL_ID"); + private static final String EVALUATION_ID = System.getenv("VIDEO_ACTION_EVALUATION_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("VIDEO_ACTION_MODEL_ID"); + requireEnvVar("VIDEO_ACTION_EVALUATION_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationVideoActionRecognitionSample() throws IOException { + // Act + GetModelEvaluationVideoActionRecognitionSample.getModelEvaluationVideoActionRecognitionSample( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("response:"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/ImportDataSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/ImportDataSampleTest.java index 72d5ad90f5f..dcfaceb9f55 100644 --- a/aiplatform/snippets/src/test/java/aiplatform/ImportDataSampleTest.java +++ b/aiplatform/snippets/src/test/java/aiplatform/ImportDataSampleTest.java @@ -72,14 +72,13 @@ public void tearDown() { } @Test - public void testImportDataSample() - throws TimeoutException { + public void testImportDataSample() throws TimeoutException { // As import data into dataset can take a long time, instead try to import data into a // nonexistent dataset and confirm that the model was not found, but other // elements of the request were valid. try { ImportDataTextClassificationSingleLabelSample.importDataTextClassificationSingleLabelSample( - PROJECT, DATASET_ID, GCS_SOURCE_URI); + PROJECT, DATASET_ID, GCS_SOURCE_URI); // Assert String got = bout.toString(); assertThat(got).contains("The Dataset does not exist."); diff --git a/aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoActionRecognitionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoActionRecognitionSampleTest.java new file mode 100644 index 00000000000..dd77e61b7be --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoActionRecognitionSampleTest.java @@ -0,0 +1,128 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.Dataset; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.protobuf.Empty; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class ImportDataVideoActionRecognitionSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + private static final String GCS_SOURCE_URI = + "gs://automl-video-demo-data/ucaip-var/swimrun.jsonl"; + + private String datasetId; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + + // create a temp dataset for importing data + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml"; + LocationName locationName = LocationName.of(PROJECT, LOCATION); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName("test_dataset_display_name") + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS); + String[] datasetValues = datasetResponse.getName().split("/"); + datasetId = datasetValues[datasetValues.length - 1]; + } + } + + @After + public void tearDown() throws InterruptedException, ExecutionException, IOException { + // delete the temp dataset + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + DatasetName datasetName = DatasetName.of(PROJECT, LOCATION, datasetId); + + OperationFuture operationFuture = + datasetServiceClient.deleteDatasetAsync(datasetName); + operationFuture.get(); + } + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testImportDataVideoActionRecognitionSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + ImportDataVideoActionRecognitionSample.importDataVideoActionRecognitionSample( + PROJECT, datasetId, GCS_SOURCE_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains("importDataResponse:"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java index dde83fa06d9..71fd8e8ba26 100644 --- a/aiplatform/snippets/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java +++ b/aiplatform/snippets/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java @@ -86,4 +86,3 @@ public void testPredictTextEntityExtraction() throws IOException { assertThat(got).contains("Predict Text Entity Extraction Response"); } } - diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java index 1189e391728..d452dc94574 100644 --- a/aiplatform/snippets/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java +++ b/aiplatform/snippets/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java @@ -87,4 +87,3 @@ public void testPredictTextSentimentAnalysis() throws IOException { assertThat(got).contains("Predict Text Sentiment Analysis Response"); } } - diff --git a/aiplatform/snippets/src/test/java/aiplatform/UploadModelSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/UploadModelSampleTest.java index 3ebae3b0d5c..c085f8c6776 100644 --- a/aiplatform/snippets/src/test/java/aiplatform/UploadModelSampleTest.java +++ b/aiplatform/snippets/src/test/java/aiplatform/UploadModelSampleTest.java @@ -37,8 +37,8 @@ public class UploadModelSampleTest { private static final String METADATASCHEMA_URI = ""; private static final String IMAGE_URI = "gcr.io/cloud-ml-service-public/" - + "cloud-ml-online-prediction-model-server-cpu:" - + "v1_15py3cmle_op_images_20200229_0210_RC00"; + + "cloud-ml-online-prediction-model-server-cpu:" + + "v1_15py3cmle_op_images_20200229_0210_RC00"; private static final String ARTIFACT_URI = "gs://ucaip-samples-us-central1/model/explain/"; private ByteArrayOutputStream bout; private PrintStream out;