Skip to content

Commit

Permalink
Migration to 1.0.0-M1
Browse files Browse the repository at this point in the history
  • Loading branch information
jschm42 committed Jun 6, 2024
1 parent dea049c commit 785f90c
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import com.talkforgeai.backend.assistant.repository.MessageRepository;
import com.talkforgeai.backend.assistant.repository.ThreadRepository;
import com.talkforgeai.backend.memory.dto.DocumentWithoutEmbeddings;
import com.talkforgeai.backend.memory.repository.MemoryRepository;
import com.talkforgeai.backend.memory.service.MemoryService;
import com.talkforgeai.backend.storage.FileStorageService;
import com.talkforgeai.backend.transformers.MessageProcessor;
Expand Down Expand Up @@ -73,12 +72,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.autoconfigure.openai.OpenAiEmbeddingProperties;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.image.ImageResponse;
Expand Down Expand Up @@ -123,14 +122,13 @@ public class AssistantSpringService {


private final Map<String, Subscription> activeStreams = new ConcurrentHashMap<>();
private final MemoryRepository memoryRepository;

public AssistantSpringService(
UniversalChatService universalChatService, UniversalImageGenService universalImageGenService,
AssistantRepository assistantRepository, MessageRepository messageRepository,
ThreadRepository threadRepository, FileStorageService fileStorageService,
MessageProcessor messageProcessor, AssistantMapper assistantMapper,
MemoryService memoryService, MemoryRepository memoryRepository) {
MemoryService memoryService) {

this.universalChatService = universalChatService;
this.universalImageGenService = universalImageGenService;
Expand All @@ -141,7 +139,6 @@ public AssistantSpringService(
this.messageProcessor = messageProcessor;
this.assistantMapper = assistantMapper;
this.memoryService = memoryService;
this.memoryRepository = memoryRepository;
}

private static @NotNull Mono<InitInfos> getInitInfosMono(Mono<AssistantDto> assistantEntityMono,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.springframework.ai.anthropic.AnthropicChatClient;
import org.springframework.ai.anthropic.AnthropicChatModel;
import org.springframework.ai.anthropic.AnthropicChatOptions;
import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.mistralai.MistralAiChatClient;
import org.springframework.ai.mistralai.MistralAiChatModel;
import org.springframework.ai.mistralai.MistralAiChatOptions;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
Expand All @@ -48,23 +48,23 @@ public class UniversalChatService {

@Qualifier("openAiRestClient")
private final RestClient openAiRestClient;

private final OpenAiChatClient openAiChatClient;
private final MistralAiChatClient mistralAiChatClient;
private final AnthropicChatClient anthropicChatClient;
private final OllamaChatClient ollamaChatClient;
private final OpenAiChatModel openAiChatModel;
private final MistralAiChatModel mistralAiChatModel;
private final AnthropicChatModel anthropicChatModel;
private final OllamaChatModel ollamaChatModel;

@Qualifier("ollamaAiRestClient")
private final RestClient ollamaAiRestClient;

public UniversalChatService(RestClient openAiRestClient, OpenAiChatClient openAiChatClient,
MistralAiChatClient mistralAiChatClient, AnthropicChatClient anthropicChatClient,
OllamaChatClient ollamaChatClient, RestClient ollamaAiRestClient) {
public UniversalChatService(RestClient openAiRestClient,
OpenAiChatModel openAiChatModel,
MistralAiChatModel mistralAiChatModel, AnthropicChatModel anthropicChatModel,
OllamaChatModel ollamaChatModel, RestClient ollamaAiRestClient) {
this.openAiRestClient = openAiRestClient;
this.openAiChatClient = openAiChatClient;
this.mistralAiChatClient = mistralAiChatClient;
this.anthropicChatClient = anthropicChatClient;
this.ollamaChatClient = ollamaChatClient;
this.openAiChatModel = openAiChatModel;
this.mistralAiChatModel = mistralAiChatModel;
this.anthropicChatModel = anthropicChatModel;
this.ollamaChatModel = ollamaChatModel;
this.ollamaAiRestClient = ollamaAiRestClient;
}

Expand Down Expand Up @@ -134,27 +134,27 @@ Flux<ChatResponse> stream(LlmSystem system, Prompt prompt) {
return getStreamingChatClient(system).stream(prompt);
}

StreamingChatClient getStreamingChatClient(LlmSystem system) {
return (StreamingChatClient) getClient(system);
StreamingChatModel getStreamingChatClient(LlmSystem system) {
return (StreamingChatModel) getClient(system);
}

ChatClient getChatClient(LlmSystem system) {
return (ChatClient) getClient(system);
ChatModel getChatClient(LlmSystem system) {
return (ChatModel) getClient(system);
}

private Object getClient(LlmSystem system) {
switch (system) {
case OPENAI -> {
return openAiChatClient;
return openAiChatModel;
}
case MISTRAL -> {
return mistralAiChatClient;
return mistralAiChatModel;
}
case OLLAMA -> {
return ollamaChatClient;
return ollamaChatModel;
}
case ANSTHROPIC -> {
return anthropicChatClient;
return anthropicChatModel;
}
default -> throw new IllegalStateException("Unexpected system: " + system);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,29 @@

import com.talkforgeai.backend.assistant.dto.ImageGenSystem;
import com.talkforgeai.backend.assistant.exception.AssistentException;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.openai.OpenAiImageClient;
import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.stabilityai.StabilityAiImageClient;
import org.springframework.ai.stabilityai.StabilityAiImageModel;
import org.springframework.ai.stabilityai.api.StabilityAiApi;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
import org.springframework.stereotype.Service;

@Service
public class UniversalImageGenService {

private final OpenAiImageClient openAiImageClient;
private final OpenAiImageModel openAiImageModel;

private final StabilityAiImageClient stabilityAiImageClient;
private final StabilityAiImageModel stabilityAiImageModel;

public UniversalImageGenService(OpenAiImageClient openAiImageClient,
StabilityAiImageClient stabilityAiImageClient) {
this.openAiImageClient = openAiImageClient;
this.stabilityAiImageClient = stabilityAiImageClient;
public UniversalImageGenService(OpenAiImageModel openAiImageModel,
StabilityAiImageModel stabilityAiImageModel) {
this.openAiImageModel = openAiImageModel;
this.stabilityAiImageModel = stabilityAiImageModel;
}

public ImageResponse generate(ImageGenSystem imageGenSystem, String text) {
Expand Down Expand Up @@ -94,20 +94,20 @@ ImageOptions getImageOptions(ImageGenSystem system, UniversalImageOptions univer
}
}

private ImageClient getClient(ImageGenSystem system) {
private ImageModel getClient(ImageGenSystem system) {
switch (system) {
case OPENAI -> {
return openAiImageClient;
return openAiImageModel;
}
case STABILITY -> {
return stabilityAiImageClient;
return stabilityAiImageModel;
}
default -> throw new IllegalStateException("Unexpected image gen system: " + system);
}
}

public static record UniversalImageOptions(String model, String quality, int n, int height,
int width) {
public record UniversalImageOptions(String model, String quality, int n, int height,
int width) {

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import com.talkforgeai.backend.memory.repository.MemoryRepository;
import com.talkforgeai.backend.memory.service.DBVectorStore;
import jakarta.persistence.EntityManager;
import org.springframework.ai.openai.OpenAiEmbeddingClient;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand All @@ -29,12 +29,12 @@
public class VectorStoreConfiguration {

private final EntityManager entityManager;
private final OpenAiEmbeddingClient embeddingClient;
private final OpenAiEmbeddingModel embeddingClient;
private final MemoryRepository memoryRepository;
private final AssistantRepository assistantRepository;

public VectorStoreConfiguration(EntityManager entityManager,
OpenAiEmbeddingClient embeddingClient,
OpenAiEmbeddingModel embeddingClient,
MemoryRepository memoryRepository, AssistantRepository assistantRepository) {
this.entityManager = entityManager;
this.embeddingClient = embeddingClient;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.SimpleVectorStore.EmbeddingMath;
import org.springframework.data.domain.PageRequest;
Expand All @@ -53,15 +53,15 @@ public class DBVectorStore implements ListableVectoreStore {
private final EntityManager entityManager;
private final MemoryRepository memoryRepository;
private final AssistantRepository assistantRepository;
private final EmbeddingClient embeddingClient;
private final EmbeddingModel embeddingModel;

public DBVectorStore(EntityManager entityManager, MemoryRepository memoryRepository,
AssistantRepository assistantRepository,
EmbeddingClient embeddingClient) {
EmbeddingModel embeddingModel) {
this.entityManager = entityManager;
this.memoryRepository = memoryRepository;
this.assistantRepository = assistantRepository;
this.embeddingClient = embeddingClient;
this.embeddingModel = embeddingModel;
}

@Transactional
Expand Down Expand Up @@ -89,7 +89,7 @@ public void add(List<Document> documents) {
MemoryDocument documentEntity = new MemoryDocument();

LOGGER.info("Calling EmbeddingClient for document id = {}", document.getId());
List<Double> embedding = this.embeddingClient.embed(document);
List<Double> embedding = this.embeddingModel.embed(document);
documentEntity.setEmbeddings(
embedding.stream().mapToDouble(Double::doubleValue).toArray());
// Convert List<Double> to byte[]
Expand Down Expand Up @@ -168,7 +168,7 @@ public List<Document> similaritySearch(SearchRequest request) {
}

private List<Double> getUserQueryEmbedding(String query) {
return this.embeddingClient.embed(query);
return this.embeddingModel.embed(query);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import com.talkforgeai.backend.assistant.exception.AssistentException;
import com.talkforgeai.backend.voice.dto.TranscriptionSystem;
import java.io.File;
import org.springframework.ai.model.ModelClient;
import org.springframework.ai.model.Model;
import org.springframework.ai.model.ModelOptions;
import org.springframework.ai.openai.OpenAiAudioTranscriptionClient;
import org.springframework.ai.openai.OpenAiAudioTranscriptionModel;
import org.springframework.ai.openai.OpenAiAudioTranscriptionOptions;
import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptResponseFormat;
import org.springframework.ai.openai.api.OpenAiAudioApi.WhisperModel;
Expand All @@ -33,11 +33,11 @@
@Service
public class UniversalTranscriptionService {

private final OpenAiAudioTranscriptionClient openAiTranscriptionClient;
private final OpenAiAudioTranscriptionModel openAiTranscriptionModel;


public UniversalTranscriptionService(OpenAiAudioTranscriptionClient openAiTranscriptionClient) {
this.openAiTranscriptionClient = openAiTranscriptionClient;
public UniversalTranscriptionService(OpenAiAudioTranscriptionModel openAiTranscriptionModel) {
this.openAiTranscriptionModel = openAiTranscriptionModel;
}

ModelOptions getDefaultTranscriptionOptions(TranscriptionSystem system) {
Expand Down Expand Up @@ -90,11 +90,11 @@ public AudioTranscriptionResponse transcribe(TranscriptionSystem system, File au
}


private ModelClient<AudioTranscriptionPrompt, AudioTranscriptionResponse> getClient(
private Model<AudioTranscriptionPrompt, AudioTranscriptionResponse> getClient(
TranscriptionSystem system) {
switch (system) {
case OPENAI -> {
return openAiTranscriptionClient;
return openAiTranscriptionModel;
}
default -> throw new IllegalStateException("Unexpected transcription system: " + system);
}
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
<maven.compiler.source>21</maven.compiler.source>
<maven.compiler.target>21</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<spring-ai.version>1.0.0-SNAPSHOT</spring-ai.version>
<spring-ai.version>1.0.0-M1</spring-ai.version>
</properties>
<dependencies>
<dependency>
Expand Down

0 comments on commit 785f90c

Please sign in to comment.