Skip to content

Commit

Permalink
Refactor APIConfig.endpoint as an associated value of Service
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Mar 5, 2025
1 parent 46b56b1 commit ff5b460
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 72 deletions.
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Sources/CountTokensRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ extension CountTokensRequest: GenerativeAIRequest {

var url: URL {
URL(string:
"\(apiConfig.serviceEndpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):countTokens")!
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):countTokens")!
}
}

Expand Down
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Sources/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ extension GenerateContentRequest: GenerativeAIRequest {
typealias Response = GenerateContentResponse

var url: URL {
let modelURL = "\(apiConfig.serviceEndpoint.rawValue)/\(apiConfig.version.rawValue)/\(model)"
let modelURL = "\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model)"
switch apiMethod {
case .generateContent:
return URL(string: "\(modelURL):\(apiMethod.rawValue)")!
Expand Down
29 changes: 17 additions & 12 deletions FirebaseVertexAI/Sources/Types/Internal/APIConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,16 @@ struct APIConfig: Sendable, Hashable {
/// This controls which backend API is used by the SDK.
let service: Service

/// The specific network address to use for API requests.
///
/// This must correspond with the API set in `service`.
let serviceEndpoint: ServiceEndpoint

/// The version of the selected API to use, e.g., "v1".
let version: Version

/// Initializes an API configuration.
///
/// - Parameters:
/// - service: The API service to use for generative AI.
/// - serviceEndpoint: The network address to use for the API service.
/// - version: The version of the API to use.
init(service: Service, serviceEndpoint: ServiceEndpoint, version: Version) {
init(service: Service, version: Version) {
self.service = service
self.serviceEndpoint = serviceEndpoint
self.version = version
}
}
Expand All @@ -46,7 +39,7 @@ extension APIConfig {
/// See [Vertex AI and Google AI
/// differences](https://cloud.google.com/vertex-ai/generative-ai/docs/overview#how-gemini-vertex-different-gemini-aistudio)
/// for a comparison of the two [API services](https://google.aip.dev/9#api-service).
enum Service {
enum Service: Hashable {
/// The Gemini Enterprise API provided by Vertex AI.
///
/// See the [Cloud
Expand All @@ -57,13 +50,25 @@ extension APIConfig {
/// The Gemini Developer API provided by Google AI.
///
/// See the [Google AI docs](https://ai.google.dev/gemini-api/docs) for more details.
case developer
case developer(endpoint: Endpoint)

/// The specific network address to use for API requests.
///
/// This must correspond with the API set in `service`.
var endpoint: Endpoint {
switch self {
case .vertexAI:
return .firebaseVertexAIProd
case let .developer(endpoint: endpoint):
return endpoint
}
}
}
}

extension APIConfig {
extension APIConfig.Service {
/// Network addresses for generative AI API services.
enum ServiceEndpoint: String {
enum Endpoint: String {
/// The Vertex AI in Firebase production endpoint.
case firebaseVertexAIProd = "https://firebasevertexai.googleapis.com"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ extension ImagenGenerationRequest: GenerativeAIRequest where ImageType: Decodabl
typealias Response = ImagenGenerationResponse<ImageType>

var url: URL {
return URL(
string: "\(apiConfig.serviceEndpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):predict"
)!
return URL(string:
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):predict")!
}
}

Expand Down
17 changes: 6 additions & 11 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class VertexAI {
}
let vertexInstance = vertexAI(app: app, location: location)
assert(vertexInstance.apiConfig.service == .vertexAI)
assert(vertexInstance.apiConfig.serviceEndpoint == .firebaseVertexAIProd)
assert(vertexInstance.apiConfig.service.endpoint == .firebaseVertexAIProd)
assert(vertexInstance.apiConfig.version == .v1beta)

return vertexInstance
Expand All @@ -56,7 +56,7 @@ public class VertexAI {
public static func vertexAI(app: FirebaseApp, location: String = "us-central1") -> VertexAI {
let vertexInstance = vertexAI(app: app, location: location, apiConfig: defaultVertexAIAPIConfig)
assert(vertexInstance.apiConfig.service == .vertexAI)
assert(vertexInstance.apiConfig.serviceEndpoint == .firebaseVertexAIProd)
assert(vertexInstance.apiConfig.service.endpoint == .firebaseVertexAIProd)
assert(vertexInstance.apiConfig.version == .v1beta)

return vertexInstance
Expand Down Expand Up @@ -159,14 +159,9 @@ public class VertexAI {

let location: String?

static let defaultVertexAIAPIConfig = APIConfig(
service: .vertexAI,
serviceEndpoint: .firebaseVertexAIProd,
version: .v1beta
)
static let defaultVertexAIAPIConfig = APIConfig(service: .vertexAI, version: .v1beta)
static let defaultDeveloperAPIConfig = APIConfig(
service: .developer,
serviceEndpoint: .generativeLanguage,
service: .developer(endpoint: .generativeLanguage),
version: .v1beta
)

Expand Down Expand Up @@ -256,14 +251,14 @@ public class VertexAI {
}

private func developerModelResourceName(modelName: String) -> String {
switch apiConfig.serviceEndpoint {
switch apiConfig.service.endpoint {
case .firebaseVertexAIStaging:
let projectID = firebaseInfo.projectID
return "projects/\(projectID)/models/\(modelName)"
case .generativeLanguage:
return "models/\(modelName)"
default:
fatalError("The Developer API is not supported on '\(apiConfig.serviceEndpoint)'.")
fatalError("The Developer API is not supported on '\(apiConfig.service.endpoint)'.")
}
}

Expand Down
6 changes: 1 addition & 5 deletions FirebaseVertexAI/Tests/Unit/ChatTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,7 @@ final class ChatTests: XCTestCase {
googleAppID: "My app ID",
firebaseApp: app
),
apiConfig: APIConfig(
service: .vertexAI,
serviceEndpoint: .firebaseVertexAIProd,
version: .v1beta
),
apiConfig: APIConfig(service: .vertexAI, version: .v1beta),
tools: nil,
requestOptions: RequestOptions(),
urlSession: urlSession
Expand Down
6 changes: 1 addition & 5 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,7 @@ final class GenerativeModelTests: XCTestCase {
].sorted()
let testModelResourceName =
"projects/test-project-id/locations/test-location/publishers/google/models/test-model"
let apiConfig = APIConfig(
service: .vertexAI,
serviceEndpoint: .firebaseVertexAIProd,
version: .v1beta
)
let apiConfig = APIConfig(service: .vertexAI, version: .v1beta)

var urlSession: URLSession!
var model: GenerativeModel!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
addWatermark: nil,
includeResponsibleAIFilterReason: includeResponsibleAIFilterReason
)
let apiConfig = APIConfig(
service: .vertexAI,
serviceEndpoint: .firebaseVertexAIProd,
version: .v1beta
)
let apiConfig = APIConfig(service: .vertexAI, version: .v1beta)

let instance = ImageGenerationInstance(prompt: "test-prompt")

Expand All @@ -64,7 +60,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
XCTAssertEqual(
request.url,
URL(string:
"\(apiConfig.serviceEndpoint.rawValue)/\(apiConfig.version.rawValue)/\(modelName):predict")
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(modelName):predict")
)
}

Expand All @@ -84,7 +80,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
XCTAssertEqual(
request.url,
URL(string:
"\(apiConfig.serviceEndpoint.rawValue)/\(apiConfig.version.rawValue)/\(modelName):predict")
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(modelName):predict")
)
}

Expand Down
31 changes: 8 additions & 23 deletions FirebaseVertexAI/Tests/Unit/Types/Internal/APIConfigTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,34 @@ import XCTest
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class APIConfigTests: XCTestCase {
func testInitialize_vertexAI_prod_v1() {
let apiConfig = APIConfig(
service: .vertexAI,
serviceEndpoint: .firebaseVertexAIProd,
version: .v1
)
let apiConfig = APIConfig(service: .vertexAI, version: .v1)

XCTAssertEqual(apiConfig.serviceEndpoint.rawValue, "https://firebasevertexai.googleapis.com")
XCTAssertEqual(apiConfig.service.endpoint.rawValue, "https://firebasevertexai.googleapis.com")
XCTAssertEqual(apiConfig.version.rawValue, "v1")
}

func testInitialize_vertexAI_prod_v1beta() {
let apiConfig = APIConfig(
service: .vertexAI,
serviceEndpoint: .firebaseVertexAIProd,
version: .v1beta
)
let apiConfig = APIConfig(service: .vertexAI, version: .v1beta)

XCTAssertEqual(apiConfig.serviceEndpoint.rawValue, "https://firebasevertexai.googleapis.com")
XCTAssertEqual(apiConfig.service.endpoint.rawValue, "https://firebasevertexai.googleapis.com")
XCTAssertEqual(apiConfig.version.rawValue, "v1beta")
}

func testInitialize_developer_staging_v1beta() {
let apiConfig = APIConfig(
service: .developer,
serviceEndpoint: .firebaseVertexAIStaging,
version: .v1beta
service: .developer(endpoint: .firebaseVertexAIStaging), version: .v1beta
)

XCTAssertEqual(
apiConfig.serviceEndpoint.rawValue,
"https://staging-firebasevertexai.sandbox.googleapis.com"
apiConfig.service.endpoint.rawValue, "https://staging-firebasevertexai.sandbox.googleapis.com"
)
XCTAssertEqual(apiConfig.version.rawValue, "v1beta")
}

func testInitialize_developer_generativeLanguage_v1beta() {
let apiConfig = APIConfig(
service: .developer,
serviceEndpoint: .generativeLanguage,
version: .v1beta
)
let apiConfig = APIConfig(service: .developer(endpoint: .generativeLanguage), version: .v1beta)

XCTAssertEqual(apiConfig.serviceEndpoint.rawValue, "https://generativelanguage.googleapis.com")
XCTAssertEqual(apiConfig.service.endpoint.rawValue, "https://generativelanguage.googleapis.com")
XCTAssertEqual(apiConfig.version.rawValue, "v1beta")
}
}
7 changes: 3 additions & 4 deletions FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class VertexComponentTests: XCTestCase {
XCTAssertEqual(vertex.firebaseInfo.apiKey, VertexComponentTests.apiKey)
XCTAssertEqual(vertex.location, location)
XCTAssertEqual(vertex.apiConfig.service, .vertexAI)
XCTAssertEqual(vertex.apiConfig.serviceEndpoint, .firebaseVertexAIProd)
XCTAssertEqual(vertex.apiConfig.service.endpoint, .firebaseVertexAIProd)
XCTAssertEqual(vertex.apiConfig.version, .v1beta)
}

Expand All @@ -73,7 +73,7 @@ class VertexComponentTests: XCTestCase {
XCTAssertEqual(vertex.firebaseInfo.apiKey, VertexComponentTests.apiKey)
XCTAssertEqual(vertex.location, location)
XCTAssertEqual(vertex.apiConfig.service, .vertexAI)
XCTAssertEqual(vertex.apiConfig.serviceEndpoint, .firebaseVertexAIProd)
XCTAssertEqual(vertex.apiConfig.service.endpoint, .firebaseVertexAIProd)
XCTAssertEqual(vertex.apiConfig.version, .v1beta)
}

Expand Down Expand Up @@ -179,8 +179,7 @@ class VertexComponentTests: XCTestCase {
func testModelResourceName_developerAPI_firebaseVertexAI() throws {
let app = try XCTUnwrap(VertexComponentTests.app)
let apiConfig = APIConfig(
service: .developer,
serviceEndpoint: .firebaseVertexAIStaging,
service: .developer(endpoint: .firebaseVertexAIStaging),
version: .v1beta
)
let vertex = VertexAI.developerAPI(app: app, apiConfig: apiConfig)
Expand Down

0 comments on commit ff5b460

Please sign in to comment.