Skip to content

Commit

Permalink
Update use cases
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasVitale committed May 28, 2024
1 parent b208f2a commit f9588c8
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
import org.springframework.ai.tokenizer.TokenCountEstimator;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
Expand All @@ -20,9 +18,4 @@ ChatMemory chatHistory() {
return new InMemoryChatMemory();
}

@Bean
TokenCountEstimator tokenCountEstimator() {
return new JTokkitTokenCountEstimator();
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.thomasvitale.ai.spring;

import jakarta.servlet.http.HttpServletRequest;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;
Expand All @@ -14,8 +15,8 @@ class ChatbotController {
}

@PostMapping("/chat")
String chat(@RequestBody String input) {
return chatbotService.chat(input);
String chat(@RequestBody String input, HttpServletRequest request) {
return chatbotService.chat(request.getSession().getId(), input);
}

}
Original file line number Diff line number Diff line change
@@ -1,35 +1,30 @@
package com.thomasvitale.ai.spring;

import org.springframework.ai.chat.memory.*;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.transformer.ChatServiceContext;
import org.springframework.ai.chat.service.ChatService;
import org.springframework.ai.chat.service.PromptTransformingChatService;
import org.springframework.ai.tokenizer.TokenCountEstimator;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.stereotype.Service;

import java.util.List;
import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY;

@Service
class ChatbotService {

private final ChatService chatService;
private final ChatClient chatClient;

ChatbotService(ChatModel chatModel, ChatMemory chatMemory, TokenCountEstimator tokenCountEstimator) {
this.chatService = PromptTransformingChatService.builder(chatModel)
.withRetrievers(List.of(new ChatMemoryRetriever(chatMemory)))
.withContentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
.withAugmentors(List.of(new SystemPromptChatMemoryAugmentor()))
.withChatServiceListeners(List.of(new ChatMemoryChatServiceListener(chatMemory)))
ChatbotService(ChatClient.Builder chatClientBuilder, ChatMemory chatMemory) {
this.chatClient = chatClientBuilder
.defaultAdvisors(new MessageChatMemoryAdvisor(chatMemory))
.build();
}

String chat(String message) {
var prompt = new Prompt(new UserMessage(message));
var chatServiceResponse = this.chatService.call(new ChatServiceContext(prompt));
return chatServiceResponse.getChatResponse().getResult().getOutput().getContent();
String chat(String chatId, String message) {
return chatClient
.prompt()
.user(message)
.advisors(a -> a.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId))
.call()
.content();
}

}
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
package com.thomasvitale.ai.spring;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;

import java.util.List;
import java.util.stream.Collectors;

@Service
class ChatService {

Expand All @@ -21,24 +18,8 @@ class ChatService {
}

String chatWithDocument(String message) {
var systemPromptTemplate = """
You are a helpful assistant, conversing with a user about the subjects contained in a set of documents.
Use the information from the DOCUMENTS section to provide accurate answers. If unsure or if the answer
isn't found in the DOCUMENTS section, simply state that you don't know the answer and do not mention
the DOCUMENTS section.
DOCUMENTS:
{documents}
""";

List<Document> similarDocuments = vectorStore.similaritySearch(SearchRequest.query(message).withTopK(5));
String content = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining(System.lineSeparator()));

return chatClient.prompt()
.system(systemSpec -> systemSpec
.text(systemPromptTemplate)
.param("documents", content)
)
.advisors(new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults().withTopK(3)))
.user(message)
.call()
.content();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.TextReader;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
Expand Down Expand Up @@ -48,7 +48,7 @@ public void run() {
documents.addAll(textReader2.get());

logger.info("Creating and storing Embeddings from Documents");
vectorStore.add(documents);
vectorStore.add(new TokenTextSplitter().split(documents));
}

}

0 comments on commit f9588c8

Please sign in to comment.