Skip to content

Commit

Permalink
Use new Spring AI APIs for chat models
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasVitale committed Jan 24, 2024
1 parent 64a311e commit 80fecb2
Show file tree
Hide file tree
Showing 14 changed files with 79 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ChatController {

@GetMapping("/ai/chat")
String chat(@RequestParam(defaultValue = "What did Gandalf say to the Balrog?") String message) {
return chatClient.generate(message);
return chatClient.call(message);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ChatController {

@GetMapping("/ai/chat")
String chat(@RequestParam(defaultValue = "What did Gandalf say to the Balrog?") String message) {
return chatClient.generate(message);
return chatClient.call(message);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ String chatWithText(@RequestBody String input) {

@PostMapping("/ai/chat/prompt")
String chatWithPrompt(@RequestBody String input) {
return chatService.chatWithPrompt(input).getGeneration().getContent();
return chatService.chatWithPrompt(input).getResult().getOutput().getContent();
}

@PostMapping("/ai/chat/full")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.stereotype.Service;

@Service
Expand All @@ -15,11 +15,11 @@ class ChatService {
}

String chatWithText(String message) {
return chatClient.generate(message);
return chatClient.call(message);
}

ChatResponse chatWithPrompt(String message) {
return chatClient.generate(new Prompt(message));
return chatClient.call(new Prompt(message));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ String chatWithText(@RequestBody String input) {

@PostMapping("/ai/chat/prompt")
String chatWithPrompt(@RequestBody String input) {
return chatService.chatWithPrompt(input).getGeneration().getContent();
return chatService.chatWithPrompt(input).getResult().getOutput().getContent();
}

@PostMapping("/ai/chat/full")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.stereotype.Service;

@Service
Expand All @@ -15,11 +15,11 @@ class ChatService {
}

String chatWithText(String message) {
return chatClient.generate(message);
return chatClient.call(message);
}

ChatResponse chatWithPrompt(String message) {
return chatClient.generate(new Prompt(message));
return chatClient.call(new Prompt(message));
}

}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package com.thomasvitale.ai.spring;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.messages.AssistantMessage;
import org.springframework.ai.prompt.messages.SystemMessage;
import org.springframework.ai.prompt.messages.UserMessage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;
Expand All @@ -26,8 +26,8 @@ class ChatService {
AssistantMessage chatWithSingleMessage(String message) {
var userMessage = new UserMessage(message);
var prompt = new Prompt(userMessage);
var chatResponse = chatClient.generate(prompt);
return new AssistantMessage(chatResponse.getGeneration().getContent(), chatResponse.getGeneration().getProperties());
var chatResponse = chatClient.call(prompt);
return chatResponse.getResult().getOutput();
}

AssistantMessage chatWithMultipleMessages(String message) {
Expand All @@ -38,16 +38,16 @@ AssistantMessage chatWithMultipleMessages(String message) {
""");
var userMessage = new UserMessage(message);
var prompt = new Prompt(List.of(systemMessage, userMessage));
var chatResponse = chatClient.generate(prompt);
return new AssistantMessage(chatResponse.getGeneration().getContent(), chatResponse.getGeneration().getProperties());
var chatResponse = chatClient.call(prompt);
return chatResponse.getResult().getOutput();
}

AssistantMessage chatWithExternalMessage(String message) {
var systemMessage = new SystemMessage(systemMessageResource);
var userMessage = new UserMessage(message);
var prompt = new Prompt(List.of(systemMessage, userMessage));
var chatResponse = chatClient.generate(prompt);
return new AssistantMessage(chatResponse.getGeneration().getContent(), chatResponse.getGeneration().getProperties());
var chatResponse = chatClient.call(prompt);
return chatResponse.getResult().getOutput();
}

}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package com.thomasvitale.ai.spring;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.messages.AssistantMessage;
import org.springframework.ai.prompt.messages.SystemMessage;
import org.springframework.ai.prompt.messages.UserMessage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;
Expand All @@ -26,8 +26,8 @@ class ChatService {
AssistantMessage chatWithSingleMessage(String message) {
var userMessage = new UserMessage(message);
var prompt = new Prompt(userMessage);
var chatResponse = chatClient.generate(prompt);
return new AssistantMessage(chatResponse.getGeneration().getContent(), chatResponse.getGeneration().getProperties());
var chatResponse = chatClient.call(prompt);
return chatResponse.getResult().getOutput();
}

AssistantMessage chatWithMultipleMessages(String message) {
Expand All @@ -38,16 +38,16 @@ AssistantMessage chatWithMultipleMessages(String message) {
""");
var userMessage = new UserMessage(message);
var prompt = new Prompt(List.of(systemMessage, userMessage));
var chatResponse = chatClient.generate(prompt);
return new AssistantMessage(chatResponse.getGeneration().getContent(), chatResponse.getGeneration().getProperties());
var chatResponse = chatClient.call(prompt);
return chatResponse.getResult().getOutput();
}

AssistantMessage chatWithExternalMessage(String message) {
var systemMessage = new SystemMessage(systemMessageResource);
var userMessage = new UserMessage(message);
var prompt = new Prompt(List.of(systemMessage, userMessage));
var chatResponse = chatClient.generate(prompt);
return new AssistantMessage(chatResponse.getGeneration().getContent(), chatResponse.getGeneration().getProperties());
var chatResponse = chatClient.call(prompt);
return chatResponse.getResult().getOutput();
}

}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package com.thomasvitale.ai.spring;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.PromptTemplate;
import org.springframework.ai.prompt.SystemPromptTemplate;
import org.springframework.ai.prompt.messages.AssistantMessage;
import org.springframework.ai.prompt.messages.UserMessage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;
Expand Down Expand Up @@ -35,8 +35,8 @@ AssistantMessage chatWithUserMessageTemplate(MusicQuestion question) {
var userMessage = userPromptTemplate.createMessage(model);

var prompt = new Prompt(userMessage);
var chatResponse = chatClient.generate(prompt);
return new AssistantMessage(chatResponse.getGeneration().getContent(), chatResponse.getGeneration().getProperties());
var chatResponse = chatClient.call(prompt);
return chatResponse.getResult().getOutput();
}

AssistantMessage chatWithSystemMessageTemplate(String message) {
Expand All @@ -50,8 +50,8 @@ AssistantMessage chatWithSystemMessageTemplate(String message) {

var prompt = new Prompt(List.of(systemMessage, userMessage));

var chatResponse = chatClient.generate(prompt);
return new AssistantMessage(chatResponse.getGeneration().getContent(), chatResponse.getGeneration().getProperties());
var chatResponse = chatClient.call(prompt);
return chatResponse.getResult().getOutput();
}

AssistantMessage chatWithSystemMessageTemplateExternal(String message) {
Expand All @@ -63,8 +63,8 @@ AssistantMessage chatWithSystemMessageTemplateExternal(String message) {

var prompt = new Prompt(List.of(systemMessage, userMessage));

var chatResponse = chatClient.generate(prompt);
return new AssistantMessage(chatResponse.getGeneration().getContent(), chatResponse.getGeneration().getProperties());
var chatResponse = chatClient.call(prompt);
return chatResponse.getResult().getOutput();
}

private String randomGreeting() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package com.thomasvitale.ai.spring;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.PromptTemplate;
import org.springframework.ai.prompt.SystemPromptTemplate;
import org.springframework.ai.prompt.messages.AssistantMessage;
import org.springframework.ai.prompt.messages.UserMessage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;
Expand Down Expand Up @@ -35,8 +35,8 @@ AssistantMessage chatWithUserMessageTemplate(MusicQuestion question) {
var userMessage = userPromptTemplate.createMessage(model);

var prompt = new Prompt(userMessage);
var chatResponse = chatClient.generate(prompt);
return new AssistantMessage(chatResponse.getGeneration().getContent(), chatResponse.getGeneration().getProperties());
var chatResponse = chatClient.call(prompt);
return chatResponse.getResult().getOutput();
}

AssistantMessage chatWithSystemMessageTemplate(String message) {
Expand All @@ -50,8 +50,8 @@ AssistantMessage chatWithSystemMessageTemplate(String message) {

var prompt = new Prompt(List.of(systemMessage, userMessage));

var chatResponse = chatClient.generate(prompt);
return new AssistantMessage(chatResponse.getGeneration().getContent(), chatResponse.getGeneration().getProperties());
var chatResponse = chatClient.call(prompt);
return chatResponse.getResult().getOutput();
}

AssistantMessage chatWithSystemMessageTemplateExternal(String message) {
Expand All @@ -63,8 +63,8 @@ AssistantMessage chatWithSystemMessageTemplateExternal(String message) {

var prompt = new Prompt(List.of(systemMessage, userMessage));

var chatResponse = chatClient.generate(prompt);
return new AssistantMessage(chatResponse.getGeneration().getContent(), chatResponse.getGeneration().getProperties());
var chatResponse = chatClient.call(prompt);
return chatResponse.getResult().getOutput();
}

private String randomGreeting() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package com.thomasvitale.ai.spring;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.parser.BeanOutputParser;
import org.springframework.ai.parser.ListOutputParser;
import org.springframework.ai.parser.MapOutputParser;
import org.springframework.ai.prompt.PromptTemplate;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.stereotype.Service;

Expand All @@ -31,8 +31,8 @@ ArtistInfo chatWithBeanOutput(MusicQuestion question) {
Map<String,Object> model = Map.of("instrument", question.instrument(), "genre", question.genre(), "format", outputParser.getFormat());
var prompt = userPromptTemplate.create(model);

var chatResponse = chatClient.generate(prompt);
return outputParser.parse(chatResponse.getGeneration().getContent());
var chatResponse = chatClient.call(prompt);
return outputParser.parse(chatResponse.getResult().getOutput().getContent());
}

Map<String,Object> chatWithMapOutput(MusicQuestion question) {
Expand All @@ -46,8 +46,8 @@ Map<String,Object> chatWithMapOutput(MusicQuestion question) {
Map<String,Object> model = Map.of("instrument", question.instrument(), "genre", question.genre(), "format", outputParser.getFormat());
var prompt = userPromptTemplate.create(model);

var chatResponse = chatClient.generate(prompt);
return outputParser.parse(chatResponse.getGeneration().getContent());
var chatResponse = chatClient.call(prompt);
return outputParser.parse(chatResponse.getResult().getOutput().getContent());
}

List<String> chatWithListOutput(MusicQuestion question) {
Expand All @@ -61,8 +61,8 @@ List<String> chatWithListOutput(MusicQuestion question) {
Map<String,Object> model = Map.of("instrument", question.instrument(), "genre", question.genre(), "format", outputParser.getFormat());
var prompt = userPromptTemplate.create(model);

var chatResponse = chatClient.generate(prompt);
return outputParser.parse(chatResponse.getGeneration().getContent());
var chatResponse = chatClient.call(prompt);
return outputParser.parse(chatResponse.getResult().getOutput().getContent());
}

}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package com.thomasvitale.ai.spring;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.parser.BeanOutputParser;
import org.springframework.ai.parser.ListOutputParser;
import org.springframework.ai.parser.MapOutputParser;
import org.springframework.ai.prompt.PromptTemplate;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.stereotype.Service;

Expand All @@ -31,8 +31,8 @@ ArtistInfo chatWithBeanOutput(MusicQuestion question) {
Map<String,Object> model = Map.of("instrument", question.instrument(), "genre", question.genre(), "format", outputParser.getFormat());
var prompt = userPromptTemplate.create(model);

var chatResponse = chatClient.generate(prompt);
return outputParser.parse(chatResponse.getGeneration().getContent());
var chatResponse = chatClient.call(prompt);
return outputParser.parse(chatResponse.getResult().getOutput().getContent());
}

Map<String,Object> chatWithMapOutput(MusicQuestion question) {
Expand All @@ -46,8 +46,8 @@ Map<String,Object> chatWithMapOutput(MusicQuestion question) {
Map<String,Object> model = Map.of("instrument", question.instrument(), "genre", question.genre(), "format", outputParser.getFormat());
var prompt = userPromptTemplate.create(model);

var chatResponse = chatClient.generate(prompt);
return outputParser.parse(chatResponse.getGeneration().getContent());
var chatResponse = chatClient.call(prompt);
return outputParser.parse(chatResponse.getResult().getOutput().getContent());
}

List<String> chatWithListOutput(MusicQuestion question) {
Expand All @@ -61,8 +61,8 @@ List<String> chatWithListOutput(MusicQuestion question) {
Map<String,Object> model = Map.of("instrument", question.instrument(), "genre", question.genre(), "format", outputParser.getFormat());
var prompt = userPromptTemplate.create(model);

var chatResponse = chatClient.generate(prompt);
return outputParser.parse(chatResponse.getGeneration().getContent());
var chatResponse = chatClient.call(prompt);
return outputParser.parse(chatResponse.getResult().getOutput().getContent());
}

}
Loading

0 comments on commit 80fecb2

Please sign in to comment.