diff --git a/README.md b/README.md index a17c9b0..5e948b7 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,12 @@ Vector Store Observability for different vector stores: * **[PGVector](https://github.com/ThomasVitale/llm-apps-java-spring-ai/tree/main/observability/observability-vector-stores-pgvector)** +## ⚙️ Model Context Protocol + +Integrations with MCP Servers for providing contexts to LLMs. + +* **[Brave Search API](https://github.com/ThomasVitale/llm-apps-java-spring-ai/tree/main/model-context-protocol/mcp-brave)** + ## 📋 Evaluation _Coming soon_ diff --git a/gradle.properties b/gradle.properties index 1461bcc..17baeb7 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,2 +1,3 @@ springAiVersion=1.0.0-SNAPSHOT +springAiMcpVersion=0.2.0 otelInstrumentationVersion=2.10.0-alpha \ No newline at end of file diff --git a/labs/tools/README.md b/labs/tools/README.md new file mode 100644 index 0000000..4e8dde4 --- /dev/null +++ b/labs/tools/README.md @@ -0,0 +1,63 @@ +# Labs: Tools + +Integrating with Tools, including @Tools-annotated methods and MCP Servers. + +## Brave + +The application consumes the [Brave Search API](https://api.search.brave.com). + +### Create an account + +Visit [api.search.brave.com](https://api.search.brave.com) and sign up for a new account. +Then, in the Brave Search API console, navigate to _Subscriptions_ and choose a subscription plan. +You can choose the "Free AI" plan to get started. + +### Configure API Key + +In the Brave Search API console, navigate to _API Keys_ and generate a new API key. +Copy and securely store your API key on your machine as an environment variable. +The application will use it to access the Brave Search API. + +```shell +export BRAVE_API_KEY= +``` + +## Ollama + +The application consumes models from an [Ollama](https://ollama.ai) inference server. You can either run Ollama locally on your laptop, +or rely on the Testcontainers support in Spring Boot to spin up an Ollama service automatically. +If you choose the first option, make sure you have Ollama installed and running on your laptop. +Either way, Spring AI will take care of pulling the needed Ollama models when the application starts, +if they are not available yet on your machine. + +## Running the application + +If you're using the native Ollama application, run the application as follows. + +```shell +./gradlew bootRun +``` + +If you want to rely on the native Testcontainers support in Spring Boot to spin up an Ollama service at startup time, +run the application as follows. + +```shell +./gradlew bootTestRun +``` + +## Calling the application + +> [!NOTE] +> These examples use the [httpie](https://httpie.io) CLI to send HTTP requests. + +Call the application that will use a @Tool-annotated method to retrieve the context to answer your question. + +```shell +http :8080/chat/method authorName=="J.R.R. Tolkien" -b +``` + +Call the application that will use an MCP Server to retrieve the context to answer your question. + +```shell +http :8080/chat/mcp question=="Does Spring AI supports a Modular RAG architecture? Please provide some references." +``` diff --git a/labs/tools/build.gradle b/labs/tools/build.gradle new file mode 100644 index 0000000..11b2e4d --- /dev/null +++ b/labs/tools/build.gradle @@ -0,0 +1,41 @@ +plugins { + id 'java' + id 'org.springframework.boot' + id 'io.spring.dependency-management' + id 'org.graalvm.buildtools.native' +} + +group = 'com.thomasvitale' +version = '0.0.1-SNAPSHOT' + +java { + toolchain { + languageVersion = JavaLanguageVersion.of(23) + } +} + +repositories { + mavenCentral() + maven { url 'https://repo.spring.io/milestone' } + maven { url 'https://repo.spring.io/snapshot' } +} + +dependencies { + implementation platform("org.springframework.ai:spring-ai-bom:${springAiVersion}") + + implementation 'org.springframework.boot:spring-boot-starter-web' + implementation "org.springframework.ai:spring-ai-ollama-spring-boot-starter" + + implementation "org.springframework.experimental:spring-ai-mcp:${springAiMcpVersion}" + + testAndDevelopmentOnly 'org.springframework.boot:spring-boot-devtools' + + testImplementation 'org.springframework.boot:spring-boot-starter-test' + testImplementation 'org.springframework.ai:spring-ai-spring-boot-testcontainers' + testImplementation 'org.testcontainers:ollama' + testRuntimeOnly 'org.junit.platform:junit-platform-launcher' +} + +tasks.named('test') { + useJUnitPlatform() +} diff --git a/labs/tools/src/main/java/com/thomasvitale/ai/spring/BookService.java b/labs/tools/src/main/java/com/thomasvitale/ai/spring/BookService.java new file mode 100644 index 0000000..8251571 --- /dev/null +++ b/labs/tools/src/main/java/com/thomasvitale/ai/spring/BookService.java @@ -0,0 +1,31 @@ +package com.thomasvitale.ai.spring; + +import org.springframework.stereotype.Service; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +@Service +public class BookService { + + private static final Map books = new ConcurrentHashMap<>(); + + static { + books.put(1, new Book("His Dark Materials", "Philip Pullman")); + books.put(2, new Book("Narnia", "C.S. Lewis")); + books.put(3, new Book("The Hobbit", "J.R.R. Tolkien")); + books.put(4, new Book("The Lord of The Rings", "J.R.R. Tolkien")); + books.put(5, new Book("The Silmarillion", "J.R.R. Tolkien")); + } + + public List getBooksByAuthor(Author author) { + return books.values().stream() + .filter(book -> author.name().equals(book.author())) + .toList(); + } + + public record Book(String title, String author) {} + public record Author(String name) {} + +} diff --git a/labs/tools/src/main/java/com/thomasvitale/ai/spring/ChatController.java b/labs/tools/src/main/java/com/thomasvitale/ai/spring/ChatController.java new file mode 100644 index 0000000..365680b --- /dev/null +++ b/labs/tools/src/main/java/com/thomasvitale/ai/spring/ChatController.java @@ -0,0 +1,54 @@ +package com.thomasvitale.ai.spring; + +import com.thomasvitale.ai.spring.api.tools.mcp.McpToolCallbackResolver; +import com.thomasvitale.ai.spring.api.tools.method.MethodToolCallbackResolver; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.mcp.client.McpSyncClient; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; + +/** + * Chat examples using the high-level ChatClient API. + */ +@RestController +class ChatController { + + private final ChatClient chatClient; + private final McpSyncClient mcpClient; + private final Tools tools; + + ChatController(ChatClient.Builder chatClientBuilder, McpSyncClient mcpClient, Tools tools) { + this.chatClient = chatClientBuilder.build(); + this.mcpClient = mcpClient; + this.tools = tools; + } + + @GetMapping("/chat/method") + String chatMethod(String authorName) { + var userPromptTemplate = "What books written by {author} are available in the library?"; + return chatClient.prompt() + .user(userSpec -> userSpec + .text(userPromptTemplate) + .param("author", authorName) + ) + .functions(MethodToolCallbackResolver.builder() + .target(tools) + .build() + .getToolCallbacks()) + .call() + .content(); + } + + @GetMapping("/chat/mcp") + String chatMcp(String question) { + return chatClient.prompt() + .user(question) + .functions(McpToolCallbackResolver.builder() + .mcpClients(mcpClient) + .build() + .getToolCallbacks()) + .call() + .content(); + } + +} diff --git a/labs/tools/src/main/java/com/thomasvitale/ai/spring/Functions.java b/labs/tools/src/main/java/com/thomasvitale/ai/spring/Functions.java new file mode 100644 index 0000000..a31251e --- /dev/null +++ b/labs/tools/src/main/java/com/thomasvitale/ai/spring/Functions.java @@ -0,0 +1,19 @@ +package com.thomasvitale.ai.spring; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Description; + +import java.util.List; +import java.util.function.Function; + +@Configuration(proxyBeanMethods = false) +class Functions { + + @Bean + @Description("Get the list of books written by the given author available in the library") + Function> booksByAuthor(BookService bookService) { + return bookService::getBooksByAuthor; + } + +} diff --git a/labs/tools/src/main/java/com/thomasvitale/ai/spring/LabsToolsApplication.java b/labs/tools/src/main/java/com/thomasvitale/ai/spring/LabsToolsApplication.java new file mode 100644 index 0000000..96457a9 --- /dev/null +++ b/labs/tools/src/main/java/com/thomasvitale/ai/spring/LabsToolsApplication.java @@ -0,0 +1,37 @@ +package com.thomasvitale.ai.spring; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.mcp.client.McpClient; +import org.springframework.ai.mcp.client.McpSyncClient; +import org.springframework.ai.mcp.client.stdio.ServerParameters; +import org.springframework.ai.mcp.client.stdio.StdioClientTransport; +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.Bean; + +@SpringBootApplication +public class LabsToolsApplication { + + private static final Logger logger = LoggerFactory.getLogger(LabsToolsApplication.class); + + public static void main(String[] args) { + SpringApplication.run(LabsToolsApplication.class, args); + } + + @Bean + McpSyncClient mcpClient() { + var serverParameters = ServerParameters.builder("npx") + .args("-y", "@modelcontextprotocol/server-brave-search") + .addEnvVar("BRAVE_API_KEY", System.getenv("BRAVE_API_KEY")) + .build(); + + var mcpClient = McpClient.using(new StdioClientTransport(serverParameters)).sync(); + + var initializeResult = mcpClient.initialize(); + logger.info("MCP Initialized: {}", initializeResult); + + return mcpClient; + } + +} diff --git a/labs/tools/src/main/java/com/thomasvitale/ai/spring/Tools.java b/labs/tools/src/main/java/com/thomasvitale/ai/spring/Tools.java new file mode 100644 index 0000000..98b37a4 --- /dev/null +++ b/labs/tools/src/main/java/com/thomasvitale/ai/spring/Tools.java @@ -0,0 +1,22 @@ +package com.thomasvitale.ai.spring; + +import com.thomasvitale.ai.spring.api.tools.Tool; +import org.springframework.stereotype.Component; + +import java.util.List; + +@Component +public class Tools { + + private final BookService bookService; + + Tools(BookService bookService) { + this.bookService = bookService; + } + + @Tool("Get the list of books written by the given author available in the library") + public List booksByAuthor(String author) { + return bookService.getBooksByAuthor(new BookService.Author(author)); + } + +} diff --git a/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/Tool.java b/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/Tool.java new file mode 100644 index 0000000..8511fbb --- /dev/null +++ b/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/Tool.java @@ -0,0 +1,16 @@ +package com.thomasvitale.ai.spring.api.tools; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target({ElementType.METHOD, ElementType.ANNOTATION_TYPE}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface Tool { + + String value() default ""; + +} diff --git a/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/ToolCallback.java b/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/ToolCallback.java new file mode 100644 index 0000000..6879ad1 --- /dev/null +++ b/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/ToolCallback.java @@ -0,0 +1,6 @@ +package com.thomasvitale.ai.spring.api.tools; + +import org.springframework.ai.model.function.FunctionCallback; + +public interface ToolCallback extends FunctionCallback { +} diff --git a/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/ToolCallbackResolver.java b/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/ToolCallbackResolver.java new file mode 100644 index 0000000..94ef7b7 --- /dev/null +++ b/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/ToolCallbackResolver.java @@ -0,0 +1,9 @@ +package com.thomasvitale.ai.spring.api.tools; + +import org.springframework.ai.model.function.FunctionCallback; + +public interface ToolCallbackResolver { + + FunctionCallback[] getToolCallbacks(); + +} diff --git a/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/mcp/McpToolCallback.java b/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/mcp/McpToolCallback.java new file mode 100644 index 0000000..970f58b --- /dev/null +++ b/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/mcp/McpToolCallback.java @@ -0,0 +1,14 @@ +package com.thomasvitale.ai.spring.api.tools.mcp; + +import com.thomasvitale.ai.spring.api.tools.ToolCallback; +import org.springframework.ai.mcp.client.McpSyncClient; +import org.springframework.ai.mcp.spec.McpSchema; +import org.springframework.ai.mcp.spring.McpFunctionCallback; + +public class McpToolCallback extends McpFunctionCallback implements ToolCallback { + + public McpToolCallback(McpSyncClient clientSession, McpSchema.Tool tool) { + super(clientSession, tool); + } + +} diff --git a/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/mcp/McpToolCallbackResolver.java b/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/mcp/McpToolCallbackResolver.java new file mode 100644 index 0000000..abaef43 --- /dev/null +++ b/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/mcp/McpToolCallbackResolver.java @@ -0,0 +1,62 @@ +package com.thomasvitale.ai.spring.api.tools.mcp; + +import com.thomasvitale.ai.spring.api.tools.ToolCallback; +import com.thomasvitale.ai.spring.api.tools.ToolCallbackResolver; +import org.springframework.ai.mcp.client.McpAsyncClient; +import org.springframework.ai.mcp.client.McpSyncClient; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.util.Assert; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; + +public class McpToolCallbackResolver implements ToolCallbackResolver { + + private final List mcpClients; + + public McpToolCallbackResolver(List mcpClients) { + Assert.notNull(mcpClients, "mcpClients cannot be null"); + Assert.noNullElements(mcpClients, "mcpClients cannot contain null elements"); + this.mcpClients = mcpClients; + } + + @Override + public FunctionCallback[] getToolCallbacks() { + return mcpClients.stream() + .flatMap(mcpClient -> mcpClient.listTools().tools().stream() + .map(tool -> (ToolCallback) new McpToolCallback(mcpClient, tool))) + .toArray(ToolCallback[]::new); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private List mcpClients; + + public Builder mcpClients(List mcpClients) { + this.mcpClients = mcpClients; + return this; + } + + public Builder mcpClients(McpSyncClient... mcpClients) { + Assert.notNull(mcpClients, "mcpClients cannot be null"); + this.mcpClients = Arrays.asList(mcpClients); + return this; + } + + public Builder mcpClients(McpAsyncClient... mcpClients) { + this.mcpClients = Stream.of(mcpClients) + .map(McpSyncClient::new) + .toList(); + return this; + } + + public McpToolCallbackResolver build() { + return new McpToolCallbackResolver(mcpClients); + } + } + +} diff --git a/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/method/MethodToolCallbackResolver.java b/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/method/MethodToolCallbackResolver.java new file mode 100644 index 0000000..cdb27bd --- /dev/null +++ b/labs/tools/src/main/java/com/thomasvitale/ai/spring/api/tools/method/MethodToolCallbackResolver.java @@ -0,0 +1,66 @@ +package com.thomasvitale.ai.spring.api.tools.method; + +import com.thomasvitale.ai.spring.api.tools.Tool; +import com.thomasvitale.ai.spring.api.tools.ToolCallbackResolver; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; + +public class MethodToolCallbackResolver implements ToolCallbackResolver { + + private final Object target; + + private MethodToolCallbackResolver(Object target) { + Assert.notNull(target, "target cannot be null"); + this.target = target; + } + + @Override + public FunctionCallback[] getToolCallbacks() { + List callbacks = new ArrayList<>(); + + // Get all methods from the target object + Method[] methods = ReflectionUtils.getAllDeclaredMethods(target.getClass()); + + for (Method method : methods) { + Tool toolAnnotation = method.getAnnotation(Tool.class); + + // Ignore methods without the @Tool annotation. + if (toolAnnotation == null) { + continue; + } + + // Create FunctionCallback for methods with the @Tool annotation. + FunctionCallback callback = FunctionCallback.builder() + .method(method.getName(), method.getParameterTypes()) + .description(toolAnnotation.value()) + .targetObject(target) + .build(); + callbacks.add(callback); + } + + return callbacks.toArray(new FunctionCallback[0]); + } + + public static class Builder { + private Object target; + + public Builder target(Object target) { + this.target = target; + return this; + } + + public MethodToolCallbackResolver build() { + return new MethodToolCallbackResolver(target); + } + } + + public static Builder builder() { + return new Builder(); + } + +} diff --git a/labs/tools/src/main/java/com/thomasvitale/ai/spring/model/ChatModelController.java b/labs/tools/src/main/java/com/thomasvitale/ai/spring/model/ChatModelController.java new file mode 100644 index 0000000..68d68d8 --- /dev/null +++ b/labs/tools/src/main/java/com/thomasvitale/ai/spring/model/ChatModelController.java @@ -0,0 +1,64 @@ +package com.thomasvitale.ai.spring.model; + +import com.thomasvitale.ai.spring.Tools; +import com.thomasvitale.ai.spring.api.tools.mcp.McpToolCallbackResolver; +import com.thomasvitale.ai.spring.api.tools.method.MethodToolCallbackResolver; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.mcp.client.McpSyncClient; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import java.util.Map; + +/** + * Chat examples using the low-level ChatModel API. + */ +@RestController +@RequestMapping("/model") +class ChatModelController { + + private final ChatModel chatModel; + private final McpSyncClient mcpClient; + private final Tools tools; + + ChatModelController(ChatModel chatModel, McpSyncClient mcpClient, Tools tools) { + this.chatModel = chatModel; + this.mcpClient = mcpClient; + this.tools = tools; + } + + @GetMapping("/chat/method") + String chatMethod(String authorName) { + var userPromptTemplate = new PromptTemplate(""" + What books written by {author} are available in the library? + """); + Map model = Map.of("author", authorName); + var prompt = userPromptTemplate.create(model, FunctionCallingOptions.builder() + .functionCallbacks(MethodToolCallbackResolver.builder() + .target(tools) + .build() + .getToolCallbacks()) + .build()); + + var chatResponse = chatModel.call(prompt); + return chatResponse.getResult().getOutput().getText(); + } + + @GetMapping("/chat/mcp") + String chatMcp(String question) { + var prompt = new Prompt(question, FunctionCallingOptions.builder() + .functionCallbacks(McpToolCallbackResolver.builder() + .mcpClients(mcpClient) + .build() + .getToolCallbacks()) + .build()); + + var chatResponse = chatModel.call(prompt); + return chatResponse.getResult().getOutput().getText(); + } + +} diff --git a/labs/tools/src/main/resources/application.yml b/labs/tools/src/main/resources/application.yml new file mode 100644 index 0000000..9445adf --- /dev/null +++ b/labs/tools/src/main/resources/application.yml @@ -0,0 +1,11 @@ +spring: + ai: + ollama: + init: + pull-model-strategy: when_missing + embedding: + include: false + chat: + options: + model: llama3.2 + temperature: 0.7 diff --git a/labs/tools/src/test/java/com/thomasvitale/ai/spring/LabsToolsApplicationTests.java b/labs/tools/src/test/java/com/thomasvitale/ai/spring/LabsToolsApplicationTests.java new file mode 100644 index 0000000..4565387 --- /dev/null +++ b/labs/tools/src/test/java/com/thomasvitale/ai/spring/LabsToolsApplicationTests.java @@ -0,0 +1,52 @@ +package com.thomasvitale.ai.spring; + +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.web.reactive.AutoConfigureWebTestClient; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Import; +import org.springframework.test.web.reactive.server.WebTestClient; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) +@AutoConfigureWebTestClient(timeout = "180s") +@Import(TestcontainersConfiguration.class) +@EnabledIfEnvironmentVariable(named = "BRAVE_API_KEY", matches = ".*") +class LabsToolsApplicationTests { + + @Autowired + WebTestClient webTestClient; + + @ParameterizedTest + @ValueSource(strings = {"/chat/method", "/model/chat/method"}) + void chatMethod(String path) { + webTestClient + .get() + .uri(uriBuilder -> uriBuilder + .path(path) + .queryParam("authorName", "Philip Pullman") + .build()) + .exchange() + .expectStatus().isOk() + .expectBody(String.class).value(result -> { + assertThat(result).containsIgnoringCase("His Dark Materials"); + }); + } + + @ParameterizedTest + @ValueSource(strings = {"/chat/mcp", "/model/chat/mcp"}) + void chatMcp(String path) { + webTestClient + .get() + .uri(uriBuilder -> uriBuilder + .path(path) + .queryParam("question", "Does Spring AI supports integrations with Ollama?") + .build()) + .exchange() + .expectStatus().isOk(); + } + +} diff --git a/labs/tools/src/test/java/com/thomasvitale/ai/spring/TestLabsToolsApplication.java b/labs/tools/src/test/java/com/thomasvitale/ai/spring/TestLabsToolsApplication.java new file mode 100644 index 0000000..641a3fc --- /dev/null +++ b/labs/tools/src/test/java/com/thomasvitale/ai/spring/TestLabsToolsApplication.java @@ -0,0 +1,11 @@ +package com.thomasvitale.ai.spring; + +import org.springframework.boot.SpringApplication; + +public class TestLabsToolsApplication { + + public static void main(String[] args) { + SpringApplication.from(LabsToolsApplication::main).with(TestcontainersConfiguration.class).run(args); + } + +} diff --git a/labs/tools/src/test/java/com/thomasvitale/ai/spring/TestcontainersConfiguration.java b/labs/tools/src/test/java/com/thomasvitale/ai/spring/TestcontainersConfiguration.java new file mode 100644 index 0000000..1db5053 --- /dev/null +++ b/labs/tools/src/test/java/com/thomasvitale/ai/spring/TestcontainersConfiguration.java @@ -0,0 +1,19 @@ +package com.thomasvitale.ai.spring; + +import org.springframework.boot.devtools.restart.RestartScope; +import org.springframework.boot.test.context.TestConfiguration; +import org.springframework.boot.testcontainers.service.connection.ServiceConnection; +import org.springframework.context.annotation.Bean; +import org.testcontainers.ollama.OllamaContainer; + +@TestConfiguration(proxyBeanMethods = false) +class TestcontainersConfiguration { + + @Bean + @RestartScope + @ServiceConnection + OllamaContainer ollama() { + return new OllamaContainer("ollama/ollama").withReuse(true); + } + +} diff --git a/labs/tools/src/test/java/com/thomasvitale/ai/spring/api/tools/method/MethodToolCallbackResolverTests.java b/labs/tools/src/test/java/com/thomasvitale/ai/spring/api/tools/method/MethodToolCallbackResolverTests.java new file mode 100644 index 0000000..36a2f95 --- /dev/null +++ b/labs/tools/src/test/java/com/thomasvitale/ai/spring/api/tools/method/MethodToolCallbackResolverTests.java @@ -0,0 +1,50 @@ +package com.thomasvitale.ai.spring.api.tools.method; + +import com.thomasvitale.ai.spring.api.tools.Tool; +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.MethodInvokingFunctionCallback; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class MethodToolCallbackResolverTests { + + @Test + void shouldResolveToolCallbacks() { + TestComponent testComponent = new TestComponent(); + MethodToolCallbackResolver resolver = MethodToolCallbackResolver.builder() + .target(testComponent) + .build(); + + FunctionCallback[] callbacks = resolver.getToolCallbacks(); + + assertThat(callbacks).hasSize(1); + MethodInvokingFunctionCallback callback = (MethodInvokingFunctionCallback) callbacks[0]; + assertThat(callback.getName()).isEqualTo("testMethod"); + assertThat(callback.getDescription()).isEqualTo("Test description"); + } + + @Test + void shouldFailWhenTargetIsNotProvided() { + assertThatThrownBy(() -> MethodToolCallbackResolver.builder().build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("target cannot be null"); + } + + static class TestComponent { + + @Tool("Test description") + public List testMethod(String input) { + return List.of(input); + } + + public void nonToolMethod() { + // This method should be ignored as it doesn't have @Tool annotation + } + + } + +} diff --git a/model-context-protocol/mcp-brave/README.md b/model-context-protocol/mcp-brave/README.md new file mode 100644 index 0000000..26e672b --- /dev/null +++ b/model-context-protocol/mcp-brave/README.md @@ -0,0 +1,57 @@ +# Model Context Protocol: Brave + +Integrating with the Brave Search API via the Model Context Protocol. + +## Brave + +The application consumes the [Brave Search API](https://api.search.brave.com). + +### Create an account + +Visit [api.search.brave.com](https://api.search.brave.com) and sign up for a new account. +Then, in the Brave Search API console, navigate to _Subscriptions_ and choose a subscription plan. +You can choose the "Free AI" plan to get started. + +### Configure API Key + +In the Brave Search API console, navigate to _API Keys_ and generate a new API key. +Copy and securely store your API key on your machine as an environment variable. +The application will use it to access the Brave Search API. + +```shell +export BRAVE_API_KEY= +``` + +## Ollama + +The application consumes models from an [Ollama](https://ollama.ai) inference server. You can either run Ollama locally on your laptop, +or rely on the Testcontainers support in Spring Boot to spin up an Ollama service automatically. +If you choose the first option, make sure you have Ollama installed and running on your laptop. +Either way, Spring AI will take care of pulling the needed Ollama models when the application starts, +if they are not available yet on your machine. + +## Running the application + +If you're using the native Ollama application, run the application as follows. + +```shell +./gradlew bootRun +``` + +If you want to rely on the native Testcontainers support in Spring Boot to spin up an Ollama service at startup time, +run the application as follows. + +```shell +./gradlew bootTestRun +``` + +## Calling the application + +> [!NOTE] +> These examples use the [httpie](https://httpie.io) CLI to send HTTP requests. + +Call the application that will use an MCP Server to retrieve the context to answer your question. + +```shell +http :8080/chat/mcp question=="Does Spring AI supports a Modular RAG architecture? Please provide some references." +``` diff --git a/model-context-protocol/mcp-brave/build.gradle b/model-context-protocol/mcp-brave/build.gradle new file mode 100644 index 0000000..11b2e4d --- /dev/null +++ b/model-context-protocol/mcp-brave/build.gradle @@ -0,0 +1,41 @@ +plugins { + id 'java' + id 'org.springframework.boot' + id 'io.spring.dependency-management' + id 'org.graalvm.buildtools.native' +} + +group = 'com.thomasvitale' +version = '0.0.1-SNAPSHOT' + +java { + toolchain { + languageVersion = JavaLanguageVersion.of(23) + } +} + +repositories { + mavenCentral() + maven { url 'https://repo.spring.io/milestone' } + maven { url 'https://repo.spring.io/snapshot' } +} + +dependencies { + implementation platform("org.springframework.ai:spring-ai-bom:${springAiVersion}") + + implementation 'org.springframework.boot:spring-boot-starter-web' + implementation "org.springframework.ai:spring-ai-ollama-spring-boot-starter" + + implementation "org.springframework.experimental:spring-ai-mcp:${springAiMcpVersion}" + + testAndDevelopmentOnly 'org.springframework.boot:spring-boot-devtools' + + testImplementation 'org.springframework.boot:spring-boot-starter-test' + testImplementation 'org.springframework.ai:spring-ai-spring-boot-testcontainers' + testImplementation 'org.testcontainers:ollama' + testRuntimeOnly 'org.junit.platform:junit-platform-launcher' +} + +tasks.named('test') { + useJUnitPlatform() +} diff --git a/model-context-protocol/mcp-brave/src/main/java/com/thomasvitale/ai/spring/ChatController.java b/model-context-protocol/mcp-brave/src/main/java/com/thomasvitale/ai/spring/ChatController.java new file mode 100644 index 0000000..3149c7e --- /dev/null +++ b/model-context-protocol/mcp-brave/src/main/java/com/thomasvitale/ai/spring/ChatController.java @@ -0,0 +1,32 @@ +package com.thomasvitale.ai.spring; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.mcp.client.McpSyncClient; +import org.springframework.ai.mcp.spring.McpFunctionCallback; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; + +/** + * Chat examples using the high-level ChatClient API. + */ +@RestController +class ChatController { + + private final ChatClient chatClient; + private final McpSyncClient mcpClient; + + ChatController(ChatClient.Builder chatClientBuilder, McpSyncClient mcpClient) { + this.chatClient = chatClientBuilder.build(); + this.mcpClient = mcpClient; + } + + @GetMapping("/chat/mcp") + String chat(String question) { + return chatClient.prompt() + .user(question) + .functions(McpFunctionCallbackResolver.resolve(mcpClient)) + .call() + .content(); + } + +} diff --git a/model-context-protocol/mcp-brave/src/main/java/com/thomasvitale/ai/spring/McpBraveApplication.java b/model-context-protocol/mcp-brave/src/main/java/com/thomasvitale/ai/spring/McpBraveApplication.java new file mode 100644 index 0000000..5aecb0f --- /dev/null +++ b/model-context-protocol/mcp-brave/src/main/java/com/thomasvitale/ai/spring/McpBraveApplication.java @@ -0,0 +1,37 @@ +package com.thomasvitale.ai.spring; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.mcp.client.McpClient; +import org.springframework.ai.mcp.client.McpSyncClient; +import org.springframework.ai.mcp.client.stdio.ServerParameters; +import org.springframework.ai.mcp.client.stdio.StdioClientTransport; +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.Bean; + +@SpringBootApplication +public class McpBraveApplication { + + private static final Logger logger = LoggerFactory.getLogger(McpBraveApplication.class); + + public static void main(String[] args) { + SpringApplication.run(McpBraveApplication.class, args); + } + + @Bean + public McpSyncClient mcpClient() { + var serverParameters = ServerParameters.builder("npx") + .args("-y", "@modelcontextprotocol/server-brave-search") + .addEnvVar("BRAVE_API_KEY", System.getenv("BRAVE_API_KEY")) + .build(); + + var mcpClient = McpClient.using(new StdioClientTransport(serverParameters)).sync(); + + var initializeResult = mcpClient.initialize(); + logger.info("MCP Initialized: {}", initializeResult); + + return mcpClient; + } + +} diff --git a/model-context-protocol/mcp-brave/src/main/java/com/thomasvitale/ai/spring/McpFunctionCallbackResolver.java b/model-context-protocol/mcp-brave/src/main/java/com/thomasvitale/ai/spring/McpFunctionCallbackResolver.java new file mode 100644 index 0000000..d1d9f95 --- /dev/null +++ b/model-context-protocol/mcp-brave/src/main/java/com/thomasvitale/ai/spring/McpFunctionCallbackResolver.java @@ -0,0 +1,28 @@ +package com.thomasvitale.ai.spring; + +import org.springframework.ai.mcp.client.McpSyncClient; +import org.springframework.ai.mcp.spring.McpFunctionCallback; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.util.Assert; + +import java.util.Arrays; +import java.util.List; + +public final class McpFunctionCallbackResolver { + + public static FunctionCallback[] resolve(McpSyncClient... mcpClients) { + Assert.notNull(mcpClients, "mcpClients cannot be null"); + return resolve(Arrays.asList(mcpClients)); + } + + public static FunctionCallback[] resolve(List mcpClients) { + Assert.notNull(mcpClients, "mcpClients cannot be null"); + Assert.noNullElements(mcpClients, "mcpClients cannot contain null elements"); + + return mcpClients.stream() + .flatMap(mcpClient -> mcpClient.listTools().tools().stream() + .map(tool -> (FunctionCallback) new McpFunctionCallback(mcpClient, tool))) + .toArray(FunctionCallback[]::new); + } + +} diff --git a/model-context-protocol/mcp-brave/src/main/java/com/thomasvitale/ai/spring/model/ChatModelController.java b/model-context-protocol/mcp-brave/src/main/java/com/thomasvitale/ai/spring/model/ChatModelController.java new file mode 100644 index 0000000..30e2bea --- /dev/null +++ b/model-context-protocol/mcp-brave/src/main/java/com/thomasvitale/ai/spring/model/ChatModelController.java @@ -0,0 +1,37 @@ +package com.thomasvitale.ai.spring.model; + +import com.thomasvitale.ai.spring.McpFunctionCallbackResolver; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.mcp.client.McpSyncClient; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +/** + * Chat examples using the low-level ChatModel API. + */ +@RestController +@RequestMapping("/model") +class ChatModelController { + + private final ChatModel chatModel; + private final McpSyncClient mcpClient; + + ChatModelController(ChatModel chatModel, McpSyncClient mcpClient) { + this.chatModel = chatModel; + this.mcpClient = mcpClient; + } + + @GetMapping("/chat/mcp") + String chat(String question) { + var prompt = new Prompt(question, FunctionCallingOptions.builder() + .functionCallbacks(McpFunctionCallbackResolver.resolve(mcpClient)) + .build()); + + var chatResponse = chatModel.call(prompt); + return chatResponse.getResult().getOutput().getText(); + } + +} diff --git a/model-context-protocol/mcp-brave/src/main/resources/application.yml b/model-context-protocol/mcp-brave/src/main/resources/application.yml new file mode 100644 index 0000000..9445adf --- /dev/null +++ b/model-context-protocol/mcp-brave/src/main/resources/application.yml @@ -0,0 +1,11 @@ +spring: + ai: + ollama: + init: + pull-model-strategy: when_missing + embedding: + include: false + chat: + options: + model: llama3.2 + temperature: 0.7 diff --git a/model-context-protocol/mcp-brave/src/test/java/com/thomasvitale/ai/spring/McpBraveApplicationTests.java b/model-context-protocol/mcp-brave/src/test/java/com/thomasvitale/ai/spring/McpBraveApplicationTests.java new file mode 100644 index 0000000..b7e00fd --- /dev/null +++ b/model-context-protocol/mcp-brave/src/test/java/com/thomasvitale/ai/spring/McpBraveApplicationTests.java @@ -0,0 +1,34 @@ +package com.thomasvitale.ai.spring; + +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.web.reactive.AutoConfigureWebTestClient; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Import; +import org.springframework.test.web.reactive.server.WebTestClient; + +@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) +@AutoConfigureWebTestClient(timeout = "180s") +@Import(TestcontainersConfiguration.class) +@EnabledIfEnvironmentVariable(named = "BRAVE_API_KEY", matches = ".*") +class McpBraveApplicationTests { + + @Autowired + WebTestClient webTestClient; + + @ParameterizedTest + @ValueSource(strings = {"/chat/mcp"}) + void chat(String path) { + webTestClient + .get() + .uri(uriBuilder -> uriBuilder + .path(path) + .queryParam("question", "Does Spring AI supports integrations with Ollama?") + .build()) + .exchange() + .expectStatus().isOk(); + } + +} diff --git a/model-context-protocol/mcp-brave/src/test/java/com/thomasvitale/ai/spring/TestMcpBraveApplication.java b/model-context-protocol/mcp-brave/src/test/java/com/thomasvitale/ai/spring/TestMcpBraveApplication.java new file mode 100644 index 0000000..f537746 --- /dev/null +++ b/model-context-protocol/mcp-brave/src/test/java/com/thomasvitale/ai/spring/TestMcpBraveApplication.java @@ -0,0 +1,11 @@ +package com.thomasvitale.ai.spring; + +import org.springframework.boot.SpringApplication; + +public class TestMcpBraveApplication { + + public static void main(String[] args) { + SpringApplication.from(McpBraveApplication::main).with(TestcontainersConfiguration.class).run(args); + } + +} diff --git a/model-context-protocol/mcp-brave/src/test/java/com/thomasvitale/ai/spring/TestcontainersConfiguration.java b/model-context-protocol/mcp-brave/src/test/java/com/thomasvitale/ai/spring/TestcontainersConfiguration.java new file mode 100644 index 0000000..1db5053 --- /dev/null +++ b/model-context-protocol/mcp-brave/src/test/java/com/thomasvitale/ai/spring/TestcontainersConfiguration.java @@ -0,0 +1,19 @@ +package com.thomasvitale.ai.spring; + +import org.springframework.boot.devtools.restart.RestartScope; +import org.springframework.boot.test.context.TestConfiguration; +import org.springframework.boot.testcontainers.service.connection.ServiceConnection; +import org.springframework.context.annotation.Bean; +import org.testcontainers.ollama.OllamaContainer; + +@TestConfiguration(proxyBeanMethods = false) +class TestcontainersConfiguration { + + @Bean + @RestartScope + @ServiceConnection + OllamaContainer ollama() { + return new OllamaContainer("ollama/ollama").withReuse(true); + } + +} diff --git a/settings.gradle b/settings.gradle index 65e5ca4..61cc221 100644 --- a/settings.gradle +++ b/settings.gradle @@ -14,6 +14,8 @@ include 'data-ingestion:document-readers:document-readers-tika-ollama' include 'data-ingestion:document-transformers:document-transformers-metadata-ollama' include 'data-ingestion:document-transformers:document-transformers-splitters-ollama' +include 'model-context-protocol:mcp-brave' + include 'models:chat:chat-mistral-ai' include 'models:chat:chat-multiple-providers' include 'models:chat:chat-ollama' @@ -61,3 +63,5 @@ include 'use-cases:question-answering' include 'use-cases:semantic-search' include 'use-cases:structured-data-extraction' include 'use-cases:text-classification' + +include 'labs:tools'