Skip to content

Commit

Permalink
Interim support for VectorStoreChatMemoryAdvisor
Browse files Browse the repository at this point in the history
  • Loading branch information
jschm42 committed Jun 10, 2024
1 parent 4e9a3ce commit a8c8b77
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ public AssistantSpringService(
return runIdMono;
}

private static @NotNull List<Message> getFinalPromptMessageList(String message,
private static @NotNull List<Message> getFinalPromptMessageList(
List<MessageDto> pastMessagesList, AssistantDto assistantDto,
List<DocumentWithoutEmbeddings> memoryResultsList) {
List<Message> promptMessageList = pastMessagesList.stream()
Expand Down Expand Up @@ -194,9 +194,6 @@ public AssistantSpringService(
result -> memoryMessage.append(result.content()).append("\n"));
memoryMessage.append("\nUser message:\n");
}

finalPromptMessageList.add(
new UserMessage(memoryMessage.append(message).toString()));
return finalPromptMessageList;
}

Expand Down Expand Up @@ -258,11 +255,11 @@ public void cancelStream(String threadId, String runId) {
}

public Flux<ServerSentEvent<String>> streamRunConversation(String assistantId, String threadId,
String message) {
String userMessage) {

final String runId = UniqueIdUtil.generateRunId();

Mono<Object> saveUserMessageMono = getSaveUserMessageMono(assistantId, threadId, message);
Mono<Object> saveUserMessageMono = getSaveUserMessageMono(assistantId, threadId, userMessage);
Mono<AssistantDto> assistantEntityMono = getAssistantEntityMono(assistantId);
Mono<List<MessageDto>> pastMessages = getPastMessagesMono(threadId);
Mono<InitInfos> initInfosMono = getInitInfosMono(assistantEntityMono, pastMessages);
Expand All @@ -275,7 +272,7 @@ public Flux<ServerSentEvent<String>> streamRunConversation(String assistantId, S
.flux()
.flatMap(initInfos -> {
List<DocumentWithoutEmbeddings> memorySearchResults = getMemorySearchResults(
initInfos.assistantDto, message);
initInfos.assistantDto, userMessage);

return Flux.just(
new PreparedInfos(initInfos.assistantDto(), initInfos.pastMessages(),
Expand All @@ -286,7 +283,7 @@ public Flux<ServerSentEvent<String>> streamRunConversation(String assistantId, S
List<MessageDto> pastMessagesList = preparedInfos.pastMessages();
List<DocumentWithoutEmbeddings> memoryResultsList = preparedInfos.memoryResults();

List<Message> finalPromptMessageList = getFinalPromptMessageList(message,
List<Message> finalPromptMessageList = getFinalPromptMessageList(
pastMessagesList,
assistantDto, memoryResultsList);

Expand All @@ -303,19 +300,20 @@ public Flux<ServerSentEvent<String>> streamRunConversation(String assistantId, S

Prompt prompt = new Prompt(finalPromptMessageList, promptOptions);

return universalChatService.stream(assistantDto.system(), prompt);
return universalChatService.stream(assistantDto.system(), finalPromptMessageList,
userMessage, assistantId, promptOptions);
})
.doOnCancel(() -> {
LOGGER.debug("doOnCancel. message={}", assistantMessageContent);
LOGGER.debug("doOnCancel. userMessage={}", assistantMessageContent);
})
.mapNotNull(chatResponse -> mapChatResponse(chatResponse, assistantMessageContent))
.doOnSubscribe(subscription -> {
LOGGER.debug("doOnSubscribe. message={}", assistantMessageContent);
LOGGER.debug("doOnSubscribe. userMessage={}", assistantMessageContent);

activeStreams.put(runId, subscription);
})
.doOnComplete(() -> {
LOGGER.trace("doOnComplete. message={}", assistantMessageContent);
LOGGER.trace("doOnComplete. userMessage={}", assistantMessageContent);

Mono.fromRunnable(
() -> saveNewMessage(assistantId, threadId, MessageType.ASSISTANT,
Expand Down Expand Up @@ -524,7 +522,7 @@ public ThreadTitleDto generateThreadTitle(String threadId,
Prompt titlePrompt = new Prompt(new UserMessage(content), options);

try {
ChatResponse titleResponse = universalChatService.call(LlmSystem.OPENAI, titlePrompt);
ChatResponse titleResponse = universalChatService.call(LlmSystem.OPENAI, content, options);

String generatedTitle = titleResponse.getResult().getOutput().getContent();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,16 @@
import com.talkforgeai.backend.assistant.dto.AssistantDto;
import com.talkforgeai.backend.assistant.dto.LlmSystem;
import com.talkforgeai.backend.assistant.exception.AssistentException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.anthropic.AnthropicChatModel;
import org.springframework.ai.anthropic.AnthropicChatOptions;
import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.VectorStoreChatMemoryAdvisor;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.mistralai.MistralAiChatModel;
import org.springframework.ai.mistralai.MistralAiChatOptions;
import org.springframework.ai.mistralai.api.MistralAiApi;
Expand All @@ -38,14 +37,20 @@
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestClient;
import reactor.core.publisher.Flux;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

@Service
public class UniversalChatService {

private final int DEFAULT_CHAT_MEMORY_RESPONSE_SIZE = 5;

@Qualifier("openAiRestClient")
private final RestClient openAiRestClient;
private final OpenAiChatModel openAiChatModel;
Expand All @@ -56,16 +61,21 @@ public class UniversalChatService {
@Qualifier("ollamaAiRestClient")
private final RestClient ollamaAiRestClient;

@Qualifier("dbVectorStore")
private final VectorStore dbVectorStore;

public UniversalChatService(RestClient openAiRestClient,
OpenAiChatModel openAiChatModel,
MistralAiChatModel mistralAiChatModel, AnthropicChatModel anthropicChatModel,
OllamaChatModel ollamaChatModel, RestClient ollamaAiRestClient) {
OllamaChatModel ollamaChatModel, RestClient ollamaAiRestClient, VectorStore dbVectorStore) {
this.openAiRestClient = openAiRestClient;
this.openAiChatModel = openAiChatModel;
this.mistralAiChatModel = mistralAiChatModel;
this.anthropicChatModel = anthropicChatModel;
this.ollamaChatModel = ollamaChatModel;
this.ollamaAiRestClient = ollamaAiRestClient;

this.dbVectorStore = dbVectorStore;
}

public ChatOptions getPromptOptions(AssistantDto assistantDto,
Expand Down Expand Up @@ -126,38 +136,48 @@ public String printPromptOptions(LlmSystem system, ChatOptions options) {
return printedOptions.toString();
}

ChatResponse call(LlmSystem system, Prompt prompt) {
return getChatClient(system).call(prompt);
ChatResponse call(LlmSystem system, String prompt, ChatOptions options) {
return getClient(system)
.prompt()
.options(options)
.user(prompt)
.call()
.chatResponse();
}

Flux<ChatResponse> stream(LlmSystem system, Prompt prompt) {
return getStreamingChatClient(system).stream(prompt);
Flux<ChatResponse> stream(LlmSystem system, List<Message> messages, String userMessage,
String conversationId,
ChatOptions options) {

return getClient(system)
.prompt()
.advisors(getVectorStoreChatMemoryAdvisor(conversationId))
.options(options)
.messages(messages)
.user(userMessage)
.stream()
.chatResponse();
}

StreamingChatModel getStreamingChatClient(LlmSystem system) {
return (StreamingChatModel) getClient(system);
private @NotNull VectorStoreChatMemoryAdvisor getVectorStoreChatMemoryAdvisor(
String converationId) {
return new VectorStoreChatMemoryAdvisor(
dbVectorStore,
converationId,
DEFAULT_CHAT_MEMORY_RESPONSE_SIZE
);
}

ChatModel getChatClient(LlmSystem system) {
return (ChatModel) getClient(system);
}
private ChatClient getClient(LlmSystem system) {
ChatModel model = switch (system) {
case OPENAI -> openAiChatModel;
case MISTRAL -> mistralAiChatModel;
case OLLAMA -> ollamaChatModel;
case ANSTHROPIC -> anthropicChatModel;
};

private Object getClient(LlmSystem system) {
switch (system) {
case OPENAI -> {
return openAiChatModel;
}
case MISTRAL -> {
return mistralAiChatModel;
}
case OLLAMA -> {
return ollamaChatModel;
}
case ANSTHROPIC -> {
return anthropicChatModel;
}
default -> throw new IllegalStateException("Unexpected system: " + system);
}
return ChatClient.builder(model)
.build();
}

private MistralAiChatOptions getMistralOptions(AssistantDto assistantDto,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ public class MemoryDocument {

@ManyToOne
private AssistantEntity assistant;

@Column(name = "run_id", length = 50)
private String runId;
@Column(name = "message_type", length = 20)
private String messageType;

public String getId() {
return id;
Expand Down Expand Up @@ -146,4 +149,13 @@ public String getRunId() {
public void setRunId(String runId) {
this.runId = runId;
}

public String getMessageType() {
return messageType;
}

public void setMessageType(String memoryType) {
this.messageType = memoryType;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

public record DocumentWithoutEmbeddings(String id, String content, String assistantId,
String assistantName,
String system, String model) {
String system, String model, String messageType) {

private static String getMetadataValue(Map<String, Object> meta, MetadataKey key) {
return (String) meta.getOrDefault(key.key(), null);
Expand All @@ -33,10 +33,11 @@ public static DocumentWithoutEmbeddings from(Document document) {
return new DocumentWithoutEmbeddings(
document.getId(),
document.getContent(),
getMetadataValue(meta, MetadataKey.ASSISTANT_ID),
getMetadataValue(meta, MetadataKey.CONVERSATION_ID),
getMetadataValue(meta, MetadataKey.ASSISTANT_NAME),
getMetadataValue(meta, MetadataKey.SYSTEM),
getMetadataValue(meta, MetadataKey.MODEL)
getMetadataValue(meta, MetadataKey.MODEL),
getMetadataValue(meta, MetadataKey.MESSAGE_TYPE)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
public enum MetadataKey {
SYSTEM("system"),
MODEL("model"),
ASSISTANT_ID("assistantId"),
CONVERSATION_ID("conversationId"),
ASSISTANT_NAME("assistantName"),
MESSAGE_TYPE("messageType"),
RUN_ID("runId");

private final String key;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.talkforgeai.backend.memory.repository;

import com.talkforgeai.backend.memory.domain.MemoryDocument;
import java.util.List;
import org.jetbrains.annotations.NotNull;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
Expand All @@ -30,6 +31,10 @@ public interface MemoryRepository extends JpaRepository<MemoryDocument, String>
@NotNull
Page<MemoryDocument> findAll(@NotNull Pageable pageable);

@NotNull
List<MemoryDocument> findAllByAssistantId(@NotNull String assistantId);


@Query("SELECT COUNT(md) FROM MemoryDocument md WHERE md.content = :content AND md.assistant.id = :assistantId")
int countByContentAndAssistantId(String content, String assistantId);

Expand Down
Loading

0 comments on commit a8c8b77

Please sign in to comment.