Skip to content

Commit

Permalink
Add initial RAG examples
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasVitale committed Oct 20, 2024
1 parent 5c92b3a commit 1f1e427
Show file tree
Hide file tree
Showing 33 changed files with 1,170 additions and 1 deletion.
40 changes: 40 additions & 0 deletions 09-rag/rag-advanced/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Retrieval Augmented Generation (RAG): Advanced

Ask questions about documents with LLMs via Ollama and PGVector.

## Running the application

The application relies on Ollama for providing LLMs. You can either run Ollama locally on your laptop, or rely on the Testcontainers support in Spring Boot to spin up an Ollama service automatically.
Either way, Spring AI will take care of pulling the needed Ollama models if not already available in your instance.
Furthermore, the application relies on the native Testcontainers support in Spring Boot to spin up a PostgreSQL database with the pgvector extension for embeddings.

### Ollama as a native application

First, make sure you have [Ollama](https://ollama.ai) installed on your laptop.

Then, run the Spring Boot application.

```shell
./gradlew bootTestRun
```

### Ollama as a dev service with Testcontainers

The application relies on the native Testcontainers support in Spring Boot to spin up an Ollama service.

```shell
./gradlew bootTestRun --args='--spring.profiles.active=ollama-image'
```

## Calling the application

You can now call the application that will use Ollama to load text documents as embeddings and generate an answer to your questions based on those documents (RAG pattern).
This example uses [httpie](https://httpie.io) to send HTTP requests.

```shell
http --raw "What is Iorek's biggest dream?" :8080/chat/doc
```

```shell
http --raw "Who is Lucio?" :8080/chat/doc
```
41 changes: 41 additions & 0 deletions 09-rag/rag-advanced/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
plugins {
id 'java'
id 'org.springframework.boot'
id 'io.spring.dependency-management'
}

group = 'com.thomasvitale'
version = '0.0.1-SNAPSHOT'

java {
toolchain {
languageVersion = JavaLanguageVersion.of(23)
}
}

repositories {
mavenCentral()
maven { url 'https://repo.spring.io/milestone' }
maven { url 'https://repo.spring.io/snapshot' }
}

dependencies {
implementation platform("org.springframework.ai:spring-ai-bom:${springAiVersion}")

implementation 'org.springframework.boot:spring-boot-starter-web'
implementation 'org.springframework.ai:spring-ai-ollama-spring-boot-starter'
implementation 'org.springframework.ai:spring-ai-pgvector-store-spring-boot-starter'

testAndDevelopmentOnly 'org.springframework.boot:spring-boot-devtools'

testImplementation 'org.springframework.boot:spring-boot-starter-test'
testImplementation 'org.springframework.boot:spring-boot-testcontainers'
testImplementation 'org.springframework.ai:spring-ai-spring-boot-testcontainers'
testImplementation 'org.springframework:spring-webflux'
testImplementation 'org.testcontainers:ollama'
testImplementation 'org.testcontainers:postgresql'
}

tasks.named('test') {
useJUnitPlatform()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.thomasvitale.ai.spring;

import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;

@RestController
class ChatController {

private final ChatService chatService;

ChatController(ChatService chatService) {
this.chatService = chatService;
}

@PostMapping("/chat/doc")
String chatWithDocument(@RequestBody String input) {
return chatService.chatWithDocument(input);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.thomasvitale.ai.spring;

import com.thomasvitale.ai.spring.rag.RetrievalAugmentationAdvisor;
import com.thomasvitale.ai.spring.rag.retriever.VectorStoreDocumentRetriever;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.DocumentRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;

@Service
class ChatService {

private final ChatClient chatClient;
private final DocumentRetriever documentRetriever;

ChatService(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {
this.chatClient = chatClientBuilder.build();
this.documentRetriever = VectorStoreDocumentRetriever.builder()
.withVectorStore(vectorStore)
.build();
}

String chatWithDocument(String message) {
return chatClient.prompt()
.advisors(RetrievalAugmentationAdvisor.builder()
.withDocumentRetriever(documentRetriever)
.build())
.user(message)
.call()
.content();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.thomasvitale.ai.spring;

import jakarta.annotation.PostConstruct;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.TextReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Component;

import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;

@Component
public class IngestionPipeline {

private static final Logger logger = LoggerFactory.getLogger(IngestionPipeline.class);
private final VectorStore vectorStore;

@Value("classpath:documents/story1.md")
Resource textFile1;

@Value("classpath:documents/story2.txt")
Resource textFile2;

public IngestionPipeline(VectorStore vectorStore) {
this.vectorStore = vectorStore;
}

@PostConstruct
public void run() {
List<Document> documents = new ArrayList<>();

logger.info("Loading .md files as Documents");
var textReader1 = new TextReader(textFile1);
textReader1.getCustomMetadata().put("location", "North Pole");
textReader1.setCharset(Charset.defaultCharset());
documents.addAll(textReader1.get());

logger.info("Loading .txt files as Documents");
var textReader2 = new TextReader(textFile2);
textReader2.getCustomMetadata().put("location", "Italy");
textReader2.setCharset(Charset.defaultCharset());
documents.addAll(textReader2.get());

logger.info("Creating and storing Embeddings from Documents");
vectorStore.add(new TokenTextSplitter().split(documents));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.thomasvitale.ai.spring;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class RagAdvanced {

public static void main(String[] args) {
SpringApplication.run(RagAdvanced.class, args);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package com.thomasvitale.ai.spring.rag;

import com.thomasvitale.ai.spring.rag.injector.DefaultDocumentInjector;
import com.thomasvitale.ai.spring.rag.injector.DocumentInjector;
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentRetriever;
import org.springframework.ai.model.Content;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;

/**
* An advisor that retrieves similar documents based on the user's query
* and augments the query with the retrieved documents to provide context-aware responses.
* This advisor implements a Retrieval-Augmented Generation (RAG) workflow.
*
* <p>Example usage:
* <pre>{@code
* DocumentRetriever documentRetriever = ...;
* RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder()
* .withDocumentRetriever(documentRetriever)
* .build();
* String response = chatClient.prompt(query)
* .advisors(ragAdvisor)
* .call()
* .content()
* }</pre>
*
* The implementation is based on the built-in {@link QuestionAnswerAdvisor}.
*/
public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {

private static final int DEFAULT_ORDER = 0;

private static final DocumentInjector DEFAULT_DOCUMENT_INJECTOR = DefaultDocumentInjector.builder().build();

public static final String RETRIEVED_DOCUMENTS = "qa_retrieved_documents";

private final DocumentRetriever documentRetriever;

private final DocumentInjector documentInjector;

private final Boolean protectFromBlocking;

private final int order;

public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable DocumentInjector documentInjector, @Nullable Boolean protectFromBlocking, @Nullable Integer order) {
this.documentRetriever = documentRetriever;
this.documentInjector = documentInjector != null ? documentInjector : DEFAULT_DOCUMENT_INJECTOR;
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : false;
this.order = order != null ? order : DEFAULT_ORDER;
}

@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
// transform query
// list of queries
// for each query, router, then retriever
// map query -> List<Document>
// aggregate documents
// add context to original prompt

AdvisedRequest processedAdvisedRequest = before(advisedRequest);

AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest);

return after(advisedResponse);
}

@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ?
Mono.just(advisedRequest)
.publishOn(Schedulers.boundedElastic())
.map(this::before)
.flatMapMany(chain::nextAroundStream)
: chain.nextAroundStream(before(advisedRequest));

return advisedResponses.map(ar -> {
if (onFinishReason().test(ar)) {
ar = after(ar);
}
return ar;
});
}

private AdvisedRequest before(AdvisedRequest request) {
var context = new HashMap<>(request.adviseContext());

var userMessage = new PromptTemplate(request.userText(), request.userParams()).render();

// 1. Retrieve similar documents.
List<Document> documents = documentRetriever.retrieve(userMessage);
context.put(RETRIEVED_DOCUMENTS, documents);

// 2. Inject documents into user message.
UserMessage augmentedUserMessage = documentInjector.inject(userMessage, documents);

// 3. Build advised request with augmented prompt.
return AdvisedRequest.from(request)
.withUserText(augmentedUserMessage.getContent())
.withAdviseContext(context)
.build();
}

private AdvisedResponse after(AdvisedResponse advisedResponse) {
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response());
chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, advisedResponse.adviseContext().get(RETRIEVED_DOCUMENTS));
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
}

private Predicate<AdvisedResponse> onFinishReason() {
return (advisedResponse) -> advisedResponse.response()
.getResults()
.stream()
.anyMatch(result -> result != null && result.getMetadata() != null
&& StringUtils.hasText(result.getMetadata().getFinishReason()));
}

@Override
public String getName() {
return this.getClass().getSimpleName();
}

@Override
public int getOrder() {
return order;
}

public static Builder builder() {
return new Builder();
}

public static class Builder {
private DocumentRetriever documentRetriever;
private DocumentInjector documentInjector;
private Boolean protectFromBlocking;
private Integer order;

private Builder() {}

public Builder withDocumentRetriever(DocumentRetriever documentRetriever) {
this.documentRetriever = documentRetriever;
return this;
}

public Builder withDocumentInjector(DocumentInjector documentInjector) {
this.documentInjector = documentInjector;
return this;
}

public Builder withProtectFromBlocking(Boolean protectFromBlocking) {
this.protectFromBlocking = protectFromBlocking;
return this;
}

public Builder withOrder(Integer order) {
this.order = order;
return this;
}

public RetrievalAugmentationAdvisor build() {
return new RetrievalAugmentationAdvisor(documentRetriever, documentInjector, protectFromBlocking, order);
}
}

}
Loading

0 comments on commit 1f1e427

Please sign in to comment.