Skip to content

Commit

Permalink
[Vertex AI] Add Developer API encoding CountTokensRequest
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Feb 28, 2025
1 parent 6039e0b commit 9054b7f
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 15 deletions.
42 changes: 32 additions & 10 deletions FirebaseVertexAI/Sources/CountTokensRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,20 @@ import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct CountTokensRequest {
let model: String

let contents: [ModelContent]
let systemInstruction: ModelContent?
let tools: [Tool]?
let generationConfig: GenerationConfig?

let options: RequestOptions
let generateContentRequest: GenerateContentRequest
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension CountTokensRequest: GenerativeAIRequest {
typealias Response = CountTokensResponse

var options: RequestOptions {
generateContentRequest.options
}

var url: URL {
URL(string: "\(Constants.baseURL)/\(options.apiVersion)/\(model):countTokens")!
let model = generateContentRequest.model
return URL(string: "\(Constants.baseURL)/\(options.apiVersion)/\(model):countTokens")!
}
}

Expand All @@ -55,12 +53,36 @@ public struct CountTokensResponse {

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension CountTokensRequest: Encodable {
enum CodingKeys: CodingKey {
enum VertexCodingKeys: CodingKey {
case contents
case systemInstruction
case tools
case generationConfig
}

enum DeveloperCodingKeys: CodingKey {
case generateContentRequest
}

func encode(to encoder: any Encoder) throws {
let backendAPI = encoder.userInfo[CodingUserInfoKey(rawValue: "BackendAPI")!] as! BackendAPI

switch backendAPI {
case .vertexAI:
var container = encoder.container(keyedBy: VertexCodingKeys.self)
try container.encode(generateContentRequest.contents, forKey: .contents)
try container.encodeIfPresent(
generateContentRequest.systemInstruction, forKey: .systemInstruction
)
try container.encodeIfPresent(generateContentRequest.tools, forKey: .tools)
try container.encodeIfPresent(
generateContentRequest.generationConfig, forKey: .generationConfig
)
case .developer:
var container = encoder.container(keyedBy: DeveloperCodingKeys.self)
try container.encode(generateContentRequest, forKey: .generateContentRequest)
}
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
Expand Down
5 changes: 4 additions & 1 deletion FirebaseVertexAI/Sources/FirebaseInfo.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,21 @@ struct FirebaseInfo: Sendable {
let apiKey: String
let googleAppID: String
let app: FirebaseApp
let backendAPI: BackendAPI

init(appCheck: AppCheckInterop? = nil,
auth: AuthInterop? = nil,
projectID: String,
apiKey: String,
googleAppID: String,
firebaseApp: FirebaseApp) {
firebaseApp: FirebaseApp,
backendAPI: BackendAPI) {
self.appCheck = appCheck
self.auth = auth
self.projectID = projectID
self.apiKey = apiKey
self.googleAppID = googleAppID
app = firebaseApp
self.backendAPI = backendAPI
}
}
1 change: 1 addition & 0 deletions FirebaseVertexAI/Sources/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct GenerateContentRequest: Sendable {
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension GenerateContentRequest: Encodable {
enum CodingKeys: String, CodingKey {
case model
case contents
case generationConfig
case safetySettings
Expand Down
1 change: 1 addition & 0 deletions FirebaseVertexAI/Sources/GenerativeAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ struct GenerativeAIService {
// }

let encoder = JSONEncoder()
encoder.userInfo[CodingUserInfoKey(rawValue: "BackendAPI")!] = firebaseInfo.backendAPI
urlRequest.httpBody = try encoder.encode(request)
urlRequest.timeoutInterval = request.options.timeout

Expand Down
11 changes: 8 additions & 3 deletions FirebaseVertexAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,19 @@ public final class GenerativeModel: Sendable {
/// - Returns: The results of running the model's tokenizer on the input; contains
/// ``CountTokensResponse/totalTokens``.
public func countTokens(_ content: [ModelContent]) async throws -> CountTokensResponse {
let countTokensRequest = CountTokensRequest(
let generateContentRequest = GenerateContentRequest(
model: modelResourceName,
contents: content,
systemInstruction: systemInstruction,
tools: tools,
generationConfig: generationConfig,
safetySettings: safetySettings,
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
isStreaming: false,
options: requestOptions
)
let countTokensRequest = CountTokensRequest(generateContentRequest: generateContentRequest)

return try await generativeAIService.loadRequest(request: countTokensRequest)
}

Expand Down
18 changes: 18 additions & 0 deletions FirebaseVertexAI/Sources/Types/Internal/BackendAPI.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

enum BackendAPI {
case vertexAI
case developer
}
3 changes: 2 additions & 1 deletion FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ public class VertexAI {
projectID: projectID,
apiKey: apiKey,
googleAppID: app.options.googleAppID,
firebaseApp: app
firebaseApp: app,
backendAPI: .vertexAI
)
self.location = location
}
Expand Down

0 comments on commit 9054b7f

Please sign in to comment.