Skip to content

Commit

Permalink
basic done
Browse files Browse the repository at this point in the history
  • Loading branch information
bartekpacia committed Aug 13, 2024
1 parent f456c97 commit 6ba3d0a
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 24 deletions.
19 changes: 6 additions & 13 deletions maestro-orchestra/src/main/java/maestro/orchestra/Orchestra.kt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ import okio.sink
import java.io.File
import java.lang.Long.max
import java.nio.file.Files
import java.time.LocalDateTime

class Orchestra(
private val maestro: Maestro,
Expand Down Expand Up @@ -363,29 +362,23 @@ class Orchestra(
maestro.takeScreenshot(imageData, compressed = false)
val imageDataBytes = imageData.readByteArray()

File("${LocalDateTime.now()}.png").apply {
createNewFile()
writeBytes(imageDataBytes)
}

val response = Prediction.findDefects(
val defects = Prediction.findDefects(
aiClient = ai,
assertion = null,
screen = imageDataBytes,
previousFalsePositives = listOf(), // TODO: take it from WorkspaceConfig (or MaestroConfig?)
)

// TODO: request response in a specific JSON from AI
// https://platform.openai.com/docs/guides/structured-outputs/introduction
if (!response.response.contains("No defects found")) {
// TODO: Check optional flag (see assertConditionCommand)
if (defects.isNotEmpty()) {
if (command.optional) throw CommandSkipped

throw MaestroException.AssertionFailure(
"Visual AI found defects: ${response.response}",
"Visual AI found possible defects: ${defects.joinToString { "${it.category}: ${it.reasoning}" }}",
maestro.viewHierarchy().root,
)
}

// TODO: Add resul to some "post-flow analysis store" (so results can be viewed in Maestro Studio)
// TODO: Add result to some "post-flow analysis store" (so results can be viewed in Maestro Studio)

false
}
Expand Down
15 changes: 15 additions & 0 deletions maestro-orchestra/src/main/java/maestro/orchestra/ai/AI.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package maestro.orchestra.ai

import kotlinx.serialization.json.JsonObject
import java.io.Closeable

data class CompletionData(
Expand All @@ -21,6 +22,20 @@ abstract class AI : Closeable {
maxTokens: Int? = null,
imageDetail: String? = null,
identifier: String? = null,
jsonSchema: JsonObject? = null,
): CompletionData

companion object {
// We use JSON mode/Structured Outputs to define the schema of the response we expect from the LLM.
// * OpenAI: https://platform.openai.com/docs/guides/structured-outputs
// * Gemini: https://ai.google.dev/gemini-api/docs/json-mode

val assertVisualSchema: String = run {
val resourceStream = this::class.java.getResourceAsStream("/assertVisualAI_schema.json")
?: throw IllegalStateException("Could not find assertVisualAI_schema.json in resources")

resourceStream.bufferedReader().use { it.readText() }
}
}

}
39 changes: 31 additions & 8 deletions maestro-orchestra/src/main/java/maestro/orchestra/ai/Prediction.kt
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
package maestro.orchestra.ai

import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.jsonObject

@Serializable
data class Defect(
val category: String,
val reasoning: String,
)

@Serializable
private data class FindDefectsResponse(
val defects: List<Defect>,
)

object Prediction {
private val json = Json { ignoreUnknownKeys = true }

private val categories = listOf(
"localization" to "Inconsistent use of language, for example mixed English and Portuguese",
"layout" to "Some UI elements are overlapping or are cropped",
)

suspend fun findDefects(aiClient: AI, screen: ByteArray, previousFalsePositives: List<String>): CompletionData {
suspend fun findDefects(
aiClient: AI,
screen: ByteArray,
assertion: String?,
previousFalsePositives: List<String>,
): List<Defect> {

// List of failed attempts to not make up false positives:
// |* If you don't see any defect, return "No defects found".
Expand All @@ -22,28 +43,30 @@ object Prediction {
|* All defects you find must belong to one of the following categories:
|${categories.joinToString { "\n * ${it.first}: ${it.second}" }}
|
|* If you see defects, your response MUST only include defect name and reasoning for each defect.
|* If you see defects, your response MUST only include defect name and detailed reasoning for each defect.
|* Provide response in the format: <defect name>:<reasoning>
|* Do not raise false positives. Some example responses that have a high chance of being a false positive:
|
| * button is partially cropped at the bottom
| * button is not aligned horizontall/vertically within its container
| * button is not aligned horizontally/vertically within its container
|
|${if (previousFalsePositives.isNotEmpty()) "Additionally, the following defects are false positives:" else ""}
|${if (previousFalsePositives.isNotEmpty()) previousFalsePositives.joinToString("\n") { " * $it" } else ""}
""".trimMargin("|")

// println("Prompt:")
// println(prompt)
// println("Prompt:\n$prompt")

return aiClient.chatCompletion(
val aiResponse = aiClient.chatCompletion(
prompt,
// model = "gpt-4o-2024-08-03",
model = "gpt-4o",
model = "gpt-4o-2024-08-06",
maxTokens = 4096,
identifier = "find-defects",
imageDetail = "high",
images = listOf(screen),
jsonSchema = json.parseToJsonElement(AI.assertVisualSchema).jsonObject,
)

val defects = json.decodeFromString<FindDefectsResponse>(aiResponse.response)
return defects.defects
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,23 @@ import io.ktor.client.request.setBody
import io.ktor.client.statement.bodyAsText
import io.ktor.http.ContentType
import io.ktor.http.contentType
import io.ktor.http.isSuccess
import io.ktor.util.encodeBase64
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import maestro.orchestra.ai.AI
import maestro.orchestra.ai.CompletionData
import org.slf4j.LoggerFactory
import java.io.File

private const val API_URL = "https://api.openai.com/v1/chat/completions"

private val logger = LoggerFactory.getLogger(OpenAI::class.java)

class OpenAI(
private val apiKey: String,
private val defaultModel: String = "gpt-4o",
private val defaultModel: String = "gpt-4o-2024-08-06",
private val defaultTemperature: Double = 0.2,
private val defaultMaxTokens: Int = 2048,
private val defaultImageDetail: String = "low",
Expand Down Expand Up @@ -51,6 +54,7 @@ class OpenAI(
maxTokens: Int?,
imageDetail: String?,
identifier: String?,
jsonSchema: JsonObject?,
): CompletionData {
val imagesBase64 = images.map { it.encodeBase64() }

Expand Down Expand Up @@ -81,7 +85,10 @@ class OpenAI(
messages = messages,
maxTokens = actualMaxTokens,
seed = 1566,
responseFormat = null,
responseFormat = if (jsonSchema == null) null else ResponseFormat(
type = "json_schema",
jsonSchema = jsonSchema,
),
)

val chatCompletionResponse = try {
Expand All @@ -91,6 +98,11 @@ class OpenAI(
setBody(json.encodeToString(chatCompletionRequest))
}

if (!httpResponse.status.isSuccess()) {
logger.error("Failed to complete request to OpenAI: ${httpResponse.status}, ${httpResponse.bodyAsText()}")
throw Exception("Failed to complete request to OpenAI: ${httpResponse.status}, ${httpResponse.bodyAsText()}")
}

json.decodeFromString<ChatCompletionResponse>(httpResponse.bodyAsText())
} catch (e: Exception) {
logger.error("Failed to complete request to OpenAI", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package maestro.orchestra.ai.openai

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonObject

@Serializable
data class ChatCompletionRequest(
Expand All @@ -14,8 +15,9 @@ data class ChatCompletionRequest(
)

@Serializable
data class ResponseFormat(
class ResponseFormat(
val type: String,
@SerialName("json_schema") val jsonSchema: JsonObject,
)

@Serializable
Expand Down
32 changes: 32 additions & 0 deletions maestro-orchestra/src/main/resources/assertVisualAI_schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"name": "assertVisualAI",
"description": "List of possible defects found in the mobile app's UI",
"strict": true,
"schema": {
"type": "object",
"required": ["defects"],
"additionalProperties": false,
"properties": {
"defects": {
"type": "array",
"items": {
"type": "object",
"required": ["category", "reasoning"],
"additionalProperties": false,
"properties": {
"category": {
"type": "string",
"enum": [
"layout",
"localization"
]
},
"reasoning": {
"type": "string"
}
}
}
}
}
}
}

0 comments on commit 6ba3d0a

Please sign in to comment.