diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000000..f2ea93d45f7 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,17 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +indent_size = 4 +indent_style = space +insert_final_newline = true +max_line_length = 100 +tab_width = 4 + +[*.java] +ij_java_names_count_to_use_import_on_demand = 999 +ij_java_class_count_to_use_import_on_demand = 999 + +[{*.yaml,*.yml}] +indent_size = 2 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 34f5add8ba8..b4e6467d1cc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,8 +5,9 @@ Thank you for investing your time and effort in contributing to our project, we - If you want to contribute a bug fix or a new feature that isn't listed in the [issues](https://github.com/langchain4j/langchain4j/issues) yet, please open a new issue for it. We will prioritize is shortly. - Follow [Google's Best Practices for Java Libraries](https://jlbp.dev/) - Keep the code compatible with Java 17. -- Avoid adding new dependencies as much as possible (new dependencies with test scope are OK). If absolutely necessary, try to use the same libraries which are already used in the project. +- Avoid adding new dependencies as much as possible (new dependencies with test scope are OK). If absolutely necessary, try to use the same libraries which are already used in the project. Make sure you run `mvn dependency:analyze` to identify unnecessary dependencies. - Write unit and/or integration tests for your code. This is critical: no tests, no review! +- Make sure you run all unit tests on all modules with `mvn clean test` - Avoid making breaking changes. Always keep backward compatibility in mind. For example, instead of removing fields/methods/etc, mark them `@Deprecated` and make sure they still work as before. - Follow existing naming conventions. - Avoid using Lombok in the new code, and remove it from the old code if you get a chance. diff --git a/langchain4j-core/pom.xml b/langchain4j-core/pom.xml index 5fd70502b80..9f811704ed5 100644 --- a/langchain4j-core/pom.xml +++ b/langchain4j-core/pom.xml @@ -34,12 +34,6 @@ slf4j-api - - org.projectlombok - lombok - provided - - org.junit.jupiter junit-jupiter-engine diff --git a/langchain4j-core/src/main/java/dev/langchain4j/Experimental.java b/langchain4j-core/src/main/java/dev/langchain4j/Experimental.java index accdff9d188..35fbb892d47 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/Experimental.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/Experimental.java @@ -2,7 +2,9 @@ import java.lang.annotation.Target; -import static java.lang.annotation.ElementType.*; +import static java.lang.annotation.ElementType.CONSTRUCTOR; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.TYPE; /** * Indicates that a class/constructor/method is experimental and might change in the future. diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/document/Metadata.java b/langchain4j-core/src/main/java/dev/langchain4j/data/document/Metadata.java index 8763f27f025..6c94b15644e 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/document/Metadata.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/document/Metadata.java @@ -3,7 +3,12 @@ import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingStore; -import java.util.*; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.UUID; import static dev.langchain4j.internal.Exceptions.illegalArgument; import static dev.langchain4j.internal.Exceptions.runtime; @@ -67,8 +72,8 @@ public Metadata(Map metadata) { validate(key, value); if (!SUPPORTED_VALUE_TYPES.contains(value.getClass())) { throw illegalArgument("The metadata key '%s' has the value '%s', which is of the unsupported type '%s'. " + - "Currently, the supported types are: %s", - key, value, value.getClass().getName(), SUPPORTED_VALUE_TYPES + "Currently, the supported types are: %s", + key, value, value.getClass().getName(), SUPPORTED_VALUE_TYPES ); } }); @@ -116,7 +121,7 @@ public String getString(String key) { } throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " + - "It cannot be returned as a String.", key, value, value.getClass().getName()); + "It cannot be returned as a String.", key, value, value.getClass().getName()); } /** @@ -140,7 +145,7 @@ public UUID getUUID(String key) { } throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " + - "It cannot be returned as a UUID.", key, value, value.getClass().getName()); + "It cannot be returned as a UUID.", key, value, value.getClass().getName()); } /** @@ -172,7 +177,7 @@ public Integer getInteger(String key) { } throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " + - "It cannot be returned as an Integer.", key, value, value.getClass().getName()); + "It cannot be returned as an Integer.", key, value, value.getClass().getName()); } /** @@ -204,7 +209,7 @@ public Long getLong(String key) { } throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " + - "It cannot be returned as a Long.", key, value, value.getClass().getName()); + "It cannot be returned as a Long.", key, value, value.getClass().getName()); } /** @@ -236,7 +241,7 @@ public Float getFloat(String key) { } throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " + - "It cannot be returned as a Float.", key, value, value.getClass().getName()); + "It cannot be returned as a Float.", key, value, value.getClass().getName()); } /** @@ -268,7 +273,7 @@ public Double getDouble(String key) { } throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " + - "It cannot be returned as a Double.", key, value, value.getClass().getName()); + "It cannot be returned as a Double.", key, value, value.getClass().getName()); } /** @@ -449,8 +454,8 @@ public int hashCode() { @Override public String toString() { return "Metadata {" + - " metadata = " + metadata + - " }"; + " metadata = " + metadata + + " }"; } /** diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/message/AiMessage.java b/langchain4j-core/src/main/java/dev/langchain4j/data/message/AiMessage.java index 65d392c5998..6ff573f88b4 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/message/AiMessage.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/message/AiMessage.java @@ -8,7 +8,9 @@ import static dev.langchain4j.data.message.ChatMessageType.AI; import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.internal.Utils.quoted; -import static dev.langchain4j.internal.ValidationUtils.*; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static java.util.Arrays.asList; /** @@ -90,7 +92,7 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; AiMessage that = (AiMessage) o; return Objects.equals(this.text, that.text) - && Objects.equals(this.toolExecutionRequests, that.toolExecutionRequests); + && Objects.equals(this.toolExecutionRequests, that.toolExecutionRequests); } @Override @@ -101,9 +103,9 @@ public int hashCode() { @Override public String toString() { return "AiMessage {" + - " text = " + quoted(text) + - " toolExecutionRequests = " + toolExecutionRequests + - " }"; + " text = " + quoted(text) + + " toolExecutionRequests = " + toolExecutionRequests + + " }"; } /** @@ -139,7 +141,7 @@ public static AiMessage from(List toolExecutionRequests) { /** * Create a new {@link AiMessage} with the given text and tool execution requests. * - * @param text the text of the message. + * @param text the text of the message. * @param toolExecutionRequests the tool execution requests of the message. * @return the new {@link AiMessage}. */ @@ -180,7 +182,7 @@ public static AiMessage aiMessage(List toolExecutionReques /** * Create a new {@link AiMessage} with the given text and tool execution requests. * - * @param text the text of the message. + * @param text the text of the message. * @param toolExecutionRequests the tool execution requests of the message. * @return the new {@link AiMessage}. */ diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/message/ChatMessageDeserializer.java b/langchain4j-core/src/main/java/dev/langchain4j/data/message/ChatMessageDeserializer.java index 7853fb9e381..9d488fd567b 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/message/ChatMessageDeserializer.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/message/ChatMessageDeserializer.java @@ -1,8 +1,9 @@ package dev.langchain4j.data.message; -import static dev.langchain4j.data.message.ChatMessageSerializer.CODEC; import java.util.List; +import static dev.langchain4j.data.message.ChatMessageSerializer.CODEC; + /** * A deserializer for {@link ChatMessage} objects. */ diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageAdapter.java b/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageAdapter.java index 477a9308ffb..1fab59b105f 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageAdapter.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageAdapter.java @@ -1,19 +1,26 @@ package dev.langchain4j.data.message; -import com.google.gson.*; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonDeserializationContext; +import com.google.gson.JsonDeserializer; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonSerializationContext; +import com.google.gson.JsonSerializer; import java.lang.reflect.Type; class GsonChatMessageAdapter implements JsonDeserializer, JsonSerializer { private static final Gson GSON = new GsonBuilder() - .registerTypeAdapter(Content.class, new GsonContentAdapter()) - .registerTypeAdapter(TextContent.class, new GsonContentAdapter()) - .registerTypeAdapter(ImageContent.class, new GsonContentAdapter()) - .registerTypeAdapter(AudioContent.class, new GsonContentAdapter()) - .registerTypeAdapter(VideoContent.class, new GsonContentAdapter()) - .registerTypeAdapter(PdfFileContent.class, new GsonContentAdapter()) - .create(); + .registerTypeAdapter(Content.class, new GsonContentAdapter()) + .registerTypeAdapter(TextContent.class, new GsonContentAdapter()) + .registerTypeAdapter(ImageContent.class, new GsonContentAdapter()) + .registerTypeAdapter(AudioContent.class, new GsonContentAdapter()) + .registerTypeAdapter(VideoContent.class, new GsonContentAdapter()) + .registerTypeAdapter(PdfFileContent.class, new GsonContentAdapter()) + .create(); private static final String CHAT_MESSAGE_TYPE = "type"; // do not change, will break backward compatibility! @@ -35,4 +42,4 @@ public ChatMessage deserialize(JsonElement messageJsonElement, Type ignored, Jso } return chatMessage; } -} \ No newline at end of file +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageJsonCodec.java b/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageJsonCodec.java index 94f85976a9c..5955cce09d7 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageJsonCodec.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageJsonCodec.java @@ -1,12 +1,14 @@ package dev.langchain4j.data.message; -import static java.util.Collections.emptyList; import com.google.gson.Gson; import com.google.gson.GsonBuilder; import com.google.gson.reflect.TypeToken; + import java.lang.reflect.Type; import java.util.List; +import static java.util.Collections.emptyList; + /** * A codec for serializing and deserializing {@link ChatMessage} objects to and from JSON. */ diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonContentAdapter.java b/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonContentAdapter.java index f1d886599c8..66e3135d403 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonContentAdapter.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonContentAdapter.java @@ -1,6 +1,12 @@ package dev.langchain4j.data.message; -import com.google.gson.*; +import com.google.gson.Gson; +import com.google.gson.JsonDeserializationContext; +import com.google.gson.JsonDeserializer; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonSerializationContext; +import com.google.gson.JsonSerializer; import java.lang.reflect.Type; @@ -23,4 +29,4 @@ public Content deserialize(JsonElement contentJsonElement, Type ignored, JsonDes ContentType contentType = ContentType.valueOf(contentTypeString); return GSON.fromJson(contentJsonElement, contentType.getContentClass()); } -} \ No newline at end of file +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelRequest.java b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelRequest.java index a4f766eca66..ff747c55f47 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelRequest.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelRequest.java @@ -5,7 +5,6 @@ import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; -import lombok.Builder; import java.util.List; @@ -25,7 +24,6 @@ public class ChatModelRequest { private final List messages; private final List toolSpecifications; - @Builder public ChatModelRequest(String model, Double temperature, Double topP, @@ -40,6 +38,10 @@ public ChatModelRequest(String model, this.toolSpecifications = copyIfNotNull(toolSpecifications); } + public static ChatModelRequestBuilder builder() { + return new ChatModelRequestBuilder(); + } + public String model() { return model; } @@ -63,4 +65,54 @@ public List messages() { public List toolSpecifications() { return toolSpecifications; } + + public static class ChatModelRequestBuilder { + private String model; + private Double temperature; + private Double topP; + private Integer maxTokens; + private List messages; + private List toolSpecifications; + + ChatModelRequestBuilder() { + } + + public ChatModelRequestBuilder model(String model) { + this.model = model; + return this; + } + + public ChatModelRequestBuilder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public ChatModelRequestBuilder topP(Double topP) { + this.topP = topP; + return this; + } + + public ChatModelRequestBuilder maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public ChatModelRequestBuilder messages(List messages) { + this.messages = messages; + return this; + } + + public ChatModelRequestBuilder toolSpecifications(List toolSpecifications) { + this.toolSpecifications = toolSpecifications; + return this; + } + + public ChatModelRequest build() { + return new ChatModelRequest(this.model, this.temperature, this.topP, this.maxTokens, this.messages, this.toolSpecifications); + } + + public String toString() { + return "ChatModelRequest.ChatModelRequestBuilder(model=" + this.model + ", temperature=" + this.temperature + ", topP=" + this.topP + ", maxTokens=" + this.maxTokens + ", messages=" + this.messages + ", toolSpecifications=" + this.toolSpecifications + ")"; + } + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelResponse.java b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelResponse.java index 05397668d5d..00446ed17d1 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelResponse.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelResponse.java @@ -6,7 +6,6 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.TokenUsage; -import lombok.Builder; /** * A response from the {@link ChatLanguageModel} or {@link StreamingChatLanguageModel}, @@ -21,7 +20,6 @@ public class ChatModelResponse { private final FinishReason finishReason; private final AiMessage aiMessage; - @Builder public ChatModelResponse(String id, String model, TokenUsage tokenUsage, @@ -34,6 +32,10 @@ public ChatModelResponse(String id, this.aiMessage = aiMessage; } + public static ChatModelResponseBuilder builder() { + return new ChatModelResponseBuilder(); + } + public String id() { return id; } @@ -53,4 +55,48 @@ public FinishReason finishReason() { public AiMessage aiMessage() { return aiMessage; } + + public static class ChatModelResponseBuilder { + private String id; + private String model; + private TokenUsage tokenUsage; + private FinishReason finishReason; + private AiMessage aiMessage; + + ChatModelResponseBuilder() { + } + + public ChatModelResponseBuilder id(String id) { + this.id = id; + return this; + } + + public ChatModelResponseBuilder model(String model) { + this.model = model; + return this; + } + + public ChatModelResponseBuilder tokenUsage(TokenUsage tokenUsage) { + this.tokenUsage = tokenUsage; + return this; + } + + public ChatModelResponseBuilder finishReason(FinishReason finishReason) { + this.finishReason = finishReason; + return this; + } + + public ChatModelResponseBuilder aiMessage(AiMessage aiMessage) { + this.aiMessage = aiMessage; + return this; + } + + public ChatModelResponse build() { + return new ChatModelResponse(this.id, this.model, this.tokenUsage, this.finishReason, this.aiMessage); + } + + public String toString() { + return "ChatModelResponse.ChatModelResponseBuilder(id=" + this.id + ", model=" + this.model + ", tokenUsage=" + this.tokenUsage + ", finishReason=" + this.finishReason + ", aiMessage=" + this.aiMessage + ")"; + } + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/embedding/DisabledEmbeddingModel.java b/langchain4j-core/src/main/java/dev/langchain4j/model/embedding/DisabledEmbeddingModel.java index 9412afcbe4b..360fb829a11 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/model/embedding/DisabledEmbeddingModel.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/embedding/DisabledEmbeddingModel.java @@ -6,7 +6,6 @@ import dev.langchain4j.model.output.Response; import java.util.List; -import java.util.Map; /** * An {@link EmbeddingModel} which throws a {@link ModelDisabledException} for all of its methods diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java b/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java index bc55ad9b8bd..14be9f791f0 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java @@ -2,6 +2,7 @@ import dev.langchain4j.data.image.Image; import dev.langchain4j.model.output.Response; + import java.util.List; /** diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/input/structured/DefaultStructuredPromptFactory.java b/langchain4j-core/src/main/java/dev/langchain4j/model/input/structured/DefaultStructuredPromptFactory.java index fe08fa8d7e6..55cbdf821a0 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/model/input/structured/DefaultStructuredPromptFactory.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/input/structured/DefaultStructuredPromptFactory.java @@ -7,6 +7,7 @@ import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.spi.prompt.structured.StructuredPromptFactory; + import java.util.Map; /** diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/AugmentationResult.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/AugmentationResult.java index 4f70e6a9935..47bffa263a1 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/AugmentationResult.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/AugmentationResult.java @@ -2,7 +2,6 @@ import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.rag.content.Content; -import lombok.Builder; import java.util.List; @@ -24,12 +23,15 @@ public class AugmentationResult { */ private final List contents; - @Builder public AugmentationResult(ChatMessage chatMessage, List contents) { this.chatMessage = ensureNotNull(chatMessage, "chatMessage"); this.contents = copyIfNotNull(contents); } + public static AugmentationResultBuilder builder() { + return new AugmentationResultBuilder(); + } + public ChatMessage chatMessage() { return chatMessage; } @@ -37,4 +39,30 @@ public ChatMessage chatMessage() { public List contents() { return contents; } + + public static class AugmentationResultBuilder { + private ChatMessage chatMessage; + private List contents; + + AugmentationResultBuilder() { + } + + public AugmentationResultBuilder chatMessage(ChatMessage chatMessage) { + this.chatMessage = chatMessage; + return this; + } + + public AugmentationResultBuilder contents(List contents) { + this.contents = contents; + return this; + } + + public AugmentationResult build() { + return new AugmentationResult(this.chatMessage, this.contents); + } + + public String toString() { + return "AugmentationResult.AugmentationResultBuilder(chatMessage=" + this.chatMessage + ", contents=" + this.contents + ")"; + } + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java index db228db14a1..0604332c166 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java @@ -14,22 +14,30 @@ import dev.langchain4j.rag.query.router.QueryRouter; import dev.langchain4j.rag.query.transformer.DefaultQueryTransformer; import dev.langchain4j.rag.query.transformer.QueryTransformer; -import lombok.Builder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.concurrent.*; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; -import static java.util.Collections.*; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; import static java.util.concurrent.CompletableFuture.allOf; import static java.util.concurrent.CompletableFuture.supplyAsync; import static java.util.concurrent.TimeUnit.SECONDS; -import static java.util.stream.Collectors.*; +import static java.util.stream.Collectors.joining; +import static java.util.stream.Collectors.toMap; /** * The default implementation of {@link RetrievalAugmentor} intended to be suitable for the majority of use cases. @@ -109,7 +117,6 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor { private final ContentInjector contentInjector; private final Executor executor; - @Builder public DefaultRetrievalAugmentor(QueryTransformer queryTransformer, QueryRouter queryRouter, ContentAggregator contentAggregator, @@ -124,9 +131,9 @@ public DefaultRetrievalAugmentor(QueryTransformer queryTransformer, private static ExecutorService createDefaultExecutor() { return new ThreadPoolExecutor( - 0, Integer.MAX_VALUE, - 1, SECONDS, - new SynchronousQueue<>() + 0, Integer.MAX_VALUE, + 1, SECONDS, + new SynchronousQueue<>() ); } @@ -160,9 +167,9 @@ public AugmentationResult augment(AugmentationRequest augmentationRequest) { log(augmentedChatMessage); return AugmentationResult.builder() - .chatMessage(augmentedChatMessage) - .contents(contents) - .build(); + .chatMessage(augmentedChatMessage) + .contents(contents) + .build(); } private Map>> process(Collection queries) { @@ -183,13 +190,13 @@ private Map>> process(Collection queries) Map>>> queryToFutureContents = new ConcurrentHashMap<>(); queries.forEach(query -> { CompletableFuture>> futureContents = - supplyAsync(() -> { - Collection retrievers = queryRouter.route(query); - log(query, retrievers); - return retrievers; - }, - executor - ).thenCompose(retrievers -> retrieveFromAll(retrievers, query)); + supplyAsync(() -> { + Collection retrievers = queryRouter.route(query); + log(query, retrievers); + return retrievers; + }, + executor + ).thenCompose(retrievers -> retrieveFromAll(retrievers, query)); queryToFutureContents.put(query, futureContents); }); return join(queryToFutureContents); @@ -201,15 +208,14 @@ private Map>> process(Collection queries) private CompletableFuture>> retrieveFromAll(Collection retrievers, Query query) { List>> futureContents = retrievers.stream() - .map(retriever -> supplyAsync(() -> retrieve(retriever, query), executor)) - .collect(toList()); + .map(retriever -> supplyAsync(() -> retrieve(retriever, query), executor)) + .toList(); return allOf(futureContents.toArray(new CompletableFuture[0])) - .thenApply(ignored -> - futureContents.stream() - .map(CompletableFuture::join) - .collect(toList()) - ); + .thenApply(ignored -> + futureContents.stream() + .map(CompletableFuture::join) + .toList()); } private static List retrieve(ContentRetriever retriever, Query query) { @@ -219,15 +225,15 @@ private static List retrieve(ContentRetriever retriever, Query query) { } private static Map>> join( - Map>>> queryToFutureContents) { + Map>>> queryToFutureContents) { return allOf(queryToFutureContents.values().toArray(new CompletableFuture[0])) - .thenApply(ignored -> - queryToFutureContents.entrySet().stream() - .collect(toMap( - Map.Entry::getKey, - entry -> entry.getValue().join() - )) - ).join(); + .thenApply(ignored -> + queryToFutureContents.entrySet().stream() + .collect(toMap( + Map.Entry::getKey, + entry -> entry.getValue().join() + )) + ).join(); } private static void logQueries(Query originalQuery, Collection queries) { @@ -235,14 +241,14 @@ private static void logQueries(Query originalQuery, Collection queries) { Query transformedQuery = queries.iterator().next(); if (!transformedQuery.equals(originalQuery)) { log.debug("Transformed original query '{}' into '{}'", - originalQuery.text(), transformedQuery.text()); + originalQuery.text(), transformedQuery.text()); } - } else { + } else if (log.isDebugEnabled()){ log.debug("Transformed original query '{}' into the following queries:\n{}", - originalQuery.text(), queries.stream() - .map(Query::text) - .map(query -> "- '" + query + "'") - .collect(joining("\n"))); + originalQuery.text(), queries.stream() + .map(Query::text) + .map(query -> "- '" + query + "'") + .collect(joining("\n"))); } } @@ -250,27 +256,40 @@ private static void log(Query query, Collection retrievers) { // TODO use retriever id if (retrievers.size() == 1) { log.debug("Routing query '{}' to the following retriever: {}", - query.text(), retrievers.iterator().next()); - } else { + query.text(), retrievers.iterator().next()); + } else if (log.isDebugEnabled()) { log.debug("Routing query '{}' to the following retrievers:\n{}", - query.text(), retrievers.stream() - .map(retriever -> "- " + retriever.toString()) - .collect(joining("\n"))); + query.text(), retrievers.stream() + .map(retriever -> "- " + retriever.toString()) + .collect(joining("\n"))); } } private static void log(Query query, ContentRetriever retriever, List contents) { // TODO use retriever id log.debug("Retrieved {} contents using query '{}' and retriever '{}'", - contents.size(), query.text(), retriever); + contents.size(), query.text(), retriever); + + if (!log.isTraceEnabled()) { + return; + } - if (contents.size() > 0) { + if (!contents.isEmpty()) { + final var contentsSting = contents.stream() + .map(Content::textSegment) + .map(segment -> "- " + escapeNewlines(segment.text())) + .collect(joining("\n")); log.trace("Retrieved {} contents using query '{}' and retriever '{}':\n{}", - contents.size(), query.text(), retriever, contents.stream() - .map(Content::textSegment) - .map(segment -> "- " + escapeNewlines(segment.text())) - .collect(joining("\n"))); + contents.size(), + query.text(), + retriever.getClass().getName(), + contentsSting); + } else { + log.trace("Retrieved 0 contents using query '{}' and retriever '{}'", + query.text(), + retriever.getClass().getName()); } + } private static void log(Map>> queryToContents, List contents) { @@ -287,15 +306,21 @@ private static void log(Map>> queryToContents, L log.debug("Aggregated {} content(s) into {}", contentCount, contents.size()); - log.trace("Aggregated {} content(s) into:\n{}", + if (log.isTraceEnabled()) { + log.trace("Aggregated {} content(s) into:\n{}", contentCount, contents.stream() - .map(Content::textSegment) - .map(segment -> "- " + escapeNewlines(segment.text())) - .collect(joining("\n"))); + .map(Content::textSegment) + .map(segment -> "- " + escapeNewlines(segment.text())) + .collect(joining("\n"))); + } } private static void log(ChatMessage augmentedChatMessage) { - log.trace("Augmented chat message: {}", escapeNewlines(augmentedChatMessage.text())); + if (log.isTraceEnabled()) { + log.trace("Augmented chat message: {}", + escapeNewlines(augmentedChatMessage.text()) + ); + } } private static String escapeNewlines(String text) { @@ -308,9 +333,51 @@ public static DefaultRetrievalAugmentorBuilder builder() { public static class DefaultRetrievalAugmentorBuilder { + private QueryTransformer queryTransformer; + private QueryRouter queryRouter; + private ContentAggregator contentAggregator; + private ContentInjector contentInjector; + private Executor executor; + + DefaultRetrievalAugmentorBuilder() { + } + public DefaultRetrievalAugmentorBuilder contentRetriever(ContentRetriever contentRetriever) { this.queryRouter = new DefaultQueryRouter(ensureNotNull(contentRetriever, "contentRetriever")); return this; } + + public DefaultRetrievalAugmentorBuilder queryTransformer(QueryTransformer queryTransformer) { + this.queryTransformer = queryTransformer; + return this; + } + + public DefaultRetrievalAugmentorBuilder queryRouter(QueryRouter queryRouter) { + this.queryRouter = queryRouter; + return this; + } + + public DefaultRetrievalAugmentorBuilder contentAggregator(ContentAggregator contentAggregator) { + this.contentAggregator = contentAggregator; + return this; + } + + public DefaultRetrievalAugmentorBuilder contentInjector(ContentInjector contentInjector) { + this.contentInjector = contentInjector; + return this; + } + + public DefaultRetrievalAugmentorBuilder executor(Executor executor) { + this.executor = executor; + return this; + } + + public DefaultRetrievalAugmentor build() { + return new DefaultRetrievalAugmentor(this.queryTransformer, this.queryRouter, this.contentAggregator, this.contentInjector, this.executor); + } + + public String toString() { + return "DefaultRetrievalAugmentor.DefaultRetrievalAugmentorBuilder(queryTransformer=" + this.queryTransformer + ", queryRouter=" + this.queryRouter + ", contentAggregator=" + this.contentAggregator + ", contentInjector=" + this.contentInjector + ", executor=" + this.executor + ")"; + } } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/aggregator/ReRankingContentAggregator.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/aggregator/ReRankingContentAggregator.java index ece58665362..72bf1263d12 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/aggregator/ReRankingContentAggregator.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/aggregator/ReRankingContentAggregator.java @@ -5,9 +5,12 @@ import dev.langchain4j.rag.content.Content; import dev.langchain4j.rag.query.Query; import dev.langchain4j.rag.query.transformer.ExpandingQueryTransformer; -import lombok.Builder; -import java.util.*; +import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.function.Function; import static dev.langchain4j.internal.Exceptions.illegalArgument; @@ -64,7 +67,6 @@ public ReRankingContentAggregator(ScoringModel scoringModel) { this(scoringModel, DEFAULT_QUERY_SELECTOR, null); } - @Builder public ReRankingContentAggregator(ScoringModel scoringModel, Function>>, Query> querySelector, Double minScore) { @@ -73,6 +75,10 @@ public ReRankingContentAggregator(ScoringModel scoringModel, this.minScore = minScore; } + public static ReRankingContentAggregatorBuilder builder() { + return new ReRankingContentAggregatorBuilder(); + } + @Override public List aggregate(Map>> queryToContents) { @@ -126,4 +132,36 @@ protected List reRankAndFilter(List contents, Query query) { .map(Content::from) .collect(toList()); } + + public static class ReRankingContentAggregatorBuilder { + private ScoringModel scoringModel; + private Function>>, Query> querySelector; + private Double minScore; + + ReRankingContentAggregatorBuilder() { + } + + public ReRankingContentAggregatorBuilder scoringModel(ScoringModel scoringModel) { + this.scoringModel = scoringModel; + return this; + } + + public ReRankingContentAggregatorBuilder querySelector(Function>>, Query> querySelector) { + this.querySelector = querySelector; + return this; + } + + public ReRankingContentAggregatorBuilder minScore(Double minScore) { + this.minScore = minScore; + return this; + } + + public ReRankingContentAggregator build() { + return new ReRankingContentAggregator(this.scoringModel, this.querySelector, this.minScore); + } + + public String toString() { + return "ReRankingContentAggregator.ReRankingContentAggregatorBuilder(scoringModel=" + this.scoringModel + ", querySelector=" + this.querySelector + ", minScore=" + this.minScore + ")"; + } + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/aggregator/ReciprocalRankFuser.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/aggregator/ReciprocalRankFuser.java index 03aa30e314d..65e3d966d72 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/aggregator/ReciprocalRankFuser.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/aggregator/ReciprocalRankFuser.java @@ -2,7 +2,12 @@ import dev.langchain4j.rag.content.Content; -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import static dev.langchain4j.internal.ValidationUtils.ensureBetween; diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/injector/DefaultContentInjector.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/injector/DefaultContentInjector.java index 2a11d14f57b..811bdd49fc1 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/injector/DefaultContentInjector.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/injector/DefaultContentInjector.java @@ -7,14 +7,15 @@ import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.rag.content.Content; -import lombok.Builder; import java.util.HashMap; import java.util.List; import java.util.Map; -import static dev.langchain4j.data.message.UserMessage.userMessage; -import static dev.langchain4j.internal.Utils.*; +import static dev.langchain4j.internal.Utils.copyIfNotNull; +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.Utils.isNotNullOrBlank; +import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static java.util.stream.Collectors.joining; @@ -47,10 +48,10 @@ public class DefaultContentInjector implements ContentInjector { public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from( """ - {{userMessage}} - - Answer using the following information: - {{contents}}""" + {{userMessage}} + + Answer using the following information: + {{contents}}""" ); private final PromptTemplate promptTemplate; @@ -68,12 +69,15 @@ public DefaultContentInjector(PromptTemplate promptTemplate) { this(ensureNotNull(promptTemplate, "promptTemplate"), null); } - @Builder public DefaultContentInjector(PromptTemplate promptTemplate, List metadataKeysToInclude) { this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE); this.metadataKeysToInclude = copyIfNotNull(metadataKeysToInclude); } + public static DefaultContentInjectorBuilder builder() { + return new DefaultContentInjectorBuilder(); + } + @Override public ChatMessage inject(List contents, ChatMessage chatMessage) { @@ -161,4 +165,30 @@ protected String format(String segmentContent, String segmentMetadata) { ? segmentContent : "content: %s\n%s".formatted(segmentContent, segmentMetadata); } + + public static class DefaultContentInjectorBuilder { + private PromptTemplate promptTemplate; + private List metadataKeysToInclude; + + DefaultContentInjectorBuilder() { + } + + public DefaultContentInjectorBuilder promptTemplate(PromptTemplate promptTemplate) { + this.promptTemplate = promptTemplate; + return this; + } + + public DefaultContentInjectorBuilder metadataKeysToInclude(List metadataKeysToInclude) { + this.metadataKeysToInclude = metadataKeysToInclude; + return this; + } + + public DefaultContentInjector build() { + return new DefaultContentInjector(this.promptTemplate, this.metadataKeysToInclude); + } + + public String toString() { + return "DefaultContentInjector.DefaultContentInjectorBuilder(promptTemplate=" + this.promptTemplate + ", metadataKeysToInclude=" + this.metadataKeysToInclude + ")"; + } + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/EmbeddingStoreContentRetriever.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/EmbeddingStoreContentRetriever.java index 36949b04c66..3569a0c2903 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/EmbeddingStoreContentRetriever.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/EmbeddingStoreContentRetriever.java @@ -11,14 +11,15 @@ import dev.langchain4j.store.embedding.EmbeddingSearchResult; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.filter.Filter; -import lombok.Builder; import java.util.Collection; import java.util.List; import java.util.function.Function; import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.internal.ValidationUtils.*; +import static dev.langchain4j.internal.ValidationUtils.ensureBetween; +import static dev.langchain4j.internal.ValidationUtils.ensureGreaterThanZero; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static java.util.stream.Collectors.toList; @@ -109,7 +110,6 @@ public EmbeddingStoreContentRetriever(EmbeddingStore embeddingStore ); } - @Builder private EmbeddingStoreContentRetriever(String displayName, EmbeddingStore embeddingStore, EmbeddingModel embeddingModel, @@ -141,8 +141,22 @@ private static EmbeddingModel loadEmbeddingModel() { return null; } + public static EmbeddingStoreContentRetrieverBuilder builder() { + return new EmbeddingStoreContentRetrieverBuilder(); + } + public static class EmbeddingStoreContentRetrieverBuilder { + private String displayName; + private EmbeddingStore embeddingStore; + private EmbeddingModel embeddingModel; + private Function dynamicMaxResults; + private Function dynamicMinScore; + private Function dynamicFilter; + + EmbeddingStoreContentRetrieverBuilder() { + } + public EmbeddingStoreContentRetrieverBuilder maxResults(Integer maxResults) { if (maxResults != null) { dynamicMaxResults = (query) -> ensureGreaterThanZero(maxResults, "maxResults"); @@ -163,6 +177,44 @@ public EmbeddingStoreContentRetrieverBuilder filter(Filter filter) { } return this; } + + public EmbeddingStoreContentRetrieverBuilder displayName(String displayName) { + this.displayName = displayName; + return this; + } + + public EmbeddingStoreContentRetrieverBuilder embeddingStore(EmbeddingStore embeddingStore) { + this.embeddingStore = embeddingStore; + return this; + } + + public EmbeddingStoreContentRetrieverBuilder embeddingModel(EmbeddingModel embeddingModel) { + this.embeddingModel = embeddingModel; + return this; + } + + public EmbeddingStoreContentRetrieverBuilder dynamicMaxResults(Function dynamicMaxResults) { + this.dynamicMaxResults = dynamicMaxResults; + return this; + } + + public EmbeddingStoreContentRetrieverBuilder dynamicMinScore(Function dynamicMinScore) { + this.dynamicMinScore = dynamicMinScore; + return this; + } + + public EmbeddingStoreContentRetrieverBuilder dynamicFilter(Function dynamicFilter) { + this.dynamicFilter = dynamicFilter; + return this; + } + + public EmbeddingStoreContentRetriever build() { + return new EmbeddingStoreContentRetriever(this.displayName, this.embeddingStore, this.embeddingModel, this.dynamicMaxResults, this.dynamicMinScore, this.dynamicFilter); + } + + public String toString() { + return "EmbeddingStoreContentRetriever.EmbeddingStoreContentRetrieverBuilder(displayName=" + this.displayName + ", embeddingStore=" + this.embeddingStore + ", embeddingModel=" + this.embeddingModel + ", dynamicMaxResults=" + this.dynamicMaxResults + ", dynamicMinScore=" + this.dynamicMinScore + ", dynamicFilter=" + this.dynamicFilter + ")"; + } } /** diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetriever.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetriever.java index 061ede10df3..d8922c5b9f9 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetriever.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetriever.java @@ -5,7 +5,6 @@ import dev.langchain4j.web.search.WebSearchEngine; import dev.langchain4j.web.search.WebSearchRequest; import dev.langchain4j.web.search.WebSearchResults; -import lombok.Builder; import java.util.List; @@ -26,12 +25,15 @@ public class WebSearchContentRetriever implements ContentRetriever { private final WebSearchEngine webSearchEngine; private final int maxResults; - @Builder public WebSearchContentRetriever(WebSearchEngine webSearchEngine, Integer maxResults) { this.webSearchEngine = ensureNotNull(webSearchEngine, "webSearchEngine"); this.maxResults = getOrDefault(maxResults, 5); } + public static WebSearchContentRetrieverBuilder builder() { + return new WebSearchContentRetrieverBuilder(); + } + @Override public List retrieve(Query query) { @@ -46,4 +48,30 @@ public List retrieve(Query query) { .map(Content::from) .collect(toList()); } + + public static class WebSearchContentRetrieverBuilder { + private WebSearchEngine webSearchEngine; + private Integer maxResults; + + WebSearchContentRetrieverBuilder() { + } + + public WebSearchContentRetrieverBuilder webSearchEngine(WebSearchEngine webSearchEngine) { + this.webSearchEngine = webSearchEngine; + return this; + } + + public WebSearchContentRetrieverBuilder maxResults(Integer maxResults) { + this.maxResults = maxResults; + return this; + } + + public WebSearchContentRetriever build() { + return new WebSearchContentRetriever(this.webSearchEngine, this.maxResults); + } + + public String toString() { + return "WebSearchContentRetriever.WebSearchContentRetrieverBuilder(webSearchEngine=" + this.webSearchEngine + ", maxResults=" + this.maxResults + ")"; + } + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/query/router/LanguageModelQueryRouter.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/query/router/LanguageModelQueryRouter.java index af97ba704b2..38fdbd26e8f 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/query/router/LanguageModelQueryRouter.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/query/router/LanguageModelQueryRouter.java @@ -6,7 +6,6 @@ import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.rag.content.retriever.ContentRetriever; import dev.langchain4j.rag.query.Query; -import lombok.Builder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -16,7 +15,9 @@ import java.util.Map; import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.internal.ValidationUtils.*; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static dev.langchain4j.rag.query.router.LanguageModelQueryRouter.FallbackStrategy.DO_NOT_ROUTE; import static java.util.Arrays.stream; import static java.util.Collections.emptyList; @@ -46,12 +47,12 @@ public class LanguageModelQueryRouter implements QueryRouter { public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from( """ - Based on the user query, determine the most suitable data source(s) \ - to retrieve relevant information from the following options: - {{options}} - It is very important that your answer consists of either a single number \ - or multiple numbers separated by commas and nothing else! - User query: {{query}}""" + Based on the user query, determine the most suitable data source(s) \ + to retrieve relevant information from the following options: + {{options}} + It is very important that your answer consists of either a single number \ + or multiple numbers separated by commas and nothing else! + User query: {{query}}""" ); protected final ChatLanguageModel chatLanguageModel; @@ -65,7 +66,6 @@ public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel, this(chatLanguageModel, retrieverToDescription, DEFAULT_PROMPT_TEMPLATE, DO_NOT_ROUTE); } - @Builder public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel, Map retrieverToDescription, PromptTemplate promptTemplate, @@ -94,6 +94,10 @@ public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel, this.fallbackStrategy = getOrDefault(fallbackStrategy, DO_NOT_ROUTE); } + public static LanguageModelQueryRouterBuilder builder() { + return new LanguageModelQueryRouterBuilder(); + } + @Override public Collection route(Query query) { Prompt prompt = createPrompt(query); @@ -107,17 +111,17 @@ public Collection route(Query query) { } protected Collection fallback(Query query, Exception e) { - switch (fallbackStrategy) { - case DO_NOT_ROUTE: + return switch (fallbackStrategy) { + case DO_NOT_ROUTE -> { log.debug("Fallback: query '{}' will not be routed", query.text()); - return emptyList(); - case ROUTE_TO_ALL: + yield emptyList(); + } + case ROUTE_TO_ALL -> { log.debug("Fallback: query '{}' will be routed to all available content retrievers", query.text()); - return new ArrayList<>(idToRetriever.values()); - case FAIL: - default: - throw new RuntimeException(e); - } + yield new ArrayList<>(idToRetriever.values()); + } + default -> throw new RuntimeException(e); + }; } protected Prompt createPrompt(Query query) { @@ -157,4 +161,42 @@ public enum FallbackStrategy { */ FAIL } + + public static class LanguageModelQueryRouterBuilder { + private ChatLanguageModel chatLanguageModel; + private Map retrieverToDescription; + private PromptTemplate promptTemplate; + private FallbackStrategy fallbackStrategy; + + LanguageModelQueryRouterBuilder() { + } + + public LanguageModelQueryRouterBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) { + this.chatLanguageModel = chatLanguageModel; + return this; + } + + public LanguageModelQueryRouterBuilder retrieverToDescription(Map retrieverToDescription) { + this.retrieverToDescription = retrieverToDescription; + return this; + } + + public LanguageModelQueryRouterBuilder promptTemplate(PromptTemplate promptTemplate) { + this.promptTemplate = promptTemplate; + return this; + } + + public LanguageModelQueryRouterBuilder fallbackStrategy(FallbackStrategy fallbackStrategy) { + this.fallbackStrategy = fallbackStrategy; + return this; + } + + public LanguageModelQueryRouter build() { + return new LanguageModelQueryRouter(this.chatLanguageModel, this.retrieverToDescription, this.promptTemplate, this.fallbackStrategy); + } + + public String toString() { + return "LanguageModelQueryRouter.LanguageModelQueryRouterBuilder(chatLanguageModel=" + this.chatLanguageModel + ", retrieverToDescription=" + this.retrieverToDescription + ", promptTemplate=" + this.promptTemplate + ", fallbackStrategy=" + this.fallbackStrategy + ")"; + } + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/CompressingQueryTransformer.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/CompressingQueryTransformer.java index 05f9a3d4131..c27012d6348 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/CompressingQueryTransformer.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/CompressingQueryTransformer.java @@ -8,9 +8,12 @@ import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.rag.query.Query; -import lombok.Builder; -import java.util.*; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; @@ -35,18 +38,18 @@ public class CompressingQueryTransformer implements QueryTransformer { public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from( """ - Read and understand the conversation between the User and the AI. \ - Then, analyze the new query from the User. \ - Identify all relevant details, terms, and context from both the conversation and the new query. \ - Reformulate this query into a clear, concise, and self-contained format suitable for information retrieval. - - Conversation: - {{chatMemory}} - - User query: {{query}} - - It is very important that you provide only reformulated query and nothing else! \ - Do not prepend a query with anything!""" + Read and understand the conversation between the User and the AI. \ + Then, analyze the new query from the User. \ + Identify all relevant details, terms, and context from both the conversation and the new query. \ + Reformulate this query into a clear, concise, and self-contained format suitable for information retrieval. + + Conversation: + {{chatMemory}} + + User query: {{query}} + + It is very important that you provide only reformulated query and nothing else! \ + Do not prepend a query with anything!""" ); protected final PromptTemplate promptTemplate; @@ -56,12 +59,15 @@ public CompressingQueryTransformer(ChatLanguageModel chatLanguageModel) { this(chatLanguageModel, DEFAULT_PROMPT_TEMPLATE); } - @Builder public CompressingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate) { this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel"); this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE); } + public static CompressingQueryTransformerBuilder builder() { + return new CompressingQueryTransformerBuilder(); + } + @Override public Collection transform(Query query) { @@ -105,4 +111,30 @@ protected Prompt createPrompt(Query query, String chatMemory) { variables.put("chatMemory", chatMemory); return promptTemplate.apply(variables); } + + public static class CompressingQueryTransformerBuilder { + private ChatLanguageModel chatLanguageModel; + private PromptTemplate promptTemplate; + + CompressingQueryTransformerBuilder() { + } + + public CompressingQueryTransformerBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) { + this.chatLanguageModel = chatLanguageModel; + return this; + } + + public CompressingQueryTransformerBuilder promptTemplate(PromptTemplate promptTemplate) { + this.promptTemplate = promptTemplate; + return this; + } + + public CompressingQueryTransformer build() { + return new CompressingQueryTransformer(this.chatLanguageModel, this.promptTemplate); + } + + public String toString() { + return "CompressingQueryTransformer.CompressingQueryTransformerBuilder(chatLanguageModel=" + this.chatLanguageModel + ", promptTemplate=" + this.promptTemplate + ")"; + } + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/ExpandingQueryTransformer.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/ExpandingQueryTransformer.java index 0dca9bcb499..c098e5afe5b 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/ExpandingQueryTransformer.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/ExpandingQueryTransformer.java @@ -5,7 +5,6 @@ import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.rag.query.Query; -import lombok.Builder; import java.util.Collection; import java.util.HashMap; @@ -37,13 +36,13 @@ public class ExpandingQueryTransformer implements QueryTransformer { public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from( """ - Generate {{n}} different versions of a provided user query. \ - Each version should be worded differently, using synonyms or alternative sentence structures, \ - but they should all retain the original meaning. \ - These versions will be used to retrieve relevant documents. \ - It is very important to provide each query version on a separate line, \ - without enumerations, hyphens, or any additional formatting! - User query: {{query}}""" + Generate {{n}} different versions of a provided user query. \ + Each version should be worded differently, using synonyms or alternative sentence structures, \ + but they should all retain the original meaning. \ + These versions will be used to retrieve relevant documents. \ + It is very important to provide each query version on a separate line, \ + without enumerations, hyphens, or any additional formatting! + User query: {{query}}""" ); public static final int DEFAULT_N = 3; @@ -63,13 +62,16 @@ public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemp this(chatLanguageModel, ensureNotNull(promptTemplate, "promptTemplate"), DEFAULT_N); } - @Builder public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate, Integer n) { this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel"); this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE); this.n = ensureGreaterThanZero(getOrDefault(n, DEFAULT_N), "n"); } + public static ExpandingQueryTransformerBuilder builder() { + return new ExpandingQueryTransformerBuilder(); + } + @Override public Collection transform(Query query) { Prompt prompt = createPrompt(query); @@ -94,4 +96,36 @@ protected List parse(String queries) { .filter(Utils::isNotNullOrBlank) .collect(toList()); } + + public static class ExpandingQueryTransformerBuilder { + private ChatLanguageModel chatLanguageModel; + private PromptTemplate promptTemplate; + private Integer n; + + ExpandingQueryTransformerBuilder() { + } + + public ExpandingQueryTransformerBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) { + this.chatLanguageModel = chatLanguageModel; + return this; + } + + public ExpandingQueryTransformerBuilder promptTemplate(PromptTemplate promptTemplate) { + this.promptTemplate = promptTemplate; + return this; + } + + public ExpandingQueryTransformerBuilder n(Integer n) { + this.n = n; + return this; + } + + public ExpandingQueryTransformer build() { + return new ExpandingQueryTransformer(this.chatLanguageModel, this.promptTemplate, this.n); + } + + public String toString() { + return "ExpandingQueryTransformer.ExpandingQueryTransformerBuilder(chatLanguageModel=" + this.chatLanguageModel + ", promptTemplate=" + this.promptTemplate + ", n=" + this.n + ")"; + } + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/retriever/EmbeddingStoreRetriever.java b/langchain4j-core/src/main/java/dev/langchain4j/retriever/EmbeddingStoreRetriever.java index 7319c8ecf3d..b2ba707c97b 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/retriever/EmbeddingStoreRetriever.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/retriever/EmbeddingStoreRetriever.java @@ -1,7 +1,7 @@ package dev.langchain4j.retriever; -import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever; import dev.langchain4j.store.embedding.EmbeddingMatch; diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingSearchRequest.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingSearchRequest.java index eb85ba492eb..e0afd77a572 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingSearchRequest.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingSearchRequest.java @@ -4,18 +4,17 @@ import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.filter.Filter; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.ToString; + +import java.util.Objects; import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.internal.ValidationUtils.*; +import static dev.langchain4j.internal.ValidationUtils.ensureBetween; +import static dev.langchain4j.internal.ValidationUtils.ensureGreaterThanZero; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; /** * Represents a request to search in an {@link EmbeddingStore}. */ -@ToString -@EqualsAndHashCode public class EmbeddingSearchRequest { private final Embedding queryEmbedding; @@ -38,7 +37,6 @@ public class EmbeddingSearchRequest { * Please note that not all {@link EmbeddingStore}s support this feature yet. * This is an optional parameter. Default: no filtering */ - @Builder public EmbeddingSearchRequest(Embedding queryEmbedding, Integer maxResults, Double minScore, Filter filter) { this.queryEmbedding = ensureNotNull(queryEmbedding, "queryEmbedding"); this.maxResults = ensureGreaterThanZero(getOrDefault(maxResults, 3), "maxResults"); @@ -46,6 +44,10 @@ public EmbeddingSearchRequest(Embedding queryEmbedding, Integer maxResults, Doub this.filter = filter; } + public static EmbeddingSearchRequestBuilder builder() { + return new EmbeddingSearchRequestBuilder(); + } + public Embedding queryEmbedding() { return queryEmbedding; } @@ -61,4 +63,59 @@ public double minScore() { public Filter filter() { return filter; } + + public boolean equals(final Object o) { + if (o == this) return true; + if (!(o instanceof EmbeddingSearchRequest other)) return false; + return this.maxResults == other.maxResults + && this.minScore == other.minScore + && Objects.equals(this.queryEmbedding, other.queryEmbedding) + && Objects.equals(this.filter, other.filter); + } + + public int hashCode() { + return Objects.hash(queryEmbedding, maxResults, minScore, filter); + } + + public String toString() { + return "EmbeddingSearchRequest(queryEmbedding=" + this.queryEmbedding + ", maxResults=" + this.maxResults + ", minScore=" + this.minScore + ", filter=" + this.filter + ")"; + } + + public static class EmbeddingSearchRequestBuilder { + private Embedding queryEmbedding; + private Integer maxResults; + private Double minScore; + private Filter filter; + + EmbeddingSearchRequestBuilder() { + } + + public EmbeddingSearchRequestBuilder queryEmbedding(Embedding queryEmbedding) { + this.queryEmbedding = queryEmbedding; + return this; + } + + public EmbeddingSearchRequestBuilder maxResults(Integer maxResults) { + this.maxResults = maxResults; + return this; + } + + public EmbeddingSearchRequestBuilder minScore(Double minScore) { + this.minScore = minScore; + return this; + } + + public EmbeddingSearchRequestBuilder filter(Filter filter) { + this.filter = filter; + return this; + } + + public EmbeddingSearchRequest build() { + return new EmbeddingSearchRequest(this.queryEmbedding, this.maxResults, this.minScore, this.filter); + } + + public String toString() { + return "EmbeddingSearchRequest.EmbeddingSearchRequestBuilder(queryEmbedding=" + this.queryEmbedding + ", maxResults=" + this.maxResults + ", minScore=" + this.minScore + ", filter=" + this.filter + ")"; + } + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreIngestor.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreIngestor.java index 829cbc3e30e..2fe33000a64 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreIngestor.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreIngestor.java @@ -10,7 +10,8 @@ import dev.langchain4j.model.output.Response; import dev.langchain4j.spi.data.document.splitter.DocumentSplitterFactory; import dev.langchain4j.spi.model.embedding.EmbeddingModelFactory; -import lombok.extern.slf4j.Slf4j; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.Collection; import java.util.List; @@ -48,9 +49,10 @@ * Including a document title or a short summary in each {@code TextSegment} is a common technique * to improve the quality of similarity searches. */ -@Slf4j public class EmbeddingStoreIngestor { + private static final Logger log = LoggerFactory.getLogger(EmbeddingStoreIngestor.class); + private final DocumentTransformer documentTransformer; private final DocumentSplitter documentSplitter; private final TextSegmentTransformer textSegmentTransformer; @@ -123,6 +125,7 @@ private static EmbeddingModel loadEmbeddingModel() { *
* For the "Easy RAG", import {@code langchain4j-easy-rag} module, * which contains a {@code DocumentSplitterFactory} and {@code EmbeddingModelFactory} implementations. + * * @return result including information related to ingestion process. */ public static IngestionResult ingest(Document document, EmbeddingStore embeddingStore) { @@ -137,6 +140,7 @@ public static IngestionResult ingest(Document document, EmbeddingStore * For the "Easy RAG", import {@code langchain4j-easy-rag} module, * which contains a {@code DocumentSplitterFactory} and {@code EmbeddingModelFactory} implementations. + * * @return result including information related to ingestion process. */ public static IngestionResult ingest(List documents, EmbeddingStore embeddingStore) { diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/Filter.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/Filter.java index f24e580232a..ae1b1fbe573 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/Filter.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/Filter.java @@ -1,7 +1,14 @@ package dev.langchain4j.store.embedding.filter; import dev.langchain4j.store.embedding.EmbeddingStore; -import dev.langchain4j.store.embedding.filter.comparison.*; +import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsIn; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThan; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotIn; import dev.langchain4j.store.embedding.filter.logical.And; import dev.langchain4j.store.embedding.filter.logical.Not; import dev.langchain4j.store.embedding.filter.logical.Or; @@ -59,4 +66,4 @@ static Filter or(Filter left, Filter right) { static Filter not(Filter expression) { return new Not(expression); } -} \ No newline at end of file +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/MetadataFilterBuilder.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/MetadataFilterBuilder.java index 593ad7a6704..ab104f053a4 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/MetadataFilterBuilder.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/MetadataFilterBuilder.java @@ -1,7 +1,14 @@ package dev.langchain4j.store.embedding.filter; import dev.langchain4j.data.document.Metadata; -import dev.langchain4j.store.embedding.filter.comparison.*; +import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsIn; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThan; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotIn; import java.util.ArrayList; import java.util.Collection; diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsEqualTo.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsEqualTo.java index 63c444ec367..28d17bf7b7c 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsEqualTo.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsEqualTo.java @@ -2,9 +2,8 @@ import dev.langchain4j.data.document.Metadata; import dev.langchain4j.store.embedding.filter.Filter; -import lombok.EqualsAndHashCode; -import lombok.ToString; +import java.util.Objects; import java.util.UUID; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; @@ -12,8 +11,6 @@ import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.compareAsBigDecimals; import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible; -@ToString -@EqualsAndHashCode public class IsEqualTo implements Filter { private final String key; @@ -34,11 +31,10 @@ public Object comparisonValue() { @Override public boolean test(Object object) { - if (!(object instanceof Metadata)) { + if (!(object instanceof Metadata metadata)) { return false; } - Metadata metadata = (Metadata) object; if (!metadata.containsKey(key)) { return false; } @@ -56,4 +52,20 @@ public boolean test(Object object) { return actualValue.equals(comparisonValue); } + + public boolean equals(final Object o) { + if (o == this) return true; + if (!(o instanceof IsEqualTo other)) return false; + + return Objects.equals(this.key, other.key) + && Objects.equals(this.comparisonValue, other.comparisonValue); + } + + public int hashCode() { + return Objects.hash(key, comparisonValue); + } + + public String toString() { + return "IsEqualTo(key=" + this.key + ", comparisonValue=" + this.comparisonValue + ")"; + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsGreaterThan.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsGreaterThan.java index 3fb0d0e6320..a2bb729d89d 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsGreaterThan.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsGreaterThan.java @@ -2,16 +2,14 @@ import dev.langchain4j.data.document.Metadata; import dev.langchain4j.store.embedding.filter.Filter; -import lombok.EqualsAndHashCode; -import lombok.ToString; + +import java.util.Objects; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.compareAsBigDecimals; import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible; -@ToString -@EqualsAndHashCode public class IsGreaterThan implements Filter { private final String key; @@ -32,11 +30,10 @@ public Comparable comparisonValue() { @Override public boolean test(Object object) { - if (!(object instanceof Metadata)) { + if (!(object instanceof Metadata metadata)) { return false; } - Metadata metadata = (Metadata) object; if (!metadata.containsKey(key)) { return false; } @@ -50,4 +47,20 @@ public boolean test(Object object) { return ((Comparable) actualValue).compareTo(comparisonValue) > 0; } + + public boolean equals(final Object o) { + if (o == this) return true; + if (!(o instanceof IsGreaterThan other)) return false; + + return Objects.equals(this.key, other.key) + && Objects.equals(this.comparisonValue, other.comparisonValue); + } + + public int hashCode() { + return Objects.hash(key, comparisonValue); + } + + public String toString() { + return "IsGreaterThan(key=" + this.key + ", comparisonValue=" + this.comparisonValue + ")"; + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsGreaterThanOrEqualTo.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsGreaterThanOrEqualTo.java index ab3d10efd91..7446b25fbc6 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsGreaterThanOrEqualTo.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsGreaterThanOrEqualTo.java @@ -2,16 +2,14 @@ import dev.langchain4j.data.document.Metadata; import dev.langchain4j.store.embedding.filter.Filter; -import lombok.EqualsAndHashCode; -import lombok.ToString; + +import java.util.Objects; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.compareAsBigDecimals; import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible; -@ToString -@EqualsAndHashCode public class IsGreaterThanOrEqualTo implements Filter { private final String key; @@ -32,11 +30,10 @@ public Comparable comparisonValue() { @Override public boolean test(Object object) { - if (!(object instanceof Metadata)) { + if (!(object instanceof Metadata metadata)) { return false; } - Metadata metadata = (Metadata) object; if (!metadata.containsKey(key)) { return false; } @@ -50,4 +47,20 @@ public boolean test(Object object) { return ((Comparable) actualValue).compareTo(comparisonValue) >= 0; } + + public boolean equals(final Object o) { + if (o == this) return true; + if (!(o instanceof IsGreaterThanOrEqualTo other)) return false; + + return Objects.equals(this.key, other.key) + && Objects.equals(this.comparisonValue, other.comparisonValue); + } + + public int hashCode() { + return Objects.hash(key, comparisonValue); + } + + public String toString() { + return "IsGreaterThanOrEqualTo(key=" + this.key + ", comparisonValue=" + this.comparisonValue + ")"; + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsIn.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsIn.java index efcf4570947..6c57119c46f 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsIn.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsIn.java @@ -2,22 +2,21 @@ import dev.langchain4j.data.document.Metadata; import dev.langchain4j.store.embedding.filter.Filter; -import lombok.EqualsAndHashCode; -import lombok.ToString; import java.util.Collection; import java.util.HashSet; +import java.util.Objects; import java.util.Set; import java.util.UUID; -import static dev.langchain4j.internal.ValidationUtils.*; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.containsAsBigDecimals; import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible; import static dev.langchain4j.store.embedding.filter.comparison.UUIDComparator.containsAsUUID; import static java.util.Collections.unmodifiableSet; -@ToString -@EqualsAndHashCode public class IsIn implements Filter { private final String key; @@ -40,11 +39,10 @@ public Collection comparisonValues() { @Override public boolean test(Object object) { - if (!(object instanceof Metadata)) { + if (!(object instanceof Metadata metadata)) { return false; } - Metadata metadata = (Metadata) object; if (!metadata.containsKey(key)) { return false; } @@ -61,4 +59,21 @@ public boolean test(Object object) { return comparisonValues.contains(actualValue); } + + public boolean equals(final Object o) { + if (o == this) return true; + if (!(o instanceof IsIn other)) return false; + + return Objects.equals(this.key, other.key) + && Objects.equals(this.comparisonValues, other.comparisonValues); + } + + public int hashCode() { + return Objects.hash(key, comparisonValues); + } + + + public String toString() { + return "IsIn(key=" + this.key + ", comparisonValues=" + this.comparisonValues + ")"; + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsLessThan.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsLessThan.java index dd4065e107e..28a2ea31f5e 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsLessThan.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsLessThan.java @@ -2,16 +2,14 @@ import dev.langchain4j.data.document.Metadata; import dev.langchain4j.store.embedding.filter.Filter; -import lombok.EqualsAndHashCode; -import lombok.ToString; + +import java.util.Objects; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.compareAsBigDecimals; import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible; -@ToString -@EqualsAndHashCode public class IsLessThan implements Filter { private final String key; @@ -32,11 +30,10 @@ public Comparable comparisonValue() { @Override public boolean test(Object object) { - if (!(object instanceof Metadata)) { + if (!(object instanceof Metadata metadata)) { return false; } - Metadata metadata = (Metadata) object; if (!metadata.containsKey(key)) { return false; } @@ -50,4 +47,20 @@ public boolean test(Object object) { return ((Comparable) actualValue).compareTo(comparisonValue) < 0; } + + public boolean equals(final Object o) { + if (o == this) return true; + if (!(o instanceof IsLessThan other)) return false; + + return Objects.equals(this.key, other.key) + && Objects.equals(this.comparisonValue, other.comparisonValue); + } + + public int hashCode() { + return Objects.hash(key, comparisonValue); + } + + public String toString() { + return "IsLessThan(key=" + this.key + ", comparisonValue=" + this.comparisonValue + ")"; + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsLessThanOrEqualTo.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsLessThanOrEqualTo.java index e577ba4850f..5db113f32f1 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsLessThanOrEqualTo.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsLessThanOrEqualTo.java @@ -2,16 +2,14 @@ import dev.langchain4j.data.document.Metadata; import dev.langchain4j.store.embedding.filter.Filter; -import lombok.EqualsAndHashCode; -import lombok.ToString; + +import java.util.Objects; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.compareAsBigDecimals; import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible; -@ToString -@EqualsAndHashCode public class IsLessThanOrEqualTo implements Filter { private final String key; @@ -32,11 +30,10 @@ public Comparable comparisonValue() { @Override public boolean test(Object object) { - if (!(object instanceof Metadata)) { + if (!(object instanceof Metadata metadata)) { return false; } - Metadata metadata = (Metadata) object; if (!metadata.containsKey(key)) { return false; } @@ -50,4 +47,21 @@ public boolean test(Object object) { return ((Comparable) actualValue).compareTo(comparisonValue) <= 0; } + + + public boolean equals(final Object o) { + if (o == this) return true; + if (!(o instanceof IsLessThanOrEqualTo other)) return false; + + return Objects.equals(this.key, other.key) + && Objects.equals(this.comparisonValue, other.comparisonValue); + } + + public int hashCode() { + return Objects.hash(key, comparisonValue); + } + + public String toString() { + return "IsLessThanOrEqualTo(key=" + this.key + ", comparisonValue=" + this.comparisonValue + ")"; + } } \ No newline at end of file diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsNotEqualTo.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsNotEqualTo.java index 8dbce02855d..fc05cb47867 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsNotEqualTo.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsNotEqualTo.java @@ -2,9 +2,8 @@ import dev.langchain4j.data.document.Metadata; import dev.langchain4j.store.embedding.filter.Filter; -import lombok.EqualsAndHashCode; -import lombok.ToString; +import java.util.Objects; import java.util.UUID; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; @@ -12,8 +11,6 @@ import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.compareAsBigDecimals; import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible; -@ToString -@EqualsAndHashCode public class IsNotEqualTo implements Filter { private final String key; @@ -34,11 +31,10 @@ public Object comparisonValue() { @Override public boolean test(Object object) { - if (!(object instanceof Metadata)) { + if (!(object instanceof Metadata metadata)) { return false; } - Metadata metadata = (Metadata) object; if (!metadata.containsKey(key)) { return true; } @@ -56,4 +52,20 @@ public boolean test(Object object) { return !actualValue.equals(comparisonValue); } + + public boolean equals(final Object o) { + if (o == this) return true; + if (!(o instanceof IsNotEqualTo other)) return false; + + return Objects.equals(this.key, other.key) + && Objects.equals(this.comparisonValue, other.comparisonValue); + } + + public int hashCode() { + return Objects.hash(key, comparisonValue); + } + + public String toString() { + return "IsNotEqualTo(key=" + this.key + ", comparisonValue=" + this.comparisonValue + ")"; + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsNotIn.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsNotIn.java index b0c3a2ed10c..ca4d852bef0 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsNotIn.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/comparison/IsNotIn.java @@ -2,22 +2,21 @@ import dev.langchain4j.data.document.Metadata; import dev.langchain4j.store.embedding.filter.Filter; -import lombok.EqualsAndHashCode; -import lombok.ToString; import java.util.Collection; import java.util.HashSet; +import java.util.Objects; import java.util.Set; import java.util.UUID; -import static dev.langchain4j.internal.ValidationUtils.*; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.containsAsBigDecimals; import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible; import static dev.langchain4j.store.embedding.filter.comparison.UUIDComparator.containsAsUUID; import static java.util.Collections.unmodifiableSet; -@ToString -@EqualsAndHashCode public class IsNotIn implements Filter { private final String key; @@ -40,11 +39,10 @@ public Collection comparisonValues() { @Override public boolean test(Object object) { - if (!(object instanceof Metadata)) { + if (!(object instanceof Metadata metadata)) { return false; } - Metadata metadata = (Metadata) object; if (!metadata.containsKey(key)) { return true; } @@ -61,4 +59,20 @@ public boolean test(Object object) { return !comparisonValues.contains(actualValue); } + + public boolean equals(final Object o) { + if (o == this) return true; + if (!(o instanceof IsNotIn other)) return false; + + return Objects.equals(this.key, other.key) + && Objects.equals(this.comparisonValues, other.comparisonValues); + } + + public int hashCode() { + return Objects.hash(key, comparisonValues); + } + + public String toString() { + return "IsNotIn(key=" + this.key + ", comparisonValues=" + this.comparisonValues + ")"; + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/logical/And.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/logical/And.java index cec73dcf50a..dddd1fed662 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/logical/And.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/logical/And.java @@ -1,13 +1,11 @@ package dev.langchain4j.store.embedding.filter.logical; import dev.langchain4j.store.embedding.filter.Filter; -import lombok.EqualsAndHashCode; -import lombok.ToString; + +import java.util.Objects; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; -@ToString -@EqualsAndHashCode public class And implements Filter { private final Filter left; @@ -30,4 +28,18 @@ public Filter right() { public boolean test(Object object) { return left().test(object) && right().test(object); } + + public boolean equals(final Object o) { + if (o == this) return true; + if (!(o instanceof And other)) return false; + return Objects.equals(this.left, other.left) && Objects.equals(this.right, other.right); + } + + public int hashCode() { + return Objects.hash(left, right); + } + + public String toString() { + return "And(left=" + this.left + ", right=" + this.right + ")"; + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/logical/Not.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/logical/Not.java index c319b84cf7d..ce222a880d0 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/logical/Not.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/logical/Not.java @@ -1,13 +1,11 @@ package dev.langchain4j.store.embedding.filter.logical; import dev.langchain4j.store.embedding.filter.Filter; -import lombok.EqualsAndHashCode; -import lombok.ToString; + +import java.util.Objects; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; -@ToString -@EqualsAndHashCode public class Not implements Filter { private final Filter expression; @@ -24,4 +22,18 @@ public Filter expression() { public boolean test(Object object) { return !expression.test(object); } + + public boolean equals(final Object o) { + if (o == this) return true; + if (!(o instanceof Not other)) return false; + return Objects.equals(this.expression, other.expression); + } + + public int hashCode() { + return Objects.hash(expression); + } + + public String toString() { + return "Not(expression=" + this.expression + ")"; + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/logical/Or.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/logical/Or.java index 1b5afcddd76..c8a71444983 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/logical/Or.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/logical/Or.java @@ -1,13 +1,11 @@ package dev.langchain4j.store.embedding.filter.logical; import dev.langchain4j.store.embedding.filter.Filter; -import lombok.EqualsAndHashCode; -import lombok.ToString; + +import java.util.Objects; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; -@ToString -@EqualsAndHashCode public class Or implements Filter { private final Filter left; @@ -30,4 +28,18 @@ public Filter right() { public boolean test(Object object) { return left().test(object) || right().test(object); } + + public boolean equals(final Object o) { + if (o == this) return true; + if (!(o instanceof Or other)) return false; + return Objects.equals(this.left, other.left) && Objects.equals(this.right, other.right); + } + + public int hashCode() { + return Objects.hash(left, right); + } + + public String toString() { + return "Or(left=" + this.left + ", right=" + this.right + ")"; + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchResults.java b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchResults.java index f191143e1ac..c5ceae36c34 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchResults.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchResults.java @@ -9,7 +9,6 @@ import java.util.Objects; import static dev.langchain4j.internal.Utils.isNullOrEmpty; -import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static java.util.stream.Collectors.toList; diff --git a/langchain4j-core/src/test/java/dev/langchain4j/agent/tool/ToolSpecificationsTest.java b/langchain4j-core/src/test/java/dev/langchain4j/agent/tool/ToolSpecificationsTest.java index ecf14a05788..1277dabc71d 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/agent/tool/ToolSpecificationsTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/agent/tool/ToolSpecificationsTest.java @@ -9,7 +9,6 @@ import dev.langchain4j.model.chat.request.json.JsonSchemaElement; import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.output.structured.Description; -import lombok.Data; import org.assertj.core.api.WithAssertions; import org.junit.jupiter.api.Test; @@ -24,16 +23,15 @@ class ToolSpecificationsTest implements WithAssertions { - @Data - public static class Person { - + public record Person( @Description("Name of the person") - private String name; - private List aliases; - private boolean active; - private Person parent; - private Address currentAddress; - private List
previousAddresses; + String name, + List aliases, + boolean active, + Person parent, + Address currentAddress, + List
previousAddresses + ) { } public static class Address { @@ -49,36 +47,36 @@ public enum E { public static class Wrapper { @Tool({"line1", "line2"}) public int f( - @P("foo") String p0, - boolean p1, - @P("b2") Boolean p2, - byte p3, - Byte p4, - short p5, - Short p6, - int p7, - Integer p8, - long p9, - Long p10, - @P("biggy") - BigInteger p11, - float p12, - Float p13, - double p14, - Double p15, - @P("bigger") BigDecimal p16, - String[] p17, - Integer[] p18, - Boolean[] p19, - int[] p20, - boolean[] p21, - List p22, - Set p23, - Collection p24, - E p25, - Person p26, - @P(value = "optional", required = false) int p27, - @P(value = "required") int p28) { + @P("foo") String p0, + boolean p1, + @P("b2") Boolean p2, + byte p3, + Byte p4, + short p5, + Short p6, + int p7, + Integer p8, + long p9, + Long p10, + @P("biggy") + BigInteger p11, + float p12, + Float p13, + double p14, + Double p15, + @P("bigger") BigDecimal p16, + String[] p17, + Integer[] p18, + Boolean[] p19, + int[] p20, + boolean[] p21, + List p22, + Set p23, + Collection p24, + E p25, + Person p26, + @P(value = "optional", required = false) int p27, + @P(value = "required") int p28) { return 42; } @@ -122,35 +120,35 @@ public int aDifferentMethod(int typeInt) { private static Method getF() throws NoSuchMethodException { return Wrapper.class.getMethod("f", - String.class,//0 - boolean.class, - Boolean.class, - byte.class, - Byte.class, - short.class,//5 - Short.class, - int.class, - Integer.class, - long.class, - Long.class, //10 - BigInteger.class, - float.class, - Float.class, - double.class, - Double.class, //15 - BigDecimal.class, - String[].class, - Integer[].class, - Boolean[].class, - int[].class,//20 - boolean[].class, - List.class, - Set.class, - Collection.class, - E.class,// 25 - Person.class, - int.class, - int.class); + String.class,//0 + boolean.class, + Boolean.class, + byte.class, + Byte.class, + short.class,//5 + Short.class, + int.class, + Integer.class, + long.class, + Long.class, //10 + BigInteger.class, + float.class, + Float.class, + double.class, + Double.class, //15 + BigDecimal.class, + String[].class, + Integer[].class, + Boolean[].class, + int[].class,//20 + boolean[].class, + List.class, + Set.class, + Collection.class, + E.class,// 25 + Person.class, + int.class, + int.class); } public static Map mapOf(K k1, V v1) { @@ -192,24 +190,24 @@ public void test_toolSpecificationsFrom() { assertThat(specs).hasSize(2); assertThat(specs).extracting(ToolSpecification::name) - .containsExactlyInAnyOrder("f", "func_name"); + .containsExactlyInAnyOrder("f", "func_name"); } @Test public void test_toolSpecificationsFrom_with_duplicate_method_names() { assertThatExceptionOfType(IllegalArgumentException.class) - .isThrownBy(() -> ToolSpecifications.toolSpecificationsFrom(new InvalidToolsWithDuplicateMethodNames())) - .withMessage("Tool names must be unique. The tool 'duplicateMethod' appears several times") - .withNoCause(); + .isThrownBy(() -> ToolSpecifications.toolSpecificationsFrom(new InvalidToolsWithDuplicateMethodNames())) + .withMessage("Tool names must be unique. The tool 'duplicateMethod' appears several times") + .withNoCause(); } @Test public void test_toolSpecificationsFrom_with_duplicate_names() { assertThatExceptionOfType(IllegalArgumentException.class) - .isThrownBy(() -> ToolSpecifications.toolSpecificationsFrom(new InvalidToolsWithDuplicateNames())) - .withMessage("Tool names must be unique. The tool 'duplicate_name' appears several times") - .withNoCause(); + .isThrownBy(() -> ToolSpecifications.toolSpecificationsFrom(new InvalidToolsWithDuplicateNames())) + .withMessage("Tool names must be unique. The tool 'duplicate_name' appears several times") + .withNoCause(); } @@ -238,73 +236,72 @@ public void test_toolSpecificationFrom() throws NoSuchMethodException { assertThat(properties).hasSize(29); assertThat(properties) - .containsEntry("arg0", JsonStringSchema.builder().description("foo").build()) - .containsEntry("arg1", new JsonBooleanSchema()) - .containsEntry("arg2", JsonBooleanSchema.builder().description("b2").build()) - .containsEntry("arg3", new JsonIntegerSchema()) - .containsEntry("arg4", new JsonIntegerSchema()) - .containsEntry("arg5", new JsonIntegerSchema()) - .containsEntry("arg6", new JsonIntegerSchema()) - .containsEntry("arg7", new JsonIntegerSchema()) - .containsEntry("arg8", new JsonIntegerSchema()) - .containsEntry("arg9", new JsonIntegerSchema()) - .containsEntry("arg10", new JsonIntegerSchema()) - .containsEntry("arg11", JsonIntegerSchema.builder().description("biggy").build()) - .containsEntry("arg12", new JsonNumberSchema()) - .containsEntry("arg13", new JsonNumberSchema()) - .containsEntry("arg14", new JsonNumberSchema()) - .containsEntry("arg15", new JsonNumberSchema()) - .containsEntry("arg16", JsonNumberSchema.builder().description("bigger").build()) - .containsEntry("arg17", JsonArraySchema.builder().items(new JsonStringSchema()).build()) - .containsEntry("arg18", JsonArraySchema.builder().items(new JsonIntegerSchema()).build()) - .containsEntry("arg19", JsonArraySchema.builder().items(new JsonBooleanSchema()).build()) - .containsEntry("arg20", JsonArraySchema.builder().items(new JsonIntegerSchema()).build()) - .containsEntry("arg21", JsonArraySchema.builder().items(new JsonBooleanSchema()).build()) - .containsEntry("arg22", JsonArraySchema.builder().items(new JsonIntegerSchema()).build()) - .containsEntry("arg23", JsonArraySchema.builder().items(new JsonNumberSchema()).build()) - .containsEntry("arg24", JsonArraySchema.builder().items(new JsonStringSchema()).build()) - .containsEntry("arg25", JsonEnumSchema.builder().enumValues("A", "B", "C").build()) - .containsEntry("arg27", JsonIntegerSchema.builder().description("optional").build()) - .containsEntry("arg28", JsonIntegerSchema.builder().description("required").build()); + .containsEntry("arg0", JsonStringSchema.builder().description("foo").build()) + .containsEntry("arg1", new JsonBooleanSchema()) + .containsEntry("arg2", JsonBooleanSchema.builder().description("b2").build()) + .containsEntry("arg3", new JsonIntegerSchema()) + .containsEntry("arg4", new JsonIntegerSchema()) + .containsEntry("arg5", new JsonIntegerSchema()) + .containsEntry("arg6", new JsonIntegerSchema()) + .containsEntry("arg7", new JsonIntegerSchema()) + .containsEntry("arg8", new JsonIntegerSchema()) + .containsEntry("arg9", new JsonIntegerSchema()) + .containsEntry("arg10", new JsonIntegerSchema()) + .containsEntry("arg11", JsonIntegerSchema.builder().description("biggy").build()) + .containsEntry("arg12", new JsonNumberSchema()) + .containsEntry("arg13", new JsonNumberSchema()) + .containsEntry("arg14", new JsonNumberSchema()) + .containsEntry("arg15", new JsonNumberSchema()) + .containsEntry("arg16", JsonNumberSchema.builder().description("bigger").build()) + .containsEntry("arg17", JsonArraySchema.builder().items(new JsonStringSchema()).build()) + .containsEntry("arg18", JsonArraySchema.builder().items(new JsonIntegerSchema()).build()) + .containsEntry("arg19", JsonArraySchema.builder().items(new JsonBooleanSchema()).build()) + .containsEntry("arg20", JsonArraySchema.builder().items(new JsonIntegerSchema()).build()) + .containsEntry("arg21", JsonArraySchema.builder().items(new JsonBooleanSchema()).build()) + .containsEntry("arg22", JsonArraySchema.builder().items(new JsonIntegerSchema()).build()) + .containsEntry("arg23", JsonArraySchema.builder().items(new JsonNumberSchema()).build()) + .containsEntry("arg24", JsonArraySchema.builder().items(new JsonStringSchema()).build()) + .containsEntry("arg25", JsonEnumSchema.builder().enumValues("A", "B", "C").build()) + .containsEntry("arg27", JsonIntegerSchema.builder().description("optional").build()) + .containsEntry("arg28", JsonIntegerSchema.builder().description("required").build()); assertThat(ts.parameters().required()) - .containsExactly("arg0", - "arg1", - "arg2", - "arg3", - "arg4", - "arg5", - "arg6", - "arg7", - "arg8", - "arg9", - "arg10", - "arg11", - "arg12", - "arg13", - "arg14", - "arg15", - "arg16", - "arg17", - "arg18", - "arg19", - "arg20", - "arg21", - "arg22", - "arg23", - "arg24", - "arg25", - "arg26", - // "arg27", params with @P(required = false) are optional - "arg28" - ); + .containsExactly("arg0", + "arg1", + "arg2", + "arg3", + "arg4", + "arg5", + "arg6", + "arg7", + "arg8", + "arg9", + "arg10", + "arg11", + "arg12", + "arg13", + "arg14", + "arg15", + "arg16", + "arg17", + "arg18", + "arg19", + "arg20", + "arg21", + "arg22", + "arg23", + "arg24", + "arg25", + "arg26", + // "arg27", params with @P(required = false) are optional + "arg28" + ); } - @Data - public static class Customer { - public String name; - public Address billingAddress; - public Address shippingAddress; + record Customer( + String name, + Address billingAddress, + Address shippingAddress) { } public static class CustomerRegistration { @@ -326,22 +323,22 @@ void test_object_used_multiple_times() { assertThat(toolSpecification.name()).isEqualTo("registerCustomer"); assertThat(toolSpecification.description()).isEqualTo("register a new customer"); assertThat(toolSpecification.parameters()).isEqualTo(JsonObjectSchema.builder() - .addProperty("arg0", JsonObjectSchema.builder() - .addStringProperty("name") - .addProperty("billingAddress", JsonObjectSchema.builder() - .addStringProperty("street") - .addStringProperty("city") - .required("street", "city") - .build()) - .addProperty("shippingAddress", JsonObjectSchema.builder() - .addStringProperty("street") - .addStringProperty("city") - .required("street", "city") - .build()) - .required("name", "billingAddress", "shippingAddress") - .build()) - .required("arg0") - .build()); + .addProperty("arg0", JsonObjectSchema.builder() + .addStringProperty("name") + .addProperty("billingAddress", JsonObjectSchema.builder() + .addStringProperty("street") + .addStringProperty("city") + .required("street", "city") + .build()) + .addProperty("shippingAddress", JsonObjectSchema.builder() + .addStringProperty("street") + .addStringProperty("city") + .required("street", "city") + .build()) + .required("name", "billingAddress", "shippingAddress") + .build()) + .required("arg0") + .build()); assertThat(toolSpecification.toolParameters()).isNull(); } -} \ No newline at end of file +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/data/document/DocumentLoaderTest.java b/langchain4j-core/src/test/java/dev/langchain4j/data/document/DocumentLoaderTest.java index 80a454e9527..25a6f8193e5 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/data/document/DocumentLoaderTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/data/document/DocumentLoaderTest.java @@ -3,7 +3,10 @@ import org.assertj.core.api.WithAssertions; import org.junit.jupiter.api.Test; -import java.io.*; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; import java.nio.charset.StandardCharsets; @@ -75,27 +78,27 @@ public void test_load() { assertThat(document).isEqualTo(Document.from("Hello, world!", new Metadata().put("foo", "bar"))); assertThatExceptionOfType(RuntimeException.class) - .isThrownBy(() -> DocumentLoader.load(new DocumentSource() { - @Override - public InputStream inputStream() throws IOException { - throw new IOException("Failed to open input stream"); - } - - @Override - public Metadata metadata() { - return new Metadata(); - } - }, new TrivialParser())) - .withMessageContaining("Failed to load document"); + .isThrownBy(() -> DocumentLoader.load(new DocumentSource() { + @Override + public InputStream inputStream() throws IOException { + throw new IOException("Failed to open input stream"); + } + + @Override + public Metadata metadata() { + return new Metadata(); + } + }, new TrivialParser())) + .withMessageContaining("Failed to load document"); assertThatExceptionOfType(RuntimeException.class) - .isThrownBy(() -> DocumentLoader.load( - source, - inputStream -> { - throw new RuntimeException("Failed to parse document"); - } - - )) - .withMessageContaining("Failed to load document"); + .isThrownBy(() -> DocumentLoader.load( + source, + inputStream -> { + throw new RuntimeException("Failed to parse document"); + } + + )) + .withMessageContaining("Failed to load document"); } -} \ No newline at end of file +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/internal/JsonTest.java b/langchain4j-core/src/test/java/dev/langchain4j/internal/JsonTest.java index 3169d6a21c2..0cd61b3a39e 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/internal/JsonTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/internal/JsonTest.java @@ -1,6 +1,7 @@ package dev.langchain4j.internal; -import static org.assertj.core.api.Assertions.assertThat; +import com.google.gson.annotations.SerializedName; +import org.junit.jupiter.api.Test; import java.io.BufferedReader; import java.io.IOException; @@ -12,8 +13,7 @@ import java.util.List; import java.util.stream.Collectors; -import com.google.gson.annotations.SerializedName; -import org.junit.jupiter.api.Test; +import static org.assertj.core.api.Assertions.assertThat; class JsonTest { diff --git a/langchain4j-core/src/test/java/dev/langchain4j/internal/RetryUtilsTest.java b/langchain4j-core/src/test/java/dev/langchain4j/internal/RetryUtilsTest.java index cb1a66bf3cf..5c7f7286608 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/internal/RetryUtilsTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/internal/RetryUtilsTest.java @@ -6,7 +6,11 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; class RetryUtilsTest { @Test @@ -138,4 +142,4 @@ void testIllegalAttemptsReached() throws Exception { verify(mockAction, times(1)).call(); verifyNoMoreInteractions(mockAction); } -} \ No newline at end of file +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/internal/UtilsTest.java b/langchain4j-core/src/test/java/dev/langchain4j/internal/UtilsTest.java index 14feb0e3072..0f1d677a25d 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/internal/UtilsTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/internal/UtilsTest.java @@ -9,13 +9,22 @@ import java.io.IOException; import java.net.HttpURLConnection; import java.net.InetSocketAddress; -import java.util.*; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; import java.util.stream.Stream; import static dev.langchain4j.internal.Utils.quoted; import static java.util.Arrays.asList; -import static java.util.Collections.*; -import static org.assertj.core.api.Assertions.*; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.entry; @SuppressWarnings({"ObviousNullCheck", "ConstantValue"}) class UtilsTest { diff --git a/langchain4j-core/src/test/java/dev/langchain4j/internal/ValidationUtilsTest.java b/langchain4j-core/src/test/java/dev/langchain4j/internal/ValidationUtilsTest.java index 3cc4348fda3..7df26a6bfce 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/internal/ValidationUtilsTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/internal/ValidationUtilsTest.java @@ -6,9 +6,15 @@ import org.junit.jupiter.params.provider.NullSource; import org.junit.jupiter.params.provider.ValueSource; -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; -import static dev.langchain4j.internal.ValidationUtils.*; +import static dev.langchain4j.internal.ValidationUtils.ensureBetween; +import static dev.langchain4j.internal.ValidationUtils.ensureEq; +import static dev.langchain4j.internal.ValidationUtils.ensureGreaterThanZero; @SuppressWarnings("ConstantConditions") class ValidationUtilsTest implements WithAssertions { @@ -191,4 +197,4 @@ public void test_ensureBetween_long() { .withMessageContaining("test must be between 0 and 1, but is: -1"); } } -} \ No newline at end of file +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/model/chat/TestStreamingResponseHandler.java b/langchain4j-core/src/test/java/dev/langchain4j/model/chat/TestStreamingResponseHandler.java index 6c163e002c1..60d68272c48 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/model/chat/TestStreamingResponseHandler.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/model/chat/TestStreamingResponseHandler.java @@ -3,9 +3,10 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.output.Response; -import lombok.SneakyThrows; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; import static dev.langchain4j.internal.Exceptions.illegalArgument; import static java.util.concurrent.TimeUnit.SECONDS; @@ -26,8 +27,7 @@ public void onNext(String token) { public void onComplete(Response response) { String expectedTextContent = textContentBuilder.toString(); - if (response.content() instanceof AiMessage) { - AiMessage aiMessage = (AiMessage) response.content(); + if (response.content() instanceof AiMessage aiMessage) { if (aiMessage.hasToolExecutionRequests()){ assertThat(aiMessage.toolExecutionRequests().size()).isGreaterThan(0); } else { @@ -47,8 +47,14 @@ public void onError(Throwable error) { futureResponse.completeExceptionally(error); } - @SneakyThrows - public Response get() { - return futureResponse.get(30, SECONDS); + public Response get() { + try { + return futureResponse.get(30, SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } catch (ExecutionException | TimeoutException e) { + throw new RuntimeException(e); + } } } diff --git a/langchain4j-core/src/test/java/dev/langchain4j/model/embedding/DimensionAwareEmbeddingModelTest.java b/langchain4j-core/src/test/java/dev/langchain4j/model/embedding/DimensionAwareEmbeddingModelTest.java index bcc1fbda88d..ee1d0b496b9 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/model/embedding/DimensionAwareEmbeddingModelTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/model/embedding/DimensionAwareEmbeddingModelTest.java @@ -8,9 +8,7 @@ import org.assertj.core.api.WithAssertions; import org.junit.jupiter.api.Test; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; class DimensionAwareEmbeddingModelTest implements WithAssertions { diff --git a/langchain4j-core/src/test/java/dev/langchain4j/model/input/PromptTemplateTest.java b/langchain4j-core/src/test/java/dev/langchain4j/model/input/PromptTemplateTest.java index c13618317fe..7e911b8bbe7 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/model/input/PromptTemplateTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/model/input/PromptTemplateTest.java @@ -4,7 +4,12 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import java.time.*; +import java.time.Clock; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneOffset; import java.util.HashMap; import java.util.Map; @@ -197,4 +202,4 @@ void should_support_special_characters(String s) { // then assertThat(prompt.text()).isEqualTo("This is " + s + "."); } -} \ No newline at end of file +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/model/input/structured/StructuredPromptProcessorTest.java b/langchain4j-core/src/test/java/dev/langchain4j/model/input/structured/StructuredPromptProcessorTest.java index e1b2fae945d..542b0609d9e 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/model/input/structured/StructuredPromptProcessorTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/model/input/structured/StructuredPromptProcessorTest.java @@ -1,13 +1,14 @@ package dev.langchain4j.model.input.structured; +import dev.langchain4j.model.input.Prompt; +import org.junit.jupiter.api.Test; + +import java.util.List; + import static dev.langchain4j.model.input.structured.StructuredPromptProcessor.toPrompt; import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; -import dev.langchain4j.model.input.Prompt; -import java.util.List; -import org.junit.jupiter.api.Test; - class StructuredPromptProcessorTest { @StructuredPrompt("Hello, my name is {{name}}") diff --git a/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java b/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java index e10668f5f71..cbe5a4dc4fc 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java @@ -30,7 +30,13 @@ import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; class DefaultRetrievalAugmentorTest { @@ -377,4 +383,4 @@ public UserMessage inject(List contents, UserMessage userMessage) { return UserMessage.from(userMessage.text() + "\n" + joinedContents); } } -} \ No newline at end of file +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/rag/content/aggregator/DefaultContentAggregatorTest.java b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/aggregator/DefaultContentAggregatorTest.java index e2c2c03542c..c08e98fbc91 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/rag/content/aggregator/DefaultContentAggregatorTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/aggregator/DefaultContentAggregatorTest.java @@ -14,7 +14,10 @@ import java.util.stream.Stream; import static java.util.Arrays.asList; -import static java.util.Collections.*; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; import static org.assertj.core.api.Assertions.assertThat; class DefaultContentAggregatorTest { @@ -207,4 +210,4 @@ private static Stream should_return_empty_list_when_there_is_no_conte )) .build(); } -} \ No newline at end of file +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/rag/content/aggregator/ReRankingContentAggregatorTest.java b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/aggregator/ReRankingContentAggregatorTest.java index a2b3d97e26b..d9c27e4f36a 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/rag/content/aggregator/ReRankingContentAggregatorTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/aggregator/ReRankingContentAggregatorTest.java @@ -9,16 +9,25 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import java.util.*; +import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.function.Function; import java.util.stream.Stream; import static java.util.Arrays.asList; -import static java.util.Collections.*; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; class ReRankingContentAggregatorTest { @@ -225,4 +234,4 @@ private static Stream should_return_empty_list_when_there_is_no_conte )) .build(); } -} \ No newline at end of file +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/EmbeddingStoreContentRetrieverTest.java b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/EmbeddingStoreContentRetrieverTest.java index 173aabc0b0f..f78a53ef0dc 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/EmbeddingStoreContentRetrieverTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/EmbeddingStoreContentRetrieverTest.java @@ -10,7 +10,6 @@ import dev.langchain4j.store.embedding.EmbeddingSearchResult; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.filter.Filter; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -19,7 +18,10 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; class EmbeddingStoreContentRetrieverTest { @@ -333,4 +335,4 @@ void should_include_implicit_display_name_in_to_string() { // then assertThat(result).contains(EmbeddingStoreContentRetriever.DEFAULT_DISPLAY_NAME); } -} \ No newline at end of file +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverTest.java b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverTest.java index 718ec1f2fa0..6e3f81c4a63 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverTest.java @@ -4,7 +4,11 @@ import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.rag.content.Content; import dev.langchain4j.rag.query.Query; -import dev.langchain4j.web.search.*; +import dev.langchain4j.web.search.WebSearchEngine; +import dev.langchain4j.web.search.WebSearchInformationResult; +import dev.langchain4j.web.search.WebSearchOrganicResult; +import dev.langchain4j.web.search.WebSearchRequest; +import dev.langchain4j.web.search.WebSearchResults; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -15,7 +19,12 @@ import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; class WebSearchContentRetrieverTest { diff --git a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreIngestorTest.java b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreIngestorTest.java index c976c1933ef..9264c56ab3f 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreIngestorTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreIngestorTest.java @@ -16,7 +16,10 @@ import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; class EmbeddingStoreIngestorTest { @@ -219,4 +222,4 @@ void should_not_split_when_no_splitter_is_specified() { assertThat(ingestionResult.tokenUsage()).isEqualTo(tokenUsage); } -} \ No newline at end of file +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithRemovalIT.java b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithRemovalIT.java index 08a545c1461..37d745a03b5 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithRemovalIT.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithRemovalIT.java @@ -4,14 +4,11 @@ import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.store.embedding.filter.Filter; -import org.awaitility.Awaitility; -import org.awaitility.core.ThrowingRunnable; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.NullAndEmptySource; import org.junit.jupiter.params.provider.ValueSource; -import java.time.Duration; import java.util.Collection; import java.util.List; diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchResultsTest.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchResultsTest.java index ad5cb7fefaa..e4c48f2216e 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchResultsTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchResultsTest.java @@ -4,17 +4,12 @@ import org.junit.jupiter.api.Test; import java.net.URI; -import java.util.AbstractMap; import java.util.HashMap; import java.util.Map; -import java.util.stream.Stream; -import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; -import static java.util.stream.Collectors.toMap; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; -import static org.mockito.ArgumentMatchers.anyList; class WebSearchResultsTest { diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolTest.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolTest.java index 8ad8a04e9a9..955414b6c90 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolTest.java @@ -10,8 +10,11 @@ import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; class WebSearchToolTest { diff --git a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/DefaultMetadataStorageConfig.java b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/DefaultMetadataStorageConfig.java index eb32b8d9a77..63c68e86536 100644 --- a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/DefaultMetadataStorageConfig.java +++ b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/DefaultMetadataStorageConfig.java @@ -1,6 +1,9 @@ package dev.langchain4j.store.embedding.pgvector; -import lombok.*; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; import lombok.experimental.Accessors; import java.util.Collections; diff --git a/langchain4j/src/test/java/dev/langchain4j/internal/TestUtils.java b/langchain4j/src/test/java/dev/langchain4j/internal/TestUtils.java index 860dac93f1b..cd09956ff52 100644 --- a/langchain4j/src/test/java/dev/langchain4j/internal/TestUtils.java +++ b/langchain4j/src/test/java/dev/langchain4j/internal/TestUtils.java @@ -4,7 +4,6 @@ import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.openai.OpenAiTokenizer; -import lombok.val; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -89,7 +88,7 @@ void should_repeat_n_times() { } public static List repeat(String s, int n) { - val result = new ArrayList(); + final var result = new ArrayList(); for (int i = 0; i < n; i++) { result.add(s); }