From b98ee1a225edadad312b1edc66abe346d9ba66e9 Mon Sep 17 00:00:00 2001 From: Victor Alfaro Date: Wed, 16 Oct 2024 20:46:17 -0600 Subject: [PATCH] #30361: allowing stream chat to dotAI --- .../dotcms/ai/client/AIClientStrategy.java | 57 +++++++- .../dotcms/ai/client/AIDefaultStrategy.java | 20 ++- .../ai/client/AIModelFallbackStrategy.java | 122 ++++++++---------- .../com/dotcms/ai/client/AIProxiedClient.java | 8 +- .../dotcms/ai/client/openai/OpenAIClient.java | 23 +++- .../java/com/dotcms/ai/domain/AIResponse.java | 1 - .../com/dotcms/ai/domain/AIResponseData.java | 11 ++ .../dotcms/ai/rest/CompletionsResource.java | 11 +- .../dotcms/ai/client/AIProxyClientTest.java | 1 + 9 files changed, 163 insertions(+), 91 deletions(-) diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java index 6ac784ef2a2a..371852569eb3 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java @@ -1,6 +1,8 @@ package com.dotcms.ai.client; +import com.dotcms.ai.AiKeys; import com.dotcms.ai.domain.AIResponse; +import com.dotcms.ai.domain.AIResponseData; import java.io.OutputStream; import java.io.Serializable; @@ -23,7 +25,10 @@ */ public interface AIClientStrategy { - AIClientStrategy NOOP = (client, handler, request, output) -> AIResponse.builder().build(); + AIClientStrategy NOOP = (client, handler, request, output) -> { + AIResponse.builder().build(); + return null; + }; /** * Applies the strategy to the given AI client request and handles the response. @@ -31,11 +36,51 @@ public interface AIClientStrategy { * @param client the AI client to which the request is sent * @param handler the response evaluator to handle the response * @param request the AI request to be processed - * @param output the output stream to which the response will be written + * @param incoming the output stream to which the response will be written + * @return response data object */ - void applyStrategy(AIClient client, - AIResponseEvaluator handler, - AIRequest request, - OutputStream output); + AIResponseData applyStrategy(AIClient client, + AIResponseEvaluator handler, + AIRequest request, + OutputStream incoming); + + /** + * Converts the given output stream to an AIResponseData object. + * + *

+ * This method takes an output stream, converts its content to a string, and + * sets it as the response in an AIResponseData object. The output stream is + * also set in the AIResponseData object. + *

+ * + * @param output the output stream containing the response data + * @param isStream is stream flag + * @return an AIResponseData object containing the response and the output stream + */ + static AIResponseData response(final OutputStream output, boolean isStream) { + final AIResponseData responseData = new AIResponseData(); + if (!isStream) { + responseData.setResponse(output.toString()); + } + responseData.setOutput(output); + + return responseData; + } + + /** + * Checks if the given request is a stream request. + * + *

+ * This method examines the payload of the provided `JSONObjectAIRequest` to determine + * if it contains a stream flag set to true. If the stream flag is present and set to true, + * the method returns true, indicating that the request is a stream request. + *

+ * + * @param request the `JSONObjectAIRequest` to be checked + * @return true if the request is a stream request, false otherwise + */ + static boolean isStream(final JSONObjectAIRequest request) { + return request.getPayload().optBoolean(AiKeys.STREAM, false); + } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java index 02149d98a7b1..37974fa9c3c6 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java @@ -1,7 +1,11 @@ package com.dotcms.ai.client; +import com.dotcms.ai.domain.AIResponseData; + +import java.io.ByteArrayOutputStream; import java.io.OutputStream; import java.io.Serializable; +import java.util.Optional; /** * Default implementation of the {@link AIClientStrategy} interface. @@ -22,11 +26,17 @@ public class AIDefaultStrategy implements AIClientStrategy { @Override - public void applyStrategy(final AIClient client, - final AIResponseEvaluator handler, - final AIRequest request, - final OutputStream output) { - client.sendRequest(request, output); + public AIResponseData applyStrategy(final AIClient client, + final AIResponseEvaluator handler, + final AIRequest request, + final OutputStream incoming) { + final JSONObjectAIRequest jsonRequest = AIClient.useRequestOrThrow(request); + final boolean isStream = AIClientStrategy.isStream(jsonRequest); + final OutputStream output = Optional.ofNullable(incoming).orElseGet(ByteArrayOutputStream::new); + + client.sendRequest(jsonRequest, output); + + return AIClientStrategy.response(output, isStream); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java index 0553645ece58..c763a50c43ed 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java @@ -8,20 +8,15 @@ import com.dotcms.ai.domain.Model; import com.dotcms.ai.exception.DotAIAllModelsExhaustedException; import com.dotcms.ai.validator.AIAppValidator; -import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.util.UtilMethods; import io.vavr.Tuple; import io.vavr.Tuple2; import io.vavr.control.Try; import org.apache.commons.io.IOUtils; -import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; import java.io.OutputStream; import java.io.Serializable; -import java.nio.charset.StandardCharsets; import java.util.Optional; /** @@ -55,23 +50,24 @@ public class AIModelFallbackStrategy implements AIClientStrategy { * @param client the AI client to which the request is sent * @param handler the response evaluator to handle the response * @param request the AI request to be processed - * @param output the output stream to which the response will be written + * @param incoming the output stream to which the response will be written + * @return response data object * @throws DotAIAllModelsExhaustedException if all models are exhausted and no successful response is obtained */ @Override - public void applyStrategy(final AIClient client, - final AIResponseEvaluator handler, - final AIRequest request, - final OutputStream output) { + public AIResponseData applyStrategy(final AIClient client, + final AIResponseEvaluator handler, + final AIRequest request, + final OutputStream incoming) { final JSONObjectAIRequest jsonRequest = AIClient.useRequestOrThrow(request); final Tuple2 modelTuple = resolveModel(jsonRequest); - final AIResponseData firstAttempt = sendAttempt(client, handler, jsonRequest, output, modelTuple); + final AIResponseData firstAttempt = sendRequest(client, handler, jsonRequest, incoming, modelTuple); if (firstAttempt.isSuccess()) { - return; + return firstAttempt; } - runFallbacks(client, handler, jsonRequest, output, modelTuple); + return runFallbacks(client, handler, jsonRequest, incoming, modelTuple); } private static Tuple2 resolveModel(final JSONObjectAIRequest request) { @@ -96,11 +92,7 @@ private static Tuple2 resolveModel(final JSONObjectAIRequest req } private static boolean isSameAsFirst(final Model firstAttempt, final Model model) { - if (firstAttempt.equals(model)) { - return true; - } - - return false; + return firstAttempt.equals(model); } private static boolean isOperational(final Model model) { @@ -114,31 +106,22 @@ private static boolean isOperational(final Model model) { return true; } - private static AIResponseData doSend(final AIClient client, final AIRequest request) { - final ByteArrayOutputStream output = new ByteArrayOutputStream(); + private static AIResponseData doSend(final AIClient client, + final JSONObjectAIRequest request, + final OutputStream incoming, + final boolean isStream) { + final OutputStream output = Optional.ofNullable(incoming).orElseGet(ByteArrayOutputStream::new); client.sendRequest(request, output); - final AIResponseData responseData = new AIResponseData(); - responseData.setResponse(output.toString()); - IOUtils.closeQuietly(output); - - return responseData; + return AIClientStrategy.response(output, isStream); } - private static void redirectOutput(final OutputStream output, final String response) { - try (final InputStream input = new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8))) { - IOUtils.copy(input, output); - } catch (IOException e) { - throw new DotRuntimeException(e); - } - } - - private static void notifyFailure(final AIModel aiModel, final AIRequest request) { + private static void notifyFailure(final AIModel aiModel, final JSONObjectAIRequest request) { AIAppValidator.get().validateModelsUsage(aiModel, request.getUserId()); } private static void handleFailure(final Tuple2 modelTuple, - final AIRequest request, + final JSONObjectAIRequest request, final AIResponseData responseData) { final AIModel aiModel = modelTuple._1; final Model model = modelTuple._2; @@ -170,41 +153,46 @@ private static void handleFailure(final Tuple2 modelTuple, } } - private static AIResponseData sendAttempt(final AIClient client, + private static AIResponseData sendRequest(final AIClient client, final AIResponseEvaluator evaluator, final JSONObjectAIRequest request, final OutputStream output, final Tuple2 modelTuple) { - + final boolean isStream = AIClientStrategy.isStream(request); final AIResponseData responseData = Try - .of(() -> doSend(client, request)) + .of(() -> doSend(client, request, output, isStream)) .getOrElseGet(exception -> fromException(evaluator, exception)); - if (!responseData.isSuccess()) { - if (responseData.getStatus().doesNeedToThrow()) { - if (!modelTuple._1.isOperational()) { - AppConfig.debugLogger( - AIModelFallbackStrategy.class, - () -> String.format( - "All models from type [%s] are not operational. Throwing exception.", - modelTuple._1.getType())); - notifyFailure(modelTuple._1, request); + try { + if (!responseData.isSuccess()) { + if (responseData.getStatus().doesNeedToThrow()) { + if (!modelTuple._1.isOperational()) { + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format( + "All models from type [%s] are not operational. Throwing exception.", + modelTuple._1.getType())); + notifyFailure(modelTuple._1, request); + } + throw responseData.getException(); } - throw responseData.getException(); + } else { + evaluator.fromResponse(responseData.getResponse(), responseData, !isStream); } - } else { - evaluator.fromResponse(responseData.getResponse(), responseData, output instanceof ByteArrayOutputStream); - } - if (responseData.isSuccess()) { - AppConfig.debugLogger( - AIModelFallbackStrategy.class, - () -> String.format("Model [%s] succeeded. No need to fallback.", modelTuple._2.getName())); - redirectOutput(output, responseData.getResponse()); - } else { - logFailure(modelTuple, responseData); + if (responseData.isSuccess()) { + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format("Model [%s] succeeded. No need to fallback.", modelTuple._2.getName())); + } else { + logFailure(modelTuple, responseData); - handleFailure(modelTuple, request, responseData); + handleFailure(modelTuple, request, responseData); + } + } finally { + if (!isStream) { + IOUtils.closeQuietly(responseData.getOutput()); + } } return responseData; @@ -235,27 +223,29 @@ private static AIResponseData fromException(final AIResponseEvaluator evaluator, return metadata; } - private static void runFallbacks(final AIClient client, - final AIResponseEvaluator evaluator, - final JSONObjectAIRequest request, - final OutputStream output, - final Tuple2 modelTuple) { + private static AIResponseData runFallbacks(final AIClient client, + final AIResponseEvaluator evaluator, + final JSONObjectAIRequest request, + final OutputStream output, + final Tuple2 modelTuple) { for(final Model model : modelTuple._1.getModels()) { if (isSameAsFirst(modelTuple._2, model) || !isOperational(model)) { continue; } request.getPayload().put(AiKeys.MODEL, model.getName()); - final AIResponseData responseData = sendAttempt( + final AIResponseData responseData = sendRequest( client, evaluator, request, output, Tuple.of(modelTuple._1, model)); if (responseData.isSuccess()) { - return; + return responseData; } } + + return null; } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java index 73d675a3b90e..151fe44e2aaa 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java @@ -1,8 +1,8 @@ package com.dotcms.ai.client; import com.dotcms.ai.domain.AIResponse; +import com.dotcms.ai.domain.AIResponseData; -import java.io.ByteArrayOutputStream; import java.io.OutputStream; import java.io.Serializable; import java.util.Optional; @@ -73,13 +73,11 @@ public static AIProxiedClient of(final AIClient client, final AIProxyStrategy st * @return the AI response */ public AIResponse sendToAI(final AIRequest request, final OutputStream output) { - final OutputStream finalOutput = Optional.ofNullable(output).orElseGet(ByteArrayOutputStream::new); - - strategy.applyStrategy(client, responseEvaluator, request, finalOutput); + final AIResponseData responseData = strategy.applyStrategy(client, responseEvaluator, request, output); return Optional.ofNullable(output) .map(out -> AIResponse.EMPTY) - .orElseGet(() -> AIResponse.builder().withResponse(finalOutput.toString()).build()); + .orElseGet(() -> AIResponse.builder().withResponse(responseData.getOutput().toString()).build()); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java index ab12dbba58f3..5bb158ae8972 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java @@ -16,7 +16,6 @@ import com.dotmarketing.util.Logger; import com.dotmarketing.util.json.JSONObject; import io.vavr.Lazy; -import io.vavr.Tuple; import io.vavr.Tuple2; import io.vavr.control.Try; import org.apache.http.HttpHeaders; @@ -29,6 +28,7 @@ import org.apache.http.impl.client.HttpClients; import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; import java.io.BufferedInputStream; import java.io.OutputStream; import java.io.Serializable; @@ -129,17 +129,19 @@ public void sendRequest(final AIRequest request, fin lastRestCall.put(aiModel, System.currentTimeMillis()); - try (CloseableHttpClient httpClient = HttpClients.createDefault()) { + try (final CloseableHttpClient httpClient = HttpClients.createDefault()) { final StringEntity jsonEntity = new StringEntity(payload.toString(), ContentType.APPLICATION_JSON); final HttpUriRequest httpRequest = AIClient.resolveMethod(jsonRequest.getMethod(), jsonRequest.getUrl()); httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON); httpRequest.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + appConfig.getApiKey()); if (!payload.getAsMap().isEmpty()) { - Try.run(() -> HttpEntityEnclosingRequestBase.class.cast(httpRequest).setEntity(jsonEntity)); + Try.run(() -> ((HttpEntityEnclosingRequestBase) httpRequest).setEntity(jsonEntity)); } - try (CloseableHttpResponse response = httpClient.execute(httpRequest)) { + try (final CloseableHttpResponse response = httpClient.execute(httpRequest)) { + onStreamCheckFotStatusCode(modelName, payload, response); + final BufferedInputStream in = new BufferedInputStream(response.getEntity().getContent()); final byte[] buffer = new byte[1024]; int len; @@ -161,4 +163,17 @@ public void sendRequest(final AIRequest request, fin } } + private static void onStreamCheckFotStatusCode(final String modelName, + final JSONObject payload, + final CloseableHttpResponse response) { + if (payload.optBoolean(AiKeys.STREAM, false)) { + final int statusCode = response.getStatusLine().getStatusCode(); + if (Response.Status.Family.familyOf(statusCode) == Response.Status.Family.CLIENT_ERROR) { + throw new DotAIModelNotFoundException(String.format( + "Model used [%s] in request in stream mode is not found", + modelName)); + } + } + } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java index 8d9887b24571..dff8cacca25d 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java @@ -41,7 +41,6 @@ public Builder withResponse(final String response) { return this; } - public AIResponse build() { return new AIResponse(this); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseData.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseData.java index 85ac2d9d0483..b8c58b17eae6 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseData.java +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseData.java @@ -3,6 +3,8 @@ import com.dotmarketing.exception.DotRuntimeException; import org.apache.commons.lang3.StringUtils; +import java.io.OutputStream; + /** * Represents the data of a response from an AI service. * @@ -20,6 +22,7 @@ public class AIResponseData { private String error; private ModelStatus status; private DotRuntimeException exception; + private OutputStream output; public String getResponse() { return response; @@ -53,6 +56,14 @@ public void setException(DotRuntimeException exception) { this.exception = exception; } + public OutputStream getOutput() { + return output; + } + + public void setOutput(OutputStream output) { + this.output = output; + } + public boolean isSuccess() { return StringUtils.isBlank(error); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java index 5499de4ce660..0351ec0bb4c8 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java @@ -56,14 +56,15 @@ public class CompletionsResource { public final Response summarizeFromContent(@Context final HttpServletRequest request, @Context final HttpServletResponse response, final CompletionsForm formIn) { + final CompletionsForm resolvedForm = resolveForm(request, response, formIn); return getResponse( request, response, formIn, - () -> APILocator.getDotAIAPI().getCompletionsAPI().summarize(formIn), + () -> APILocator.getDotAIAPI().getCompletionsAPI().summarize(resolvedForm), output -> APILocator.getDotAIAPI() .getCompletionsAPI() - .summarizeStream(formIn, new LineReadingOutputStream(output))); + .summarizeStream(resolvedForm, new LineReadingOutputStream(output))); } /** @@ -81,14 +82,15 @@ public final Response summarizeFromContent(@Context final HttpServletRequest req public final Response rawPrompt(@Context final HttpServletRequest request, @Context final HttpServletResponse response, final CompletionsForm formIn) { + final CompletionsForm resolvedForm = resolveForm(request, response, formIn); return getResponse( request, response, formIn, - () -> APILocator.getDotAIAPI().getCompletionsAPI().raw(formIn), + () -> APILocator.getDotAIAPI().getCompletionsAPI().raw(resolvedForm), output -> APILocator.getDotAIAPI() .getCompletionsAPI() - .rawStream(formIn, new LineReadingOutputStream(output))); + .rawStream(resolvedForm, new LineReadingOutputStream(output))); } /** @@ -180,6 +182,7 @@ private static Response getResponse(final HttpServletRequest request, final JSONObject jsonResponse = noStream.get(); jsonResponse.put(AiKeys.TOTAL_TIME, System.currentTimeMillis() - startTime + "ms"); + return Response.ok(jsonResponse.toString(), MediaType.APPLICATION_JSON).build(); } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java index af100d725f2f..5fa1d0f05213 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java @@ -238,6 +238,7 @@ public void test_callToAI_withProvidedOutput() throws Exception { final JSONObjectAIRequest request = textRequest( model, "What are the major achievements of the Apollo space program?"); + request.getPayload().put(AiKeys.STREAM, true); final AIResponse aiResponse = aiProxyClient.callToAI( request,