Skip to content
This repository was archived by the owner on Sep 9, 2023. It is now read-only.

Commit

Permalink
feat: samples updated for EJCL
Browse files Browse the repository at this point in the history
  • Loading branch information
telpirion committed Dec 4, 2020
1 parent a484ba6 commit 1df941a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationMetadata;
import com.google.cloud.aiplatform.v1beta1.utility.ValueConverter;
import com.google.protobuf.Any;
import com.google.protobuf.Value;
Expand Down Expand Up @@ -77,11 +80,13 @@ static void createTrainingPipelineImageClassificationSample(
+ "automl_image_classification_1.0.0.yaml";
LocationName locationName = LocationName.of(project, location);

String jsonString =
"{\"multiLabel\": false, \"modelType\": \"CLOUD\", \"budgetMilliNodeHours\": 8000,"
+ " \"disableEarlyStopping\": false}";
Value.Builder trainingTaskInputs = Value.newBuilder();
JsonFormat.parser().merge(jsonString, trainingTaskInputs);
AutoMlImageClassificationInputs autoMlImageClassificationInputs =
AutoMlImageClassificationInputs.newBuilder()
.setModelType(ModelType.CLOUD)
.setMultiLabel(false)
.setBudgetMilliNodeHours(8000)
.setDisableEarlyStopping(false)
.build();

InputDataConfig trainingInputDataConfig =
InputDataConfig.newBuilder().setDatasetId(datasetId).build();
Expand All @@ -90,7 +95,7 @@ static void createTrainingPipelineImageClassificationSample(
TrainingPipeline.newBuilder()
.setDisplayName(trainingPipelineDisplayName)
.setTrainingTaskDefinition(trainingTaskDefinition)
.setTrainingTaskInputs(trainingTaskInputs)
.setTrainingTaskInputs(ValueConverter.toValue(autoMlImageClassificationInputs))
.setInputDataConfig(trainingInputDataConfig)
.setModelToUpload(model)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,16 @@ static void predictImageClassification(String project, String fileName, String e
byte[] contents = Base64.encodeBase64(Files.readAllBytes(Paths.get(fileName)));
String content = new String(contents, StandardCharsets.UTF_8);

ImageClassificationPredictionInstance predictionInstance = ImageClassificationPredictionInstance.newBuilder()
ImageClassificationPredictionInstance predictionInstance =
ImageClassificationPredictionInstance.newBuilder()
.setContent(content)
.build();

List<Value> instances = new ArrayList<>();
instances.add(ValueConverter.toValue(predictionInstance));

ImageClassificationPredictionParams predictionParams = ImageClassificationPredictionParams.newBuilder()
ImageClassificationPredictionParams predictionParams =
ImageClassificationPredictionParams.newBuilder()
.setConfidenceThreshold((float) 0.5)
.setMaxPredictions(5)
.build();
Expand Down

0 comments on commit 1df941a

Please sign in to comment.