From 3c7b987a46f6bb0d329487d70398dd32fe9cae8c Mon Sep 17 00:00:00 2001 From: Paul Beusterien Date: Tue, 25 Feb 2025 11:54:15 -0800 Subject: [PATCH] Add google-app-id to Vertex AI requests (#14479) --- FirebaseVertexAI/Sources/FirebaseInfo.swift | 44 +++++++ .../Sources/GenerativeAIService.swift | 30 +++-- .../Sources/GenerativeModel.swift | 10 +- .../Types/Public/Imagen/ImagenModel.swift | 10 +- FirebaseVertexAI/Sources/VertexAI.swift | 44 +++---- FirebaseVertexAI/Tests/Unit/ChatTests.swift | 14 ++- .../Tests/Unit/GenerativeModelTests.swift | 108 +++++++++--------- .../Tests/Unit/VertexComponentTests.swift | 8 +- 8 files changed, 148 insertions(+), 120 deletions(-) create mode 100644 FirebaseVertexAI/Sources/FirebaseInfo.swift diff --git a/FirebaseVertexAI/Sources/FirebaseInfo.swift b/FirebaseVertexAI/Sources/FirebaseInfo.swift new file mode 100644 index 00000000000..cddebffadde --- /dev/null +++ b/FirebaseVertexAI/Sources/FirebaseInfo.swift @@ -0,0 +1,44 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +import FirebaseAppCheckInterop +import FirebaseAuthInterop +import FirebaseCore + +/// Firebase data used by VertexAI +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +struct FirebaseInfo { + let appCheck: AppCheckInterop? + let auth: AuthInterop? + let projectID: String + let apiKey: String + let googleAppID: String + let app: FirebaseApp + + init(appCheck: AppCheckInterop? = nil, + auth: AuthInterop? = nil, + projectID: String, + apiKey: String, + googleAppID: String, + firebaseApp: FirebaseApp) { + self.appCheck = appCheck + self.auth = auth + self.projectID = projectID + self.apiKey = apiKey + self.googleAppID = googleAppID + app = firebaseApp + } +} diff --git a/FirebaseVertexAI/Sources/GenerativeAIService.swift b/FirebaseVertexAI/Sources/GenerativeAIService.swift index fc35c2b258a..b3f150b1acb 100644 --- a/FirebaseVertexAI/Sources/GenerativeAIService.swift +++ b/FirebaseVertexAI/Sources/GenerativeAIService.swift @@ -26,23 +26,12 @@ struct GenerativeAIService { /// The Firebase SDK version in the format `fire/`. static let firebaseVersionTag = "fire/\(FirebaseVersion())" - private let projectID: String - - /// Gives permission to talk to the backend. - private let apiKey: String - - private let appCheck: AppCheckInterop? - - private let auth: AuthInterop? + private let firebaseInfo: FirebaseInfo private let urlSession: URLSession - init(projectID: String, apiKey: String, appCheck: AppCheckInterop?, auth: AuthInterop?, - urlSession: URLSession) { - self.projectID = projectID - self.apiKey = apiKey - self.appCheck = appCheck - self.auth = auth + init(firebaseInfo: FirebaseInfo, urlSession: URLSession) { + self.firebaseInfo = firebaseInfo self.urlSession = urlSession } @@ -180,14 +169,14 @@ struct GenerativeAIService { private func urlRequest(request: T) async throws -> URLRequest { var urlRequest = URLRequest(url: request.url) urlRequest.httpMethod = "POST" - urlRequest.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key") + urlRequest.setValue(firebaseInfo.apiKey, forHTTPHeaderField: "x-goog-api-key") urlRequest.setValue( "\(GenerativeAIService.languageTag) \(GenerativeAIService.firebaseVersionTag)", forHTTPHeaderField: "x-goog-api-client" ) urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") - if let appCheck { + if let appCheck = firebaseInfo.appCheck { let tokenResult = await appCheck.getToken(forcingRefresh: false) urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck") if let error = tokenResult.error { @@ -198,10 +187,16 @@ struct GenerativeAIService { } } - if let auth, let authToken = try await auth.getToken(forcingRefresh: false) { + if let auth = firebaseInfo.auth, let authToken = try await auth.getToken( + forcingRefresh: false + ) { urlRequest.setValue("Firebase \(authToken)", forHTTPHeaderField: "Authorization") } + if firebaseInfo.app.isDataCollectionDefaultEnabled { + urlRequest.setValue(firebaseInfo.googleAppID, forHTTPHeaderField: "X-Firebase-AppId") + } + let encoder = JSONEncoder() urlRequest.httpBody = try encoder.encode(request) urlRequest.timeoutInterval = request.options.timeout @@ -260,6 +255,7 @@ struct GenerativeAIService { // Log specific RPC errors that cannot be mitigated or handled by user code. // These errors do not produce specific GenerateContentError or CountTokensError cases. private func logRPCError(_ error: BackendError) { + let projectID = firebaseInfo.projectID if error.isVertexAIInFirebaseServiceDisabledError() { VertexLog.error(code: .vertexAIInFirebaseAPIDisabled, """ The Vertex AI in Firebase SDK requires the Vertex AI in Firebase API \ diff --git a/FirebaseVertexAI/Sources/GenerativeModel.swift b/FirebaseVertexAI/Sources/GenerativeModel.swift index 0d2ea829f55..ef104cbc8de 100644 --- a/FirebaseVertexAI/Sources/GenerativeModel.swift +++ b/FirebaseVertexAI/Sources/GenerativeModel.swift @@ -59,23 +59,17 @@ public final class GenerativeModel { /// - requestOptions: Configuration parameters for sending requests to the backend. /// - urlSession: The `URLSession` to use for requests; defaults to `URLSession.shared`. init(name: String, - projectID: String, - apiKey: String, + firebaseInfo: FirebaseInfo, generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil, tools: [Tool]?, toolConfig: ToolConfig? = nil, systemInstruction: ModelContent? = nil, requestOptions: RequestOptions, - appCheck: AppCheckInterop?, - auth: AuthInterop?, urlSession: URLSession = .shared) { modelResourceName = name generativeAIService = GenerativeAIService( - projectID: projectID, - apiKey: apiKey, - appCheck: appCheck, - auth: auth, + firebaseInfo: firebaseInfo, urlSession: urlSession ) self.generationConfig = generationConfig diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift index 2196d4c6040..8f894a52488 100644 --- a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift @@ -41,20 +41,14 @@ public final class ImagenModel { let requestOptions: RequestOptions init(name: String, - projectID: String, - apiKey: String, + firebaseInfo: FirebaseInfo, generationConfig: ImagenGenerationConfig?, safetySettings: ImagenSafetySettings?, requestOptions: RequestOptions, - appCheck: AppCheckInterop?, - auth: AuthInterop?, urlSession: URLSession = .shared) { modelResourceName = name generativeAIService = GenerativeAIService( - projectID: projectID, - apiKey: apiKey, - appCheck: appCheck, - auth: auth, + firebaseInfo: firebaseInfo, urlSession: urlSession ) self.generationConfig = generationConfig diff --git a/FirebaseVertexAI/Sources/VertexAI.swift b/FirebaseVertexAI/Sources/VertexAI.swift index bb6415988a5..097f3230ec4 100644 --- a/FirebaseVertexAI/Sources/VertexAI.swift +++ b/FirebaseVertexAI/Sources/VertexAI.swift @@ -91,16 +91,13 @@ public class VertexAI { -> GenerativeModel { return GenerativeModel( name: modelResourceName(modelName: modelName), - projectID: projectID, - apiKey: apiKey, + firebaseInfo: firebaseInfo, generationConfig: generationConfig, safetySettings: safetySettings, tools: tools, toolConfig: toolConfig, systemInstruction: systemInstruction, - requestOptions: requestOptions, - appCheck: appCheck, - auth: auth + requestOptions: requestOptions ) } @@ -126,13 +123,10 @@ public class VertexAI { requestOptions: RequestOptions = RequestOptions()) -> ImagenModel { return ImagenModel( name: modelResourceName(modelName: modelName), - projectID: projectID, - apiKey: apiKey, + firebaseInfo: firebaseInfo, generationConfig: generationConfig, safetySettings: safetySettings, - requestOptions: requestOptions, - appCheck: appCheck, - auth: auth + requestOptions: requestOptions ) } @@ -142,12 +136,8 @@ public class VertexAI { // MARK: - Private - /// The `FirebaseApp` associated with this `VertexAI` instance. - private let app: FirebaseApp - - private let appCheck: AppCheckInterop? - - private let auth: AuthInterop? + /// Firebase data relevant to Vertex AI. + let firebaseInfo: FirebaseInfo #if compiler(>=6) /// A map of active `VertexAI` instances keyed by the `FirebaseApp` name and the `location`, in @@ -165,25 +155,26 @@ public class VertexAI { private static var instancesLock: os_unfair_lock = .init() #endif - let projectID: String - let apiKey: String let location: String init(app: FirebaseApp, location: String) { - self.app = app - appCheck = ComponentType.instance(for: AppCheckInterop.self, in: app.container) - auth = ComponentType.instance(for: AuthInterop.self, in: app.container) - guard let projectID = app.options.projectID else { fatalError("The Firebase app named \"\(app.name)\" has no project ID in its configuration.") } - self.projectID = projectID - guard let apiKey = app.options.apiKey else { fatalError("The Firebase app named \"\(app.name)\" has no API key in its configuration.") } - self.apiKey = apiKey - + firebaseInfo = FirebaseInfo( + appCheck: ComponentType.instance( + for: AppCheckInterop.self, + in: app.container + ), + auth: ComponentType.instance(for: AuthInterop.self, in: app.container), + projectID: projectID, + apiKey: apiKey, + googleAppID: app.options.googleAppID, + firebaseApp: app + ) self.location = location } @@ -205,6 +196,7 @@ public class VertexAI { """) } + let projectID = firebaseInfo.projectID return "projects/\(projectID)/locations/\(location)/publishers/google/models/\(modelName)" } } diff --git a/FirebaseVertexAI/Tests/Unit/ChatTests.swift b/FirebaseVertexAI/Tests/Unit/ChatTests.swift index 1c4988faf7c..a0525880da1 100644 --- a/FirebaseVertexAI/Tests/Unit/ChatTests.swift +++ b/FirebaseVertexAI/Tests/Unit/ChatTests.swift @@ -15,6 +15,7 @@ import Foundation import XCTest +import FirebaseCore @testable import FirebaseVertexAI @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) @@ -53,14 +54,19 @@ final class ChatTests: XCTestCase { return (response, fileURL.lines) } + let app = FirebaseApp(instanceWithName: "testApp", + options: FirebaseOptions(googleAppID: "ignore", + gcmSenderID: "ignore")) let model = GenerativeModel( name: "my-model", - projectID: "my-project-id", - apiKey: "API_KEY", + firebaseInfo: FirebaseInfo( + projectID: "my-project-id", + apiKey: "API_KEY", + googleAppID: "My app ID", + firebaseApp: app + ), tools: nil, requestOptions: RequestOptions(), - appCheck: nil, - auth: nil, urlSession: urlSession ) let chat = Chat(model: model, history: []) diff --git a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift index 3ed40ce2530..10a7d41d793 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift @@ -68,12 +68,9 @@ final class GenerativeModelTests: XCTestCase { urlSession = try XCTUnwrap(URLSession(configuration: configuration)) model = GenerativeModel( name: testModelResourceName, - projectID: "my-project-id", - apiKey: "API_KEY", + firebaseInfo: testFirebaseInfo(), tools: nil, requestOptions: RequestOptions(), - appCheck: nil, - auth: nil, urlSession: urlSession ) } @@ -269,12 +266,9 @@ final class GenerativeModelTests: XCTestCase { let model = GenerativeModel( // Model name is prefixed with "models/". name: "models/test-model", - projectID: "my-project-id", - apiKey: "API_KEY", + firebaseInfo: testFirebaseInfo(), tools: nil, requestOptions: RequestOptions(), - appCheck: nil, - auth: nil, urlSession: urlSession ) @@ -389,12 +383,9 @@ final class GenerativeModelTests: XCTestCase { let appCheckToken = "test-valid-token" model = GenerativeModel( name: testModelResourceName, - projectID: "my-project-id", - apiKey: "API_KEY", + firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)), tools: nil, requestOptions: RequestOptions(), - appCheck: AppCheckInteropFake(token: appCheckToken), - auth: nil, urlSession: urlSession ) MockURLProtocol @@ -407,15 +398,33 @@ final class GenerativeModelTests: XCTestCase { _ = try await model.generateContent(testPrompt) } + func testGenerateContent_dataCollectionOff() async throws { + let appCheckToken = "test-valid-token" + model = GenerativeModel( + name: testModelResourceName, + firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken), + privateAppID: true), + tools: nil, + requestOptions: RequestOptions(), + urlSession: urlSession + ) + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-success-basic-reply-short", + withExtension: "json", + appCheckToken: appCheckToken, + dataCollection: false + ) + + _ = try await model.generateContent(testPrompt) + } + func testGenerateContent_appCheck_tokenRefreshError() async throws { model = GenerativeModel( name: testModelResourceName, - projectID: "my-project-id", - apiKey: "API_KEY", + firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())), tools: nil, requestOptions: RequestOptions(), - appCheck: AppCheckInteropFake(error: AppCheckErrorFake()), - auth: nil, urlSession: urlSession ) MockURLProtocol @@ -432,12 +441,9 @@ final class GenerativeModelTests: XCTestCase { let authToken = "test-valid-token" model = GenerativeModel( name: testModelResourceName, - projectID: "my-project-id", - apiKey: "API_KEY", + firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: authToken)), tools: nil, requestOptions: RequestOptions(), - appCheck: nil, - auth: AuthInteropFake(token: authToken), urlSession: urlSession ) MockURLProtocol @@ -453,12 +459,9 @@ final class GenerativeModelTests: XCTestCase { func testGenerateContent_auth_nilAuthToken() async throws { model = GenerativeModel( name: testModelResourceName, - projectID: "my-project-id", - apiKey: "API_KEY", + firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: nil)), tools: nil, requestOptions: RequestOptions(), - appCheck: nil, - auth: AuthInteropFake(token: nil), urlSession: urlSession ) MockURLProtocol @@ -474,12 +477,9 @@ final class GenerativeModelTests: XCTestCase { func testGenerateContent_auth_authTokenRefreshError() async throws { model = GenerativeModel( name: "my-model", - projectID: "my-project-id", - apiKey: "API_KEY", + firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(error: AuthErrorFake())), tools: nil, requestOptions: RequestOptions(), - appCheck: nil, - auth: AuthInteropFake(error: AuthErrorFake()), urlSession: urlSession ) MockURLProtocol @@ -856,12 +856,9 @@ final class GenerativeModelTests: XCTestCase { let requestOptions = RequestOptions(timeout: expectedTimeout) model = GenerativeModel( name: testModelResourceName, - projectID: "my-project-id", - apiKey: "API_KEY", + firebaseInfo: testFirebaseInfo(), tools: nil, requestOptions: requestOptions, - appCheck: nil, - auth: nil, urlSession: urlSession ) @@ -1151,12 +1148,9 @@ final class GenerativeModelTests: XCTestCase { let appCheckToken = "test-valid-token" model = GenerativeModel( name: testModelResourceName, - projectID: "my-project-id", - apiKey: "API_KEY", + firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)), tools: nil, requestOptions: RequestOptions(), - appCheck: AppCheckInteropFake(token: appCheckToken), - auth: nil, urlSession: urlSession ) MockURLProtocol @@ -1173,12 +1167,9 @@ final class GenerativeModelTests: XCTestCase { func testGenerateContentStream_appCheck_tokenRefreshError() async throws { model = GenerativeModel( name: testModelResourceName, - projectID: "my-project-id", - apiKey: "API_KEY", + firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())), tools: nil, requestOptions: RequestOptions(), - appCheck: AppCheckInteropFake(error: AppCheckErrorFake()), - auth: nil, urlSession: urlSession ) MockURLProtocol @@ -1319,12 +1310,9 @@ final class GenerativeModelTests: XCTestCase { let requestOptions = RequestOptions(timeout: expectedTimeout) model = GenerativeModel( name: testModelResourceName, - projectID: "my-project-id", - apiKey: "API_KEY", + firebaseInfo: testFirebaseInfo(), tools: nil, requestOptions: requestOptions, - appCheck: nil, - auth: nil, urlSession: urlSession ) @@ -1394,14 +1382,11 @@ final class GenerativeModelTests: XCTestCase { ) model = GenerativeModel( name: testModelResourceName, - projectID: "my-project-id", - apiKey: "API_KEY", + firebaseInfo: testFirebaseInfo(), generationConfig: generationConfig, tools: [Tool(functionDeclarations: [sumFunction])], systemInstruction: systemInstruction, requestOptions: RequestOptions(), - appCheck: nil, - auth: nil, urlSession: urlSession ) @@ -1453,12 +1438,9 @@ final class GenerativeModelTests: XCTestCase { let requestOptions = RequestOptions(timeout: expectedTimeout) model = GenerativeModel( name: testModelResourceName, - projectID: "my-project-id", - apiKey: "API_KEY", + firebaseInfo: testFirebaseInfo(), tools: nil, requestOptions: requestOptions, - appCheck: nil, - auth: nil, urlSession: urlSession ) @@ -1469,6 +1451,23 @@ final class GenerativeModelTests: XCTestCase { // MARK: - Helpers + private func testFirebaseInfo(appCheck: AppCheckInterop? = nil, + auth: AuthInterop? = nil, + privateAppID: Bool = false) -> FirebaseInfo { + let app = FirebaseApp(instanceWithName: "testApp", + options: FirebaseOptions(googleAppID: "ignore", + gcmSenderID: "ignore")) + app.isDataCollectionDefaultEnabled = !privateAppID + return FirebaseInfo( + appCheck: appCheck, + auth: auth, + projectID: "my-project-id", + apiKey: "API_KEY", + googleAppID: "My app ID", + firebaseApp: app + ) + } + private func nonHTTPRequestHandler() throws -> ((URLRequest) -> ( URLResponse, AsyncLineSequence? @@ -1495,7 +1494,8 @@ final class GenerativeModelTests: XCTestCase { statusCode: Int = 200, timeout: TimeInterval = RequestOptions().timeout, appCheckToken: String? = nil, - authToken: String? = nil) throws -> ((URLRequest) throws -> ( + authToken: String? = nil, + dataCollection: Bool = true) throws -> ((URLRequest) throws -> ( URLResponse, AsyncLineSequence? )) { @@ -1515,6 +1515,8 @@ final class GenerativeModelTests: XCTestCase { XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag)) XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag)) XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken) + let googleAppID = request.value(forHTTPHeaderField: "X-Firebase-AppId") + XCTAssertEqual(googleAppID, dataCollection ? "My app ID" : nil) if let authToken { XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Firebase \(authToken)") } else { diff --git a/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift b/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift index 5e685dd98bd..832d56f9cb4 100644 --- a/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift +++ b/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift @@ -50,8 +50,8 @@ class VertexComponentTests: XCTestCase { let vertex = VertexAI.vertexAI(app: VertexComponentTests.app, location: location) XCTAssertNotNil(vertex) - XCTAssertEqual(vertex.projectID, VertexComponentTests.projectID) - XCTAssertEqual(vertex.apiKey, VertexComponentTests.apiKey) + XCTAssertEqual(vertex.firebaseInfo.projectID, VertexComponentTests.projectID) + XCTAssertEqual(vertex.firebaseInfo.apiKey, VertexComponentTests.apiKey) XCTAssertEqual(vertex.location, location) } @@ -121,12 +121,12 @@ class VertexComponentTests: XCTestCase { let app = try XCTUnwrap(VertexComponentTests.app) let vertex = VertexAI.vertexAI(app: app, location: location) let model = "test-model-name" - let modelResourceName = vertex.modelResourceName(modelName: model) + let projectID = vertex.firebaseInfo.projectID XCTAssertEqual( modelResourceName, - "projects/\(vertex.projectID)/locations/\(vertex.location)/publishers/google/models/\(model)" + "projects/\(projectID)/locations/\(vertex.location)/publishers/google/models/\(model)" ) }