Skip to content

Commit

Permalink
Publish a Spring Event after registering the AiService Bean in AiServ…
Browse files Browse the repository at this point in the history
…icesAutoConfig.

This event contains the AiService class and its corresponding tools description information.

Once a user implements the AiServiceRegisteredEventListener to listen for this event, they can receive the event during the Spring Boot startup phase and handle their business logic as needed.

original PR link: langchain4j#77
Issues link: langchain4j/langchain4j#2112
  • Loading branch information
catofdestruction committed Nov 26, 2024
1 parent a803860 commit 4f5c678
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package dev.langchain4j.service.spring;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.agent.tool.ToolSpecifications;
import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
Expand All @@ -9,19 +11,27 @@
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.beans.factory.support.ManagedList;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.context.annotation.Bean;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
Expand All @@ -31,7 +41,14 @@
import static dev.langchain4j.service.spring.AiServiceWiringMode.EXPLICIT;
import static java.util.Arrays.asList;

public class AiServicesAutoConfig {
public class AiServicesAutoConfig implements ApplicationEventPublisherAware {

private ApplicationEventPublisher eventPublisher;

@Override
public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
this.eventPublisher = applicationEventPublisher;
}

@Bean
BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
Expand All @@ -47,6 +64,7 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
String[] moderationModels = beanFactory.getBeanNamesForType(ModerationModel.class);

Set<String> tools = new HashSet<>();
Map<String, List<ToolSpecification>> beanToolSpecifications = new HashMap<>();
for (String beanName : beanFactory.getBeanDefinitionNames()) {
try {
String beanClassName = beanFactory.getBeanDefinition(beanName).getBeanClassName();
Expand All @@ -58,6 +76,10 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
for (Method beanMethod : beanClass.getDeclaredMethods()) {
if (beanMethod.isAnnotationPresent(Tool.class)) {
tools.add(beanName);
List<ToolSpecification> toolSpecifications =
beanToolSpecifications.getOrDefault(beanName, new ArrayList<>());
toolSpecifications.add(ToolSpecifications.toolSpecificationFrom(beanMethod));
beanToolSpecifications.put(beanName, toolSpecifications);
}
}
} catch (Exception e) {
Expand Down Expand Up @@ -146,17 +168,24 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
propertyValues
);

AiServiceRegisteredEvent registeredEvent;
if (aiServiceAnnotation.wiringMode() == EXPLICIT) {
propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools())));
registeredEvent = buildEvent(aiServiceClass, beanToolSpecifications, asList(aiServiceAnnotation.tools()));
} else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) {
propertyValues.add("tools", toManagedList(tools));
registeredEvent = buildEvent(aiServiceClass, beanToolSpecifications, tools);
} else {
throw illegalArgument("Unknown wiring mode: " + aiServiceAnnotation.wiringMode());
}

BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory;
registry.removeBeanDefinition(aiService);
registry.registerBeanDefinition(lowercaseFirstLetter(aiService), aiServiceBeanDefinition);

if (eventPublisher != null) {
eventPublisher.publishEvent(registeredEvent);
}
}
};
}
Expand Down Expand Up @@ -204,4 +233,14 @@ private static ManagedList<RuntimeBeanReference> toManagedList(Collection<String
}
return managedList;
}

private static AiServiceRegisteredEvent buildEvent(Class<?> aiServiceClass,
Map<String, List<ToolSpecification>> toolSpecifications,
Collection<String> tools) {
return new AiServiceRegisteredEvent(aiServiceClass, aiServiceClass,
tools.stream()
.filter(toolSpecifications::containsKey)
.flatMap(tool -> toolSpecifications.get(tool).stream())
.collect(Collectors.toList()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package dev.langchain4j.service.spring.event;

import dev.langchain4j.agent.tool.ToolSpecification;
import org.springframework.context.ApplicationEvent;

import java.util.List;

public class AiServiceRegisteredEvent extends ApplicationEvent {

private final Class<?> aiServiceClass;
private final List<ToolSpecification> toolSpecifications;

public AiServiceRegisteredEvent(Object source, Class<?> aiServiceClass, List<ToolSpecification> toolSpecifications) {
super(source);
this.aiServiceClass = aiServiceClass;
this.toolSpecifications = toolSpecifications;
}

public Class<?> getAiServiceClass() {
return aiServiceClass;
}

public List<ToolSpecification> getToolSpecifications() {
return toolSpecifications;
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
package dev.langchain4j.service.spring.mode.automatic.withTools;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.ApplicationListener;

import java.util.List;

@SpringBootApplication
class AiServiceWithToolsApplication {
class AiServiceWithToolsApplication implements ApplicationListener<AiServiceRegisteredEvent> {

public static void main(String[] args) {
SpringApplication.run(AiServiceWithToolsApplication.class, args);
}

@Override
public void onApplicationEvent(AiServiceRegisteredEvent event) {
Class<?> aiServiceClass = event.getAiServiceClass();
List<ToolSpecification> toolSpecifications = event.getToolSpecifications();
for (int i = 0; i < toolSpecifications.size(); i++) {
System.out.printf("[%s]: [Tool-%s]: %s%n", aiServiceClass.getSimpleName(), i + 1, toolSpecifications.get(i));
}
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package dev.langchain4j.service.spring.mode.automatic.withTools;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.service.spring.AiServicesAutoConfig;
import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent;
import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserverAspect;
import dev.langchain4j.service.spring.mode.automatic.withTools.listener.AiServiceRegisteredEventListener;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

import java.util.List;

import static dev.langchain4j.service.spring.mode.ApiKeys.OPENAI_API_KEY;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY_NAME_DESCRIPTION;
Expand All @@ -16,6 +21,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

class AiServicesAutoConfigIT {
Expand Down Expand Up @@ -69,6 +75,32 @@ void should_create_AI_service_with_tool_that_is_package_private_method_in_packag
});
}

@Test
void should_receive_ai_service_registered_event() {
contextRunner
.withUserConfiguration(AiServiceWithToolsApplication.class)
.run(context -> {

// given
AiServiceRegisteredEventListener listener = context.getBean(AiServiceRegisteredEventListener.class);

// then should receive AiServiceRegisteredEvent
assertTrue(listener.isEventReceived());
assertEquals(1, listener.getReceivedEvents().size());

AiServiceRegisteredEvent event = listener.getReceivedEvents().stream().findFirst().orElse(null);
assertNotNull(event);
assertEquals(AiServiceWithTools.class, event.getAiServiceClass());
assertEquals(4, event.getToolSpecifications().size());

List<String> tools = event.getToolSpecifications().stream().map(ToolSpecification::name).toList();
assertTrue(tools.contains("getCurrentDate"));
assertTrue(tools.contains("getCurrentTime"));
assertTrue(tools.contains("getToolObserverPackageName"));
assertTrue(tools.contains("getToolObserverKey"));
});
}

@Test
void should_create_AI_service_with_tool_which_is_enhanced_by_spring_aop() {
contextRunner
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package dev.langchain4j.service.spring.mode.automatic.withTools.listener;

import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationListener;

import java.util.ArrayList;
import java.util.List;

public class AbstractApplicationListener<E extends ApplicationEvent> implements ApplicationListener<E> {
private final List<E> receivedEvents = new ArrayList<>();

@Override
public void onApplicationEvent(E event) {
receivedEvents.add(event);
}

public List<E> getReceivedEvents() {
return receivedEvents;
}

public boolean isEventReceived() {
return !receivedEvents.isEmpty();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package dev.langchain4j.service.spring.mode.automatic.withTools.listener;

import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent;
import org.springframework.stereotype.Component;
@Component
public class AiServiceRegisteredEventListener extends AbstractApplicationListener<AiServiceRegisteredEvent> {
}

0 comments on commit 4f5c678

Please sign in to comment.