Skip to content

Commit

Permalink
Optional usage of history memory store implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
jschm42 committed Jun 20, 2024
1 parent a8c8b77 commit 4693360
Show file tree
Hide file tree
Showing 14 changed files with 136 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public record AssistantDto(

public enum MemoryType {
NONE,
ASSISTANT,
GLOBAL
AI_DECIDES,
HISTORY
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,14 @@
import com.talkforgeai.backend.assistant.dto.ThreadTitleGenerationRequestDto;
import com.talkforgeai.backend.assistant.dto.ThreadTitleUpdateRequestDto;
import com.talkforgeai.backend.assistant.exception.AssistentException;
import com.talkforgeai.backend.assistant.functions.ContextStorageFunction;
import com.talkforgeai.backend.assistant.functions.ContextStorageFunction.Request;
import com.talkforgeai.backend.assistant.functions.ContextStorageFunction.Response;
import com.talkforgeai.backend.assistant.functions.ContextTool;
import com.talkforgeai.backend.assistant.functions.FunctionContext;
import com.talkforgeai.backend.assistant.repository.AssistantRepository;
import com.talkforgeai.backend.assistant.repository.MessageRepository;
import com.talkforgeai.backend.assistant.repository.ThreadRepository;
import com.talkforgeai.backend.memory.dto.DocumentWithoutEmbeddings;
import com.talkforgeai.backend.memory.functions.MemoryContextStorageFunction;
import com.talkforgeai.backend.memory.functions.MemoryContextStorageFunction.Request;
import com.talkforgeai.backend.memory.functions.MemoryContextStorageFunction.Response;
import com.talkforgeai.backend.memory.functions.MemoryFunctionContext;
import com.talkforgeai.backend.memory.service.MemoryService;
import com.talkforgeai.backend.storage.FileStorageService;
import com.talkforgeai.backend.transformers.MessageProcessor;
Expand Down Expand Up @@ -79,12 +78,14 @@
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackWrapper;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter.Expression;
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.data.domain.PageRequest;
Expand Down Expand Up @@ -272,7 +273,7 @@ public Flux<ServerSentEvent<String>> streamRunConversation(String assistantId, S
.flux()
.flatMap(initInfos -> {
List<DocumentWithoutEmbeddings> memorySearchResults = getMemorySearchResults(
initInfos.assistantDto, userMessage);
initInfos.assistantDto.id(), initInfos.assistantDto.memory(), userMessage);

return Flux.just(
new PreparedInfos(initInfos.assistantDto(), initInfos.pastMessages(),
Expand All @@ -287,21 +288,24 @@ public Flux<ServerSentEvent<String>> streamRunConversation(String assistantId, S
pastMessagesList,
assistantDto, memoryResultsList);

FunctionCallbackWrapper<Request, Response> memoryFunctionCallback = getMemoryFunctionCallback(
assistantId, runId, assistantDto);
List<FunctionCallback> functionCallbacks = new ArrayList<>();
FunctionCallbackWrapper<Request, Response> memoryFunctionCallback
= getMemoryFunctionCallback(assistantId, assistantDto.name(), runId);

if (assistantDto.memory() == MemoryType.AI_DECIDES) {
functionCallbacks.add(memoryFunctionCallback);
}

ChatOptions promptOptions = universalChatService.getPromptOptions(assistantDto,
List.of(memoryFunctionCallback));
functionCallbacks);

LOGGER.debug("Starting stream with prompt: {}", finalPromptMessageList);
LOGGER.debug("Prompt Options: {}",
universalChatService.printPromptOptions(assistantDto.system(),
promptOptions));

Prompt prompt = new Prompt(finalPromptMessageList, promptOptions);

return universalChatService.stream(assistantDto.system(), finalPromptMessageList,
userMessage, assistantId, promptOptions);
return universalChatService.stream(assistantDto,
finalPromptMessageList, userMessage, promptOptions);
})
.doOnCancel(() -> {
LOGGER.debug("doOnCancel. userMessage={}", assistantMessageContent);
Expand Down Expand Up @@ -354,18 +358,18 @@ public Flux<ServerSentEvent<String>> streamRunConversation(String assistantId, S
}

private FunctionCallbackWrapper<Request, Response> getMemoryFunctionCallback(
String assistantId, String runId, AssistantDto assistantDto) {
String assistantId, String assistantName, String runId) {
return FunctionCallbackWrapper.builder(
new ContextStorageFunction(memoryService,
new FunctionContext(LlmSystem.OPENAI,
new MemoryContextStorageFunction(memoryService,
new MemoryFunctionContext(LlmSystem.OPENAI,
OpenAiEmbeddingProperties.DEFAULT_EMBEDDING_MODEL,
assistantId,
assistantDto.name(),
assistantName,
runId
)))
.withDescription(
"Store relevant information in the vector database for later retrieval.")
.withName(ContextTool.MEMORY_STORE.getFunctionBeanName())
.withName(MemoryContextStorageFunction.NAME)
.build();
}

Expand All @@ -392,31 +396,32 @@ private FunctionCallbackWrapper<Request, Response> getMemoryFunctionCallback(
.subscribeOn(Schedulers.boundedElastic());
}

private @NotNull List<DocumentWithoutEmbeddings> getMemorySearchResults(AssistantDto assistantDto,
private @NotNull List<DocumentWithoutEmbeddings> getMemorySearchResults(String assistantId,
MemoryType memoryType,
String message) {

if (assistantDto.memory() == MemoryType.NONE) {
if (memoryType == MemoryType.NONE) {
return List.of();
}

LOGGER.info("Searching memory for message: {}", message);
List<DocumentWithoutEmbeddings> searchResults = memoryService.search(
SearchRequest.query(message).withSimilarityThreshold(0.75f));

if (assistantDto.memory() == MemoryType.ASSISTANT) {
List<DocumentWithoutEmbeddings> filteredMemory = searchResults.stream()
.filter(m -> m.assistantId() != null && m.assistantId().equals(assistantDto.id()))
.toList();
FilterExpressionBuilder expressionBuilder = new FilterExpressionBuilder();
Expression assistantExpression = expressionBuilder.eq("assistantId", assistantId).build();

LOGGER.debug("Memory search results for assistant '{}': {}", assistantDto.id(),
filteredMemory);
List<DocumentWithoutEmbeddings> searchResults = memoryService.search(
SearchRequest.query(message)
.withFilterExpression(assistantExpression)
.withSimilarityThreshold(0.75f));

return filteredMemory;
}
List<DocumentWithoutEmbeddings> filteredMemory = searchResults.stream()
.filter(m -> m.assistantId() != null && m.assistantId().equals(assistantId))
.toList();

LOGGER.debug("Memory search results: {}", searchResults);
LOGGER.debug("Memory search results for assistant '{}': {}", assistantId,
filteredMemory);

return searchResults;
return filteredMemory;
}

private ServerSentEvent<String> createResponseSseEvent(ChatResponse chatResponse) {
Expand Down Expand Up @@ -519,8 +524,6 @@ public ThreadTitleDto generateThreadTitle(String threadId,
.withMaxTokens(256)
.build();

Prompt titlePrompt = new Prompt(new UserMessage(content), options);

try {
ChatResponse titleResponse = universalChatService.call(LlmSystem.OPENAI, content, options);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@
package com.talkforgeai.backend.assistant.service;

import com.talkforgeai.backend.assistant.dto.AssistantDto;
import com.talkforgeai.backend.assistant.dto.AssistantDto.MemoryType;
import com.talkforgeai.backend.assistant.dto.LlmSystem;
import com.talkforgeai.backend.assistant.exception.AssistentException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.anthropic.AnthropicChatModel;
import org.springframework.ai.anthropic.AnthropicChatOptions;
import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.RequestResponseAdvisor;
import org.springframework.ai.chat.client.advisor.VectorStoreChatMemoryAdvisor;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.model.ChatModel;
Expand All @@ -42,9 +48,6 @@
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestClient;
import reactor.core.publisher.Flux;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

@Service
public class UniversalChatService {
Expand All @@ -67,7 +70,8 @@ public class UniversalChatService {
public UniversalChatService(RestClient openAiRestClient,
OpenAiChatModel openAiChatModel,
MistralAiChatModel mistralAiChatModel, AnthropicChatModel anthropicChatModel,
OllamaChatModel ollamaChatModel, RestClient ollamaAiRestClient, VectorStore dbVectorStore) {
OllamaChatModel ollamaChatModel,
RestClient ollamaAiRestClient, VectorStore dbVectorStore) {
this.openAiRestClient = openAiRestClient;
this.openAiChatModel = openAiChatModel;
this.mistralAiChatModel = mistralAiChatModel;
Expand Down Expand Up @@ -145,13 +149,18 @@ ChatResponse call(LlmSystem system, String prompt, ChatOptions options) {
.chatResponse();
}

Flux<ChatResponse> stream(LlmSystem system, List<Message> messages, String userMessage,
String conversationId,
ChatOptions options) {
Flux<ChatResponse> stream(AssistantDto assistantDto, List<Message> messages,
String userMessage, ChatOptions options) {

return getClient(system)
List<RequestResponseAdvisor> requestResponseAdvisors = new ArrayList<>();

if (assistantDto.memory() == MemoryType.HISTORY) {
requestResponseAdvisors.add(getVectorStoreChatMemoryAdvisor(assistantDto.id()));
}

return getClient(assistantDto.system())
.prompt()
.advisors(getVectorStoreChatMemoryAdvisor(conversationId))
.advisors(requestResponseAdvisors)
.options(options)
.messages(messages)
.user(userMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@

package com.talkforgeai.backend.memory.dto;

public record MemoryImportDto(String content, String assistantName) {
public record MemoryImportDto(String content, String assistantName, String messageType) {

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,16 @@
* limitations under the License.
*/

package com.talkforgeai.backend.assistant.functions;
package com.talkforgeai.backend.memory.exceptions;

public enum ContextTool {
MEMORY_STORE("contextStorageFunction");
public class MemoryImportException extends RuntimeException {

private String functionBeanName;

ContextTool(String functionBeanName) {
this.functionBeanName = functionBeanName;
public MemoryImportException(String message) {
super(message);
}

public String getFunctionBeanName() {
return functionBeanName;
public MemoryImportException(String message, Throwable cause) {
super(message, cause);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package com.talkforgeai.backend.assistant.functions;
package com.talkforgeai.backend.memory.functions;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
Expand All @@ -27,26 +27,28 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ContextStorageFunction implements
Function<ContextStorageFunction.Request, ContextStorageFunction.Response> {
public class MemoryContextStorageFunction implements
Function<MemoryContextStorageFunction.Request, MemoryContextStorageFunction.Response> {

public static final Logger LOGGER = LoggerFactory.getLogger(ContextStorageFunction.class);
public static final String NAME = "memoryContextStorageFunction";
public static final Logger LOGGER = LoggerFactory.getLogger(MemoryContextStorageFunction.class);

private final MemoryService memoryService;
private final FunctionContext functionContext;
private final MemoryFunctionContext memoryFunctionContext;

public ContextStorageFunction(MemoryService memoryService, FunctionContext functionContext) {
public MemoryContextStorageFunction(MemoryService memoryService,
MemoryFunctionContext memoryFunctionContext) {
this.memoryService = memoryService;
this.functionContext = functionContext;
this.memoryFunctionContext = memoryFunctionContext;
}

@Override
public Response apply(Request request) {
LOGGER.info("Storing information in memory: {}", request.contextInfo());
MemoryStoreRequestDto requestDto = new MemoryStoreRequestDto(request.contextInfo(),
functionContext.assistantId());
memoryFunctionContext.assistantId());
DocumentWithoutEmbeddings storedDocument = memoryService.store(requestDto.content(),
requestDto.assistantId(), functionContext.runId());
requestDto.assistantId());
return new Response(
storedDocument.id(),
"I stored the following information in memory: " + request.contextInfo());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
* limitations under the License.
*/

package com.talkforgeai.backend.assistant.functions;
package com.talkforgeai.backend.memory.functions;

import com.talkforgeai.backend.assistant.dto.LlmSystem;

public record FunctionContext(LlmSystem embedLlmSystem,
String embedModel, String assistantId, String assistantName,
String runId) {
public record MemoryFunctionContext(LlmSystem embedLlmSystem,
String embedModel, String assistantId, String assistantName,
String runId) {

}
Original file line number Diff line number Diff line change
Expand Up @@ -131,22 +131,26 @@ public List<Document> similaritySearch(SearchRequest request) {
LOGGER.info("Similarity search request: {}", request);

String assistantId = null;

// TODO Handle expression
if (request.getFilterExpression() != null) {
ExpressionType type = request.getFilterExpression().type();
Key left = (Key) request.getFilterExpression().left();
Value right = (Value) request.getFilterExpression().right();
if (type == ExpressionType.EQ) {
Key left = (Key) request.getFilterExpression().left();
Value right = (Value) request.getFilterExpression().right();

assistantId = right.value().toString();
if (left.key().equals(SEARCH_ASSISTANT_ID)) {
assistantId = right.value().toString();
}
}
}

List<Double> userQueryEmbedding = getUserQueryEmbedding(request.getQuery());

List<MemoryDocument> documents;
if (assistantId == null) {
LOGGER.info("Searching for all documents");
documents = memoryRepository.findAll();
} else {
LOGGER.info("Searching for documents with assistantId = {}", assistantId);
documents = memoryRepository.findAllByAssistantId(assistantId);
}

Expand Down
Loading

0 comments on commit 4693360

Please sign in to comment.