Skip to content

Commit

Permalink
added getImageResponse to LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
supertick committed Jun 28, 2024
1 parent 3862221 commit 5fb5c4f
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 63 deletions.
239 changes: 176 additions & 63 deletions src/main/java/org/myrobotlab/service/LLM.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import org.myrobotlab.programab.Response;
import org.myrobotlab.service.config.HttpClientConfig;
import org.myrobotlab.service.config.LLMConfig;
import org.myrobotlab.service.data.ImageData;
import org.myrobotlab.service.data.Utterance;
import org.myrobotlab.service.interfaces.ImageListener;
import org.myrobotlab.service.interfaces.ResponsePublisher;
import org.myrobotlab.service.interfaces.TextListener;
import org.myrobotlab.service.interfaces.TextPublisher;
Expand Down Expand Up @@ -55,7 +57,7 @@
*
*/

public class LLM extends Service<LLMConfig> implements TextListener, TextPublisher, UtterancePublisher, UtteranceListener, ResponsePublisher {
public class LLM extends Service<LLMConfig> implements TextListener, TextPublisher, UtterancePublisher, UtteranceListener, ResponsePublisher, ImageListener {

private static final long serialVersionUID = 1L;

Expand Down Expand Up @@ -89,66 +91,92 @@ public void clearInputs() {
inputs.clear();
}

public String createImagePrompt(String model, String prompt, List<String> images) {

if (model == null) {
model = config.model;
}

if (prompt == null) {
prompt = config.defaultImagePrompt;
}

if (images == null || images.size() == 0) {
error("no images in image request");
return null;
}

LinkedHashMap<String, Object> msg = new LinkedHashMap<>();
msg.put("model", model);
msg.put("prompt", prompt);
msg.put("images", images);
msg.put("stream", false);
msg.put("n", 1);

return CodecUtils.toJson(msg);

}

public String createChatCompletionPayload(String model, String systemContent, String userContent, int n, float temperature, int maxTokens) {
try {
// Create the map to hold the request parameters
LinkedHashMap<String, Object> requestPayload = new LinkedHashMap<>();
requestPayload.put("model", model);

// Create and format date and time strings
LocalDateTime currentDateTime = LocalDateTime.now();
DateTimeFormatter dateFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd");
DateTimeFormatter timeFormatter = DateTimeFormatter.ofPattern("h:mm a");
DateTimeFormatter fullDateFormatter = DateTimeFormatter.ofPattern("EEEE MMMM d'th' yyyy h:mm a");

inputs.put("Date", currentDateTime.format(dateFormatter));
inputs.put("Time", currentDateTime.format(timeFormatter));
inputs.put("DateTime", currentDateTime.format(fullDateFormatter));

// Replace placeholders in system content
for (Map.Entry<String, Object> entry : inputs.entrySet()) {
if (entry.getValue() != null) {
systemContent = systemContent.replace(String.format("{{%s}}", entry.getKey()), entry.getValue().toString());
}
// Create the map to hold the request parameters
LinkedHashMap<String, Object> requestPayload = new LinkedHashMap<>();
requestPayload.put("model", model);

// Create and format date and time strings
LocalDateTime currentDateTime = LocalDateTime.now();
DateTimeFormatter dateFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd");
DateTimeFormatter timeFormatter = DateTimeFormatter.ofPattern("h:mm a");
DateTimeFormatter fullDateFormatter = DateTimeFormatter.ofPattern("EEEE MMMM d'th' yyyy h:mm a");

inputs.put("Date", currentDateTime.format(dateFormatter));
inputs.put("Time", currentDateTime.format(timeFormatter));
inputs.put("DateTime", currentDateTime.format(fullDateFormatter));

// Replace placeholders in system content
for (Map.Entry<String, Object> entry : inputs.entrySet()) {
if (entry.getValue() != null) {
systemContent = systemContent.replace(String.format("{{%s}}", entry.getKey()), entry.getValue().toString());
}
}

// Create system message
LinkedHashMap<String, Object> systemMessage = new LinkedHashMap<>();
systemMessage.put("role", "system");
systemMessage.put("content", systemContent);

// Handle message history
LinkedHashMap<String, Object> userMessage = new LinkedHashMap<>();
userMessage.put("role", "user");
userMessage.put("content", userContent);
userMessages.add(userMessage);

if (config.maxHistory > 0) {
while (userMessages.size() > config.maxHistory) {
userMessages.remove(0);
}
} else {
userMessages.clear();
// Create system message
LinkedHashMap<String, Object> systemMessage = new LinkedHashMap<>();
systemMessage.put("role", "system");
systemMessage.put("content", systemContent);

// Handle message history
LinkedHashMap<String, Object> userMessage = new LinkedHashMap<>();
userMessage.put("role", "user");
userMessage.put("content", userContent);
userMessages.add(userMessage);

if (config.maxHistory > 0) {
while (userMessages.size() > config.maxHistory) {
userMessages.remove(0);
}
} else {
userMessages.clear();
}

// Combine messages
List<LinkedHashMap<String, Object>> allMessages = new ArrayList<>();
allMessages.add(systemMessage);
allMessages.addAll(userMessages);
requestPayload.put("messages", allMessages);
// Combine messages
List<LinkedHashMap<String, Object>> allMessages = new ArrayList<>();
allMessages.add(systemMessage);
allMessages.addAll(userMessages);
requestPayload.put("messages", allMessages);

// Add other parameters
requestPayload.put("n", n);
requestPayload.put("temperature", temperature);
requestPayload.put("max_tokens", maxTokens);
// Add other parameters
requestPayload.put("n", n);
requestPayload.put("temperature", temperature);
requestPayload.put("max_tokens", maxTokens);

return CodecUtils.toJson(requestPayload);
return CodecUtils.toJson(requestPayload);

} catch (Exception e) {
error(e);
return null;
error(e);
return null;
}
}
}

public LinkedHashMap<String, Object> createFunctionDefinition(String name, String description, LinkedHashMap<String, Object> parameters) {
LinkedHashMap<String, Object> functionDefinition = new LinkedHashMap<>();
Expand Down Expand Up @@ -258,6 +286,61 @@ public Response getResponse(String text) {
return null;
}

Response getImageResponse(String base64Image) {
return getImageResponse(base64Image, null, null);
}

Response getImageResponse(String base64Image, String prompt, String model) {
try {

if (prompt == null) {
prompt = config.defaultImagePrompt;
}
// String.format("data:image/%s;base64,%s",
invoke("publishImageRequest", new ImageRequest(base64Image, prompt));

List<String> images = new ArrayList<>();
images.add(base64Image);
String json = createImagePrompt(model, prompt, images);

HttpClient<HttpClientConfig> http = (HttpClient) startPeer("http");

// log.info("curl {} -d '{}'", config.url, json);

String msg = http.postJson(config.password, config.url, json);
log.error("url: {}", config.url);

Map<String, Object> payload = CodecUtils.fromJson(msg, new StaticType<>() {
});

Response response = null;

if (payload.get("response") != null) {
String responseText = payload.get("response").toString();
response = new Response("friend", getName(), responseText, null);
Utterance utterance = new Utterance();
utterance.username = getName();
utterance.text = responseText;
utterance.isBot = true;
utterance.channel = currentChannel;
utterance.channelType = currentChannelType;
utterance.channelBotName = currentBotName;
utterance.channelName = currentChannelName;
if (responseText != null && responseText.length() > 0) {
invoke("publishUtterance", utterance);
invoke("publishResponse", response);
invoke("publishText", responseText);
}
}

return response;

} catch (Exception e) {
error(e);
}
return null;
}

/**
* Overridden error to also publish the errors probably would be a better
* solution to self subscribe to errors and have the subscriptions publish
Expand All @@ -279,6 +362,21 @@ public String publishRequest(String text) {
return text;
}

public static class ImageRequest {
public String base64Image;
public String prompt;

public ImageRequest(String base64Image, String prompt) {
this.base64Image = base64Image;
this.prompt = prompt;

}
}

public ImageRequest publishImageRequest(ImageRequest request) {
return request;
}

public void setToken(String password) {
config.password = password;
}
Expand Down Expand Up @@ -355,36 +453,51 @@ public void attach(Attachable attachable) {
log.error("don't know how to attach a {}", attachable.getName());
}
}

public void clearHistory() {
userMessages.clear();
}

@Override
public void onImage(ImageData img) {
getImageResponse(img.src);
}

public static void main(String[] args) {
try {

LoggingFactory.init(Level.INFO);

// Runtime runtime = Runtime.getInstance();
// Runtime.startConfig("gpt3-01");
Runtime.start("llm", "LLM");
Response response = null;
LLM llm = (LLM) Runtime.start("llm", "LLM");
LLM imagellm = (LLM) Runtime.start("imagellm", "LLM");

OpenCV cv = (OpenCV) Runtime.start("cv", "OpenCV");
cv.capture();

WebGui webgui = (WebGui) Runtime.create("webgui", "WebGui");
webgui.autoStartBrowser(false);
webgui.startService();

/*
* Gpt3 i01_chatBot = (Gpt3) Runtime.start("i01.chatBot", "Gpt3");
*
* bot.attach("i01.chatBot"); i01_chatBot.attach("bot");
*
* i01_chatBot.getResponse("hi, how are you?");
*
* Runtime.start("webgui", "WebGui");
*/
// llm.config.url = "http://fast:11434/v1/chat/completions";
// response = llm.getResponse("Hello, why is the sky blue?");
// System.out.println(response.msg);
for (int i = 0; i < 100; ++i) {

while (cv.getBase64Image() == null) {
Service.sleep(1000);
}

String base64Image = cv.getBase64Image();
imagellm.config.url = "http://fast:11434/api/generate";
imagellm.config.model = "bakllava";
response = imagellm.getImageResponse(base64Image);
System.out.println(response.msg);
}

} catch (Exception e) {
log.error("main threw", e);
}
}

}
24 changes: 24 additions & 0 deletions src/main/java/org/myrobotlab/service/OpenCV.java
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,9 @@ public Color getColor() {
}

public BufferedImage getDisplay() {
if (data == null) {
return null;
}
return data.getDisplay();
}

Expand Down Expand Up @@ -1091,6 +1094,22 @@ public String getGrabberType() {
public IplImage getImage() {
return lastImage;
}

public String getBase64Image() {
try {
final ByteArrayOutputStream os = new ByteArrayOutputStream();
String imgType = "jpg";
BufferedImage bi = getDisplay();
if (bi != null) {
ImageIO.write(bi, imgType, os);
os.close();
return String.format(CodecUtils.toBase64(os.toByteArray()));
}
} catch (Exception e) {
error(e);
}
return null;
}

/**
* "Easy" Base64 web image from display last frame
Expand Down Expand Up @@ -1246,6 +1265,11 @@ private void processVideo(OpenCVData data) throws org.bytedeco.javacv.FrameGrabb
// latency use the original ts from before fetch image and the filters
// !
webImage.ts = data.getTs();
// try {
// FileIO.toFile(String.format("image-%s-%d-base64.txt", getName(), frameIndex), webImage.data);
// } catch(Exception e) {
// error(e);
// }
broadcast("publishWebDisplay", webImage);
}

Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/myrobotlab/service/config/LLMConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public class LLMConfig extends ServiceConfig {
public float temperature = 0.7f;
// public String url = "https://api.openai.com/v1/chat/completions";
// http://localhost:11434/v1/chat/completions
// http://localhost:11434/api/generate : for image completion or I'd imagine https://api.openai.com/v1/api/generate
public String url = null;
public String password = null;
public String model = "llama3"; //"gpt-3.5-turbo"; // "text-davinci-003"
Expand All @@ -34,6 +35,8 @@ public class LLMConfig extends ServiceConfig {
*/
public String system = "You are a helpful robot.";

public String defaultImagePrompt = "What is this image?";

@Override
public Plan getDefault(Plan plan, String name) {
super.getDefault(plan, name);
Expand Down
10 changes: 10 additions & 0 deletions src/main/resources/resource/WebGui/app/service/js/LLMGui.js
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ angular.module("mrlapp.service.LLMGui", []).controller("LLMGuiCtrl", [
}
$scope.$apply()
break
case "onImageRequest":
request = { username: "friend", text: data.prompt, img: 'data:image/jpeg;base64,' + data.base64Image }
$scope.utterances.push(request)
// remove the beginning if we are at maxRecords
if ($scope.utterances.length > $scope.maxRecords) {
$scope.utterances.shift()
}
$scope.$apply()
break
case "onEpoch":
$scope.onEpoch = data
$scope.$apply()
Expand Down Expand Up @@ -131,6 +140,7 @@ angular.module("mrlapp.service.LLMGui", []).controller("LLMGuiCtrl", [
}

msg.subscribe("publishRequest")
msg.subscribe("publishImageRequest")
msg.subscribe("publishUtterance")
msg.subscribe(this)
},
Expand Down
Loading

0 comments on commit 5fb5c4f

Please sign in to comment.