-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5c92b3a
commit 1f1e427
Showing
33 changed files
with
1,170 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Retrieval Augmented Generation (RAG): Advanced | ||
|
||
Ask questions about documents with LLMs via Ollama and PGVector. | ||
|
||
## Running the application | ||
|
||
The application relies on Ollama for providing LLMs. You can either run Ollama locally on your laptop, or rely on the Testcontainers support in Spring Boot to spin up an Ollama service automatically. | ||
Either way, Spring AI will take care of pulling the needed Ollama models if not already available in your instance. | ||
Furthermore, the application relies on the native Testcontainers support in Spring Boot to spin up a PostgreSQL database with the pgvector extension for embeddings. | ||
|
||
### Ollama as a native application | ||
|
||
First, make sure you have [Ollama](https://ollama.ai) installed on your laptop. | ||
|
||
Then, run the Spring Boot application. | ||
|
||
```shell | ||
./gradlew bootTestRun | ||
``` | ||
|
||
### Ollama as a dev service with Testcontainers | ||
|
||
The application relies on the native Testcontainers support in Spring Boot to spin up an Ollama service. | ||
|
||
```shell | ||
./gradlew bootTestRun --args='--spring.profiles.active=ollama-image' | ||
``` | ||
|
||
## Calling the application | ||
|
||
You can now call the application that will use Ollama to load text documents as embeddings and generate an answer to your questions based on those documents (RAG pattern). | ||
This example uses [httpie](https://httpie.io) to send HTTP requests. | ||
|
||
```shell | ||
http --raw "What is Iorek's biggest dream?" :8080/chat/doc | ||
``` | ||
|
||
```shell | ||
http --raw "Who is Lucio?" :8080/chat/doc | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
plugins { | ||
id 'java' | ||
id 'org.springframework.boot' | ||
id 'io.spring.dependency-management' | ||
} | ||
|
||
group = 'com.thomasvitale' | ||
version = '0.0.1-SNAPSHOT' | ||
|
||
java { | ||
toolchain { | ||
languageVersion = JavaLanguageVersion.of(23) | ||
} | ||
} | ||
|
||
repositories { | ||
mavenCentral() | ||
maven { url 'https://repo.spring.io/milestone' } | ||
maven { url 'https://repo.spring.io/snapshot' } | ||
} | ||
|
||
dependencies { | ||
implementation platform("org.springframework.ai:spring-ai-bom:${springAiVersion}") | ||
|
||
implementation 'org.springframework.boot:spring-boot-starter-web' | ||
implementation 'org.springframework.ai:spring-ai-ollama-spring-boot-starter' | ||
implementation 'org.springframework.ai:spring-ai-pgvector-store-spring-boot-starter' | ||
|
||
testAndDevelopmentOnly 'org.springframework.boot:spring-boot-devtools' | ||
|
||
testImplementation 'org.springframework.boot:spring-boot-starter-test' | ||
testImplementation 'org.springframework.boot:spring-boot-testcontainers' | ||
testImplementation 'org.springframework.ai:spring-ai-spring-boot-testcontainers' | ||
testImplementation 'org.springframework:spring-webflux' | ||
testImplementation 'org.testcontainers:ollama' | ||
testImplementation 'org.testcontainers:postgresql' | ||
} | ||
|
||
tasks.named('test') { | ||
useJUnitPlatform() | ||
} |
21 changes: 21 additions & 0 deletions
21
09-rag/rag-advanced/src/main/java/com/thomasvitale/ai/spring/ChatController.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
package com.thomasvitale.ai.spring; | ||
|
||
import org.springframework.web.bind.annotation.PostMapping; | ||
import org.springframework.web.bind.annotation.RequestBody; | ||
import org.springframework.web.bind.annotation.RestController; | ||
|
||
@RestController | ||
class ChatController { | ||
|
||
private final ChatService chatService; | ||
|
||
ChatController(ChatService chatService) { | ||
this.chatService = chatService; | ||
} | ||
|
||
@PostMapping("/chat/doc") | ||
String chatWithDocument(@RequestBody String input) { | ||
return chatService.chatWithDocument(input); | ||
} | ||
|
||
} |
33 changes: 33 additions & 0 deletions
33
09-rag/rag-advanced/src/main/java/com/thomasvitale/ai/spring/ChatService.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package com.thomasvitale.ai.spring; | ||
|
||
import com.thomasvitale.ai.spring.rag.RetrievalAugmentationAdvisor; | ||
import com.thomasvitale.ai.spring.rag.retriever.VectorStoreDocumentRetriever; | ||
import org.springframework.ai.chat.client.ChatClient; | ||
import org.springframework.ai.document.DocumentRetriever; | ||
import org.springframework.ai.vectorstore.VectorStore; | ||
import org.springframework.stereotype.Service; | ||
|
||
@Service | ||
class ChatService { | ||
|
||
private final ChatClient chatClient; | ||
private final DocumentRetriever documentRetriever; | ||
|
||
ChatService(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) { | ||
this.chatClient = chatClientBuilder.build(); | ||
this.documentRetriever = VectorStoreDocumentRetriever.builder() | ||
.withVectorStore(vectorStore) | ||
.build(); | ||
} | ||
|
||
String chatWithDocument(String message) { | ||
return chatClient.prompt() | ||
.advisors(RetrievalAugmentationAdvisor.builder() | ||
.withDocumentRetriever(documentRetriever) | ||
.build()) | ||
.user(message) | ||
.call() | ||
.content(); | ||
} | ||
|
||
} |
54 changes: 54 additions & 0 deletions
54
09-rag/rag-advanced/src/main/java/com/thomasvitale/ai/spring/IngestionPipeline.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package com.thomasvitale.ai.spring; | ||
|
||
import jakarta.annotation.PostConstruct; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
import org.springframework.ai.document.Document; | ||
import org.springframework.ai.reader.TextReader; | ||
import org.springframework.ai.transformer.splitter.TokenTextSplitter; | ||
import org.springframework.ai.vectorstore.VectorStore; | ||
import org.springframework.beans.factory.annotation.Value; | ||
import org.springframework.core.io.Resource; | ||
import org.springframework.stereotype.Component; | ||
|
||
import java.nio.charset.Charset; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
@Component | ||
public class IngestionPipeline { | ||
|
||
private static final Logger logger = LoggerFactory.getLogger(IngestionPipeline.class); | ||
private final VectorStore vectorStore; | ||
|
||
@Value("classpath:documents/story1.md") | ||
Resource textFile1; | ||
|
||
@Value("classpath:documents/story2.txt") | ||
Resource textFile2; | ||
|
||
public IngestionPipeline(VectorStore vectorStore) { | ||
this.vectorStore = vectorStore; | ||
} | ||
|
||
@PostConstruct | ||
public void run() { | ||
List<Document> documents = new ArrayList<>(); | ||
|
||
logger.info("Loading .md files as Documents"); | ||
var textReader1 = new TextReader(textFile1); | ||
textReader1.getCustomMetadata().put("location", "North Pole"); | ||
textReader1.setCharset(Charset.defaultCharset()); | ||
documents.addAll(textReader1.get()); | ||
|
||
logger.info("Loading .txt files as Documents"); | ||
var textReader2 = new TextReader(textFile2); | ||
textReader2.getCustomMetadata().put("location", "Italy"); | ||
textReader2.setCharset(Charset.defaultCharset()); | ||
documents.addAll(textReader2.get()); | ||
|
||
logger.info("Creating and storing Embeddings from Documents"); | ||
vectorStore.add(new TokenTextSplitter().split(documents)); | ||
} | ||
|
||
} |
13 changes: 13 additions & 0 deletions
13
09-rag/rag-advanced/src/main/java/com/thomasvitale/ai/spring/RagAdvanced.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
package com.thomasvitale.ai.spring; | ||
|
||
import org.springframework.boot.SpringApplication; | ||
import org.springframework.boot.autoconfigure.SpringBootApplication; | ||
|
||
@SpringBootApplication | ||
public class RagAdvanced { | ||
|
||
public static void main(String[] args) { | ||
SpringApplication.run(RagAdvanced.class, args); | ||
} | ||
|
||
} |
185 changes: 185 additions & 0 deletions
185
...g-advanced/src/main/java/com/thomasvitale/ai/spring/rag/RetrievalAugmentationAdvisor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
package com.thomasvitale.ai.spring.rag; | ||
|
||
import com.thomasvitale.ai.spring.rag.injector.DefaultDocumentInjector; | ||
import com.thomasvitale.ai.spring.rag.injector.DocumentInjector; | ||
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor; | ||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; | ||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; | ||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; | ||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; | ||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; | ||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; | ||
import org.springframework.ai.chat.messages.UserMessage; | ||
import org.springframework.ai.chat.model.ChatResponse; | ||
import org.springframework.ai.chat.prompt.PromptTemplate; | ||
import org.springframework.ai.document.Document; | ||
import org.springframework.ai.document.DocumentRetriever; | ||
import org.springframework.ai.model.Content; | ||
import org.springframework.lang.Nullable; | ||
import org.springframework.util.StringUtils; | ||
import reactor.core.publisher.Flux; | ||
import reactor.core.publisher.Mono; | ||
import reactor.core.scheduler.Schedulers; | ||
|
||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.function.Predicate; | ||
import java.util.stream.Collectors; | ||
|
||
/** | ||
* An advisor that retrieves similar documents based on the user's query | ||
* and augments the query with the retrieved documents to provide context-aware responses. | ||
* This advisor implements a Retrieval-Augmented Generation (RAG) workflow. | ||
* | ||
* <p>Example usage: | ||
* <pre>{@code | ||
* DocumentRetriever documentRetriever = ...; | ||
* RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder() | ||
* .withDocumentRetriever(documentRetriever) | ||
* .build(); | ||
* String response = chatClient.prompt(query) | ||
* .advisors(ragAdvisor) | ||
* .call() | ||
* .content() | ||
* }</pre> | ||
* | ||
* The implementation is based on the built-in {@link QuestionAnswerAdvisor}. | ||
*/ | ||
public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { | ||
|
||
private static final int DEFAULT_ORDER = 0; | ||
|
||
private static final DocumentInjector DEFAULT_DOCUMENT_INJECTOR = DefaultDocumentInjector.builder().build(); | ||
|
||
public static final String RETRIEVED_DOCUMENTS = "qa_retrieved_documents"; | ||
|
||
private final DocumentRetriever documentRetriever; | ||
|
||
private final DocumentInjector documentInjector; | ||
|
||
private final Boolean protectFromBlocking; | ||
|
||
private final int order; | ||
|
||
public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable DocumentInjector documentInjector, @Nullable Boolean protectFromBlocking, @Nullable Integer order) { | ||
this.documentRetriever = documentRetriever; | ||
this.documentInjector = documentInjector != null ? documentInjector : DEFAULT_DOCUMENT_INJECTOR; | ||
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : false; | ||
this.order = order != null ? order : DEFAULT_ORDER; | ||
} | ||
|
||
@Override | ||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { | ||
// transform query | ||
// list of queries | ||
// for each query, router, then retriever | ||
// map query -> List<Document> | ||
// aggregate documents | ||
// add context to original prompt | ||
|
||
AdvisedRequest processedAdvisedRequest = before(advisedRequest); | ||
|
||
AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest); | ||
|
||
return after(advisedResponse); | ||
} | ||
|
||
@Override | ||
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { | ||
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ? | ||
Mono.just(advisedRequest) | ||
.publishOn(Schedulers.boundedElastic()) | ||
.map(this::before) | ||
.flatMapMany(chain::nextAroundStream) | ||
: chain.nextAroundStream(before(advisedRequest)); | ||
|
||
return advisedResponses.map(ar -> { | ||
if (onFinishReason().test(ar)) { | ||
ar = after(ar); | ||
} | ||
return ar; | ||
}); | ||
} | ||
|
||
private AdvisedRequest before(AdvisedRequest request) { | ||
var context = new HashMap<>(request.adviseContext()); | ||
|
||
var userMessage = new PromptTemplate(request.userText(), request.userParams()).render(); | ||
|
||
// 1. Retrieve similar documents. | ||
List<Document> documents = documentRetriever.retrieve(userMessage); | ||
context.put(RETRIEVED_DOCUMENTS, documents); | ||
|
||
// 2. Inject documents into user message. | ||
UserMessage augmentedUserMessage = documentInjector.inject(userMessage, documents); | ||
|
||
// 3. Build advised request with augmented prompt. | ||
return AdvisedRequest.from(request) | ||
.withUserText(augmentedUserMessage.getContent()) | ||
.withAdviseContext(context) | ||
.build(); | ||
} | ||
|
||
private AdvisedResponse after(AdvisedResponse advisedResponse) { | ||
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response()); | ||
chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, advisedResponse.adviseContext().get(RETRIEVED_DOCUMENTS)); | ||
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext()); | ||
} | ||
|
||
private Predicate<AdvisedResponse> onFinishReason() { | ||
return (advisedResponse) -> advisedResponse.response() | ||
.getResults() | ||
.stream() | ||
.anyMatch(result -> result != null && result.getMetadata() != null | ||
&& StringUtils.hasText(result.getMetadata().getFinishReason())); | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return this.getClass().getSimpleName(); | ||
} | ||
|
||
@Override | ||
public int getOrder() { | ||
return order; | ||
} | ||
|
||
public static Builder builder() { | ||
return new Builder(); | ||
} | ||
|
||
public static class Builder { | ||
private DocumentRetriever documentRetriever; | ||
private DocumentInjector documentInjector; | ||
private Boolean protectFromBlocking; | ||
private Integer order; | ||
|
||
private Builder() {} | ||
|
||
public Builder withDocumentRetriever(DocumentRetriever documentRetriever) { | ||
this.documentRetriever = documentRetriever; | ||
return this; | ||
} | ||
|
||
public Builder withDocumentInjector(DocumentInjector documentInjector) { | ||
this.documentInjector = documentInjector; | ||
return this; | ||
} | ||
|
||
public Builder withProtectFromBlocking(Boolean protectFromBlocking) { | ||
this.protectFromBlocking = protectFromBlocking; | ||
return this; | ||
} | ||
|
||
public Builder withOrder(Integer order) { | ||
this.order = order; | ||
return this; | ||
} | ||
|
||
public RetrievalAugmentationAdvisor build() { | ||
return new RetrievalAugmentationAdvisor(documentRetriever, documentInjector, protectFromBlocking, order); | ||
} | ||
} | ||
|
||
} |
Oops, something went wrong.