Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make OpenAiAiAssertionModel gemini API compatible #626

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class OpenAiAiAssertionModel(
private val loggingEnabled: Boolean = false,
private val temperature: Float = DefaultTemperature,
private val maxTokens: Int = DefaultMaxOutputTokens,
private val seed: Int = 1566,
private val seed: Int? = 1566,
private val requestBuilderModifier: (HttpRequestBuilder.() -> Unit) = {
header("Authorization", "Bearer $apiKey")
},
Expand Down Expand Up @@ -231,7 +231,7 @@ private data class ChatCompletionRequest(
val temperature: Float,
@SerialName("max_tokens") val maxTokens: Int,
@SerialName("response_format") val responseFormat: ResponseFormat?,
val seed: Int,
val seed: Int?,
)

@Serializable
Expand Down Expand Up @@ -260,7 +260,8 @@ private data class ImageUrl(

@Serializable
private data class ChatCompletionResponse(
val id: String,
// null on gemini
val id: String? = null,
val `object`: String,
val created: Long,
val model: String,
Expand All @@ -283,9 +284,12 @@ private data class ChoiceMessage(

@Serializable
private data class Usage(
@SerialName("prompt_tokens") val promptTokens: Int,
// null on gemini
@SerialName("prompt_tokens") val promptTokens: Int? = null,
// null on gemini
@SerialName("completion_tokens") val completionTokens: Int? = null,
@SerialName("total_tokens") val totalTokens: Int,
// null on gemini
@SerialName("total_tokens") val totalTokens: Int? = null,
)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package com.github.takahirom.roborazzi.sample

import androidx.compose.ui.test.junit4.createAndroidComposeRule
import androidx.test.espresso.Espresso.onView
import androidx.test.espresso.matcher.ViewMatchers
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.github.takahirom.roborazzi.AiAssertionOptions
import com.github.takahirom.roborazzi.DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH
import com.github.takahirom.roborazzi.ExperimentalRoborazziApi
import com.github.takahirom.roborazzi.OpenAiAiAssertionModel
import com.github.takahirom.roborazzi.ROBORAZZI_DEBUG
import com.github.takahirom.roborazzi.RobolectricDeviceQualifiers
import com.github.takahirom.roborazzi.RoborazziOptions
import com.github.takahirom.roborazzi.RoborazziRule
import com.github.takahirom.roborazzi.RoborazziTaskType
import com.github.takahirom.roborazzi.captureRoboImage
import com.github.takahirom.roborazzi.provideRoborazziContext
import com.github.takahirom.roborazzi.roboOutputName
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.annotation.Config
import org.robolectric.annotation.GraphicsMode
import java.io.File

@OptIn(ExperimentalRoborazziApi::class)
@RunWith(AndroidJUnit4::class)
@GraphicsMode(GraphicsMode.Mode.NATIVE)
@Config(
sdk = [30],
qualifiers = RobolectricDeviceQualifiers.NexusOne
)
class GeminiWithOpenAiApiInterfaceTest {
@get:Rule
val composeTestRule = createAndroidComposeRule<MainActivity>()

@get:Rule
val roborazziRule = RoborazziRule(
options = RoborazziRule.Options(
roborazziOptions = RoborazziOptions(
taskType = RoborazziTaskType.Compare,
compareOptions = RoborazziOptions.CompareOptions(
aiAssertionOptions = AiAssertionOptions(
aiAssertionModel = OpenAiAiAssertionModel(
baseUrl = "https://generativelanguage.googleapis.com/v1beta/openai/",
apiKey = System.getenv("gemini_api_key").orEmpty(),
modelName = "gemini-1.5-flash",
seed = null
),
)
)
)
)
)

@Test
fun captureWithAi() {
ROBORAZZI_DEBUG = true
if (System.getenv("gemini_api_key") == null) {
println("Skip the test because openai_api_key is not set.")
return
}
File(DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH + File.separator + roboOutputName() + ".png").delete()
onView(ViewMatchers.isRoot())
.captureRoboImage(
roborazziOptions = provideRoborazziContext().options.addedAiAssertions(
AiAssertionOptions.AiAssertion(
assertionPrompt = "it should have PREVIOUS button",
requiredFulfillmentPercent = 90,
),
AiAssertionOptions.AiAssertion(
assertionPrompt = "it should show First Fragment",
requiredFulfillmentPercent = 90,
)
)
)
File(DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH + File.separator + roboOutputName() + "_compare.png").delete()
}
}
Loading