diff --git a/backend/src/main/java/com/talkforgeai/backend/assistant/functions/ContextStorageFunction.java b/backend/src/main/java/com/talkforgeai/backend/assistant/functions/ContextStorageFunction.java index 509fde1d..eefb182e 100644 --- a/backend/src/main/java/com/talkforgeai/backend/assistant/functions/ContextStorageFunction.java +++ b/backend/src/main/java/com/talkforgeai/backend/assistant/functions/ContextStorageFunction.java @@ -46,7 +46,7 @@ public Response apply(Request request) { MemoryStoreRequestDto requestDto = new MemoryStoreRequestDto(request.contextInfo(), functionContext.assistantId()); DocumentWithoutEmbeddings storedDocument = memoryService.store(requestDto.content(), - requestDto.assistantId()); + requestDto.assistantId(), functionContext.runId()); return new Response( storedDocument.id(), "I stored the following information in memory: " + request.contextInfo()); diff --git a/backend/src/main/java/com/talkforgeai/backend/assistant/functions/FunctionContext.java b/backend/src/main/java/com/talkforgeai/backend/assistant/functions/FunctionContext.java index 1dcb42e6..7c25cb4a 100644 --- a/backend/src/main/java/com/talkforgeai/backend/assistant/functions/FunctionContext.java +++ b/backend/src/main/java/com/talkforgeai/backend/assistant/functions/FunctionContext.java @@ -19,6 +19,7 @@ import com.talkforgeai.backend.assistant.dto.LlmSystem; public record FunctionContext(LlmSystem embedLlmSystem, - String embedModel, String assistantId, String assistantName) { + String embedModel, String assistantId, String assistantName, + String runId) { } diff --git a/backend/src/main/java/com/talkforgeai/backend/assistant/service/AssistantSpringService.java b/backend/src/main/java/com/talkforgeai/backend/assistant/service/AssistantSpringService.java index b631ad48..ecb17a0d 100644 --- a/backend/src/main/java/com/talkforgeai/backend/assistant/service/AssistantSpringService.java +++ b/backend/src/main/java/com/talkforgeai/backend/assistant/service/AssistantSpringService.java @@ -43,6 +43,7 @@ import com.talkforgeai.backend.assistant.repository.MessageRepository; import com.talkforgeai.backend.assistant.repository.ThreadRepository; import com.talkforgeai.backend.memory.dto.DocumentWithoutEmbeddings; +import com.talkforgeai.backend.memory.repository.MemoryRepository; import com.talkforgeai.backend.memory.service.MemoryService; import com.talkforgeai.backend.storage.FileStorageService; import com.talkforgeai.backend.transformers.MessageProcessor; @@ -64,6 +65,7 @@ import java.util.Optional; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.logging.Level; import java.util.stream.Stream; import javax.imageio.ImageIO; import org.jetbrains.annotations.NotNull; @@ -121,13 +123,14 @@ public class AssistantSpringService { private final Map activeStreams = new ConcurrentHashMap<>(); + private final MemoryRepository memoryRepository; public AssistantSpringService( UniversalChatService universalChatService, UniversalImageGenService universalImageGenService, AssistantRepository assistantRepository, MessageRepository messageRepository, ThreadRepository threadRepository, FileStorageService fileStorageService, MessageProcessor messageProcessor, AssistantMapper assistantMapper, - MemoryService memoryService) { + MemoryService memoryService, MemoryRepository memoryRepository) { this.universalChatService = universalChatService; this.universalImageGenService = universalImageGenService; @@ -138,6 +141,7 @@ public AssistantSpringService( this.messageProcessor = messageProcessor; this.assistantMapper = assistantMapper; this.memoryService = memoryService; + this.memoryRepository = memoryRepository; } private static @NotNull Mono getInitInfosMono(Mono assistantEntityMono, @@ -265,74 +269,77 @@ public Flux> streamRunConversation(String assistantId, S StringBuilder assistantMessageContent = new StringBuilder(); return getRunIdEventFlux(runId).concatWith( - saveUserMessageMono - .then(initInfosMono) - .flux() - .flatMap(initInfos -> { - List memorySearchResults = getMemorySearchResults( - initInfos.assistantDto, message); - - return Flux.just( - new PreparedInfos(initInfos.assistantDto(), initInfos.pastMessages(), - memorySearchResults)); - }) - .flatMap(preparedInfos -> { - AssistantDto assistantDto = preparedInfos.assistantDto(); - List pastMessagesList = preparedInfos.pastMessages(); - List memoryResultsList = preparedInfos.memoryResults(); - - List finalPromptMessageList = getFinalPromptMessageList(message, - pastMessagesList, - assistantDto, memoryResultsList); - - FunctionCallbackWrapper memoryFunctionCallback = getMemoryFunctionCallback( - assistantId, assistantDto); - - ChatOptions promptOptions = universalChatService.getPromptOptions(assistantDto, - List.of(memoryFunctionCallback)); - - LOGGER.debug("Starting stream with prompt: {}", finalPromptMessageList); - LOGGER.debug("Prompt Options: {}", - universalChatService.printPromptOptions(assistantDto.system(), - promptOptions)); - - Prompt prompt = new Prompt(finalPromptMessageList, promptOptions); - - return universalChatService.stream(assistantDto.system(), prompt); - }) - .doOnCancel(() -> { - LOGGER.debug("doOnCancel. message={}", assistantMessageContent); - }) - .mapNotNull(chatResponse -> mapChatResponse(chatResponse, assistantMessageContent)) - .doOnSubscribe(subscription -> { - LOGGER.debug("doOnSubscribe. message={}", assistantMessageContent); - - activeStreams.put(runId, subscription); - }) - .doOnNext(chatResponse -> LOGGER.debug("doOnNext response: {}", chatResponse)) - .doOnComplete(() -> { - LOGGER.debug("doOnComplete. message={}", assistantMessageContent); - - Mono.fromRunnable( - () -> saveNewMessage(assistantId, threadId, MessageType.ASSISTANT, - assistantMessageContent.toString(), null)) // Wrap blocking call - .subscribeOn( - Schedulers.boundedElastic()) // Subscribe on separate thread pool - .subscribe(); // Subscribe to start execution - }) - .onErrorResume(throwable -> { - LOGGER.error("Error while streaming: {}", throwable.getMessage()); - return Flux.just(ServerSentEvent.builder() - .event("error") - .data(throwable.getMessage()) - .build()); - })); - } - - private @NotNull ServerSentEvent mapChatResponse(ChatResponse chatResponse, + saveUserMessageMono + .then(initInfosMono) + .flux() + .flatMap(initInfos -> { + List memorySearchResults = getMemorySearchResults( + initInfos.assistantDto, message); + + return Flux.just( + new PreparedInfos(initInfos.assistantDto(), initInfos.pastMessages(), + memorySearchResults)); + }) + .flatMap(preparedInfos -> { + AssistantDto assistantDto = preparedInfos.assistantDto(); + List pastMessagesList = preparedInfos.pastMessages(); + List memoryResultsList = preparedInfos.memoryResults(); + + List finalPromptMessageList = getFinalPromptMessageList(message, + pastMessagesList, + assistantDto, memoryResultsList); + + FunctionCallbackWrapper memoryFunctionCallback = getMemoryFunctionCallback( + assistantId, runId, assistantDto); + + ChatOptions promptOptions = universalChatService.getPromptOptions(assistantDto, + List.of(memoryFunctionCallback)); + + LOGGER.debug("Starting stream with prompt: {}", finalPromptMessageList); + LOGGER.debug("Prompt Options: {}", + universalChatService.printPromptOptions(assistantDto.system(), + promptOptions)); + + Prompt prompt = new Prompt(finalPromptMessageList, promptOptions); + + return universalChatService.stream(assistantDto.system(), prompt); + }) + .doOnCancel(() -> { + LOGGER.debug("doOnCancel. message={}", assistantMessageContent); + }) + .mapNotNull(chatResponse -> mapChatResponse(chatResponse, assistantMessageContent)) + .doOnSubscribe(subscription -> { + LOGGER.debug("doOnSubscribe. message={}", assistantMessageContent); + + activeStreams.put(runId, subscription); + }) + .doOnComplete(() -> { + LOGGER.trace("doOnComplete. message={}", assistantMessageContent); + + Mono.fromRunnable( + () -> saveNewMessage(assistantId, threadId, MessageType.ASSISTANT, + assistantMessageContent.toString(), null)) // Wrap blocking call + .subscribeOn( + Schedulers.boundedElastic()) // Subscribe on separate thread pool + .subscribe(); // Subscribe to start execution + + + }) + .onErrorResume(throwable -> { + LOGGER.error("Error while streaming: {}", throwable.getMessage()); + return Flux.just(ServerSentEvent.builder() + .event("error") + .data(throwable.getMessage()) + .build()); + })) + .log(AssistantSpringService.class.getName(), Level.FINE); + } + + private @NotNull ServerSentEvent mapChatResponse(@NotNull ChatResponse chatResponse, StringBuilder assistantMessageContent) { + String content = chatResponse.getResult().getOutput().getContent(); - LOGGER.debug("ChatResponse received: {}", content); + LOGGER.trace("ChatResponse received: {}", chatResponse.getResult()); if (content != null) { assistantMessageContent.append( @@ -342,18 +349,21 @@ public Flux> streamRunConversation(String assistantId, S ServerSentEvent responseSseEvent = createResponseSseEvent(chatResponse); if (responseSseEvent != null) { - LOGGER.debug("Sending event '{}'", responseSseEvent.event()); + LOGGER.trace("Sending event '{}'", responseSseEvent.event()); } return responseSseEvent; } private FunctionCallbackWrapper getMemoryFunctionCallback( - String assistantId, AssistantDto assistantDto) { + String assistantId, String runId, AssistantDto assistantDto) { return FunctionCallbackWrapper.builder( new ContextStorageFunction(memoryService, new FunctionContext(LlmSystem.OPENAI, - OpenAiEmbeddingProperties.DEFAULT_EMBEDDING_MODEL, assistantId, - assistantDto.name()))) + OpenAiEmbeddingProperties.DEFAULT_EMBEDDING_MODEL, + assistantId, + assistantDto.name(), + runId + ))) .withDescription( "Store relevant information in the vector database for later retrieval.") .withName(ContextTool.MEMORY_STORE.getFunctionBeanName()) diff --git a/backend/src/main/java/com/talkforgeai/backend/memory/domain/MemoryDocument.java b/backend/src/main/java/com/talkforgeai/backend/memory/domain/MemoryDocument.java index d0b224c7..25f6b198 100644 --- a/backend/src/main/java/com/talkforgeai/backend/memory/domain/MemoryDocument.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/domain/MemoryDocument.java @@ -37,6 +37,7 @@ public class MemoryDocument { @Id + @Column(name = "id", length = 50) private String id; @Column(name = "created_at", nullable = false) @@ -58,6 +59,8 @@ public class MemoryDocument { @ManyToOne private AssistantEntity assistant; + @Column(name = "run_id", length = 50) + private String runId; public String getId() { return id; @@ -135,4 +138,12 @@ public String getModel() { public void setModel(String model) { this.model = model; } + + public String getRunId() { + return runId; + } + + public void setRunId(String runId) { + this.runId = runId; + } } diff --git a/backend/src/main/java/com/talkforgeai/backend/memory/dto/MetadataKey.java b/backend/src/main/java/com/talkforgeai/backend/memory/dto/MetadataKey.java index 82b1117d..080e69cf 100644 --- a/backend/src/main/java/com/talkforgeai/backend/memory/dto/MetadataKey.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/dto/MetadataKey.java @@ -20,7 +20,8 @@ public enum MetadataKey { SYSTEM("system"), MODEL("model"), ASSISTANT_ID("assistantId"), - ASSISTANT_NAME("assistantName"); + ASSISTANT_NAME("assistantName"), + RUN_ID("runId"); private final String key; diff --git a/backend/src/main/java/com/talkforgeai/backend/memory/repository/MemoryRepository.java b/backend/src/main/java/com/talkforgeai/backend/memory/repository/MemoryRepository.java index fdbc4df4..e560ba64 100644 --- a/backend/src/main/java/com/talkforgeai/backend/memory/repository/MemoryRepository.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/repository/MemoryRepository.java @@ -37,4 +37,5 @@ public interface MemoryRepository extends JpaRepository "SELECT COUNT(md) FROM MemoryDocument md WHERE md.content = :content AND md.assistant is null") int countByContentAndEmptyAssistant(String content); + int countByRunId(String runId); } diff --git a/backend/src/main/java/com/talkforgeai/backend/memory/service/DBVectorStore.java b/backend/src/main/java/com/talkforgeai/backend/memory/service/DBVectorStore.java index 8b3d4617..df2e078a 100644 --- a/backend/src/main/java/com/talkforgeai/backend/memory/service/DBVectorStore.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/service/DBVectorStore.java @@ -98,6 +98,7 @@ public void add(List documents) { documentEntity.setCreatedAt(new Date()); documentEntity.setSystem((String) document.getMetadata().get(MetadataKey.SYSTEM.key())); documentEntity.setModel((String) document.getMetadata().get(MetadataKey.MODEL.key())); + documentEntity.setRunId((String) document.getMetadata().get(MetadataKey.RUN_ID.key())); if (document.getMetadata().containsKey(MetadataKey.ASSISTANT_ID.key())) { assistantRepository.findById( (String) document.getMetadata().get(MetadataKey.ASSISTANT_ID.key())) @@ -160,7 +161,7 @@ public List similaritySearch(SearchRequest request) { metadata.put(MetadataKey.ASSISTANT_NAME.key(), memoryDocument.getAssistant() == null ? null : memoryDocument.getAssistant().getName()); - + return new Document(s.key(), memoryDocument.getContent(), metadata); }) .toList(); diff --git a/backend/src/main/java/com/talkforgeai/backend/memory/service/MemoryService.java b/backend/src/main/java/com/talkforgeai/backend/memory/service/MemoryService.java index a64041a2..3e9eeb9b 100644 --- a/backend/src/main/java/com/talkforgeai/backend/memory/service/MemoryService.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/service/MemoryService.java @@ -44,11 +44,16 @@ public class MemoryService { } public DocumentWithoutEmbeddings store(String content, String assistantId) { + return this.store(content, assistantId, ""); + } + + public DocumentWithoutEmbeddings store(String content, String assistantId, String runId) { Document document = new Document(content); document.getMetadata().put(MetadataKey.SYSTEM.key(), LlmSystem.OPENAI.name()); document.getMetadata() .put(MetadataKey.MODEL.key(), OpenAiEmbeddingProperties.DEFAULT_EMBEDDING_MODEL); + document.getMetadata().put(MetadataKey.RUN_ID.key(), runId); if (assistantId != null && !assistantId.isBlank()) { document.getMetadata().put(MetadataKey.ASSISTANT_ID.key(), assistantId); diff --git a/backend/src/main/java/com/talkforgeai/backend/util/JsonUtil.java b/backend/src/main/java/com/talkforgeai/backend/util/JsonUtil.java new file mode 100644 index 00000000..4ac33730 --- /dev/null +++ b/backend/src/main/java/com/talkforgeai/backend/util/JsonUtil.java @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024 Jean Schmitz. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.talkforgeai.backend.util; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +public class JsonUtil { + + private static final ObjectMapper objectMapper = new ObjectMapper(); + + public static String convertObjectToJson(Object obj) { + try { + return objectMapper.writeValueAsString(obj); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to convert object to JSON string", e); + } + } +} diff --git a/backend/src/main/resources/db/migration/V1__Initial_Setup.sql b/backend/src/main/resources/db/migration/V1__Initial_Setup.sql index e9592448..cdd59dd7 100644 --- a/backend/src/main/resources/db/migration/V1__Initial_Setup.sql +++ b/backend/src/main/resources/db/migration/V1__Initial_Setup.sql @@ -66,6 +66,7 @@ create table memory_document system varchar(50), model varchar(50), assistant_id varchar(50), + run_id varchar(50), embeddings CLOB, foreign key (assistant_id) references assistant (id) );