diff --git a/src/main/java/io/github/stefanbratanov/jvm/openai/EmbeddingsRequest.java b/src/main/java/io/github/stefanbratanov/jvm/openai/EmbeddingsRequest.java index ef982e7..29669ed 100644 --- a/src/main/java/io/github/stefanbratanov/jvm/openai/EmbeddingsRequest.java +++ b/src/main/java/io/github/stefanbratanov/jvm/openai/EmbeddingsRequest.java @@ -5,7 +5,11 @@ import java.util.Optional; public record EmbeddingsRequest( - List input, String model, Optional encodingFormat, Optional user) { + List input, + String model, + Optional encodingFormat, + Optional dimensions, + Optional user) { public static Builder newBuilder() { return new Builder(); @@ -16,6 +20,7 @@ public static class Builder { private List input; private String model; private Optional encodingFormat = Optional.empty(); + private Optional dimensions = Optional.empty(); private Optional user = Optional.empty(); /** @@ -58,6 +63,15 @@ public Builder encodingFormat(String encodingFormat) { return this; } + /** + * @param dimensions The number of dimensions the resulting output embeddings should have. Only + * supported in text-embedding-3 and later models. + */ + public Builder dimensions(int dimensions) { + this.dimensions = Optional.of(dimensions); + return this; + } + /** * @param user A unique identifier representing your end-user, which can help OpenAI to monitor * and detect abuse. @@ -74,7 +88,7 @@ public EmbeddingsRequest build() { if (model == null) { throw new IllegalStateException("model must be set"); } - return new EmbeddingsRequest(List.copyOf(input), model, encodingFormat, user); + return new EmbeddingsRequest(List.copyOf(input), model, encodingFormat, dimensions, user); } } } diff --git a/src/test/java/io/github/stefanbratanov/jvm/openai/TestDataUtil.java b/src/test/java/io/github/stefanbratanov/jvm/openai/TestDataUtil.java index d5ce23d..79e0804 100644 --- a/src/test/java/io/github/stefanbratanov/jvm/openai/TestDataUtil.java +++ b/src/test/java/io/github/stefanbratanov/jvm/openai/TestDataUtil.java @@ -83,8 +83,9 @@ public EmbeddingsRequest randomEmbeddingsRequest() { () -> builder.input(randomIntArray(randomInt(1, 5))), () -> builder.input(listOf(randomInt(1, 10), () -> randomIntArray(randomInt(1, 5))))); return builder - .model(randomModel()) + .model(oneOf("text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large")) .encodingFormat(oneOf("float", "base64")) + .dimensions(randomInt(1, 10)) .user(randomString(10)) .build(); }