Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BIGTOP-4313: Adjust the code of AI module #136

Merged
merged 9 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,74 +18,87 @@
*/
package org.apache.bigtop.manager.ai.assistant;

import org.apache.bigtop.manager.ai.assistant.provider.LocSystemPromptProvider;
import org.apache.bigtop.manager.ai.assistant.store.ChatMemoryStoreProvider;
import org.apache.bigtop.manager.ai.assistant.config.GeneralAssistantConfig;
import org.apache.bigtop.manager.ai.assistant.provider.ChatMemoryStoreProvider;
import org.apache.bigtop.manager.ai.core.AbstractAIAssistantFactory;
import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.enums.SystemPrompt;
import org.apache.bigtop.manager.ai.core.exception.AssistantConfigNotSetException;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;
import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider;
import org.apache.bigtop.manager.ai.dashscope.DashScopeAssistant;
import org.apache.bigtop.manager.ai.openai.OpenAIAssistant;
import org.apache.bigtop.manager.ai.qianfan.QianFanAssistant;

import org.springframework.stereotype.Component;

import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;

import jakarta.annotation.Resource;
import java.util.ArrayList;
import java.util.List;

@Component
public class GeneralAssistantFactory extends AbstractAIAssistantFactory {

private final SystemPromptProvider systemPromptProvider;
private final ChatMemoryStoreProvider chatMemoryStoreProvider;
@Resource
private SystemPromptProvider systemPromptProvider;

@Resource
private ChatMemoryStoreProvider chatMemoryStoreProvider;

public GeneralAssistantFactory(ChatMemoryStoreProvider chatMemoryStoreProvider) {
this(new LocSystemPromptProvider(), chatMemoryStoreProvider);
private void configureSystemPrompt(AIAssistant.Builder builder, SystemPrompt systemPrompt, String locale) {
List<String> systemPrompts = new ArrayList<>();
if (systemPrompt != null) {
systemPrompts.add(systemPromptProvider.getSystemMessage(systemPrompt));
}
if (locale != null) {
systemPrompts.add(systemPromptProvider.getLanguagePrompt(locale));
}
builder.withSystemPrompt(systemPromptProvider.getSystemMessages(systemPrompts));
}

public GeneralAssistantFactory(
SystemPromptProvider systemPromptProvider, ChatMemoryStoreProvider chatMemoryStoreProvider) {
this.systemPromptProvider = systemPromptProvider;
this.chatMemoryStoreProvider = chatMemoryStoreProvider;
private AIAssistant.Builder initializeBuilder(PlatformType platformType) {
return switch (platformType) {
case OPENAI -> OpenAIAssistant.builder();
case DASH_SCOPE -> DashScopeAssistant.builder();
case QIANFAN -> QianFanAssistant.builder();
};
}

@Override
public AIAssistant createWithPrompt(
PlatformType platformType,
AIAssistantConfigProvider assistantConfig,
Object id,
ToolProvider toolProvider,
SystemPrompt systemPrompt) {
AIAssistant.Builder builder =
switch (platformType) {
case OPENAI -> OpenAIAssistant.builder();
case DASH_SCOPE -> DashScopeAssistant.builder();
case QIANFAN -> QianFanAssistant.builder();
};
builder = builder.id(id)
.memoryStore(
(id == null)
? new InMemoryChatMemoryStore()
: chatMemoryStoreProvider.createPersistentChatMemoryStore())
.withConfigProvider(assistantConfig)
.withToolProvider(toolProvider);

List<String> systemPrompts = new java.util.ArrayList<>();
systemPrompts.add(systemPromptProvider.getSystemMessage(systemPrompt));
String locale = assistantConfig.getLanguage();
if (locale != null) {
systemPrompts.add(systemPromptProvider.getLanguagePrompt(locale));
AIAssistantConfig config, ToolProvider toolProvider, SystemPrompt systemPrompt) {
GeneralAssistantConfig generalAssistantConfig = (GeneralAssistantConfig) config;
PlatformType platformType = generalAssistantConfig.getPlatformType();
Object id = generalAssistantConfig.getId();
if (id == null) {
throw new AssistantConfigNotSetException("ID");
}

builder.withSystemPrompt(systemPromptProvider.getSystemMessages(systemPrompts));
AIAssistant.Builder builder = initializeBuilder(platformType);
builder.id(id)
.memoryStore(chatMemoryStoreProvider.createPersistentChatMemoryStore())
.withConfig(generalAssistantConfig)
.withToolProvider(toolProvider);

configureSystemPrompt(builder, systemPrompt, generalAssistantConfig.getLanguage());

return builder.build();
}

@Override
public AIAssistant createAiService(
PlatformType platformType, AIAssistantConfigProvider assistantConfig, Long id, ToolProvider toolProvider) {
return createWithPrompt(platformType, assistantConfig, id, toolProvider, SystemPrompt.DEFAULT_PROMPT);
public AIAssistant createForTest(AIAssistantConfig config, ToolProvider toolProvider) {
GeneralAssistantConfig generalAssistantConfig = (GeneralAssistantConfig) config;
PlatformType platformType = generalAssistantConfig.getPlatformType();
AIAssistant.Builder builder = initializeBuilder(platformType);

builder.id(null)
.memoryStore(chatMemoryStoreProvider.createInMemoryChatMemoryStore())
.withConfig(generalAssistantConfig)
.withToolProvider(toolProvider);

return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,34 @@
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.ai.assistant.provider;
package org.apache.bigtop.manager.ai.assistant.config;

import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;
import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;

import lombok.Getter;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

public class AIAssistantConfig implements AIAssistantConfigProvider {
@Getter
public class GeneralAssistantConfig implements AIAssistantConfig {

/**
* Model name for platform that we want to use
*/
private final Long id;
private final String model;

/**
* Credentials for different platforms
*/
private final Map<String, String> credentials;

private final String language;
/**
* Platform extra configs are put here
*/
private final PlatformType platformType;
private final Map<String, String> credentials;
private final Map<String, String> configs;

private AIAssistantConfig(
String model, Map<String, String> credentials, String language, Map<String, String> configMap) {
this.model = model;
this.credentials = credentials;
this.language = language;
this.configs = configMap;
private GeneralAssistantConfig(Builder builder) {
this.model = Objects.requireNonNull(builder.model);
this.credentials = Objects.requireNonNull(builder.credentials);
this.platformType = Objects.requireNonNull(builder.platformType);
this.language = builder.language;
this.id = builder.id;
this.configs = builder.configs;
}

public static Builder builder() {
Expand All @@ -68,26 +65,29 @@ public Map<String, String> getConfigs() {
return configs;
}

@Override
public String getLanguage() {
return language;
}

public static class Builder {
private Long id;
private String model;
private String language;

private PlatformType platformType;
private final Map<String, String> credentials = new HashMap<>();

private final Map<String, String> configs = new HashMap<>();

public Builder() {}

public Builder setModel(String model) {
this.model = model;
return this;
}

public Builder setPlatformType(PlatformType platformType) {
this.platformType = platformType;
return this;
}

public Builder setId(Long id) {
this.id = id;
return this;
}

public Builder addCredential(String key, String value) {
credentials.put(key, value);
return this;
Expand Down Expand Up @@ -115,8 +115,8 @@ public Builder addConfigs(Map<String, String> configMap) {
return this;
}

public AIAssistantConfig build() {
return new AIAssistantConfig(model, credentials, language, configs);
public GeneralAssistantConfig build() {
return new GeneralAssistantConfig(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,32 @@
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.ai.assistant.store;
package org.apache.bigtop.manager.ai.assistant.provider;

import org.apache.bigtop.manager.ai.assistant.store.PersistentChatMemoryStore;
import org.apache.bigtop.manager.dao.repository.ChatMessageDao;
import org.apache.bigtop.manager.dao.repository.ChatThreadDao;

import org.springframework.stereotype.Component;

import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;

import jakarta.annotation.Resource;

@Component
public class ChatMemoryStoreProvider {
private final ChatThreadDao chatThreadDao;
private final ChatMessageDao chatMessageDao;
@Resource
private ChatThreadDao chatThreadDao;

public ChatMemoryStoreProvider(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
this.chatThreadDao = chatThreadDao;
this.chatMessageDao = chatMessageDao;
}

public ChatMemoryStoreProvider() {
this(null, null);
}
@Resource
private ChatMessageDao chatMessageDao;

public ChatMemoryStore createPersistentChatMemoryStore() {
if (chatThreadDao == null || chatMessageDao == null) {
return new InMemoryChatMemoryStore();
}
return new PersistentChatMemoryStore(chatThreadDao, chatMessageDao);
}

public ChatMemoryStore createInMemoryChatMemoryStore() {
return new InMemoryChatMemoryStore();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.bigtop.manager.ai.core.enums.SystemPrompt;
import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider;

import org.springframework.stereotype.Component;
import org.springframework.util.ResourceUtils;

import lombok.extern.slf4j.Slf4j;
Expand All @@ -33,6 +34,7 @@
import java.util.Objects;

@Slf4j
@Component
public class LocSystemPromptProvider implements SystemPromptProvider {

@Override
Expand Down Expand Up @@ -67,21 +69,14 @@ private String loadTextFromFile(String fileName) {
private String loadPromptFromFile(String fileName) {
final String filePath = fileName + ".st";
String text = loadTextFromFile(filePath);
if (text == null) {
return "You are a helpful assistant.";
} else {
return text;
}
return Objects.requireNonNullElse(text, "You are a helpful assistant.");
}

@Override
public String getLanguagePrompt(String locale) {
final String filePath = SystemPrompt.LANGUAGE_PROMPT.getValue() + '-' + locale + ".st";
String text = loadTextFromFile(filePath);
if (text == null) {
return "Answer in " + locale;
} else {
return text;
}
return Objects.requireNonNullElseGet(text, () -> "Answer in " + locale);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
public class PersistentChatMemoryStore implements ChatMemoryStore {

private final Map<Object, List<ChatMessage>> messagesByMemoryId = new ConcurrentHashMap<>();
protected final ChatThreadDao chatThreadDao;
protected final ChatMessageDao chatMessageDao;
private final ChatThreadDao chatThreadDao;
private final ChatMessageDao chatMessageDao;

public PersistentChatMemoryStore(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
this.chatThreadDao = chatThreadDao;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,12 @@
*/
package org.apache.bigtop.manager.ai.assistant;

import org.apache.bigtop.manager.ai.assistant.store.ChatMemoryStoreProvider;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;
import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider;
import org.apache.bigtop.manager.ai.openai.OpenAIAssistant;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;
Expand All @@ -42,42 +38,35 @@
class GeneralAssistantFactoryTest {

@Mock
private SystemPromptProvider systemPromptProvider;
private AIAssistantConfig assistantConfigProvider;

@Mock
private AIAssistantConfigProvider assistantConfigProvider;

@InjectMocks
private GeneralAssistantFactory generalAssistantFactory;

@BeforeEach
void setUp() {
MockitoAnnotations.openMocks(this);
generalAssistantFactory = new GeneralAssistantFactory(systemPromptProvider, new ChatMemoryStoreProvider());
Map<String, String> credentials = Map.of("apiKey", "123456");
when(assistantConfigProvider.getModel()).thenReturn("model");
when(assistantConfigProvider.getCredentials()).thenReturn(credentials);
when(assistantConfigProvider.getConfigs()).thenReturn(null);
when(assistantConfigProvider.getLanguage()).thenReturn("en");
}

@Test
void testCreateAIAssistant() {
AIAssistant.Builder mockBuilder = mock(OpenAIAssistant.Builder.class);
when(mockBuilder.id(any())).thenReturn(mockBuilder);
when(mockBuilder.memoryStore(any())).thenReturn(mockBuilder);
when(mockBuilder.withConfigProvider(any())).thenReturn(mockBuilder);
when(mockBuilder.withConfig(any())).thenReturn(mockBuilder);
when(mockBuilder.withToolProvider(any())).thenReturn(mockBuilder);
when(mockBuilder.withSystemPrompt(any())).thenReturn(mockBuilder);
when(mockBuilder.build()).thenReturn(mock(AIAssistant.class));

try (MockedStatic<OpenAIAssistant> openAIAssistantMockedStatic = mockStatic(OpenAIAssistant.class)) {
openAIAssistantMockedStatic.when(OpenAIAssistant::builder).thenReturn(mockBuilder);

PlatformType platformType = PlatformType.OPENAI;
generalAssistantFactory.create(platformType, assistantConfigProvider);
generalAssistantFactory = new GeneralAssistantFactory(new ChatMemoryStoreProvider());
generalAssistantFactory.create(platformType, assistantConfigProvider);
generalAssistantFactory.createAIService(assistantConfigProvider, null);
generalAssistantFactory.createForTest(assistantConfigProvider, null);
}
}
}
Loading
Loading