Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vertex AI] Move ImagenModelConfig params to ImagenGenerationConfig #14340

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@
public struct ImagenGenerationConfig {
public var numberOfImages: Int?
public var negativePrompt: String?
public var imageFormat: ImagenImageFormat?
public var aspectRatio: ImagenAspectRatio?
public var addWatermark: Bool?

public init(numberOfImages: Int? = nil, negativePrompt: String? = nil,
aspectRatio: ImagenAspectRatio? = nil) {
imageFormat: ImagenImageFormat? = nil, aspectRatio: ImagenAspectRatio? = nil,
addWatermark: Bool? = nil) {
self.numberOfImages = numberOfImages
self.negativePrompt = negativePrompt
self.imageFormat = imageFormat
self.aspectRatio = aspectRatio
self.addWatermark = addWatermark
}
}
11 changes: 2 additions & 9 deletions FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ public final class ImagenModel {
/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService

let modelConfig: ImagenModelConfig?

let safetySettings: ImagenSafetySettings?

/// Configuration parameters for sending requests to the backend.
Expand All @@ -34,7 +32,6 @@ public final class ImagenModel {
init(name: String,
projectID: String,
apiKey: String,
modelConfig: ImagenModelConfig?,
safetySettings: ImagenSafetySettings?,
requestOptions: RequestOptions,
appCheck: AppCheckInterop?,
Expand All @@ -48,7 +45,6 @@ public final class ImagenModel {
auth: auth,
urlSession: urlSession
)
self.modelConfig = modelConfig
self.safetySettings = safetySettings
self.requestOptions = requestOptions
}
Expand All @@ -61,7 +57,6 @@ public final class ImagenModel {
parameters: ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: generationConfig,
modelConfig: modelConfig,
safetySettings: safetySettings
)
)
Expand All @@ -75,7 +70,6 @@ public final class ImagenModel {
parameters: ImagenModel.imageGenerationParameters(
storageURI: storageURI,
generationConfig: generationConfig,
modelConfig: modelConfig,
safetySettings: safetySettings
)
)
Expand All @@ -96,7 +90,6 @@ public final class ImagenModel {

static func imageGenerationParameters(storageURI: String?,
generationConfig: ImagenGenerationConfig?,
modelConfig: ImagenModelConfig?,
safetySettings: ImagenSafetySettings?)
-> ImageGenerationParameters {
return ImageGenerationParameters(
Expand All @@ -106,13 +99,13 @@ public final class ImagenModel {
aspectRatio: generationConfig?.aspectRatio?.rawValue,
safetyFilterLevel: safetySettings?.safetyFilterLevel?.rawValue,
personGeneration: safetySettings?.personFilterLevel?.rawValue,
outputOptions: modelConfig?.imageFormat.map {
outputOptions: generationConfig?.imageFormat.map {
ImageGenerationOutputOptions(
mimeType: $0.mimeType,
compressionQuality: $0.compressionQuality
)
},
addWatermark: modelConfig?.addWatermark,
addWatermark: generationConfig?.addWatermark,
includeResponsibleAIFilterReason: true
)
}
Expand Down

This file was deleted.

4 changes: 1 addition & 3 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,12 @@ public class VertexAI {
)
}

public func imagenModel(modelName: String, modelConfig: ImagenModelConfig? = nil,
safetySettings: ImagenSafetySettings? = nil,
public func imagenModel(modelName: String, safetySettings: ImagenSafetySettings? = nil,
requestOptions: RequestOptions = RequestOptions()) -> ImagenModel {
return ImagenModel(
name: modelResourceName(modelName: modelName),
projectID: projectID,
apiKey: apiKey,
modelConfig: modelConfig,
safetySettings: safetySettings,
requestOptions: requestOptions,
appCheck: appCheck,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ final class IntegrationTests: XCTestCase {
)
imagenModel = vertex.imagenModel(
modelName: "imagen-3.0-fast-generate-001",
modelConfig: ImagenModelConfig(imageFormat: .jpeg(compressionQuality: 70)),
safetySettings: ImagenSafetySettings(
safetyFilterLevel: .blockLowAndAbove,
personFilterLevel: .blockAll
Expand Down Expand Up @@ -254,7 +253,10 @@ final class IntegrationTests: XCTestCase {
overlooking a vast African savanna at sunset. Golden hour light, long shadows, sharp focus on
the lion, shallow depth of field, detailed fur texture, DSLR, 85mm lens.
"""
let generationConfig = ImagenGenerationConfig(aspectRatio: .landscape16x9)
let generationConfig = ImagenGenerationConfig(
imageFormat: .jpeg(compressionQuality: 70),
aspectRatio: .landscape16x9
)

let response = try await imagenModel.generateImages(
prompt: imagePrompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: nil,
modelConfig: nil,
safetySettings: nil
)

Expand All @@ -64,37 +63,6 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: storageURI,
generationConfig: nil,
modelConfig: nil,
safetySettings: nil
)

XCTAssertEqual(parameters, expectedParameters)
}

func testParameters_includeModelConfig() throws {
let compressionQuality = 80
let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality)
let addWatermark = true
let modelConfig = ImagenModelConfig(imageFormat: imageFormat, addWatermark: addWatermark)
let expectedParameters = ImageGenerationParameters(
sampleCount: 1,
storageURI: nil,
negativePrompt: nil,
aspectRatio: nil,
safetyFilterLevel: nil,
personGeneration: nil,
outputOptions: ImageGenerationOutputOptions(
mimeType: imageFormat.mimeType,
compressionQuality: imageFormat.compressionQuality
),
addWatermark: addWatermark,
includeResponsibleAIFilterReason: true
)

let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: nil,
modelConfig: modelConfig,
safetySettings: nil
)

Expand All @@ -104,11 +72,16 @@ final class ImageGenerationParametersTests: XCTestCase {
func testParameters_includeGenerationConfig() throws {
let sampleCount = 2
let negativePrompt = "test-negative-prompt"
let compressionQuality = 80
let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality)
let aspectRatio = ImagenAspectRatio.landscape16x9
let addWatermark = true
let generationConfig = ImagenGenerationConfig(
numberOfImages: sampleCount,
negativePrompt: negativePrompt,
aspectRatio: aspectRatio
imageFormat: imageFormat,
aspectRatio: aspectRatio,
addWatermark: addWatermark
)
let expectedParameters = ImageGenerationParameters(
sampleCount: sampleCount,
Expand All @@ -117,15 +90,17 @@ final class ImageGenerationParametersTests: XCTestCase {
aspectRatio: aspectRatio.rawValue,
safetyFilterLevel: nil,
personGeneration: nil,
outputOptions: nil,
addWatermark: nil,
outputOptions: ImageGenerationOutputOptions(
mimeType: imageFormat.mimeType,
compressionQuality: imageFormat.compressionQuality
),
addWatermark: addWatermark,
includeResponsibleAIFilterReason: true
)

let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: generationConfig,
modelConfig: nil,
safetySettings: nil
)

Expand Down Expand Up @@ -155,7 +130,6 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: nil,
modelConfig: nil,
safetySettings: safetySettings
)

Expand All @@ -168,15 +142,16 @@ final class ImageGenerationParametersTests: XCTestCase {
let storageURI = "gs://test-bucket/path"
let sampleCount = 4
let negativePrompt = "test-negative-prompt"
let imageFormat = ImagenImageFormat.png()
let aspectRatio = ImagenAspectRatio.portrait3x4
let addWatermark = false
let generationConfig = ImagenGenerationConfig(
numberOfImages: sampleCount,
negativePrompt: negativePrompt,
aspectRatio: aspectRatio
imageFormat: imageFormat,
aspectRatio: aspectRatio,
addWatermark: addWatermark
)
let imageFormat = ImagenImageFormat.png()
let addWatermark = false
let modelConfig = ImagenModelConfig(imageFormat: imageFormat, addWatermark: addWatermark)
let safetyFilterLevel = ImagenSafetyFilterLevel.blockNone
let personFilterLevel = ImagenPersonFilterLevel.blockAll
let safetySettings = ImagenSafetySettings(
Expand All @@ -201,7 +176,6 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: storageURI,
generationConfig: generationConfig,
modelConfig: modelConfig,
safetySettings: safetySettings
)

Expand Down
Loading