Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AssertionImageType to support reference image validation in AI Assertion #637

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,31 @@ package com.github.takahirom.roborazzi
data class AiAssertionOptions(
val aiAssertionModel: AiAssertionModel,
val aiAssertions: List<AiAssertion> = emptyList(),
val systemPrompt: String = """Evaluate the following assertion for fulfillment in the new image.
The evaluation should be based on the comparison between the original image on the left and the new image on the right, with differences highlighted in red in the center. Focus on whether the new image fulfills the requirement specified in the user input.
val assertionImageType: AssertionImageType = AssertionImageType.Comparison,
val systemPrompt: String = when (assertionImageType) {
AssertionImageType.Reference -> """Evaluate the new image's fulfillment of the user's requirements.
The assessment should be based solely on the provided reference image
and the user's input specifications. Focus on whether the new image
meets all functional and design requirements.

Output:
For each assertion:
A fulfillment percentage from 0 to 100.
A brief explanation of how this percentage was determined.""",
- A fulfillment percentage from 0 to 100
- A justification based on requirement adherence rather than visual differences
"""

AssertionImageType.Comparison -> """Evaluate the following assertion for fulfillment in the new image.
The evaluation should be based on the comparison between the original image
on the left and the new image on the right, with differences highlighted in red
in the center. Focus on whether the new image fulfills the requirement specified
in the user input.

Output:
For each assertion:
- A fulfillment percentage from 0 to 100
- A brief explanation of how this percentage was determined
"""
},
val promptTemplate: String = """Assertions:
INPUT_PROMPT
""",
Expand All @@ -33,12 +51,18 @@ INPUT_PROMPT
actualImageFilePath: String,
aiAssertionOptions: AiAssertionOptions
): AiAssertionResults

companion object {
const val DefaultMaxOutputTokens = 300
const val DefaultTemperature = 0.4F
}
}

sealed interface AssertionImageType {
data object Comparison : AssertionImageType
data object Reference : AssertionImageType
}

data class AiAssertion(
val assertionPrompt: String,
val failIfNotFulfilled: Boolean = true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@ import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultMaxOutputTokens
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultTemperature
import dev.shreyaspatil.ai.client.generativeai.GenerativeModel
import dev.shreyaspatil.ai.client.generativeai.type.FunctionType
import dev.shreyaspatil.ai.client.generativeai.type.GenerationConfig
import dev.shreyaspatil.ai.client.generativeai.type.PlatformImage
import dev.shreyaspatil.ai.client.generativeai.type.Schema
import dev.shreyaspatil.ai.client.generativeai.type.content
import dev.shreyaspatil.ai.client.generativeai.type.generationConfig
import dev.shreyaspatil.ai.client.generativeai.type.*
import kotlinx.coroutines.runBlocking
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
Expand Down Expand Up @@ -68,8 +63,12 @@ class GeminiAiAssertionModel(
val template = aiAssertionOptions.promptTemplate

val inputPrompt = aiAssertionOptions.inputPrompt(aiAssertionOptions)
val imageFilePath = when (aiAssertionOptions.assertionImageType) {
AiAssertionOptions.AssertionImageType.Comparison -> comparisonImageFilePath
AiAssertionOptions.AssertionImageType.Reference -> referenceImageFilePath
}
val inputContent = content {
image(readByteArrayFromFile(comparisonImageFilePath))
image(readByteArrayFromFile(imageFilePath))
val prompt = template.replace("INPUT_PROMPT", inputPrompt)
text(prompt)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,15 @@ package com.github.takahirom.roborazzi
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultMaxOutputTokens
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultTemperature
import com.github.takahirom.roborazzi.CaptureResults.Companion.json
import io.ktor.client.HttpClient
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.HttpTimeout.Plugin.INFINITE_TIMEOUT_MS
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.plugins.logging.LogLevel
import io.ktor.client.plugins.logging.Logger
import io.ktor.client.plugins.logging.Logging
import io.ktor.client.plugins.logging.SIMPLE
import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.header
import io.ktor.client.request.post
import io.ktor.client.request.setBody
import io.ktor.client.statement.HttpResponse
import io.ktor.client.statement.bodyAsText
import io.ktor.http.ContentType
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.plugins.logging.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.serialization.kotlinx.json.*
import kotlinx.coroutines.runBlocking
import kotlinx.io.buffered
import kotlinx.io.files.Path
Expand Down Expand Up @@ -57,7 +49,7 @@ class OpenAiAiAssertionModel(
}
if (loggingEnabled) {
install(Logging) {
logger = object: Logger {
logger = object : Logger {
override fun log(message: String) {
Logger.SIMPLE.log(message.replace(apiKey, "****"))
}
Expand All @@ -77,7 +69,11 @@ class OpenAiAiAssertionModel(
val systemPrompt = aiAssertionOptions.systemPrompt
val template = aiAssertionOptions.promptTemplate
val inputPrompt = aiAssertionOptions.inputPrompt(aiAssertionOptions)
val imageBytes = readByteArrayFromFile(comparisonImageFilePath)
val imageFilePath = when (aiAssertionOptions.assertionImageType) {
AiAssertionOptions.AssertionImageType.Comparison -> comparisonImageFilePath
AiAssertionOptions.AssertionImageType.Reference -> referenceImageFilePath
}
val imageBytes = readByteArrayFromFile(imageFilePath)
val imageBase64 = imageBytes.encodeBase64()
val messages = listOf(
Message(
Expand Down
Loading