Skip to content

Commit

Permalink
add controller
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq committed Sep 23, 2024
1 parent 1dd7251 commit 4547195
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ public class ChatThreadPO extends BasePO implements Serializable {
@Column(name = "model", nullable = false, length = 255)
private String model;

@Column(name = "name")
private String name;

@Column(name = "thread_info", columnDefinition = "json")
private Map<String, String> threadInfo;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.bigtop.manager.server.model.converter.PlatformConverter;
import org.apache.bigtop.manager.server.model.dto.PlatformDTO;
import org.apache.bigtop.manager.server.model.req.ChatbotMessageReq;
import org.apache.bigtop.manager.server.model.req.ChatbotThreadReq;
import org.apache.bigtop.manager.server.model.req.PlatformReq;
import org.apache.bigtop.manager.server.model.vo.ChatMessageVO;
import org.apache.bigtop.manager.server.model.vo.ChatThreadVO;
Expand Down Expand Up @@ -112,6 +113,19 @@ public SseEmitter talk(
return chatbotService.talk(platformId, threadId, messageReq.getMessage());
}

@Operation(summary = "get name", description = "Get name of the thread")
@GetMapping("platforms/{platformId}/threads/{threadId}/name")
public ResponseEntity<ChatThreadVO> getThreadName(@PathVariable Long platformId, @PathVariable Long threadId) {
return ResponseEntity.success(chatbotService.getThreadName(platformId, threadId));
}

@Operation(summary = "get name", description = "Get name of the thread")
@PostMapping("platforms/{platformId}/threads/{threadId}/name")
public ResponseEntity<Boolean> setThreadName(
@PathVariable Long platformId, @PathVariable Long threadId, @RequestBody ChatbotThreadReq threadReq) {
return ResponseEntity.success(chatbotService.setThreadName(platformId, threadId, threadReq.getNewName()));
}

@Operation(summary = "history", description = "Get chat records")
@GetMapping("platforms/{platformId}/threads/{threadId}/history")
public ResponseEntity<List<ChatMessageVO>> history(@PathVariable Long platformId, @PathVariable Long threadId) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.server.model.req;

import lombok.Data;

import jakarta.validation.constraints.NotEmpty;

@Data
public class ChatbotThreadReq {
@NotEmpty
private String newName;
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,11 @@ public class ChatThreadVO {

private String model;

private String name;

private String createTime;

private String updateTime;

public ChatThreadVO(Long threadId, Long platformId, String model, String createTime) {
this.threadId = threadId;
this.platformId = platformId;
this.model = model;
this.createTime = createTime;
this.updateTime = createTime;
}

public ChatThreadVO() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,8 @@ public interface ChatbotService {
SseEmitter talk(Long platformId, Long threadId, String message);

List<ChatMessageVO> history(Long platformId, Long threadId);

ChatThreadVO getThreadName(Long platformId, Long threadId);

boolean setThreadName(Long platformId, Long threadId, String newName);
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,23 @@ private AIAssistant buildAIAssistant(
threadId);
}

private AIAssistant getAIAssistant(Long platformId, Long threadId) {
ChatThreadPO chatThreadPO = chatThreadDao.findByThreadId(threadId);
Long userId = SessionUserHolder.getUserId();
if (!Objects.equals(userId, chatThreadPO.getUserId())) {
throw new ApiException(ApiExceptionEnum.CHAT_THREAD_NOT_FOUND);
}
PlatformAuthorizedPO platformAuthorizedPO = platformAuthorizedDao.findByPlatformId(platformId);
if (platformAuthorizedPO == null || !platformAuthorizedPO.getId().equals(chatThreadPO.getPlatformId())) {
throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_AUTHORIZED);
}

PlatformPO platformPO = platformDao.findById(platformAuthorizedPO.getPlatformId());
PlatformAuthorizedDTO platformAuthorizedDTO = new PlatformAuthorizedDTO(
platformPO.getName(), platformAuthorizedPO.getCredentials(), chatThreadPO.getModel());
return buildAIAssistant(platformAuthorizedDTO, chatThreadPO.getId(), chatThreadPO.getThreadInfo());
}

private Boolean testAuthorization(PlatformAuthorizedDTO platformAuthorizedDTO) {
AIAssistant aiAssistant = getAiAssistantFactory()
.create(
Expand Down Expand Up @@ -289,21 +306,7 @@ public List<ChatThreadVO> getAllChatThreads(Long platformId, String model) {

@Override
public SseEmitter talk(Long platformId, Long threadId, String message) {
ChatThreadPO chatThreadPO = chatThreadDao.findByThreadId(threadId);
Long userId = SessionUserHolder.getUserId();
if (!Objects.equals(userId, chatThreadPO.getUserId())) {
throw new ApiException(ApiExceptionEnum.CHAT_THREAD_NOT_FOUND);
}
PlatformAuthorizedPO platformAuthorizedPO = platformAuthorizedDao.findByPlatformId(platformId);
if (platformAuthorizedPO == null) {
throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_AUTHORIZED);
}

PlatformPO platformPO = platformDao.findById(platformAuthorizedPO.getPlatformId());
PlatformAuthorizedDTO platformAuthorizedDTO = new PlatformAuthorizedDTO(
platformPO.getName(), platformAuthorizedPO.getCredentials(), chatThreadPO.getModel());
AIAssistant aiAssistant =
buildAIAssistant(platformAuthorizedDTO, chatThreadPO.getId(), chatThreadPO.getThreadInfo());
AIAssistant aiAssistant = getAIAssistant(platformId, threadId);
Flux<String> stringFlux = aiAssistant.streamAsk(message);

SseEmitter emitter = new SseEmitter();
Expand Down Expand Up @@ -346,4 +349,33 @@ public List<ChatMessageVO> history(Long platformId, Long threadId) {
}
return chatMessages;
}

@Override
public ChatThreadVO getThreadName(Long platformId, Long threadId) {
AIAssistant aiAssistant = getAIAssistant(platformId, threadId);
ChatThreadPO chatThreadPO = chatThreadDao.findByThreadId(threadId);
String threadName = aiAssistant.getThreadName();
if (threadName == null) {
threadName = "Unnamed thread " + threadId;
}
chatThreadPO.setName(threadName);
chatThreadDao.partialUpdateById(chatThreadPO);
return ChatThreadConverter.INSTANCE.fromPO2VO(chatThreadPO);
}

@Override
public boolean setThreadName(Long platformId, Long threadId, String newName) {
ChatThreadPO chatThreadPO = chatThreadDao.findByThreadId(threadId);
Long userId = SessionUserHolder.getUserId();
if (!Objects.equals(userId, chatThreadPO.getUserId())) {
throw new ApiException(ApiExceptionEnum.CHAT_THREAD_NOT_FOUND);
}
PlatformAuthorizedPO platformAuthorizedPO = platformAuthorizedDao.findByPlatformId(platformId);
if (platformAuthorizedPO == null || !platformAuthorizedPO.getId().equals(chatThreadPO.getPlatformId())) {
throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_AUTHORIZED);
}
chatThreadPO.setName(newName);
chatThreadDao.partialUpdateById(chatThreadPO);
return true;
}
}

0 comments on commit 4547195

Please sign in to comment.