diff --git a/cli/src/main/java/com/box/l10n/mojito/cli/command/RepositoryAiTranslationCommand.java b/cli/src/main/java/com/box/l10n/mojito/cli/command/RepositoryAiTranslationCommand.java index d4b431c934..7986d5ab84 100644 --- a/cli/src/main/java/com/box/l10n/mojito/cli/command/RepositoryAiTranslationCommand.java +++ b/cli/src/main/java/com/box/l10n/mojito/cli/command/RepositoryAiTranslationCommand.java @@ -43,8 +43,8 @@ public class RepositoryAiTranslationCommand extends Command { @Parameter( names = {Param.REPOSITORY_LOCALES_LONG, Param.REPOSITORY_LOCALES_SHORT}, variableArity = true, - required = true, - description = "List of locales (bcp47 tags) to machine translate") + description = + "List of locales (bcp47 tags) to translate, if not provided translate all locales in the repository") List locales; @Parameter( @@ -55,6 +55,12 @@ public class RepositoryAiTranslationCommand extends Command { + "sending too many strings to MT)") int sourceTextMaxCount = 100; + @Parameter( + names = {"--use-batch"}, + arity = 1, + description = "To use the batch API or not") + boolean useBatch = false; + @Autowired CommandHelper commandHelper; @Autowired RepositoryAiTranslateClient repositoryAiTranslateClient; @@ -75,13 +81,13 @@ public void execute() throws CommandException { .reset() .a(" for locales: ") .fg(Color.CYAN) - .a(locales.stream().collect(Collectors.joining(", ", "[", "]"))) + .a(locales == null ? "" : locales.stream().collect(Collectors.joining(", ", "[", "]"))) .println(2); ProtoAiTranslateResponse protoAiTranslateResponse = repositoryAiTranslateClient.translateRepository( new RepositoryAiTranslateClient.ProtoAiTranslateRequest( - repositoryParam, locales, sourceTextMaxCount)); + repositoryParam, locales, sourceTextMaxCount, useBatch)); PollableTask pollableTask = protoAiTranslateResponse.pollableTask(); commandHelper.waitForPollableTask(pollableTask.getId()); diff --git a/common/src/main/java/com/box/l10n/mojito/openai/OpenAIClient.java b/common/src/main/java/com/box/l10n/mojito/openai/OpenAIClient.java index a804929bb4..35290cd880 100644 --- a/common/src/main/java/com/box/l10n/mojito/openai/OpenAIClient.java +++ b/common/src/main/java/com/box/l10n/mojito/openai/OpenAIClient.java @@ -25,6 +25,8 @@ import java.util.Objects; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.ForkJoinPool; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -39,11 +41,19 @@ public class OpenAIClient { final HttpClient httpClient; - OpenAIClient(String apiKey, String host, ObjectMapper objectMapper, HttpClient httpClient) { + final Executor asyncExecutor; + + OpenAIClient( + String apiKey, + String host, + ObjectMapper objectMapper, + HttpClient httpClient, + Executor asyncExecutor) { this.apiKey = Objects.requireNonNull(apiKey); this.host = Objects.requireNonNull(host); this.objectMapper = Objects.requireNonNull(objectMapper); this.httpClient = Objects.requireNonNull(httpClient); + this.asyncExecutor = Objects.requireNonNull(asyncExecutor); } public static class Builder { @@ -56,6 +66,8 @@ public static class Builder { private HttpClient httpClient; + private Executor asyncExecutor; + public Builder() {} public Builder apiKey(String apiKey) { @@ -78,6 +90,11 @@ public Builder httpClient(HttpClient httpClient) { return this; } + public Builder asyncExecutor(Executor asyncExecutor) { + this.asyncExecutor = asyncExecutor; + return this; + } + public OpenAIClient build() { if (apiKey == null) { throw new IllegalStateException("API key must be provided"); @@ -89,11 +106,16 @@ public OpenAIClient build() { if (httpClient == null) { httpClient = createHttpClient(); } - return new OpenAIClient(apiKey, host, objectMapper, httpClient); + + if (asyncExecutor == null) { + asyncExecutor = ForkJoinPool.commonPool(); + } + + return new OpenAIClient(apiKey, host, objectMapper, httpClient, asyncExecutor); } private HttpClient createHttpClient() { - return HttpClient.newHttpClient(); + return HttpClient.newBuilder().build(); } private ObjectMapper createObjectMapper() { @@ -135,7 +157,7 @@ public CompletableFuture getChatCompletions( CompletableFuture chatCompletionsResponse = httpClient .sendAsync(request, HttpResponse.BodyHandlers.ofString()) - .thenApply( + .thenApplyAsync( httpResponse -> { if (httpResponse.statusCode() != 200) { throw new OpenAIClientResponseException("ChatCompletion failed", httpResponse); @@ -148,7 +170,8 @@ public CompletableFuture getChatCompletions( "Can't deserialize ChatCompletionsResponse", e, httpResponse); } } - }); + }, + asyncExecutor); return chatCompletionsResponse; } diff --git a/common/src/main/java/com/box/l10n/mojito/openai/OpenAIClientPool.java b/common/src/main/java/com/box/l10n/mojito/openai/OpenAIClientPool.java new file mode 100644 index 0000000000..f18ed0c433 --- /dev/null +++ b/common/src/main/java/com/box/l10n/mojito/openai/OpenAIClientPool.java @@ -0,0 +1,77 @@ +package com.box.l10n.mojito.openai; + +import com.google.common.base.Function; +import java.net.http.HttpClient; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadLocalRandom; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class OpenAIClientPool { + + static Logger logger = LoggerFactory.getLogger(OpenAIClientPool.class); + + int numberOfClients; + OpenAIClientWithSemaphore[] openAIClientWithSemaphores; + + /** + * Pool to parallelize slower requests (1s+) over HTTP/2 connections. + * + * @param numberOfClients Number of OpenAIClient instances with independent HttpClients. + * @param numberOfParallelRequestPerClient Maximum parallel requests per client, controlled by a + * semaphore to prevent overload. + * @param sizeOfAsyncProcessors Shared async processors across all HttpClients to limit threads, + * as request time is the main bottleneck. + * @param apiKey API key for authentication. + */ + public OpenAIClientPool( + int numberOfClients, + int numberOfParallelRequestPerClient, + int sizeOfAsyncProcessors, + String apiKey) { + ExecutorService asyncExecutor = Executors.newWorkStealingPool(sizeOfAsyncProcessors); + this.numberOfClients = numberOfClients; + this.openAIClientWithSemaphores = new OpenAIClientWithSemaphore[numberOfClients]; + for (int i = 0; i < numberOfClients; i++) { + this.openAIClientWithSemaphores[i] = + new OpenAIClientWithSemaphore( + OpenAIClient.builder() + .apiKey(apiKey) + .asyncExecutor(asyncExecutor) + .httpClient(HttpClient.newBuilder().executor(asyncExecutor).build()) + .build(), + new Semaphore(numberOfParallelRequestPerClient)); + } + } + + public CompletableFuture submit(Function> f) { + + while (true) { + for (OpenAIClientWithSemaphore openAIClientWithSemaphore : openAIClientWithSemaphores) { + if (openAIClientWithSemaphore.semaphore().tryAcquire()) { + return f.apply(openAIClientWithSemaphore.openAIClient()) + .whenComplete((o, e) -> openAIClientWithSemaphore.semaphore().release()); + } + } + + try { + logger.debug("can't directly acquire any semaphore, do blocking"); + int randomSemaphoreIndex = + ThreadLocalRandom.current().nextInt(openAIClientWithSemaphores.length); + OpenAIClientWithSemaphore randomClientWithSemaphore = + this.openAIClientWithSemaphores[randomSemaphoreIndex]; + randomClientWithSemaphore.semaphore().acquire(); + return f.apply(randomClientWithSemaphore.openAIClient()) + .whenComplete((o, e) -> randomClientWithSemaphore.semaphore().release()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Can't submit task to the OpenAIClientPool", e); + } + } + } + + record OpenAIClientWithSemaphore(OpenAIClient openAIClient, Semaphore semaphore) {} +} diff --git a/common/src/test/java/com/box/l10n/mojito/openai/OpenAIClientPoolTest.java b/common/src/test/java/com/box/l10n/mojito/openai/OpenAIClientPoolTest.java new file mode 100644 index 0000000000..b31ccb8071 --- /dev/null +++ b/common/src/test/java/com/box/l10n/mojito/openai/OpenAIClientPoolTest.java @@ -0,0 +1,135 @@ +package com.box.l10n.mojito.openai; + +import com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsResponse; +import com.google.common.base.Stopwatch; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import static com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsRequest.SystemMessage.systemMessageBuilder; +import static com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsRequest.UserMessage.userMessageBuilder; +import static com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsRequest.chatCompletionsRequest; + +@Disabled +public class OpenAIClientPoolTest { + + static Logger logger = LoggerFactory.getLogger(OpenAIClientPoolTest.class); + + static final String API_KEY; + + static { + try { + // API_KEY = + // + // Files.readString(Paths.get(System.getProperty("user.home")).resolve(".keys/openai")) + // .trim(); + API_KEY = "test-api-key"; + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + + @Test + public void test() { + int numberOfClients = 10; + int numberOfParallelRequestPerClient = 50; + int numberOfRequests = 10000; + int sizeOfAsyncProcessors = 10; + int totalExecutions = numberOfClients * numberOfParallelRequestPerClient; + + OpenAIClientPool openAIClientPool = + new OpenAIClientPool( + numberOfClients, numberOfParallelRequestPerClient, sizeOfAsyncProcessors, API_KEY); + + AtomicInteger responseCounter = new AtomicInteger(); + AtomicInteger submitted = new AtomicInteger(); + Stopwatch stopwatch = Stopwatch.createStarted(); + + ArrayList submissionTimes = new ArrayList<>(); + ArrayList responseTimes = new ArrayList<>(); + + List> responses = new ArrayList<>(); + for (int i = 0; i < numberOfRequests; i++) { + String message = "Is %d prime?".formatted(i); + Stopwatch requestStopwatch = Stopwatch.createStarted(); + OpenAIClient.ChatCompletionsRequest chatCompletionsRequest = + chatCompletionsRequest() + .model("gpt-4o-2024-08-06") + .messages( + List.of( + systemMessageBuilder() + .content("You're an engine designed to check prime numbers") + .build(), + userMessageBuilder().content(message).build())) + .build(); + + CompletableFuture response = + openAIClientPool.submit( + openAIClient -> { + CompletableFuture chatCompletions = + openAIClient.getChatCompletions(chatCompletionsRequest); + submissionTimes.add(requestStopwatch.elapsed(TimeUnit.SECONDS)); + if (submitted.incrementAndGet() % 100 == 0) { + logger.info( + "--> request per second: " + + submitted.get() / (stopwatch.elapsed(TimeUnit.SECONDS) + 0.00001) + + ", submission count: " + + submitted.get() + + ", future response count: " + + responses.size() + + ", last submissions took: " + + submissionTimes.subList( + Math.max(0, submissionTimes.size() - 100), submissionTimes.size())); + } + return chatCompletions; + }); + + response.thenApply( + chatCompletionsResponse -> { + responseTimes.add(requestStopwatch.elapsed(TimeUnit.MILLISECONDS)); + if (responseCounter.incrementAndGet() % 10 == 0) { + double avg = + responseTimes.stream().collect(Collectors.averagingLong(Long::longValue)); + logger.info( + "<-- response per second: " + + responseCounter.get() / stopwatch.elapsed(TimeUnit.SECONDS) + + ", average response time: " + + Math.round(avg) + + " (rps: " + + Math.round(totalExecutions / (avg / 1000.0)) + + "), response count from counter: " + + responseCounter.get() + + ", last elapsed times: " + + responseTimes.subList(responseTimes.size() - 20, responseTimes.size())); + } + return chatCompletionsResponse; + }); + + responses.add(response); + } + + Stopwatch started = Stopwatch.createStarted(); + CompletableFuture.allOf(responses.toArray(new CompletableFuture[responses.size()])).join(); + logger.info("Waiting for join: " + started.elapsed()); + + double avg = responseTimes.stream().collect(Collectors.averagingLong(Long::longValue)); + logger.info( + "Total time: " + + stopwatch.elapsed().toString() + + ", request per second: " + + Math.round((double) numberOfRequests / stopwatch.elapsed(TimeUnit.SECONDS)) + + ", average response time: " + + Math.round(avg) + + " (theory rps: " + + Math.round(totalExecutions / (avg / 1000.0)) + + ")"); + } +} diff --git a/common/src/test/java/com/box/l10n/mojito/openai/OpenAIClientTest.java b/common/src/test/java/com/box/l10n/mojito/openai/OpenAIClientTest.java index dfc0391cc5..d538cb45b3 100644 --- a/common/src/test/java/com/box/l10n/mojito/openai/OpenAIClientTest.java +++ b/common/src/test/java/com/box/l10n/mojito/openai/OpenAIClientTest.java @@ -13,7 +13,6 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import com.box.l10n.mojito.io.Files; import com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsResponse; import com.box.l10n.mojito.openai.OpenAIClient.OpenAIClientResponseException; import com.box.l10n.mojito.openai.OpenAIClient.UploadFileRequest; @@ -22,23 +21,23 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; -import java.nio.file.Paths; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import org.junit.jupiter.api.Test; -class OpenAIClientTest { +public class OpenAIClientTest { static final String API_KEY; static { try { - API_KEY = - Files.readString(Paths.get(System.getProperty("user.home")).resolve(".keys/openai")) - .trim(); - // API_KEY = "test-api-key"; + // API_KEY = + // + // Files.readString(Paths.get(System.getProperty("user.home")).resolve(".keys/openai")) + // .trim(); + API_KEY = "test-api-key"; } catch (Throwable e) { throw new RuntimeException(e); } @@ -66,29 +65,29 @@ public void testGetChatCompletionsSuccess() throws Exception { String jsonResponse = """ - { - "id": "chatcmpl-9DNYjOkXJxILUK3NXFv9MCZV0P8jZ", - "object": "chat.completion", - "created": 1712975853, - "model": "gpt-3.5-turbo-0125", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "Il s'agit d'un test unitaire" - }, - "logprobs": null, - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": 24, - "completion_tokens": 9, - "total_tokens": 33 - }, - "system_fingerprint": "fp_c2295e73ad" - }"""; + { + "id": "chatcmpl-9DNYjOkXJxILUK3NXFv9MCZV0P8jZ", + "object": "chat.completion", + "created": 1712975853, + "model": "gpt-3.5-turbo-0125", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Il s'agit d'un test unitaire" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 24, + "completion_tokens": 9, + "total_tokens": 33 + }, + "system_fingerprint": "fp_c2295e73ad" + }"""; HttpResponse mockResponse = mock(HttpResponse.class); when(mockResponse.statusCode()).thenReturn(200); @@ -132,14 +131,14 @@ public void testGetChatCompletionsRequestError() throws Exception { when(mockResponse.statusCode()).thenReturn(400); String errorMsg = """ - { - "error": { - "message": "The model `invalid-model` does not exist or you do not have access to it.", - "type": "invalid_request_error", - "param": null, - "code": "model_not_found" - } - }"""; + { + "error": { + "message": "The model `invalid-model` does not exist or you do not have access to it.", + "type": "invalid_request_error", + "param": null, + "code": "model_not_found" + } + }"""; when(mockResponse.body()).thenReturn(errorMsg); HttpClient mockHttpClient = mock(HttpClient.class); @@ -159,12 +158,12 @@ public void testGetChatCompletionsRequestError() throws Exception { .getMessage() .contains( """ - "error": { - "message": "The model `invalid-model` does not exist or you do not have access to it.", - "type": "invalid_request_error", - "param": null, - "code": "model_not_found" - }""")); + "error": { + "message": "The model `invalid-model` does not exist or you do not have access to it.", + "type": "invalid_request_error", + "param": null, + "code": "model_not_found" + }""")); } @Test @@ -186,29 +185,29 @@ public void testGetChatCompletionsDeserializationError() throws Exception { String jsonResponse = """ - { - "id": "chatcmpl-9DNYjOkXJxILUK3NXFv9MCZV0P8jZ", - "object": "chat.completion", - "created": "invalid date to break deserialization", - "model": "gpt-3.5-turbo-0125", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "Il s'agit d'un test unitaire" - }, - "logprobs": null, - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": 24, - "completion_tokens": 9, - "total_tokens": 33 - }, - "system_fingerprint": "fp_c2295e73ad" - }"""; + { + "id": "chatcmpl-9DNYjOkXJxILUK3NXFv9MCZV0P8jZ", + "object": "chat.completion", + "created": "invalid date to break deserialization", + "model": "gpt-3.5-turbo-0125", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Il s'agit d'un test unitaire" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 24, + "completion_tokens": 9, + "total_tokens": 33 + }, + "system_fingerprint": "fp_c2295e73ad" + }"""; HttpResponse mockResponse = mock(HttpResponse.class); when(mockResponse.statusCode()).thenReturn(200); @@ -239,12 +238,12 @@ public void testUploadFileSuccess() throws Exception { when(mockResponse.body()) .thenReturn( """ -{ - "id": "file-123", - "filename": "example.jsonl", - "status": "uploaded", - "created_at": 1690000000 -}"""); + { + "id": "file-123", + "filename": "example.jsonl", + "status": "uploaded", + "created_at": 1690000000 + }"""); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); @@ -271,15 +270,15 @@ public void testUploadFileError() throws Exception { when(mockResponse.statusCode()).thenReturn(400); String errorMessage = """ - { - "error": { - "message": "Invalid file format for Batch API. Must be .jsonl", - "type": "invalid_request_error", - "param": null, - "code": null + { + "error": { + "message": "Invalid file format for Batch API. Must be .jsonl", + "type": "invalid_request_error", + "param": null, + "code": null + } } - } - """; + """; when(mockResponse.body()).thenReturn(errorMessage); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); @@ -291,10 +290,10 @@ public void testUploadFileError() throws Exception { UploadFileRequest.forBatch( "example.jsonl", """ - { - "a" : "b" - } - """); + { + "a" : "b" + } + """); OpenAIClientResponseException openAIClientResponseException = assertThrows( @@ -309,18 +308,18 @@ public void testFileUploadRequestMultiPartBody() { String actual = uploadFileRequest.getMultipartBody("test-boundary"); assertEquals( """ - --test-boundary\r - Content-Disposition: form-data; name="purpose"\r - \r - batch\r - --test-boundary\r - Content-Disposition: form-data; name="file"; filename="test.jsonl"\r - Content-Type: application/json\r - \r - {} - {}\r - --test-boundary--\r - """, + --test-boundary\r + Content-Disposition: form-data; name="purpose"\r + \r + batch\r + --test-boundary\r + Content-Disposition: form-data; name="file"; filename="test.jsonl"\r + Content-Type: application/json\r + \r + {} + {}\r + --test-boundary--\r + """, actual); } @@ -332,9 +331,9 @@ public void testDownloadFileContentSuccess() throws IOException, InterruptedExce when(mockResponse.statusCode()).thenReturn(200); String fileContent = """ - {"a" : "b"} - {"c" : "d"} - """; + {"a" : "b"} + {"c" : "d"} + """; when(mockResponse.body()).thenReturn(fileContent); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); @@ -357,15 +356,15 @@ public void testDownloadFileContentError() throws IOException, InterruptedExcept when(mockResponse.statusCode()).thenReturn(404); String body = """ - { - "error": { - "message": "No such File object: id-for-test", - "type": "invalid_request_error", - "param": "id", - "code": null + { + "error": { + "message": "No such File object: id-for-test", + "type": "invalid_request_error", + "param": "id", + "code": null + } } - } - """; + """; when(mockResponse.body()).thenReturn(body); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); @@ -392,36 +391,36 @@ public void testCreateBatchSuccess() throws IOException, InterruptedException { when(mockResponse.statusCode()).thenReturn(200); String body = """ - { - "id": "batch_67199315c20081909074e442115938a2", - "object": "batch", - "endpoint": "/v1/chat/completions", - "errors": null, - "input_file_id": "file-pp1I2zv79eAnm47wt6rCNL5a", - "completion_window": "24h", - "status": "validating", - "output_file_id": null, - "error_file_id": null, - "created_at": 1729729301, - "in_progress_at": null, - "expires_at": 1729815701, - "finalizing_at": null, - "completed_at": null, - "failed_at": null, - "expired_at": null, - "cancelling_at": null, - "cancelled_at": null, - "request_counts": { - "total": 0, - "completed": 0, - "failed": 0 - }, - "metadata": { - "k1": "v1", - "k2": "v2" + { + "id": "batch_67199315c20081909074e442115938a2", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-pp1I2zv79eAnm47wt6rCNL5a", + "completion_window": "24h", + "status": "validating", + "output_file_id": null, + "error_file_id": null, + "created_at": 1729729301, + "in_progress_at": null, + "expires_at": 1729815701, + "finalizing_at": null, + "completed_at": null, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 0, + "completed": 0, + "failed": 0 + }, + "metadata": { + "k1": "v1", + "k2": "v2" + } } - } - """; + """; when(mockResponse.body()).thenReturn(body); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); @@ -447,14 +446,14 @@ public void testCreateBatchError() throws IOException, InterruptedException { when(mockResponse.statusCode()).thenReturn(400); String body = """ - { - "error": { - "message": "Invalid 'input_file_id': 'wrong-id'. Expected an ID that begins with 'file'.", - "type": "invalid_request_error", - "param": "input_file_id", - "code": "invalid_value" - } - }"""; + { + "error": { + "message": "Invalid 'input_file_id': 'wrong-id'. Expected an ID that begins with 'file'.", + "type": "invalid_request_error", + "param": "input_file_id", + "code": "invalid_value" + } + }"""; when(mockResponse.body()).thenReturn(body); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); @@ -481,36 +480,36 @@ public void testRetrieveBatchSuccess() throws IOException, InterruptedException when(mockResponse.statusCode()).thenReturn(200); String body = """ - { - "id": "batch_67199315c20081909074e442115938a2", - "object": "batch", - "endpoint": "/v1/chat/completions", - "errors": null, - "input_file_id": "file-pp1I2zv79eAnm47wt6rCNL5a", - "completion_window": "24h", - "status": "validating", - "output_file_id": null, - "error_file_id": null, - "created_at": 1729729301, - "in_progress_at": null, - "expires_at": 1729815701, - "finalizing_at": null, - "completed_at": null, - "failed_at": null, - "expired_at": null, - "cancelling_at": null, - "cancelled_at": null, - "request_counts": { - "total": 0, - "completed": 0, - "failed": 0 - }, - "metadata": { - "k1": "v1", - "k2": "v2" + { + "id": "batch_67199315c20081909074e442115938a2", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-pp1I2zv79eAnm47wt6rCNL5a", + "completion_window": "24h", + "status": "validating", + "output_file_id": null, + "error_file_id": null, + "created_at": 1729729301, + "in_progress_at": null, + "expires_at": 1729815701, + "finalizing_at": null, + "completed_at": null, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 0, + "completed": 0, + "failed": 0 + }, + "metadata": { + "k1": "v1", + "k2": "v2" + } } - } - """; + """; when(mockResponse.body()).thenReturn(body); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); @@ -535,14 +534,14 @@ public void testRetrieveBatchError() throws IOException, InterruptedException { when(mockResponse.statusCode()).thenReturn(400); String body = """ - { - "error": { - "message": "Invalid 'input_file_id': 'wrong-id'. Expected an ID that begins with 'file'.", - "type": "invalid_request_error", - "param": "input_file_id", - "code": "invalid_value" - } - }"""; + { + "error": { + "message": "Invalid 'input_file_id': 'wrong-id'. Expected an ID that begins with 'file'.", + "type": "invalid_request_error", + "param": "input_file_id", + "code": "invalid_value" + } + }"""; when(mockResponse.body()).thenReturn(body); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); diff --git a/restclient/src/main/java/com/box/l10n/mojito/rest/client/RepositoryAiTranslateClient.java b/restclient/src/main/java/com/box/l10n/mojito/rest/client/RepositoryAiTranslateClient.java index 6262f484ff..f0f5207675 100644 --- a/restclient/src/main/java/com/box/l10n/mojito/rest/client/RepositoryAiTranslateClient.java +++ b/restclient/src/main/java/com/box/l10n/mojito/rest/client/RepositoryAiTranslateClient.java @@ -29,7 +29,10 @@ public ProtoAiTranslateResponse translateRepository( } public record ProtoAiTranslateRequest( - String repositoryName, List targetBcp47tags, int sourceTextMaxCountPerLocale) {} + String repositoryName, + List targetBcp47tags, + int sourceTextMaxCountPerLocale, + boolean useBatch) {} public record ProtoAiTranslateResponse(PollableTask pollableTask) {} } diff --git a/webapp/src/main/java/com/box/l10n/mojito/rest/textunit/AiTranslateWS.java b/webapp/src/main/java/com/box/l10n/mojito/rest/textunit/AiTranslateWS.java index 2eddf2929d..7f33de3211 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/rest/textunit/AiTranslateWS.java +++ b/webapp/src/main/java/com/box/l10n/mojito/rest/textunit/AiTranslateWS.java @@ -42,13 +42,18 @@ public ProtoAiTranslateResponse aiTranslate( new AiTranslateInput( protoAiTranslateRequest.repositoryName(), protoAiTranslateRequest.targetBcp47tags(), - protoAiTranslateRequest.sourceTextMaxCountPerLocale())); + protoAiTranslateRequest.sourceTextMaxCountPerLocale(), + protoAiTranslateRequest.useBatch())); return new ProtoAiTranslateResponse(pollableFuture.getPollableTask()); } public record ProtoAiTranslateRequest( - String repositoryName, List targetBcp47tags, int sourceTextMaxCountPerLocale) {} + String repositoryName, + List targetBcp47tags, + int sourceTextMaxCountPerLocale, + boolean useBatch, + boolean allLocales) {} public record ProtoAiTranslateResponse(PollableTask pollableTask) {} } diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/oaitranslate/AiTranslateConfig.java b/webapp/src/main/java/com/box/l10n/mojito/service/oaitranslate/AiTranslateConfig.java index a54f98a9ba..4ef75a3db5 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/oaitranslate/AiTranslateConfig.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/oaitranslate/AiTranslateConfig.java @@ -2,6 +2,7 @@ import com.box.l10n.mojito.json.ObjectMapper; import com.box.l10n.mojito.openai.OpenAIClient; +import com.box.l10n.mojito.openai.OpenAIClientPool; import java.time.Duration; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.context.annotation.Bean; @@ -28,6 +29,17 @@ OpenAIClient openAIClient() { return new OpenAIClient.Builder().apiKey(openaiClientToken).build(); } + @Bean + @Qualifier("AiTranslate") + OpenAIClientPool openAIClientPool() { + String openaiClientToken = aiTranslateConfigurationProperties.getOpenaiClientToken(); + if (openaiClientToken == null) { + return null; + } + return new OpenAIClientPool( + 10, 50, 5, aiTranslateConfigurationProperties.getOpenaiClientToken()); + } + @Bean @Qualifier("AiTranslate") ObjectMapper objectMapper() { diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/oaitranslate/AiTranslateService.java b/webapp/src/main/java/com/box/l10n/mojito/service/oaitranslate/AiTranslateService.java index d9014f856a..8d9b962398 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/oaitranslate/AiTranslateService.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/oaitranslate/AiTranslateService.java @@ -21,12 +21,15 @@ import com.box.l10n.mojito.entity.RepositoryLocale; import com.box.l10n.mojito.json.ObjectMapper; import com.box.l10n.mojito.openai.OpenAIClient; +import com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsResponse; import com.box.l10n.mojito.openai.OpenAIClient.CreateBatchResponse; import com.box.l10n.mojito.openai.OpenAIClient.RequestBatchFileLine; +import com.box.l10n.mojito.openai.OpenAIClientPool; import com.box.l10n.mojito.quartz.QuartzJobInfo; import com.box.l10n.mojito.quartz.QuartzPollableTaskScheduler; import com.box.l10n.mojito.service.blobstorage.Retention; import com.box.l10n.mojito.service.blobstorage.StructuredBlobStorage; +import com.box.l10n.mojito.service.oaitranslate.AiTranslateService.CompletionInput.ExistingTarget; import com.box.l10n.mojito.service.pollableTask.PollableFuture; import com.box.l10n.mojito.service.repository.RepositoryNameNotFoundException; import com.box.l10n.mojito.service.repository.RepositoryRepository; @@ -41,12 +44,17 @@ import com.fasterxml.jackson.databind.SerializationFeature; import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import java.io.IOException; +import java.time.Duration; import java.util.ArrayDeque; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.TimeoutException; import java.util.function.Function; import java.util.stream.Collectors; import org.slf4j.Logger; @@ -54,7 +62,9 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.stereotype.Service; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; import reactor.util.retry.RetryBackoffSpec; @Service @@ -75,6 +85,8 @@ public class AiTranslateService { OpenAIClient openAIClient; + OpenAIClientPool openAIClientPool; + TextUnitBatchImporterService textUnitBatchImporterService; StructuredBlobStorage structuredBlobStorage; @@ -93,6 +105,7 @@ public AiTranslateService( StructuredBlobStorage structuredBlobStorage, AiTranslateConfigurationProperties aiTranslateConfigurationProperties, @Qualifier("AiTranslate") @Autowired(required = false) OpenAIClient openAIClient, + @Qualifier("AiTranslate") @Autowired(required = false) OpenAIClientPool openAIClientPool, @Qualifier("AiTranslate") ObjectMapper objectMapper, @Qualifier("AiTranslate") RetryBackoffSpec retryBackoffSpec, QuartzPollableTaskScheduler quartzPollableTaskScheduler) { @@ -104,12 +117,16 @@ public AiTranslateService( this.aiTranslateConfigurationProperties = aiTranslateConfigurationProperties; this.objectMapper = objectMapper; this.openAIClient = openAIClient; + this.openAIClientPool = openAIClientPool; this.retryBackoffSpec = retryBackoffSpec; this.quartzPollableTaskScheduler = quartzPollableTaskScheduler; } public record AiTranslateInput( - String repositoryName, List targetBcp47tags, int sourceTextMaxCountPerLocale) {} + String repositoryName, + List targetBcp47tags, + int sourceTextMaxCountPerLocale, + boolean useBatch) {} public PollableFuture aiTranslateAsync(AiTranslateInput aiTranslateInput) { @@ -124,26 +141,183 @@ public PollableFuture aiTranslateAsync(AiTranslateInput aiTranslateInput) } public void aiTranslate(AiTranslateInput aiTranslateInput) throws AiTranslateException { + if (aiTranslateInput.useBatch()) { + aiTranslateBatch(aiTranslateInput); + } else { + aiTranslateNoBatch(aiTranslateInput); + } + } - Repository repository = repositoryRepository.findByName(aiTranslateInput.repositoryName()); + public void aiTranslateNoBatch(AiTranslateInput aiTranslateInput) { - if (repository == null) { - throw new RepositoryNameNotFoundException( - String.format( - "Repository with name '%s' can not be found!", aiTranslateInput.repositoryName())); + Repository repository = getRepository(aiTranslateInput); + + logger.info("Start AI Translation (no batch) for repository: {}", repository.getName()); + + Set filteredRepositoryLocales = + getFilteredRepositoryLocales(aiTranslateInput, repository); + + Flux.fromIterable(filteredRepositoryLocales) + .flatMap( + rl -> + asyncProcessLocale( + rl, aiTranslateInput.sourceTextMaxCountPerLocale(), openAIClientPool), + 10) + .then() + .doOnTerminate( + () -> + logger.info( + "Done with AI Translation (no batch) for repository: {}", repository.getName())) + .block(); + } + + Mono asyncProcessLocale( + RepositoryLocale repositoryLocale, + int sourceTextMaxCountPerLocale, + OpenAIClientPool openAIClientPool) { + + Repository repository = repositoryLocale.getRepository(); + + logger.info( + "Get untranslated strings for locale: '{}' in repository: '{}'", + repositoryLocale.getLocale().getBcp47Tag(), + repository.getName()); + + TextUnitSearcherParameters textUnitSearcherParameters = new TextUnitSearcherParameters(); + textUnitSearcherParameters.setRepositoryIds(repository.getId()); + textUnitSearcherParameters.setStatusFilter(StatusFilter.FOR_TRANSLATION); + textUnitSearcherParameters.setLocaleId(repositoryLocale.getLocale().getId()); + textUnitSearcherParameters.setLimit(sourceTextMaxCountPerLocale); + + List textUnitDTOS = textUnitSearcher.search(textUnitSearcherParameters); + + if (textUnitDTOS.isEmpty()) { + logger.debug( + "Nothing to translate for locale: {}", repositoryLocale.getLocale().getBcp47Tag()); + return Mono.empty(); } + logger.info( + "Starting parallel processing for each string in locale: {}, count: {}", + repositoryLocale.getLocale().getBcp47Tag(), + textUnitDTOS.size()); + + return Flux.fromIterable(textUnitDTOS) + .buffer(500) + .concatMap( + batch -> + Flux.fromIterable(batch) + .flatMap( + textUnitDTO -> + getChatCompletionForTextUnitDTO(textUnitDTO, openAIClientPool) + .retryWhen( + Retry.backoff(5, Duration.ofSeconds(1)) + .filter(this::isRetriableException) + .doBeforeRetry( + retrySignal -> { + logger.warn( + "Retrying request for TextUnitDTO {} due to exception of type {}", + textUnitDTO.getTmTextUnitId(), + retrySignal.failure().getMessage()); + })) + .onErrorResume( + error -> { + logger.error( + "Request for TextUnitDTO {} failed after retries: {}", + textUnitDTO.getTmTextUnitId(), + error.getMessage()); + return Mono.empty(); + })) + .collectList() + .flatMap(this::submitForImport) + .doOnTerminate(() -> logger.info("Done submitting for processing"))) + .then(); + } + + record MyRecord(TextUnitDTO textUnitDTO, ChatCompletionsResponse chatCompletionsResponse) {} + + private Mono submitForImport(List results) { + logger.info("Submit for import for locale {}", results.get(0).textUnitDTO().getTargetLocale()); + List forImport = + results.stream() + .map( + myRecord -> { + TextUnitDTO textUnitDTO = myRecord.textUnitDTO(); + ChatCompletionsResponse chatCompletionsResponse = + myRecord.chatCompletionsResponse(); + + String completionOutputAsJson = + chatCompletionsResponse.choices().getFirst().message().content(); + + CompletionOutput completionOutput = + objectMapper.readValueUnchecked( + completionOutputAsJson, CompletionOutput.class); + + textUnitDTO.setTarget(completionOutput.target().content()); + textUnitDTO.setTargetComment("ai-translate"); + return textUnitDTO; + }) + .collect(Collectors.toList()); + + textUnitBatchImporterService.importTextUnits( + forImport, + TextUnitBatchImporterService.IntegrityChecksType.ALWAYS_USE_INTEGRITY_CHECKER_STATUS); + + return Mono.empty(); + } + + private Mono getChatCompletionForTextUnitDTO( + TextUnitDTO textUnitDTO, OpenAIClientPool openAIClientPool) { + + CompletionInput completionInput = + new CompletionInput( + textUnitDTO.getTargetLocale(), + textUnitDTO.getSource(), + textUnitDTO.getComment(), + textUnitDTO.getTarget() == null + ? null + : new ExistingTarget( + textUnitDTO.getTarget(), !textUnitDTO.isIncludedInLocalizedFile())); + + String inputAsJsonString = objectMapper.writeValueAsStringUnchecked(completionInput); + ObjectNode jsonSchema = createJsonSchema(CompletionOutput.class); + + ChatCompletionsRequest chatCompletionsRequest = + chatCompletionsRequest() + .model("gpt-4o-2024-08-06") + .maxTokens(16384) + .messages( + List.of( + systemMessageBuilder().content(PROMPT).build(), + userMessageBuilder().content(inputAsJsonString).build())) + .responseFormat( + new ChatCompletionsRequest.JsonFormat( + "json_schema", + new ChatCompletionsRequest.JsonFormat.JsonSchema( + true, "request_json_format", jsonSchema))) + .build(); + + CompletableFuture futureResult = + openAIClientPool.submit( + (openAIClient) -> openAIClient.getChatCompletions(chatCompletionsRequest)); + return Mono.fromFuture(futureResult) + .map(chatCompletionsResponse -> new MyRecord(textUnitDTO, chatCompletionsResponse)); + } + + private boolean isRetriableException(Throwable throwable) { + Throwable cause = throwable instanceof CompletionException ? throwable.getCause() : throwable; + return cause instanceof IOException || cause instanceof TimeoutException; + } + + public void aiTranslateBatch(AiTranslateInput aiTranslateInput) throws AiTranslateException { + + Repository repository = getRepository(aiTranslateInput); + logger.debug("Start AI Translation for repository: {}", repository.getName()); try { Set repositoryLocalesWithoutRootLocale = - repositoryService.getRepositoryLocalesWithoutRootLocale(repository).stream() - .filter( - rl -> - aiTranslateInput.targetBcp47tags == null - || aiTranslateInput.targetBcp47tags.contains( - rl.getLocale().getBcp47Tag())) - .collect(Collectors.toSet()); + getFilteredRepositoryLocales(aiTranslateInput, repository); logger.debug("Create batches for repository: {}", repository.getName()); ArrayDeque batches = @@ -167,6 +341,27 @@ public void aiTranslate(AiTranslateInput aiTranslateInput) throws AiTranslateExc } } + private Set getFilteredRepositoryLocales( + AiTranslateInput aiTranslateInput, Repository repository) { + return repositoryService.getRepositoryLocalesWithoutRootLocale(repository).stream() + .filter( + rl -> + aiTranslateInput.targetBcp47tags == null + || aiTranslateInput.targetBcp47tags.contains(rl.getLocale().getBcp47Tag())) + .collect(Collectors.toSet()); + } + + private Repository getRepository(AiTranslateInput aiTranslateInput) { + Repository repository = repositoryRepository.findByName(aiTranslateInput.repositoryName()); + + if (repository == null) { + throw new RepositoryNameNotFoundException( + String.format( + "Repository with name '%s' can not be found!", aiTranslateInput.repositoryName())); + } + return repository; + } + void importBatch(RetrieveBatchResponse retrieveBatchResponse) { logger.info("Importing batch: {}", retrieveBatchResponse.id()); @@ -208,7 +403,7 @@ void importBatch(RetrieveBatchResponse retrieveBatchResponse) { "Response batch file line failed: " + chatCompletionResponseBatchFileLine); } - String aiTranslateOutputAsJson = + String completionOutputAsJson = chatCompletionResponseBatchFileLine .response() .chatCompletionsResponse() @@ -217,14 +412,14 @@ void importBatch(RetrieveBatchResponse retrieveBatchResponse) { .message() .content(); - AiTranslateOutput aiTranslateOutput = + CompletionOutput completionOutput = objectMapper.readValueUnchecked( - aiTranslateOutputAsJson, AiTranslateOutput.class); + completionOutputAsJson, CompletionOutput.class); TextUnitDTO textUnitDTO = tmTextUnitIdToTextUnitDTOs.get( Long.valueOf(chatCompletionResponseBatchFileLine.customId())); - textUnitDTO.setTarget(aiTranslateOutput.target().content()); + textUnitDTO.setTarget(completionOutput.target().content()); textUnitDTO.setTargetComment("ai-translate"); return textUnitDTO; }) @@ -266,7 +461,6 @@ Function createBatchForRepositoryLocale( logger.debug("Generate the batch file content"); String batchFileContent = generateBatchFileContent(textUnitDTOS); - logger.debug("Upload batch file content: {}", batchFileContent); UploadFileResponse uploadFileResponse = getOpenAIClient() .uploadFile( @@ -297,11 +491,13 @@ String generateBatchFileContent(List textUnitDTOS) { new CompletionInput( textUnitDTO.getTargetLocale(), textUnitDTO.getSource(), - textUnitDTO.getComment()); + textUnitDTO.getComment(), + new ExistingTarget( + textUnitDTO.getTarget(), !textUnitDTO.isIncludedInLocalizedFile())); String inputAsJsonString = objectMapper.writeValueAsStringUnchecked(completionInput); - ObjectNode jsonSchema = createJsonSchema(AiTranslateOutput.class); + ObjectNode jsonSchema = createJsonSchema(CompletionOutput.class); ChatCompletionsRequest chatCompletionsRequest = chatCompletionsRequest() @@ -376,9 +572,12 @@ RetrieveBatchResponse retrieveBatchWithRetry(CreateBatchResponse batch) { .block(); } - record CompletionInput(String locale, String source, String sourceDescription) {} + record CompletionInput( + String locale, String source, String sourceDescription, ExistingTarget existingTarget) { + record ExistingTarget(String content, boolean hasBrokenPlaceholders) {} + } - record AiTranslateOutput( + record CompletionOutput( String source, Target target, DescriptionRating descriptionRating, @@ -408,7 +607,7 @@ record AiTranslateBlobStorage(List textUnitDTOS) {} • "source": The source text to be translated. • "locale": The target language locale, following the BCP47 standard (e.g., “fr”, “es-419”). • "sourceDescription": A description providing context for the source text. - • "existingTarget" (optional): An existing translation to review. + • "existingTarget" (optional): An existing translation to review. Indicates if it has broken placeholders. Instructions: @@ -422,6 +621,7 @@ Some strings contain code elements such as tags (e.g., {atag}, ICU message forma • Tags like {atag} remain untouched. • In cases of nested content (e.g., text that needs translation), only translate the inner text while preserving the outer structure. • Complex structures like ICU message formats should have placeholders or variables left intact (e.g., {count, plural, one {# item} other {# items}}), but translate any inner translatable text. + • If an existing translation is provided and has broken placeholder, make sure to fix them in the new translation. Ambiguity and Context: diff --git a/webapp/src/test/java/com/box/l10n/mojito/service/oaitranslate/AiTranslateServiceTest.java b/webapp/src/test/java/com/box/l10n/mojito/service/oaitranslate/AiTranslateServiceTest.java index 4c389bceab..1ef12ede94 100644 --- a/webapp/src/test/java/com/box/l10n/mojito/service/oaitranslate/AiTranslateServiceTest.java +++ b/webapp/src/test/java/com/box/l10n/mojito/service/oaitranslate/AiTranslateServiceTest.java @@ -1,9 +1,12 @@ package com.box.l10n.mojito.service.oaitranslate; import com.box.l10n.mojito.service.assetExtraction.ServiceTestBase; +import com.box.l10n.mojito.service.repository.RepositoryNameAlreadyUsedException; +import com.box.l10n.mojito.service.repository.RepositoryService; import com.box.l10n.mojito.service.tm.TMTestData; import com.box.l10n.mojito.test.TestIdWatcher; import java.util.concurrent.ExecutionException; +import org.junit.Assume; import org.junit.Rule; import org.junit.Test; import org.slf4j.Logger; @@ -18,12 +21,33 @@ public class AiTranslateServiceTest extends ServiceTestBase { @Autowired AiTranslateService aiTranslateService; + @Autowired AiTranslateConfigurationProperties aiTranslateConfigurationProperties; + + @Autowired RepositoryService repositoryService; + + @Test + public void aiTranslateBatch() throws ExecutionException, InterruptedException { + Assume.assumeNotNull(aiTranslateConfigurationProperties.getOpenaiClientToken()); + + TMTestData tmTestData = new TMTestData(testIdWatcher); + aiTranslateService + .aiTranslateAsync( + new AiTranslateService.AiTranslateInput( + tmTestData.repository.getName(), null, 100, true)) + .get(); + } + @Test - public void aiTranslate() throws ExecutionException, InterruptedException { + public void aiTranslateNoBatch() + throws ExecutionException, InterruptedException, RepositoryNameAlreadyUsedException { + Assume.assumeNotNull(aiTranslateConfigurationProperties.getOpenaiClientToken()); + TMTestData tmTestData = new TMTestData(testIdWatcher); + aiTranslateService .aiTranslateAsync( - new AiTranslateService.AiTranslateInput(tmTestData.repository.getName(), null, 100)) + new AiTranslateService.AiTranslateInput( + tmTestData.repository.getName(), null, 100, false)) .get(); } }