From 785f90c9b975bca75f387694bf73cfc3cba7a3f0 Mon Sep 17 00:00:00 2001 From: jschm42 Date: Thu, 6 Jun 2024 12:37:55 +0200 Subject: [PATCH] Migration to 1.0.0-M1 --- .../service/AssistantSpringService.java | 7 +-- .../service/UniversalChatService.java | 54 +++++++++---------- .../service/UniversalImageGenService.java | 28 +++++----- .../memory/VectorStoreConfiguration.java | 6 +-- .../backend/memory/service/DBVectorStore.java | 12 ++--- .../UniversalTranscriptionService.java | 14 ++--- pom.xml | 2 +- 7 files changed, 60 insertions(+), 63 deletions(-) diff --git a/backend/src/main/java/com/talkforgeai/backend/assistant/service/AssistantSpringService.java b/backend/src/main/java/com/talkforgeai/backend/assistant/service/AssistantSpringService.java index ecb17a0d..aa06e08a 100644 --- a/backend/src/main/java/com/talkforgeai/backend/assistant/service/AssistantSpringService.java +++ b/backend/src/main/java/com/talkforgeai/backend/assistant/service/AssistantSpringService.java @@ -43,7 +43,6 @@ import com.talkforgeai.backend.assistant.repository.MessageRepository; import com.talkforgeai.backend.assistant.repository.ThreadRepository; import com.talkforgeai.backend.memory.dto.DocumentWithoutEmbeddings; -import com.talkforgeai.backend.memory.repository.MemoryRepository; import com.talkforgeai.backend.memory.service.MemoryService; import com.talkforgeai.backend.storage.FileStorageService; import com.talkforgeai.backend.transformers.MessageProcessor; @@ -73,12 +72,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.autoconfigure.openai.OpenAiEmbeddingProperties; -import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.image.ImageResponse; @@ -123,14 +122,13 @@ public class AssistantSpringService { private final Map activeStreams = new ConcurrentHashMap<>(); - private final MemoryRepository memoryRepository; public AssistantSpringService( UniversalChatService universalChatService, UniversalImageGenService universalImageGenService, AssistantRepository assistantRepository, MessageRepository messageRepository, ThreadRepository threadRepository, FileStorageService fileStorageService, MessageProcessor messageProcessor, AssistantMapper assistantMapper, - MemoryService memoryService, MemoryRepository memoryRepository) { + MemoryService memoryService) { this.universalChatService = universalChatService; this.universalImageGenService = universalImageGenService; @@ -141,7 +139,6 @@ public AssistantSpringService( this.messageProcessor = messageProcessor; this.assistantMapper = assistantMapper; this.memoryService = memoryService; - this.memoryRepository = memoryRepository; } private static @NotNull Mono getInitInfosMono(Mono assistantEntityMono, diff --git a/backend/src/main/java/com/talkforgeai/backend/assistant/service/UniversalChatService.java b/backend/src/main/java/com/talkforgeai/backend/assistant/service/UniversalChatService.java index 3804366d..6d95e246 100644 --- a/backend/src/main/java/com/talkforgeai/backend/assistant/service/UniversalChatService.java +++ b/backend/src/main/java/com/talkforgeai/backend/assistant/service/UniversalChatService.java @@ -22,21 +22,21 @@ import java.util.Arrays; import java.util.List; import java.util.Map; -import org.springframework.ai.anthropic.AnthropicChatClient; +import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.AnthropicChatOptions; import org.springframework.ai.anthropic.api.AnthropicApi; -import org.springframework.ai.chat.ChatClient; -import org.springframework.ai.chat.ChatResponse; -import org.springframework.ai.chat.StreamingChatClient; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.mistralai.MistralAiChatClient; +import org.springframework.ai.mistralai.MistralAiChatModel; import org.springframework.ai.mistralai.MistralAiChatOptions; import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.ollama.OllamaChatClient; +import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.api.OllamaOptions; -import org.springframework.ai.openai.OpenAiChatClient; +import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.stereotype.Service; @@ -48,23 +48,23 @@ public class UniversalChatService { @Qualifier("openAiRestClient") private final RestClient openAiRestClient; - - private final OpenAiChatClient openAiChatClient; - private final MistralAiChatClient mistralAiChatClient; - private final AnthropicChatClient anthropicChatClient; - private final OllamaChatClient ollamaChatClient; + private final OpenAiChatModel openAiChatModel; + private final MistralAiChatModel mistralAiChatModel; + private final AnthropicChatModel anthropicChatModel; + private final OllamaChatModel ollamaChatModel; @Qualifier("ollamaAiRestClient") private final RestClient ollamaAiRestClient; - public UniversalChatService(RestClient openAiRestClient, OpenAiChatClient openAiChatClient, - MistralAiChatClient mistralAiChatClient, AnthropicChatClient anthropicChatClient, - OllamaChatClient ollamaChatClient, RestClient ollamaAiRestClient) { + public UniversalChatService(RestClient openAiRestClient, + OpenAiChatModel openAiChatModel, + MistralAiChatModel mistralAiChatModel, AnthropicChatModel anthropicChatModel, + OllamaChatModel ollamaChatModel, RestClient ollamaAiRestClient) { this.openAiRestClient = openAiRestClient; - this.openAiChatClient = openAiChatClient; - this.mistralAiChatClient = mistralAiChatClient; - this.anthropicChatClient = anthropicChatClient; - this.ollamaChatClient = ollamaChatClient; + this.openAiChatModel = openAiChatModel; + this.mistralAiChatModel = mistralAiChatModel; + this.anthropicChatModel = anthropicChatModel; + this.ollamaChatModel = ollamaChatModel; this.ollamaAiRestClient = ollamaAiRestClient; } @@ -134,27 +134,27 @@ Flux stream(LlmSystem system, Prompt prompt) { return getStreamingChatClient(system).stream(prompt); } - StreamingChatClient getStreamingChatClient(LlmSystem system) { - return (StreamingChatClient) getClient(system); + StreamingChatModel getStreamingChatClient(LlmSystem system) { + return (StreamingChatModel) getClient(system); } - ChatClient getChatClient(LlmSystem system) { - return (ChatClient) getClient(system); + ChatModel getChatClient(LlmSystem system) { + return (ChatModel) getClient(system); } private Object getClient(LlmSystem system) { switch (system) { case OPENAI -> { - return openAiChatClient; + return openAiChatModel; } case MISTRAL -> { - return mistralAiChatClient; + return mistralAiChatModel; } case OLLAMA -> { - return ollamaChatClient; + return ollamaChatModel; } case ANSTHROPIC -> { - return anthropicChatClient; + return anthropicChatModel; } default -> throw new IllegalStateException("Unexpected system: " + system); } diff --git a/backend/src/main/java/com/talkforgeai/backend/assistant/service/UniversalImageGenService.java b/backend/src/main/java/com/talkforgeai/backend/assistant/service/UniversalImageGenService.java index dd0de57d..83f4a02f 100644 --- a/backend/src/main/java/com/talkforgeai/backend/assistant/service/UniversalImageGenService.java +++ b/backend/src/main/java/com/talkforgeai/backend/assistant/service/UniversalImageGenService.java @@ -18,14 +18,14 @@ import com.talkforgeai.backend.assistant.dto.ImageGenSystem; import com.talkforgeai.backend.assistant.exception.AssistentException; -import org.springframework.ai.image.ImageClient; +import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageOptions; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; -import org.springframework.ai.openai.OpenAiImageClient; +import org.springframework.ai.openai.OpenAiImageModel; import org.springframework.ai.openai.OpenAiImageOptions; import org.springframework.ai.openai.api.OpenAiImageApi; -import org.springframework.ai.stabilityai.StabilityAiImageClient; +import org.springframework.ai.stabilityai.StabilityAiImageModel; import org.springframework.ai.stabilityai.api.StabilityAiApi; import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; import org.springframework.stereotype.Service; @@ -33,14 +33,14 @@ @Service public class UniversalImageGenService { - private final OpenAiImageClient openAiImageClient; + private final OpenAiImageModel openAiImageModel; - private final StabilityAiImageClient stabilityAiImageClient; + private final StabilityAiImageModel stabilityAiImageModel; - public UniversalImageGenService(OpenAiImageClient openAiImageClient, - StabilityAiImageClient stabilityAiImageClient) { - this.openAiImageClient = openAiImageClient; - this.stabilityAiImageClient = stabilityAiImageClient; + public UniversalImageGenService(OpenAiImageModel openAiImageModel, + StabilityAiImageModel stabilityAiImageModel) { + this.openAiImageModel = openAiImageModel; + this.stabilityAiImageModel = stabilityAiImageModel; } public ImageResponse generate(ImageGenSystem imageGenSystem, String text) { @@ -94,20 +94,20 @@ ImageOptions getImageOptions(ImageGenSystem system, UniversalImageOptions univer } } - private ImageClient getClient(ImageGenSystem system) { + private ImageModel getClient(ImageGenSystem system) { switch (system) { case OPENAI -> { - return openAiImageClient; + return openAiImageModel; } case STABILITY -> { - return stabilityAiImageClient; + return stabilityAiImageModel; } default -> throw new IllegalStateException("Unexpected image gen system: " + system); } } - public static record UniversalImageOptions(String model, String quality, int n, int height, - int width) { + public record UniversalImageOptions(String model, String quality, int n, int height, + int width) { } diff --git a/backend/src/main/java/com/talkforgeai/backend/memory/VectorStoreConfiguration.java b/backend/src/main/java/com/talkforgeai/backend/memory/VectorStoreConfiguration.java index 8d473ac4..1eb234a1 100644 --- a/backend/src/main/java/com/talkforgeai/backend/memory/VectorStoreConfiguration.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/VectorStoreConfiguration.java @@ -20,7 +20,7 @@ import com.talkforgeai.backend.memory.repository.MemoryRepository; import com.talkforgeai.backend.memory.service.DBVectorStore; import jakarta.persistence.EntityManager; -import org.springframework.ai.openai.OpenAiEmbeddingClient; +import org.springframework.ai.openai.OpenAiEmbeddingModel; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -29,12 +29,12 @@ public class VectorStoreConfiguration { private final EntityManager entityManager; - private final OpenAiEmbeddingClient embeddingClient; + private final OpenAiEmbeddingModel embeddingClient; private final MemoryRepository memoryRepository; private final AssistantRepository assistantRepository; public VectorStoreConfiguration(EntityManager entityManager, - OpenAiEmbeddingClient embeddingClient, + OpenAiEmbeddingModel embeddingClient, MemoryRepository memoryRepository, AssistantRepository assistantRepository) { this.entityManager = entityManager; this.embeddingClient = embeddingClient; diff --git a/backend/src/main/java/com/talkforgeai/backend/memory/service/DBVectorStore.java b/backend/src/main/java/com/talkforgeai/backend/memory/service/DBVectorStore.java index df2e078a..dc1f8403 100644 --- a/backend/src/main/java/com/talkforgeai/backend/memory/service/DBVectorStore.java +++ b/backend/src/main/java/com/talkforgeai/backend/memory/service/DBVectorStore.java @@ -37,7 +37,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.SimpleVectorStore.EmbeddingMath; import org.springframework.data.domain.PageRequest; @@ -53,15 +53,15 @@ public class DBVectorStore implements ListableVectoreStore { private final EntityManager entityManager; private final MemoryRepository memoryRepository; private final AssistantRepository assistantRepository; - private final EmbeddingClient embeddingClient; + private final EmbeddingModel embeddingModel; public DBVectorStore(EntityManager entityManager, MemoryRepository memoryRepository, AssistantRepository assistantRepository, - EmbeddingClient embeddingClient) { + EmbeddingModel embeddingModel) { this.entityManager = entityManager; this.memoryRepository = memoryRepository; this.assistantRepository = assistantRepository; - this.embeddingClient = embeddingClient; + this.embeddingModel = embeddingModel; } @Transactional @@ -89,7 +89,7 @@ public void add(List documents) { MemoryDocument documentEntity = new MemoryDocument(); LOGGER.info("Calling EmbeddingClient for document id = {}", document.getId()); - List embedding = this.embeddingClient.embed(document); + List embedding = this.embeddingModel.embed(document); documentEntity.setEmbeddings( embedding.stream().mapToDouble(Double::doubleValue).toArray()); // Convert List to byte[] @@ -168,7 +168,7 @@ public List similaritySearch(SearchRequest request) { } private List getUserQueryEmbedding(String query) { - return this.embeddingClient.embed(query); + return this.embeddingModel.embed(query); } @Override diff --git a/backend/src/main/java/com/talkforgeai/backend/voice/service/UniversalTranscriptionService.java b/backend/src/main/java/com/talkforgeai/backend/voice/service/UniversalTranscriptionService.java index 8504b9eb..46891b90 100644 --- a/backend/src/main/java/com/talkforgeai/backend/voice/service/UniversalTranscriptionService.java +++ b/backend/src/main/java/com/talkforgeai/backend/voice/service/UniversalTranscriptionService.java @@ -19,9 +19,9 @@ import com.talkforgeai.backend.assistant.exception.AssistentException; import com.talkforgeai.backend.voice.dto.TranscriptionSystem; import java.io.File; -import org.springframework.ai.model.ModelClient; +import org.springframework.ai.model.Model; import org.springframework.ai.model.ModelOptions; -import org.springframework.ai.openai.OpenAiAudioTranscriptionClient; +import org.springframework.ai.openai.OpenAiAudioTranscriptionModel; import org.springframework.ai.openai.OpenAiAudioTranscriptionOptions; import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptResponseFormat; import org.springframework.ai.openai.api.OpenAiAudioApi.WhisperModel; @@ -33,11 +33,11 @@ @Service public class UniversalTranscriptionService { - private final OpenAiAudioTranscriptionClient openAiTranscriptionClient; + private final OpenAiAudioTranscriptionModel openAiTranscriptionModel; - public UniversalTranscriptionService(OpenAiAudioTranscriptionClient openAiTranscriptionClient) { - this.openAiTranscriptionClient = openAiTranscriptionClient; + public UniversalTranscriptionService(OpenAiAudioTranscriptionModel openAiTranscriptionModel) { + this.openAiTranscriptionModel = openAiTranscriptionModel; } ModelOptions getDefaultTranscriptionOptions(TranscriptionSystem system) { @@ -90,11 +90,11 @@ public AudioTranscriptionResponse transcribe(TranscriptionSystem system, File au } - private ModelClient getClient( + private Model getClient( TranscriptionSystem system) { switch (system) { case OPENAI -> { - return openAiTranscriptionClient; + return openAiTranscriptionModel; } default -> throw new IllegalStateException("Unexpected transcription system: " + system); } diff --git a/pom.xml b/pom.xml index 53d28301..e86b716a 100644 --- a/pom.xml +++ b/pom.xml @@ -42,7 +42,7 @@ 21 21 UTF-8 - 1.0.0-SNAPSHOT + 1.0.0-M1