diff --git a/backend/src/main/java/com/talkforgeai/backend/assistant/dto/AssistantDto.java b/backend/src/main/java/com/talkforgeai/backend/assistant/dto/AssistantDto.java index 259f0776..f55d0333 100644 --- a/backend/src/main/java/com/talkforgeai/backend/assistant/dto/AssistantDto.java +++ b/backend/src/main/java/com/talkforgeai/backend/assistant/dto/AssistantDto.java @@ -38,7 +38,7 @@ public record AssistantDto( public enum MemoryType { NONE, - ASSISTANT, - GLOBAL + AI_DECIDES, + HISTORY } } diff --git a/backend/src/main/java/com/talkforgeai/backend/assistant/service/AssistantSpringService.java b/backend/src/main/java/com/talkforgeai/backend/assistant/service/AssistantSpringService.java index 39d2d0cb..7b3e6625 100644 --- a/backend/src/main/java/com/talkforgeai/backend/assistant/service/AssistantSpringService.java +++ b/backend/src/main/java/com/talkforgeai/backend/assistant/service/AssistantSpringService.java @@ -34,15 +34,14 @@ import com.talkforgeai.backend.assistant.dto.ThreadTitleGenerationRequestDto; import com.talkforgeai.backend.assistant.dto.ThreadTitleUpdateRequestDto; import com.talkforgeai.backend.assistant.exception.AssistentException; -import com.talkforgeai.backend.assistant.functions.ContextStorageFunction; -import com.talkforgeai.backend.assistant.functions.ContextStorageFunction.Request; -import com.talkforgeai.backend.assistant.functions.ContextStorageFunction.Response; -import com.talkforgeai.backend.assistant.functions.ContextTool; -import com.talkforgeai.backend.assistant.functions.FunctionContext; import com.talkforgeai.backend.assistant.repository.AssistantRepository; import com.talkforgeai.backend.assistant.repository.MessageRepository; import com.talkforgeai.backend.assistant.repository.ThreadRepository; import com.talkforgeai.backend.memory.dto.DocumentWithoutEmbeddings; +import com.talkforgeai.backend.memory.functions.MemoryContextStorageFunction; +import com.talkforgeai.backend.memory.functions.MemoryContextStorageFunction.Request; +import com.talkforgeai.backend.memory.functions.MemoryContextStorageFunction.Response; +import com.talkforgeai.backend.memory.functions.MemoryFunctionContext; import com.talkforgeai.backend.memory.service.MemoryService; import com.talkforgeai.backend.storage.FileStorageService; import com.talkforgeai.backend.transformers.MessageProcessor; @@ -79,12 +78,14 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi.ChatModel; import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; import org.springframework.core.io.FileSystemResource; import org.springframework.core.io.Resource; import org.springframework.data.domain.PageRequest; @@ -157,7 +158,7 @@ public AssistantSpringService( return runIdMono; } - private static @NotNull List getFinalPromptMessageList(String message, + private static @NotNull List getFinalPromptMessageList( List pastMessagesList, AssistantDto assistantDto, List memoryResultsList) { List promptMessageList = pastMessagesList.stream() @@ -194,9 +195,6 @@ public AssistantSpringService( result -> memoryMessage.append(result.content()).append("\n")); memoryMessage.append("\nUser message:\n"); } - - finalPromptMessageList.add( - new UserMessage(memoryMessage.append(message).toString())); return finalPromptMessageList; } @@ -258,11 +256,11 @@ public void cancelStream(String threadId, String runId) { } public Flux> streamRunConversation(String assistantId, String threadId, - String message) { + String userMessage) { final String runId = UniqueIdUtil.generateRunId(); - Mono saveUserMessageMono = getSaveUserMessageMono(assistantId, threadId, message); + Mono saveUserMessageMono = getSaveUserMessageMono(assistantId, threadId, userMessage); Mono assistantEntityMono = getAssistantEntityMono(assistantId); Mono> pastMessages = getPastMessagesMono(threadId); Mono initInfosMono = getInitInfosMono(assistantEntityMono, pastMessages); @@ -275,7 +273,7 @@ public Flux> streamRunConversation(String assistantId, S .flux() .flatMap(initInfos -> { List memorySearchResults = getMemorySearchResults( - initInfos.assistantDto, message); + initInfos.assistantDto.id(), initInfos.assistantDto.memory(), userMessage); return Flux.just( new PreparedInfos(initInfos.assistantDto(), initInfos.pastMessages(), @@ -286,36 +284,40 @@ public Flux> streamRunConversation(String assistantId, S List pastMessagesList = preparedInfos.pastMessages(); List memoryResultsList = preparedInfos.memoryResults(); - List finalPromptMessageList = getFinalPromptMessageList(message, + List finalPromptMessageList = getFinalPromptMessageList( pastMessagesList, assistantDto, memoryResultsList); - FunctionCallbackWrapper memoryFunctionCallback = getMemoryFunctionCallback( - assistantId, runId, assistantDto); + List functionCallbacks = new ArrayList<>(); + FunctionCallbackWrapper memoryFunctionCallback + = getMemoryFunctionCallback(assistantId, assistantDto.name(), runId); + + if (assistantDto.memory() == MemoryType.AI_DECIDES) { + functionCallbacks.add(memoryFunctionCallback); + } ChatOptions promptOptions = universalChatService.getPromptOptions(assistantDto, - List.of(memoryFunctionCallback)); + functionCallbacks); LOGGER.debug("Starting stream with prompt: {}", finalPromptMessageList); LOGGER.debug("Prompt Options: {}", universalChatService.printPromptOptions(assistantDto.system(), promptOptions)); - Prompt prompt = new Prompt(finalPromptMessageList, promptOptions); - - return universalChatService.stream(assistantDto.system(), prompt); + return universalChatService.stream(assistantDto, + finalPromptMessageList, userMessage, promptOptions); }) .doOnCancel(() -> { - LOGGER.debug("doOnCancel. message={}", assistantMessageContent); + LOGGER.debug("doOnCancel. userMessage={}", assistantMessageContent); }) .mapNotNull(chatResponse -> mapChatResponse(chatResponse, assistantMessageContent)) .doOnSubscribe(subscription -> { - LOGGER.debug("doOnSubscribe. message={}", assistantMessageContent); + LOGGER.debug("doOnSubscribe. userMessage={}", assistantMessageContent); activeStreams.put(runId, subscription); }) .doOnComplete(() -> { - LOGGER.trace("doOnComplete. message={}", assistantMessageContent); + LOGGER.trace("doOnComplete. userMessage={}", assistantMessageContent); Mono.fromRunnable( () -> saveNewMessage(assistantId, threadId, MessageType.ASSISTANT, @@ -356,18 +358,18 @@ public Flux> streamRunConversation(String assistantId, S } private FunctionCallbackWrapper getMemoryFunctionCallback( - String assistantId, String runId, AssistantDto assistantDto) { + String assistantId, String assistantName, String runId) { return FunctionCallbackWrapper.builder( - new ContextStorageFunction(memoryService, - new FunctionContext(LlmSystem.OPENAI, + new MemoryContextStorageFunction(memoryService, + new MemoryFunctionContext(LlmSystem.OPENAI, OpenAiEmbeddingProperties.DEFAULT_EMBEDDING_MODEL, assistantId, - assistantDto.name(), + assistantName, runId ))) .withDescription( "Store relevant information in the vector database for later retrieval.") - .withName(ContextTool.MEMORY_STORE.getFunctionBeanName()) + .withName(MemoryContextStorageFunction.NAME) .build(); } @@ -394,31 +396,32 @@ private FunctionCallbackWrapper getMemoryFunctionCallback( .subscribeOn(Schedulers.boundedElastic()); } - private @NotNull List getMemorySearchResults(AssistantDto assistantDto, + private @NotNull List getMemorySearchResults(String assistantId, + MemoryType memoryType, String message) { - if (assistantDto.memory() == MemoryType.NONE) { + if (memoryType == MemoryType.NONE) { return List.of(); } LOGGER.info("Searching memory for message: {}", message); - List searchResults = memoryService.search( - SearchRequest.query(message).withSimilarityThreshold(0.75f)); - if (assistantDto.memory() == MemoryType.ASSISTANT) { - List filteredMemory = searchResults.stream() - .filter(m -> m.assistantId() != null && m.assistantId().equals(assistantDto.id())) - .toList(); + FilterExpressionBuilder expressionBuilder = new FilterExpressionBuilder(); + Expression assistantExpression = expressionBuilder.eq("assistantId", assistantId).build(); - LOGGER.debug("Memory search results for assistant '{}': {}", assistantDto.id(), - filteredMemory); + List searchResults = memoryService.search( + SearchRequest.query(message) + .withFilterExpression(assistantExpression) + .withSimilarityThreshold(0.75f)); - return filteredMemory; - } + List filteredMemory = searchResults.stream() + .filter(m -> m.assistantId() != null && m.assistantId().equals(assistantId)) + .toList(); - LOGGER.debug("Memory search results: {}", searchResults); + LOGGER.debug("Memory search results for assistant '{}': {}", assistantId, + filteredMemory); - return searchResults; + return filteredMemory; } private ServerSentEvent createResponseSseEvent(ChatResponse chatResponse) { @@ -521,10 +524,8 @@ public ThreadTitleDto generateThreadTitle(String threadId, .withMaxTokens(256) .build(); - Prompt titlePrompt = new Prompt(new UserMessage(content), options); - try { - ChatResponse titleResponse = universalChatService.call(LlmSystem.OPENAI, titlePrompt); + ChatResponse titleResponse = universalChatService.call(LlmSystem.OPENAI, content, options); String generatedTitle = titleResponse.getResult().getOutput().getContent(); diff --git a/backend/src/main/java/com/talkforgeai/backend/assistant/service/UniversalChatService.java b/backend/src/main/java/com/talkforgeai/backend/assistant/service/UniversalChatService.java index 6d95e246..5541952b 100644 --- a/backend/src/main/java/com/talkforgeai/backend/assistant/service/UniversalChatService.java +++ b/backend/src/main/java/com/talkforgeai/backend/assistant/service/UniversalChatService.java @@ -17,19 +17,24 @@ package com.talkforgeai.backend.assistant.service; import com.talkforgeai.backend.assistant.dto.AssistantDto; +import com.talkforgeai.backend.assistant.dto.AssistantDto.MemoryType; import com.talkforgeai.backend.assistant.dto.LlmSystem; import com.talkforgeai.backend.assistant.exception.AssistentException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; +import org.jetbrains.annotations.NotNull; import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.AnthropicChatOptions; import org.springframework.ai.anthropic.api.AnthropicApi; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.RequestResponseAdvisor; +import org.springframework.ai.chat.client.advisor.VectorStoreChatMemoryAdvisor; +import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.mistralai.MistralAiChatModel; import org.springframework.ai.mistralai.MistralAiChatOptions; import org.springframework.ai.mistralai.api.MistralAiApi; @@ -38,6 +43,7 @@ import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.vectorstore.VectorStore; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.stereotype.Service; import org.springframework.web.client.RestClient; @@ -46,6 +52,8 @@ @Service public class UniversalChatService { + private final int DEFAULT_CHAT_MEMORY_RESPONSE_SIZE = 5; + @Qualifier("openAiRestClient") private final RestClient openAiRestClient; private final OpenAiChatModel openAiChatModel; @@ -56,16 +64,22 @@ public class UniversalChatService { @Qualifier("ollamaAiRestClient") private final RestClient ollamaAiRestClient; + @Qualifier("dbVectorStore") + private final VectorStore dbVectorStore; + public UniversalChatService(RestClient openAiRestClient, OpenAiChatModel openAiChatModel, MistralAiChatModel mistralAiChatModel, AnthropicChatModel anthropicChatModel, - OllamaChatModel ollamaChatModel, RestClient ollamaAiRestClient) { + OllamaChatModel ollamaChatModel, + RestClient ollamaAiRestClient, VectorStore dbVectorStore) { this.openAiRestClient = openAiRestClient; this.openAiChatModel = openAiChatModel; this.mistralAiChatModel = mistralAiChatModel; this.anthropicChatModel = anthropicChatModel; this.ollamaChatModel = ollamaChatModel; this.ollamaAiRestClient = ollamaAiRestClient; + + this.dbVectorStore = dbVectorStore; } public ChatOptions getPromptOptions(AssistantDto assistantDto, @@ -126,38 +140,53 @@ public String printPromptOptions(LlmSystem system, ChatOptions options) { return printedOptions.toString(); } - ChatResponse call(LlmSystem system, Prompt prompt) { - return getChatClient(system).call(prompt); + ChatResponse call(LlmSystem system, String prompt, ChatOptions options) { + return getClient(system) + .prompt() + .options(options) + .user(prompt) + .call() + .chatResponse(); } - Flux stream(LlmSystem system, Prompt prompt) { - return getStreamingChatClient(system).stream(prompt); - } + Flux stream(AssistantDto assistantDto, List messages, + String userMessage, ChatOptions options) { + + List requestResponseAdvisors = new ArrayList<>(); - StreamingChatModel getStreamingChatClient(LlmSystem system) { - return (StreamingChatModel) getClient(system); + if (assistantDto.memory() == MemoryType.HISTORY) { + requestResponseAdvisors.add(getVectorStoreChatMemoryAdvisor(assistantDto.id())); + } + + return getClient(assistantDto.system()) + .prompt() + .advisors(requestResponseAdvisors) + .options(options) + .messages(messages) + .user(userMessage) + .stream() + .chatResponse(); } - ChatModel getChatClient(LlmSystem system) { - return (ChatModel) getClient(system); + private @NotNull VectorStoreChatMemoryAdvisor getVectorStoreChatMemoryAdvisor( + String converationId) { + return new VectorStoreChatMemoryAdvisor( + dbVectorStore, + converationId, + DEFAULT_CHAT_MEMORY_RESPONSE_SIZE + ); } - private Object getClient(LlmSystem system) { - switch (system) { - case OPENAI -> { - return openAiChatModel; - } - case MISTRAL -> { - return mistralAiChatModel; - } - case OLLAMA -> { - return ollamaChatModel; - } - case ANSTHROPIC -> { - return anthropicChatModel; - } - default -> throw new IllegalStateException("Unexpected system: " + system); - } + private ChatClient getClient(LlmSystem system) { + ChatModel model = switch (system) { + case OPENAI -> openAiChatModel; + case MISTRAL -> mistralAiChatModel; + case OLLAMA -> ollamaChatModel; + case ANSTHROPIC -> anthropicChatModel; + }; + + return ChatClient.builder(model) + .build(); } private MistralAiChatOptions getMistralOptions(AssistantDto assistantDto, diff --git a/backend/src/main/java/com/talkforgeai/backend/memory/domain/MemoryDocument.java b/backend/src/main/java/com/talkforgeai/backend/memory/domain/MemoryDocument.java index 25f6b198..7aa2408c 100644 --- a/backend/src/main/java/com/talkforgeai/backend/memory/domain/MemoryDocument.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/domain/MemoryDocument.java @@ -59,8 +59,11 @@ public class MemoryDocument { @ManyToOne private AssistantEntity assistant; + @Column(name = "run_id", length = 50) private String runId; + @Column(name = "message_type", length = 20) + private String messageType; public String getId() { return id; @@ -146,4 +149,13 @@ public String getRunId() { public void setRunId(String runId) { this.runId = runId; } + + public String getMessageType() { + return messageType; + } + + public void setMessageType(String memoryType) { + this.messageType = memoryType; + } + } diff --git a/backend/src/main/java/com/talkforgeai/backend/memory/dto/DocumentWithoutEmbeddings.java b/backend/src/main/java/com/talkforgeai/backend/memory/dto/DocumentWithoutEmbeddings.java index 0e999683..a1fbc284 100644 --- a/backend/src/main/java/com/talkforgeai/backend/memory/dto/DocumentWithoutEmbeddings.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/dto/DocumentWithoutEmbeddings.java @@ -22,7 +22,7 @@ public record DocumentWithoutEmbeddings(String id, String content, String assistantId, String assistantName, - String system, String model) { + String system, String model, String messageType) { private static String getMetadataValue(Map meta, MetadataKey key) { return (String) meta.getOrDefault(key.key(), null); @@ -33,10 +33,11 @@ public static DocumentWithoutEmbeddings from(Document document) { return new DocumentWithoutEmbeddings( document.getId(), document.getContent(), - getMetadataValue(meta, MetadataKey.ASSISTANT_ID), + getMetadataValue(meta, MetadataKey.CONVERSATION_ID), getMetadataValue(meta, MetadataKey.ASSISTANT_NAME), getMetadataValue(meta, MetadataKey.SYSTEM), - getMetadataValue(meta, MetadataKey.MODEL) + getMetadataValue(meta, MetadataKey.MODEL), + getMetadataValue(meta, MetadataKey.MESSAGE_TYPE) ); } } \ No newline at end of file diff --git a/backend/src/main/java/com/talkforgeai/backend/memory/dto/MemoryImportDto.java b/backend/src/main/java/com/talkforgeai/backend/memory/dto/MemoryImportDto.java index 2f41ee01..9c161257 100644 --- a/backend/src/main/java/com/talkforgeai/backend/memory/dto/MemoryImportDto.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/dto/MemoryImportDto.java @@ -16,6 +16,6 @@ package com.talkforgeai.backend.memory.dto; -public record MemoryImportDto(String content, String assistantName) { +public record MemoryImportDto(String content, String assistantName, String messageType) { } diff --git a/backend/src/main/java/com/talkforgeai/backend/memory/dto/MetadataKey.java b/backend/src/main/java/com/talkforgeai/backend/memory/dto/MetadataKey.java index 080e69cf..b2128a67 100644 --- a/backend/src/main/java/com/talkforgeai/backend/memory/dto/MetadataKey.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/dto/MetadataKey.java @@ -19,8 +19,9 @@ public enum MetadataKey { SYSTEM("system"), MODEL("model"), - ASSISTANT_ID("assistantId"), + CONVERSATION_ID("conversationId"), ASSISTANT_NAME("assistantName"), + MESSAGE_TYPE("messageType"), RUN_ID("runId"); private final String key; diff --git a/backend/src/main/java/com/talkforgeai/backend/assistant/functions/ContextTool.java b/backend/src/main/java/com/talkforgeai/backend/memory/exceptions/MemoryImportException.java similarity index 66% rename from backend/src/main/java/com/talkforgeai/backend/assistant/functions/ContextTool.java rename to backend/src/main/java/com/talkforgeai/backend/memory/exceptions/MemoryImportException.java index f5e10a24..a68aadba 100644 --- a/backend/src/main/java/com/talkforgeai/backend/assistant/functions/ContextTool.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/exceptions/MemoryImportException.java @@ -14,18 +14,16 @@ * limitations under the License. */ -package com.talkforgeai.backend.assistant.functions; +package com.talkforgeai.backend.memory.exceptions; -public enum ContextTool { - MEMORY_STORE("contextStorageFunction"); +public class MemoryImportException extends RuntimeException { - private String functionBeanName; - - ContextTool(String functionBeanName) { - this.functionBeanName = functionBeanName; + public MemoryImportException(String message) { + super(message); } - public String getFunctionBeanName() { - return functionBeanName; + public MemoryImportException(String message, Throwable cause) { + super(message, cause); } + } diff --git a/backend/src/main/java/com/talkforgeai/backend/assistant/functions/ContextStorageFunction.java b/backend/src/main/java/com/talkforgeai/backend/memory/functions/MemoryContextStorageFunction.java similarity index 74% rename from backend/src/main/java/com/talkforgeai/backend/assistant/functions/ContextStorageFunction.java rename to backend/src/main/java/com/talkforgeai/backend/memory/functions/MemoryContextStorageFunction.java index eefb182e..f0a66b39 100644 --- a/backend/src/main/java/com/talkforgeai/backend/assistant/functions/ContextStorageFunction.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/functions/MemoryContextStorageFunction.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.talkforgeai.backend.assistant.functions; +package com.talkforgeai.backend.memory.functions; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -27,26 +27,28 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class ContextStorageFunction implements - Function { +public class MemoryContextStorageFunction implements + Function { - public static final Logger LOGGER = LoggerFactory.getLogger(ContextStorageFunction.class); + public static final String NAME = "memoryContextStorageFunction"; + public static final Logger LOGGER = LoggerFactory.getLogger(MemoryContextStorageFunction.class); private final MemoryService memoryService; - private final FunctionContext functionContext; + private final MemoryFunctionContext memoryFunctionContext; - public ContextStorageFunction(MemoryService memoryService, FunctionContext functionContext) { + public MemoryContextStorageFunction(MemoryService memoryService, + MemoryFunctionContext memoryFunctionContext) { this.memoryService = memoryService; - this.functionContext = functionContext; + this.memoryFunctionContext = memoryFunctionContext; } @Override public Response apply(Request request) { LOGGER.info("Storing information in memory: {}", request.contextInfo()); MemoryStoreRequestDto requestDto = new MemoryStoreRequestDto(request.contextInfo(), - functionContext.assistantId()); + memoryFunctionContext.assistantId()); DocumentWithoutEmbeddings storedDocument = memoryService.store(requestDto.content(), - requestDto.assistantId(), functionContext.runId()); + requestDto.assistantId()); return new Response( storedDocument.id(), "I stored the following information in memory: " + request.contextInfo()); diff --git a/backend/src/main/java/com/talkforgeai/backend/assistant/functions/FunctionContext.java b/backend/src/main/java/com/talkforgeai/backend/memory/functions/MemoryFunctionContext.java similarity index 71% rename from backend/src/main/java/com/talkforgeai/backend/assistant/functions/FunctionContext.java rename to backend/src/main/java/com/talkforgeai/backend/memory/functions/MemoryFunctionContext.java index 7c25cb4a..d6996772 100644 --- a/backend/src/main/java/com/talkforgeai/backend/assistant/functions/FunctionContext.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/functions/MemoryFunctionContext.java @@ -14,12 +14,12 @@ * limitations under the License. */ -package com.talkforgeai.backend.assistant.functions; +package com.talkforgeai.backend.memory.functions; import com.talkforgeai.backend.assistant.dto.LlmSystem; -public record FunctionContext(LlmSystem embedLlmSystem, - String embedModel, String assistantId, String assistantName, - String runId) { +public record MemoryFunctionContext(LlmSystem embedLlmSystem, + String embedModel, String assistantId, String assistantName, + String runId) { } diff --git a/backend/src/main/java/com/talkforgeai/backend/memory/repository/MemoryRepository.java b/backend/src/main/java/com/talkforgeai/backend/memory/repository/MemoryRepository.java index e560ba64..b81c6bf1 100644 --- a/backend/src/main/java/com/talkforgeai/backend/memory/repository/MemoryRepository.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/repository/MemoryRepository.java @@ -17,6 +17,7 @@ package com.talkforgeai.backend.memory.repository; import com.talkforgeai.backend.memory.domain.MemoryDocument; +import java.util.List; import org.jetbrains.annotations.NotNull; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; @@ -30,6 +31,10 @@ public interface MemoryRepository extends JpaRepository @NotNull Page findAll(@NotNull Pageable pageable); + @NotNull + List findAllByAssistantId(@NotNull String assistantId); + + @Query("SELECT COUNT(md) FROM MemoryDocument md WHERE md.content = :content AND md.assistant.id = :assistantId") int countByContentAndAssistantId(String content, String assistantId); diff --git a/backend/src/main/java/com/talkforgeai/backend/memory/service/DBVectorStore.java b/backend/src/main/java/com/talkforgeai/backend/memory/service/DBVectorStore.java index dc1f8403..eee217ab 100644 --- a/backend/src/main/java/com/talkforgeai/backend/memory/service/DBVectorStore.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/service/DBVectorStore.java @@ -16,6 +16,7 @@ package com.talkforgeai.backend.memory.service; +import com.talkforgeai.backend.assistant.dto.LlmSystem; import com.talkforgeai.backend.assistant.repository.AssistantRepository; import com.talkforgeai.backend.memory.domain.MemoryDocument; import com.talkforgeai.backend.memory.dto.MemoryListRequestDto; @@ -38,8 +39,12 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.SimpleVectorStore.EmbeddingMath; +import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.Filter.Value; import org.springframework.data.domain.PageRequest; public class DBVectorStore implements ListableVectoreStore { @@ -49,7 +54,7 @@ public class DBVectorStore implements ListableVectoreStore { public static final String SEARCH_ASSISTANT_NAME = "assistantName"; public static final String SEARCH_SYSTEM = "system"; public static final String SEARCH_CONTENT = "content"; - public static final String SEARCH_CREATED_AT = "createdAt"; + public static final String SEARCH_MESSAGE_TYPE = "messageType"; private final EntityManager entityManager; private final MemoryRepository memoryRepository; private final AssistantRepository assistantRepository; @@ -71,11 +76,11 @@ public void add(List documents) { List memoryDocuments = documents.stream() .filter(document -> { int countDocs; - if (document.getMetadata().get(MetadataKey.ASSISTANT_ID.key()) == null) { + if (document.getMetadata().get(MetadataKey.CONVERSATION_ID.key()) == null) { countDocs = memoryRepository.countByContentAndEmptyAssistant(document.getContent()); } else { countDocs = memoryRepository.countByContentAndAssistantId(document.getContent(), - (String) document.getMetadata().get(MetadataKey.ASSISTANT_ID.key())); + (String) document.getMetadata().get(MetadataKey.CONVERSATION_ID.key())); } if (countDocs > 0) { @@ -92,16 +97,18 @@ public void add(List documents) { List embedding = this.embeddingModel.embed(document); documentEntity.setEmbeddings( embedding.stream().mapToDouble(Double::doubleValue).toArray()); - // Convert List to byte[] documentEntity.setId(UniqueIdUtil.generateMemoryId()); documentEntity.setContent(document.getContent()); documentEntity.setCreatedAt(new Date()); - documentEntity.setSystem((String) document.getMetadata().get(MetadataKey.SYSTEM.key())); - documentEntity.setModel((String) document.getMetadata().get(MetadataKey.MODEL.key())); + + documentEntity.setSystem(LlmSystem.OPENAI.name()); + documentEntity.setModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL); documentEntity.setRunId((String) document.getMetadata().get(MetadataKey.RUN_ID.key())); - if (document.getMetadata().containsKey(MetadataKey.ASSISTANT_ID.key())) { + documentEntity.setMessageType( + (String) document.getMetadata().get(MetadataKey.MESSAGE_TYPE.key())); + if (document.getMetadata().containsKey(MetadataKey.CONVERSATION_ID.key())) { assistantRepository.findById( - (String) document.getMetadata().get(MetadataKey.ASSISTANT_ID.key())) + (String) document.getMetadata().get(MetadataKey.CONVERSATION_ID.key())) .ifPresent(documentEntity::setAssistant); } @@ -121,14 +128,32 @@ public Optional delete(List idList) { @Override public List similaritySearch(SearchRequest request) { + LOGGER.info("Similarity search request: {}", request); + + String assistantId = null; if (request.getFilterExpression() != null) { - throw new UnsupportedOperationException( - "The [" + this.getClass() + "] doesn't support metadata filtering!"); + ExpressionType type = request.getFilterExpression().type(); + if (type == ExpressionType.EQ) { + Key left = (Key) request.getFilterExpression().left(); + Value right = (Value) request.getFilterExpression().right(); + + if (left.key().equals(SEARCH_ASSISTANT_ID)) { + assistantId = right.value().toString(); + } + } } List userQueryEmbedding = getUserQueryEmbedding(request.getQuery()); - List documents = memoryRepository.findAll(); + List documents; + if (assistantId == null) { + LOGGER.info("Searching for all documents"); + documents = memoryRepository.findAll(); + } else { + LOGGER.info("Searching for documents with assistantId = {}", assistantId); + documents = memoryRepository.findAllByAssistantId(assistantId); + } + // Convert documents to Map Map documentMap = documents.stream() .collect(Collectors.toMap(MemoryDocument::getId, document -> document)); @@ -156,14 +181,16 @@ public List similaritySearch(SearchRequest request) { Map metadata = new HashMap<>(); metadata.put(MetadataKey.SYSTEM.key(), memoryDocument.getSystem()); metadata.put(MetadataKey.MODEL.key(), memoryDocument.getModel()); - metadata.put(MetadataKey.ASSISTANT_ID.key(), + metadata.put(MetadataKey.CONVERSATION_ID.key(), memoryDocument.getAssistant() == null ? null : memoryDocument.getAssistant().getId()); metadata.put(MetadataKey.ASSISTANT_NAME.key(), memoryDocument.getAssistant() == null ? null : memoryDocument.getAssistant().getName()); + metadata.put(MetadataKey.MESSAGE_TYPE.key(), memoryDocument.getMessageType()); return new Document(s.key(), memoryDocument.getContent(), metadata); }) + .peek(document -> LOGGER.info("Similarity search result: {}", document)) .toList(); } @@ -201,6 +228,8 @@ public List list(MemoryListRequestDto listRequest) { "a.name LIKE :assistantName"); appendQueryCondition(query, searchMap, SEARCH_SYSTEM, "md.system = :system"); + appendQueryCondition(query, searchMap, SEARCH_MESSAGE_TYPE, + "md.messageType = :messageType"); } List sortBy = listRequest.sortBy(); @@ -233,6 +262,8 @@ public List list(MemoryListRequestDto listRequest) { setQueryParameter(typedQuery, searchMap, SEARCH_ASSISTANT_NAME, "%" + searchMap.get(SEARCH_ASSISTANT_NAME) + "%"); setQueryParameter(typedQuery, searchMap, SEARCH_SYSTEM, searchMap.get(SEARCH_SYSTEM)); + setQueryParameter(typedQuery, searchMap, SEARCH_MESSAGE_TYPE, + searchMap.get(SEARCH_MESSAGE_TYPE)); } return typedQuery.getResultList().stream() @@ -251,10 +282,11 @@ private Document mapDocument(MemoryDocument memoryDocument) { metadata.put(MetadataKey.SYSTEM.key(), memoryDocument.getSystem()); metadata.put(MetadataKey.MODEL.key(), memoryDocument.getModel()); - metadata.put(MetadataKey.ASSISTANT_ID.key(), + metadata.put(MetadataKey.CONVERSATION_ID.key(), memoryDocument.getAssistant() == null ? null : memoryDocument.getAssistant().getId()); metadata.put(MetadataKey.ASSISTANT_NAME.key(), memoryDocument.getAssistant() == null ? null : memoryDocument.getAssistant().getName()); + metadata.put(MetadataKey.MESSAGE_TYPE.key(), memoryDocument.getMessageType()); return new Document( memoryDocument.getId(), diff --git a/backend/src/main/java/com/talkforgeai/backend/memory/service/MemoryImportService.java b/backend/src/main/java/com/talkforgeai/backend/memory/service/MemoryImportService.java index 5698dd88..5d7330b7 100644 --- a/backend/src/main/java/com/talkforgeai/backend/memory/service/MemoryImportService.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/service/MemoryImportService.java @@ -22,8 +22,10 @@ import com.talkforgeai.backend.assistant.repository.AssistantRepository; import com.talkforgeai.backend.memory.dto.MemoryImportDto; import com.talkforgeai.backend.memory.exceptions.MemoryException; +import com.talkforgeai.backend.memory.exceptions.MemoryImportException; import com.talkforgeai.backend.storage.FileStorageService; import jakarta.transaction.Transactional; +import java.awt.TrayIcon.MessageType; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; @@ -67,6 +69,8 @@ public void importMemory() { new TypeReference<>() { }); memoryImportDtos.forEach(memoryImportDto -> { + validateImport(memoryImportDto); + try { String assistantId = null; if (memoryImportDto.assistantName() != null && !memoryImportDto.assistantName() @@ -90,4 +94,15 @@ public void importMemory() { LOGGER.error("Failed to list files in directory: {} ", importDirectory, e); } } + + private void validateImport(MemoryImportDto memoryImportDto) { + try { + if (memoryImportDto.messageType() != null && !memoryImportDto.messageType() + .isBlank()) { + MessageType.valueOf(memoryImportDto.messageType()); + } + } catch (IllegalArgumentException e) { + throw new MemoryImportException("Invalid message type: " + memoryImportDto.messageType(), e); + } + } } diff --git a/backend/src/main/java/com/talkforgeai/backend/memory/service/MemoryService.java b/backend/src/main/java/com/talkforgeai/backend/memory/service/MemoryService.java index 3e9eeb9b..48895a0a 100644 --- a/backend/src/main/java/com/talkforgeai/backend/memory/service/MemoryService.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/service/MemoryService.java @@ -44,19 +44,21 @@ public class MemoryService { } public DocumentWithoutEmbeddings store(String content, String assistantId) { - return this.store(content, assistantId, ""); + return this.store(content, assistantId, "", ""); } - public DocumentWithoutEmbeddings store(String content, String assistantId, String runId) { + public DocumentWithoutEmbeddings store(String content, String assistantId, String runId, + String messageType) { Document document = new Document(content); document.getMetadata().put(MetadataKey.SYSTEM.key(), LlmSystem.OPENAI.name()); document.getMetadata() .put(MetadataKey.MODEL.key(), OpenAiEmbeddingProperties.DEFAULT_EMBEDDING_MODEL); document.getMetadata().put(MetadataKey.RUN_ID.key(), runId); + document.getMetadata().put(MetadataKey.MESSAGE_TYPE.key(), messageType); if (assistantId != null && !assistantId.isBlank()) { - document.getMetadata().put(MetadataKey.ASSISTANT_ID.key(), assistantId); + document.getMetadata().put(MetadataKey.CONVERSATION_ID.key(), assistantId); var assistant = assistantRepository.findById(assistantId); assistant.ifPresent( a -> document.getMetadata().put(MetadataKey.ASSISTANT_NAME.key(), a.getName())); diff --git a/backend/src/main/java/com/talkforgeai/backend/assistant/functions/FunctionsConfiguration.java b/backend/src/main/resources/db/migration/V2__Memory.sql similarity index 58% rename from backend/src/main/java/com/talkforgeai/backend/assistant/functions/FunctionsConfiguration.java rename to backend/src/main/resources/db/migration/V2__Memory.sql index 8070f0a6..b9ebe25a 100644 --- a/backend/src/main/java/com/talkforgeai/backend/assistant/functions/FunctionsConfiguration.java +++ b/backend/src/main/resources/db/migration/V2__Memory.sql @@ -14,16 +14,6 @@ * limitations under the License. */ -package com.talkforgeai.backend.assistant.functions; -import org.springframework.context.annotation.Configuration; - -@Configuration -public class FunctionsConfiguration { - -// @Bean -// @Description("Store relevant information in the vector database for later retrieval.") -// public Function contextStorageFunction(MemoryService memoryService) { -// return new ContextStorageFunction(memoryService); -// } -} +ALTER TABLE memory_document + ADD COLUMN message_type VARCHAR(20); \ No newline at end of file diff --git a/frontend/src/components/editor/EditorTabMemory.vue b/frontend/src/components/editor/EditorTabMemory.vue index 124a8d0d..138d0539 100644 --- a/frontend/src/components/editor/EditorTabMemory.vue +++ b/frontend/src/components/editor/EditorTabMemory.vue @@ -25,9 +25,12 @@ export default defineComponent({ data() { return { memoryTypes: [ - {key: 'NONE', description: 'No Memory'}, - {key: 'ASSISTANT', description: 'Use only Assistant memory'}, - {key: 'GLOBAL', description: 'Use global memory'}, + {key: 'NONE', description: 'No information is stored'}, + { + key: 'AI_DECIDES', + description: 'The AI independently decides which information is stored', + }, + {key: 'HISTORY', description: 'The entire chat history is stored'}, ], }; }, diff --git a/frontend/src/components/memory/MemoryView.vue b/frontend/src/components/memory/MemoryView.vue index 1e0bbfea..82a4226f 100644 --- a/frontend/src/components/memory/MemoryView.vue +++ b/frontend/src/components/memory/MemoryView.vue @@ -39,8 +39,7 @@ :search="searchModifier" item-value="id" show-select - @update:options="loadServerItems" - > + @update:options="loadServerItems">