Skip to content

Commit

Permalink
Refactoring of OpenAIAssistantService
Browse files Browse the repository at this point in the history
  • Loading branch information
jschm42 committed Feb 15, 2024
1 parent 37762cf commit 8e0a8f3
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.talkforgeai.service.openai.assistant.dto.ApiError;
import com.talkforgeai.service.openai.assistant.dto.Assistant;
import com.talkforgeai.service.openai.assistant.dto.AssistantList;
import com.talkforgeai.service.openai.assistant.dto.GptModelList;
Expand All @@ -29,7 +28,6 @@
import com.talkforgeai.service.openai.assistant.dto.Run;
import com.talkforgeai.service.openai.assistant.dto.RunConversationRequest;
import com.talkforgeai.service.openai.assistant.dto.Thread;
import com.talkforgeai.service.openai.chat.OpenAIChatService;
import com.talkforgeai.service.openai.exception.OpenAIException;
import com.talkforgeai.service.properties.OpenAIProperties;
import java.io.IOException;
Expand All @@ -43,94 +41,184 @@
import okhttp3.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpMethod;
import org.springframework.stereotype.Service;

/**
* This service class is responsible for interacting with the OpenAI Assistant API. It provides
* methods to create, retrieve, list, modify, and delete assistants. It also provides methods to run
* conversations, create threads, post messages, retrieve and list messages. It interacts with the
* OpenAI Assistant API using the provided OpenAIProperties and OkHttpClient.
*
* @author Jean Schmitz
*/
@Service
public class OpenAIAssistantService {

public static final Logger LOGGER = LoggerFactory.getLogger(OpenAIChatService.class);
public static final Logger LOGGER = LoggerFactory.getLogger(OpenAIAssistantService.class);
public static final MediaType JSON = MediaType.get("application/json; charset=utf-8");
private final OpenAIProperties openAIProperties;
private final OkHttpClient client;
private final ObjectMapper objectMapper = new ObjectMapper();

public OpenAIAssistantService(OpenAIProperties openAIProperties, OkHttpClient client) {
this.openAIProperties = openAIProperties;
this.client = client;
}

/**
* Creates a new assistant with the provided request details.
*
* @param createAssistantRequest the details of the assistant to be created
* @return the created assistant
*/
public Assistant createAssistant(Assistant createAssistantRequest) {
String body = objectToJsonString(createAssistantRequest);
Request request = createPostRequest(body, "/assistants");
return executeRequest(request, Assistant.class);
}

/**
* Retrieves an assistant with the provided assistant ID.
*
* @param assistantId the ID of the assistant to be retrieved
* @return the retrieved assistant
*/
public Assistant retrieveAssistant(String assistantId) {
Request request = createGetRequest("/assistants/" + assistantId);
return executeRequest(request, Assistant.class);
}

/**
* Lists all assistants based on the provided request details.
*
* @param listAssistantsRequest the details of the list request
* @return the list of assistants
*/
public AssistantList listAssistants(ListRequest listAssistantsRequest) {
Request request = createGetRequest("/assistants?" + createListUrlParams(listAssistantsRequest));
return executeRequest(request, AssistantList.class);
}

private String createListUrlParams(ListRequest listRequest) {
List<String> params = new ArrayList<>();

if (listRequest.limit() != null) {
params.add("limit=" + listRequest.limit());
}
if (listRequest.order() != null) {
params.add("order=" + listRequest.order());
}
if (listRequest.before() != null) {
params.add("before=" + listRequest.before());
}
if (listRequest.after() != null) {
params.add("after=" + listRequest.after());
}

return String.join("&", params);
}

/**
* Runs a conversation with the provided thread ID and request details.
*
* @param threadId the ID of the thread
* @param runConversationRequest the details of the run conversation request
* @return the run conversation
*/
public Run runConversation(String threadId, RunConversationRequest runConversationRequest) {
String body = objectToJsonString(runConversationRequest);
Request request = createPostRequest(body, "/threads/" + threadId + "/runs");
return executeRequest(request, Run.class);
}

/**
* Creates a new thread.
*
* @return the created thread
*/
public Thread createThread() {
Request request = createPostRequest("", "/threads");
return executeRequest(request, Thread.class);
}

/**
* Posts a message with the provided thread ID and request details.
*
* @param threadId the ID of the thread
* @param postMessageRequest the details of the post message request
* @return the posted message
*/
public Message postMessage(String threadId, PostMessageRequest postMessageRequest) {
String body = objectToJsonString(postMessageRequest);
Request request = createPostRequest(body, "/threads/" + threadId + "/messages");
return executeRequest(request, Message.class);
}

/**
* Retrieves a run with the provided thread ID and run ID.
*
* @param threadId the ID of the thread
* @param runId the ID of the run
* @return the retrieved run
*/
public Run retrieveRun(String threadId, String runId) {
Request request = createGetRequest("/threads/" + threadId + "/runs/" + runId);
return executeRequest(request, Run.class);
}

/**
* Lists all messages with the provided thread ID and request details.
*
* @param threadId the ID of the thread
* @param listMessagesRequest the details of the list messages request
* @return the list of messages
*/
public MessageList listMessages(String threadId, ListRequest listMessagesRequest) {
Request request = createGetRequest(
"/threads/" + threadId + "/messages?" + createListUrlParams(listMessagesRequest));
return executeRequest(request, MessageList.class);
}

/**
* Retrieves a message with the provided thread ID and message ID.
*
* @param threadId the ID of the thread
* @param messageId the ID of the message
* @return the retrieved message
*/
public Message retrieveMessage(String threadId, String messageId) {
Request request = createGetRequest("/threads/" + threadId + "/messages/" + messageId);
return executeRequest(request, Message.class);
}

/**
* Retrieves all models.
*
* @return the list of models
*/
public GptModelList retrieveModels() {
Request request = createGetRequest("/models");
return executeRequest(request, GptModelList.class);
}


/**
* Modifies an assistant with the provided assistant ID and modified assistant details.
*
* @param assistantId the ID of the assistant to be modified
* @param openAIModifiedAssistant the modified assistant details
* @return the modified assistant
*/
public Assistant modifyAssistant(String assistantId, Assistant openAIModifiedAssistant) {
String body = objectToJsonString(openAIModifiedAssistant);
Request request = createPostRequest(body, "/assistants/" + assistantId);
return executeRequest(request, Assistant.class);
}

/**
* Deletes an assistant with the provided assistant ID.
*
* @param assistantId the ID of the assistant to be deleted
*/
public void deleteAssistant(String assistantId) {
Request request = createDeleteRequest("/assistants/" + assistantId);
executeRequest(request, Void.class);
}

/**
* Cancels a run with the provided thread ID and run ID.
*
* @param threadId the ID of the thread
* @param runId the ID of the run
* @return the cancelled run
*/
public Run cancelRun(String threadId, String runId) {
Request request = createPostRequest("/threads/" + threadId + "/runs/" + runId + "/cancel");
return executeRequest(request, Run.class);
}

private Headers.Builder createDefaultHeaderBuilder() {
Headers.Builder headersBuilder = new Headers.Builder();
headersBuilder.add("Authorization", "Bearer " + openAIProperties.apiKey());
Expand All @@ -139,7 +227,6 @@ private Headers.Builder createDefaultHeaderBuilder() {
}

private String objectToJsonString(Object object) {
ObjectMapper objectMapper = new ObjectMapper();

try {
return objectMapper.writeValueAsString(object);
Expand All @@ -149,42 +236,41 @@ private String objectToJsonString(Object object) {
}

private <T> T executeRequest(Request request, Class<T> clazz) {
ObjectMapper objectMapper = new ObjectMapper();

try (Response response = client.newCall(request).execute()) {
if (response.code() != 200) {
ApiError error = objectMapper.readValue(response.body().string(), ApiError.class);
throw new OpenAIException(
"Request failed with code " + response.code() + " and message " + error.body()
.message(), error.body());
if (!response.isSuccessful()) {
throw new OpenAIException("Request failed with code " + response.code());
}
if (response.body() != null) {
return new ObjectMapper().readValue(response.body().string(), clazz);
} else {
throw new OpenAIException("Response body is null.");
}
return objectMapper.readValue(response.body().string(), clazz);
} catch (IOException e) {
throw new OpenAIException("Message creation failed.", e);
}
}

private Request createPostRequest(String body, String path) {
return createRequest("POST", path, body);
return createRequest(HttpMethod.POST.name(), path, body);
}

private Request createPostRequest(String path) {
return createRequest("POST", path, null);
return createRequest(HttpMethod.POST.name(), path, null);
}

private Request createGetRequest(String path) {
return createRequest("GET", path, null);
return createRequest(HttpMethod.GET.name(), path, null);
}

private Request createDeleteRequest(String path) {
return createRequest("DELETE", path, null);
return createRequest(HttpMethod.DELETE.name(), path, null);
}

private Request createRequest(String method, String path, String body) {
String apiUrl = openAIProperties.apiUrl() + path;

RequestBody requestBody = null;
if (body == null && method.equals("POST")) {
if (body == null && method.equals(HttpMethod.POST.name())) {
requestBody = RequestBody.create("", null);
} else if (body != null) {
requestBody = RequestBody.create(body, JSON);
Expand All @@ -197,19 +283,23 @@ private Request createRequest(String method, String path, String body) {
.build();
}

public Assistant modifyAssistant(String assistantId, Assistant openAIModifiedAssistant) {
String body = objectToJsonString(openAIModifiedAssistant);
Request request = createPostRequest(body, "/assistants/" + assistantId);
return executeRequest(request, Assistant.class);
}
private String createListUrlParams(ListRequest listRequest) {
List<String> params = new ArrayList<>();

public void deleteAssistant(String assistantId) {
Request request = createDeleteRequest("/assistants/" + assistantId);
executeRequest(request, Void.class);
}
if (listRequest.limit() != null) {
params.add("limit=" + listRequest.limit());
}
if (listRequest.order() != null) {
params.add("order=" + listRequest.order());
}
if (listRequest.before() != null) {
params.add("before=" + listRequest.before());
}
if (listRequest.after() != null) {
params.add("after=" + listRequest.after());
}

public Run cancelRun(String threadId, String runId) {
Request request = createPostRequest("/threads/" + threadId + "/runs/" + runId + "/cancel");
return executeRequest(request, Run.class);
return String.join("&", params);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023 Jean Schmitz.
* Copyright (c) 2023-2024 Jean Schmitz.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -38,6 +38,10 @@ public OpenAIException(String message, Throwable cause) {
super(message, cause);
}

public OpenAIException(String message) {
super(message);
}

@Deprecated
public String getDetail() {
return detail;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023 Jean Schmitz.
* Copyright (c) 2023-2024 Jean Schmitz.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,4 +27,5 @@ public record OpenAIProperties(String apiKey,
String postmanChatUrl,
String postmanRequestId,
boolean usePostman) {

}
Loading

0 comments on commit 8e0a8f3

Please sign in to comment.