Skip to content

Commit

Permalink
Threads are assistant local now.
Browse files Browse the repository at this point in the history
  • Loading branch information
jschm42 committed Jun 7, 2024
1 parent 785f90c commit 3186c92
Show file tree
Hide file tree
Showing 13 changed files with 85 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,20 @@ public GenerateImageResponse generateImage(@RequestBody GenerateImageRequest gen
return assistantService.generateImage(generateImageRequest.prompt());
}

@GetMapping("/threads")
public List<ThreadDto> listThreads() {
return assistantService.retrieveThreads();
@GetMapping("/threads/{assistantId}")
public List<ThreadDto> listThreads(@PathVariable("assistantId") String assistantId) {
return assistantService.retrieveThreads(assistantId);
}

@PostMapping("/threads")
public ThreadDto createThread() {
return assistantService.createThread();
@PostMapping("/threads/{assistantId}")
public ThreadDto createThread(@PathVariable("assistantId") String assistantId) {
return assistantService.createThread(assistantId);
}

@GetMapping("/threads/{threadId}")
public ThreadDto retrieveThread(@PathVariable("threadId") String threadId) {
return assistantService.retrieveThread(threadId);
@GetMapping("/threads/{assistantId}/{threadId}")
public ThreadDto retrieveThread(@PathVariable("assistantId") String assistantId,
@PathVariable("threadId") String threadId) {
return assistantService.retrieveThread(assistantId, threadId);
}

@GetMapping("/threads/{threadId}/messages")
Expand All @@ -162,7 +163,7 @@ public ParsedMessageDto postProcessMessage(@PathVariable("threadId") String thre
}


@GetMapping("/threads/{threadId}/{filename}")
@GetMapping("/threads/images/{threadId}/{filename}")
public ResponseEntity<byte[]> getImage(@PathVariable String threadId,
@PathVariable String filename) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ public class AssistantEntity {
@OneToMany(mappedBy = "assistant", fetch = FetchType.LAZY)
private List<MemoryDocument> memoryDocuments = new ArrayList<>();

@OneToMany(mappedBy = "assistant", fetch = FetchType.LAZY)
private List<ThreadEntity> threads = new ArrayList<>();

@Column(name = "memory", length = 20, nullable = false)
private String memory;

Expand Down Expand Up @@ -165,5 +168,12 @@ public void setMemoryDocuments(List<MemoryDocument> memoryDocuments) {
this.memoryDocuments = memoryDocuments;
}

public List<ThreadEntity> getThreads() {
return threads;
}

public void setThreads(List<ThreadEntity> threads) {
this.threads = threads;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import jakarta.persistence.ManyToOne;
import jakarta.persistence.Table;
import java.util.Date;

Expand All @@ -36,6 +37,9 @@ public class ThreadEntity {
@Column(name = "created_at", nullable = false)
private Date createdAt;

@ManyToOne
private AssistantEntity assistant;

// Standard getters and setters
public String getId() {
return id;
Expand All @@ -60,4 +64,12 @@ public Date getCreatedAt() {
public void setCreatedAt(Date createdAt) {
this.createdAt = createdAt;
}

public AssistantEntity getAssistant() {
return assistant;
}

public void setAssistant(AssistantEntity assistant) {
this.assistant = assistant;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

public record ThreadDto(String id,
String title,
Date createdAt) {
Date createdAt,
AssistantDto assistant) {

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023 Jean Schmitz.
* Copyright (c) 2023-2024 Jean Schmitz.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,9 +17,17 @@
package com.talkforgeai.backend.assistant.repository;

import com.talkforgeai.backend.assistant.domain.ThreadEntity;
import java.util.List;
import java.util.Optional;
import org.springframework.data.domain.Sort;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;

@Repository
public interface ThreadRepository extends JpaRepository<ThreadEntity, String> {

List<ThreadEntity> findAllByAssistantId(String assistantId, Sort sortOrder);

Optional<ThreadEntity> findByIdAndAssistantId(String id, String assistantId);

}
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public Map<String, AssistantPropertyValue> mapProperties(Map<String, String> pro

public ThreadDto toDto(ThreadEntity threadEntity) {
return new ThreadDto(threadEntity.getId(), threadEntity.getTitle(),
threadEntity.getCreatedAt());
threadEntity.getCreatedAt(), toDto(threadEntity.getAssistant()));
}

public ThreadEntity toEntity(ThreadDto threadDto) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,18 +228,22 @@ public boolean doesAssistantExistByName(String assistantName) {
}

@Transactional
public ThreadDto createThread() {
public ThreadDto createThread(String assistantId) {
ThreadEntity threadEntity = new ThreadEntity();
threadEntity.setId(UniqueIdUtil.generateThreadId());
threadEntity.setAssistant(assistantRepository.findById(assistantId)
.orElseThrow(() -> new AssistentException("Assistant not found")));
threadEntity.setTitle("<no title>");
threadEntity.setCreatedAt(new Date());
threadRepository.save(threadEntity);

return assistantMapper.toDto(threadEntity);
}

public List<ThreadDto> retrieveThreads() {
return this.threadRepository.findAll(Sort.by(Sort.Direction.DESC, "createdAt")).stream()
public List<ThreadDto> retrieveThreads(String assistantId) {
return this.threadRepository.findAllByAssistantId(assistantId,
Sort.by(Sort.Direction.DESC, "createdAt"))
.stream()
.map(assistantMapper::toDto)
.toList();
}
Expand Down Expand Up @@ -533,10 +537,12 @@ public ThreadTitleDto generateThreadTitle(String threadId,
}
}

public ThreadDto retrieveThread(String threadId) {
return threadRepository.findById(threadId)
.map(assistantMapper::toDto)
.orElseThrow(() -> new AssistentException("Thread not found"));
public ThreadDto retrieveThread(String assistantId, String threadId) {
ThreadEntity threadEntity = threadRepository.findByIdAndAssistantId(threadId, assistantId)
.orElseThrow(() -> new AssistentException(
"Thread " + threadId + " owned by assisant " + assistantId + " not found"));

return assistantMapper.toDto(threadEntity);
}

@Transactional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,6 @@ private String downloadImage(String imageUrl, String threadId, Path chatDirector
throw ex;
}

return "/api/v1/threads/" + threadId + "/" + fileName;
return "/api/v1/threads/images/" + threadId + "/" + fileName;
}
}
10 changes: 7 additions & 3 deletions backend/src/main/resources/db/migration/V1__Initial_Setup.sql
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,15 @@ create table assistant_properties

create table thread
(
id varchar(50) not null primary key,
title varchar(50),
created_at timestamp
id varchar(50) not null primary key,
title varchar(50),
created_at timestamp,
assistant_id varchar(50) not null,
foreign key (assistant_id) references assistant (id)
);

create index idx_thread_assistant_id on thread (assistant_id);

create table message
(
id varchar(50) not null primary key,
Expand Down
5 changes: 3 additions & 2 deletions frontend/src/components/App.vue
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
<v-app>
<v-navigation-drawer absolute permanent width="200">
<v-list-item>
<img alt="Talkforge AI" class="logo" src="@/assets/logo.png"
style="width: 90%" title="Talkforge AI">
<img alt="Talkforge AI" class="logo" role="button"
src="@/assets/logo.png" style="width: 90%" title="Talkforge AI"
@click.prevent="onListAssistants">

<div class="d-flex flex-grow-1 align-items-start">
</div>
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/components/chat/ChatView.vue
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ export default defineComponent({
},
},
mounted() {
this.chatStore.clearThreadsList();
this.chatStore.newThread();
this.fetchData();
},
});
Expand Down
14 changes: 7 additions & 7 deletions frontend/src/composable/use-assistants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ export function useAssistants() {
chatStore.removeStatus();
isReading = false;
} else {
console.log('Chunk value: ', chunkValue);
//console.log('Chunk value: ', chunkValue);
partial += chunkValue;
const parts = partial.split('\n');
partial = parts.pop() ?? '';
Expand Down Expand Up @@ -169,7 +169,7 @@ export function useAssistants() {

const processData = (
data: string, event: string, debouncedUpdateCallback: () => void): string => {
console.log('PROCESS event=' + event + ': ', data);
//console.log('PROCESS event=' + event + ': ', data);

let result = '';
switch (event) {
Expand All @@ -191,12 +191,12 @@ export function useAssistants() {
};

function processRunStartedEvent(data: string) {
console.log('## processRunStartedEvent', data);
//console.log('## processRunStartedEvent', data);
chatStore.runId = data;
}

const processDeltaEvent = (data: string) => {
console.log('## processDeltaEvent', data);
//console.log('## processDeltaEvent', data);
const lastMessage = chatStore.getLastMessage();

if (data.length > 0) {
Expand Down Expand Up @@ -240,7 +240,7 @@ export function useAssistants() {

const retrieveThreads = async () => {
const result = await axios.get(
`/api/v1/threads`,
`/api/v1/threads/${chatStore.selectedAssistant.id}`,
{
params: {},
},
Expand All @@ -251,13 +251,13 @@ export function useAssistants() {

const retrieveThread = async (threadId: string) => {
const result = await axios.get(
`/api/v1/threads/${threadId}`,
`/api/v1/threads/${chatStore.selectedAssistant.id}/${threadId}`,
);
return result.data;
};

const createThread = async () => {
const result = await axios.post(`/api/v1/threads`);
const result = await axios.post(`/api/v1/threads/${chatStore.selectedAssistant.id}`);
return result.data;
};

Expand Down
8 changes: 7 additions & 1 deletion frontend/src/store/chat-store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,16 @@ export const useChatStore = defineStore('chat', {
},
},
actions: {
/**
* Clear the threads list.
*/
clearThreadsList() {
this.threads = [];
},
/**
* Create a new thread.
*/
async newThread() {
newThread() {
this.threadId = '';
this.threadMessages = [];
this.parsedMessages = {};
Expand Down

0 comments on commit 3186c92

Please sign in to comment.