Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow specifying default values in client configuration #12

Merged
merged 1 commit into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 58 additions & 45 deletions src/commonMain/kotlin/io/github/vyfor/groqkt/GroqClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import io.github.vyfor.groqkt.api.chat.ChatCompletionRequest
import io.github.vyfor.groqkt.api.chat.StreamingChatCompletion
import io.github.vyfor.groqkt.api.model.Model
import io.github.vyfor.groqkt.api.model.Models
import io.github.vyfor.groqkt.util.applyDefaults
import io.github.vyfor.groqkt.util.parse
import io.github.vyfor.groqkt.util.parseHeaders
import io.github.vyfor.groqkt.util.validate
Expand Down Expand Up @@ -67,6 +68,7 @@ class GroqClient(
contentType(ContentType.Application.Json)
setBody(
ChatCompletionRequest.Builder()
.applyDefaults(config.defaults?.chatCompletion)
.apply {
block()
stream = false
Expand Down Expand Up @@ -119,6 +121,7 @@ class GroqClient(
contentType(ContentType.Application.Json)
setBody(
ChatCompletionRequest.Builder()
.applyDefaults(config.defaults?.chatCompletion)
.apply {
block()
stream = true
Expand Down Expand Up @@ -165,7 +168,7 @@ class GroqClient(
"filename=\"${data.filename.encodeURLPathPart()}\"")
},
)
append("model", data.model.id)
append("model", data.model!!.id)
data.prompt?.let { append("prompt", it) }
data.responseFormat?.let { append("response_format", it.name) }
data.temperature?.let { append("temperature", it.toString()) }
Expand All @@ -186,21 +189,25 @@ class GroqClient(
.submitFormWithBinaryData(
AudioTranslationRequest.ENDPOINT,
formData {
AudioTranslationRequest.Builder().apply(block).build().let { data ->
append(
"file",
data.file,
Headers.build {
append(
HttpHeaders.ContentDisposition,
"filename=\"${data.filename.encodeURLPathPart()}\"")
},
)
append("model", data.model.id)
data.prompt?.let { append("prompt", it) }
data.responseFormat?.let { append("response_format", it.name) }
data.temperature?.let { append("temperature", it.toString()) }
}
AudioTranslationRequest.Builder()
.applyDefaults(config.defaults?.audioTranslation)
.apply(block)
.build()
.let { data ->
append(
"file",
data.file,
Headers.build {
append(
HttpHeaders.ContentDisposition,
"filename=\"${data.filename.encodeURLPathPart()}\"")
},
)
append("model", data.model!!.id)
data.prompt?.let { append("prompt", it) }
data.responseFormat?.let { append("response_format", it.name) }
data.temperature?.let { append("temperature", it.toString()) }
}
},
) {
contentType(ContentType.MultiPart.FormData)
Expand Down Expand Up @@ -229,7 +236,7 @@ class GroqClient(
},
)
}
append("model", data.model.id)
append("model", data.model!!.id)
data.language?.let { append("language", it) }
data.prompt?.let { append("prompt", it) }
data.responseFormat?.let { append("response_format", it.name) }
Expand Down Expand Up @@ -260,34 +267,40 @@ class GroqClient(
.submitFormWithBinaryData(
AudioTranscriptionRequest.ENDPOINT,
formData {
AudioTranscriptionRequest.Builder().apply(block).build().let { data ->
data.file?.let {
append(
"file",
it,
Headers.build {
append(
HttpHeaders.ContentDisposition,
"filename=\"${data.filename.encodeURLPathPart()}\"")
},
)
}
append("model", data.model.id)
data.url?.let { append("url", it) }
data.language?.let { append("language", it) }
data.prompt?.let { append("prompt", it) }
data.responseFormat?.let { append("response_format", it.name) }
data.temperature?.let { append("temperature", it.toString()) }
data.timestampGranularities?.let {
append(
"timestamp_granularities",
config.json.encodeToString(
buildJsonArray { it.forEach { enum -> add(JsonPrimitive(enum.value)) } }
.toString(),
),
)
}
}
AudioTranscriptionRequest.Builder()
.applyDefaults(config.defaults?.audioTranscription)
.apply(block)
.build()
.let { data ->
data.file?.let {
append(
"file",
it,
Headers.build {
append(
HttpHeaders.ContentDisposition,
"filename=\"${data.filename.encodeURLPathPart()}\"")
},
)
}
append("model", data.model!!.id)
data.url?.let { append("url", it) }
data.language?.let { append("language", it) }
data.prompt?.let { append("prompt", it) }
data.responseFormat?.let { append("response_format", it.name) }
data.temperature?.let { append("temperature", it.toString()) }
data.timestampGranularities?.let {
append(
"timestamp_granularities",
config.json.encodeToString(
buildJsonArray {
it.forEach { enum -> add(JsonPrimitive(enum.value)) }
}
.toString(),
),
)
}
}
},
) {
contentType(ContentType.MultiPart.FormData)
Expand Down
64 changes: 64 additions & 0 deletions src/commonMain/kotlin/io/github/vyfor/groqkt/GroqConfig.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
package io.github.vyfor.groqkt

import io.github.vyfor.groqkt.GroqClient.Companion.BASE_URL
import io.github.vyfor.groqkt.api.audio.transcription.AudioTranscriptionRequest
import io.github.vyfor.groqkt.api.audio.translation.AudioTranslationRequest
import io.github.vyfor.groqkt.api.chat.ChatCompletionRequest
import io.ktor.client.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.contentnegotiation.*
Expand All @@ -19,10 +22,27 @@ import kotlinx.serialization.json.JsonNamingStrategy
*
* @property json The JSON serializer.
* @property client The HTTP client.
* @property defaults The default values for the [GroqClient]. These values are applied to every
* request made using a DSL function.
*/
data class GroqConfig(
val json: Json,
val client: HttpClient,
val defaults: GroqDefaults?,
)

/**
* Default values for use with the [GroqClient]. These values are applied to every request made
* using a DSL function.
*
* @property chatCompletion The default values for [ChatCompletionRequest].
* @property audioTranslation The default values for [AudioTranslationRequest].
* @property audioTranscription The default values for [AudioTranscriptionRequest].
*/
data class GroqDefaults(
val chatCompletion: (ChatCompletionRequest.Builder.() -> Unit)? = null,
val audioTranslation: (AudioTranslationRequest.Builder.() -> Unit)? = null,
val audioTranscription: (AudioTranscriptionRequest.Builder.() -> Unit)? = null,
)

/**
Expand All @@ -42,6 +62,7 @@ class GroqConfigBuilder(
namingStrategy = JsonNamingStrategy.SnakeCase
classDiscriminatorMode = ClassDiscriminatorMode.NONE
}
private var defaults: GroqDefaults? = null
var client: HttpClient = HttpClient {
install(ContentNegotiation) { json(json) }

Expand Down Expand Up @@ -74,9 +95,52 @@ class GroqConfigBuilder(
}
}

/**
* Sets the default values for the [GroqClient]. These values are applied to every request made
* using a DSL function.
*
* @param block The default values for the [GroqClient].
*/
fun defaults(block: GroqDefaultsBuilder.() -> Unit) {
defaults = GroqDefaultsBuilder().apply(block).build()
}

internal fun build(): GroqConfig =
GroqConfig(
json,
client,
defaults,
)
}

/**
* Groq defaults builder class.
*
* @property chatCompletion The default values for [ChatCompletionRequest].
* @property audioTranslation The default values for [AudioTranslationRequest].
* @property audioTranscription The default values for [AudioTranscriptionRequest].
*/
class GroqDefaultsBuilder {
private var chatCompletion: (ChatCompletionRequest.Builder.() -> Unit)? = null
private var audioTranslation: (AudioTranslationRequest.Builder.() -> Unit)? = null
private var audioTranscription: (AudioTranscriptionRequest.Builder.() -> Unit)? = null

fun chatCompletion(block: ChatCompletionRequest.Builder.() -> Unit) {
chatCompletion = block
}

fun audioTranslation(block: AudioTranslationRequest.Builder.() -> Unit) {
audioTranslation = block
}

fun audioTranscription(block: AudioTranscriptionRequest.Builder.() -> Unit) {
audioTranscription = block
}

internal fun build(): GroqDefaults =
GroqDefaults(
chatCompletion,
audioTranslation,
audioTranscription,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ data class AudioTranscriptionRequest(
val file: ByteArray? = null,
val url: String? = null,
val language: String? = null,
val model: GroqModel,
val model: GroqModel?,
val prompt: String? = null,
val responseFormat: AudioResponseFormat? = null,
val temperature: Double? = null,
Expand All @@ -49,6 +49,7 @@ data class AudioTranscriptionRequest(
var filename: String = "audio.mp3"

init {
require(model != null) { "model must be set" }
require(file != null || url != null) { "either file or url must be set" }
}

Expand Down Expand Up @@ -130,7 +131,7 @@ data class AudioTranscriptionRequest(
file,
url,
language,
requireNotNull(model) { "model must be set" },
model,
prompt,
responseFormat,
temperature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@ import kotlinx.serialization.Serializable
@Serializable
data class AudioTranslationRequest(
val file: ByteArray,
val model: GroqModel,
val model: GroqModel?,
val prompt: String? = null,
val responseFormat: AudioResponseFormat? = null,
val temperature: Double? = null,
) {
var filename: String = "audio.mp3"

init {
require(model != null) { "model must be set" }
}

companion object {
const val ENDPOINT = "audio/translations"
}
Expand Down Expand Up @@ -103,7 +107,7 @@ data class AudioTranslationRequest(
fun build(): AudioTranslationRequest {
return AudioTranslationRequest(
requireNotNull(file) { "file must be set" },
requireNotNull(model) { "model must be set" },
model,
prompt,
responseFormat,
temperature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ data class ChatCompletionRequest(
/* val logprobs: Boolean? = null, */
val maxTokens: Int? = null,
val messages: List<CompletionMessage>,
val model: GroqModel,
val model: GroqModel?,
val n: Int? = null,
val parallelToolCalls: Boolean? = null,
var presencePenalty: Double? = null,
Expand All @@ -96,9 +96,10 @@ data class ChatCompletionRequest(
}

init {
require(n == null || n == 1) { "Currently only n = 1 is supported." }
require(model != null) { "model must be set" }
require(n == null || n == 1) { "currently only n = 1 is supported." }
require(streamOptions == null || stream == true) { "streamOptions must have stream = true." }
require(tools == null || tools.size <= 128) { "Currently only up to 128 tools are supported." }
require(tools == null || tools.size <= 128) { "currently only up to 128 tools are supported." }
require(messages.isNotEmpty()) { "messages must not be empty." }
presencePenalty = presencePenalty?.coerceIn(-2.0, 2.0)
temperature = temperature?.coerceIn(-2.0, 2.0)
Expand Down Expand Up @@ -210,7 +211,7 @@ data class ChatCompletionRequest(
functions,
maxTokens,
requireNotNull(messages) { "messages must be set" },
requireNotNull(model) { "model must be set" },
model,
n,
parallelToolCalls,
presencePenalty,
Expand Down Expand Up @@ -243,18 +244,13 @@ class ChatCompletionMessageBuilder {
fun image(image: String) {
messages.add(
CompletionMessage.User(
UserMessageType.Array(
imageContent =
Image(ImageObject(url = image)))))
UserMessageType.Array(imageContent = Image(ImageObject(url = image)))))
}

fun user(content: String?, image: String?, name: String? = null) {
messages.add(
CompletionMessage.User(
UserMessageType.Array(
Text(content),
Image(ImageObject(url = image))),
name))
UserMessageType.Array(Text(content), Image(ImageObject(url = image))), name))
}

fun assistant(
Expand Down Expand Up @@ -544,17 +540,10 @@ sealed class CompletionMessage(val role: String) {
fun text(content: String) = User(UserMessageType.Text(content))

fun image(image: String) =
User(
UserMessageType.Array(
imageContent =
Image(ImageObject(url = image))))
User(UserMessageType.Array(imageContent = Image(ImageObject(url = image))))

fun user(content: String?, image: String?, name: String? = null) =
User(
UserMessageType.Array(
Text(content),
Image(ImageObject(url = image))),
name)
User(UserMessageType.Array(Text(content), Image(ImageObject(url = image))), name)

fun assistant(
content: String,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@file:Suppress("unused")
@file:Suppress("unused", "NAME_SHADOWING")

package io.github.vyfor.groqkt.util

Expand All @@ -14,6 +14,9 @@ import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.minutes
import kotlin.time.Duration.Companion.seconds

internal inline fun <reified T> T.applyDefaults(noinline defaults: (T.() -> Unit)?): T =
defaults?.let { defaults -> apply(defaults) } ?: this

internal suspend inline fun <reified T> HttpResponse.validate(): Result<T> =
if (status.isSuccess()) {
Result.success(body<T>())
Expand Down
Loading