From b73834c0043777c3ac30649d40de1937f50bab77 Mon Sep 17 00:00:00 2001 From: Yvonne Yu Date: Tue, 21 May 2024 19:06:12 -0700 Subject: [PATCH] feat: [vertexai] infer location and project when user doesn't specify them. PiperOrigin-RevId: 635997756 --- .../com/google/cloud/vertexai/Constants.java | 4 + .../com/google/cloud/vertexai/VertexAI.java | 83 +++++- .../cloud/vertexai/FakeGoogleCredentials.java | 17 ++ .../google/cloud/vertexai/VertexAITest.java | 249 +++++++++++++++++- .../it/ITGenerativeModelIntegrationTest.java | 22 ++ 5 files changed, 366 insertions(+), 9 deletions(-) create mode 100644 java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/FakeGoogleCredentials.java diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/Constants.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/Constants.java index 3175e0cfcac5..1609fccc8d1d 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/Constants.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/Constants.java @@ -20,6 +20,10 @@ public final class Constants { // Constants for VertexAI class public static final String USER_AGENT_HEADER = "model-builder"; + public static final String DEFAULT_LOCATION = "us-central1"; + public static final String GOOGLE_CLOUD_REGION = "GOOGLE_CLOUD_REGION"; + public static final String CLOUD_ML_REGION = "CLOUD_ML_REGION"; + public static final String GOOGLE_CLOUD_PROJECT = "GOOGLE_CLOUD_PROJECT"; private Constants() {} } diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java index 30abfe14cc51..b5b3b556a0d7 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java @@ -27,6 +27,7 @@ import com.google.api.gax.rpc.FixedHeaderProvider; import com.google.api.gax.rpc.HeaderProvider; import com.google.auth.Credentials; +import com.google.auth.oauth2.GoogleCredentials; import com.google.cloud.vertexai.api.LlmUtilityServiceClient; import com.google.cloud.vertexai.api.LlmUtilityServiceSettings; import com.google.cloud.vertexai.api.PredictionServiceClient; @@ -67,6 +68,21 @@ public class VertexAI implements AutoCloseable { private final transient Supplier predictionClientSupplier; private final transient Supplier llmClientSupplier; + static Optional getEnvironmentVariable(String envKey) { + return Optional.ofNullable(System.getenv(envKey)); + } + + static GoogleCredentials getGoogleCredentialsFromCredentialsProvider( + CredentialsProvider credentialsProvider) { + try { + final GoogleCredentials googleCredentials = + (GoogleCredentials) credentialsProvider.getCredentials(); + return googleCredentials; + } catch (Exception e) { + throw new IllegalArgumentException(e); + } + } + /** * Constructs a VertexAI instance. * @@ -85,6 +101,26 @@ public VertexAI(String projectId, String location) { /* llmClientSupplierOpt= */ Optional.empty()); } + /** + * Constructs a VertexAI instance with no arguments. SDK will infer location from runtime + * environment. If there is no location inferred from runtime environment, SDK will default + * location to `us-central1`. SDK will infer projectId from GoogleCredentials. If there is + * GoogleCredentials at runtime, SDK will throw IllegalArgumentException + * + * @throws java.lang.IllegalArgumentException + */ + public VertexAI() { + this( + null, + null, + Transport.GRPC, + ImmutableList.of(), + /* credentials= */ Optional.empty(), + /* apiEndpoint= */ Optional.empty(), + /* predictionClientSupplierOpt= */ Optional.empty(), + /* llmClientSupplierOpt= */ Optional.empty()); + } + private VertexAI( String projectId, String location, @@ -98,12 +134,8 @@ private VertexAI( throw new IllegalArgumentException( "At most one of Credentials and scopes should be specified."); } - checkArgument(!Strings.isNullOrEmpty(projectId), "projectId can't be null or empty"); - checkArgument(!Strings.isNullOrEmpty(location), "location can't be null or empty"); checkNotNull(transport, "transport can't be null"); - - this.projectId = projectId; - this.location = location; + this.location = Strings.isNullOrEmpty(location) ? inferLocation() : location; this.transport = transport; if (credentials.isPresent()) { @@ -118,13 +150,15 @@ private VertexAI( .build(); } + this.projectId = Strings.isNullOrEmpty(projectId) ? inferProjectId() : projectId; this.predictionClientSupplier = Suppliers.memoize(predictionClientSupplierOpt.orElse(this::newPredictionServiceClient)); this.llmClientSupplier = Suppliers.memoize(llmClientSupplierOpt.orElse(this::newLlmUtilityClient)); - this.apiEndpoint = apiEndpoint.orElse(String.format("%s-aiplatform.googleapis.com", location)); + this.apiEndpoint = + apiEndpoint.orElse(String.format("%s-aiplatform.googleapis.com", this.location)); } /** Builder for {@link VertexAI}. */ @@ -141,8 +175,6 @@ public static class Builder { private Supplier llmClientSupplier; public VertexAI build() { - checkNotNull(projectId, "projectId must be set."); - checkNotNull(location, "location must be set."); return new VertexAI( projectId, @@ -339,6 +371,41 @@ private LlmUtilityServiceClient newLlmUtilityClient() { } } + private String inferProjectId() { + final String projectNotFoundErrorMessage = + ("Unable to infer your project. Please provide a project Id by one of the following:" + + "\n- Passing a constructor argument by using new VertexAI(String projectId, String" + + " location)" + + "\n- Setting project using 'gcloud config set project my-project'" + + "\n- Setting a GCP environment variable" + + "\n- To create a Google Cloud project, please follow guidance at" + + " https://developers.google.com/workspace/guides/create-project"); + final Optional projectIdOptional = + getEnvironmentVariable(Constants.GOOGLE_CLOUD_PROJECT); + if (projectIdOptional.isPresent()) { + return projectIdOptional.get(); + } + String projectId; + try { + final GoogleCredentials googleCredentials = + getGoogleCredentialsFromCredentialsProvider(this.credentialsProvider); + projectId = googleCredentials.getQuotaProjectId(); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(projectNotFoundErrorMessage, e); + } + if (Strings.isNullOrEmpty(projectId)) { + throw new IllegalArgumentException(projectNotFoundErrorMessage); + } + + return projectId; + } + + private String inferLocation() { + return getEnvironmentVariable(Constants.GOOGLE_CLOUD_REGION) + .orElse( + getEnvironmentVariable(Constants.CLOUD_ML_REGION).orElse(Constants.DEFAULT_LOCATION)); + } + private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IOException { LlmUtilityServiceSettings.Builder settingsBuilder; if (transport == Transport.REST) { diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/FakeGoogleCredentials.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/FakeGoogleCredentials.java new file mode 100644 index 000000000000..0c73715283f1 --- /dev/null +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/FakeGoogleCredentials.java @@ -0,0 +1,17 @@ +package com.google.cloud.vertexai; + +import com.google.auth.oauth2.GoogleCredentials; + +/** */ +public final class FakeGoogleCredentials extends GoogleCredentials { + private String testProject; + + FakeGoogleCredentials(String testProject) { + this.testProject = testProject; + } + + @Override + public String getQuotaProjectId() { + return this.testProject; + } +} diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java index 9ee2a14c7f79..0479fdeeb3db 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java @@ -21,11 +21,13 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mockStatic; +import com.google.api.gax.core.CredentialsProvider; import com.google.auth.oauth2.GoogleCredentials; import com.google.cloud.vertexai.api.PredictionServiceClient; import com.google.cloud.vertexai.api.PredictionServiceSettings; import com.google.common.collect.ImmutableList; import java.io.IOException; +import java.util.Optional; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -43,9 +45,19 @@ public final class VertexAITest { private static final String TEST_ENDPOINT = "test_endpoint"; private static final String TEST_DEFAULT_ENDPOINT = String.format("%s-aiplatform.googleapis.com", TEST_LOCATION); + private static final String EXPECTED_DEFAULT_LOCATION = "us-central1"; + private static final String EXPECTED_DEFAULT_ENDPOINT_FROM_DEFAULT_LOCATION = + String.format("%s-aiplatform.googleapis.com", EXPECTED_DEFAULT_LOCATION); + private static final String TEST_ENV_LOCATION_1 = "us-central2"; + private static final String TEST_ENV_LOCATION_2 = "us-central3"; + private static final String EXPECTED_ENV_LOCATION = "us-central2"; + private static final String EXPECTED_DEFAULT_ENDPOINT_FROM_ENV_LOCATION = + String.format("%s-aiplatform.googleapis.com", EXPECTED_ENV_LOCATION); + private static final Optional EMPTY_ENV_VAR_OPTIONAL = Optional.ofNullable(null); + private final FakeGoogleCredentials fakeGoogleCredentials = + new FakeGoogleCredentials(TEST_PROJECT); private VertexAI vertexAi; - @Rule public final MockitoRule mocksRule = MockitoJUnit.rule(); @Mock private GoogleCredentials mockGoogleCredentials; @@ -62,6 +74,124 @@ public void testInstantiateVertexAI_usingConstructor_shouldContainRightFields() assertThat(vertexAi.getApiEndpoint()).isEqualTo(TEST_DEFAULT_ENDPOINT); } + @Test + public void testInstantiateVertexAI_usingConstructorNoArgsCase1_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStatic = mockStatic(VertexAI.class)) { + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(Optional.of(TEST_PROJECT)); + vertexAi = new VertexAI(); + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo(EXPECTED_DEFAULT_LOCATION); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()) + .isEqualTo(EXPECTED_DEFAULT_ENDPOINT_FROM_DEFAULT_LOCATION); + } + } + + @Test + public void testInstantiateVertexAI_usingConstructorNoArgsCase2_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStatic = mockStatic(VertexAI.class)) { + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStatic + .when( + () -> + VertexAI.getGoogleCredentialsFromCredentialsProvider( + any(CredentialsProvider.class))) + .thenReturn(fakeGoogleCredentials); + vertexAi = new VertexAI(); + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo(EXPECTED_DEFAULT_LOCATION); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()) + .isEqualTo(EXPECTED_DEFAULT_ENDPOINT_FROM_DEFAULT_LOCATION); + } + } + + @Test + public void testInstantiateVertexAI_usingConstructorNoArgsCase3_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStatic = mockStatic(VertexAI.class)) { + final String expectedErrorMessage = + ("Unable to infer your project. Please provide a project Id by one of the following:" + + "\n- Passing a constructor argument by using new VertexAI(String projectId, String" + + " location)" + + "\n- Setting project using 'gcloud config set project my-project'" + + "\n- Setting a GCP environment variable" + + "\n- To create a Google Cloud project, please follow guidance at" + + " https://developers.google.com/workspace/guides/create-project"); + mockStatic + .when( + () -> + VertexAI.getGoogleCredentialsFromCredentialsProvider( + any(CredentialsProvider.class))) + .thenThrow(new IllegalArgumentException("")); + IllegalArgumentException thrown = + assertThrows(IllegalArgumentException.class, () -> new VertexAI()); + assertThat(thrown).hasMessageThat().contains(expectedErrorMessage); + } + } + + @Test + public void + testInstantiateVertexAI_usingConstructorLocationFromGOOGLE_CLOUD_REGION_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStatic = mockStatic(VertexAI.class)) { + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(Optional.of(TEST_ENV_LOCATION_1)); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(Optional.of(TEST_ENV_LOCATION_2)); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(Optional.of(TEST_PROJECT)); + vertexAi = new VertexAI(); + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo(EXPECTED_ENV_LOCATION); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()).isEqualTo(EXPECTED_DEFAULT_ENDPOINT_FROM_ENV_LOCATION); + } + } + + @Test + public void + testInstantiateVertexAI_usingConstructorLocationFromCLOUD_ML_REGION_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStatic = mockStatic(VertexAI.class)) { + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(Optional.of(TEST_ENV_LOCATION_1)); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(Optional.of(TEST_PROJECT)); + vertexAi = new VertexAI(); + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo(EXPECTED_ENV_LOCATION); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()).isEqualTo(EXPECTED_DEFAULT_ENDPOINT_FROM_ENV_LOCATION); + } + } + @Test public void testInstantiateVertexAI_builderWithCredentials_shouldContainRightFields() throws IOException { @@ -78,6 +208,123 @@ public void testInstantiateVertexAI_builderWithCredentials_shouldContainRightFie assertThat(vertexAi.getCredentials()).isEqualTo(mockGoogleCredentials); } + @Test + public void testInstantiateVertexAI_builderNoArgsCase1_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStatic = mockStatic(VertexAI.class)) { + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(Optional.of(TEST_PROJECT)); + vertexAi = new VertexAI.Builder().build(); + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo(EXPECTED_DEFAULT_LOCATION); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()) + .isEqualTo(EXPECTED_DEFAULT_ENDPOINT_FROM_DEFAULT_LOCATION); + } + } + + @Test + public void testInstantiateVertexAI_builderNoArgsCase2_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStatic = mockStatic(VertexAI.class)) { + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStatic + .when( + () -> + VertexAI.getGoogleCredentialsFromCredentialsProvider( + any(CredentialsProvider.class))) + .thenReturn(fakeGoogleCredentials); + vertexAi = new VertexAI.Builder().build(); + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo(EXPECTED_DEFAULT_LOCATION); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()) + .isEqualTo(EXPECTED_DEFAULT_ENDPOINT_FROM_DEFAULT_LOCATION); + } + } + + @Test + public void testInstantiateVertexAI_builderNoArgsCase3_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStatic = mockStatic(VertexAI.class)) { + final String expectedErrorMessage = + ("Unable to infer your project. Please provide a project Id by one of the following:" + + "\n- Passing a constructor argument by using new VertexAI(String projectId, String" + + " location)" + + "\n- Setting project using 'gcloud config set project my-project'" + + "\n- Setting a GCP environment variable" + + "\n- To create a Google Cloud project, please follow guidance at" + + " https://developers.google.com/workspace/guides/create-project"); + mockStatic + .when( + () -> + VertexAI.getGoogleCredentialsFromCredentialsProvider( + any(CredentialsProvider.class))) + .thenThrow(new IllegalArgumentException("")); + IllegalArgumentException thrown = + assertThrows(IllegalArgumentException.class, () -> new VertexAI.Builder().build()); + assertThat(thrown).hasMessageThat().contains(expectedErrorMessage); + } + } + + @Test + public void + testInstantiateVertexAI_builderLocationFromGOOGLE_CLOUD_REGION_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStatic = mockStatic(VertexAI.class)) { + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(Optional.of(TEST_ENV_LOCATION_1)); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(Optional.of(TEST_ENV_LOCATION_2)); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(Optional.of(TEST_PROJECT)); + vertexAi = new VertexAI.Builder().build(); + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo(EXPECTED_ENV_LOCATION); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()).isEqualTo(EXPECTED_DEFAULT_ENDPOINT_FROM_ENV_LOCATION); + } + } + + @Test + public void testInstantiateVertexAI_builderLocationFromCLOUD_ML_REGION_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStatic = mockStatic(VertexAI.class)) { + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(Optional.of(TEST_ENV_LOCATION_1)); + mockStatic + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(Optional.of(TEST_PROJECT)); + vertexAi = new VertexAI.Builder().build(); + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo(EXPECTED_ENV_LOCATION); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()).isEqualTo(EXPECTED_DEFAULT_ENDPOINT_FROM_ENV_LOCATION); + } + } + @Test public void testInstantiateVertexAI_builderWithScopes_throwsIlegalArgumentException() throws IOException { diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java index 960037c77a1b..a7253c4132d3 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java @@ -143,6 +143,28 @@ public void generateContent_restTransport_nonEmptyCandidateList() throws IOExcep } } + @Test + public void generateContentInferredArgs_restTransport_nonEmptyCandidate() throws IOException { + try (VertexAI vertexAiViaRestWithInferredArgs = + new VertexAI.Builder().setTransport(Transport.REST).build()) { + GenerativeModel textModelWithRest = + new GenerativeModel(MODEL_NAME_TEXT, vertexAiViaRestWithInferredArgs); + GenerateContentResponse response = textModelWithRest.generateContent(TEXT); + + assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, response); + } + } + + @Test + public void vertexAIInferredArgs_nonEmptyCandidate() throws IOException { + try (VertexAI vertexAiWithInferredArgs = new VertexAI()) { + GenerativeModel textModel = new GenerativeModel(MODEL_NAME_TEXT, vertexAiWithInferredArgs); + GenerateContentResponse response = textModel.generateContent(TEXT); + + assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, response); + } + } + @Test public void generateContent_withPlainText_nonEmptyCandidateList() throws IOException { GenerateContentResponse response = textModel.generateContent(TEXT);