From 20c8252dcc4b093f6d699753ad39f0cdbef1d8ed Mon Sep 17 00:00:00 2001 From: "copybara-service[bot]" <56741989+copybara-service[bot]@users.noreply.github.com> Date: Tue, 23 Jan 2024 12:25:13 -0800 Subject: [PATCH] feat: [vertexai] add fromFunctionResponse in PartMaker (#10272) PiperOrigin-RevId: 600847017 Co-authored-by: Jaycee Li --- java-vertexai/README.md | 2 +- .../generativeai/preview/PartMaker.java | 56 ++++++++++++++++++ .../generativeai/preview/PartMakerTest.java | 57 +++++++++++++++++++ 3 files changed, 114 insertions(+), 1 deletion(-) diff --git a/java-vertexai/README.md b/java-vertexai/README.md index d03c64fcddc0..983ed5a70375 100644 --- a/java-vertexai/README.md +++ b/java-vertexai/README.md @@ -18,7 +18,7 @@ If you are using Maven with [BOM][libraries-bom], add this to your pom.xml file: com.google.cloud libraries-bom - 26.30.0 + 26.29.0 pom import diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/preview/PartMaker.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/preview/PartMaker.java index 24d4f64a008e..f37dad336da9 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/preview/PartMaker.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/preview/PartMaker.java @@ -18,9 +18,14 @@ import com.google.cloud.vertexai.api.Blob; import com.google.cloud.vertexai.api.FileData; +import com.google.cloud.vertexai.api.FunctionResponse; import com.google.cloud.vertexai.api.Part; import com.google.protobuf.ByteString; +import com.google.protobuf.NullValue; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; import java.net.URI; +import java.util.Map; /** Helper class to create {@link com.google.cloud.vertexai.api.Part} */ public class PartMaker { @@ -77,4 +82,55 @@ public static Part fromMimeTypeAndData(String mimeType, Object partData) { } return part; } + + /** + * Make a {@link com.google.cloud.vertexai.api.Part} from the output of {@link + * com.google.cloud.vertexai.api.FunctionCall}. + * + * @param name a string represents the name of the {@link + * com.google.cloud.vertexai.api.FunctionDeclaration} + * @param response a structured JSON object containing any output from the function call + */ + public static Part fromFunctionResponse(String name, Struct response) { + return Part.newBuilder() + .setFunctionResponse(FunctionResponse.newBuilder().setName(name).setResponse(response)) + .build(); + } + + /** + * Make a {@link com.google.cloud.vertexai.api.Part} from the result output of {@link + * com.google.cloud.vertexai.api.FunctionCall}. + * + * @param name a string represents the name of the {@link + * com.google.cloud.vertexai.api.FunctionDeclaration} + * @param response a map containing the output from the function call, supported output type: + * String, Double, Boolean, null + */ + public static Part fromFunctionResponse(String name, Map response) { + Struct.Builder structBuilder = Struct.newBuilder(); + response.forEach( + (key, value) -> { + if (value instanceof String) { + String stringValue = (String) value; + structBuilder.putFields(key, Value.newBuilder().setStringValue(stringValue).build()); + } else if (value instanceof Double) { + Double doubleValue = (Double) value; + structBuilder.putFields(key, Value.newBuilder().setNumberValue(doubleValue).build()); + } else if (value instanceof Boolean) { + Boolean boolValue = (Boolean) value; + structBuilder.putFields(key, Value.newBuilder().setBoolValue(boolValue).build()); + } else if (value == null) { + structBuilder.putFields( + key, Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build()); + } else { + throw new IllegalArgumentException( + "The value in the map can only be one of the following format: " + + "String, Double, Boolean, null."); + } + }); + + return Part.newBuilder() + .setFunctionResponse(FunctionResponse.newBuilder().setName(name).setResponse(structBuilder)) + .build(); + } } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/preview/PartMakerTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/preview/PartMakerTest.java index d6f218e78db1..ffb9d8b144a4 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/preview/PartMakerTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/preview/PartMakerTest.java @@ -17,11 +17,16 @@ package com.google.cloud.vertexai.generativeai.preview; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.google.cloud.vertexai.api.Part; import com.google.protobuf.ByteString; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; import java.net.URI; import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -72,4 +77,56 @@ public void fromMimeTypeAndData_dataInURI() throws URISyntaxException { assertThat(part.getFileData().getMimeType()).isEqualTo("image/png"); assertThat(part.getFileData().getFileUri()).isEqualTo(fileUri.toString()); } + + @Test + public void testFromFunctionResponseWithStruct() { + String functionName = "getCurrentWeather"; + Struct functionResponse = + Struct.newBuilder() + .putFields("currentWeather", Value.newBuilder().setStringValue("Super nice!").build()) + .putFields("currentTemperature", Value.newBuilder().setNumberValue(85.0).build()) + .putFields("isRaining", Value.newBuilder().setBoolValue(false).build()) + .build(); + + Part part = PartMaker.fromFunctionResponse(functionName, functionResponse); + + assertThat(part.getFunctionResponse().getName()).isEqualTo("getCurrentWeather"); + assertThat(part.getFunctionResponse().getResponse()).isEqualTo(functionResponse); + } + + @Test + public void testFromFunctionResponseWithMap() { + String functionName = "getCurrentWeather"; + Map functionResponse = new HashMap<>(); + functionResponse.put("currentWeather", "Super nice!"); + functionResponse.put("currentTemperature", 85.0); + functionResponse.put("isRaining", false); + functionResponse.put("other", null); + + Part part = PartMaker.fromFunctionResponse(functionName, functionResponse); + + assertThat(part.getFunctionResponse().getName()).isEqualTo("getCurrentWeather"); + + Map fieldsMap = part.getFunctionResponse().getResponse().getFieldsMap(); + assertThat(fieldsMap.get("currentWeather").getStringValue()).isEqualTo("Super nice!"); + assertThat(fieldsMap.get("currentTemperature").getNumberValue()).isEqualTo(85.0); + assertThat(fieldsMap.get("isRaining").getBoolValue()).isEqualTo(false); + assertThat(fieldsMap.get("other").hasNullValue()).isEqualTo(true); + } + + @Test + public void testFromFunctionResponseWithInvalidMap() { + String functionName = "getCurrentWeather"; + Map invalidResponse = new HashMap<>(); + invalidResponse.put("currentWeather", new byte[] {1, 2, 3}); + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> PartMaker.fromFunctionResponse(functionName, invalidResponse)); + assertThat(thrown) + .hasMessageThat() + .isEqualTo( + "The value in the map can only be one of the following format: " + + "String, Double, Boolean, null."); + } }