diff --git a/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift b/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift index 399777ad1e8..7b3b78753db 100644 --- a/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift +++ b/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift @@ -142,22 +142,6 @@ struct ErrorDetailsView: View { SafetyRatingsSection(ratings: ratings) } - case GenerateContentError.invalidAPIKey: - Section("Error Type") { - Text("Invalid API Key") - } - - Section("Details") { - SubtitleFormRow(title: "Error description", value: error.localizedDescription) - SubtitleMarkdownFormRow( - title: "Help", - value: """ - The `API_KEY` provided in the `GoogleService-Info.plist` file is invalid. Download a - new copy of the file from the [Firebase Console](https://console.firebase.google.com). - """ - ) - } - default: Section("Error Type") { Text("Some other error") @@ -222,11 +206,3 @@ struct ErrorDetailsView: View { return ErrorDetailsView(error: error) } - -#Preview("Invalid API Key") { - ErrorDetailsView(error: GenerateContentError.invalidAPIKey) -} - -#Preview("Unsupported User Location") { - ErrorDetailsView(error: GenerateContentError.unsupportedUserLocation) -} diff --git a/FirebaseVertexAI/Sources/Chat.swift b/FirebaseVertexAI/Sources/Chat.swift new file mode 100644 index 00000000000..c7cfb859fa8 --- /dev/null +++ b/FirebaseVertexAI/Sources/Chat.swift @@ -0,0 +1,184 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// An object that represents a back-and-forth chat with a model, capturing the history and saving +/// the context in memory between each message sent. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public class Chat { + private let model: GenerativeModel + + /// Initializes a new chat representing a 1:1 conversation between model and user. + init(model: GenerativeModel, history: [ModelContent]) { + self.model = model + self.history = history + } + + /// The previous content from the chat that has been successfully sent and received from the + /// model. This will be provided to the model for each message sent as context for the discussion. + public var history: [ModelContent] + + /// See ``sendMessage(_:)-3ify5``. + public func sendMessage(_ parts: any ThrowingPartsRepresentable...) async throws + -> GenerateContentResponse { + return try await sendMessage([ModelContent(parts: parts)]) + } + + /// Sends a message using the existing history of this chat as context. If successful, the message + /// and response will be added to the history. If unsuccessful, history will remain unchanged. + /// - Parameter content: The new content to send as a single chat message. + /// - Returns: The model's response if no error occurred. + /// - Throws: A ``GenerateContentError`` if an error occurred. + public func sendMessage(_ content: @autoclosure () throws -> [ModelContent]) async throws + -> GenerateContentResponse { + // Ensure that the new content has the role set. + let newContent: [ModelContent] + do { + newContent = try content().map(populateContentRole(_:)) + } catch let underlying { + if let contentError = underlying as? ImageConversionError { + throw GenerateContentError.promptImageContentError(underlying: contentError) + } else { + throw GenerateContentError.internalError(underlying: underlying) + } + } + + // Send the history alongside the new message as context. + let request = history + newContent + let result = try await model.generateContent(request) + guard let reply = result.candidates.first?.content else { + let error = NSError(domain: "com.google.generative-ai", + code: -1, + userInfo: [ + NSLocalizedDescriptionKey: "No candidates with content available.", + ]) + throw GenerateContentError.internalError(underlying: error) + } + + // Make sure we inject the role into the content received. + let toAdd = ModelContent(role: "model", parts: reply.parts) + + // Append the request and successful result to history, then return the value. + history.append(contentsOf: newContent) + history.append(toAdd) + return result + } + + /// See ``sendMessageStream(_:)-4abs3``. + @available(macOS 12.0, *) + public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...) + -> AsyncThrowingStream { + return try sendMessageStream([ModelContent(parts: parts)]) + } + + /// Sends a message using the existing history of this chat as context. If successful, the message + /// and response will be added to the history. If unsuccessful, history will remain unchanged. + /// - Parameter content: The new content to send as a single chat message. + /// - Returns: A stream containing the model's response or an error if an error occurred. + @available(macOS 12.0, *) + public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent]) + -> AsyncThrowingStream { + let resolvedContent: [ModelContent] + do { + resolvedContent = try content() + } catch let underlying { + return AsyncThrowingStream { continuation in + let error: Error + if let contentError = underlying as? ImageConversionError { + error = GenerateContentError.promptImageContentError(underlying: contentError) + } else { + error = GenerateContentError.internalError(underlying: underlying) + } + continuation.finish(throwing: error) + } + } + + return AsyncThrowingStream { continuation in + Task { + var aggregatedContent: [ModelContent] = [] + + // Ensure that the new content has the role set. + let newContent: [ModelContent] = resolvedContent.map(populateContentRole(_:)) + + // Send the history alongside the new message as context. + let request = history + newContent + let stream = model.generateContentStream(request) + do { + for try await chunk in stream { + // Capture any content that's streaming. This should be populated if there's no error. + if let chunkContent = chunk.candidates.first?.content { + aggregatedContent.append(chunkContent) + } + + // Pass along the chunk. + continuation.yield(chunk) + } + } catch { + // Rethrow the error that the underlying stream threw. Don't add anything to history. + continuation.finish(throwing: error) + return + } + + // Save the request. + history.append(contentsOf: newContent) + + // Aggregate the content to add it to the history before we finish. + let aggregated = aggregatedChunks(aggregatedContent) + history.append(aggregated) + + continuation.finish() + } + } + } + + private func aggregatedChunks(_ chunks: [ModelContent]) -> ModelContent { + var parts: [ModelContent.Part] = [] + var combinedText = "" + for aggregate in chunks { + // Loop through all the parts, aggregating the text and adding the images. + for part in aggregate.parts { + switch part { + case let .text(str): + combinedText += str + + case .data(mimetype: _, _): + // Don't combine it, just add to the content. If there's any text pending, add that as + // a part. + if !combinedText.isEmpty { + parts.append(.text(combinedText)) + combinedText = "" + } + + parts.append(part) + } + } + } + + if !combinedText.isEmpty { + parts.append(.text(combinedText)) + } + + return ModelContent(role: "model", parts: parts) + } + + /// Populates the `role` field with `user` if it doesn't exist. Required in chat sessions. + private func populateContentRole(_ content: ModelContent) -> ModelContent { + if content.role != nil { + return content + } else { + return ModelContent(role: "user", parts: content.parts) + } + } +} diff --git a/FirebaseVertexAI/Sources/CountTokensRequest.swift b/FirebaseVertexAI/Sources/CountTokensRequest.swift new file mode 100644 index 00000000000..0a58d40acc0 --- /dev/null +++ b/FirebaseVertexAI/Sources/CountTokensRequest.swift @@ -0,0 +1,45 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +struct CountTokensRequest { + let model: String + let contents: [ModelContent] + let options: RequestOptions +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension CountTokensRequest: Encodable { + enum CodingKeys: CodingKey { + case contents + } +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension CountTokensRequest: GenerativeAIRequest { + typealias Response = CountTokensResponse + + var url: URL { + URL(string: "\(GenerativeAISwift.baseURL)/\(options.apiVersion)/\(model):countTokens")! + } +} + +/// The model's response to a count tokens request. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public struct CountTokensResponse: Decodable { + /// The total number of tokens in the input given to the model as a prompt. + public let totalTokens: Int +} diff --git a/FirebaseVertexAI/Sources/Errors.swift b/FirebaseVertexAI/Sources/Errors.swift new file mode 100644 index 00000000000..0b1b8848411 --- /dev/null +++ b/FirebaseVertexAI/Sources/Errors.swift @@ -0,0 +1,177 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +struct RPCError: Error { + let httpResponseCode: Int + let message: String + let status: RPCStatus + let details: [ErrorDetails] + + private var errorInfo: ErrorDetails? { + return details.first { $0.isErrorInfo() } + } + + init(httpResponseCode: Int, message: String, status: RPCStatus, details: [ErrorDetails]) { + self.httpResponseCode = httpResponseCode + self.message = message + self.status = status + self.details = details + } +} + +extension RPCError: Decodable { + enum CodingKeys: CodingKey { + case error + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let status = try container.decode(ErrorStatus.self, forKey: .error) + + if let code = status.code { + httpResponseCode = code + } else { + httpResponseCode = -1 + } + + if let message = status.message { + self.message = message + } else { + message = "Unknown error." + } + + if let rpcStatus = status.status { + self.status = rpcStatus + } else { + self.status = .unknown + } + + details = status.details + } +} + +struct ErrorStatus { + let code: Int? + let message: String? + let status: RPCStatus? + let details: [ErrorDetails] +} + +struct ErrorDetails { + static let errorInfoType = "type.googleapis.com/google.rpc.ErrorInfo" + + let type: String + let reason: String? + let domain: String? + + func isErrorInfo() -> Bool { + return type == ErrorDetails.errorInfoType + } +} + +extension ErrorDetails: Decodable, Equatable { + enum CodingKeys: String, CodingKey { + case type = "@type" + case reason + case domain + } +} + +extension ErrorStatus: Decodable { + enum CodingKeys: CodingKey { + case code + case message + case status + case details + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + code = try container.decodeIfPresent(Int.self, forKey: .code) + message = try container.decodeIfPresent(String.self, forKey: .message) + do { + status = try container.decodeIfPresent(RPCStatus.self, forKey: .status) + } catch { + status = .unknown + } + if container.contains(.details) { + details = try container.decode([ErrorDetails].self, forKey: .details) + } else { + details = [] + } + } +} + +enum RPCStatus: String, Decodable { + // Not an error; returned on success. + case ok = "OK" + + // The operation was cancelled, typically by the caller. + case cancelled = "CANCELLED" + + // Unknown error. + case unknown = "UNKNOWN" + + // The client specified an invalid argument. + case invalidArgument = "INVALID_ARGUMENT" + + // The deadline expired before the operation could complete. + case deadlineExceeded = "DEADLINE_EXCEEDED" + + // Some requested entity (e.g., file or directory) was not found. + case notFound = "NOT_FOUND" + + // The entity that a client attempted to create (e.g., file or directory) already exists. + case alreadyExists = "ALREADY_EXISTS" + + // The caller does not have permission to execute the specified operation. + case permissionDenied = "PERMISSION_DENIED" + + // The request does not have valid authentication credentials for the operation. + case unauthenticated = "UNAUTHENTICATED" + + // Some resource has been exhausted, perhaps a per-user quota, or perhaps the entire file system + // is out of space. + case resourceExhausted = "RESOURCE_EXHAUSTED" + + // The operation was rejected because the system is not in a state required for the operation's + // execution. + case failedPrecondition = "FAILED_PRECONDITION" + + // The operation was aborted, typically due to a concurrency issue such as a sequencer check + // failure or transaction abort. + case aborted = "ABORTED" + + // The operation was attempted past the valid range. + case outOfRange = "OUT_OF_RANGE" + + // The operation is not implemented or is not supported/enabled in this service. + case unimplemented = "UNIMPLEMENTED" + + // Internal errors. + case internalError = "INTERNAL" + + // The service is currently unavailable. + case unavailable = "UNAVAILABLE" + + // Unrecoverable data loss or corruption. + case dataLoss = "DATA_LOSS" +} + +enum InvalidCandidateError: Error { + case emptyContent(underlyingError: Error) + case malformedContent(underlyingError: Error) +} diff --git a/FirebaseVertexAI/Sources/GenerateContentError.swift b/FirebaseVertexAI/Sources/GenerateContentError.swift new file mode 100644 index 00000000000..31501a7744c --- /dev/null +++ b/FirebaseVertexAI/Sources/GenerateContentError.swift @@ -0,0 +1,31 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// Errors that occur when generating content from a model. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public enum GenerateContentError: Error { + /// An error occurred when constructing the prompt. Examine the related error for details. + case promptImageContentError(underlying: ImageConversionError) + + /// An internal error occurred. See the underlying error for more context. + case internalError(underlying: Error) + + /// A prompt was blocked. See the response's `promptFeedback.blockReason` for more information. + case promptBlocked(response: GenerateContentResponse) + + /// A response didn't fully complete. See the `FinishReason` for more information. + case responseStoppedEarly(reason: FinishReason, response: GenerateContentResponse) +} diff --git a/FirebaseVertexAI/Sources/GenerateContentRequest.swift b/FirebaseVertexAI/Sources/GenerateContentRequest.swift new file mode 100644 index 00000000000..417260bb700 --- /dev/null +++ b/FirebaseVertexAI/Sources/GenerateContentRequest.swift @@ -0,0 +1,49 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +struct GenerateContentRequest { + /// Model name. + let model: String + let contents: [ModelContent] + let generationConfig: GenerationConfig? + let safetySettings: [SafetySetting]? + let isStreaming: Bool + let options: RequestOptions +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension GenerateContentRequest: Encodable { + enum CodingKeys: String, CodingKey { + case contents + case generationConfig + case safetySettings + } +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension GenerateContentRequest: GenerativeAIRequest { + typealias Response = GenerateContentResponse + + var url: URL { + let modelURL = "\(GenerativeAISwift.baseURL)/\(options.apiVersion)/\(model)" + if isStreaming { + return URL(string: "\(modelURL):streamGenerateContent?alt=sse")! + } else { + return URL(string: "\(modelURL):generateContent")! + } + } +} diff --git a/FirebaseVertexAI/Sources/GenerateContentResponse.swift b/FirebaseVertexAI/Sources/GenerateContentResponse.swift new file mode 100644 index 00000000000..03153da9ae7 --- /dev/null +++ b/FirebaseVertexAI/Sources/GenerateContentResponse.swift @@ -0,0 +1,286 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// The model's response to a generate content request. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public struct GenerateContentResponse { + /// A list of candidate response content, ordered from best to worst. + public let candidates: [CandidateResponse] + + /// A value containing the safety ratings for the response, or, if the request was blocked, a + /// reason for blocking the request. + public let promptFeedback: PromptFeedback? + + /// The response's content as text, if it exists. + public var text: String? { + guard let candidate = candidates.first else { + Logging.default.error("Could not get text from a response that had no candidates.") + return nil + } + guard let text = candidate.content.parts.first?.text else { + Logging.default.error("Could not get a text part from the first candidate.") + return nil + } + return text + } + + /// Initializer for SwiftUI previews or tests. + public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback?) { + self.candidates = candidates + self.promptFeedback = promptFeedback + } +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension GenerateContentResponse: Decodable { + enum CodingKeys: CodingKey { + case candidates + case promptFeedback + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + + guard container.contains(CodingKeys.candidates) || container + .contains(CodingKeys.promptFeedback) else { + let context = DecodingError.Context( + codingPath: [], + debugDescription: "Failed to decode GenerateContentResponse;" + + " missing keys 'candidates' and 'promptFeedback'." + ) + throw DecodingError.dataCorrupted(context) + } + + if let candidates = try container.decodeIfPresent( + [CandidateResponse].self, + forKey: .candidates + ) { + self.candidates = candidates + } else { + candidates = [] + } + promptFeedback = try container.decodeIfPresent(PromptFeedback.self, forKey: .promptFeedback) + } +} + +/// A struct representing a possible reply to a content generation prompt. Each content generation +/// prompt may produce multiple candidate responses. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public struct CandidateResponse { + /// The response's content. + public let content: ModelContent + + /// The safety rating of the response content. + public let safetyRatings: [SafetyRating] + + /// The reason the model stopped generating content, if it exists; for example, if the model + /// generated a predefined stop sequence. + public let finishReason: FinishReason? + + /// Cited works in the model's response content, if it exists. + public let citationMetadata: CitationMetadata? + + /// Initializer for SwiftUI previews or tests. + public init(content: ModelContent, safetyRatings: [SafetyRating], finishReason: FinishReason?, + citationMetadata: CitationMetadata?) { + self.content = content + self.safetyRatings = safetyRatings + self.finishReason = finishReason + self.citationMetadata = citationMetadata + } +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension CandidateResponse: Decodable { + enum CodingKeys: CodingKey { + case content + case safetyRatings + case finishReason + case finishMessage + case citationMetadata + } + + /// Initializes a response from a decoder. Used for decoding server responses; not for public + /// use. + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + + do { + if let content = try container.decodeIfPresent(ModelContent.self, forKey: .content) { + self.content = content + } else { + content = ModelContent(parts: []) + } + } catch { + // Check if `content` can be decoded as an empty dictionary to detect the `"content": {}` bug. + if let content = try? container.decode([String: String].self, forKey: .content), + content.isEmpty { + throw InvalidCandidateError.emptyContent(underlyingError: error) + } else { + throw InvalidCandidateError.malformedContent(underlyingError: error) + } + } + + if let safetyRatings = try container.decodeIfPresent( + [SafetyRating].self, + forKey: .safetyRatings + ) { + self.safetyRatings = safetyRatings + } else { + safetyRatings = [] + } + + finishReason = try container.decodeIfPresent(FinishReason.self, forKey: .finishReason) + + citationMetadata = try container.decodeIfPresent( + CitationMetadata.self, + forKey: .citationMetadata + ) + } +} + +/// A collection of source attributions for a piece of content. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public struct CitationMetadata: Decodable { + /// A list of individual cited sources and the parts of the content to which they apply. + public let citationSources: [Citation] +} + +/// A struct describing a source attribution. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public struct Citation: Decodable { + /// The inclusive beginning of a sequence in a model response that derives from a cited source. + public let startIndex: Int + + /// The exclusive end of a sequence in a model response that derives from a cited source. + public let endIndex: Int + + /// A link to the cited source. + public let uri: String + + /// The license the cited source work is distributed under. + public let license: String +} + +/// A value enumerating possible reasons for a model to terminate a content generation request. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public enum FinishReason: String { + case unknown = "FINISH_REASON_UNKNOWN" + + case unspecified = "FINISH_REASON_UNSPECIFIED" + + /// Natural stop point of the model or provided stop sequence. + case stop = "STOP" + + /// The maximum number of tokens as specified in the request was reached. + case maxTokens = "MAX_TOKENS" + + /// The token generation was stopped because the response was flagged for safety reasons. + /// NOTE: When streaming, the Candidate.content will be empty if content filters blocked the + /// output. + case safety = "SAFETY" + + /// The token generation was stopped because the response was flagged for unauthorized citations. + case recitation = "RECITATION" + + /// All other reasons that stopped token generation. + case other = "OTHER" +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension FinishReason: Decodable { + /// Do not explicitly use. Initializer required for Decodable conformance. + public init(from decoder: Decoder) throws { + let value = try decoder.singleValueContainer().decode(String.self) + guard let decodedFinishReason = FinishReason(rawValue: value) else { + Logging.default + .error("[GoogleGenerativeAI] Unrecognized FinishReason with value \"\(value)\".") + self = .unknown + return + } + + self = decodedFinishReason + } +} + +/// A metadata struct containing any feedback the model had on the prompt it was provided. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public struct PromptFeedback { + /// A type describing possible reasons to block a prompt. + public enum BlockReason: String, Decodable { + /// The block reason is unknown. + case unknown = "UNKNOWN" + + /// The block reason was not specified in the server response. + case unspecified = "BLOCK_REASON_UNSPECIFIED" + + /// The prompt was blocked because it was deemed unsafe. + case safety = "SAFETY" + + /// All other block reasons. + case other = "OTHER" + + /// Do not explicitly use. Initializer required for Decodable conformance. + public init(from decoder: Decoder) throws { + let value = try decoder.singleValueContainer().decode(String.self) + guard let decodedBlockReason = BlockReason(rawValue: value) else { + Logging.default + .error("[GoogleGenerativeAI] Unrecognized BlockReason with value \"\(value)\".") + self = .unknown + return + } + + self = decodedBlockReason + } + } + + /// The reason a prompt was blocked, if it was blocked. + public let blockReason: BlockReason? + + /// The safety ratings of the prompt. + public let safetyRatings: [SafetyRating] + + /// Initializer for SwiftUI previews or tests. + public init(blockReason: BlockReason?, safetyRatings: [SafetyRating]) { + self.blockReason = blockReason + self.safetyRatings = safetyRatings + } +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension PromptFeedback: Decodable { + enum CodingKeys: CodingKey { + case blockReason + case safetyRatings + } + + /// Do not explicitly use. Initializer required for Decodable conformance. + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + blockReason = try container.decodeIfPresent( + PromptFeedback.BlockReason.self, + forKey: .blockReason + ) + if let safetyRatings = try container.decodeIfPresent( + [SafetyRating].self, + forKey: .safetyRatings + ) { + self.safetyRatings = safetyRatings + } else { + safetyRatings = [] + } + } +} diff --git a/FirebaseVertexAI/Sources/GenerationConfig.swift b/FirebaseVertexAI/Sources/GenerationConfig.swift new file mode 100644 index 00000000000..2d1016c965c --- /dev/null +++ b/FirebaseVertexAI/Sources/GenerationConfig.swift @@ -0,0 +1,86 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// A struct defining model parameters to be used when sending generative AI +/// requests to the backend model. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public struct GenerationConfig: Encodable { + /// A parameter controlling the degree of randomness in token selection. A + /// temperature of zero is deterministic, always choosing the + /// highest-probability response. Typical values are between 0 and 1 + /// inclusive. Defaults to 0 if unspecified. + public let temperature: Float? + + /// The `topP` parameter changes how the model selects tokens for output. + /// Tokens are selected from the most to least probable until the sum of + /// their probabilities equals the `topP` value. For example, if tokens A, B, + /// and C have probabilities of 0.3, 0.2, and 0.1 respectively and the topP + /// value is 0.5, then the model will select either A or B as the next token + /// by using the `temperature` and exclude C as a candidate. + /// Defaults to 0.95 if unset. + public let topP: Float? + + /// The `topK` parameter changes how the model selects tokens for output. A + /// `topK` of 1 means the selected token is the most probable among all the + /// tokens in the model's vocabulary, while a `topK` of 3 means that the next + /// token is selected from among the 3 most probable using the `temperature`. + /// For each token selection step, the `topK` tokens with the highest + /// probabilities are sampled. Tokens are then further filtered based on + /// `topP` with the final token selected using `temperature` sampling. + /// Defaults to 40 if unspecified. + public let topK: Int? + + /// The maximum number of generated response messages to return. This value + /// must be between [1, 8], inclusive. If unset, this will default to 1. + /// + /// - Note: Only unique candidates are returned. Higher temperatures are more + /// likely to produce unique candidates. Setting `temperature` to 0 will + /// always produce exactly one candidate regardless of the + /// `candidateCount`. + public let candidateCount: Int? + + /// Specifies the maximum number of tokens that can be generated in the + /// response. The number of tokens per word varies depending on the + /// language outputted. The maximum value is capped at 1024. Defaults to 0 + /// (unbounded). + public let maxOutputTokens: Int? + + /// A set of up to 5 `String`s that will stop output generation. If + /// specified, the API will stop at the first appearance of a stop sequence. + /// The stop sequence will not be included as part of the response. + public let stopSequences: [String]? + + /// Creates a new `GenerationConfig` value. + /// + /// - Parameter temperature: See ``temperature`` + /// - Parameter topP: See ``topP`` + /// - Parameter topK: See ``topK`` + /// - Parameter candidateCount: See ``candidateCount`` + /// - Parameter maxOutputTokens: See ``maxOutputTokens`` + /// - Parameter stopSequences: See ``stopSequences`` + public init(temperature: Float? = nil, topP: Float? = nil, topK: Int? = nil, + candidateCount: Int? = nil, maxOutputTokens: Int? = nil, + stopSequences: [String]? = nil) { + // Explicit init because otherwise if we re-arrange the above variables it changes the API + // surface. + self.temperature = temperature + self.topP = topP + self.topK = topK + self.candidateCount = candidateCount + self.maxOutputTokens = maxOutputTokens + self.stopSequences = stopSequences + } +} diff --git a/FirebaseVertexAI/Sources/GenerativeAIRequest.swift b/FirebaseVertexAI/Sources/GenerativeAIRequest.swift new file mode 100644 index 00000000000..21f35b1b728 --- /dev/null +++ b/FirebaseVertexAI/Sources/GenerativeAIRequest.swift @@ -0,0 +1,46 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +protocol GenerativeAIRequest: Encodable { + associatedtype Response: Decodable + + var url: URL { get } + + var options: RequestOptions { get } +} + +/// Configuration parameters for sending requests to the backend. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public struct RequestOptions { + /// The request’s timeout interval in seconds; if not specified uses the default value for a + /// `URLRequest`. + let timeout: TimeInterval? + + /// The API version to use in requests to the backend. + let apiVersion: String + + /// Initializes a request options object. + /// + /// - Parameters: + /// - timeout The request’s timeout interval in seconds; if not specified uses the default value + /// for a `URLRequest`. + /// - apiVersion The API version to use in requests to the backend; defaults to "v2beta". + public init(timeout: TimeInterval? = nil, apiVersion: String = "v2beta") { + self.timeout = timeout + self.apiVersion = apiVersion + } +} diff --git a/FirebaseVertexAI/Sources/GenerativeAIService.swift b/FirebaseVertexAI/Sources/GenerativeAIService.swift new file mode 100644 index 00000000000..6f0bda2d4ec --- /dev/null +++ b/FirebaseVertexAI/Sources/GenerativeAIService.swift @@ -0,0 +1,273 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import FirebaseAppCheckInterop +import FirebaseCore +import Foundation + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +struct GenerativeAIService { + /// Gives permission to talk to the backend. + private let apiKey: String + + private let appCheck: AppCheckInterop? + + private let urlSession: URLSession + + init(apiKey: String, appCheck: AppCheckInterop?, urlSession: URLSession) { + self.apiKey = apiKey + self.appCheck = appCheck + self.urlSession = urlSession + } + + func loadRequest(request: T) async throws -> T.Response { + let urlRequest = try await urlRequest(request: request) + + #if DEBUG + printCURLCommand(from: urlRequest) + #endif + + let data: Data + let rawResponse: URLResponse + (data, rawResponse) = try await urlSession.data(for: urlRequest) + + let response = try httpResponse(urlResponse: rawResponse) + + // Verify the status code is 200 + guard response.statusCode == 200 else { + Logging.default.error("[GoogleGenerativeAI] The server responded with an error: \(response)") + if let responseString = String(data: data, encoding: .utf8) { + Logging.network.error("[GoogleGenerativeAI] Response payload: \(responseString)") + } + + throw parseError(responseData: data) + } + + return try parseResponse(T.Response.self, from: data) + } + + @available(macOS 12.0, *) + func loadRequestStream(request: T) + -> AsyncThrowingStream { + return AsyncThrowingStream { continuation in + Task { + let urlRequest: URLRequest + do { + urlRequest = try await self.urlRequest(request: request) + } catch { + continuation.finish(throwing: error) + return + } + + #if DEBUG + printCURLCommand(from: urlRequest) + #endif + + let stream: URLSession.AsyncBytes + let rawResponse: URLResponse + do { + (stream, rawResponse) = try await urlSession.bytes(for: urlRequest) + } catch { + continuation.finish(throwing: error) + return + } + + // Verify the status code is 200 + let response: HTTPURLResponse + do { + response = try httpResponse(urlResponse: rawResponse) + } catch { + continuation.finish(throwing: error) + return + } + + // Verify the status code is 200 + guard response.statusCode == 200 else { + Logging.default + .error("[GoogleGenerativeAI] The server responded with an error: \(response)") + var responseBody = "" + for try await line in stream.lines { + responseBody += line + "\n" + } + + Logging.network.error("[GoogleGenerativeAI] Response payload: \(responseBody)") + continuation.finish(throwing: parseError(responseBody: responseBody)) + + return + } + + // Received lines that are not server-sent events (SSE); these are not prefixed with "data:" + var extraLines: String = "" + + let decoder = JSONDecoder() + decoder.keyDecodingStrategy = .convertFromSnakeCase + for try await line in stream.lines { + Logging.network.debug("[GoogleGenerativeAI] Stream response: \(line)") + + if line.hasPrefix("data:") { + // We can assume 5 characters since it's utf-8 encoded, removing `data:`. + let jsonText = String(line.dropFirst(5)) + let data: Data + do { + data = try jsonData(jsonText: jsonText) + } catch { + continuation.finish(throwing: error) + return + } + + // Handle the content. + do { + let content = try parseResponse(T.Response.self, from: data) + continuation.yield(content) + } catch { + continuation.finish(throwing: error) + return + } + } else { + extraLines += line + } + } + + if extraLines.count > 0 { + continuation.finish(throwing: parseError(responseBody: extraLines)) + return + } + + continuation.finish(throwing: nil) + } + } + } + + // MARK: - Private Helpers + + private func urlRequest(request: T) async throws -> URLRequest { + var urlRequest = URLRequest(url: request.url) + urlRequest.httpMethod = "POST" + urlRequest.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key") + // TODO: Determine the right client prefix to use. + urlRequest.setValue("genai-swift/\(FirebaseVersion())", + forHTTPHeaderField: "x-goog-api-client") + urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") + + if let appCheck { + let tokenResult = await appCheck.getToken(forcingRefresh: false) + urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck") + if let error = tokenResult.error { + Logging.default + .debug("[GoogleGenerativeAI] Failed to fetch AppCheck token. Error: \(error)") + } + } + + let encoder = JSONEncoder() + encoder.keyEncodingStrategy = .convertToSnakeCase + urlRequest.httpBody = try encoder.encode(request) + + if let timeoutInterval = request.options.timeout { + urlRequest.timeoutInterval = timeoutInterval + } + + return urlRequest + } + + private func httpResponse(urlResponse: URLResponse) throws -> HTTPURLResponse { + // Verify the status code is 200 + guard let response = urlResponse as? HTTPURLResponse else { + Logging.default + .error( + "[GoogleGenerativeAI] Response wasn't an HTTP response, internal error \(urlResponse)" + ) + throw NSError( + domain: "com.google.generative-ai", + code: -1, + userInfo: [NSLocalizedDescriptionKey: "Response was not an HTTP response."] + ) + } + + return response + } + + private func jsonData(jsonText: String) throws -> Data { + guard let data = jsonText.data(using: .utf8) else { + let error = NSError( + domain: "com.google.generative-ai", + code: -1, + userInfo: [NSLocalizedDescriptionKey: "Could not parse response as UTF8."] + ) + throw error + } + + return data + } + + private func parseError(responseBody: String) -> Error { + do { + let data = try jsonData(jsonText: responseBody) + return parseError(responseData: data) + } catch { + return error + } + } + + private func parseError(responseData: Data) -> Error { + do { + return try JSONDecoder().decode(RPCError.self, from: responseData) + } catch { + // TODO: Return an error about an unrecognized error payload with the response body + return error + } + } + + private func parseResponse(_ type: T.Type, from data: Data) throws -> T { + do { + return try JSONDecoder().decode(type, from: data) + } catch { + if let json = String(data: data, encoding: .utf8) { + Logging.network.error("[GoogleGenerativeAI] JSON response: \(json)") + } + Logging.default.error("[GoogleGenerativeAI] Error decoding server JSON: \(error)") + throw error + } + } + + #if DEBUG + private func cURLCommand(from request: URLRequest) -> String { + var returnValue = "curl " + if let allHeaders = request.allHTTPHeaderFields { + for (key, value) in allHeaders { + returnValue += "-H '\(key): \(value)' " + } + } + + guard let url = request.url else { return "" } + returnValue += "'\(url.absoluteString)' " + + guard let body = request.httpBody, + let jsonStr = String(bytes: body, encoding: .utf8) else { return "" } + let escapedJSON = jsonStr.replacingOccurrences(of: "'", with: "'\\''") + returnValue += "-d '\(escapedJSON)'" + + return returnValue + } + + private func printCURLCommand(from request: URLRequest) { + let command = cURLCommand(from: request) + Logging.verbose.debug(""" + [GoogleGenerativeAI] Creating request with the equivalent cURL command: + ----- cURL command ----- + \(command, privacy: .private) + ------------------------ + """) + } + #endif // DEBUG +} diff --git a/FirebaseVertexAI/Sources/GenerativeAISwift.swift b/FirebaseVertexAI/Sources/GenerativeAISwift.swift new file mode 100644 index 00000000000..0f0af2bc3e2 --- /dev/null +++ b/FirebaseVertexAI/Sources/GenerativeAISwift.swift @@ -0,0 +1,25 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +#if !os(macOS) && !os(iOS) + #warning("Only iOS, macOS, and Catalyst targets are currently fully supported.") +#endif + +/// Constants associated with the GenerativeAISwift SDK +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public enum GenerativeAISwift { + static let baseURL = "https://staging-firebaseml.sandbox.googleapis.com" +} diff --git a/FirebaseVertexAI/Sources/GenerativeModel.swift b/FirebaseVertexAI/Sources/GenerativeModel.swift new file mode 100644 index 00000000000..aad2b3e99c0 --- /dev/null +++ b/FirebaseVertexAI/Sources/GenerativeModel.swift @@ -0,0 +1,285 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import FirebaseAppCheckInterop +import Foundation + +/// A type that represents a remote multimodal model (like Gemini), with the ability to generate +/// content based on various input types. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public final class GenerativeModel { + // The prefix for a model resource in the Gemini API. + private static let modelResourcePrefix = "models/" + + /// The resource name of the model in the backend; has the format "models/model-name". + let modelResourceName: String + + /// The backing service responsible for sending and receiving model requests to the backend. + let generativeAIService: GenerativeAIService + + /// Configuration parameters used for the MultiModalModel. + let generationConfig: GenerationConfig? + + /// The safety settings to be used for prompts. + let safetySettings: [SafetySetting]? + + /// Configuration parameters for sending requests to the backend. + let requestOptions: RequestOptions + + /// Initializes a new remote model with the given parameters. + /// + /// - Parameters: + /// - name: The name of the model to use, e.g., `"gemini-1.0-pro"`; see + /// [Gemini models](https://ai.google.dev/models/gemini) for a list of supported model names. + /// - apiKey: The API key for your project. + /// - generationConfig: The content generation parameters your model should use. + /// - safetySettings: A value describing what types of harmful content your model should allow. + /// - requestOptions: Configuration parameters for sending requests to the backend. + /// - urlSession: The `URLSession` to use for requests; defaults to `URLSession.shared`. + init(name: String, + apiKey: String, + generationConfig: GenerationConfig? = nil, + safetySettings: [SafetySetting]? = nil, + requestOptions: RequestOptions, + appCheck: AppCheckInterop?, + urlSession: URLSession = .shared) { + modelResourceName = GenerativeModel.modelResourceName(name: name) + generativeAIService = GenerativeAIService( + apiKey: apiKey, + appCheck: appCheck, + urlSession: urlSession + ) + self.generationConfig = generationConfig + self.safetySettings = safetySettings + self.requestOptions = requestOptions + + Logging.default.info(""" + [GoogleGenerativeAI] Model \( + name, + privacy: .public + ) initialized. To enable additional logging, add \ + `\(Logging.enableArgumentKey, privacy: .public)` as a launch argument in Xcode. + """) + Logging.verbose.debug("[GoogleGenerativeAI] Verbose logging enabled.") + } + + /// Generates content from String and/or image inputs, given to the model as a prompt, that are + /// representable as one or more ``ModelContent/Part``s. + /// + /// Since ``ModelContent/Part``s do not specify a role, this method is intended for generating + /// content from + /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting) + /// or "direct" prompts. For + /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting) + /// prompts, see ``generateContent(_:)-58rm0``. + /// + /// - Parameter content: The input(s) given to the model as a prompt (see + /// ``ThrowingPartsRepresentable`` + /// for conforming types). + /// - Returns: The content generated by the model. + /// - Throws: A ``GenerateContentError`` if the request failed. + public func generateContent(_ parts: any ThrowingPartsRepresentable...) + async throws -> GenerateContentResponse { + return try await generateContent([ModelContent(parts: parts)]) + } + + /// Generates new content from input content given to the model as a prompt. + /// + /// - Parameter content: The input(s) given to the model as a prompt. + /// - Returns: The generated content response from the model. + /// - Throws: A ``GenerateContentError`` if the request failed. + public func generateContent(_ content: @autoclosure () throws -> [ModelContent]) async throws + -> GenerateContentResponse { + let response: GenerateContentResponse + do { + let generateContentRequest = try GenerateContentRequest(model: modelResourceName, + contents: content(), + generationConfig: generationConfig, + safetySettings: safetySettings, + isStreaming: false, + options: requestOptions) + response = try await generativeAIService.loadRequest(request: generateContentRequest) + } catch { + if let imageError = error as? ImageConversionError { + throw GenerateContentError.promptImageContentError(underlying: imageError) + } + throw GenerativeModel.generateContentError(from: error) + } + + // Check the prompt feedback to see if the prompt was blocked. + if response.promptFeedback?.blockReason != nil { + throw GenerateContentError.promptBlocked(response: response) + } + + // Check to see if an error should be thrown for stop reason. + if let reason = response.candidates.first?.finishReason, reason != .stop { + throw GenerateContentError.responseStoppedEarly(reason: reason, response: response) + } + + return response + } + + /// Generates content from String and/or image inputs, given to the model as a prompt, that are + /// representable as one or more ``ModelContent/Part``s. + /// + /// Since ``ModelContent/Part``s do not specify a role, this method is intended for generating + /// content from + /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting) + /// or "direct" prompts. For + /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting) + /// prompts, see ``generateContent(_:)-58rm0``. + /// + /// - Parameter content: The input(s) given to the model as a prompt (see + /// ``ThrowingPartsRepresentable`` + /// for conforming types). + /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError`` + /// error if an error occurred. + @available(macOS 12.0, *) + public func generateContentStream(_ parts: any ThrowingPartsRepresentable...) + -> AsyncThrowingStream { + return try generateContentStream([ModelContent(parts: parts)]) + } + + /// Generates new content from input content given to the model as a prompt. + /// + /// - Parameter content: The input(s) given to the model as a prompt. + /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError`` + /// error if an error occurred. + @available(macOS 12.0, *) + public func generateContentStream(_ content: @autoclosure () throws -> [ModelContent]) + -> AsyncThrowingStream { + let evaluatedContent: [ModelContent] + do { + evaluatedContent = try content() + } catch let underlying { + return AsyncThrowingStream { continuation in + let error: Error + if let contentError = underlying as? ImageConversionError { + error = GenerateContentError.promptImageContentError(underlying: contentError) + } else { + error = GenerateContentError.internalError(underlying: underlying) + } + continuation.finish(throwing: error) + } + } + + let generateContentRequest = GenerateContentRequest(model: modelResourceName, + contents: evaluatedContent, + generationConfig: generationConfig, + safetySettings: safetySettings, + isStreaming: true, + options: requestOptions) + + var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest) + .makeAsyncIterator() + return AsyncThrowingStream { + let response: GenerateContentResponse? + do { + response = try await responseIterator.next() + } catch { + throw GenerativeModel.generateContentError(from: error) + } + + // The responseIterator will return `nil` when it's done. + guard let response = response else { + // This is the end of the stream! Signal it by sending `nil`. + return nil + } + + // Check the prompt feedback to see if the prompt was blocked. + if response.promptFeedback?.blockReason != nil { + throw GenerateContentError.promptBlocked(response: response) + } + + // If the stream ended early unexpectedly, throw an error. + if let finishReason = response.candidates.first?.finishReason, finishReason != .stop { + throw GenerateContentError.responseStoppedEarly(reason: finishReason, response: response) + } else { + // Response was valid content, pass it along and continue. + return response + } + } + } + + /// Creates a new chat conversation using this model with the provided history. + public func startChat(history: [ModelContent] = []) -> Chat { + return Chat(model: self, history: history) + } + + /// Runs the model's tokenizer on String and/or image inputs that are representable as one or more + /// ``ModelContent/Part``s. + /// + /// Since ``ModelContent/Part``s do not specify a role, this method is intended for tokenizing + /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting) + /// or "direct" prompts. For + /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting) + /// input, see ``countTokens(_:)-9spwl``. + /// + /// - Parameter content: The input(s) given to the model as a prompt (see + /// ``ThrowingPartsRepresentable`` + /// for conforming types). + /// - Returns: The results of running the model's tokenizer on the input; contains + /// ``CountTokensResponse/totalTokens``. + /// - Throws: A ``CountTokensError`` if the tokenization request failed. + public func countTokens(_ parts: any ThrowingPartsRepresentable...) async throws + -> CountTokensResponse { + return try await countTokens([ModelContent(parts: parts)]) + } + + /// Runs the model's tokenizer on the input content and returns the token count. + /// + /// - Parameter content: The input given to the model as a prompt. + /// - Returns: The results of running the model's tokenizer on the input; contains + /// ``CountTokensResponse/totalTokens``. + /// - Throws: A ``CountTokensError`` if the tokenization request failed or the input content was + /// invalid. + public func countTokens(_ content: @autoclosure () throws -> [ModelContent]) async throws + -> CountTokensResponse { + do { + let countTokensRequest = try CountTokensRequest( + model: modelResourceName, + contents: content(), + options: requestOptions + ) + return try await generativeAIService.loadRequest(request: countTokensRequest) + } catch { + throw CountTokensError.internalError(underlying: error) + } + } + + /// Returns a model resource name of the form "models/model-name" based on `name`. + private static func modelResourceName(name: String) -> String { + if name.contains("/") { + return name + } else { + return modelResourcePrefix + name + } + } + + /// Returns a `GenerateContentError` (for public consumption) from an internal error. + /// + /// If `error` is already a `GenerateContentError` the error is returned unchanged. + private static func generateContentError(from error: Error) -> GenerateContentError { + if let error = error as? GenerateContentError { + return error + } + return GenerateContentError.internalError(underlying: error) + } +} + +/// See ``GenerativeModel/countTokens(_:)-9spwl``. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public enum CountTokensError: Error { + case internalError(underlying: Error) +} diff --git a/FirebaseVertexAI/Sources/Logging.swift b/FirebaseVertexAI/Sources/Logging.swift new file mode 100644 index 00000000000..458c34ed18d --- /dev/null +++ b/FirebaseVertexAI/Sources/Logging.swift @@ -0,0 +1,56 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation +import OSLog + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +struct Logging { + /// Subsystem that should be used for all Loggers. + static let subsystem = "com.google.generative-ai" + + /// Default category used for most loggers, unless specialized. + static let defaultCategory = "" + + /// The argument required to enable additional logging. + static let enableArgumentKey = "-GoogleGenerativeAIDebugLogEnabled" + + // No initializer available. + @available(*, unavailable) + private init() {} + + /// The default logger that is visible for all users. Note: we shouldn't be using anything lower + /// than `.notice`. + static var `default` = Logger(subsystem: subsystem, category: defaultCategory) + + /// A non default + static var network: Logger = { + if ProcessInfo.processInfo.arguments.contains(enableArgumentKey) { + return Logger(subsystem: subsystem, category: "NetworkResponse") + } else { + // Return a valid logger that's using `OSLog.disabled` as the logger, hiding everything. + return Logger(.disabled) + } + }() + + /// + static var verbose: Logger = { + if ProcessInfo.processInfo.arguments.contains(enableArgumentKey) { + return Logger(subsystem: subsystem, category: defaultCategory) + } else { + // Return a valid logger that's using `OSLog.disabled` as the logger, hiding everything. + return Logger(.disabled) + } + }() +} diff --git a/FirebaseVertexAI/Sources/ModelContent.swift b/FirebaseVertexAI/Sources/ModelContent.swift new file mode 100644 index 00000000000..44648c57852 --- /dev/null +++ b/FirebaseVertexAI/Sources/ModelContent.swift @@ -0,0 +1,141 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// A type describing data in media formats interpretable by an AI model. Each generative AI +/// request or response contains an `Array` of ``ModelContent``s, and each ``ModelContent`` value +/// may comprise multiple heterogeneous ``ModelContent/Part``s. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public struct ModelContent: Codable, Equatable { + /// A discrete piece of data in a media format intepretable by an AI model. Within a single value + /// of ``Part``, different data types may not mix. + public enum Part: Codable, Equatable { + enum CodingKeys: String, CodingKey { + case text + case inlineData + } + + enum InlineDataKeys: String, CodingKey { + case mimeType = "mime_type" + case bytes = "data" + } + + /// Text value. + case text(String) + + /// Data with a specified media type. Not all media types may be supported by the AI model. + case data(mimetype: String, Data) + + // MARK: Convenience Initializers + + /// Convenience function for populating a Part with JPEG data. + public static func jpeg(_ data: Data) -> Self { + return .data(mimetype: "image/jpeg", data) + } + + /// Convenience function for populating a Part with PNG data. + public static func png(_ data: Data) -> Self { + return .data(mimetype: "image/png", data) + } + + // MARK: Codable Conformance + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: ModelContent.Part.CodingKeys.self) + switch self { + case let .text(a0): + try container.encode(a0, forKey: .text) + case let .data(mimetype, bytes): + var inlineDataContainer = container.nestedContainer( + keyedBy: InlineDataKeys.self, + forKey: .inlineData + ) + try inlineDataContainer.encode(mimetype, forKey: .mimeType) + try inlineDataContainer.encode(bytes, forKey: .bytes) + } + } + + public init(from decoder: Decoder) throws { + let values = try decoder.container(keyedBy: CodingKeys.self) + if values.contains(.text) { + self = try .text(values.decode(String.self, forKey: .text)) + } else if values.contains(.inlineData) { + let dataContainer = try values.nestedContainer( + keyedBy: InlineDataKeys.self, + forKey: .inlineData + ) + let mimetype = try dataContainer.decode(String.self, forKey: .mimeType) + let bytes = try dataContainer.decode(Data.self, forKey: .bytes) + self = .data(mimetype: mimetype, bytes) + } else { + throw DecodingError.dataCorrupted(.init( + codingPath: [CodingKeys.text, CodingKeys.inlineData], + debugDescription: "Neither text or inline data was found." + )) + } + } + + /// Returns the text contents of this ``Part``, if it contains text. + public var text: String? { + switch self { + case let .text(contents): return contents + default: return nil + } + } + } + + /// The role of the entity creating the ``ModelContent``. For user-generated client requests, + /// for example, the role is `user`. + public let role: String? + + /// The data parts comprising this ``ModelContent`` value. + public let parts: [Part] + + /// Creates a new value from any data or `Array` of data interpretable as a + /// ``Part``. See ``ThrowingPartsRepresentable`` for types that can be interpreted as `Part`s. + public init(role: String? = "user", parts: some ThrowingPartsRepresentable) throws { + self.role = role + try self.parts = parts.tryPartsValue() + } + + /// Creates a new value from any data or `Array` of data interpretable as a + /// ``Part``. See ``ThrowingPartsRepresentable`` for types that can be interpreted as `Part`s. + public init(role: String? = "user", parts: some PartsRepresentable) { + self.role = role + self.parts = parts.partsValue + } + + /// Creates a new value from a list of ``Part``s. + public init(role: String? = "user", parts: [Part]) { + self.role = role + self.parts = parts + } + + /// Creates a new value from any data interpretable as a ``Part``. See + /// ``ThrowingPartsRepresentable`` + /// for types that can be interpreted as `Part`s. + public init(role: String? = "user", _ parts: any ThrowingPartsRepresentable...) throws { + let content = try parts.flatMap { try $0.tryPartsValue() } + self.init(role: role, parts: content) + } + + /// Creates a new value from any data interpretable as a ``Part``. See + /// ``ThrowingPartsRepresentable`` + /// for types that can be interpreted as `Part`s. + public init(role: String? = "user", _ parts: [PartsRepresentable]) { + let content = parts.flatMap { $0.partsValue } + self.init(role: role, parts: content) + } +} diff --git a/FirebaseVertexAI/Sources/PartsRepresentable+Image.swift b/FirebaseVertexAI/Sources/PartsRepresentable+Image.swift new file mode 100644 index 00000000000..052a004e52f --- /dev/null +++ b/FirebaseVertexAI/Sources/PartsRepresentable+Image.swift @@ -0,0 +1,107 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UniformTypeIdentifiers +#if canImport(UIKit) + import UIKit // For UIImage extensions. +#elseif canImport(AppKit) + import AppKit // For NSImage extensions. +#endif + +private let imageCompressionQuality: CGFloat = 0.8 + +/// An enum describing failures that can occur when converting image types to model content data. +/// For some image types like `CIImage`, creating valid model content requires creating a JPEG +/// representation of the image that may not yet exist, which may be computationally expensive. +public enum ImageConversionError: Error { + /// The image (the receiver of the call `toModelContentParts()`) was invalid. + case invalidUnderlyingImage + + /// A valid image destination could not be allocated. + case couldNotAllocateDestination + + /// JPEG image data conversion failed, accompanied by the original image, which may be an + /// instance of `NSImageRep`, `UIImage`, `CGImage`, or `CIImage`. + case couldNotConvertToJPEG(Any) +} + +#if canImport(UIKit) + /// Enables images to be representable as ``ThrowingPartsRepresentable``. + @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) + extension UIImage: ThrowingPartsRepresentable { + public func tryPartsValue() throws -> [ModelContent.Part] { + guard let data = jpegData(compressionQuality: imageCompressionQuality) else { + throw ImageConversionError.couldNotConvertToJPEG(self) + } + return [ModelContent.Part.data(mimetype: "image/jpeg", data)] + } + } + +#elseif canImport(AppKit) + /// Enables images to be representable as ``ThrowingPartsRepresentable``. + @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) + extension NSImage: ThrowingPartsRepresentable { + public func tryPartsValue() throws -> [ModelContent.Part] { + guard let cgImage = cgImage(forProposedRect: nil, context: nil, hints: nil) else { + throw ImageConversionError.invalidUnderlyingImage + } + let bmp = NSBitmapImageRep(cgImage: cgImage) + guard let data = bmp.representation(using: .jpeg, properties: [.compressionFactor: 0.8]) + else { + throw ImageConversionError.couldNotConvertToJPEG(bmp) + } + return [ModelContent.Part.data(mimetype: "image/jpeg", data)] + } + } +#endif + +/// Enables `CGImages` to be representable as model content. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension CGImage: ThrowingPartsRepresentable { + public func tryPartsValue() throws -> [ModelContent.Part] { + let output = NSMutableData() + guard let imageDestination = CGImageDestinationCreateWithData( + output, UTType.jpeg.identifier as CFString, 1, nil + ) else { + throw ImageConversionError.couldNotAllocateDestination + } + CGImageDestinationAddImage(imageDestination, self, nil) + CGImageDestinationSetProperties(imageDestination, [ + kCGImageDestinationLossyCompressionQuality: imageCompressionQuality, + ] as CFDictionary) + if CGImageDestinationFinalize(imageDestination) { + return [.data(mimetype: "image/jpeg", output as Data)] + } + throw ImageConversionError.couldNotConvertToJPEG(self) + } +} + +/// Enables `CIImages` to be representable as model content. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension CIImage: ThrowingPartsRepresentable { + public func tryPartsValue() throws -> [ModelContent.Part] { + let context = CIContext() + let jpegData = (colorSpace ?? CGColorSpace(name: CGColorSpace.sRGB)) + .flatMap { + // The docs specify kCGImageDestinationLossyCompressionQuality as a supported option, but + // Swift's type system does not allow this. + // [kCGImageDestinationLossyCompressionQuality: imageCompressionQuality] + context.jpegRepresentation(of: self, colorSpace: $0, options: [:]) + } + if let jpegData = jpegData { + return [.data(mimetype: "image/jpeg", jpegData)] + } + throw ImageConversionError.couldNotConvertToJPEG(self) + } +} diff --git a/FirebaseVertexAI/Sources/PartsRepresentable.swift b/FirebaseVertexAI/Sources/PartsRepresentable.swift new file mode 100644 index 00000000000..05ba0d9dabc --- /dev/null +++ b/FirebaseVertexAI/Sources/PartsRepresentable.swift @@ -0,0 +1,66 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// A protocol describing any data that could be serialized to model-interpretable input data, +/// where the serialization process might fail with an error. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public protocol ThrowingPartsRepresentable { + func tryPartsValue() throws -> [ModelContent.Part] +} + +/// A protocol describing any data that could be serialized to model-interpretable input data, +/// where the serialization process cannot fail with an error. For a failable conversion, see +/// ``ThrowingPartsRepresentable`` +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public protocol PartsRepresentable: ThrowingPartsRepresentable { + var partsValue: [ModelContent.Part] { get } +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public extension PartsRepresentable { + func tryPartsValue() throws -> [ModelContent.Part] { + return partsValue + } +} + +/// Enables a ``ModelContent.Part`` to be passed in as ``ThrowingPartsRepresentable``. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension ModelContent.Part: ThrowingPartsRepresentable { + public typealias ErrorType = Never + public func tryPartsValue() throws -> [ModelContent.Part] { + return [self] + } +} + +/// Enable an `Array` of ``ThrowingPartsRepresentable`` values to be passed in as a single +/// ``ThrowingPartsRepresentable``. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension [ThrowingPartsRepresentable]: ThrowingPartsRepresentable { + public func tryPartsValue() throws -> [ModelContent.Part] { + return try compactMap { element in + try element.tryPartsValue() + } + .flatMap { $0 } + } +} + +/// Enables a `String` to be passed in as ``ThrowingPartsRepresentable``. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension String: PartsRepresentable { + public var partsValue: [ModelContent.Part] { + return [.text(self)] + } +} diff --git a/FirebaseVertexAI/Sources/Safety.swift b/FirebaseVertexAI/Sources/Safety.swift new file mode 100644 index 00000000000..d59511f2476 --- /dev/null +++ b/FirebaseVertexAI/Sources/Safety.swift @@ -0,0 +1,182 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// A type defining potentially harmful media categories and their model-assigned ratings. A value +/// of this type may be assigned to a category for every model-generated response, not just +/// responses that exceed a certain threshold. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public struct SafetyRating: Decodable, Equatable, Hashable { + /// The category describing the potential harm a piece of content may pose. See + /// ``SafetySetting/HarmCategory`` for a list of possible values. + public let category: SafetySetting.HarmCategory + + /// The model-generated probability that a given piece of content falls under the harm category + /// described in ``category``. This does not + /// indiciate the severity of harm for a piece of content. See ``HarmProbability`` for a list of + /// possible values. + public let probability: HarmProbability + + /// Initializes a new `SafetyRating` instance with the given category and probability. + /// Use this initializer for SwiftUI previews or tests. + public init(category: SafetySetting.HarmCategory, probability: HarmProbability) { + self.category = category + self.probability = probability + } + + /// The probability that a given model output falls under a harmful content category. This does + /// not indicate the severity of harm for a piece of content. + public enum HarmProbability: String, Codable { + /// Unknown. A new server value that isn't recognized by the SDK. + case unknown = "UNKNOWN" + + /// The probability was not specified in the server response. + case unspecified = "HARM_PROBABILITY_UNSPECIFIED" + + /// The probability is zero or close to zero. For benign content, the probability across all + /// categories will be this value. + case negligible = "NEGLIGIBLE" + + /// The probability is small but non-zero. + case low = "LOW" + + /// The probability is moderate. + case medium = "MEDIUM" + + /// The probability is high. The content described is very likely harmful. + case high = "HIGH" + + /// Initializes a new `SafetyRating` from a decoder. + /// Not for external use. Initializer required for Decodable conformance. + public init(from decoder: Decoder) throws { + let value = try decoder.singleValueContainer().decode(String.self) + guard let decodedProbability = HarmProbability(rawValue: value) else { + Logging.default + .error("[GoogleGenerativeAI] Unrecognized HarmProbability with value \"\(value)\".") + self = .unknown + return + } + + self = decodedProbability + } + } +} + +/// Safety feedback for an entire request. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public struct SafetyFeedback: Decodable { + /// Safety rating evaluated from content. + public let rating: SafetyRating + + /// Safety settings applied to the request. + public let setting: SafetySetting + + /// Internal initializer. + init(rating: SafetyRating, setting: SafetySetting) { + self.rating = rating + self.setting = setting + } +} + +/// A type used to specify a threshold for harmful content, beyond which the model will return a +/// fallback response instead of generated content. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public struct SafetySetting: Codable { + /// A type describing safety attributes, which include harmful categories and topics that can + /// be considered sensitive. + public enum HarmCategory: String, Codable { + /// Unknown. A new server value that isn't recognized by the SDK. + case unknown = "HARM_CATEGORY_UNKNOWN" + + /// Unspecified by the server. + case unspecified = "HARM_CATEGORY_UNSPECIFIED" + + /// Harassment content. + case harassment = "HARM_CATEGORY_HARASSMENT" + + /// Negative or harmful comments targeting identity and/or protected attributes. + case hateSpeech = "HARM_CATEGORY_HATE_SPEECH" + + /// Contains references to sexual acts or other lewd content. + case sexuallyExplicit = "HARM_CATEGORY_SEXUALLY_EXPLICIT" + + /// Promotes or enables access to harmful goods, services, or activities. + case dangerousContent = "HARM_CATEGORY_DANGEROUS_CONTENT" + + /// Do not explicitly use. Initializer required for Decodable conformance. + public init(from decoder: Decoder) throws { + let value = try decoder.singleValueContainer().decode(String.self) + guard let decodedCategory = HarmCategory(rawValue: value) else { + Logging.default + .error("[GoogleGenerativeAI] Unrecognized HarmCategory with value \"\(value)\".") + self = .unknown + return + } + + self = decodedCategory + } + } + + /// Block at and beyond a specified ``SafetyRating/HarmProbability``. + public enum BlockThreshold: String, Codable { + /// Unknown. A new server value that isn't recognized by the SDK. + case unknown = "UNKNOWN" + + /// Threshold is unspecified. + case unspecified = "HARM_BLOCK_THRESHOLD_UNSPECIFIED" + + // Content with `.negligible` will be allowed. + case blockLowAndAbove = "BLOCK_LOW_AND_ABOVE" + + /// Content with `.negligible` and `.low` will be allowed. + case blockMediumAndAbove = "BLOCK_MEDIUM_AND_ABOVE" + + /// Content with `.negligible`, `.low`, and `.medium` will be allowed. + case blockOnlyHigh = "BLOCK_ONLY_HIGH" + + /// All content will be allowed. + case blockNone = "BLOCK_NONE" + + /// Do not explicitly use. Initializer required for Decodable conformance. + public init(from decoder: Decoder) throws { + let value = try decoder.singleValueContainer().decode(String.self) + guard let decodedThreshold = BlockThreshold(rawValue: value) else { + Logging.default + .error("[GoogleGenerativeAI] Unrecognized BlockThreshold with value \"\(value)\".") + self = .unknown + return + } + + self = decodedThreshold + } + } + + enum CodingKeys: String, CodingKey { + case harmCategory = "category" + case threshold + } + + /// The category this safety setting should be applied to. + public let harmCategory: HarmCategory + + /// The threshold describing what content should be blocked. + public let threshold: BlockThreshold + + /// Initializes a new safety setting with the given category and threshold. + public init(harmCategory: HarmCategory, threshold: BlockThreshold) { + self.harmCategory = harmCategory + self.threshold = threshold + } +} diff --git a/FirebaseVertexAI/Sources/VertexAI.swift b/FirebaseVertexAI/Sources/VertexAI.swift index 8fb0f9e8c68..16fc0d20754 100644 --- a/FirebaseVertexAI/Sources/VertexAI.swift +++ b/FirebaseVertexAI/Sources/VertexAI.swift @@ -12,13 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -import Foundation - import FirebaseAppCheckInterop import FirebaseCore - -// Exports the GoogleGenerativeAI module to users of the SDK. -@_exported import GoogleGenerativeAI +import Foundation // Avoids exposing internal FirebaseCore APIs to Swift users. @_implementationOnly import FirebaseCoreExtension @@ -31,8 +27,7 @@ open class VertexAI: NSObject { /// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API. /// /// This instance is configured with the default `FirebaseApp`. - public static func generativeModel(modelName: String, location: String) -> GoogleGenerativeAI - .GenerativeModel { + public static func generativeModel(modelName: String, location: String) -> GenerativeModel { guard let app = FirebaseApp.app() else { fatalError("No instance of the default Firebase app was found.") } @@ -41,7 +36,7 @@ open class VertexAI: NSObject { /// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API. public static func generativeModel(app: FirebaseApp, modelName: String, - location: String) -> GoogleGenerativeAI.GenerativeModel { + location: String) -> GenerativeModel { guard let provider = ComponentType.instance(for: VertexAIProvider.self, in: app.container) else { fatalError("No \(VertexAIProvider.self) instance found for Firebase app: \(app.name)") @@ -64,18 +59,15 @@ open class VertexAI: NSObject { private let modelResouceName: String lazy var model: GenerativeModel = { - let options = RequestOptions( - apiVersion: "v2beta", - endpoint: "staging-firebaseml.sandbox.googleapis.com", - hooks: [addAppCheckHeader] - ) guard let apiKey = app.options.apiKey else { fatalError("The Firebase app named \"\(app.name)\" has no API key in its configuration.") } return GenerativeModel( name: modelResouceName, apiKey: apiKey, - requestOptions: options + // TODO: Consider adding RequestOptions to public API. + requestOptions: RequestOptions(), + appCheck: appCheck ) }() @@ -104,21 +96,4 @@ open class VertexAI: NSObject { return "projects/\(projectID)/locations/\(location)/publishers/google/models/\(modelName)" } - - // MARK: Request Hooks - - /// Adds an App Check token to the provided request if App Check is included in the app. - /// - /// This demonstrates how an App Check token can be added to requests; it is currently ignored by - /// the backend. - /// - /// - Parameter request: The `URLRequest` to modify by adding an App Check token header. - func addAppCheckHeader(request: inout URLRequest) async { - guard let appCheck else { - return - } - - let tokenResult = await appCheck.getToken(forcingRefresh: false) - request.addValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck") - } } diff --git a/Package.swift b/Package.swift index b1c20caeefd..5bcebc34ed2 100644 --- a/Package.swift +++ b/Package.swift @@ -187,10 +187,6 @@ let package = Package( "100.0.0" ..< "101.0.0" ), .package(url: "https://github.com/google/app-check.git", "10.18.0" ..< "11.0.0"), - .package( - url: "https://github.com/google/generative-ai-swift.git", - revision: "c9f2c4913bc65aa267815962c7e91358c2d8463f" - ), ], targets: [ .target( @@ -1363,7 +1359,6 @@ let package = Package( "FirebaseAppCheckInterop", "FirebaseCore", "FirebaseCoreExtension", - .product(name: "GoogleGenerativeAI", package: "generative-ai-swift"), ], path: "FirebaseVertexAI/Sources" ),