Skip to content

Commit

Permalink
add prepareTalk
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq committed Nov 13, 2024
1 parent ed826da commit f17e2c3
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

@Slf4j
public class AiServiceChatMemoryStore extends PersistentChatMemoryStore {
private final Map<Object, List<ChatMessage>> messagesByMemoryId = new ConcurrentHashMap();
private final Map<Object, List<ChatMessage>> messagesByMemoryId = new ConcurrentHashMap<>();

public AiServiceChatMemoryStore(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
super(chatThreadDao, chatMessageDao);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,10 @@ public ResponseEntity<List<ChatThreadVO>> getAllChatThreads() {
@Operation(summary = "talk", description = "Talk with Chatbot")
@PostMapping("/threads/{threadId}/talk")
public SseEmitter talk(@PathVariable Long threadId, @RequestBody ChatbotMessageReq messageReq) {
ChatbotCommand command;
if (messageReq.getCommand() == null) {
command = ChatbotCommand.getCommandFromMessage(messageReq.getMessage());
if (command != null) {
messageReq.setMessage(
messageReq.getMessage().substring(command.getCmd().length() + 2));
}
} else {
command = ChatbotCommand.getCommand(messageReq.getCommand());
ChatbotCommand command = ChatbotCommand.getCommandFromMessage(messageReq.getMessage());
if (command != null) {
messageReq.setMessage(
messageReq.getMessage().substring(command.getCmd().length() + 2));
}
if (command != null) {
return chatbotService.talkWithTools(threadId, command, messageReq.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,4 @@
public class ChatbotMessageReq {
@NotEmpty
private String message;

private String command;
}
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,7 @@ public List<ChatThreadVO> getAllChatThreads() {
return chatThreads;
}

@Override
public SseEmitter talk(Long threadId, String message) {
private AIAssistant prepareTalk(Long threadId, ChatbotCommand command, String message) {
ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId);
Long userId = SessionUserHolder.getUserId();
if (!Objects.equals(userId, chatThreadPO.getUserId()) || chatThreadPO.getIsDeleted()) {
Expand All @@ -218,21 +217,21 @@ public SseEmitter talk(Long threadId, String message) {
throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_IN_USE);
}

if (chatThreadPO.getName() == null) {
chatThreadPO.setName(getNameFromMessage(message));
chatThreadDao.partialUpdateById(chatThreadPO);
}

AuthPlatformDTO authPlatformDTO = AuthPlatformConverter.INSTANCE.fromPO2DTO(authPlatformPO);
ChatThreadDTO chatThreadDTO = ChatThreadConverter.INSTANCE.fromPO2DTO(chatThreadPO);
PlatformPO platformPO = platformDao.findById(authPlatformPO.getPlatformId());
AIAssistant aiAssistant = buildAIAssistant(
return buildAIAssistant(
platformPO.getName(),
authPlatformDTO.getModel(),
authPlatformDTO.getAuthCredentials(),
chatThreadPO.getId(),
threadId,
chatThreadDTO.getThreadInfo(),
null);
command);
}

@Override
public SseEmitter talk(Long threadId, String message) {
AIAssistant aiAssistant = prepareTalk(threadId, null, message);
Flux<String> stringFlux = aiAssistant.streamAsk(message);

SseEmitter emitter = new SseEmitter();
Expand All @@ -253,42 +252,16 @@ public SseEmitter talk(Long threadId, String message) {

@Override
public SseEmitter talkWithTools(Long threadId, ChatbotCommand command, String message) {
ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId);
Long userId = SessionUserHolder.getUserId();
if (!Objects.equals(userId, chatThreadPO.getUserId()) || chatThreadPO.getIsDeleted()) {
throw new ApiException(ApiExceptionEnum.CHAT_THREAD_NOT_FOUND);
}
AuthPlatformPO authPlatformPO = getActiveAuthPlatform();
if (authPlatformPO == null
|| authPlatformPO.getIsDeleted()
|| !authPlatformPO.getId().equals(chatThreadPO.getAuthId())) {
throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_IN_USE);
}
AIAssistant aiAssistant = prepareTalk(threadId, command, message);

AuthPlatformDTO authPlatformDTO = AuthPlatformConverter.INSTANCE.fromPO2DTO(authPlatformPO);
ChatThreadDTO chatThreadDTO = ChatThreadConverter.INSTANCE.fromPO2DTO(chatThreadPO);
PlatformPO platformPO = platformDao.findById(authPlatformPO.getPlatformId());
AIAssistant aiAssistant = buildAIAssistant(
platformPO.getName(),
authPlatformDTO.getModel(),
authPlatformDTO.getAuthCredentials(),
chatThreadDTO.getId(),
chatThreadDTO.getThreadInfo(),
command);

log.info("message: {}", message);
String result = aiAssistant.ask(message);
SseEmitter emitter = new SseEmitter(30_000L);
SseEmitter emitter = new SseEmitter();
try {
emitter.send(result);
emitter.complete();
} catch (Exception e) {
emitter.completeWithError(e);
}

emitter.onCompletion(() -> {
System.out.println("Data has been sent, performing post-send actions.");
});
return emitter;
}

Expand Down

0 comments on commit f17e2c3

Please sign in to comment.