Skip to content

Commit

Permalink
#30361: allowing stream chat to dotAI
Browse files Browse the repository at this point in the history
  • Loading branch information
victoralfaro-dotcms committed Oct 24, 2024
1 parent b5017f1 commit b98ee1a
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 91 deletions.
57 changes: 51 additions & 6 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.dotcms.ai.client;

import com.dotcms.ai.AiKeys;
import com.dotcms.ai.domain.AIResponse;
import com.dotcms.ai.domain.AIResponseData;

import java.io.OutputStream;
import java.io.Serializable;
Expand All @@ -23,19 +25,62 @@
*/
public interface AIClientStrategy {

AIClientStrategy NOOP = (client, handler, request, output) -> AIResponse.builder().build();
AIClientStrategy NOOP = (client, handler, request, output) -> {
AIResponse.builder().build();
return null;
};

/**
* Applies the strategy to the given AI client request and handles the response.
*
* @param client the AI client to which the request is sent
* @param handler the response evaluator to handle the response
* @param request the AI request to be processed
* @param output the output stream to which the response will be written
* @param incoming the output stream to which the response will be written
* @return response data object
*/
void applyStrategy(AIClient client,
AIResponseEvaluator handler,
AIRequest<? extends Serializable> request,
OutputStream output);
AIResponseData applyStrategy(AIClient client,
AIResponseEvaluator handler,
AIRequest<? extends Serializable> request,
OutputStream incoming);

/**
* Converts the given output stream to an AIResponseData object.
*
* <p>
* This method takes an output stream, converts its content to a string, and
* sets it as the response in an AIResponseData object. The output stream is
* also set in the AIResponseData object.
* </p>
*
* @param output the output stream containing the response data
* @param isStream is stream flag
* @return an AIResponseData object containing the response and the output stream
*/
static AIResponseData response(final OutputStream output, boolean isStream) {
final AIResponseData responseData = new AIResponseData();
if (!isStream) {
responseData.setResponse(output.toString());
}
responseData.setOutput(output);

return responseData;
}

/**
* Checks if the given request is a stream request.
*
* <p>
* This method examines the payload of the provided `JSONObjectAIRequest` to determine
* if it contains a stream flag set to true. If the stream flag is present and set to true,
* the method returns true, indicating that the request is a stream request.
* </p>
*
* @param request the `JSONObjectAIRequest` to be checked
* @return true if the request is a stream request, false otherwise
*/
static boolean isStream(final JSONObjectAIRequest request) {
return request.getPayload().optBoolean(AiKeys.STREAM, false);
}

}
20 changes: 15 additions & 5 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package com.dotcms.ai.client;

import com.dotcms.ai.domain.AIResponseData;

import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.Optional;

/**
* Default implementation of the {@link AIClientStrategy} interface.
Expand All @@ -22,11 +26,17 @@
public class AIDefaultStrategy implements AIClientStrategy {

@Override
public void applyStrategy(final AIClient client,
final AIResponseEvaluator handler,
final AIRequest<? extends Serializable> request,
final OutputStream output) {
client.sendRequest(request, output);
public AIResponseData applyStrategy(final AIClient client,
final AIResponseEvaluator handler,
final AIRequest<? extends Serializable> request,
final OutputStream incoming) {
final JSONObjectAIRequest jsonRequest = AIClient.useRequestOrThrow(request);
final boolean isStream = AIClientStrategy.isStream(jsonRequest);
final OutputStream output = Optional.ofNullable(incoming).orElseGet(ByteArrayOutputStream::new);

client.sendRequest(jsonRequest, output);

return AIClientStrategy.response(output, isStream);
}

}
122 changes: 56 additions & 66 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,15 @@
import com.dotcms.ai.domain.Model;
import com.dotcms.ai.exception.DotAIAllModelsExhaustedException;
import com.dotcms.ai.validator.AIAppValidator;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.UtilMethods;
import io.vavr.Tuple;
import io.vavr.Tuple2;
import io.vavr.control.Try;
import org.apache.commons.io.IOUtils;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.util.Optional;

/**
Expand Down Expand Up @@ -55,23 +50,24 @@ public class AIModelFallbackStrategy implements AIClientStrategy {
* @param client the AI client to which the request is sent
* @param handler the response evaluator to handle the response
* @param request the AI request to be processed
* @param output the output stream to which the response will be written
* @param incoming the output stream to which the response will be written
* @return response data object
* @throws DotAIAllModelsExhaustedException if all models are exhausted and no successful response is obtained
*/
@Override
public void applyStrategy(final AIClient client,
final AIResponseEvaluator handler,
final AIRequest<? extends Serializable> request,
final OutputStream output) {
public AIResponseData applyStrategy(final AIClient client,
final AIResponseEvaluator handler,
final AIRequest<? extends Serializable> request,
final OutputStream incoming) {
final JSONObjectAIRequest jsonRequest = AIClient.useRequestOrThrow(request);
final Tuple2<AIModel, Model> modelTuple = resolveModel(jsonRequest);

final AIResponseData firstAttempt = sendAttempt(client, handler, jsonRequest, output, modelTuple);
final AIResponseData firstAttempt = sendRequest(client, handler, jsonRequest, incoming, modelTuple);
if (firstAttempt.isSuccess()) {
return;
return firstAttempt;
}

runFallbacks(client, handler, jsonRequest, output, modelTuple);
return runFallbacks(client, handler, jsonRequest, incoming, modelTuple);
}

private static Tuple2<AIModel, Model> resolveModel(final JSONObjectAIRequest request) {
Expand All @@ -96,11 +92,7 @@ private static Tuple2<AIModel, Model> resolveModel(final JSONObjectAIRequest req
}

private static boolean isSameAsFirst(final Model firstAttempt, final Model model) {
if (firstAttempt.equals(model)) {
return true;
}

return false;
return firstAttempt.equals(model);
}

private static boolean isOperational(final Model model) {
Expand All @@ -114,31 +106,22 @@ private static boolean isOperational(final Model model) {
return true;
}

private static AIResponseData doSend(final AIClient client, final AIRequest<? extends Serializable> request) {
final ByteArrayOutputStream output = new ByteArrayOutputStream();
private static AIResponseData doSend(final AIClient client,
final JSONObjectAIRequest request,
final OutputStream incoming,
final boolean isStream) {
final OutputStream output = Optional.ofNullable(incoming).orElseGet(ByteArrayOutputStream::new);
client.sendRequest(request, output);

final AIResponseData responseData = new AIResponseData();
responseData.setResponse(output.toString());
IOUtils.closeQuietly(output);

return responseData;
return AIClientStrategy.response(output, isStream);
}

private static void redirectOutput(final OutputStream output, final String response) {
try (final InputStream input = new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8))) {
IOUtils.copy(input, output);
} catch (IOException e) {
throw new DotRuntimeException(e);
}
}

private static void notifyFailure(final AIModel aiModel, final AIRequest<? extends Serializable> request) {
private static void notifyFailure(final AIModel aiModel, final JSONObjectAIRequest request) {
AIAppValidator.get().validateModelsUsage(aiModel, request.getUserId());
}

private static void handleFailure(final Tuple2<AIModel, Model> modelTuple,
final AIRequest<? extends Serializable> request,
final JSONObjectAIRequest request,
final AIResponseData responseData) {
final AIModel aiModel = modelTuple._1;
final Model model = modelTuple._2;
Expand Down Expand Up @@ -170,41 +153,46 @@ private static void handleFailure(final Tuple2<AIModel, Model> modelTuple,
}
}

private static AIResponseData sendAttempt(final AIClient client,
private static AIResponseData sendRequest(final AIClient client,
final AIResponseEvaluator evaluator,
final JSONObjectAIRequest request,
final OutputStream output,
final Tuple2<AIModel, Model> modelTuple) {

final boolean isStream = AIClientStrategy.isStream(request);
final AIResponseData responseData = Try
.of(() -> doSend(client, request))
.of(() -> doSend(client, request, output, isStream))
.getOrElseGet(exception -> fromException(evaluator, exception));

if (!responseData.isSuccess()) {
if (responseData.getStatus().doesNeedToThrow()) {
if (!modelTuple._1.isOperational()) {
AppConfig.debugLogger(
AIModelFallbackStrategy.class,
() -> String.format(
"All models from type [%s] are not operational. Throwing exception.",
modelTuple._1.getType()));
notifyFailure(modelTuple._1, request);
try {
if (!responseData.isSuccess()) {
if (responseData.getStatus().doesNeedToThrow()) {
if (!modelTuple._1.isOperational()) {
AppConfig.debugLogger(
AIModelFallbackStrategy.class,
() -> String.format(
"All models from type [%s] are not operational. Throwing exception.",
modelTuple._1.getType()));
notifyFailure(modelTuple._1, request);
}
throw responseData.getException();
}
throw responseData.getException();
} else {
evaluator.fromResponse(responseData.getResponse(), responseData, !isStream);
}
} else {
evaluator.fromResponse(responseData.getResponse(), responseData, output instanceof ByteArrayOutputStream);
}

if (responseData.isSuccess()) {
AppConfig.debugLogger(
AIModelFallbackStrategy.class,
() -> String.format("Model [%s] succeeded. No need to fallback.", modelTuple._2.getName()));
redirectOutput(output, responseData.getResponse());
} else {
logFailure(modelTuple, responseData);
if (responseData.isSuccess()) {
AppConfig.debugLogger(
AIModelFallbackStrategy.class,
() -> String.format("Model [%s] succeeded. No need to fallback.", modelTuple._2.getName()));
} else {
logFailure(modelTuple, responseData);

handleFailure(modelTuple, request, responseData);
handleFailure(modelTuple, request, responseData);
}
} finally {
if (!isStream) {
IOUtils.closeQuietly(responseData.getOutput());
}
}

return responseData;
Expand Down Expand Up @@ -235,27 +223,29 @@ private static AIResponseData fromException(final AIResponseEvaluator evaluator,
return metadata;
}

private static void runFallbacks(final AIClient client,
final AIResponseEvaluator evaluator,
final JSONObjectAIRequest request,
final OutputStream output,
final Tuple2<AIModel, Model> modelTuple) {
private static AIResponseData runFallbacks(final AIClient client,
final AIResponseEvaluator evaluator,
final JSONObjectAIRequest request,
final OutputStream output,
final Tuple2<AIModel, Model> modelTuple) {
for(final Model model : modelTuple._1.getModels()) {
if (isSameAsFirst(modelTuple._2, model) || !isOperational(model)) {
continue;
}

request.getPayload().put(AiKeys.MODEL, model.getName());
final AIResponseData responseData = sendAttempt(
final AIResponseData responseData = sendRequest(
client,
evaluator,
request,
output,
Tuple.of(modelTuple._1, model));
if (responseData.isSuccess()) {
return;
return responseData;
}
}

return null;
}

}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package com.dotcms.ai.client;

import com.dotcms.ai.domain.AIResponse;
import com.dotcms.ai.domain.AIResponseData;

import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.Optional;
Expand Down Expand Up @@ -73,13 +73,11 @@ public static AIProxiedClient of(final AIClient client, final AIProxyStrategy st
* @return the AI response
*/
public <T extends Serializable> AIResponse sendToAI(final AIRequest<T> request, final OutputStream output) {
final OutputStream finalOutput = Optional.ofNullable(output).orElseGet(ByteArrayOutputStream::new);

strategy.applyStrategy(client, responseEvaluator, request, finalOutput);
final AIResponseData responseData = strategy.applyStrategy(client, responseEvaluator, request, output);

return Optional.ofNullable(output)
.map(out -> AIResponse.EMPTY)
.orElseGet(() -> AIResponse.builder().withResponse(finalOutput.toString()).build());
.orElseGet(() -> AIResponse.builder().withResponse(responseData.getOutput().toString()).build());
}

}
Loading

0 comments on commit b98ee1a

Please sign in to comment.