Skip to content

Commit

Permalink
Update tools examples
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasVitale committed Dec 31, 2024
1 parent bb47bb1 commit 0d40c10
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 29 deletions.
10 changes: 9 additions & 1 deletion labs/tools/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,15 @@ run the application as follows.
Call the application that will use a @Tool-annotated method to retrieve the context to answer your question.

```shell
http :8080/chat/method authorName=="J.R.R. Tolkien" -b
http :8080/chat/method/void -b
```

```shell
http :8080/chat/method/single authorName=="J.R.R. Tolkien" -b
```

```shell
http :8080/chat/method/list authorName1=="J.R.R. Tolkien" authorName2=="Philip Pullman" -b
```

Call the application that will use an MCP Server to retrieve the context to answer your question.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ public List<Book> getBooksByAuthor(Author author) {
.toList();
}

public List<Book> getBooksByAuthor(List<Author> authors) {
return books.values().stream()
.filter(book -> authors.stream()
.anyMatch(author -> author.name().equals(book.author())))
.toList();
}

public record Book(String title, String author) {}
public record Author(String name) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,20 @@ class ChatController {
this.tools = tools;
}

@GetMapping("/chat/method")
String chatMethod(String authorName) {
@GetMapping("/chat/method/void")
String chatMethodVoid() {
return chatClient.prompt()
.user("Welcome the user to the library")
.functions(MethodToolCallbackResolver.builder()
.target(tools)
.build()
.getToolCallbacks())
.call()
.content();
}

@GetMapping("/chat/method/single")
String chatMethodSingle(String authorName) {
var userPromptTemplate = "What books written by {author} are available in the library?";
return chatClient.prompt()
.user(userSpec -> userSpec
Expand All @@ -39,6 +51,23 @@ String chatMethod(String authorName) {
.content();
}

@GetMapping("/chat/method/list")
String chatMethodList(String authorName1, String authorName2) {
var userPromptTemplate = "What books written by {authorName1} and {authorName2} are available in the library?";
return chatClient.prompt()
.user(userSpec -> userSpec
.text(userPromptTemplate)
.param("authorName1", authorName1)
.param("authorName2", authorName2)
)
.functions(MethodToolCallbackResolver.builder()
.target(tools)
.build()
.getToolCallbacks())
.call()
.content();
}

@GetMapping("/chat/mcp")
String chatMcp(String question) {
return chatClient.prompt()
Expand Down
20 changes: 18 additions & 2 deletions labs/tools/src/main/java/com/thomasvitale/ai/spring/Tools.java
Original file line number Diff line number Diff line change
@@ -1,22 +1,38 @@
package com.thomasvitale.ai.spring;

import com.thomasvitale.ai.spring.api.tools.Tool;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import java.util.List;

@Component
public class Tools {

private static final Logger logger = LoggerFactory.getLogger(Tools.class);

private final BookService bookService;

Tools(BookService bookService) {
this.bookService = bookService;
}

@Tool("Get the list of books written by the given author available in the library")
public List<BookService.Book> booksByAuthor(String author) {
return bookService.getBooksByAuthor(new BookService.Author(author));
public List<BookService.Book> booksByAuthor(BookService.Author author) {
logger.info("Getting books by author: {}", author);
return bookService.getBooksByAuthor(author);
}

@Tool("Get the list of books written by the given authors available in the library")
public List<BookService.Book> booksByAuthors(List<String> authors) {
logger.info("Getting books by authors: {}", String.join(", ", authors));
return bookService.getBooksByAuthor(authors.stream().map(BookService.Author::new).toList());
}

@Tool("Welcome users to the library")
public void welcome() {
logger.info("Welcoming users to the library");
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.thomasvitale.ai.spring.api.tools;

import org.springframework.ai.model.function.FunctionCallback;

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
Expand All @@ -11,6 +13,20 @@
@Documented
public @interface Tool {

/**
* The description of the tool. If not provided, the method name will be used.
*/
String value() default "";

/**
* The name of the tool. If not provided, the method name will be used.
*/
String name() default "";

/**
* The schema type of the tool. JSON Schema will work for most cases.
* Vertex AI requires OpenAPI Schema.
*/
FunctionCallback.SchemaType schemaType() default FunctionCallback.SchemaType.JSON_SCHEMA;

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

import java.util.stream.Stream;

public class MethodToolCallbackResolver implements ToolCallbackResolver {

Expand All @@ -21,29 +21,24 @@ private MethodToolCallbackResolver(Object target) {

@Override
public FunctionCallback[] getToolCallbacks() {
List<FunctionCallback> callbacks = new ArrayList<>();

// Get all methods from the target object
Method[] methods = ReflectionUtils.getAllDeclaredMethods(target.getClass());

for (Method method : methods) {
Tool toolAnnotation = method.getAnnotation(Tool.class);
return Stream.of(ReflectionUtils.getDeclaredMethods(target.getClass()))
.filter(method -> method.isAnnotationPresent(Tool.class))
.map(method -> FunctionCallback.builder()
.method(method.getName(), method.getParameterTypes())
.name(getToolName(method.getAnnotation(Tool.class), method.getName()))
.description(getToolDescription(method.getAnnotation(Tool.class), method.getName()))
.schemaType(method.getAnnotation(Tool.class).schemaType())
.targetObject(target)
.build())
.toArray(FunctionCallback[]::new);
}

// Ignore methods without the @Tool annotation.
if (toolAnnotation == null) {
continue;
}
private static String getToolName(Tool tool, String methodName) {
return StringUtils.hasText(tool.name()) ? tool.name() : methodName;
}

// Create FunctionCallback for methods with the @Tool annotation.
FunctionCallback callback = FunctionCallback.builder()
.method(method.getName(), method.getParameterTypes())
.description(toolAnnotation.value())
.targetObject(target)
.build();
callbacks.add(callback);
}

return callbacks.toArray(new FunctionCallback[0]);
private static String getToolDescription(Tool tool, String methodName) {
return StringUtils.hasText(tool.value()) ? tool.value() : methodName;
}

public static class Builder {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package com.thomasvitale.ai.spring;

import com.thomasvitale.ai.spring.api.tools.method.MethodToolCallbackResolver;
import org.junit.jupiter.api.Test;

import java.util.stream.Stream;

import static org.assertj.core.api.Assertions.assertThat;

class ToolsTests {

String value1 = """
{
"author": {
"name" : "J.R.R. Tolkien"
}
}
""";

String value2 = """
{
"authors" : ["J.R.R. Tolkien", "Philip Pullman"]
}
""";

@Test
void nonStaticMethod() {
var object = new Tools(new BookService());

var functionCallbacks = MethodToolCallbackResolver.builder()
.target(object)
.build()
.getToolCallbacks();

var booksByAuthor = Stream.of(functionCallbacks)
.filter(func -> func.getName().equals("booksByAuthor"))
.findFirst()
.orElseThrow();

String response1 = booksByAuthor.call(value1);

assertThat(response1).isNotEmpty();
assertThat(response1)
.containsIgnoringWhitespaces("The Hobbit")
.containsIgnoringWhitespaces("The Lord of The Rings")
.containsIgnoringWhitespaces("The Silmarillion");

var booksByAuthors = Stream.of(functionCallbacks)
.filter(func -> func.getName().equals("booksByAuthors"))
.findFirst()
.orElseThrow();

String response2 = booksByAuthors.call(value2);

assertThat(response2).isNotEmpty();
}

@Test
void noArgsNoReturnMethod() {
var object = new Tools(new BookService());

var functionCallbacks = MethodToolCallbackResolver.builder()
.target(object)
.build()
.getToolCallbacks();

var welcome = Stream.of(functionCallbacks)
.filter(func -> func.getName().equals("welcome"))
.findFirst()
.orElseThrow();

String response = welcome.call("{}");

assertThat(response).isEqualTo("Done");
}

}

0 comments on commit 0d40c10

Please sign in to comment.