Skip to content

Commit

Permalink
[Vertex AI] Update Imagen public APIs to match API proposal
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Jan 28, 2025
1 parent d0e2014 commit c206ff9
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenFileDataImage {
public struct ImagenGCSImage {
public let mimeType: String
public let gcsURI: String

Expand All @@ -26,7 +26,7 @@ public struct ImagenFileDataImage {
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenFileDataImage: ImagenImageRepresentable {
extension ImagenGCSImage: ImagenImageRepresentable {
// TODO(andrewheard): Make this public when the SDK supports Imagen operations that take images as
// input (upscaling / editing).
var _internalImagenImage: _InternalImagenImage {
Expand All @@ -35,12 +35,12 @@ extension ImagenFileDataImage: ImagenImageRepresentable {
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenFileDataImage: Equatable {}
extension ImagenGCSImage: Equatable {}

// MARK: - Codable Conformances

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenFileDataImage: Decodable {
extension ImagenGCSImage: Decodable {
enum CodingKeys: String, CodingKey {
case mimeType
case gcsURI = "gcsUri"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenGenerationConfig {
public var numberOfImages: Int?
public var negativePrompt: String?
public var imageFormat: ImagenImageFormat?
public var numberOfImages: Int?
public var aspectRatio: ImagenAspectRatio?
public var imageFormat: ImagenImageFormat?
public var addWatermark: Bool?

public init(numberOfImages: Int? = nil, negativePrompt: String? = nil,
imageFormat: ImagenImageFormat? = nil, aspectRatio: ImagenAspectRatio? = nil,
public init(negativePrompt: String? = nil, numberOfImages: Int? = nil,
aspectRatio: ImagenAspectRatio? = nil, imageFormat: ImagenImageFormat? = nil,
addWatermark: Bool? = nil) {
self.numberOfImages = numberOfImages
self.negativePrompt = negativePrompt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenInlineDataImage {
public struct ImagenInlineImage {
public let mimeType: String
public let data: Data

Expand All @@ -30,7 +30,7 @@ public struct ImagenInlineDataImage {
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenInlineDataImage: ImagenImageRepresentable {
extension ImagenInlineImage: ImagenImageRepresentable {
// TODO(andrewheard): Make this public when the SDK supports Imagen operations that take images as
// input (upscaling / editing).
var _internalImagenImage: _InternalImagenImage {
Expand All @@ -43,12 +43,12 @@ extension ImagenInlineDataImage: ImagenImageRepresentable {
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenInlineDataImage: Equatable {}
extension ImagenInlineImage: Equatable {}

// MARK: - Codable Conformances

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenInlineDataImage: Decodable {
extension ImagenInlineImage: Decodable {
enum CodingKeys: CodingKey {
case mimeType
case bytesBase64Encoded
Expand Down
16 changes: 9 additions & 7 deletions FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public final class ImagenModel {
/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService

let generationConfig: ImagenGenerationConfig?

let safetySettings: ImagenSafetySettings?

/// Configuration parameters for sending requests to the backend.
Expand All @@ -32,6 +34,7 @@ public final class ImagenModel {
init(name: String,
projectID: String,
apiKey: String,
generationConfig: ImagenGenerationConfig?,
safetySettings: ImagenSafetySettings?,
requestOptions: RequestOptions,
appCheck: AppCheckInterop?,
Expand All @@ -45,13 +48,13 @@ public final class ImagenModel {
auth: auth,
urlSession: urlSession
)
self.generationConfig = generationConfig
self.safetySettings = safetySettings
self.requestOptions = requestOptions
}

public func generateImages(prompt: String,
generationConfig: ImagenGenerationConfig? = nil) async throws
-> ImagenGenerationResponse<ImagenInlineDataImage> {
public func generateImages(prompt: String) async throws
-> ImagenGenerationResponse<ImagenInlineImage> {
return try await generateImages(
prompt: prompt,
parameters: ImagenModel.imageGenerationParameters(
Expand All @@ -62,13 +65,12 @@ public final class ImagenModel {
)
}

public func generateImages(prompt: String, storageURI: String,
generationConfig: ImagenGenerationConfig? = nil) async throws
-> ImagenGenerationResponse<ImagenFileDataImage> {
public func generateImages(prompt: String, gcsUri: String) async throws
-> ImagenGenerationResponse<ImagenGCSImage> {
return try await generateImages(
prompt: prompt,
parameters: ImagenModel.imageGenerationParameters(
storageURI: storageURI,
storageURI: gcsUri,
generationConfig: generationConfig,
safetySettings: safetySettings
)
Expand Down
4 changes: 3 additions & 1 deletion FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,14 @@ public class VertexAI {
)
}

public func imagenModel(modelName: String, safetySettings: ImagenSafetySettings? = nil,
public func imagenModel(modelName: String, generationConfig: ImagenGenerationConfig? = nil,
safetySettings: ImagenSafetySettings? = nil,
requestOptions: RequestOptions = RequestOptions()) -> ImagenModel {
return ImagenModel(
name: modelResourceName(modelName: modelName),
projectID: projectID,
apiKey: apiKey,
generationConfig: generationConfig,
safetySettings: safetySettings,
requestOptions: requestOptions,
appCheck: appCheck,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ final class ImageGenerationParametersTests: XCTestCase {
let aspectRatio = ImagenAspectRatio.landscape16x9
let addWatermark = true
let generationConfig = ImagenGenerationConfig(
numberOfImages: sampleCount,
negativePrompt: negativePrompt,
imageFormat: imageFormat,
numberOfImages: sampleCount,
aspectRatio: aspectRatio,
imageFormat: imageFormat,
addWatermark: addWatermark
)
let expectedParameters = ImageGenerationParameters(
Expand Down Expand Up @@ -146,10 +146,10 @@ final class ImageGenerationParametersTests: XCTestCase {
let aspectRatio = ImagenAspectRatio.portrait3x4
let addWatermark = false
let generationConfig = ImagenGenerationConfig(
numberOfImages: sampleCount,
negativePrompt: negativePrompt,
imageFormat: imageFormat,
numberOfImages: sampleCount,
aspectRatio: aspectRatio,
imageFormat: imageFormat,
addWatermark: addWatermark
)
let safetyFilterLevel = ImagenSafetyFilterLevel.blockNone
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import XCTest
@testable import FirebaseVertexAI

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class ImagenFileDataImageTests: XCTestCase {
final class ImagenGCSImageTests: XCTestCase {
let decoder = JSONDecoder()

func testDecodeImage_gcsURI() throws {
Expand All @@ -31,7 +31,7 @@ final class ImagenFileDataImageTests: XCTestCase {
"""
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let image = try decoder.decode(ImagenFileDataImage.self, from: jsonData)
let image = try decoder.decode(ImagenGCSImage.self, from: jsonData)

XCTAssertEqual(image.mimeType, mimeType)
XCTAssertEqual(image.gcsURI, gcsURI)
Expand All @@ -49,10 +49,10 @@ final class ImagenFileDataImageTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

do {
_ = try decoder.decode(ImagenFileDataImage.self, from: jsonData)
_ = try decoder.decode(ImagenGCSImage.self, from: jsonData)
XCTFail("Expected an error; none thrown.")
} catch let DecodingError.keyNotFound(codingKey, _) {
let codingKey = try XCTUnwrap(codingKey as? ImagenFileDataImage.CodingKeys)
let codingKey = try XCTUnwrap(codingKey as? ImagenGCSImage.CodingKeys)
XCTAssertEqual(codingKey, .gcsURI)
} catch {
XCTFail("Expected a DecodingError.keyNotFound error; got \(error).")
Expand All @@ -68,10 +68,10 @@ final class ImagenFileDataImageTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

do {
_ = try decoder.decode(ImagenFileDataImage.self, from: jsonData)
_ = try decoder.decode(ImagenGCSImage.self, from: jsonData)
XCTFail("Expected an error; none thrown.")
} catch let DecodingError.keyNotFound(codingKey, _) {
let codingKey = try XCTUnwrap(codingKey as? ImagenFileDataImage.CodingKeys)
let codingKey = try XCTUnwrap(codingKey as? ImagenGCSImage.CodingKeys)
XCTAssertEqual(codingKey, .mimeType)
} catch {
XCTFail("Expected a DecodingError.keyNotFound error; got \(error).")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
}

func testInitializeRequest_inlineDataImage() throws {
let request = ImagenGenerationRequest<ImagenInlineDataImage>(
let request = ImagenGenerationRequest<ImagenInlineImage>(
model: modelName,
options: requestOptions,
instances: [instance],
Expand All @@ -62,7 +62,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
}

func testInitializeRequest_fileDataImage() throws {
let request = ImagenGenerationRequest<ImagenFileDataImage>(
let request = ImagenGenerationRequest<ImagenGCSImage>(
model: modelName,
options: requestOptions,
instances: [instance],
Expand All @@ -82,7 +82,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
// MARK: - Encoding Tests

func testEncodeRequest_inlineDataImage() throws {
let request = ImagenGenerationRequest<ImagenInlineDataImage>(
let request = ImagenGenerationRequest<ImagenInlineImage>(
model: modelName,
options: RequestOptions(),
instances: [instance],
Expand Down Expand Up @@ -110,7 +110,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
}

func testEncodeRequest_fileDataImage() throws {
let request = ImagenGenerationRequest<ImagenFileDataImage>(
let request = ImagenGenerationRequest<ImagenGCSImage>(
model: modelName,
options: RequestOptions(),
instances: [instance],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ final class ImagenGenerationResponseTests: XCTestCase {
func testDecodeResponse_oneBase64Image_noneFiltered() throws {
let mimeType = "image/png"
let bytesBase64Encoded = "dGVzdC1iYXNlNjQtZGF0YQ=="
let image = ImagenInlineDataImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded)
let image = ImagenInlineImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded)
let json = """
{
"predictions": [
Expand All @@ -37,7 +37,7 @@ final class ImagenGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImagenGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineImage>.self,
from: jsonData
)

Expand All @@ -50,9 +50,9 @@ final class ImagenGenerationResponseTests: XCTestCase {
let bytesBase64Encoded1 = "dGVzdC1iYXNlNjQtYnl0ZXMtMQ=="
let bytesBase64Encoded2 = "dGVzdC1iYXNlNjQtYnl0ZXMtMg=="
let bytesBase64Encoded3 = "dGVzdC1iYXNlNjQtYnl0ZXMtMw=="
let image1 = ImagenInlineDataImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded1)
let image2 = ImagenInlineDataImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded2)
let image3 = ImagenInlineDataImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded3)
let image1 = ImagenInlineImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded1)
let image2 = ImagenInlineImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded2)
let image3 = ImagenInlineImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded3)
let json = """
{
"predictions": [
Expand All @@ -74,7 +74,7 @@ final class ImagenGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImagenGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineImage>.self,
from: jsonData
)

Expand All @@ -86,8 +86,8 @@ final class ImagenGenerationResponseTests: XCTestCase {
let mimeType = "image/png"
let bytesBase64Encoded1 = "dGVzdC1iYXNlNjQtYnl0ZXMtMQ=="
let bytesBase64Encoded2 = "dGVzdC1iYXNlNjQtYnl0ZXMtMg=="
let image1 = ImagenInlineDataImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded1)
let image2 = ImagenInlineDataImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded2)
let image1 = ImagenInlineImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded1)
let image2 = ImagenInlineImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded2)
let raiFilteredReason = """
Your current safety filter threshold filtered out 2 generated images. You will not be charged \
for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback.
Expand All @@ -112,7 +112,7 @@ final class ImagenGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImagenGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineImage>.self,
from: jsonData
)

Expand All @@ -124,8 +124,8 @@ final class ImagenGenerationResponseTests: XCTestCase {
let mimeType = "image/png"
let gcsURI1 = "gs://test-bucket/images/123456789/sample_0.png"
let gcsURI2 = "gs://test-bucket/images/123456789/sample_1.png"
let image1 = ImagenFileDataImage(mimeType: mimeType, gcsURI: gcsURI1)
let image2 = ImagenFileDataImage(mimeType: mimeType, gcsURI: gcsURI2)
let image1 = ImagenGCSImage(mimeType: mimeType, gcsURI: gcsURI1)
let image2 = ImagenGCSImage(mimeType: mimeType, gcsURI: gcsURI2)
let json = """
{
"predictions": [
Expand All @@ -143,7 +143,7 @@ final class ImagenGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImagenGenerationResponse<ImagenFileDataImage>.self,
ImagenGenerationResponse<ImagenGCSImage>.self,
from: jsonData
)

Expand All @@ -169,7 +169,7 @@ final class ImagenGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImagenGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineImage>.self,
from: jsonData
)

Expand All @@ -182,7 +182,7 @@ final class ImagenGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImagenGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineImage>.self,
from: jsonData
)

Expand All @@ -208,7 +208,7 @@ final class ImagenGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImagenGenerationResponse<ImagenFileDataImage>.self,
ImagenGenerationResponse<ImagenGCSImage>.self,
from: jsonData
)

Expand All @@ -230,7 +230,7 @@ final class ImagenGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImagenGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineImage>.self,
from: jsonData
)

Expand Down
Loading

0 comments on commit c206ff9

Please sign in to comment.