diff --git a/.editorconfig b/.editorconfig
new file mode 100644
index 00000000000..f2ea93d45f7
--- /dev/null
+++ b/.editorconfig
@@ -0,0 +1,17 @@
+root = true
+
+[*]
+charset = utf-8
+end_of_line = lf
+indent_size = 4
+indent_style = space
+insert_final_newline = true
+max_line_length = 100
+tab_width = 4
+
+[*.java]
+ij_java_names_count_to_use_import_on_demand = 999
+ij_java_class_count_to_use_import_on_demand = 999
+
+[{*.yaml,*.yml}]
+indent_size = 2
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 34f5add8ba8..b4e6467d1cc 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -5,8 +5,9 @@ Thank you for investing your time and effort in contributing to our project, we
- If you want to contribute a bug fix or a new feature that isn't listed in the [issues](https://github.com/langchain4j/langchain4j/issues) yet, please open a new issue for it. We will prioritize is shortly.
- Follow [Google's Best Practices for Java Libraries](https://jlbp.dev/)
- Keep the code compatible with Java 17.
-- Avoid adding new dependencies as much as possible (new dependencies with test scope are OK). If absolutely necessary, try to use the same libraries which are already used in the project.
+- Avoid adding new dependencies as much as possible (new dependencies with test scope are OK). If absolutely necessary, try to use the same libraries which are already used in the project. Make sure you run `mvn dependency:analyze` to identify unnecessary dependencies.
- Write unit and/or integration tests for your code. This is critical: no tests, no review!
+- Make sure you run all unit tests on all modules with `mvn clean test`
- Avoid making breaking changes. Always keep backward compatibility in mind. For example, instead of removing fields/methods/etc, mark them `@Deprecated` and make sure they still work as before.
- Follow existing naming conventions.
- Avoid using Lombok in the new code, and remove it from the old code if you get a chance.
diff --git a/langchain4j-core/pom.xml b/langchain4j-core/pom.xml
index 5fd70502b80..9f811704ed5 100644
--- a/langchain4j-core/pom.xml
+++ b/langchain4j-core/pom.xml
@@ -34,12 +34,6 @@
slf4j-api
-
- org.projectlombok
- lombok
- provided
-
-
org.junit.jupiter
junit-jupiter-engine
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/Experimental.java b/langchain4j-core/src/main/java/dev/langchain4j/Experimental.java
index accdff9d188..35fbb892d47 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/Experimental.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/Experimental.java
@@ -2,7 +2,9 @@
import java.lang.annotation.Target;
-import static java.lang.annotation.ElementType.*;
+import static java.lang.annotation.ElementType.CONSTRUCTOR;
+import static java.lang.annotation.ElementType.METHOD;
+import static java.lang.annotation.ElementType.TYPE;
/**
* Indicates that a class/constructor/method is experimental and might change in the future.
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/document/Metadata.java b/langchain4j-core/src/main/java/dev/langchain4j/data/document/Metadata.java
index 8763f27f025..6c94b15644e 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/data/document/Metadata.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/data/document/Metadata.java
@@ -3,7 +3,12 @@
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingStore;
-import java.util.*;
+import java.util.HashMap;
+import java.util.LinkedHashSet;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.UUID;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.Exceptions.runtime;
@@ -67,8 +72,8 @@ public Metadata(Map metadata) {
validate(key, value);
if (!SUPPORTED_VALUE_TYPES.contains(value.getClass())) {
throw illegalArgument("The metadata key '%s' has the value '%s', which is of the unsupported type '%s'. " +
- "Currently, the supported types are: %s",
- key, value, value.getClass().getName(), SUPPORTED_VALUE_TYPES
+ "Currently, the supported types are: %s",
+ key, value, value.getClass().getName(), SUPPORTED_VALUE_TYPES
);
}
});
@@ -116,7 +121,7 @@ public String getString(String key) {
}
throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " +
- "It cannot be returned as a String.", key, value, value.getClass().getName());
+ "It cannot be returned as a String.", key, value, value.getClass().getName());
}
/**
@@ -140,7 +145,7 @@ public UUID getUUID(String key) {
}
throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " +
- "It cannot be returned as a UUID.", key, value, value.getClass().getName());
+ "It cannot be returned as a UUID.", key, value, value.getClass().getName());
}
/**
@@ -172,7 +177,7 @@ public Integer getInteger(String key) {
}
throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " +
- "It cannot be returned as an Integer.", key, value, value.getClass().getName());
+ "It cannot be returned as an Integer.", key, value, value.getClass().getName());
}
/**
@@ -204,7 +209,7 @@ public Long getLong(String key) {
}
throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " +
- "It cannot be returned as a Long.", key, value, value.getClass().getName());
+ "It cannot be returned as a Long.", key, value, value.getClass().getName());
}
/**
@@ -236,7 +241,7 @@ public Float getFloat(String key) {
}
throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " +
- "It cannot be returned as a Float.", key, value, value.getClass().getName());
+ "It cannot be returned as a Float.", key, value, value.getClass().getName());
}
/**
@@ -268,7 +273,7 @@ public Double getDouble(String key) {
}
throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " +
- "It cannot be returned as a Double.", key, value, value.getClass().getName());
+ "It cannot be returned as a Double.", key, value, value.getClass().getName());
}
/**
@@ -449,8 +454,8 @@ public int hashCode() {
@Override
public String toString() {
return "Metadata {" +
- " metadata = " + metadata +
- " }";
+ " metadata = " + metadata +
+ " }";
}
/**
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/message/AiMessage.java b/langchain4j-core/src/main/java/dev/langchain4j/data/message/AiMessage.java
index 65d392c5998..6ff573f88b4 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/data/message/AiMessage.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/data/message/AiMessage.java
@@ -8,7 +8,9 @@
import static dev.langchain4j.data.message.ChatMessageType.AI;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.Utils.quoted;
-import static dev.langchain4j.internal.ValidationUtils.*;
+import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
+import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
+import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static java.util.Arrays.asList;
/**
@@ -90,7 +92,7 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
AiMessage that = (AiMessage) o;
return Objects.equals(this.text, that.text)
- && Objects.equals(this.toolExecutionRequests, that.toolExecutionRequests);
+ && Objects.equals(this.toolExecutionRequests, that.toolExecutionRequests);
}
@Override
@@ -101,9 +103,9 @@ public int hashCode() {
@Override
public String toString() {
return "AiMessage {" +
- " text = " + quoted(text) +
- " toolExecutionRequests = " + toolExecutionRequests +
- " }";
+ " text = " + quoted(text) +
+ " toolExecutionRequests = " + toolExecutionRequests +
+ " }";
}
/**
@@ -139,7 +141,7 @@ public static AiMessage from(List toolExecutionRequests) {
/**
* Create a new {@link AiMessage} with the given text and tool execution requests.
*
- * @param text the text of the message.
+ * @param text the text of the message.
* @param toolExecutionRequests the tool execution requests of the message.
* @return the new {@link AiMessage}.
*/
@@ -180,7 +182,7 @@ public static AiMessage aiMessage(List toolExecutionReques
/**
* Create a new {@link AiMessage} with the given text and tool execution requests.
*
- * @param text the text of the message.
+ * @param text the text of the message.
* @param toolExecutionRequests the tool execution requests of the message.
* @return the new {@link AiMessage}.
*/
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/message/ChatMessageDeserializer.java b/langchain4j-core/src/main/java/dev/langchain4j/data/message/ChatMessageDeserializer.java
index 7853fb9e381..9d488fd567b 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/data/message/ChatMessageDeserializer.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/data/message/ChatMessageDeserializer.java
@@ -1,8 +1,9 @@
package dev.langchain4j.data.message;
-import static dev.langchain4j.data.message.ChatMessageSerializer.CODEC;
import java.util.List;
+import static dev.langchain4j.data.message.ChatMessageSerializer.CODEC;
+
/**
* A deserializer for {@link ChatMessage} objects.
*/
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageAdapter.java b/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageAdapter.java
index 477a9308ffb..1fab59b105f 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageAdapter.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageAdapter.java
@@ -1,19 +1,26 @@
package dev.langchain4j.data.message;
-import com.google.gson.*;
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.JsonDeserializationContext;
+import com.google.gson.JsonDeserializer;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonSerializationContext;
+import com.google.gson.JsonSerializer;
import java.lang.reflect.Type;
class GsonChatMessageAdapter implements JsonDeserializer, JsonSerializer {
private static final Gson GSON = new GsonBuilder()
- .registerTypeAdapter(Content.class, new GsonContentAdapter())
- .registerTypeAdapter(TextContent.class, new GsonContentAdapter())
- .registerTypeAdapter(ImageContent.class, new GsonContentAdapter())
- .registerTypeAdapter(AudioContent.class, new GsonContentAdapter())
- .registerTypeAdapter(VideoContent.class, new GsonContentAdapter())
- .registerTypeAdapter(PdfFileContent.class, new GsonContentAdapter())
- .create();
+ .registerTypeAdapter(Content.class, new GsonContentAdapter())
+ .registerTypeAdapter(TextContent.class, new GsonContentAdapter())
+ .registerTypeAdapter(ImageContent.class, new GsonContentAdapter())
+ .registerTypeAdapter(AudioContent.class, new GsonContentAdapter())
+ .registerTypeAdapter(VideoContent.class, new GsonContentAdapter())
+ .registerTypeAdapter(PdfFileContent.class, new GsonContentAdapter())
+ .create();
private static final String CHAT_MESSAGE_TYPE = "type"; // do not change, will break backward compatibility!
@@ -35,4 +42,4 @@ public ChatMessage deserialize(JsonElement messageJsonElement, Type ignored, Jso
}
return chatMessage;
}
-}
\ No newline at end of file
+}
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageJsonCodec.java b/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageJsonCodec.java
index 94f85976a9c..5955cce09d7 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageJsonCodec.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonChatMessageJsonCodec.java
@@ -1,12 +1,14 @@
package dev.langchain4j.data.message;
-import static java.util.Collections.emptyList;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
+
import java.lang.reflect.Type;
import java.util.List;
+import static java.util.Collections.emptyList;
+
/**
* A codec for serializing and deserializing {@link ChatMessage} objects to and from JSON.
*/
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonContentAdapter.java b/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonContentAdapter.java
index f1d886599c8..66e3135d403 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonContentAdapter.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/data/message/GsonContentAdapter.java
@@ -1,6 +1,12 @@
package dev.langchain4j.data.message;
-import com.google.gson.*;
+import com.google.gson.Gson;
+import com.google.gson.JsonDeserializationContext;
+import com.google.gson.JsonDeserializer;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonSerializationContext;
+import com.google.gson.JsonSerializer;
import java.lang.reflect.Type;
@@ -23,4 +29,4 @@ public Content deserialize(JsonElement contentJsonElement, Type ignored, JsonDes
ContentType contentType = ContentType.valueOf(contentTypeString);
return GSON.fromJson(contentJsonElement, contentType.getContentClass());
}
-}
\ No newline at end of file
+}
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelRequest.java b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelRequest.java
index a4f766eca66..ff747c55f47 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelRequest.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelRequest.java
@@ -5,7 +5,6 @@
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
-import lombok.Builder;
import java.util.List;
@@ -25,7 +24,6 @@ public class ChatModelRequest {
private final List messages;
private final List toolSpecifications;
- @Builder
public ChatModelRequest(String model,
Double temperature,
Double topP,
@@ -40,6 +38,10 @@ public ChatModelRequest(String model,
this.toolSpecifications = copyIfNotNull(toolSpecifications);
}
+ public static ChatModelRequestBuilder builder() {
+ return new ChatModelRequestBuilder();
+ }
+
public String model() {
return model;
}
@@ -63,4 +65,54 @@ public List messages() {
public List toolSpecifications() {
return toolSpecifications;
}
+
+ public static class ChatModelRequestBuilder {
+ private String model;
+ private Double temperature;
+ private Double topP;
+ private Integer maxTokens;
+ private List messages;
+ private List toolSpecifications;
+
+ ChatModelRequestBuilder() {
+ }
+
+ public ChatModelRequestBuilder model(String model) {
+ this.model = model;
+ return this;
+ }
+
+ public ChatModelRequestBuilder temperature(Double temperature) {
+ this.temperature = temperature;
+ return this;
+ }
+
+ public ChatModelRequestBuilder topP(Double topP) {
+ this.topP = topP;
+ return this;
+ }
+
+ public ChatModelRequestBuilder maxTokens(Integer maxTokens) {
+ this.maxTokens = maxTokens;
+ return this;
+ }
+
+ public ChatModelRequestBuilder messages(List messages) {
+ this.messages = messages;
+ return this;
+ }
+
+ public ChatModelRequestBuilder toolSpecifications(List toolSpecifications) {
+ this.toolSpecifications = toolSpecifications;
+ return this;
+ }
+
+ public ChatModelRequest build() {
+ return new ChatModelRequest(this.model, this.temperature, this.topP, this.maxTokens, this.messages, this.toolSpecifications);
+ }
+
+ public String toString() {
+ return "ChatModelRequest.ChatModelRequestBuilder(model=" + this.model + ", temperature=" + this.temperature + ", topP=" + this.topP + ", maxTokens=" + this.maxTokens + ", messages=" + this.messages + ", toolSpecifications=" + this.toolSpecifications + ")";
+ }
+ }
}
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelResponse.java b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelResponse.java
index 05397668d5d..00446ed17d1 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelResponse.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelResponse.java
@@ -6,7 +6,6 @@
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;
-import lombok.Builder;
/**
* A response from the {@link ChatLanguageModel} or {@link StreamingChatLanguageModel},
@@ -21,7 +20,6 @@ public class ChatModelResponse {
private final FinishReason finishReason;
private final AiMessage aiMessage;
- @Builder
public ChatModelResponse(String id,
String model,
TokenUsage tokenUsage,
@@ -34,6 +32,10 @@ public ChatModelResponse(String id,
this.aiMessage = aiMessage;
}
+ public static ChatModelResponseBuilder builder() {
+ return new ChatModelResponseBuilder();
+ }
+
public String id() {
return id;
}
@@ -53,4 +55,48 @@ public FinishReason finishReason() {
public AiMessage aiMessage() {
return aiMessage;
}
+
+ public static class ChatModelResponseBuilder {
+ private String id;
+ private String model;
+ private TokenUsage tokenUsage;
+ private FinishReason finishReason;
+ private AiMessage aiMessage;
+
+ ChatModelResponseBuilder() {
+ }
+
+ public ChatModelResponseBuilder id(String id) {
+ this.id = id;
+ return this;
+ }
+
+ public ChatModelResponseBuilder model(String model) {
+ this.model = model;
+ return this;
+ }
+
+ public ChatModelResponseBuilder tokenUsage(TokenUsage tokenUsage) {
+ this.tokenUsage = tokenUsage;
+ return this;
+ }
+
+ public ChatModelResponseBuilder finishReason(FinishReason finishReason) {
+ this.finishReason = finishReason;
+ return this;
+ }
+
+ public ChatModelResponseBuilder aiMessage(AiMessage aiMessage) {
+ this.aiMessage = aiMessage;
+ return this;
+ }
+
+ public ChatModelResponse build() {
+ return new ChatModelResponse(this.id, this.model, this.tokenUsage, this.finishReason, this.aiMessage);
+ }
+
+ public String toString() {
+ return "ChatModelResponse.ChatModelResponseBuilder(id=" + this.id + ", model=" + this.model + ", tokenUsage=" + this.tokenUsage + ", finishReason=" + this.finishReason + ", aiMessage=" + this.aiMessage + ")";
+ }
+ }
}
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/embedding/DisabledEmbeddingModel.java b/langchain4j-core/src/main/java/dev/langchain4j/model/embedding/DisabledEmbeddingModel.java
index 9412afcbe4b..360fb829a11 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/model/embedding/DisabledEmbeddingModel.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/model/embedding/DisabledEmbeddingModel.java
@@ -6,7 +6,6 @@
import dev.langchain4j.model.output.Response;
import java.util.List;
-import java.util.Map;
/**
* An {@link EmbeddingModel} which throws a {@link ModelDisabledException} for all of its methods
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java b/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java
index bc55ad9b8bd..14be9f791f0 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java
@@ -2,6 +2,7 @@
import dev.langchain4j.data.image.Image;
import dev.langchain4j.model.output.Response;
+
import java.util.List;
/**
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/input/structured/DefaultStructuredPromptFactory.java b/langchain4j-core/src/main/java/dev/langchain4j/model/input/structured/DefaultStructuredPromptFactory.java
index fe08fa8d7e6..55cbdf821a0 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/model/input/structured/DefaultStructuredPromptFactory.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/model/input/structured/DefaultStructuredPromptFactory.java
@@ -7,6 +7,7 @@
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.spi.prompt.structured.StructuredPromptFactory;
+
import java.util.Map;
/**
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/AugmentationResult.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/AugmentationResult.java
index 4f70e6a9935..47bffa263a1 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/rag/AugmentationResult.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/AugmentationResult.java
@@ -2,7 +2,6 @@
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.rag.content.Content;
-import lombok.Builder;
import java.util.List;
@@ -24,12 +23,15 @@ public class AugmentationResult {
*/
private final List contents;
- @Builder
public AugmentationResult(ChatMessage chatMessage, List contents) {
this.chatMessage = ensureNotNull(chatMessage, "chatMessage");
this.contents = copyIfNotNull(contents);
}
+ public static AugmentationResultBuilder builder() {
+ return new AugmentationResultBuilder();
+ }
+
public ChatMessage chatMessage() {
return chatMessage;
}
@@ -37,4 +39,30 @@ public ChatMessage chatMessage() {
public List contents() {
return contents;
}
+
+ public static class AugmentationResultBuilder {
+ private ChatMessage chatMessage;
+ private List contents;
+
+ AugmentationResultBuilder() {
+ }
+
+ public AugmentationResultBuilder chatMessage(ChatMessage chatMessage) {
+ this.chatMessage = chatMessage;
+ return this;
+ }
+
+ public AugmentationResultBuilder contents(List contents) {
+ this.contents = contents;
+ return this;
+ }
+
+ public AugmentationResult build() {
+ return new AugmentationResult(this.chatMessage, this.contents);
+ }
+
+ public String toString() {
+ return "AugmentationResult.AugmentationResultBuilder(chatMessage=" + this.chatMessage + ", contents=" + this.contents + ")";
+ }
+ }
}
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java
index db228db14a1..0604332c166 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java
@@ -14,22 +14,30 @@
import dev.langchain4j.rag.query.router.QueryRouter;
import dev.langchain4j.rag.query.transformer.DefaultQueryTransformer;
import dev.langchain4j.rag.query.transformer.QueryTransformer;
-import lombok.Builder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Collection;
import java.util.List;
import java.util.Map;
-import java.util.concurrent.*;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.Executor;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.SynchronousQueue;
+import java.util.concurrent.ThreadPoolExecutor;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
-import static java.util.Collections.*;
+import static java.util.Collections.emptyMap;
+import static java.util.Collections.singletonList;
+import static java.util.Collections.singletonMap;
import static java.util.concurrent.CompletableFuture.allOf;
import static java.util.concurrent.CompletableFuture.supplyAsync;
import static java.util.concurrent.TimeUnit.SECONDS;
-import static java.util.stream.Collectors.*;
+import static java.util.stream.Collectors.joining;
+import static java.util.stream.Collectors.toMap;
/**
* The default implementation of {@link RetrievalAugmentor} intended to be suitable for the majority of use cases.
@@ -109,7 +117,6 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
private final ContentInjector contentInjector;
private final Executor executor;
- @Builder
public DefaultRetrievalAugmentor(QueryTransformer queryTransformer,
QueryRouter queryRouter,
ContentAggregator contentAggregator,
@@ -124,9 +131,9 @@ public DefaultRetrievalAugmentor(QueryTransformer queryTransformer,
private static ExecutorService createDefaultExecutor() {
return new ThreadPoolExecutor(
- 0, Integer.MAX_VALUE,
- 1, SECONDS,
- new SynchronousQueue<>()
+ 0, Integer.MAX_VALUE,
+ 1, SECONDS,
+ new SynchronousQueue<>()
);
}
@@ -160,9 +167,9 @@ public AugmentationResult augment(AugmentationRequest augmentationRequest) {
log(augmentedChatMessage);
return AugmentationResult.builder()
- .chatMessage(augmentedChatMessage)
- .contents(contents)
- .build();
+ .chatMessage(augmentedChatMessage)
+ .contents(contents)
+ .build();
}
private Map>> process(Collection queries) {
@@ -183,13 +190,13 @@ private Map>> process(Collection queries)
Map>>> queryToFutureContents = new ConcurrentHashMap<>();
queries.forEach(query -> {
CompletableFuture>> futureContents =
- supplyAsync(() -> {
- Collection retrievers = queryRouter.route(query);
- log(query, retrievers);
- return retrievers;
- },
- executor
- ).thenCompose(retrievers -> retrieveFromAll(retrievers, query));
+ supplyAsync(() -> {
+ Collection retrievers = queryRouter.route(query);
+ log(query, retrievers);
+ return retrievers;
+ },
+ executor
+ ).thenCompose(retrievers -> retrieveFromAll(retrievers, query));
queryToFutureContents.put(query, futureContents);
});
return join(queryToFutureContents);
@@ -201,15 +208,14 @@ private Map>> process(Collection queries)
private CompletableFuture>> retrieveFromAll(Collection retrievers,
Query query) {
List>> futureContents = retrievers.stream()
- .map(retriever -> supplyAsync(() -> retrieve(retriever, query), executor))
- .collect(toList());
+ .map(retriever -> supplyAsync(() -> retrieve(retriever, query), executor))
+ .toList();
return allOf(futureContents.toArray(new CompletableFuture[0]))
- .thenApply(ignored ->
- futureContents.stream()
- .map(CompletableFuture::join)
- .collect(toList())
- );
+ .thenApply(ignored ->
+ futureContents.stream()
+ .map(CompletableFuture::join)
+ .toList());
}
private static List retrieve(ContentRetriever retriever, Query query) {
@@ -219,15 +225,15 @@ private static List retrieve(ContentRetriever retriever, Query query) {
}
private static Map>> join(
- Map>>> queryToFutureContents) {
+ Map>>> queryToFutureContents) {
return allOf(queryToFutureContents.values().toArray(new CompletableFuture[0]))
- .thenApply(ignored ->
- queryToFutureContents.entrySet().stream()
- .collect(toMap(
- Map.Entry::getKey,
- entry -> entry.getValue().join()
- ))
- ).join();
+ .thenApply(ignored ->
+ queryToFutureContents.entrySet().stream()
+ .collect(toMap(
+ Map.Entry::getKey,
+ entry -> entry.getValue().join()
+ ))
+ ).join();
}
private static void logQueries(Query originalQuery, Collection queries) {
@@ -235,14 +241,14 @@ private static void logQueries(Query originalQuery, Collection queries) {
Query transformedQuery = queries.iterator().next();
if (!transformedQuery.equals(originalQuery)) {
log.debug("Transformed original query '{}' into '{}'",
- originalQuery.text(), transformedQuery.text());
+ originalQuery.text(), transformedQuery.text());
}
- } else {
+ } else if (log.isDebugEnabled()){
log.debug("Transformed original query '{}' into the following queries:\n{}",
- originalQuery.text(), queries.stream()
- .map(Query::text)
- .map(query -> "- '" + query + "'")
- .collect(joining("\n")));
+ originalQuery.text(), queries.stream()
+ .map(Query::text)
+ .map(query -> "- '" + query + "'")
+ .collect(joining("\n")));
}
}
@@ -250,27 +256,40 @@ private static void log(Query query, Collection retrievers) {
// TODO use retriever id
if (retrievers.size() == 1) {
log.debug("Routing query '{}' to the following retriever: {}",
- query.text(), retrievers.iterator().next());
- } else {
+ query.text(), retrievers.iterator().next());
+ } else if (log.isDebugEnabled()) {
log.debug("Routing query '{}' to the following retrievers:\n{}",
- query.text(), retrievers.stream()
- .map(retriever -> "- " + retriever.toString())
- .collect(joining("\n")));
+ query.text(), retrievers.stream()
+ .map(retriever -> "- " + retriever.toString())
+ .collect(joining("\n")));
}
}
private static void log(Query query, ContentRetriever retriever, List contents) {
// TODO use retriever id
log.debug("Retrieved {} contents using query '{}' and retriever '{}'",
- contents.size(), query.text(), retriever);
+ contents.size(), query.text(), retriever);
+
+ if (!log.isTraceEnabled()) {
+ return;
+ }
- if (contents.size() > 0) {
+ if (!contents.isEmpty()) {
+ final var contentsSting = contents.stream()
+ .map(Content::textSegment)
+ .map(segment -> "- " + escapeNewlines(segment.text()))
+ .collect(joining("\n"));
log.trace("Retrieved {} contents using query '{}' and retriever '{}':\n{}",
- contents.size(), query.text(), retriever, contents.stream()
- .map(Content::textSegment)
- .map(segment -> "- " + escapeNewlines(segment.text()))
- .collect(joining("\n")));
+ contents.size(),
+ query.text(),
+ retriever.getClass().getName(),
+ contentsSting);
+ } else {
+ log.trace("Retrieved 0 contents using query '{}' and retriever '{}'",
+ query.text(),
+ retriever.getClass().getName());
}
+
}
private static void log(Map>> queryToContents, List contents) {
@@ -287,15 +306,21 @@ private static void log(Map>> queryToContents, L
log.debug("Aggregated {} content(s) into {}", contentCount, contents.size());
- log.trace("Aggregated {} content(s) into:\n{}",
+ if (log.isTraceEnabled()) {
+ log.trace("Aggregated {} content(s) into:\n{}",
contentCount, contents.stream()
- .map(Content::textSegment)
- .map(segment -> "- " + escapeNewlines(segment.text()))
- .collect(joining("\n")));
+ .map(Content::textSegment)
+ .map(segment -> "- " + escapeNewlines(segment.text()))
+ .collect(joining("\n")));
+ }
}
private static void log(ChatMessage augmentedChatMessage) {
- log.trace("Augmented chat message: {}", escapeNewlines(augmentedChatMessage.text()));
+ if (log.isTraceEnabled()) {
+ log.trace("Augmented chat message: {}",
+ escapeNewlines(augmentedChatMessage.text())
+ );
+ }
}
private static String escapeNewlines(String text) {
@@ -308,9 +333,51 @@ public static DefaultRetrievalAugmentorBuilder builder() {
public static class DefaultRetrievalAugmentorBuilder {
+ private QueryTransformer queryTransformer;
+ private QueryRouter queryRouter;
+ private ContentAggregator contentAggregator;
+ private ContentInjector contentInjector;
+ private Executor executor;
+
+ DefaultRetrievalAugmentorBuilder() {
+ }
+
public DefaultRetrievalAugmentorBuilder contentRetriever(ContentRetriever contentRetriever) {
this.queryRouter = new DefaultQueryRouter(ensureNotNull(contentRetriever, "contentRetriever"));
return this;
}
+
+ public DefaultRetrievalAugmentorBuilder queryTransformer(QueryTransformer queryTransformer) {
+ this.queryTransformer = queryTransformer;
+ return this;
+ }
+
+ public DefaultRetrievalAugmentorBuilder queryRouter(QueryRouter queryRouter) {
+ this.queryRouter = queryRouter;
+ return this;
+ }
+
+ public DefaultRetrievalAugmentorBuilder contentAggregator(ContentAggregator contentAggregator) {
+ this.contentAggregator = contentAggregator;
+ return this;
+ }
+
+ public DefaultRetrievalAugmentorBuilder contentInjector(ContentInjector contentInjector) {
+ this.contentInjector = contentInjector;
+ return this;
+ }
+
+ public DefaultRetrievalAugmentorBuilder executor(Executor executor) {
+ this.executor = executor;
+ return this;
+ }
+
+ public DefaultRetrievalAugmentor build() {
+ return new DefaultRetrievalAugmentor(this.queryTransformer, this.queryRouter, this.contentAggregator, this.contentInjector, this.executor);
+ }
+
+ public String toString() {
+ return "DefaultRetrievalAugmentor.DefaultRetrievalAugmentorBuilder(queryTransformer=" + this.queryTransformer + ", queryRouter=" + this.queryRouter + ", contentAggregator=" + this.contentAggregator + ", contentInjector=" + this.contentInjector + ", executor=" + this.executor + ")";
+ }
}
}
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/aggregator/ReRankingContentAggregator.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/aggregator/ReRankingContentAggregator.java
index ece58665362..72bf1263d12 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/aggregator/ReRankingContentAggregator.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/aggregator/ReRankingContentAggregator.java
@@ -5,9 +5,12 @@
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.rag.query.transformer.ExpandingQueryTransformer;
-import lombok.Builder;
-import java.util.*;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
import java.util.function.Function;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
@@ -64,7 +67,6 @@ public ReRankingContentAggregator(ScoringModel scoringModel) {
this(scoringModel, DEFAULT_QUERY_SELECTOR, null);
}
- @Builder
public ReRankingContentAggregator(ScoringModel scoringModel,
Function