Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#30361: allowing stream chat to dotAI #30377

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.dotcms.ai.client;

import com.dotcms.ai.AiKeys;
import com.dotcms.ai.domain.AIResponse;
import com.dotcms.ai.domain.AIResponseData;

import java.io.OutputStream;
import java.io.Serializable;
Expand All @@ -23,19 +25,62 @@
*/
public interface AIClientStrategy {

AIClientStrategy NOOP = (client, handler, request, output) -> AIResponse.builder().build();
AIClientStrategy NOOP = (client, handler, request, output) -> {
AIResponse.builder().build();
return null;
};

/**
* Applies the strategy to the given AI client request and handles the response.
*
* @param client the AI client to which the request is sent
* @param handler the response evaluator to handle the response
* @param request the AI request to be processed
* @param output the output stream to which the response will be written
* @param incoming the output stream to which the response will be written
* @return response data object
*/
void applyStrategy(AIClient client,
AIResponseEvaluator handler,
AIRequest<? extends Serializable> request,
OutputStream output);
AIResponseData applyStrategy(AIClient client,
AIResponseEvaluator handler,
AIRequest<? extends Serializable> request,
OutputStream incoming);

/**
* Converts the given output stream to an AIResponseData object.
*
* <p>
* This method takes an output stream, converts its content to a string, and
* sets it as the response in an AIResponseData object. The output stream is
* also set in the AIResponseData object.
* </p>
*
* @param output the output stream containing the response data
* @param isStream is stream flag
* @return an AIResponseData object containing the response and the output stream
*/
static AIResponseData response(final OutputStream output, boolean isStream) {
final AIResponseData responseData = new AIResponseData();
if (!isStream) {
responseData.setResponse(output.toString());
}
responseData.setOutput(output);

return responseData;
}

/**
* Checks if the given request is a stream request.
*
* <p>
* This method examines the payload of the provided `JSONObjectAIRequest` to determine
* if it contains a stream flag set to true. If the stream flag is present and set to true,
* the method returns true, indicating that the request is a stream request.
* </p>
*
* @param request the `JSONObjectAIRequest` to be checked
* @return true if the request is a stream request, false otherwise
*/
static boolean isStream(final JSONObjectAIRequest request) {
return request.getPayload().optBoolean(AiKeys.STREAM, false);
}

}
20 changes: 15 additions & 5 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package com.dotcms.ai.client;

import com.dotcms.ai.domain.AIResponseData;

import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.Optional;

/**
* Default implementation of the {@link AIClientStrategy} interface.
Expand All @@ -22,11 +26,17 @@
public class AIDefaultStrategy implements AIClientStrategy {

@Override
public void applyStrategy(final AIClient client,
final AIResponseEvaluator handler,
final AIRequest<? extends Serializable> request,
final OutputStream output) {
client.sendRequest(request, output);
public AIResponseData applyStrategy(final AIClient client,
final AIResponseEvaluator handler,
final AIRequest<? extends Serializable> request,
final OutputStream incoming) {
final JSONObjectAIRequest jsonRequest = AIClient.useRequestOrThrow(request);
final boolean isStream = AIClientStrategy.isStream(jsonRequest);
final OutputStream output = Optional.ofNullable(incoming).orElseGet(ByteArrayOutputStream::new);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not bad


client.sendRequest(jsonRequest, output);

return AIClientStrategy.response(output, isStream);
}

}
122 changes: 56 additions & 66 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,15 @@
import com.dotcms.ai.domain.Model;
import com.dotcms.ai.exception.DotAIAllModelsExhaustedException;
import com.dotcms.ai.validator.AIAppValidator;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.UtilMethods;
import io.vavr.Tuple;
import io.vavr.Tuple2;
import io.vavr.control.Try;
import org.apache.commons.io.IOUtils;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.util.Optional;

/**
Expand Down Expand Up @@ -55,23 +50,24 @@ public class AIModelFallbackStrategy implements AIClientStrategy {
* @param client the AI client to which the request is sent
* @param handler the response evaluator to handle the response
* @param request the AI request to be processed
* @param output the output stream to which the response will be written
* @param incoming the output stream to which the response will be written
* @return response data object
* @throws DotAIAllModelsExhaustedException if all models are exhausted and no successful response is obtained
*/
@Override
public void applyStrategy(final AIClient client,
final AIResponseEvaluator handler,
final AIRequest<? extends Serializable> request,
final OutputStream output) {
public AIResponseData applyStrategy(final AIClient client,
final AIResponseEvaluator handler,
final AIRequest<? extends Serializable> request,
final OutputStream incoming) {
final JSONObjectAIRequest jsonRequest = AIClient.useRequestOrThrow(request);
final Tuple2<AIModel, Model> modelTuple = resolveModel(jsonRequest);

final AIResponseData firstAttempt = sendAttempt(client, handler, jsonRequest, output, modelTuple);
final AIResponseData firstAttempt = sendRequest(client, handler, jsonRequest, incoming, modelTuple);
if (firstAttempt.isSuccess()) {
return;
return firstAttempt;
}

runFallbacks(client, handler, jsonRequest, output, modelTuple);
return runFallbacks(client, handler, jsonRequest, incoming, modelTuple);
}

private static Tuple2<AIModel, Model> resolveModel(final JSONObjectAIRequest request) {
Expand All @@ -96,11 +92,7 @@ private static Tuple2<AIModel, Model> resolveModel(final JSONObjectAIRequest req
}

private static boolean isSameAsFirst(final Model firstAttempt, final Model model) {
if (firstAttempt.equals(model)) {
return true;
}

return false;
return firstAttempt.equals(model);
}

private static boolean isOperational(final Model model) {
Expand All @@ -114,31 +106,22 @@ private static boolean isOperational(final Model model) {
return true;
}

private static AIResponseData doSend(final AIClient client, final AIRequest<? extends Serializable> request) {
final ByteArrayOutputStream output = new ByteArrayOutputStream();
private static AIResponseData doSend(final AIClient client,
final JSONObjectAIRequest request,
final OutputStream incoming,
final boolean isStream) {
final OutputStream output = Optional.ofNullable(incoming).orElseGet(ByteArrayOutputStream::new);
client.sendRequest(request, output);

final AIResponseData responseData = new AIResponseData();
responseData.setResponse(output.toString());
IOUtils.closeQuietly(output);

return responseData;
return AIClientStrategy.response(output, isStream);
}

private static void redirectOutput(final OutputStream output, final String response) {
try (final InputStream input = new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8))) {
IOUtils.copy(input, output);
} catch (IOException e) {
throw new DotRuntimeException(e);
}
}

private static void notifyFailure(final AIModel aiModel, final AIRequest<? extends Serializable> request) {
private static void notifyFailure(final AIModel aiModel, final JSONObjectAIRequest request) {
AIAppValidator.get().validateModelsUsage(aiModel, request.getUserId());
}

private static void handleFailure(final Tuple2<AIModel, Model> modelTuple,
final AIRequest<? extends Serializable> request,
final JSONObjectAIRequest request,
final AIResponseData responseData) {
final AIModel aiModel = modelTuple._1;
final Model model = modelTuple._2;
Expand Down Expand Up @@ -170,41 +153,46 @@ private static void handleFailure(final Tuple2<AIModel, Model> modelTuple,
}
}

private static AIResponseData sendAttempt(final AIClient client,
private static AIResponseData sendRequest(final AIClient client,
final AIResponseEvaluator evaluator,
final JSONObjectAIRequest request,
final OutputStream output,
final Tuple2<AIModel, Model> modelTuple) {

final boolean isStream = AIClientStrategy.isStream(request);
final AIResponseData responseData = Try
.of(() -> doSend(client, request))
.of(() -> doSend(client, request, output, isStream))
.getOrElseGet(exception -> fromException(evaluator, exception));

if (!responseData.isSuccess()) {
if (responseData.getStatus().doesNeedToThrow()) {
if (!modelTuple._1.isOperational()) {
AppConfig.debugLogger(
AIModelFallbackStrategy.class,
() -> String.format(
"All models from type [%s] are not operational. Throwing exception.",
modelTuple._1.getType()));
notifyFailure(modelTuple._1, request);
try {
if (!responseData.isSuccess()) {
if (responseData.getStatus().doesNeedToThrow()) {
if (!modelTuple._1.isOperational()) {
AppConfig.debugLogger(
AIModelFallbackStrategy.class,
() -> String.format(
"All models from type [%s] are not operational. Throwing exception.",
modelTuple._1.getType()));
notifyFailure(modelTuple._1, request);
}
throw responseData.getException();
}
throw responseData.getException();
} else {
evaluator.fromResponse(responseData.getResponse(), responseData, !isStream);
}
} else {
evaluator.fromResponse(responseData.getResponse(), responseData, output instanceof ByteArrayOutputStream);
}

if (responseData.isSuccess()) {
AppConfig.debugLogger(
AIModelFallbackStrategy.class,
() -> String.format("Model [%s] succeeded. No need to fallback.", modelTuple._2.getName()));
redirectOutput(output, responseData.getResponse());
} else {
logFailure(modelTuple, responseData);
if (responseData.isSuccess()) {
AppConfig.debugLogger(
AIModelFallbackStrategy.class,
() -> String.format("Model [%s] succeeded. No need to fallback.", modelTuple._2.getName()));
} else {
logFailure(modelTuple, responseData);

handleFailure(modelTuple, request, responseData);
handleFailure(modelTuple, request, responseData);
}
} finally {
if (!isStream) {
IOUtils.closeQuietly(responseData.getOutput());
}
}

return responseData;
Expand Down Expand Up @@ -235,27 +223,29 @@ private static AIResponseData fromException(final AIResponseEvaluator evaluator,
return metadata;
}

private static void runFallbacks(final AIClient client,
final AIResponseEvaluator evaluator,
final JSONObjectAIRequest request,
final OutputStream output,
final Tuple2<AIModel, Model> modelTuple) {
private static AIResponseData runFallbacks(final AIClient client,
final AIResponseEvaluator evaluator,
final JSONObjectAIRequest request,
final OutputStream output,
final Tuple2<AIModel, Model> modelTuple) {
for(final Model model : modelTuple._1.getModels()) {
if (isSameAsFirst(modelTuple._2, model) || !isOperational(model)) {
continue;
}

request.getPayload().put(AiKeys.MODEL, model.getName());
final AIResponseData responseData = sendAttempt(
final AIResponseData responseData = sendRequest(
client,
evaluator,
request,
output,
Tuple.of(modelTuple._1, model));
if (responseData.isSuccess()) {
return;
return responseData;
}
}

return null;
}

}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package com.dotcms.ai.client;

import com.dotcms.ai.domain.AIResponse;
import com.dotcms.ai.domain.AIResponseData;

import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.Optional;
Expand Down Expand Up @@ -73,13 +73,11 @@ public static AIProxiedClient of(final AIClient client, final AIProxyStrategy st
* @return the AI response
*/
public <T extends Serializable> AIResponse sendToAI(final AIRequest<T> request, final OutputStream output) {
final OutputStream finalOutput = Optional.ofNullable(output).orElseGet(ByteArrayOutputStream::new);

strategy.applyStrategy(client, responseEvaluator, request, finalOutput);
final AIResponseData responseData = strategy.applyStrategy(client, responseEvaluator, request, output);

return Optional.ofNullable(output)
.map(out -> AIResponse.EMPTY)
.orElseGet(() -> AIResponse.builder().withResponse(finalOutput.toString()).build());
.orElseGet(() -> AIResponse.builder().withResponse(responseData.getOutput().toString()).build());
}

}
Loading