Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add support for dimensions parameter to OpenAIEmbedding #2215

Merged
merged 5 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,24 @@ trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion {

}

trait HasOpenAIEmbeddingParams extends HasOpenAISharedParams with HasAPIVersion {

val dimensions: ServiceParam[Int] = new ServiceParam[Int](
this, "dimensions", "Number of dimensions for output embeddings.", isRequired = false)

def getDimensions: Int = getScalarParam(dimensions)

def setDimensions(value: Int): this.type = setScalarParam(dimensions, value)

private[ml] def getOptionalParams(r: Row): Map[String, Any] = {
Seq(
dimensions
).flatMap(param =>
getValueOpt(r, param).map(v => (GenerationUtils.camelToSnake(param.name), v))
).toMap
}
}

trait HasOpenAITextParams extends HasOpenAISharedParams {

val maxTokens: ServiceParam[Int] = new ServiceParam[Int](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.io.http.JSONOutputParser
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.ServiceParam
Expand All @@ -21,7 +22,7 @@ import scala.language.existentials
object OpenAIEmbedding extends ComplexParamsReadable[OpenAIEmbedding]

class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAISharedParams with HasOpenAICognitiveServiceInput with SynapseMLLogging {
with HasOpenAIEmbeddingParams with HasOpenAICognitiveServiceInput with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

def this() = this(Identifiable.randomUID("OpenAIEmbedding"))
Expand Down Expand Up @@ -61,10 +62,16 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/embeddings"
}

private[this] def getStringEntity[A](text: A, optionalParams: Map[String, Any]): StringEntity = {
val fullPayload = optionalParams.updated("input", text)
new StringEntity(fullPayload.toJson.compactPrint, ContentType.APPLICATION_JSON)
}

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
r =>
lazy val optionalParams: Map[String, Any] = getOptionalParams(r)
getValueOpt(r, text)
.map(text => new StringEntity(Map("input" -> text).toJson.compactPrint, ContentType.APPLICATION_JSON))
.map(text => getStringEntity(text, optionalParams))
.orElse(throw new IllegalArgumentException("Please set textCol."))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ import org.scalactic.Equality

trait OpenAIAPIKey {
lazy val openAIAPIKey: String = sys.env.getOrElse("OPENAI_API_KEY", Secrets.OpenAIApiKey)
lazy val openAIServiceName: String = "synapseml-openai"
lazy val openAIServiceName: String = sys.env.getOrElse("OPENAI_SERVICE_NAME", "synapseml-openai")
lazy val deploymentName: String = "gpt-35-turbo"
lazy val modelName: String = "gpt-35-turbo"
lazy val openAIAPIKeyGpt4: String = sys.env.getOrElse("OPENAI_API_KEY_2", Secrets.OpenAIApiKeyGpt4)
lazy val openAIServiceNameGpt4: String = "synapseml-openai-2"
lazy val openAIServiceNameGpt4: String = sys.env.getOrElse("OPENAI_SERVICE_NAME_2", "synapseml-openai-2")
lazy val deploymentNameGpt4: String = "gpt-4"
lazy val modelNameGpt4: String = "gpt-4"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,24 @@ class OpenAIEmbeddingsSuite extends TransformerFuzzing[OpenAIEmbedding] with Ope
})
}

lazy val embeddingExtra: OpenAIEmbedding = new OpenAIEmbedding()
.setSubscriptionKey(openAIAPIKeyGpt4)
.setDeploymentName("text-embedding-3-small")
.setApiVersion("2024-03-01-preview")
.setDimensions(100)
.setUser("testUser")
.setCustomServiceName(openAIServiceNameGpt4)
.setTextCol("text")
.setOutputCol("out")

test("Extra Params Usage") {
embeddingExtra.transform(df).collect().foreach(r => {
val v = r.getAs[Vector]("out")
assert(v.size == 100)
})
}


override def testObjects(): Seq[TestObject[OpenAIEmbedding]] =
Seq(new TestObject(embedding, df))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ from synapse.ml.core.platform import find_secret
service_name = "synapseml-openai"
deployment_name = "gpt-35-turbo"
deployment_name_embeddings = "text-embedding-ada-002"
deployment_name_embeddings_3 = "text-embedding-3-small"

key = find_secret(
secret_name="openai-api-key", keyvault="mmlspark-build-keys"
Expand Down Expand Up @@ -132,6 +133,29 @@ embedding = (
display(embedding.transform(df))
```

### Generating Text Embeddings with Reduced Dimensions

Text-Embedding-3 models developed by OpenAI are trained using a Matryoshka Representation Learning technique
which supports reducing the dimension of the embedding by trading-off some performance.

```python
from synapse.ml.services.openai import OpenAIEmbedding

embedding = (
OpenAIEmbedding()
.setSubscriptionKey(key)
.setDeploymentName(deployment_name_embeddings_3)
.setCustomServiceName(service_name)
.setApiVersion("2024-03-01-preview")
.setDimensions(256)
.setTextCol("prompt")
.setErrorCol("error")
.setOutputCol("embeddings")
)

display(embedding.transform(df))
```

### Chat Completion

Models such as ChatGPT and GPT-4 are capable of understanding chats instead of single prompts. The `OpenAIChatCompletion` transformer exposes this functionality at scale.
Expand Down