Skip to content

Commit

Permalink
Refactoring in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq committed Dec 29, 2024
1 parent a6d80e4 commit 121e09e
Show file tree
Hide file tree
Showing 14 changed files with 108 additions and 175 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,41 @@
*/
package org.apache.bigtop.manager.ai.assistant;

import org.apache.bigtop.manager.ai.assistant.provider.LocSystemPromptProvider;
import org.apache.bigtop.manager.ai.assistant.provider.GeneralAssistantConfig;
import org.apache.bigtop.manager.ai.assistant.store.ChatMemoryStoreProvider;
import org.apache.bigtop.manager.ai.core.AbstractAIAssistantFactory;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.enums.SystemPrompt;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfig;
import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider;
import org.apache.bigtop.manager.ai.dashscope.DashScopeAssistant;
import org.apache.bigtop.manager.ai.openai.OpenAIAssistant;
import org.apache.bigtop.manager.ai.qianfan.QianFanAssistant;

import org.springframework.stereotype.Component;

import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;

import jakarta.annotation.Resource;
import java.util.List;

@Component
public class GeneralAssistantFactory extends AbstractAIAssistantFactory {

private final SystemPromptProvider systemPromptProvider;
private final ChatMemoryStoreProvider chatMemoryStoreProvider;

public GeneralAssistantFactory(ChatMemoryStoreProvider chatMemoryStoreProvider) {
this(new LocSystemPromptProvider(), chatMemoryStoreProvider);
}
@Resource
private SystemPromptProvider systemPromptProvider;

public GeneralAssistantFactory(
SystemPromptProvider systemPromptProvider, ChatMemoryStoreProvider chatMemoryStoreProvider) {
this.systemPromptProvider = systemPromptProvider;
this.chatMemoryStoreProvider = chatMemoryStoreProvider;
}
@Resource
private ChatMemoryStoreProvider chatMemoryStoreProvider;

@Override
public AIAssistant createWithPrompt(
PlatformType platformType,
AIAssistantConfigProvider assistantConfig,
Object id,
ToolProvider toolProvider,
SystemPrompt systemPrompt) {
AIAssistantConfig config, ToolProvider toolProvider, SystemPrompt systemPrompt) {
GeneralAssistantConfig generalAssistantConfig = (GeneralAssistantConfig) config;
PlatformType platformType = generalAssistantConfig.getPlatformType();
Object id = generalAssistantConfig.getId();
AIAssistant.Builder builder =
switch (platformType) {
case OPENAI -> OpenAIAssistant.builder();
Expand All @@ -68,12 +64,12 @@ public AIAssistant createWithPrompt(
(id == null)
? new InMemoryChatMemoryStore()
: chatMemoryStoreProvider.createPersistentChatMemoryStore())
.withConfigProvider(assistantConfig)
.withConfigProvider(generalAssistantConfig)
.withToolProvider(toolProvider);

List<String> systemPrompts = new java.util.ArrayList<>();
systemPrompts.add(systemPromptProvider.getSystemMessage(systemPrompt));
String locale = assistantConfig.getLanguage();
String locale = generalAssistantConfig.getLanguage();
if (locale != null) {
systemPrompts.add(systemPromptProvider.getLanguagePrompt(locale));
}
Expand All @@ -82,10 +78,4 @@ public AIAssistant createWithPrompt(

return builder.build();
}

@Override
public AIAssistant createAiService(
PlatformType platformType, AIAssistantConfigProvider assistantConfig, Long id, ToolProvider toolProvider) {
return createWithPrompt(platformType, assistantConfig, id, toolProvider, SystemPrompt.DEFAULT_PROMPT);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,34 @@
*/
package org.apache.bigtop.manager.ai.assistant.provider;

import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfig;

import lombok.Getter;

import java.util.HashMap;
import java.util.Map;

public class AIAssistantConfig implements AIAssistantConfigProvider {
@Getter
public class GeneralAssistantConfig implements AIAssistantConfig {

/**
* Model name for platform that we want to use
*/
private final Long id;
private final String model;

/**
* Credentials for different platforms
*/
private final Map<String, String> credentials;

private final String language;
private final PlatformType platformType;
private final Map<String, String> credentials;
/**
* Platform extra configs are put here
*/
private final Map<String, String> configs;

private AIAssistantConfig(
String model, Map<String, String> credentials, String language, Map<String, String> configMap) {
this.model = model;
this.credentials = credentials;
this.language = language;
this.configs = configMap;
private GeneralAssistantConfig(Builder builder) {
this.model = builder.model;
this.credentials = builder.credentials;
this.language = builder.language;
this.platformType = builder.platformType;
this.id = builder.id;
this.configs = builder.configs;
}

public static Builder builder() {
Expand All @@ -68,17 +67,12 @@ public Map<String, String> getConfigs() {
return configs;
}

@Override
public String getLanguage() {
return language;
}

public static class Builder {
private Long id;
private String model;
private String language;

private PlatformType platformType;
private final Map<String, String> credentials = new HashMap<>();

private final Map<String, String> configs = new HashMap<>();

public Builder() {}
Expand All @@ -88,6 +82,16 @@ public Builder setModel(String model) {
return this;
}

public Builder setPlatformType(PlatformType platformType) {
this.platformType = platformType;
return this;
}

public Builder setId(Long id) {
this.id = id;
return this;
}

public Builder addCredential(String key, String value) {
credentials.put(key, value);
return this;
Expand Down Expand Up @@ -115,8 +119,8 @@ public Builder addConfigs(Map<String, String> configMap) {
return this;
}

public AIAssistantConfig build() {
return new AIAssistantConfig(model, credentials, language, configs);
public GeneralAssistantConfig build() {
return new GeneralAssistantConfig(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.bigtop.manager.ai.core.enums.SystemPrompt;
import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider;

import org.springframework.stereotype.Component;
import org.springframework.util.ResourceUtils;

import lombok.extern.slf4j.Slf4j;
Expand All @@ -33,6 +34,7 @@
import java.util.Objects;

@Slf4j
@Component
public class LocSystemPromptProvider implements SystemPromptProvider {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,20 @@
import org.apache.bigtop.manager.dao.repository.ChatMessageDao;
import org.apache.bigtop.manager.dao.repository.ChatThreadDao;

import org.springframework.stereotype.Component;

import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;

public class ChatMemoryStoreProvider {
private final ChatThreadDao chatThreadDao;
private final ChatMessageDao chatMessageDao;
import jakarta.annotation.Resource;

public ChatMemoryStoreProvider(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
this.chatThreadDao = chatThreadDao;
this.chatMessageDao = chatMessageDao;
}
@Component
public class ChatMemoryStoreProvider {
@Resource
private ChatThreadDao chatThreadDao;

public ChatMemoryStoreProvider() {
this(null, null);
}
@Resource
private ChatMessageDao chatMessageDao;

public ChatMemoryStore createPersistentChatMemoryStore() {
if (chatThreadDao == null || chatMessageDao == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@
*/
package org.apache.bigtop.manager.ai.assistant;

import org.apache.bigtop.manager.ai.assistant.store.ChatMemoryStoreProvider;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;
import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfig;
import org.apache.bigtop.manager.ai.openai.OpenAIAssistant;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;
Expand All @@ -42,23 +39,18 @@
class GeneralAssistantFactoryTest {

@Mock
private SystemPromptProvider systemPromptProvider;
private AIAssistantConfig assistantConfigProvider;

@Mock
private AIAssistantConfigProvider assistantConfigProvider;

@InjectMocks
private GeneralAssistantFactory generalAssistantFactory;

@BeforeEach
void setUp() {
MockitoAnnotations.openMocks(this);
generalAssistantFactory = new GeneralAssistantFactory(systemPromptProvider, new ChatMemoryStoreProvider());
Map<String, String> credentials = Map.of("apiKey", "123456");
when(assistantConfigProvider.getModel()).thenReturn("model");
when(assistantConfigProvider.getCredentials()).thenReturn(credentials);
when(assistantConfigProvider.getConfigs()).thenReturn(null);
when(assistantConfigProvider.getLanguage()).thenReturn("en");
}

@Test
Expand All @@ -75,9 +67,8 @@ void testCreateAIAssistant() {
openAIAssistantMockedStatic.when(OpenAIAssistant::builder).thenReturn(mockBuilder);

PlatformType platformType = PlatformType.OPENAI;
generalAssistantFactory.create(platformType, assistantConfigProvider);
generalAssistantFactory = new GeneralAssistantFactory(new ChatMemoryStoreProvider());
generalAssistantFactory.create(platformType, assistantConfigProvider);
generalAssistantFactory.create(assistantConfigProvider);
generalAssistantFactory.create(assistantConfigProvider);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@
import static org.junit.jupiter.api.Assertions.assertNull;

@ExtendWith(MockitoExtension.class)
public class AIAssistantConfigTest {
public class GeneralAssistantConfigTest {

private AIAssistantConfig.Builder builder;
private GeneralAssistantConfig.Builder builder;
private String model;
private String language;
private Map<String, String> credentials;
private Map<String, String> configs;

@BeforeEach
public void setUp() {
builder = AIAssistantConfig.builder();
builder = GeneralAssistantConfig.builder();
model = "test-model";
language = "en-US";
credentials = new HashMap<>();
Expand All @@ -53,7 +53,7 @@ public void setUp() {

@Test
public void testBuilderSetsValuesCorrectly() {
AIAssistantConfig config = builder.setModel(model)
GeneralAssistantConfig config = builder.setModel(model)
.setLanguage(language)
.addCredentials(credentials)
.addConfigs(configs)
Expand All @@ -68,7 +68,7 @@ public void testBuilderSetsValuesCorrectly() {

@Test
public void testBuilderAddsSingleCredential() {
AIAssistantConfig config = builder.setModel(model)
GeneralAssistantConfig config = builder.setModel(model)
.setLanguage(language)
.addCredential("client_id", "abcd1234")
.build();
Expand All @@ -79,7 +79,7 @@ public void testBuilderAddsSingleCredential() {

@Test
public void testBuilderAddsSingleConfig() {
AIAssistantConfig config = builder.setModel(model)
GeneralAssistantConfig config = builder.setModel(model)
.setLanguage(language)
.addConfig("threadId", "123")
.build();
Expand All @@ -90,7 +90,7 @@ public void testBuilderAddsSingleConfig() {

@Test
public void testEmptyBuilder() {
AIAssistantConfig config = builder.build();
GeneralAssistantConfig config = builder.build();

assertNotNull(config);
assertNull(config.getModel());
Expand All @@ -108,7 +108,7 @@ public void testMultipleCredentialsAndConfigs() {
Map<String, String> extraConfigs = new HashMap<>();
extraConfigs.put("retry", "3");

AIAssistantConfig config = builder.setModel(model)
GeneralAssistantConfig config = builder.setModel(model)
.setLanguage(language)
.addCredentials(extraCredentials)
.addConfigs(extraConfigs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package org.apache.bigtop.manager.ai.core;

import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfig;

import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
Expand Down Expand Up @@ -61,7 +61,7 @@ public abstract static class Builder implements AIAssistant.Builder {
protected Object id;

protected ChatMemoryStore chatMemoryStore;
protected AIAssistantConfigProvider configProvider;
protected AIAssistantConfig configProvider;

protected ToolProvider toolProvider;
protected String systemPrompt;
Expand All @@ -78,7 +78,7 @@ public Builder withSystemPrompt(String systemPrompt) {
return this;
}

public Builder withConfigProvider(AIAssistantConfigProvider configProvider) {
public Builder withConfigProvider(AIAssistantConfig configProvider) {
this.configProvider = configProvider;
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package org.apache.bigtop.manager.ai.core.factory;

import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfig;

import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
Expand Down Expand Up @@ -74,7 +74,7 @@ interface Builder {

Builder memoryStore(ChatMemoryStore memoryStore);

Builder withConfigProvider(AIAssistantConfigProvider configProvider);
Builder withConfigProvider(AIAssistantConfig configProvider);

Builder withToolProvider(ToolProvider toolProvider);

Expand Down
Loading

0 comments on commit 121e09e

Please sign in to comment.