Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BIGTOP-4260: Add chatbot command and tools #101

Merged
merged 28 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions bigtop-manager-ai/bigtop-manager-ai-dashscope/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,5 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-dashscope</artifactId>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-dashscope</artifactId>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.bigtop.manager.server.controller;

import org.apache.bigtop.manager.server.enums.ChatbotCommand;
import org.apache.bigtop.manager.server.model.converter.ChatThreadConverter;
import org.apache.bigtop.manager.server.model.dto.ChatThreadDTO;
import org.apache.bigtop.manager.server.model.req.ChatbotMessageReq;
Expand Down Expand Up @@ -88,12 +89,24 @@ public ResponseEntity<List<ChatThreadVO>> getAllChatThreads() {
@Operation(summary = "talk", description = "Talk with Chatbot")
@PostMapping("/threads/{threadId}/talk")
public SseEmitter talk(@PathVariable Long threadId, @RequestBody ChatbotMessageReq messageReq) {
return chatbotService.talk(threadId, messageReq.getMessage());
ChatbotCommand command = ChatbotCommand.getCommandFromMessage(messageReq.getMessage());
if (command != null) {
messageReq.setMessage(
messageReq.getMessage().substring(command.getCmd().length() + 2));
return chatbotService.talk(threadId, command, messageReq.getMessage());
}
return chatbotService.talk(threadId, null, messageReq.getMessage());
}

@Operation(summary = "history", description = "Get chat records")
@GetMapping("/threads/{threadId}/history")
public ResponseEntity<List<ChatMessageVO>> history(@PathVariable Long threadId) {
return ResponseEntity.success(chatbotService.history(threadId));
}

@Operation(summary = "get commands", description = "Get all commands")
@GetMapping("/commands")
public ResponseEntity<List<String>> getCommands() {
return ResponseEntity.success(chatbotService.getChatbotCommands());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.server.enums;

import lombok.Getter;

import java.util.ArrayList;
import java.util.List;

@Getter
public enum ChatbotCommand {
INFO("info"),
SEARCH("search"),
HELP("help");

private final String cmd;

ChatbotCommand(String cmd) {
this.cmd = cmd;
}

public static List<String> getAllCommands() {
List<String> commands = new ArrayList<>();
for (ChatbotCommand command : ChatbotCommand.values()) {
commands.add(command.cmd);
}
return commands;
}

public static ChatbotCommand getCommand(String cmd) {
for (ChatbotCommand command : ChatbotCommand.values()) {
if (command.cmd.equals(cmd)) {
return command;
}
}
return null;
}

public static ChatbotCommand getCommandFromMessage(String message) {
if (message.startsWith("/")) {
String[] parts = message.split(" ");
return getCommand(parts[0].substring(1));
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.bigtop.manager.server.service;

import org.apache.bigtop.manager.server.enums.ChatbotCommand;
import org.apache.bigtop.manager.server.model.dto.ChatThreadDTO;
import org.apache.bigtop.manager.server.model.vo.ChatMessageVO;
import org.apache.bigtop.manager.server.model.vo.ChatThreadVO;
Expand All @@ -34,11 +35,13 @@ public interface ChatbotService {

List<ChatThreadVO> getAllChatThreads();

SseEmitter talk(Long threadId, String message);
SseEmitter talk(Long threadId, ChatbotCommand command, String message);

List<ChatMessageVO> history(Long threadId);

ChatThreadVO updateChatThread(ChatThreadDTO chatThreadDTO);

List<String> getChatbotCommands();

ChatThreadVO getChatThread(Long threadId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.bigtop.manager.dao.repository.PlatformDao;
import org.apache.bigtop.manager.server.enums.ApiExceptionEnum;
import org.apache.bigtop.manager.server.enums.AuthPlatformStatus;
import org.apache.bigtop.manager.server.enums.ChatbotCommand;
import org.apache.bigtop.manager.server.exception.ApiException;
import org.apache.bigtop.manager.server.holder.SessionUserHolder;
import org.apache.bigtop.manager.server.model.converter.AuthPlatformConverter;
Expand All @@ -46,6 +47,7 @@
import org.apache.bigtop.manager.server.model.vo.ChatThreadVO;
import org.apache.bigtop.manager.server.model.vo.TalkVO;
import org.apache.bigtop.manager.server.service.ChatbotService;
import org.apache.bigtop.manager.server.tools.AiServiceToolsProvider;

import org.springframework.context.i18n.LocaleContextHolder;
import org.springframework.stereotype.Service;
Expand Down Expand Up @@ -118,10 +120,19 @@ private PlatformType getPlatformType(String platformName) {
}

private AIAssistant buildAIAssistant(
String platformName, String model, Map<String, String> credentials, Long threadId) {
return getAIAssistantFactory()
.createAiService(
getPlatformType(platformName), getAIAssistantConfig(model, credentials), threadId, null);
String platformName, String model, Map<String, String> credentials, Long threadId, ChatbotCommand command) {
if (command == null) {
return getAIAssistantFactory()
.createAiService(
getPlatformType(platformName), getAIAssistantConfig(model, credentials), threadId, null);
} else {
return getAIAssistantFactory()
.createAiService(
getPlatformType(platformName),
getAIAssistantConfig(model, credentials),
threadId,
new AiServiceToolsProvider(command));
}
}

@Override
Expand Down Expand Up @@ -182,7 +193,7 @@ public List<ChatThreadVO> getAllChatThreads() {
return chatThreads;
}

private AIAssistant prepareTalk(Long threadId) {
private AIAssistant prepareTalk(Long threadId, ChatbotCommand command, String message) {
ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId);
Long userId = SessionUserHolder.getUserId();
if (!Objects.equals(userId, chatThreadPO.getUserId()) || chatThreadPO.getIsDeleted()) {
Expand All @@ -199,14 +210,22 @@ private AIAssistant prepareTalk(Long threadId) {

PlatformPO platformPO = platformDao.findById(authPlatformPO.getPlatformId());
return buildAIAssistant(
platformPO.getName(), authPlatformDTO.getModel(), authPlatformDTO.getAuthCredentials(), threadId);
platformPO.getName(),
authPlatformDTO.getModel(),
authPlatformDTO.getAuthCredentials(),
threadId,
command);
}

@Override
public SseEmitter talk(Long threadId, String message) {
AIAssistant aiAssistant = prepareTalk(threadId);
Flux<String> stringFlux = aiAssistant.streamAsk(message);

public SseEmitter talk(Long threadId, ChatbotCommand command, String message) {
AIAssistant aiAssistant = prepareTalk(threadId, command, message);
Flux<String> stringFlux;
if (command == null) {
stringFlux = aiAssistant.streamAsk(message);
} else {
stringFlux = Flux.just(aiAssistant.ask(message));
}
SseEmitter emitter = new SseEmitter();
stringFlux.subscribe(
s -> {
Expand Down Expand Up @@ -290,6 +309,11 @@ public ChatThreadVO updateChatThread(ChatThreadDTO chatThreadDTO) {
chatThreadPO, authPlatformPO, platformDao.findById(authPlatformPO.getPlatformId()));
}

@Override
public List<String> getChatbotCommands() {
return ChatbotCommand.getAllCommands();
}

@Override
public ChatThreadVO getChatThread(Long threadId) {
ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.factory.AIAssistantFactory;
import org.apache.bigtop.manager.common.utils.JsonUtils;
import org.apache.bigtop.manager.dao.po.AuthPlatformPO;
import org.apache.bigtop.manager.dao.po.ChatMessagePO;
import org.apache.bigtop.manager.dao.po.ChatThreadPO;
Expand All @@ -48,6 +49,11 @@
import org.springframework.context.i18n.LocaleContextHolder;
import org.springframework.stereotype.Service;

import dev.langchain4j.agent.tool.JsonSchemaProperty;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.service.tool.ToolExecutor;
import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.service.tool.ToolProviderResult;
import lombok.extern.slf4j.Slf4j;

import jakarta.annotation.Resource;
Expand All @@ -73,6 +79,9 @@ public class LLMConfigServiceImpl implements LLMConfigService {

private AIAssistantFactory aiAssistantFactory;

private static final String TEST_FLAG = "ZmxhZw==";
private static final String TEST_KEY = "bm";

public AIAssistantFactory getAIAssistantFactory() {
if (aiAssistantFactory == null) {
aiAssistantFactory =
Expand All @@ -96,11 +105,9 @@ private PlatformType getPlatformType(String platformName) {
}

private Boolean testAuthorization(String platformName, String model, Map<String, String> credentials) {
AIAssistantConfig aiAssistantConfig = AIAssistantConfig.builder()
.setModel(model)
.setLanguage(LocaleContextHolder.getLocale().toString())
.addCredentials(credentials)
.build();
Boolean result = testFuncCalling(platformName, model, credentials);
log.info("Test func calling result: {}", result);
AIAssistantConfig aiAssistantConfig = getAIAssistantConfig(model, credentials, null);
AIAssistant aiAssistant = getAIAssistantFactory().create(getPlatformType(platformName), aiAssistantConfig);
try {
return aiAssistant.test();
Expand All @@ -109,6 +116,37 @@ private Boolean testAuthorization(String platformName, String model, Map<String,
}
}

private Boolean testFuncCalling(String platformName, String model, Map<String, String> credentials) {
ToolProvider toolProvider = (toolProviderRequest) -> {
ToolSpecification toolSpecification = ToolSpecification.builder()
.name("getFlag")
.description("Get flag based on key")
.addParameter("key", JsonSchemaProperty.STRING, JsonSchemaProperty.description("Lowercase key"))
.build();
ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> {
Map<String, Object> arguments = JsonUtils.readFromString(toolExecutionRequest.arguments());
String key = arguments.get("key").toString();
if (key.equals(TEST_KEY)) {
return TEST_FLAG;
}
return null;
};

return ToolProviderResult.builder()
.add(toolSpecification, toolExecutor)
.build();
};

AIAssistantConfig aiAssistantConfig = getAIAssistantConfig(model, credentials, null);
AIAssistant aiAssistant = getAIAssistantFactory()
.createAiService(getPlatformType(platformName), aiAssistantConfig, null, toolProvider);
try {
return aiAssistant.ask("What is the flag of " + TEST_KEY).contains(TEST_FLAG);
} catch (Exception e) {
throw new ApiException(ApiExceptionEnum.CREDIT_INCORRECT, e.getMessage());
}
}

private void switchActivePlatform(Long id) {
List<AuthPlatformPO> authPlatformPOS = authPlatformDao.findAll();
for (AuthPlatformPO authPlatformPO : authPlatformPOS) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.server.tools;

import org.apache.bigtop.manager.server.enums.ChatbotCommand;

import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.service.tool.ToolProviderRequest;
import dev.langchain4j.service.tool.ToolProviderResult;

public class AiServiceToolsProvider implements ToolProvider {

ChatbotCommand chatbotCommand;

public AiServiceToolsProvider(ChatbotCommand chatbotCommand) {
this.chatbotCommand = chatbotCommand;
}

public AiServiceToolsProvider() {
this.chatbotCommand = null;
}

@Override
public ToolProviderResult provideTools(ToolProviderRequest toolProviderRequest) {
if (chatbotCommand.equals(ChatbotCommand.INFO)) {
ClusterInfoTools clusterInfoTools = new ClusterInfoTools();
return ToolProviderResult.builder().addAll(clusterInfoTools.list()).build();
}
return null;
}
}
Loading
Loading