Skip to content

Commit

Permalink
Implement first version of no-batch ai translate api
Browse files Browse the repository at this point in the history
  • Loading branch information
ja-openai committed Nov 22, 2024
1 parent 7f8870e commit 1fa2715
Show file tree
Hide file tree
Showing 10 changed files with 703 additions and 219 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ public class RepositoryAiTranslationCommand extends Command {
@Parameter(
names = {Param.REPOSITORY_LOCALES_LONG, Param.REPOSITORY_LOCALES_SHORT},
variableArity = true,
required = true,
description = "List of locales (bcp47 tags) to machine translate")
description =
"List of locales (bcp47 tags) to translate, if not provided translate all locales in the repository")
List<String> locales;

@Parameter(
Expand All @@ -55,6 +55,12 @@ public class RepositoryAiTranslationCommand extends Command {
+ "sending too many strings to MT)")
int sourceTextMaxCount = 100;

@Parameter(
names = {"--use-batch"},
arity = 1,
description = "To use the batch API or not")
boolean useBatch = false;

@Autowired CommandHelper commandHelper;

@Autowired RepositoryAiTranslateClient repositoryAiTranslateClient;
Expand All @@ -75,13 +81,13 @@ public void execute() throws CommandException {
.reset()
.a(" for locales: ")
.fg(Color.CYAN)
.a(locales.stream().collect(Collectors.joining(", ", "[", "]")))
.a(locales == null ? "<all>" : locales.stream().collect(Collectors.joining(", ", "[", "]")))
.println(2);

ProtoAiTranslateResponse protoAiTranslateResponse =
repositoryAiTranslateClient.translateRepository(
new RepositoryAiTranslateClient.ProtoAiTranslateRequest(
repositoryParam, locales, sourceTextMaxCount));
repositoryParam, locales, sourceTextMaxCount, useBatch));

PollableTask pollableTask = protoAiTranslateResponse.pollableTask();
commandHelper.waitForPollableTask(pollableTask.getId());
Expand Down
33 changes: 28 additions & 5 deletions common/src/main/java/com/box/l10n/mojito/openai/OpenAIClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ForkJoinPool;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand All @@ -39,11 +41,19 @@ public class OpenAIClient {

final HttpClient httpClient;

OpenAIClient(String apiKey, String host, ObjectMapper objectMapper, HttpClient httpClient) {
final Executor asyncExecutor;

OpenAIClient(
String apiKey,
String host,
ObjectMapper objectMapper,
HttpClient httpClient,
Executor asyncExecutor) {
this.apiKey = Objects.requireNonNull(apiKey);
this.host = Objects.requireNonNull(host);
this.objectMapper = Objects.requireNonNull(objectMapper);
this.httpClient = Objects.requireNonNull(httpClient);
this.asyncExecutor = Objects.requireNonNull(asyncExecutor);
}

public static class Builder {
Expand All @@ -56,6 +66,8 @@ public static class Builder {

private HttpClient httpClient;

private Executor asyncExecutor;

public Builder() {}

public Builder apiKey(String apiKey) {
Expand All @@ -78,6 +90,11 @@ public Builder httpClient(HttpClient httpClient) {
return this;
}

public Builder asyncExecutor(Executor asyncExecutor) {
this.asyncExecutor = asyncExecutor;
return this;
}

public OpenAIClient build() {
if (apiKey == null) {
throw new IllegalStateException("API key must be provided");
Expand All @@ -89,11 +106,16 @@ public OpenAIClient build() {
if (httpClient == null) {
httpClient = createHttpClient();
}
return new OpenAIClient(apiKey, host, objectMapper, httpClient);

if (asyncExecutor == null) {
asyncExecutor = ForkJoinPool.commonPool();
}

return new OpenAIClient(apiKey, host, objectMapper, httpClient, asyncExecutor);
}

private HttpClient createHttpClient() {
return HttpClient.newHttpClient();
return HttpClient.newBuilder().build();
}

private ObjectMapper createObjectMapper() {
Expand Down Expand Up @@ -135,7 +157,7 @@ public CompletableFuture<ChatCompletionsResponse> getChatCompletions(
CompletableFuture<ChatCompletionsResponse> chatCompletionsResponse =
httpClient
.sendAsync(request, HttpResponse.BodyHandlers.ofString())
.thenApply(
.thenApplyAsync(
httpResponse -> {
if (httpResponse.statusCode() != 200) {
throw new OpenAIClientResponseException("ChatCompletion failed", httpResponse);
Expand All @@ -148,7 +170,8 @@ public CompletableFuture<ChatCompletionsResponse> getChatCompletions(
"Can't deserialize ChatCompletionsResponse", e, httpResponse);
}
}
});
},
asyncExecutor);

return chatCompletionsResponse;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package com.box.l10n.mojito.openai;

import com.google.common.base.Function;
import java.net.http.HttpClient;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadLocalRandom;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OpenAIClientPool {

static Logger logger = LoggerFactory.getLogger(OpenAIClientPool.class);

int numberOfClients;
OpenAIClientWithSemaphore[] openAIClientWithSemaphores;

/**
* Pool to parallelize slower requests (1s+) over HTTP/2 connections.
*
* @param numberOfClients Number of OpenAIClient instances with independent HttpClients.
* @param numberOfParallelRequestPerClient Maximum parallel requests per client, controlled by a
* semaphore to prevent overload.
* @param sizeOfAsyncProcessors Shared async processors across all HttpClients to limit threads,
* as request time is the main bottleneck.
* @param apiKey API key for authentication.
*/
public OpenAIClientPool(
int numberOfClients,
int numberOfParallelRequestPerClient,
int sizeOfAsyncProcessors,
String apiKey) {
ExecutorService asyncExecutor = Executors.newWorkStealingPool(sizeOfAsyncProcessors);
this.numberOfClients = numberOfClients;
this.openAIClientWithSemaphores = new OpenAIClientWithSemaphore[numberOfClients];
for (int i = 0; i < numberOfClients; i++) {
this.openAIClientWithSemaphores[i] =
new OpenAIClientWithSemaphore(
OpenAIClient.builder()
.apiKey(apiKey)
.asyncExecutor(asyncExecutor)
.httpClient(HttpClient.newBuilder().executor(asyncExecutor).build())
.build(),
new Semaphore(numberOfParallelRequestPerClient));
}
}

public <T> CompletableFuture<T> submit(Function<OpenAIClient, CompletableFuture<T>> f) {

while (true) {
for (OpenAIClientWithSemaphore openAIClientWithSemaphore : openAIClientWithSemaphores) {
if (openAIClientWithSemaphore.semaphore().tryAcquire()) {
return f.apply(openAIClientWithSemaphore.openAIClient())
.whenComplete((o, e) -> openAIClientWithSemaphore.semaphore().release());
}
}

try {
logger.debug("can't directly acquire any semaphore, do blocking");
int randomSemaphoreIndex =
ThreadLocalRandom.current().nextInt(openAIClientWithSemaphores.length);
OpenAIClientWithSemaphore randomClientWithSemaphore =
this.openAIClientWithSemaphores[randomSemaphoreIndex];
randomClientWithSemaphore.semaphore().acquire();
return f.apply(randomClientWithSemaphore.openAIClient())
.whenComplete((o, e) -> randomClientWithSemaphore.semaphore().release());
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Can't submit task to the OpenAIClientPool", e);
}
}
}

record OpenAIClientWithSemaphore(OpenAIClient openAIClient, Semaphore semaphore) {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package com.box.l10n.mojito.openai;

import com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsResponse;
import com.google.common.base.Stopwatch;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import static com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsRequest.SystemMessage.systemMessageBuilder;
import static com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsRequest.UserMessage.userMessageBuilder;
import static com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsRequest.chatCompletionsRequest;

@Disabled
public class OpenAIClientPoolTest {

static Logger logger = LoggerFactory.getLogger(OpenAIClientPoolTest.class);

static final String API_KEY;

static {
try {
// API_KEY =
//
// Files.readString(Paths.get(System.getProperty("user.home")).resolve(".keys/openai"))
// .trim();
API_KEY = "test-api-key";
} catch (Throwable e) {
throw new RuntimeException(e);
}
}

@Test
public void test() {
int numberOfClients = 10;
int numberOfParallelRequestPerClient = 50;
int numberOfRequests = 10000;
int sizeOfAsyncProcessors = 10;
int totalExecutions = numberOfClients * numberOfParallelRequestPerClient;

OpenAIClientPool openAIClientPool =
new OpenAIClientPool(
numberOfClients, numberOfParallelRequestPerClient, sizeOfAsyncProcessors, API_KEY);

AtomicInteger responseCounter = new AtomicInteger();
AtomicInteger submitted = new AtomicInteger();
Stopwatch stopwatch = Stopwatch.createStarted();

ArrayList<Long> submissionTimes = new ArrayList<>();
ArrayList<Long> responseTimes = new ArrayList<>();

List<CompletableFuture<ChatCompletionsResponse>> responses = new ArrayList<>();
for (int i = 0; i < numberOfRequests; i++) {
String message = "Is %d prime?".formatted(i);
Stopwatch requestStopwatch = Stopwatch.createStarted();
OpenAIClient.ChatCompletionsRequest chatCompletionsRequest =
chatCompletionsRequest()
.model("gpt-4o-2024-08-06")
.messages(
List.of(
systemMessageBuilder()
.content("You're an engine designed to check prime numbers")
.build(),
userMessageBuilder().content(message).build()))
.build();

CompletableFuture<ChatCompletionsResponse> response =
openAIClientPool.submit(
openAIClient -> {
CompletableFuture<ChatCompletionsResponse> chatCompletions =
openAIClient.getChatCompletions(chatCompletionsRequest);
submissionTimes.add(requestStopwatch.elapsed(TimeUnit.SECONDS));
if (submitted.incrementAndGet() % 100 == 0) {
logger.info(
"--> request per second: "
+ submitted.get() / (stopwatch.elapsed(TimeUnit.SECONDS) + 0.00001)
+ ", submission count: "
+ submitted.get()
+ ", future response count: "
+ responses.size()
+ ", last submissions took: "
+ submissionTimes.subList(
Math.max(0, submissionTimes.size() - 100), submissionTimes.size()));
}
return chatCompletions;
});

response.thenApply(
chatCompletionsResponse -> {
responseTimes.add(requestStopwatch.elapsed(TimeUnit.MILLISECONDS));
if (responseCounter.incrementAndGet() % 10 == 0) {
double avg =
responseTimes.stream().collect(Collectors.averagingLong(Long::longValue));
logger.info(
"<-- response per second: "
+ responseCounter.get() / stopwatch.elapsed(TimeUnit.SECONDS)
+ ", average response time: "
+ Math.round(avg)
+ " (rps: "
+ Math.round(totalExecutions / (avg / 1000.0))
+ "), response count from counter: "
+ responseCounter.get()
+ ", last elapsed times: "
+ responseTimes.subList(responseTimes.size() - 20, responseTimes.size()));
}
return chatCompletionsResponse;
});

responses.add(response);
}

Stopwatch started = Stopwatch.createStarted();
CompletableFuture.allOf(responses.toArray(new CompletableFuture[responses.size()])).join();
logger.info("Waiting for join: " + started.elapsed());

double avg = responseTimes.stream().collect(Collectors.averagingLong(Long::longValue));
logger.info(
"Total time: "
+ stopwatch.elapsed().toString()
+ ", request per second: "
+ Math.round((double) numberOfRequests / stopwatch.elapsed(TimeUnit.SECONDS))
+ ", average response time: "
+ Math.round(avg)
+ " (theory rps: "
+ Math.round(totalExecutions / (avg / 1000.0))
+ ")");
}
}
Loading

0 comments on commit 1fa2715

Please sign in to comment.