Skip to content

Commit

Permalink
BIGTOP-4209: Add backend API for AI Chat functionality (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq authored Aug 30, 2024
1 parent 6121e4c commit 492c11d
Show file tree
Hide file tree
Showing 50 changed files with 1,560 additions and 263 deletions.
4 changes: 4 additions & 0 deletions bigtop-manager-ai/bigtop-manager-ai-assistant/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
<groupId>org.apache.bigtop</groupId>
<artifactId>bigtop-manager-ai-core</artifactId>
</dependency>
<dependency>
<groupId>org.apache.bigtop</groupId>
<artifactId>bigtop-manager-dao</artifactId>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@
import org.apache.bigtop.manager.ai.assistant.provider.LocSystemPromptProvider;
import org.apache.bigtop.manager.ai.core.AbstractAIAssistantFactory;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.enums.SystemPrompt;
import org.apache.bigtop.manager.ai.core.exception.PlatformNotFoundException;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.factory.ToolBox;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;
import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider;
import org.apache.bigtop.manager.ai.openai.OpenAIAssistant;

import org.apache.commons.lang3.NotImplementedException;

import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
Expand All @@ -35,13 +39,19 @@

public class GeneralAssistantFactory extends AbstractAIAssistantFactory {

private SystemPromptProvider systemPromptProvider = new LocSystemPromptProvider();
private ChatMemoryStore chatMemoryStore = new InMemoryChatMemoryStore();
private final SystemPromptProvider systemPromptProvider;
private final ChatMemoryStore chatMemoryStore;

public GeneralAssistantFactory() {}
public GeneralAssistantFactory() {
this(new LocSystemPromptProvider(), new InMemoryChatMemoryStore());
}

public GeneralAssistantFactory(SystemPromptProvider systemPromptProvider) {
this.systemPromptProvider = systemPromptProvider;
this(systemPromptProvider, new InMemoryChatMemoryStore());
}

public GeneralAssistantFactory(ChatMemoryStore chatMemoryStore) {
this(new LocSystemPromptProvider(), chatMemoryStore);
}

public GeneralAssistantFactory(SystemPromptProvider systemPromptProvider, ChatMemoryStore chatMemoryStore) {
Expand All @@ -51,29 +61,33 @@ public GeneralAssistantFactory(SystemPromptProvider systemPromptProvider, ChatMe

@Override
public AIAssistant createWithPrompt(
PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id, Object promptId) {
AIAssistant aiAssistant = create(platformType, assistantConfig, id);
SystemMessage systemPrompt = systemPromptProvider.getSystemPrompt(promptId);
aiAssistant.setSystemPrompt(systemPrompt);
return aiAssistant;
}

@Override
public AIAssistant create(PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id) {
PlatformType platformType,
AIAssistantConfigProvider assistantConfig,
Object id,
SystemPrompt systemPrompts) {
AIAssistant aiAssistant;
if (Objects.requireNonNull(platformType) == PlatformType.OPENAI) {
AIAssistant aiAssistant = OpenAIAssistant.builder()
aiAssistant = OpenAIAssistant.builder()
.id(id)
.memoryStore(chatMemoryStore)
.withConfigProvider(assistantConfig)
.build();
aiAssistant.setSystemPrompt(systemPromptProvider.getSystemPrompt());
return aiAssistant;
} else {
throw new PlatformNotFoundException(platformType.getValue());
}
return null;

SystemMessage systemPrompt = systemPromptProvider.getSystemPrompt(systemPrompts);
aiAssistant.setSystemPrompt(systemPrompt);
return aiAssistant;
}

@Override
public AIAssistant create(PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id) {
return createWithPrompt(platformType, assistantConfig, id, SystemPrompt.DEFAULT_PROMPT);
}

@Override
public ToolBox createToolBox(PlatformType platformType) {
return null;
throw new NotImplementedException("ToolBox is not implemented for GeneralAssistantFactory");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,42 +24,83 @@
import java.util.Map;

public class AIAssistantConfig implements AIAssistantConfigProvider {
private final Map<String, String> configMap;

private AIAssistantConfig(Map<String, String> configMap) {
this.configMap = configMap;
/**
* Model name for platform that we want to use
*/
private final String model;

/**
* Credentials for different platforms
*/
private final Map<String, String> credentials;

/**
* Platform extra configs are put here
*/
private final Map<String, String> configs;

private AIAssistantConfig(String model, Map<String, String> credentials, Map<String, String> configMap) {
this.model = model;
this.credentials = credentials;
this.configs = configMap;
}

public static Builder builder() {
return new Builder();
}

public static Builder withDefault(String baseUrl, String apiKey) {
Builder builder = new Builder();
return builder.set("baseUrl", baseUrl).set("apiKey", apiKey);
@Override
public String getModel() {
return model;
}

@Override
public Map<String, String> configs() {
public Map<String, String> getCredentials() {
return credentials;
}

return configMap;
@Override
public Map<String, String> getConfigs() {
return configs;
}

public static class Builder {
private final Map<String, String> configs;
private String model;

private final Map<String, String> credentials = new HashMap<>();

private final Map<String, String> configs = new HashMap<>();

public Builder() {
configs = new HashMap<>();
configs.put("memoryLen", "30");
public Builder() {}

public Builder setModel(String model) {
this.model = model;
return this;
}

public Builder set(String key, String value) {
public Builder addCredential(String key, String value) {
credentials.put(key, value);
return this;
}

public Builder addCredentials(Map<String, String> credentialMap) {
credentials.putAll(credentialMap);
return this;
}

public Builder addConfig(String key, String value) {
configs.put(key, value);
return this;
}

public Builder addConfigs(Map<String, String> configMap) {
configs.putAll(configMap);
return this;
}

public AIAssistantConfig build() {
return new AIAssistantConfig(configs);
return new AIAssistantConfig(model, credentials, configs);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.bigtop.manager.ai.assistant.provider;

import org.apache.bigtop.manager.ai.core.enums.SystemPrompt;
import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider;

import org.springframework.util.ResourceUtils;
Expand All @@ -29,41 +30,37 @@
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.Objects;

@Slf4j
public class LocSystemPromptProvider implements SystemPromptProvider {

public static final String DEFAULT = "default";
private static final String SYSTEM_PROMPT_PATH = "src/main/resources/";
private static final String DEFAULT_NAME = "big-data-professor.st";

@Override
public SystemMessage getSystemPrompt(Object id) {
if (Objects.equals(id.toString(), DEFAULT)) {
return getSystemPrompt();
} else {
return loadPromptFromFile(id.toString());
public SystemMessage getSystemPrompt(SystemPrompt systemPrompt) {
if (systemPrompt == SystemPrompt.DEFAULT_PROMPT) {
systemPrompt = SystemPrompt.BIGDATA_PROFESSOR;
}

return loadPromptFromFile(systemPrompt.getValue());
}

@Override
public SystemMessage getSystemPrompt() {
return loadPromptFromFile(DEFAULT_NAME);
return getSystemPrompt(SystemPrompt.DEFAULT_PROMPT);
}

private SystemMessage loadPromptFromFile(String fileName) {
final String filePath = SYSTEM_PROMPT_PATH + fileName;
final String filePath = SYSTEM_PROMPT_PATH + fileName + ".st";
try {
File file = ResourceUtils.getFile(filePath);
String text = Files.readString(file.toPath(), StandardCharsets.UTF_8);
return SystemMessage.from(text);
} catch (IOException e) {
//
log.error(
"Exception occurred while loading SystemPrompt from local. Here is some information:{}",
e.getMessage());
return SystemMessage.from("");
return SystemMessage.from("You are a helpful assistant.");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.ai.assistant.store;

import org.apache.bigtop.manager.ai.core.enums.MessageSender;
import org.apache.bigtop.manager.dao.po.ChatMessagePO;
import org.apache.bigtop.manager.dao.po.ChatThreadPO;
import org.apache.bigtop.manager.dao.repository.ChatMessageDao;
import org.apache.bigtop.manager.dao.repository.ChatThreadDao;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

public class PersistentChatMemoryStore implements ChatMemoryStore {

private final ChatThreadDao chatThreadDao;
private final ChatMessageDao chatMessageDao;

public PersistentChatMemoryStore(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
this.chatThreadDao = chatThreadDao;
this.chatMessageDao = chatMessageDao;
}

private ChatMessage convertToChatMessage(ChatMessagePO chatMessagePO) {
String sender = chatMessagePO.getSender().toLowerCase();
if (sender.equals(MessageSender.AI.getValue())) {
return new AiMessage(chatMessagePO.getMessage());
} else if (sender.equals(MessageSender.USER.getValue())) {
return new UserMessage(chatMessagePO.getMessage());
} else if (sender.equals(MessageSender.SYSTEM.getValue())) {
return new SystemMessage(chatMessagePO.getMessage());
} else {
return null;
}
}

private ChatMessagePO convertToChatMessagePO(ChatMessage chatMessage, Long chatThreadId) {
ChatMessagePO chatMessagePO = new ChatMessagePO();
if (chatMessage.type().equals(ChatMessageType.AI)) {
chatMessagePO.setSender(MessageSender.AI.getValue());
AiMessage aiMessage = (AiMessage) chatMessage;
chatMessagePO.setMessage(aiMessage.text());
} else if (chatMessage.type().equals(ChatMessageType.USER)) {
chatMessagePO.setSender(MessageSender.USER.getValue());
UserMessage userMessage = (UserMessage) chatMessage;
chatMessagePO.setMessage(userMessage.singleText());
} else if (chatMessage.type().equals(ChatMessageType.SYSTEM)) {
chatMessagePO.setSender(MessageSender.SYSTEM.getValue());
SystemMessage systemMessage = (SystemMessage) chatMessage;
chatMessagePO.setMessage(systemMessage.text());
} else {
chatMessagePO.setSender(chatMessage.type().toString());
}
ChatThreadPO chatThreadPO = chatThreadDao.findById(chatThreadId);
chatMessagePO.setUserId(chatThreadPO.getUserId());
chatMessagePO.setThreadId(chatThreadId);
return chatMessagePO;
}

@Override
public List<ChatMessage> getMessages(Object threadId) {
List<ChatMessagePO> chatMessages = chatMessageDao.findAllByThreadId((Long) threadId);
if (chatMessages.isEmpty()) {
return new ArrayList<>();
} else {
return chatMessages.stream().map(this::convertToChatMessage).collect(Collectors.toList());
}
}

@Override
public void updateMessages(Object threadId, List<ChatMessage> messages) {
ChatMessagePO chatMessagePO = convertToChatMessagePO(messages.get(messages.size() - 1), (Long) threadId);
chatMessageDao.save(chatMessagePO);
}

@Override
public void deleteMessages(Object threadId) {
chatMessageDao.deleteByThreadId((Long) threadId);
}
}
Loading

0 comments on commit 492c11d

Please sign in to comment.