Skip to content

Commit

Permalink
feat: [vertexai] infer location when user doesn't specify one.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 635997756
  • Loading branch information
yyyu-google authored and copybara-github committed May 22, 2024
1 parent 0a854d8 commit af1983a
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
public final class Constants {
// Constants for VertexAI class
public static final String USER_AGENT_HEADER = "model-builder";
public static final String DEFAULT_LOCATION = "us-central1";
public static final String GOOGLE_CLOUD_REGION = "GOOGLE_CLOUD_REGION";
public static final String CLOUD_ML_REGION = "CLOUD_ML_REGION";

private Constants() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ public class VertexAI implements AutoCloseable {
private final transient Supplier<PredictionServiceClient> predictionClientSupplier;
private final transient Supplier<LlmUtilityServiceClient> llmClientSupplier;

static Optional<String> getEnvironmentVariable(String envKey) {
final String envValue = System.getenv(envKey);
if (envValue == null) {
return Optional.empty();
}
return Optional.of(envValue);
}

/**
* Constructs a VertexAI instance.
*
Expand All @@ -85,6 +93,26 @@ public VertexAI(String projectId, String location) {
/* llmClientSupplierOpt= */ Optional.empty());
}

/**
* Constructs a VertexAI instance.
*
* @param projectId the default project to use when making API calls
*/
public VertexAI(String projectId) {
this(
projectId,
getEnvironmentVariable(Constants.GOOGLE_CLOUD_REGION)
.orElse(
getEnvironmentVariable(Constants.CLOUD_ML_REGION)
.orElse(Constants.DEFAULT_LOCATION)),
Transport.GRPC,
ImmutableList.of(),
/* credentials= */ Optional.empty(),
/* apiEndpoint= */ Optional.empty(),
/* predictionClientSupplierOpt= */ Optional.empty(),
/* llmClientSupplierOpt= */ Optional.empty());
}

private VertexAI(
String projectId,
String location,
Expand Down Expand Up @@ -142,7 +170,13 @@ public static class Builder {

public VertexAI build() {
checkNotNull(projectId, "projectId must be set.");
checkNotNull(location, "location must be set.");
if (location == null) {
location =
getEnvironmentVariable(Constants.GOOGLE_CLOUD_REGION)
.orElse(
getEnvironmentVariable(Constants.CLOUD_ML_REGION)
.orElse(Constants.DEFAULT_LOCATION));
}

return new VertexAI(
projectId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.google.cloud.vertexai.api.PredictionServiceSettings;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.Optional;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -43,6 +44,14 @@ public final class VertexAITest {
private static final String TEST_ENDPOINT = "test_endpoint";
private static final String TEST_DEFAULT_ENDPOINT =
String.format("%s-aiplatform.googleapis.com", TEST_LOCATION);
private static final String EXPECTED_DEFAULT_LOCATION = "us-central1";
private static final String EXPECTED_DEFAULT_ENDPOINT_FROM_DEFAULT_LOCATION =
String.format("%s-aiplatform.googleapis.com", EXPECTED_DEFAULT_LOCATION);
private static final String TEST_ENV_LOCATION_1 = "us-central2";
private static final String TEST_ENV_LOCATION_2 = "us-central3";
private static final String EXPECTED_ENV_LOCATION = "us-central2";
private static final String EXPECTED_DEFAULT_ENDPOINT_FROM_ENV_LOCATION =
String.format("%s-aiplatform.googleapis.com", EXPECTED_ENV_LOCATION);

private VertexAI vertexAi;

Expand All @@ -62,6 +71,63 @@ public void testInstantiateVertexAI_usingConstructor_shouldContainRightFields()
assertThat(vertexAi.getApiEndpoint()).isEqualTo(TEST_DEFAULT_ENDPOINT);
}

@Test
public void testInstantiateVertexAI_usingConstructorNoLocation_shouldContainRightFields()
throws IOException {
try (MockedStatic mockStatic = mockStatic(VertexAI.class)) {
mockStatic
.when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION"))
.thenReturn(Optional.empty());
mockStatic
.when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION"))
.thenReturn(Optional.empty());
vertexAi = new VertexAI(TEST_PROJECT);
assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT);
assertThat(vertexAi.getLocation()).isEqualTo(EXPECTED_DEFAULT_LOCATION);
assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC);
assertThat(vertexAi.getApiEndpoint())
.isEqualTo(EXPECTED_DEFAULT_ENDPOINT_FROM_DEFAULT_LOCATION);
}
}

@Test
public void
testInstantiateVertexAI_usingConstructorLocationFromGOOGLE_CLOUD_REGION_shouldContainRightFields()
throws IOException {
try (MockedStatic mockStatic = mockStatic(VertexAI.class)) {
mockStatic
.when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION"))
.thenReturn(Optional.of(TEST_ENV_LOCATION_1));
mockStatic
.when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION"))
.thenReturn(Optional.of(TEST_ENV_LOCATION_2));
vertexAi = new VertexAI(TEST_PROJECT);
assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT);
assertThat(vertexAi.getLocation()).isEqualTo(EXPECTED_ENV_LOCATION);
assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC);
assertThat(vertexAi.getApiEndpoint()).isEqualTo(EXPECTED_DEFAULT_ENDPOINT_FROM_ENV_LOCATION);
}
}

@Test
public void
testInstantiateVertexAI_usingConstructorLocationFromCLOUD_ML_REGION_shouldContainRightFields()
throws IOException {
try (MockedStatic mockStatic = mockStatic(VertexAI.class)) {
mockStatic
.when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION"))
.thenReturn(Optional.empty());
mockStatic
.when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION"))
.thenReturn(Optional.of(TEST_ENV_LOCATION_1));
vertexAi = new VertexAI(TEST_PROJECT);
assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT);
assertThat(vertexAi.getLocation()).isEqualTo(EXPECTED_ENV_LOCATION);
assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC);
assertThat(vertexAi.getApiEndpoint()).isEqualTo(EXPECTED_DEFAULT_ENDPOINT_FROM_ENV_LOCATION);
}
}

@Test
public void testInstantiateVertexAI_builderWithCredentials_shouldContainRightFields()
throws IOException {
Expand All @@ -78,6 +144,62 @@ public void testInstantiateVertexAI_builderWithCredentials_shouldContainRightFie
assertThat(vertexAi.getCredentials()).isEqualTo(mockGoogleCredentials);
}

@Test
public void testInstantiateVertexAI_builderNoLocation_shouldContainRightFields()
throws IOException {
try (MockedStatic mockStatic = mockStatic(VertexAI.class)) {
mockStatic
.when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION"))
.thenReturn(Optional.empty());
mockStatic
.when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION"))
.thenReturn(Optional.empty());
vertexAi = new VertexAI.Builder().setProjectId(TEST_PROJECT).build();
assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT);
assertThat(vertexAi.getLocation()).isEqualTo(EXPECTED_DEFAULT_LOCATION);
assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC);
assertThat(vertexAi.getApiEndpoint())
.isEqualTo(EXPECTED_DEFAULT_ENDPOINT_FROM_DEFAULT_LOCATION);
}
}

@Test
public void
testInstantiateVertexAI_builderLocationFromGOOGLE_CLOUD_REGION_shouldContainRightFields()
throws IOException {
try (MockedStatic mockStatic = mockStatic(VertexAI.class)) {
mockStatic
.when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION"))
.thenReturn(Optional.of(TEST_ENV_LOCATION_1));
mockStatic
.when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION"))
.thenReturn(Optional.of(TEST_ENV_LOCATION_2));
vertexAi = new VertexAI.Builder().setProjectId(TEST_PROJECT).build();
assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT);
assertThat(vertexAi.getLocation()).isEqualTo(EXPECTED_ENV_LOCATION);
assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC);
assertThat(vertexAi.getApiEndpoint()).isEqualTo(EXPECTED_DEFAULT_ENDPOINT_FROM_ENV_LOCATION);
}
}

@Test
public void testInstantiateVertexAI_builderLocationFromCLOUD_ML_REGION_shouldContainRightFields()
throws IOException {
try (MockedStatic mockStatic = mockStatic(VertexAI.class)) {
mockStatic
.when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION"))
.thenReturn(Optional.empty());
mockStatic
.when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION"))
.thenReturn(Optional.of(TEST_ENV_LOCATION_1));
vertexAi = new VertexAI.Builder().setProjectId(TEST_PROJECT).build();
assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT);
assertThat(vertexAi.getLocation()).isEqualTo(EXPECTED_ENV_LOCATION);
assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC);
assertThat(vertexAi.getApiEndpoint()).isEqualTo(EXPECTED_DEFAULT_ENDPOINT_FROM_ENV_LOCATION);
}
}

@Test
public void testInstantiateVertexAI_builderWithScopes_throwsIlegalArgumentException()
throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,29 @@ public void generateContent_restTransport_nonEmptyCandidateList() throws IOExcep
}
}

@Test
public void generateContentInferredLocation_restTransport_nonEmptyCandidate() throws IOException {
try (VertexAI vertexAiViaRestWithInferredLocation =
new VertexAI.Builder().setProjectId(PROJECT_ID).setTransport(Transport.REST).build()) {
GenerativeModel textModelWithRest =
new GenerativeModel(MODEL_NAME_TEXT, vertexAiViaRestWithInferredLocation);
GenerateContentResponse response = textModelWithRest.generateContent(TEXT);

assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, response);
}
}

@Test
public void vertexAIInferredLocation_nonEmptyCandidate() throws IOException {
try (VertexAI vertexAiWithInferredLocation = new VertexAI(PROJECT_ID)) {
GenerativeModel textModel =
new GenerativeModel(MODEL_NAME_TEXT, vertexAiWithInferredLocation);
GenerateContentResponse response = textModel.generateContent(TEXT);

assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, response);
}
}

@Test
public void generateContent_withPlainText_nonEmptyCandidateList() throws IOException {
GenerateContentResponse response = textModel.generateContent(TEXT);
Expand Down

0 comments on commit af1983a

Please sign in to comment.