From 08e2abbbf21bb8fb9a6878b9c7f0990c6a0df701 Mon Sep 17 00:00:00 2001 From: Thomas Vitale Date: Tue, 24 Dec 2024 15:07:55 +0100 Subject: [PATCH] Add more Sequential RAG examples --- rag/rag-sequential/rag-advanced/README.md | 36 ++++++++++++ .../thomasvitale/ai/spring/RagAdvanced.java | 7 +++ .../ai/spring/RagControllerCompression.java | 55 +++++++++++++++++++ .../ai/spring/RagControllerMemory.java | 45 +++++++++++++++ .../ai/spring/RagControllerRewrite.java | 46 ++++++++++++++++ .../ai/spring/config/HttpClientConfig.java | 27 --------- 6 files changed, 189 insertions(+), 27 deletions(-) create mode 100644 rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagControllerCompression.java create mode 100644 rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagControllerMemory.java create mode 100644 rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagControllerRewrite.java delete mode 100644 rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/config/HttpClientConfig.java diff --git a/rag/rag-sequential/rag-advanced/README.md b/rag/rag-sequential/rag-advanced/README.md index f61880a..fe678a2 100644 --- a/rag/rag-sequential/rag-advanced/README.md +++ b/rag/rag-sequential/rag-advanced/README.md @@ -44,6 +44,42 @@ You can also explore metrics in "Explore > Metrics" and logs in "Explore > Logs" Call the application that will use a chat model to answer your questions. +### Query Transformation: Compression + +Without compression: + +```shell +http --raw "Who are the characters going on an adventure in the North Pole?" :8080/rag/memory/007 -b --pretty none +``` + +```shell +http --raw "What places do they visit?" :8080/rag/memory/007 -b --pretty none +``` + +With compression: + +```shell +http --raw "Who are the characters going on an adventure in the North Pole?" :8080/rag/compression/007 -b --pretty none +``` + +```shell +http --raw "What places do they visit?" :8080/rag/compression/007 -b --pretty none +``` + +### Query Transformation: Rewrite + +Without rewrite: + +```shell +http --raw "Where are the main characters going on an adventure?" :8080/rag/basic -b --pretty none +``` + +With rewrite: + +```shell +http --raw "Where are the main characters going on an adventure?" :8080/rag/rewrite -b --pretty none +``` + ### Query Transformation: Translation Without translation: diff --git a/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagAdvanced.java b/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagAdvanced.java index 7db7583..cca07c2 100644 --- a/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagAdvanced.java +++ b/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagAdvanced.java @@ -2,6 +2,8 @@ import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.instrumentation.logback.appender.v1_0.OpenTelemetryAppender; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.InMemoryChatMemory; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.context.event.ApplicationReadyEvent; @@ -20,4 +22,9 @@ ApplicationListener logbackOtelAppenderInitializer(OpenTe return _ -> OpenTelemetryAppender.install(openTelemetry); } + @Bean + ChatMemory chatMemory() { + return new InMemoryChatMemory(); + } + } diff --git a/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagControllerCompression.java b/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagControllerCompression.java new file mode 100644 index 0000000..4945097 --- /dev/null +++ b/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagControllerCompression.java @@ -0,0 +1,55 @@ +package com.thomasvitale.ai.spring; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer; +import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RestController; + +@RestController +public class RagControllerCompression { + + private final ChatClient chatClient; + private final MessageChatMemoryAdvisor chatMemoryAdvisor; + private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor; + + public RagControllerCompression(ChatClient.Builder chatClientBuilder, ChatMemory chatMemory, VectorStore vectorStore) { + this.chatClient = chatClientBuilder.build(); + + this.chatMemoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory) + .build(); + + var documentRetriever = VectorStoreDocumentRetriever.builder() + .vectorStore(vectorStore) + .similarityThreshold(0.50) + .build(); + + var queryTransformer = CompressionQueryTransformer.builder() + .chatClientBuilder(chatClientBuilder.build().mutate()) + .build(); + + this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder() + .documentRetriever(documentRetriever) + .queryTransformers(queryTransformer) + .build(); + } + + @PostMapping("/rag/compression/{conversationId}") + String rag(@RequestBody String input, @PathVariable String conversationId) { + return chatClient.prompt() + .advisors(chatMemoryAdvisor, retrievalAugmentationAdvisor) + .advisors(advisors -> advisors.param( + AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId)) + .user(input) + .call() + .content(); + } + +} diff --git a/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagControllerMemory.java b/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagControllerMemory.java new file mode 100644 index 0000000..075a840 --- /dev/null +++ b/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagControllerMemory.java @@ -0,0 +1,45 @@ +package com.thomasvitale.ai.spring; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RestController; + +@RestController +class RagControllerMemory { + + private final ChatClient chatClient; + private final MessageChatMemoryAdvisor chatMemoryAdvisor; + private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor; + + RagControllerMemory(ChatClient.Builder chatClientBuilder, ChatMemory chatMemory, VectorStore vectorStore) { + this.chatClient = chatClientBuilder.build(); + this.chatMemoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory) + .build(); + this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder() + .documentRetriever(VectorStoreDocumentRetriever.builder() + .similarityThreshold(0.50) + .vectorStore(vectorStore) + .build()) + .build(); + } + + @PostMapping("/rag/memory/{conversationId}") + String chatWithDocument(@RequestBody String question, @PathVariable String conversationId) { + return chatClient.prompt() + .advisors(chatMemoryAdvisor, retrievalAugmentationAdvisor) + .advisors(advisors -> advisors.param( + AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId)) + .user(question) + .call() + .content(); + } + +} diff --git a/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagControllerRewrite.java b/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagControllerRewrite.java new file mode 100644 index 0000000..e6b7bf7 --- /dev/null +++ b/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagControllerRewrite.java @@ -0,0 +1,46 @@ +package com.thomasvitale.ai.spring; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor; +import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer; +import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RestController; + +@RestController +public class RagControllerRewrite { + + private final ChatClient chatClient; + private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor; + + public RagControllerRewrite(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) { + this.chatClient = chatClientBuilder.build(); + + var documentRetriever = VectorStoreDocumentRetriever.builder() + .vectorStore(vectorStore) + .similarityThreshold(0.50) + .build(); + + var queryTransformer = RewriteQueryTransformer.builder() + .chatClientBuilder(chatClientBuilder.build().mutate()) + .targetSearchSystem("vector store") + .build(); + + this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder() + .documentRetriever(documentRetriever) + .queryTransformers(queryTransformer) + .build(); + } + + @PostMapping("/rag/rewrite") + String rag(@RequestBody String input) { + return chatClient.prompt() + .advisors(retrievalAugmentationAdvisor) + .user(input) + .call() + .content(); + } + +} diff --git a/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/config/HttpClientConfig.java b/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/config/HttpClientConfig.java deleted file mode 100644 index 598482f..0000000 --- a/rag/rag-sequential/rag-advanced/src/main/java/com/thomasvitale/ai/spring/config/HttpClientConfig.java +++ /dev/null @@ -1,27 +0,0 @@ -package com.thomasvitale.ai.spring.config; - -import org.springframework.boot.web.client.ClientHttpRequestFactories; -import org.springframework.boot.web.client.ClientHttpRequestFactorySettings; -import org.springframework.boot.web.client.RestClientCustomizer; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.http.client.BufferingClientHttpRequestFactory; - -import java.time.Duration; - -@Configuration(proxyBeanMethods = false) -public class HttpClientConfig { - -// @Bean -// RestClientCustomizer restClientCustomizer() { -// return restClientBuilder -> { -// restClientBuilder -// .requestFactory(new BufferingClientHttpRequestFactory( -// ClientHttpRequestFactories.get(ClientHttpRequestFactorySettings.DEFAULTS -// .withConnectTimeout(Duration.ofSeconds(60)) -// .withReadTimeout(Duration.ofSeconds(60)) -// ))); -// }; -// } - -}