diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java index fadd37322..bf4b52908 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java @@ -18,74 +18,87 @@ */ package org.apache.bigtop.manager.ai.assistant; -import org.apache.bigtop.manager.ai.assistant.provider.LocSystemPromptProvider; -import org.apache.bigtop.manager.ai.assistant.store.ChatMemoryStoreProvider; +import org.apache.bigtop.manager.ai.assistant.config.GeneralAssistantConfig; +import org.apache.bigtop.manager.ai.assistant.provider.ChatMemoryStoreProvider; import org.apache.bigtop.manager.ai.core.AbstractAIAssistantFactory; +import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig; import org.apache.bigtop.manager.ai.core.enums.PlatformType; import org.apache.bigtop.manager.ai.core.enums.SystemPrompt; +import org.apache.bigtop.manager.ai.core.exception.AssistantConfigNotSetException; import org.apache.bigtop.manager.ai.core.factory.AIAssistant; -import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider; import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider; import org.apache.bigtop.manager.ai.dashscope.DashScopeAssistant; import org.apache.bigtop.manager.ai.openai.OpenAIAssistant; import org.apache.bigtop.manager.ai.qianfan.QianFanAssistant; +import org.springframework.stereotype.Component; + import dev.langchain4j.service.tool.ToolProvider; -import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore; +import jakarta.annotation.Resource; +import java.util.ArrayList; import java.util.List; +@Component public class GeneralAssistantFactory extends AbstractAIAssistantFactory { - private final SystemPromptProvider systemPromptProvider; - private final ChatMemoryStoreProvider chatMemoryStoreProvider; + @Resource + private SystemPromptProvider systemPromptProvider; + + @Resource + private ChatMemoryStoreProvider chatMemoryStoreProvider; - public GeneralAssistantFactory(ChatMemoryStoreProvider chatMemoryStoreProvider) { - this(new LocSystemPromptProvider(), chatMemoryStoreProvider); + private void configureSystemPrompt(AIAssistant.Builder builder, SystemPrompt systemPrompt, String locale) { + List systemPrompts = new ArrayList<>(); + if (systemPrompt != null) { + systemPrompts.add(systemPromptProvider.getSystemMessage(systemPrompt)); + } + if (locale != null) { + systemPrompts.add(systemPromptProvider.getLanguagePrompt(locale)); + } + builder.withSystemPrompt(systemPromptProvider.getSystemMessages(systemPrompts)); } - public GeneralAssistantFactory( - SystemPromptProvider systemPromptProvider, ChatMemoryStoreProvider chatMemoryStoreProvider) { - this.systemPromptProvider = systemPromptProvider; - this.chatMemoryStoreProvider = chatMemoryStoreProvider; + private AIAssistant.Builder initializeBuilder(PlatformType platformType) { + return switch (platformType) { + case OPENAI -> OpenAIAssistant.builder(); + case DASH_SCOPE -> DashScopeAssistant.builder(); + case QIANFAN -> QianFanAssistant.builder(); + }; } @Override public AIAssistant createWithPrompt( - PlatformType platformType, - AIAssistantConfigProvider assistantConfig, - Object id, - ToolProvider toolProvider, - SystemPrompt systemPrompt) { - AIAssistant.Builder builder = - switch (platformType) { - case OPENAI -> OpenAIAssistant.builder(); - case DASH_SCOPE -> DashScopeAssistant.builder(); - case QIANFAN -> QianFanAssistant.builder(); - }; - builder = builder.id(id) - .memoryStore( - (id == null) - ? new InMemoryChatMemoryStore() - : chatMemoryStoreProvider.createPersistentChatMemoryStore()) - .withConfigProvider(assistantConfig) - .withToolProvider(toolProvider); - - List systemPrompts = new java.util.ArrayList<>(); - systemPrompts.add(systemPromptProvider.getSystemMessage(systemPrompt)); - String locale = assistantConfig.getLanguage(); - if (locale != null) { - systemPrompts.add(systemPromptProvider.getLanguagePrompt(locale)); + AIAssistantConfig config, ToolProvider toolProvider, SystemPrompt systemPrompt) { + GeneralAssistantConfig generalAssistantConfig = (GeneralAssistantConfig) config; + PlatformType platformType = generalAssistantConfig.getPlatformType(); + Object id = generalAssistantConfig.getId(); + if (id == null) { + throw new AssistantConfigNotSetException("ID"); } - builder.withSystemPrompt(systemPromptProvider.getSystemMessages(systemPrompts)); + AIAssistant.Builder builder = initializeBuilder(platformType); + builder.id(id) + .memoryStore(chatMemoryStoreProvider.createPersistentChatMemoryStore()) + .withConfig(generalAssistantConfig) + .withToolProvider(toolProvider); + + configureSystemPrompt(builder, systemPrompt, generalAssistantConfig.getLanguage()); return builder.build(); } @Override - public AIAssistant createAiService( - PlatformType platformType, AIAssistantConfigProvider assistantConfig, Long id, ToolProvider toolProvider) { - return createWithPrompt(platformType, assistantConfig, id, toolProvider, SystemPrompt.DEFAULT_PROMPT); + public AIAssistant createForTest(AIAssistantConfig config, ToolProvider toolProvider) { + GeneralAssistantConfig generalAssistantConfig = (GeneralAssistantConfig) config; + PlatformType platformType = generalAssistantConfig.getPlatformType(); + AIAssistant.Builder builder = initializeBuilder(platformType); + + builder.id(null) + .memoryStore(chatMemoryStoreProvider.createInMemoryChatMemoryStore()) + .withConfig(generalAssistantConfig) + .withToolProvider(toolProvider); + + return builder.build(); } } diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/AIAssistantConfig.java b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/config/GeneralAssistantConfig.java similarity index 68% rename from bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/AIAssistantConfig.java rename to bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/config/GeneralAssistantConfig.java index 132c37a14..2838529e1 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/AIAssistantConfig.java +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/config/GeneralAssistantConfig.java @@ -16,37 +16,34 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.bigtop.manager.ai.assistant.provider; +package org.apache.bigtop.manager.ai.assistant.config; -import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider; +import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig; +import org.apache.bigtop.manager.ai.core.enums.PlatformType; + +import lombok.Getter; import java.util.HashMap; import java.util.Map; +import java.util.Objects; -public class AIAssistantConfig implements AIAssistantConfigProvider { +@Getter +public class GeneralAssistantConfig implements AIAssistantConfig { - /** - * Model name for platform that we want to use - */ + private final Long id; private final String model; - - /** - * Credentials for different platforms - */ - private final Map credentials; - private final String language; - /** - * Platform extra configs are put here - */ + private final PlatformType platformType; + private final Map credentials; private final Map configs; - private AIAssistantConfig( - String model, Map credentials, String language, Map configMap) { - this.model = model; - this.credentials = credentials; - this.language = language; - this.configs = configMap; + private GeneralAssistantConfig(Builder builder) { + this.model = Objects.requireNonNull(builder.model); + this.credentials = Objects.requireNonNull(builder.credentials); + this.platformType = Objects.requireNonNull(builder.platformType); + this.language = builder.language; + this.id = builder.id; + this.configs = builder.configs; } public static Builder builder() { @@ -68,26 +65,29 @@ public Map getConfigs() { return configs; } - @Override - public String getLanguage() { - return language; - } - public static class Builder { + private Long id; private String model; private String language; - + private PlatformType platformType; private final Map credentials = new HashMap<>(); - private final Map configs = new HashMap<>(); - public Builder() {} - public Builder setModel(String model) { this.model = model; return this; } + public Builder setPlatformType(PlatformType platformType) { + this.platformType = platformType; + return this; + } + + public Builder setId(Long id) { + this.id = id; + return this; + } + public Builder addCredential(String key, String value) { credentials.put(key, value); return this; @@ -115,8 +115,8 @@ public Builder addConfigs(Map configMap) { return this; } - public AIAssistantConfig build() { - return new AIAssistantConfig(model, credentials, language, configs); + public GeneralAssistantConfig build() { + return new GeneralAssistantConfig(this); } } } diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/store/ChatMemoryStoreProvider.java b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/ChatMemoryStoreProvider.java similarity index 70% rename from bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/store/ChatMemoryStoreProvider.java rename to bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/ChatMemoryStoreProvider.java index 90e24e6c3..67003f10e 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/store/ChatMemoryStoreProvider.java +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/ChatMemoryStoreProvider.java @@ -16,31 +16,32 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.bigtop.manager.ai.assistant.store; +package org.apache.bigtop.manager.ai.assistant.provider; +import org.apache.bigtop.manager.ai.assistant.store.PersistentChatMemoryStore; import org.apache.bigtop.manager.dao.repository.ChatMessageDao; import org.apache.bigtop.manager.dao.repository.ChatThreadDao; +import org.springframework.stereotype.Component; + import dev.langchain4j.store.memory.chat.ChatMemoryStore; import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore; +import jakarta.annotation.Resource; + +@Component public class ChatMemoryStoreProvider { - private final ChatThreadDao chatThreadDao; - private final ChatMessageDao chatMessageDao; + @Resource + private ChatThreadDao chatThreadDao; - public ChatMemoryStoreProvider(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) { - this.chatThreadDao = chatThreadDao; - this.chatMessageDao = chatMessageDao; - } - - public ChatMemoryStoreProvider() { - this(null, null); - } + @Resource + private ChatMessageDao chatMessageDao; public ChatMemoryStore createPersistentChatMemoryStore() { - if (chatThreadDao == null || chatMessageDao == null) { - return new InMemoryChatMemoryStore(); - } return new PersistentChatMemoryStore(chatThreadDao, chatMessageDao); } + + public ChatMemoryStore createInMemoryChatMemoryStore() { + return new InMemoryChatMemoryStore(); + } } diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/LocSystemPromptProvider.java b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/LocSystemPromptProvider.java index a3cb98943..756bcf338 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/LocSystemPromptProvider.java +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/LocSystemPromptProvider.java @@ -21,6 +21,7 @@ import org.apache.bigtop.manager.ai.core.enums.SystemPrompt; import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider; +import org.springframework.stereotype.Component; import org.springframework.util.ResourceUtils; import lombok.extern.slf4j.Slf4j; @@ -33,6 +34,7 @@ import java.util.Objects; @Slf4j +@Component public class LocSystemPromptProvider implements SystemPromptProvider { @Override @@ -67,21 +69,14 @@ private String loadTextFromFile(String fileName) { private String loadPromptFromFile(String fileName) { final String filePath = fileName + ".st"; String text = loadTextFromFile(filePath); - if (text == null) { - return "You are a helpful assistant."; - } else { - return text; - } + return Objects.requireNonNullElse(text, "You are a helpful assistant."); } + @Override public String getLanguagePrompt(String locale) { final String filePath = SystemPrompt.LANGUAGE_PROMPT.getValue() + '-' + locale + ".st"; String text = loadTextFromFile(filePath); - if (text == null) { - return "Answer in " + locale; - } else { - return text; - } + return Objects.requireNonNullElseGet(text, () -> "Answer in " + locale); } @Override diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/store/PersistentChatMemoryStore.java b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/store/PersistentChatMemoryStore.java index 8bc3e980c..d165c7816 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/store/PersistentChatMemoryStore.java +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/store/PersistentChatMemoryStore.java @@ -41,8 +41,8 @@ public class PersistentChatMemoryStore implements ChatMemoryStore { private final Map> messagesByMemoryId = new ConcurrentHashMap<>(); - protected final ChatThreadDao chatThreadDao; - protected final ChatMessageDao chatMessageDao; + private final ChatThreadDao chatThreadDao; + private final ChatMessageDao chatMessageDao; public PersistentChatMemoryStore(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) { this.chatThreadDao = chatThreadDao; diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactoryTest.java b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactoryTest.java index 0a559b44b..b80149ad9 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactoryTest.java +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactoryTest.java @@ -18,16 +18,12 @@ */ package org.apache.bigtop.manager.ai.assistant; -import org.apache.bigtop.manager.ai.assistant.store.ChatMemoryStoreProvider; -import org.apache.bigtop.manager.ai.core.enums.PlatformType; +import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig; import org.apache.bigtop.manager.ai.core.factory.AIAssistant; -import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider; -import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider; import org.apache.bigtop.manager.ai.openai.OpenAIAssistant; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.MockitoAnnotations; @@ -42,23 +38,18 @@ class GeneralAssistantFactoryTest { @Mock - private SystemPromptProvider systemPromptProvider; + private AIAssistantConfig assistantConfigProvider; @Mock - private AIAssistantConfigProvider assistantConfigProvider; - - @InjectMocks private GeneralAssistantFactory generalAssistantFactory; @BeforeEach void setUp() { MockitoAnnotations.openMocks(this); - generalAssistantFactory = new GeneralAssistantFactory(systemPromptProvider, new ChatMemoryStoreProvider()); Map credentials = Map.of("apiKey", "123456"); when(assistantConfigProvider.getModel()).thenReturn("model"); when(assistantConfigProvider.getCredentials()).thenReturn(credentials); when(assistantConfigProvider.getConfigs()).thenReturn(null); - when(assistantConfigProvider.getLanguage()).thenReturn("en"); } @Test @@ -66,7 +57,7 @@ void testCreateAIAssistant() { AIAssistant.Builder mockBuilder = mock(OpenAIAssistant.Builder.class); when(mockBuilder.id(any())).thenReturn(mockBuilder); when(mockBuilder.memoryStore(any())).thenReturn(mockBuilder); - when(mockBuilder.withConfigProvider(any())).thenReturn(mockBuilder); + when(mockBuilder.withConfig(any())).thenReturn(mockBuilder); when(mockBuilder.withToolProvider(any())).thenReturn(mockBuilder); when(mockBuilder.withSystemPrompt(any())).thenReturn(mockBuilder); when(mockBuilder.build()).thenReturn(mock(AIAssistant.class)); @@ -74,10 +65,8 @@ void testCreateAIAssistant() { try (MockedStatic openAIAssistantMockedStatic = mockStatic(OpenAIAssistant.class)) { openAIAssistantMockedStatic.when(OpenAIAssistant::builder).thenReturn(mockBuilder); - PlatformType platformType = PlatformType.OPENAI; - generalAssistantFactory.create(platformType, assistantConfigProvider); - generalAssistantFactory = new GeneralAssistantFactory(new ChatMemoryStoreProvider()); - generalAssistantFactory.create(platformType, assistantConfigProvider); + generalAssistantFactory.createAIService(assistantConfigProvider, null); + generalAssistantFactory.createForTest(assistantConfigProvider, null); } } } diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/provider/AIAssistantConfigTest.java b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/provider/GeneralAssistantConfigTest.java similarity index 80% rename from bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/provider/AIAssistantConfigTest.java rename to bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/provider/GeneralAssistantConfigTest.java index 6b70a5fae..cbe44f035 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/provider/AIAssistantConfigTest.java +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/provider/GeneralAssistantConfigTest.java @@ -18,6 +18,9 @@ */ package org.apache.bigtop.manager.ai.assistant.provider; +import org.apache.bigtop.manager.ai.assistant.config.GeneralAssistantConfig; +import org.apache.bigtop.manager.ai.core.enums.PlatformType; + import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -31,9 +34,9 @@ import static org.junit.jupiter.api.Assertions.assertNull; @ExtendWith(MockitoExtension.class) -public class AIAssistantConfigTest { +public class GeneralAssistantConfigTest { - private AIAssistantConfig.Builder builder; + private GeneralAssistantConfig.Builder builder; private String model; private String language; private Map credentials; @@ -41,7 +44,7 @@ public class AIAssistantConfigTest { @BeforeEach public void setUp() { - builder = AIAssistantConfig.builder(); + builder = GeneralAssistantConfig.builder(); model = "test-model"; language = "en-US"; credentials = new HashMap<>(); @@ -53,7 +56,8 @@ public void setUp() { @Test public void testBuilderSetsValuesCorrectly() { - AIAssistantConfig config = builder.setModel(model) + GeneralAssistantConfig config = builder.setPlatformType(PlatformType.OPENAI) + .setModel(model) .setLanguage(language) .addCredentials(credentials) .addConfigs(configs) @@ -68,7 +72,8 @@ public void testBuilderSetsValuesCorrectly() { @Test public void testBuilderAddsSingleCredential() { - AIAssistantConfig config = builder.setModel(model) + GeneralAssistantConfig config = builder.setPlatformType(PlatformType.OPENAI) + .setModel(model) .setLanguage(language) .addCredential("client_id", "abcd1234") .build(); @@ -79,7 +84,8 @@ public void testBuilderAddsSingleCredential() { @Test public void testBuilderAddsSingleConfig() { - AIAssistantConfig config = builder.setModel(model) + GeneralAssistantConfig config = builder.setPlatformType(PlatformType.OPENAI) + .setModel(model) .setLanguage(language) .addConfig("threadId", "123") .build(); @@ -90,10 +96,10 @@ public void testBuilderAddsSingleConfig() { @Test public void testEmptyBuilder() { - AIAssistantConfig config = builder.build(); + GeneralAssistantConfig config = + builder.setPlatformType(PlatformType.OPENAI).setModel(model).build(); assertNotNull(config); - assertNull(config.getModel()); assertNull(config.getLanguage()); assertEquals(0, config.getCredentials().size()); assertEquals(0, config.getConfigs().size()); @@ -108,7 +114,8 @@ public void testMultipleCredentialsAndConfigs() { Map extraConfigs = new HashMap<>(); extraConfigs.put("retry", "3"); - AIAssistantConfig config = builder.setModel(model) + GeneralAssistantConfig config = builder.setPlatformType(PlatformType.OPENAI) + .setModel(model) .setLanguage(language) .addCredentials(extraCredentials) .addConfigs(extraConfigs) diff --git a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java index 4699a8401..58e8268a4 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java +++ b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java @@ -18,8 +18,8 @@ */ package org.apache.bigtop.manager.ai.core; +import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig; import org.apache.bigtop.manager.ai.core.factory.AIAssistant; -import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.MessageWindowChatMemory; @@ -61,7 +61,7 @@ public abstract static class Builder implements AIAssistant.Builder { protected Object id; protected ChatMemoryStore chatMemoryStore; - protected AIAssistantConfigProvider configProvider; + protected AIAssistantConfig config; protected ToolProvider toolProvider; protected String systemPrompt; @@ -78,8 +78,8 @@ public Builder withSystemPrompt(String systemPrompt) { return this; } - public Builder withConfigProvider(AIAssistantConfigProvider configProvider) { - this.configProvider = configProvider; + public Builder withConfig(AIAssistantConfig config) { + this.config = config; return this; } diff --git a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/provider/AIAssistantConfigProvider.java b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/config/AIAssistantConfig.java similarity index 88% rename from bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/provider/AIAssistantConfigProvider.java rename to bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/config/AIAssistantConfig.java index b49d99c86..42d8e9ef0 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/provider/AIAssistantConfigProvider.java +++ b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/config/AIAssistantConfig.java @@ -16,16 +16,14 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.bigtop.manager.ai.core.provider; +package org.apache.bigtop.manager.ai.core.config; import java.util.Map; -public interface AIAssistantConfigProvider { +public interface AIAssistantConfig { String getModel(); Map getCredentials(); Map getConfigs(); - - String getLanguage(); } diff --git a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java index 8257c4e78..9a7074e6f 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java +++ b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java @@ -18,8 +18,8 @@ */ package org.apache.bigtop.manager.ai.core.factory; +import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig; import org.apache.bigtop.manager.ai.core.enums.PlatformType; -import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -74,7 +74,7 @@ interface Builder { Builder memoryStore(ChatMemoryStore memoryStore); - Builder withConfigProvider(AIAssistantConfigProvider configProvider); + Builder withConfig(AIAssistantConfig configProvider); Builder withToolProvider(ToolProvider toolProvider); diff --git a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java index f06a61277..d947e778f 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java +++ b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java @@ -18,25 +18,18 @@ */ package org.apache.bigtop.manager.ai.core.factory; -import org.apache.bigtop.manager.ai.core.enums.PlatformType; +import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig; import org.apache.bigtop.manager.ai.core.enums.SystemPrompt; -import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider; import dev.langchain4j.service.tool.ToolProvider; public interface AIAssistantFactory { - AIAssistant createWithPrompt( - PlatformType platformType, - AIAssistantConfigProvider assistantConfig, - Object id, - ToolProvider toolProvider, - SystemPrompt systemPrompt); + AIAssistant createWithPrompt(AIAssistantConfig config, ToolProvider toolProvider, SystemPrompt systemPrompt); - default AIAssistant create(PlatformType platformType, AIAssistantConfigProvider assistantConfig) { - return createAiService(platformType, assistantConfig, null, null); - } + AIAssistant createForTest(AIAssistantConfig config, ToolProvider toolProvider); - AIAssistant createAiService( - PlatformType platformType, AIAssistantConfigProvider assistantConfig, Long id, ToolProvider toolProvider); + default AIAssistant createAIService(AIAssistantConfig config, ToolProvider toolProvider) { + return createWithPrompt(config, toolProvider, SystemPrompt.DEFAULT_PROMPT); + } } diff --git a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/ToolBox.java b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/ToolBox.java deleted file mode 100644 index 47ee00444..000000000 --- a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/ToolBox.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.bigtop.manager.ai.core.factory; - -import reactor.core.publisher.Flux; - -import java.util.List; - -public interface ToolBox { - - List getTools(); - - String invoke(String toolName); - - Flux streamInvoke(String toolName); -} diff --git a/bigtop-manager-ai/bigtop-manager-ai-dashscope/src/main/java/org/apache/bigtop/manager/ai/dashscope/DashScopeAssistant.java b/bigtop-manager-ai/bigtop-manager-ai-dashscope/src/main/java/org/apache/bigtop/manager/ai/dashscope/DashScopeAssistant.java index 53c6be9d6..9a6531ea7 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-dashscope/src/main/java/org/apache/bigtop/manager/ai/dashscope/DashScopeAssistant.java +++ b/bigtop-manager-ai/bigtop-manager-ai-dashscope/src/main/java/org/apache/bigtop/manager/ai/dashscope/DashScopeAssistant.java @@ -65,17 +65,17 @@ public AIAssistant build() { @Override public ChatLanguageModel getChatLanguageModel() { - String model = ValidationUtils.ensureNotNull(configProvider.getModel(), "model"); - String apiKey = ValidationUtils.ensureNotNull( - configProvider.getCredentials().get("apiKey"), "apiKey"); + String model = ValidationUtils.ensureNotNull(config.getModel(), "model"); + String apiKey = + ValidationUtils.ensureNotNull(config.getCredentials().get("apiKey"), "apiKey"); return QwenChatModel.builder().apiKey(apiKey).modelName(model).build(); } @Override public StreamingChatLanguageModel getStreamingChatLanguageModel() { - String model = ValidationUtils.ensureNotNull(configProvider.getModel(), "model"); - String apiKey = ValidationUtils.ensureNotNull( - configProvider.getCredentials().get("apiKey"), "apiKey"); + String model = ValidationUtils.ensureNotNull(config.getModel(), "model"); + String apiKey = + ValidationUtils.ensureNotNull(config.getCredentials().get("apiKey"), "apiKey"); return QwenStreamingChatModel.builder() .apiKey(apiKey) .modelName(model) diff --git a/bigtop-manager-ai/bigtop-manager-ai-openai/src/main/java/org/apache/bigtop/manager/ai/openai/OpenAIAssistant.java b/bigtop-manager-ai/bigtop-manager-ai-openai/src/main/java/org/apache/bigtop/manager/ai/openai/OpenAIAssistant.java index d5fb43cf7..ec60173cb 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-openai/src/main/java/org/apache/bigtop/manager/ai/openai/OpenAIAssistant.java +++ b/bigtop-manager-ai/bigtop-manager-ai-openai/src/main/java/org/apache/bigtop/manager/ai/openai/OpenAIAssistant.java @@ -51,9 +51,9 @@ public static class Builder extends AbstractAIAssistant.Builder { @Override public ChatLanguageModel getChatLanguageModel() { - String model = ValidationUtils.ensureNotNull(configProvider.getModel(), "model"); - String apiKey = ValidationUtils.ensureNotNull( - configProvider.getCredentials().get("apiKey"), "apiKey"); + String model = ValidationUtils.ensureNotNull(config.getModel(), "model"); + String apiKey = + ValidationUtils.ensureNotNull(config.getCredentials().get("apiKey"), "apiKey"); return OpenAiChatModel.builder() .apiKey(apiKey) .baseUrl(BASE_URL) @@ -63,9 +63,9 @@ public ChatLanguageModel getChatLanguageModel() { @Override public StreamingChatLanguageModel getStreamingChatLanguageModel() { - String model = ValidationUtils.ensureNotNull(configProvider.getModel(), "model"); - String apiKey = ValidationUtils.ensureNotNull( - configProvider.getCredentials().get("apiKey"), "apiKey"); + String model = ValidationUtils.ensureNotNull(config.getModel(), "model"); + String apiKey = + ValidationUtils.ensureNotNull(config.getCredentials().get("apiKey"), "apiKey"); return OpenAiStreamingChatModel.builder() .apiKey(apiKey) .baseUrl(BASE_URL) diff --git a/bigtop-manager-ai/bigtop-manager-ai-qianfan/src/main/java/org/apache/bigtop/manager/ai/qianfan/QianFanAssistant.java b/bigtop-manager-ai/bigtop-manager-ai-qianfan/src/main/java/org/apache/bigtop/manager/ai/qianfan/QianFanAssistant.java index 80010eb8b..344a88201 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-qianfan/src/main/java/org/apache/bigtop/manager/ai/qianfan/QianFanAssistant.java +++ b/bigtop-manager-ai/bigtop-manager-ai-qianfan/src/main/java/org/apache/bigtop/manager/ai/qianfan/QianFanAssistant.java @@ -65,11 +65,11 @@ public AIAssistant build() { @Override public ChatLanguageModel getChatLanguageModel() { - String model = ValidationUtils.ensureNotNull(configProvider.getModel(), "model"); - String apiKey = ValidationUtils.ensureNotNull( - configProvider.getCredentials().get("apiKey"), "apiKey"); - String secretKey = ValidationUtils.ensureNotNull( - configProvider.getCredentials().get("secretKey"), "secretKey"); + String model = ValidationUtils.ensureNotNull(config.getModel(), "model"); + String apiKey = + ValidationUtils.ensureNotNull(config.getCredentials().get("apiKey"), "apiKey"); + String secretKey = + ValidationUtils.ensureNotNull(config.getCredentials().get("secretKey"), "secretKey"); return QianfanChatModel.builder() .apiKey(apiKey) .secretKey(secretKey) @@ -79,11 +79,11 @@ public ChatLanguageModel getChatLanguageModel() { @Override public StreamingChatLanguageModel getStreamingChatLanguageModel() { - String model = ValidationUtils.ensureNotNull(configProvider.getModel(), "model"); - String apiKey = ValidationUtils.ensureNotNull( - configProvider.getCredentials().get("apiKey"), "apiKey"); - String secretKey = ValidationUtils.ensureNotNull( - configProvider.getCredentials().get("secretKey"), "secretKey"); + String model = ValidationUtils.ensureNotNull(config.getModel(), "model"); + String apiKey = + ValidationUtils.ensureNotNull(config.getCredentials().get("apiKey"), "apiKey"); + String secretKey = + ValidationUtils.ensureNotNull(config.getCredentials().get("secretKey"), "secretKey"); return QianfanStreamingChatModel.builder() .apiKey(apiKey) .secretKey(secretKey) diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/ServerApplication.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/ServerApplication.java index cc6a5a5c7..7c03ea334 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/ServerApplication.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/ServerApplication.java @@ -25,7 +25,12 @@ @EnableAsync @EnableScheduling -@SpringBootApplication(scanBasePackages = {"org.apache.bigtop.manager.server", "org.apache.bigtop.manager.common"}) +@SpringBootApplication( + scanBasePackages = { + "org.apache.bigtop.manager.server", + "org.apache.bigtop.manager.common", + "org.apache.bigtop.manager.ai" + }) public class ServerApplication { public static void main(String[] args) { diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/ChatbotServiceImpl.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/ChatbotServiceImpl.java index 415225815..45ae57aec 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/ChatbotServiceImpl.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/ChatbotServiceImpl.java @@ -18,9 +18,7 @@ */ package org.apache.bigtop.manager.server.service.impl; -import org.apache.bigtop.manager.ai.assistant.GeneralAssistantFactory; -import org.apache.bigtop.manager.ai.assistant.provider.AIAssistantConfig; -import org.apache.bigtop.manager.ai.assistant.store.ChatMemoryStoreProvider; +import org.apache.bigtop.manager.ai.assistant.config.GeneralAssistantConfig; import org.apache.bigtop.manager.ai.core.enums.MessageType; import org.apache.bigtop.manager.ai.core.enums.PlatformType; import org.apache.bigtop.manager.ai.core.factory.AIAssistant; @@ -60,7 +58,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Objects; @Service @Slf4j @@ -80,86 +77,30 @@ public class ChatbotServiceImpl implements ChatbotService { @Resource private AIServiceToolsProvider aiServiceToolsProvider; + @Resource private AIAssistantFactory aiAssistantFactory; - private static final int CHAT_THREAD_NAME_LENGTH = 100; - - public static String getNameFromMessage(String input) { - if (input == null || input.length() <= CHAT_THREAD_NAME_LENGTH) { - return input; - } else { - return input.substring(0, CHAT_THREAD_NAME_LENGTH); - } - } - - public AIAssistantFactory getAIAssistantFactory() { - if (aiAssistantFactory == null) { - aiAssistantFactory = - new GeneralAssistantFactory(new ChatMemoryStoreProvider(chatThreadDao, chatMessageDao)); - } - return aiAssistantFactory; - } - - private AuthPlatformPO getActiveAuthPlatform() { - List authPlatformPOS = authPlatformDao.findAll(); - for (AuthPlatformPO authPlatformPO : authPlatformPOS) { - if (AuthPlatformStatus.isActive(authPlatformPO.getStatus())) { - return authPlatformPO; - } - } - return null; - } - - private AIAssistantConfig getAIAssistantConfig(String model, Map credentials) { - return AIAssistantConfig.builder() - .setModel(model) - .setLanguage(LocaleContextHolder.getLocale().toString()) - .addCredentials(credentials) - .build(); - } - - private PlatformType getPlatformType(String platformName) { - return PlatformType.getPlatformType(platformName.toLowerCase()); - } - - private AIAssistant buildAIAssistant( - String platformName, String model, Map credentials, Long threadId, ChatbotCommand command) { - return getAIAssistantFactory() - .createAiService( - getPlatformType(platformName), - getAIAssistantConfig(model, credentials), - threadId, - aiServiceToolsProvider.getToolsProvide(command)); - } - @Override public ChatThreadVO createChatThread(ChatThreadDTO chatThreadDTO) { - AuthPlatformPO authPlatformPO = getActiveAuthPlatform(); - if (authPlatformPO == null || authPlatformPO.getIsDeleted()) { - throw new ApiException(ApiExceptionEnum.NO_PLATFORM_IN_USE); - } - - Long userId = SessionUserHolder.getUserId(); + AuthPlatformPO authPlatformPO = validateAndGetActiveAuthPlatform(); PlatformPO platformPO = platformDao.findById(authPlatformPO.getPlatformId()); chatThreadDTO.setPlatformId(platformPO.getId()); chatThreadDTO.setAuthId(authPlatformPO.getId()); ChatThreadPO chatThreadPO = ChatThreadConverter.INSTANCE.fromDTO2PO(chatThreadDTO); - chatThreadPO.setUserId(userId); + chatThreadPO.setUserId(SessionUserHolder.getUserId()); chatThreadDao.save(chatThreadPO); + return ChatThreadConverter.INSTANCE.fromPO2VO(chatThreadPO, authPlatformPO, platformPO); } @Override public boolean deleteChatThread(Long threadId) { - ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId); - if (chatThreadPO == null || chatThreadPO.getIsDeleted()) { - throw new ApiException(ApiExceptionEnum.CHAT_THREAD_NOT_FOUND); - } - + ChatThreadPO chatThreadPO = validateAndGetChatThread(threadId); chatThreadPO.setIsDeleted(true); chatThreadDao.partialUpdateById(chatThreadPO); + List chatMessagePOS = chatMessageDao.findAllByThreadId(threadId); chatMessagePOS.forEach(chatMessagePO -> chatMessagePO.setIsDeleted(true)); chatMessageDao.partialUpdateByIds(chatMessagePOS); @@ -169,15 +110,11 @@ public boolean deleteChatThread(Long threadId) { @Override public List getAllChatThreads() { - AuthPlatformPO authPlatformPO = getActiveAuthPlatform(); - if (authPlatformPO == null) { - throw new ApiException(ApiExceptionEnum.NO_PLATFORM_IN_USE); - } + AuthPlatformPO authPlatformPO = validateAndGetActiveAuthPlatform(); PlatformPO platformPO = platformDao.findById(authPlatformPO.getPlatformId()); - Long authId = authPlatformPO.getId(); - Long userId = SessionUserHolder.getUserId(); - List chatThreadPOS = chatThreadDao.findAllByAuthIdAndUserId(authId, userId); + List chatThreadPOS = + chatThreadDao.findAllByAuthIdAndUserId(authPlatformPO.getId(), SessionUserHolder.getUserId()); List chatThreads = new ArrayList<>(); for (ChatThreadPO chatThreadPO : chatThreadPOS) { if (chatThreadPO.getIsDeleted()) { @@ -190,89 +127,30 @@ public List getAllChatThreads() { return chatThreads; } - private AIAssistant prepareTalk(Long threadId, ChatbotCommand command) { - ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId); - Long userId = SessionUserHolder.getUserId(); - if (!Objects.equals(userId, chatThreadPO.getUserId()) || chatThreadPO.getIsDeleted()) { - throw new ApiException(ApiExceptionEnum.CHAT_THREAD_NOT_FOUND); - } - AuthPlatformPO authPlatformPO = getActiveAuthPlatform(); - if (authPlatformPO == null - || authPlatformPO.getIsDeleted() - || !authPlatformPO.getId().equals(chatThreadPO.getAuthId())) { - throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_IN_USE); - } - - AuthPlatformDTO authPlatformDTO = AuthPlatformConverter.INSTANCE.fromPO2DTO(authPlatformPO); - - PlatformPO platformPO = platformDao.findById(authPlatformPO.getPlatformId()); - return buildAIAssistant( - platformPO.getName(), - authPlatformDTO.getModel(), - authPlatformDTO.getAuthCredentials(), - threadId, - command); - } - @Override public SseEmitter talk(Long threadId, ChatbotCommand command, String message) { AIAssistant aiAssistant = prepareTalk(threadId, command); - Flux stringFlux; - if (command == null) { - stringFlux = aiAssistant.streamAsk(message); - } else { - stringFlux = Flux.just(aiAssistant.ask(message)); - } + + Flux stringFlux = + (command == null) ? aiAssistant.streamAsk(message) : Flux.just(aiAssistant.ask(message)); + SseEmitter emitter = new SseEmitter(); + stringFlux.subscribe( - s -> { - try { - TalkVO talkVO = new TalkVO(); - talkVO.setContent(s); - talkVO.setFinishReason(null); - emitter.send(talkVO); - } catch (Exception e) { - emitter.completeWithError(e); - } - }, - throwable -> { - try { - TalkVO errorVO = new TalkVO(); - errorVO.setContent(null); - errorVO.setFinishReason("Error: " + throwable.getMessage()); - emitter.send(errorVO); - } catch (Exception sendException) { - sendException.printStackTrace(); - } - emitter.completeWithError(throwable); - }, - () -> { - try { - TalkVO finishVO = new TalkVO(); - finishVO.setContent(null); - finishVO.setFinishReason("completed"); - emitter.send(finishVO); - } catch (Exception e) { - e.printStackTrace(); - } - emitter.complete(); - }); + s -> sendTalkVO(emitter, s, null), + throwable -> handleError(emitter, throwable), + () -> completeEmitter(emitter)); emitter.onTimeout(emitter::complete); + return emitter; } @Override public List history(Long threadId) { List chatMessages = new ArrayList<>(); - ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId); - if (chatThreadPO == null || chatThreadPO.getIsDeleted()) { - throw new ApiException(ApiExceptionEnum.CHAT_THREAD_NOT_FOUND); - } - Long userId = SessionUserHolder.getUserId(); - if (!chatThreadPO.getUserId().equals(userId)) { - throw new ApiException(ApiExceptionEnum.PERMISSION_DENIED); - } + validateAndGetChatThread(threadId); + List chatMessagePOs = chatMessageDao.findAllByThreadId(threadId); for (ChatMessagePO chatMessagePO : chatMessagePOs) { ChatMessageVO chatMessageVO = ChatMessageConverter.INSTANCE.fromPO2VO(chatMessagePO); @@ -289,14 +167,7 @@ public List history(Long threadId) { @Override public ChatThreadVO updateChatThread(ChatThreadDTO chatThreadDTO) { - ChatThreadPO chatThreadPO = chatThreadDao.findById(chatThreadDTO.getId()); - if (chatThreadPO == null || chatThreadPO.getIsDeleted()) { - throw new ApiException(ApiExceptionEnum.CHAT_THREAD_NOT_FOUND); - } - Long userId = SessionUserHolder.getUserId(); - if (!chatThreadPO.getUserId().equals(userId)) { - throw new ApiException(ApiExceptionEnum.PERMISSION_DENIED); - } + ChatThreadPO chatThreadPO = validateAndGetChatThread(chatThreadDTO.getId()); chatThreadPO.setName(chatThreadDTO.getName()); chatThreadDao.partialUpdateById(chatThreadPO); @@ -313,6 +184,28 @@ public List getChatbotCommands() { @Override public ChatThreadVO getChatThread(Long threadId) { + ChatThreadPO chatThreadPO = validateAndGetChatThread(threadId); + + AuthPlatformPO authPlatformPO = authPlatformDao.findById(chatThreadPO.getAuthId()); + return ChatThreadConverter.INSTANCE.fromPO2VO( + chatThreadPO, authPlatformPO, platformDao.findById(authPlatformPO.getPlatformId())); + } + + private AuthPlatformPO validateAndGetActiveAuthPlatform() { + AuthPlatformPO authPlatform = null; + List authPlatformPOS = authPlatformDao.findAll(); + for (AuthPlatformPO authPlatformPO : authPlatformPOS) { + if (AuthPlatformStatus.isActive(authPlatformPO.getStatus())) { + authPlatform = authPlatformPO; + } + } + if (authPlatform == null || authPlatform.getIsDeleted()) { + throw new ApiException(ApiExceptionEnum.NO_PLATFORM_IN_USE); + } + return authPlatform; + } + + private ChatThreadPO validateAndGetChatThread(Long threadId) { ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId); if (chatThreadPO == null || chatThreadPO.getIsDeleted()) { throw new ApiException(ApiExceptionEnum.CHAT_THREAD_NOT_FOUND); @@ -321,8 +214,70 @@ public ChatThreadVO getChatThread(Long threadId) { if (!chatThreadPO.getUserId().equals(userId)) { throw new ApiException(ApiExceptionEnum.PERMISSION_DENIED); } - AuthPlatformPO authPlatformPO = authPlatformDao.findById(chatThreadPO.getAuthId()); - return ChatThreadConverter.INSTANCE.fromPO2VO( - chatThreadPO, authPlatformPO, platformDao.findById(authPlatformPO.getPlatformId())); + return chatThreadPO; + } + + private GeneralAssistantConfig getAIAssistantConfig( + String platformName, String model, Map credentials, Long id) { + return GeneralAssistantConfig.builder() + .setPlatformType(getPlatformType(platformName)) + .setModel(model) + .setId(id) + .setLanguage(LocaleContextHolder.getLocale().toString()) + .addCredentials(credentials) + .build(); + } + + private PlatformType getPlatformType(String platformName) { + return PlatformType.getPlatformType(platformName.toLowerCase()); + } + + private AIAssistant buildAIAssistant( + String platformName, String model, Map credentials, Long threadId, ChatbotCommand command) { + return aiAssistantFactory.createAIService( + getAIAssistantConfig(platformName, model, credentials, threadId), + aiServiceToolsProvider.getToolsProvide(command)); + } + + private AIAssistant prepareTalk(Long threadId, ChatbotCommand command) { + ChatThreadPO chatThreadPO = validateAndGetChatThread(threadId); + AuthPlatformPO authPlatformPO = validateAndGetActiveAuthPlatform(); + + if (!authPlatformPO.getId().equals(chatThreadPO.getAuthId())) { + throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_IN_USE); + } + + AuthPlatformDTO authPlatformDTO = AuthPlatformConverter.INSTANCE.fromPO2DTO(authPlatformPO); + PlatformPO platformPO = platformDao.findById(authPlatformPO.getPlatformId()); + + return buildAIAssistant( + platformPO.getName(), + authPlatformDTO.getModel(), + authPlatformDTO.getAuthCredentials(), + threadId, + command); + } + + private void sendTalkVO(SseEmitter emitter, String content, String finishReason) { + try { + TalkVO talkVO = new TalkVO(); + talkVO.setContent(content); + talkVO.setFinishReason(finishReason); + emitter.send(talkVO); + } catch (Exception e) { + log.error("Error sending data to SseEmitter", e); + emitter.completeWithError(e); + } + } + + private void handleError(SseEmitter emitter, Throwable throwable) { + log.error("Error during SSE streaming: {}", throwable.getMessage(), throwable); + sendTalkVO(emitter, null, "Error: " + throwable.getMessage()); + emitter.completeWithError(throwable); + } + + private void completeEmitter(SseEmitter emitter) { + sendTalkVO(emitter, null, "completed"); + emitter.complete(); } } diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/LLMConfigServiceImpl.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/LLMConfigServiceImpl.java index 7b1e1ead0..1ca5558e8 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/LLMConfigServiceImpl.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/LLMConfigServiceImpl.java @@ -18,9 +18,7 @@ */ package org.apache.bigtop.manager.server.service.impl; -import org.apache.bigtop.manager.ai.assistant.GeneralAssistantFactory; -import org.apache.bigtop.manager.ai.assistant.provider.AIAssistantConfig; -import org.apache.bigtop.manager.ai.assistant.store.ChatMemoryStoreProvider; +import org.apache.bigtop.manager.ai.assistant.config.GeneralAssistantConfig; import org.apache.bigtop.manager.ai.core.enums.PlatformType; import org.apache.bigtop.manager.ai.core.factory.AIAssistant; import org.apache.bigtop.manager.ai.core.factory.AIAssistantFactory; @@ -77,91 +75,12 @@ public class LLMConfigServiceImpl implements LLMConfigService { @Resource private ChatMessageDao chatMessageDao; + @Resource private AIAssistantFactory aiAssistantFactory; private static final String TEST_FLAG = "ZmxhZw=="; private static final String TEST_KEY = "bm"; - public AIAssistantFactory getAIAssistantFactory() { - if (aiAssistantFactory == null) { - aiAssistantFactory = - new GeneralAssistantFactory(new ChatMemoryStoreProvider(chatThreadDao, chatMessageDao)); - } - return aiAssistantFactory; - } - - private AIAssistantConfig getAIAssistantConfig( - String model, Map credentials, Map configs) { - return AIAssistantConfig.builder() - .setModel(model) - .setLanguage(LocaleContextHolder.getLocale().toString()) - .addCredentials(credentials) - .addConfigs(configs) - .build(); - } - - private PlatformType getPlatformType(String platformName) { - return PlatformType.getPlatformType(platformName.toLowerCase()); - } - - private Boolean testAuthorization(String platformName, String model, Map credentials) { - Boolean result = testFuncCalling(platformName, model, credentials); - log.info("Test func calling result: {}", result); - AIAssistantConfig aiAssistantConfig = getAIAssistantConfig(model, credentials, null); - AIAssistant aiAssistant = getAIAssistantFactory().create(getPlatformType(platformName), aiAssistantConfig); - try { - return aiAssistant.test(); - } catch (Exception e) { - throw new ApiException(ApiExceptionEnum.CREDIT_INCORRECT, e.getMessage()); - } - } - - private Boolean testFuncCalling(String platformName, String model, Map credentials) { - ToolProvider toolProvider = (toolProviderRequest) -> { - ToolSpecification toolSpecification = ToolSpecification.builder() - .name("getFlag") - .description("Get flag based on key") - .addParameter("key", JsonSchemaProperty.STRING, JsonSchemaProperty.description("Lowercase key")) - .build(); - ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> { - Map arguments = JsonUtils.readFromString(toolExecutionRequest.arguments()); - String key = arguments.get("key").toString(); - if (key.equals(TEST_KEY)) { - return TEST_FLAG; - } - return null; - }; - - return ToolProviderResult.builder() - .add(toolSpecification, toolExecutor) - .build(); - }; - - AIAssistantConfig aiAssistantConfig = getAIAssistantConfig(model, credentials, null); - AIAssistant aiAssistant = getAIAssistantFactory() - .createAiService(getPlatformType(platformName), aiAssistantConfig, null, toolProvider); - try { - return aiAssistant.ask("What is the flag of " + TEST_KEY).contains(TEST_FLAG); - } catch (Exception e) { - throw new ApiException(ApiExceptionEnum.CREDIT_INCORRECT, e.getMessage()); - } - } - - private void switchActivePlatform(Long id) { - List authPlatformPOS = authPlatformDao.findAll(); - for (AuthPlatformPO authPlatformPO : authPlatformPOS) { - if (!AuthPlatformStatus.available(authPlatformPO.getStatus())) { - continue; - } - if (authPlatformPO.getId().equals(id)) { - authPlatformPO.setStatus(AuthPlatformStatus.ACTIVE.getCode()); - } else { - authPlatformPO.setStatus(AuthPlatformStatus.AVAILABLE.getCode()); - } - } - authPlatformDao.partialUpdateByIds(authPlatformPOS); - } - @Override public List platforms() { List platformPOs = platformDao.findAll(); @@ -202,10 +121,7 @@ public List authorizedPlatforms() { @Override public AuthPlatformVO addAuthorizedPlatform(AuthPlatformDTO authPlatformDTO) { - PlatformPO platformPO = platformDao.findById(authPlatformDTO.getPlatformId()); - if (platformPO == null) { - throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_FOUND); - } + PlatformPO platformPO = validateAndGetPlatform(authPlatformDTO.getPlatformId()); Map credentialSet = getStringMap(authPlatformDTO, PlatformConverter.INSTANCE.fromPO2DTO(platformPO)); @@ -222,28 +138,9 @@ public AuthPlatformVO addAuthorizedPlatform(AuthPlatformDTO authPlatformDTO) { return AuthPlatformConverter.INSTANCE.fromPO2VO(authPlatformPO, platformPO); } - private static @NotNull Map getStringMap(AuthPlatformDTO authPlatformDTO, PlatformDTO platformDTO) { - if (platformDTO == null) { - throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_FOUND); - } - Map credentialNeed = platformDTO.getAuthCredentials(); - Map credentialGet = authPlatformDTO.getAuthCredentials(); - Map credentialSet = new HashMap<>(); - for (String key : credentialNeed.keySet()) { - if (!credentialGet.containsKey(key)) { - throw new ApiException(ApiExceptionEnum.CREDIT_INCORRECT); - } - credentialSet.put(key, credentialGet.get(key)); - } - return credentialSet; - } - @Override public boolean deleteAuthorizedPlatform(Long authId) { - AuthPlatformPO authPlatformPO = authPlatformDao.findById(authId); - if (authPlatformPO == null || authPlatformPO.getIsDeleted()) { - throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_AUTHORIZED); - } + AuthPlatformPO authPlatformPO = validateAndGetAuthPlatform(authId); if (AuthPlatformStatus.isActive(authPlatformPO.getStatus())) { throw new ApiException(ApiExceptionEnum.PLATFORM_IS_ACTIVE); @@ -253,15 +150,7 @@ public boolean deleteAuthorizedPlatform(Long authId) { authPlatformDao.partialUpdateById(authPlatformPO); List chatThreadPOS = chatThreadDao.findAllByAuthId(authPlatformPO.getId()); - for (ChatThreadPO chatThreadPO : chatThreadPOS) { - chatThreadPO.setIsDeleted(true); - chatThreadDao.partialUpdateById(chatThreadPO); - List chatMessagePOS = chatMessageDao.findAllByThreadId(chatThreadPO.getId()); - for (ChatMessagePO chatMessagePO : chatMessagePOS) { - chatMessagePO.setIsDeleted(true); - chatMessageDao.partialUpdateById(chatMessagePO); - } - } + softDeleteChatThreads(chatThreadPOS); return true; } @@ -273,20 +162,16 @@ public boolean testAuthorizedPlatform(AuthPlatformDTO authPlatformDTO) { AuthPlatformConverter.INSTANCE.fromPO2DTO(authPlatformDao.findById(authPlatformDTO.getId())); } - PlatformPO platformPO = platformDao.findById(authPlatformDTO.getPlatformId()); - if (platformPO == null) { - throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_FOUND); - } + PlatformPO platformPO = validateAndGetPlatform(authPlatformDTO.getPlatformId()); + List supportModels = List.of(platformPO.getSupportModels().split(",")); if (supportModels.isEmpty() || !supportModels.contains(authPlatformDTO.getModel())) { throw new ApiException(ApiExceptionEnum.MODEL_NOT_SUPPORTED); } if (authPlatformDTO.getId() != null) { - AuthPlatformPO authPlatformPO = authPlatformDao.findById(authPlatformDTO.getId()); - if (authPlatformPO == null || authPlatformPO.getIsDeleted()) { - throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_AUTHORIZED); - } + AuthPlatformPO authPlatformPO = validateAndGetAuthPlatform(authPlatformDTO.getId()); + AuthPlatformDTO existAuthPlatformDTO = AuthPlatformConverter.INSTANCE.fromPO2DTO(authPlatformPO); authPlatformDTO.setAuthCredentials(existAuthPlatformDTO.getAuthCredentials()); authPlatformDTO.setModel(existAuthPlatformDTO.getModel()); @@ -294,6 +179,7 @@ public boolean testAuthorizedPlatform(AuthPlatformDTO authPlatformDTO) { Map credentialSet = getStringMap(authPlatformDTO, PlatformConverter.INSTANCE.fromPO2DTO(platformPO)); + if (!testAuthorization(platformPO.getName(), authPlatformDTO.getModel(), credentialSet)) { throw new ApiException(ApiExceptionEnum.CREDIT_INCORRECT); } @@ -309,10 +195,7 @@ public boolean testAuthorizedPlatform(AuthPlatformDTO authPlatformDTO) { @Override public AuthPlatformVO updateAuthorizedPlatform(AuthPlatformDTO authPlatformDTO) { - AuthPlatformPO authPlatformPO = authPlatformDao.findById(authPlatformDTO.getId()); - if (authPlatformPO == null || authPlatformPO.getIsDeleted()) { - throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_AUTHORIZED); - } + AuthPlatformPO authPlatformPO = validateAndGetAuthPlatform(authPlatformDTO.getId()); String newModel = authPlatformDTO.getModel(); if (newModel != null) { @@ -340,10 +223,7 @@ public AuthPlatformVO updateAuthorizedPlatform(AuthPlatformDTO authPlatformDTO) @Override public boolean activateAuthorizedPlatform(Long authId) { - AuthPlatformPO authPlatformPO = authPlatformDao.findById(authId); - if (authPlatformPO == null || authPlatformPO.getIsDeleted()) { - throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_AUTHORIZED); - } + AuthPlatformPO authPlatformPO = validateAndGetAuthPlatform(authId); if (!AuthPlatformStatus.available(authPlatformPO.getStatus())) { return false; @@ -358,10 +238,8 @@ public boolean activateAuthorizedPlatform(Long authId) { @Override public boolean deactivateAuthorizedPlatform(Long authId) { - AuthPlatformPO authPlatformPO = authPlatformDao.findById(authId); - if (authPlatformPO == null || authPlatformPO.getIsDeleted()) { - throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_AUTHORIZED); - } + AuthPlatformPO authPlatformPO = validateAndGetAuthPlatform(authId); + AuthPlatformStatus authPlatformStatus = AuthPlatformStatus.fromCode(authPlatformPO.getStatus()); if (authPlatformStatus.equals(AuthPlatformStatus.ACTIVE)) { authPlatformPO.setStatus(AuthPlatformStatus.AVAILABLE.getCode()); @@ -373,10 +251,7 @@ public boolean deactivateAuthorizedPlatform(Long authId) { @Override public AuthPlatformVO getAuthorizedPlatform(Long authId) { - AuthPlatformPO authPlatformPO = authPlatformDao.findById(authId); - if (authPlatformPO == null || authPlatformPO.getIsDeleted()) { - throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_FOUND); - } + AuthPlatformPO authPlatformPO = validateAndGetAuthPlatform(authId); return AuthPlatformConverter.INSTANCE.fromPO2VO( authPlatformPO, platformDao.findById(authPlatformPO.getPlatformId())); @@ -384,10 +259,134 @@ public AuthPlatformVO getAuthorizedPlatform(Long authId) { @Override public PlatformVO getPlatform(Long id) { - PlatformPO platformPO = platformDao.findById(id); + PlatformPO platformPO = validateAndGetPlatform(id); + + return PlatformConverter.INSTANCE.fromPO2VO(platformPO); + } + + public PlatformPO validateAndGetPlatform(Long platformId) { + if (platformId == null) { + throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_FOUND); + } + + PlatformPO platformPO = platformDao.findById(platformId); if (platformPO == null) { throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_FOUND); } - return PlatformConverter.INSTANCE.fromPO2VO(platformPO); + return platformPO; + } + + public AuthPlatformPO validateAndGetAuthPlatform(Long authId) { + if (authId == null) { + throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_AUTHORIZED); + } + AuthPlatformPO authPlatformPO = authPlatformDao.findById(authId); + if (authPlatformPO == null || authPlatformPO.getIsDeleted()) { + throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_AUTHORIZED); + } + return authPlatformPO; + } + + private GeneralAssistantConfig getAIAssistantConfig( + String platformName, String model, Map credentials) { + return GeneralAssistantConfig.builder() + .setPlatformType(getPlatformType(platformName)) + .setModel(model) + .setLanguage(LocaleContextHolder.getLocale().toString()) + .addCredentials(credentials) + .build(); + } + + private PlatformType getPlatformType(String platformName) { + return PlatformType.getPlatformType(platformName.toLowerCase()); + } + + private Boolean testAuthorization(String platformName, String model, Map credentials) { + Boolean result = testFuncCalling(platformName, model, credentials); + log.info("Test func calling result: {}", result); + GeneralAssistantConfig generalAssistantConfig = getAIAssistantConfig(platformName, model, credentials); + AIAssistant aiAssistant = aiAssistantFactory.createForTest(generalAssistantConfig, null); + try { + return aiAssistant.test(); + } catch (Exception e) { + throw new ApiException(ApiExceptionEnum.CREDIT_INCORRECT, e.getMessage()); + } + } + + private Boolean testFuncCalling(String platformName, String model, Map credentials) { + ToolProvider toolProvider = (toolProviderRequest) -> { + ToolSpecification toolSpecification = ToolSpecification.builder() + .name("getFlag") + .description("Get flag based on key") + .addParameter("key", JsonSchemaProperty.STRING, JsonSchemaProperty.description("Lowercase key")) + .build(); + ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> { + Map arguments = JsonUtils.readFromString(toolExecutionRequest.arguments()); + String key = arguments.get("key").toString(); + if (key.equals(TEST_KEY)) { + return TEST_FLAG; + } + return null; + }; + + return ToolProviderResult.builder() + .add(toolSpecification, toolExecutor) + .build(); + }; + + GeneralAssistantConfig generalAssistantConfig = getAIAssistantConfig(platformName, model, credentials); + AIAssistant aiAssistant = aiAssistantFactory.createForTest(generalAssistantConfig, toolProvider); + try { + return aiAssistant.ask("What is the flag of " + TEST_KEY).contains(TEST_FLAG); + } catch (Exception e) { + throw new ApiException(ApiExceptionEnum.CREDIT_INCORRECT, e.getMessage()); + } + } + + private void switchActivePlatform(Long id) { + List authPlatformPOS = authPlatformDao.findAll(); + for (AuthPlatformPO authPlatformPO : authPlatformPOS) { + if (!AuthPlatformStatus.available(authPlatformPO.getStatus())) { + continue; + } + if (authPlatformPO.getId().equals(id)) { + authPlatformPO.setStatus(AuthPlatformStatus.ACTIVE.getCode()); + } else { + authPlatformPO.setStatus(AuthPlatformStatus.AVAILABLE.getCode()); + } + } + authPlatformDao.partialUpdateByIds(authPlatformPOS); + } + + private void softDeleteChatMessages(List chatMessagePOS) { + for (ChatMessagePO chatMessagePO : chatMessagePOS) { + chatMessagePO.setIsDeleted(true); + chatMessageDao.partialUpdateById(chatMessagePO); + } + } + + private void softDeleteChatThreads(List chatThreadPOS) { + for (ChatThreadPO chatThreadPO : chatThreadPOS) { + chatThreadPO.setIsDeleted(true); + chatThreadDao.partialUpdateById(chatThreadPO); + List chatMessagePOS = chatMessageDao.findAllByThreadId(chatThreadPO.getId()); + softDeleteChatMessages(chatMessagePOS); + } + } + + private static @NotNull Map getStringMap(AuthPlatformDTO authPlatformDTO, PlatformDTO platformDTO) { + if (platformDTO == null) { + throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_FOUND); + } + Map credentialNeed = platformDTO.getAuthCredentials(); + Map credentialGet = authPlatformDTO.getAuthCredentials(); + Map credentialSet = new HashMap<>(); + for (String key : credentialNeed.keySet()) { + if (!credentialGet.containsKey(key)) { + throw new ApiException(ApiExceptionEnum.CREDIT_INCORRECT); + } + credentialSet.put(key, credentialGet.get(key)); + } + return credentialSet; } }