Skip to content

Commit

Permalink
Update RAG, observability and OpenAI examples
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasVitale committed Nov 14, 2024
1 parent 59df8cd commit 9007ba5
Show file tree
Hide file tree
Showing 32 changed files with 51 additions and 406 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.retrieval.source.VectorStoreDocumentRetriever;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import org.springframework.ai.converter.ListOutputConverter;
import org.springframework.ai.converter.MapOutputConverter;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.ResponseFormat;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
Expand Down Expand Up @@ -38,7 +38,7 @@ ArtistInfo chatBeanOutput(@RequestBody MusicQuestion question) {
.param("instrument", question.instrument())
)
.options(OpenAiChatOptions.builder()
.withResponseFormat(new OpenAiApi.ChatCompletionRequest.ResponseFormat(OpenAiApi.ChatCompletionRequest.ResponseFormat.Type.JSON_OBJECT))
.withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_OBJECT, null))
.build())
.call()
.entity(ArtistInfo.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.springframework.ai.converter.MapOutputConverter;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.ResponseFormat;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
Expand Down Expand Up @@ -44,7 +45,7 @@ ArtistInfo chatBeanOutput(@RequestBody MusicQuestion question) {
"genre", question.genre(),
"format", outputConverter.getFormat());
var prompt = userPromptTemplate.create(model, OpenAiChatOptions.builder()
.withResponseFormat(new OpenAiApi.ChatCompletionRequest.ResponseFormat(OpenAiApi.ChatCompletionRequest.ResponseFormat.Type.JSON_OBJECT))
.withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_OBJECT, null))
.build());

var chatResponse = chatModel.call(prompt);
Expand Down Expand Up @@ -92,7 +93,7 @@ ArtistInfoVariant chatJsonOutput(@RequestBody MusicQuestion question) {
Map<String,Object> model = Map.of("instrument", question.instrument(), "genre", question.genre());
var prompt = userPromptTemplate.create(model, OpenAiChatOptions.builder()
.withModel(OpenAiApi.ChatModel.GPT_4_O.getValue())
.withResponseFormat(new OpenAiApi.ChatCompletionRequest.ResponseFormat(OpenAiApi.ChatCompletionRequest.ResponseFormat.Type.JSON_SCHEMA, outputConverter.getJsonSchema()))
.withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, outputConverter.getJsonSchema()))
.build());

var chatResponse = chatModel.call(prompt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.thomasvitale.ai.spring.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.rag.retrieval.source.VectorStoreDocumentRetriever;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import com.thomasvitale.ai.spring.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.rag.augmentation.ContextualQueryAugmentor;
import org.springframework.ai.rag.retrieval.source.VectorStoreDocumentRetriever;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package com.thomasvitale.ai.spring;

import com.thomasvitale.ai.spring.advisor.RetrievalAugmentationAdvisor;
import com.thomasvitale.ai.spring.rag.preretrieval.query.expansion.MultiQueryExpander;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.rag.retrieval.source.VectorStoreDocumentRetriever;
import org.springframework.ai.rag.analysis.query.expansion.MultiQueryExpander;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package com.thomasvitale.ai.spring;

import com.thomasvitale.ai.spring.advisor.RetrievalAugmentationAdvisor;
import com.thomasvitale.ai.spring.rag.preretrieval.query.expansion.MultiQueryExpander;
import com.thomasvitale.ai.spring.rag.preretrieval.query.transformation.TranslationQueryTransformer;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.rag.retrieval.source.VectorStoreDocumentRetriever;
import org.springframework.ai.rag.analysis.query.expansion.MultiQueryExpander;
import org.springframework.ai.rag.analysis.query.transformation.TranslationQueryTransformer;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import com.thomasvitale.ai.spring.advisor.RetrievalAugmentationAdvisor;
import com.thomasvitale.ai.spring.rag.orchestration.routing.QueryRouter;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.rag.retrieval.source.DocumentRetriever;
import org.springframework.ai.rag.retrieval.source.VectorStoreDocumentRetriever;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import org.springframework.web.bind.annotation.PostMapping;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package com.thomasvitale.ai.spring;

import com.thomasvitale.ai.spring.advisor.RetrievalAugmentationAdvisor;
import com.thomasvitale.ai.spring.rag.preretrieval.query.transformation.TranslationQueryTransformer;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.rag.retrieval.source.VectorStoreDocumentRetriever;
import org.springframework.ai.rag.analysis.query.transformation.TranslationQueryTransformer;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
package com.thomasvitale.ai.spring.advisor;

import com.thomasvitale.ai.spring.rag.analysis.query.expansion.IdentityQueryExpander;
import com.thomasvitale.ai.spring.rag.orchestration.routing.AllRetrieversQueryRouter;
import com.thomasvitale.ai.spring.rag.orchestration.routing.QueryRouter;
import com.thomasvitale.ai.spring.rag.preretrieval.query.expansion.IdentityQueryExpander;
import com.thomasvitale.ai.spring.rag.preretrieval.query.expansion.QueryExpander;
import com.thomasvitale.ai.spring.rag.preretrieval.query.transformation.IdentityQueryTransformer;
import com.thomasvitale.ai.spring.rag.preretrieval.query.transformation.QueryTransformer;
import com.thomasvitale.ai.spring.rag.retrieval.combination.DocumentCombiner;
import com.thomasvitale.ai.spring.rag.retrieval.combination.MergeDocumentCombiner;
import com.thomasvitale.ai.spring.rag.retrieval.fusion.DocumentFuser;
import com.thomasvitale.ai.spring.rag.retrieval.fusion.MergeDocumentFuser;
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
Expand All @@ -19,9 +16,11 @@
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.analysis.query.expansion.QueryExpander;
import org.springframework.ai.rag.analysis.query.transformation.QueryTransformer;
import org.springframework.ai.rag.augmentation.ContextualQueryAugmentor;
import org.springframework.ai.rag.augmentation.QueryAugmentor;
import org.springframework.ai.rag.retrieval.source.DocumentRetriever;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
Expand Down Expand Up @@ -66,19 +65,19 @@ public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAr

private final QueryRouter queryRouter;

private final DocumentCombiner documentCombiner;
private final DocumentFuser documentFuser;

private final QueryAugmentor queryAugmentor;

private final Boolean protectFromBlocking;

private final int order;

public RetrievalAugmentationAdvisor(List<QueryTransformer> queryTransformers, @Nullable QueryExpander queryExpander, QueryRouter queryRouter, @Nullable DocumentCombiner documentCombiner, @Nullable QueryAugmentor queryAugmentor, @Nullable Boolean protectFromBlocking, @Nullable Integer order) {
this.queryTransformers = queryTransformers.isEmpty() ? List.of(new IdentityQueryTransformer()) : queryTransformers;
public RetrievalAugmentationAdvisor(List<QueryTransformer> queryTransformers, @Nullable QueryExpander queryExpander, QueryRouter queryRouter, @Nullable DocumentFuser documentFuser, @Nullable QueryAugmentor queryAugmentor, @Nullable Boolean protectFromBlocking, @Nullable Integer order) {
this.queryTransformers = queryTransformers.isEmpty() ? List.of() : queryTransformers;
this.queryExpander = queryExpander != null ? queryExpander : new IdentityQueryExpander();
this.queryRouter = queryRouter;
this.documentCombiner = documentCombiner != null ? documentCombiner : new MergeDocumentCombiner();
this.documentFuser = documentFuser != null ? documentFuser : new MergeDocumentFuser();
this.queryAugmentor = queryAugmentor != null ? queryAugmentor : ContextualQueryAugmentor.builder().build();
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : false;
this.order = order != null ? order : 0;
Expand Down Expand Up @@ -141,12 +140,12 @@ private AdvisedRequest before(AdvisedRequest request) {
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> entry.getValue().stream()
.map(retriever -> retriever.retrieve(entry.getKey().text()))
.map(retriever -> retriever.retrieve(entry.getKey()))
.toList()
));

// 5. Combine documents retrieved across multiple queries and retrievers.
List<Document> documents = documentCombiner.combine(documentsForQuery);
List<Document> documents = documentFuser.fuse(documentsForQuery);
context.put(DOCUMENT_CONTEXT, documents);

// 6. Augment user query with the document contextual data.
Expand Down Expand Up @@ -191,7 +190,7 @@ public static class Builder {
private List<QueryTransformer> queryTransformers = new ArrayList<>();
private QueryExpander queryExpander;
private QueryRouter queryRouter;
private DocumentCombiner documentCombiner;
private DocumentFuser documentFuser;
private QueryAugmentor queryAugmentor;
private Boolean protectFromBlocking;
private Integer order;
Expand Down Expand Up @@ -227,8 +226,8 @@ public Builder documentRetriever(DocumentRetriever documentRetriever) {
return this;
}

public Builder documentCombiner(DocumentCombiner documentCombiner) {
this.documentCombiner = documentCombiner;
public Builder documentCombiner(DocumentFuser documentFuser) {
this.documentFuser = documentFuser;
return this;
}

Expand All @@ -248,7 +247,7 @@ public Builder order(Integer order) {
}

public RetrievalAugmentationAdvisor build() {
return new RetrievalAugmentationAdvisor(queryTransformers, queryExpander, queryRouter, documentCombiner, queryAugmentor, protectFromBlocking, order);
return new RetrievalAugmentationAdvisor(queryTransformers, queryExpander, queryRouter, documentFuser, queryAugmentor, protectFromBlocking, order);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
package com.thomasvitale.ai.spring.rag.preretrieval.query.expansion;
package com.thomasvitale.ai.spring.rag.analysis.query.expansion;

import org.springframework.ai.rag.Query;
import org.springframework.util.Assert;
import org.springframework.ai.rag.analysis.query.expansion.QueryExpander;

import java.util.List;

/**
* An expander that keeps the query as is.
*/
public class IdentityQueryExpander implements QueryExpander {

@Override
public List<Query> expand(Query query) {
Assert.notNull(query, "query cannot be null");
return List.of(query);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.thomasvitale.ai.spring.rag.orchestration.routing;

import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.retrieval.source.DocumentRetriever;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.util.Assert;

import java.util.Arrays;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.thomasvitale.ai.spring.rag.orchestration.routing;

import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.retrieval.source.DocumentRetriever;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;

import java.util.List;
import java.util.function.Function;
Expand Down

This file was deleted.

This file was deleted.

Loading

0 comments on commit 9007ba5

Please sign in to comment.