Skip to content

Commit

Permalink
BIGTOP-4260: Add chatbot command and tools (#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq authored Dec 11, 2024
1 parent da37120 commit 70dbb6c
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 23 deletions.
5 changes: 0 additions & 5 deletions bigtop-manager-ai/bigtop-manager-ai-dashscope/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,5 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-dashscope</artifactId>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-dashscope</artifactId>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.bigtop.manager.server.controller;

import org.apache.bigtop.manager.server.enums.ChatbotCommand;
import org.apache.bigtop.manager.server.model.converter.ChatThreadConverter;
import org.apache.bigtop.manager.server.model.dto.ChatThreadDTO;
import org.apache.bigtop.manager.server.model.req.ChatbotMessageReq;
Expand Down Expand Up @@ -88,12 +89,24 @@ public ResponseEntity<List<ChatThreadVO>> getAllChatThreads() {
@Operation(summary = "talk", description = "Talk with Chatbot")
@PostMapping("/threads/{threadId}/talk")
public SseEmitter talk(@PathVariable Long threadId, @RequestBody ChatbotMessageReq messageReq) {
return chatbotService.talk(threadId, messageReq.getMessage());
ChatbotCommand command = ChatbotCommand.getCommandFromMessage(messageReq.getMessage());
if (command != null) {
messageReq.setMessage(
messageReq.getMessage().substring(command.getCmd().length() + 2));
return chatbotService.talk(threadId, command, messageReq.getMessage());
}
return chatbotService.talk(threadId, null, messageReq.getMessage());
}

@Operation(summary = "history", description = "Get chat records")
@GetMapping("/threads/{threadId}/history")
public ResponseEntity<List<ChatMessageVO>> history(@PathVariable Long threadId) {
return ResponseEntity.success(chatbotService.history(threadId));
}

@Operation(summary = "get commands", description = "Get all commands")
@GetMapping("/commands")
public ResponseEntity<List<String>> getCommands() {
return ResponseEntity.success(chatbotService.getChatbotCommands());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.server.enums;

import lombok.Getter;

import java.util.ArrayList;
import java.util.List;

@Getter
public enum ChatbotCommand {
INFO("info"),
SEARCH("search"),
HELP("help");

private final String cmd;

ChatbotCommand(String cmd) {
this.cmd = cmd;
}

public static List<String> getAllCommands() {
List<String> commands = new ArrayList<>();
for (ChatbotCommand command : ChatbotCommand.values()) {
commands.add(command.cmd);
}
return commands;
}

public static ChatbotCommand getCommand(String cmd) {
for (ChatbotCommand command : ChatbotCommand.values()) {
if (command.cmd.equals(cmd)) {
return command;
}
}
return null;
}

public static ChatbotCommand getCommandFromMessage(String message) {
if (message.startsWith("/")) {
String[] parts = message.split(" ");
return getCommand(parts[0].substring(1));
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.bigtop.manager.server.service;

import org.apache.bigtop.manager.server.enums.ChatbotCommand;
import org.apache.bigtop.manager.server.model.dto.ChatThreadDTO;
import org.apache.bigtop.manager.server.model.vo.ChatMessageVO;
import org.apache.bigtop.manager.server.model.vo.ChatThreadVO;
Expand All @@ -34,11 +35,13 @@ public interface ChatbotService {

List<ChatThreadVO> getAllChatThreads();

SseEmitter talk(Long threadId, String message);
SseEmitter talk(Long threadId, ChatbotCommand command, String message);

List<ChatMessageVO> history(Long threadId);

ChatThreadVO updateChatThread(ChatThreadDTO chatThreadDTO);

List<String> getChatbotCommands();

ChatThreadVO getChatThread(Long threadId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.bigtop.manager.dao.repository.PlatformDao;
import org.apache.bigtop.manager.server.enums.ApiExceptionEnum;
import org.apache.bigtop.manager.server.enums.AuthPlatformStatus;
import org.apache.bigtop.manager.server.enums.ChatbotCommand;
import org.apache.bigtop.manager.server.exception.ApiException;
import org.apache.bigtop.manager.server.holder.SessionUserHolder;
import org.apache.bigtop.manager.server.model.converter.AuthPlatformConverter;
Expand All @@ -46,6 +47,7 @@
import org.apache.bigtop.manager.server.model.vo.ChatThreadVO;
import org.apache.bigtop.manager.server.model.vo.TalkVO;
import org.apache.bigtop.manager.server.service.ChatbotService;
import org.apache.bigtop.manager.server.tools.AiServiceToolsProvider;

import org.springframework.context.i18n.LocaleContextHolder;
import org.springframework.stereotype.Service;
Expand Down Expand Up @@ -118,10 +120,19 @@ private PlatformType getPlatformType(String platformName) {
}

private AIAssistant buildAIAssistant(
String platformName, String model, Map<String, String> credentials, Long threadId) {
return getAIAssistantFactory()
.createAiService(
getPlatformType(platformName), getAIAssistantConfig(model, credentials), threadId, null);
String platformName, String model, Map<String, String> credentials, Long threadId, ChatbotCommand command) {
if (command == null) {
return getAIAssistantFactory()
.createAiService(
getPlatformType(platformName), getAIAssistantConfig(model, credentials), threadId, null);
} else {
return getAIAssistantFactory()
.createAiService(
getPlatformType(platformName),
getAIAssistantConfig(model, credentials),
threadId,
new AiServiceToolsProvider(command));
}
}

@Override
Expand Down Expand Up @@ -182,7 +193,7 @@ public List<ChatThreadVO> getAllChatThreads() {
return chatThreads;
}

private AIAssistant prepareTalk(Long threadId) {
private AIAssistant prepareTalk(Long threadId, ChatbotCommand command, String message) {
ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId);
Long userId = SessionUserHolder.getUserId();
if (!Objects.equals(userId, chatThreadPO.getUserId()) || chatThreadPO.getIsDeleted()) {
Expand All @@ -199,14 +210,22 @@ private AIAssistant prepareTalk(Long threadId) {

PlatformPO platformPO = platformDao.findById(authPlatformPO.getPlatformId());
return buildAIAssistant(
platformPO.getName(), authPlatformDTO.getModel(), authPlatformDTO.getAuthCredentials(), threadId);
platformPO.getName(),
authPlatformDTO.getModel(),
authPlatformDTO.getAuthCredentials(),
threadId,
command);
}

@Override
public SseEmitter talk(Long threadId, String message) {
AIAssistant aiAssistant = prepareTalk(threadId);
Flux<String> stringFlux = aiAssistant.streamAsk(message);

public SseEmitter talk(Long threadId, ChatbotCommand command, String message) {
AIAssistant aiAssistant = prepareTalk(threadId, command, message);
Flux<String> stringFlux;
if (command == null) {
stringFlux = aiAssistant.streamAsk(message);
} else {
stringFlux = Flux.just(aiAssistant.ask(message));
}
SseEmitter emitter = new SseEmitter();
stringFlux.subscribe(
s -> {
Expand Down Expand Up @@ -290,6 +309,11 @@ public ChatThreadVO updateChatThread(ChatThreadDTO chatThreadDTO) {
chatThreadPO, authPlatformPO, platformDao.findById(authPlatformPO.getPlatformId()));
}

@Override
public List<String> getChatbotCommands() {
return ChatbotCommand.getAllCommands();
}

@Override
public ChatThreadVO getChatThread(Long threadId) {
ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.factory.AIAssistantFactory;
import org.apache.bigtop.manager.common.utils.JsonUtils;
import org.apache.bigtop.manager.dao.po.AuthPlatformPO;
import org.apache.bigtop.manager.dao.po.ChatMessagePO;
import org.apache.bigtop.manager.dao.po.ChatThreadPO;
Expand All @@ -48,6 +49,11 @@
import org.springframework.context.i18n.LocaleContextHolder;
import org.springframework.stereotype.Service;

import dev.langchain4j.agent.tool.JsonSchemaProperty;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.service.tool.ToolExecutor;
import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.service.tool.ToolProviderResult;
import lombok.extern.slf4j.Slf4j;

import jakarta.annotation.Resource;
Expand All @@ -73,6 +79,9 @@ public class LLMConfigServiceImpl implements LLMConfigService {

private AIAssistantFactory aiAssistantFactory;

private static final String TEST_FLAG = "ZmxhZw==";
private static final String TEST_KEY = "bm";

public AIAssistantFactory getAIAssistantFactory() {
if (aiAssistantFactory == null) {
aiAssistantFactory =
Expand All @@ -96,11 +105,9 @@ private PlatformType getPlatformType(String platformName) {
}

private Boolean testAuthorization(String platformName, String model, Map<String, String> credentials) {
AIAssistantConfig aiAssistantConfig = AIAssistantConfig.builder()
.setModel(model)
.setLanguage(LocaleContextHolder.getLocale().toString())
.addCredentials(credentials)
.build();
Boolean result = testFuncCalling(platformName, model, credentials);
log.info("Test func calling result: {}", result);
AIAssistantConfig aiAssistantConfig = getAIAssistantConfig(model, credentials, null);
AIAssistant aiAssistant = getAIAssistantFactory().create(getPlatformType(platformName), aiAssistantConfig);
try {
return aiAssistant.test();
Expand All @@ -109,6 +116,37 @@ private Boolean testAuthorization(String platformName, String model, Map<String,
}
}

private Boolean testFuncCalling(String platformName, String model, Map<String, String> credentials) {
ToolProvider toolProvider = (toolProviderRequest) -> {
ToolSpecification toolSpecification = ToolSpecification.builder()
.name("getFlag")
.description("Get flag based on key")
.addParameter("key", JsonSchemaProperty.STRING, JsonSchemaProperty.description("Lowercase key"))
.build();
ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> {
Map<String, Object> arguments = JsonUtils.readFromString(toolExecutionRequest.arguments());
String key = arguments.get("key").toString();
if (key.equals(TEST_KEY)) {
return TEST_FLAG;
}
return null;
};

return ToolProviderResult.builder()
.add(toolSpecification, toolExecutor)
.build();
};

AIAssistantConfig aiAssistantConfig = getAIAssistantConfig(model, credentials, null);
AIAssistant aiAssistant = getAIAssistantFactory()
.createAiService(getPlatformType(platformName), aiAssistantConfig, null, toolProvider);
try {
return aiAssistant.ask("What is the flag of " + TEST_KEY).contains(TEST_FLAG);
} catch (Exception e) {
throw new ApiException(ApiExceptionEnum.CREDIT_INCORRECT, e.getMessage());
}
}

private void switchActivePlatform(Long id) {
List<AuthPlatformPO> authPlatformPOS = authPlatformDao.findAll();
for (AuthPlatformPO authPlatformPO : authPlatformPOS) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.server.tools;

import org.apache.bigtop.manager.server.enums.ChatbotCommand;

import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.service.tool.ToolProviderRequest;
import dev.langchain4j.service.tool.ToolProviderResult;

public class AiServiceToolsProvider implements ToolProvider {

ChatbotCommand chatbotCommand;

public AiServiceToolsProvider(ChatbotCommand chatbotCommand) {
this.chatbotCommand = chatbotCommand;
}

public AiServiceToolsProvider() {
this.chatbotCommand = null;
}

@Override
public ToolProviderResult provideTools(ToolProviderRequest toolProviderRequest) {
if (chatbotCommand.equals(ChatbotCommand.INFO)) {
ClusterInfoTools clusterInfoTools = new ClusterInfoTools();
return ToolProviderResult.builder().addAll(clusterInfoTools.list()).build();
}
return null;
}
}
Loading

0 comments on commit 70dbb6c

Please sign in to comment.