Skip to content

Commit

Permalink
feat: [vertexai] infer location and project when user doesn't specify…
Browse files Browse the repository at this point in the history
… them.

PiperOrigin-RevId: 635997756
  • Loading branch information
yyyu-google authored and copybara-github committed May 29, 2024
1 parent 135c89c commit b73834c
Show file tree
Hide file tree
Showing 5 changed files with 366 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
public final class Constants {
// Constants for VertexAI class
public static final String USER_AGENT_HEADER = "model-builder";
public static final String DEFAULT_LOCATION = "us-central1";
public static final String GOOGLE_CLOUD_REGION = "GOOGLE_CLOUD_REGION";
public static final String CLOUD_ML_REGION = "CLOUD_ML_REGION";
public static final String GOOGLE_CLOUD_PROJECT = "GOOGLE_CLOUD_PROJECT";

private Constants() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.auth.Credentials;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.cloud.vertexai.api.LlmUtilityServiceClient;
import com.google.cloud.vertexai.api.LlmUtilityServiceSettings;
import com.google.cloud.vertexai.api.PredictionServiceClient;
Expand Down Expand Up @@ -67,6 +68,21 @@ public class VertexAI implements AutoCloseable {
private final transient Supplier<PredictionServiceClient> predictionClientSupplier;
private final transient Supplier<LlmUtilityServiceClient> llmClientSupplier;

static Optional<String> getEnvironmentVariable(String envKey) {
return Optional.ofNullable(System.getenv(envKey));
}

static GoogleCredentials getGoogleCredentialsFromCredentialsProvider(
CredentialsProvider credentialsProvider) {
try {
final GoogleCredentials googleCredentials =
(GoogleCredentials) credentialsProvider.getCredentials();
return googleCredentials;
} catch (Exception e) {
throw new IllegalArgumentException(e);
}
}

/**
* Constructs a VertexAI instance.
*
Expand All @@ -85,6 +101,26 @@ public VertexAI(String projectId, String location) {
/* llmClientSupplierOpt= */ Optional.empty());
}

/**
* Constructs a VertexAI instance with no arguments. SDK will infer location from runtime
* environment. If there is no location inferred from runtime environment, SDK will default
* location to `us-central1`. SDK will infer projectId from GoogleCredentials. If there is
* GoogleCredentials at runtime, SDK will throw IllegalArgumentException
*
* @throws java.lang.IllegalArgumentException
*/
public VertexAI() {
this(
null,
null,
Transport.GRPC,
ImmutableList.of(),
/* credentials= */ Optional.empty(),
/* apiEndpoint= */ Optional.empty(),
/* predictionClientSupplierOpt= */ Optional.empty(),
/* llmClientSupplierOpt= */ Optional.empty());
}

private VertexAI(
String projectId,
String location,
Expand All @@ -98,12 +134,8 @@ private VertexAI(
throw new IllegalArgumentException(
"At most one of Credentials and scopes should be specified.");
}
checkArgument(!Strings.isNullOrEmpty(projectId), "projectId can't be null or empty");
checkArgument(!Strings.isNullOrEmpty(location), "location can't be null or empty");
checkNotNull(transport, "transport can't be null");

this.projectId = projectId;
this.location = location;
this.location = Strings.isNullOrEmpty(location) ? inferLocation() : location;
this.transport = transport;

if (credentials.isPresent()) {
Expand All @@ -118,13 +150,15 @@ private VertexAI(
.build();
}

this.projectId = Strings.isNullOrEmpty(projectId) ? inferProjectId() : projectId;
this.predictionClientSupplier =
Suppliers.memoize(predictionClientSupplierOpt.orElse(this::newPredictionServiceClient));

this.llmClientSupplier =
Suppliers.memoize(llmClientSupplierOpt.orElse(this::newLlmUtilityClient));

this.apiEndpoint = apiEndpoint.orElse(String.format("%s-aiplatform.googleapis.com", location));
this.apiEndpoint =
apiEndpoint.orElse(String.format("%s-aiplatform.googleapis.com", this.location));
}

/** Builder for {@link VertexAI}. */
Expand All @@ -141,8 +175,6 @@ public static class Builder {
private Supplier<LlmUtilityServiceClient> llmClientSupplier;

public VertexAI build() {
checkNotNull(projectId, "projectId must be set.");
checkNotNull(location, "location must be set.");

return new VertexAI(
projectId,
Expand Down Expand Up @@ -339,6 +371,41 @@ private LlmUtilityServiceClient newLlmUtilityClient() {
}
}

private String inferProjectId() {
final String projectNotFoundErrorMessage =
("Unable to infer your project. Please provide a project Id by one of the following:"
+ "\n- Passing a constructor argument by using new VertexAI(String projectId, String"
+ " location)"
+ "\n- Setting project using 'gcloud config set project my-project'"
+ "\n- Setting a GCP environment variable"
+ "\n- To create a Google Cloud project, please follow guidance at"
+ " https://developers.google.com/workspace/guides/create-project");
final Optional<String> projectIdOptional =
getEnvironmentVariable(Constants.GOOGLE_CLOUD_PROJECT);
if (projectIdOptional.isPresent()) {
return projectIdOptional.get();
}
String projectId;
try {
final GoogleCredentials googleCredentials =
getGoogleCredentialsFromCredentialsProvider(this.credentialsProvider);
projectId = googleCredentials.getQuotaProjectId();
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException(projectNotFoundErrorMessage, e);
}
if (Strings.isNullOrEmpty(projectId)) {
throw new IllegalArgumentException(projectNotFoundErrorMessage);
}

return projectId;
}

private String inferLocation() {
return getEnvironmentVariable(Constants.GOOGLE_CLOUD_REGION)
.orElse(
getEnvironmentVariable(Constants.CLOUD_ML_REGION).orElse(Constants.DEFAULT_LOCATION));
}

private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IOException {
LlmUtilityServiceSettings.Builder settingsBuilder;
if (transport == Transport.REST) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.google.cloud.vertexai;

import com.google.auth.oauth2.GoogleCredentials;

/** */
public final class FakeGoogleCredentials extends GoogleCredentials {
private String testProject;

FakeGoogleCredentials(String testProject) {
this.testProject = testProject;
}

@Override
public String getQuotaProjectId() {
return this.testProject;
}
}
Loading

0 comments on commit b73834c

Please sign in to comment.