-
Notifications
You must be signed in to change notification settings - Fork 73
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement first version of no-batch ai translate api
- Loading branch information
Showing
10 changed files
with
703 additions
and
219 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
77 changes: 77 additions & 0 deletions
77
common/src/main/java/com/box/l10n/mojito/openai/OpenAIClientPool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
package com.box.l10n.mojito.openai; | ||
|
||
import com.google.common.base.Function; | ||
import java.net.http.HttpClient; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.concurrent.ExecutorService; | ||
import java.util.concurrent.Executors; | ||
import java.util.concurrent.Semaphore; | ||
import java.util.concurrent.ThreadLocalRandom; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
public class OpenAIClientPool { | ||
|
||
static Logger logger = LoggerFactory.getLogger(OpenAIClientPool.class); | ||
|
||
int numberOfClients; | ||
OpenAIClientWithSemaphore[] openAIClientWithSemaphores; | ||
|
||
/** | ||
* Pool to parallelize slower requests (1s+) over HTTP/2 connections. | ||
* | ||
* @param numberOfClients Number of OpenAIClient instances with independent HttpClients. | ||
* @param numberOfParallelRequestPerClient Maximum parallel requests per client, controlled by a | ||
* semaphore to prevent overload. | ||
* @param sizeOfAsyncProcessors Shared async processors across all HttpClients to limit threads, | ||
* as request time is the main bottleneck. | ||
* @param apiKey API key for authentication. | ||
*/ | ||
public OpenAIClientPool( | ||
int numberOfClients, | ||
int numberOfParallelRequestPerClient, | ||
int sizeOfAsyncProcessors, | ||
String apiKey) { | ||
ExecutorService asyncExecutor = Executors.newWorkStealingPool(sizeOfAsyncProcessors); | ||
this.numberOfClients = numberOfClients; | ||
this.openAIClientWithSemaphores = new OpenAIClientWithSemaphore[numberOfClients]; | ||
for (int i = 0; i < numberOfClients; i++) { | ||
this.openAIClientWithSemaphores[i] = | ||
new OpenAIClientWithSemaphore( | ||
OpenAIClient.builder() | ||
.apiKey(apiKey) | ||
.asyncExecutor(asyncExecutor) | ||
.httpClient(HttpClient.newBuilder().executor(asyncExecutor).build()) | ||
.build(), | ||
new Semaphore(numberOfParallelRequestPerClient)); | ||
} | ||
} | ||
|
||
public <T> CompletableFuture<T> submit(Function<OpenAIClient, CompletableFuture<T>> f) { | ||
|
||
while (true) { | ||
for (OpenAIClientWithSemaphore openAIClientWithSemaphore : openAIClientWithSemaphores) { | ||
if (openAIClientWithSemaphore.semaphore().tryAcquire()) { | ||
return f.apply(openAIClientWithSemaphore.openAIClient()) | ||
.whenComplete((o, e) -> openAIClientWithSemaphore.semaphore().release()); | ||
} | ||
} | ||
|
||
try { | ||
logger.debug("can't directly acquire any semaphore, do blocking"); | ||
int randomSemaphoreIndex = | ||
ThreadLocalRandom.current().nextInt(openAIClientWithSemaphores.length); | ||
OpenAIClientWithSemaphore randomClientWithSemaphore = | ||
this.openAIClientWithSemaphores[randomSemaphoreIndex]; | ||
randomClientWithSemaphore.semaphore().acquire(); | ||
return f.apply(randomClientWithSemaphore.openAIClient()) | ||
.whenComplete((o, e) -> randomClientWithSemaphore.semaphore().release()); | ||
} catch (InterruptedException e) { | ||
Thread.currentThread().interrupt(); | ||
throw new RuntimeException("Can't submit task to the OpenAIClientPool", e); | ||
} | ||
} | ||
} | ||
|
||
record OpenAIClientWithSemaphore(OpenAIClient openAIClient, Semaphore semaphore) {} | ||
} |
135 changes: 135 additions & 0 deletions
135
common/src/test/java/com/box/l10n/mojito/openai/OpenAIClientPoolTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
package com.box.l10n.mojito.openai; | ||
|
||
import com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsResponse; | ||
import com.google.common.base.Stopwatch; | ||
import org.junit.jupiter.api.Disabled; | ||
import org.junit.jupiter.api.Test; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.concurrent.TimeUnit; | ||
import java.util.concurrent.atomic.AtomicInteger; | ||
import java.util.stream.Collectors; | ||
|
||
import static com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsRequest.SystemMessage.systemMessageBuilder; | ||
import static com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsRequest.UserMessage.userMessageBuilder; | ||
import static com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsRequest.chatCompletionsRequest; | ||
|
||
@Disabled | ||
public class OpenAIClientPoolTest { | ||
|
||
static Logger logger = LoggerFactory.getLogger(OpenAIClientPoolTest.class); | ||
|
||
static final String API_KEY; | ||
|
||
static { | ||
try { | ||
// API_KEY = | ||
// | ||
// Files.readString(Paths.get(System.getProperty("user.home")).resolve(".keys/openai")) | ||
// .trim(); | ||
API_KEY = "test-api-key"; | ||
} catch (Throwable e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
@Test | ||
public void test() { | ||
int numberOfClients = 10; | ||
int numberOfParallelRequestPerClient = 50; | ||
int numberOfRequests = 10000; | ||
int sizeOfAsyncProcessors = 10; | ||
int totalExecutions = numberOfClients * numberOfParallelRequestPerClient; | ||
|
||
OpenAIClientPool openAIClientPool = | ||
new OpenAIClientPool( | ||
numberOfClients, numberOfParallelRequestPerClient, sizeOfAsyncProcessors, API_KEY); | ||
|
||
AtomicInteger responseCounter = new AtomicInteger(); | ||
AtomicInteger submitted = new AtomicInteger(); | ||
Stopwatch stopwatch = Stopwatch.createStarted(); | ||
|
||
ArrayList<Long> submissionTimes = new ArrayList<>(); | ||
ArrayList<Long> responseTimes = new ArrayList<>(); | ||
|
||
List<CompletableFuture<ChatCompletionsResponse>> responses = new ArrayList<>(); | ||
for (int i = 0; i < numberOfRequests; i++) { | ||
String message = "Is %d prime?".formatted(i); | ||
Stopwatch requestStopwatch = Stopwatch.createStarted(); | ||
OpenAIClient.ChatCompletionsRequest chatCompletionsRequest = | ||
chatCompletionsRequest() | ||
.model("gpt-4o-2024-08-06") | ||
.messages( | ||
List.of( | ||
systemMessageBuilder() | ||
.content("You're an engine designed to check prime numbers") | ||
.build(), | ||
userMessageBuilder().content(message).build())) | ||
.build(); | ||
|
||
CompletableFuture<ChatCompletionsResponse> response = | ||
openAIClientPool.submit( | ||
openAIClient -> { | ||
CompletableFuture<ChatCompletionsResponse> chatCompletions = | ||
openAIClient.getChatCompletions(chatCompletionsRequest); | ||
submissionTimes.add(requestStopwatch.elapsed(TimeUnit.SECONDS)); | ||
if (submitted.incrementAndGet() % 100 == 0) { | ||
logger.info( | ||
"--> request per second: " | ||
+ submitted.get() / (stopwatch.elapsed(TimeUnit.SECONDS) + 0.00001) | ||
+ ", submission count: " | ||
+ submitted.get() | ||
+ ", future response count: " | ||
+ responses.size() | ||
+ ", last submissions took: " | ||
+ submissionTimes.subList( | ||
Math.max(0, submissionTimes.size() - 100), submissionTimes.size())); | ||
} | ||
return chatCompletions; | ||
}); | ||
|
||
response.thenApply( | ||
chatCompletionsResponse -> { | ||
responseTimes.add(requestStopwatch.elapsed(TimeUnit.MILLISECONDS)); | ||
if (responseCounter.incrementAndGet() % 10 == 0) { | ||
double avg = | ||
responseTimes.stream().collect(Collectors.averagingLong(Long::longValue)); | ||
logger.info( | ||
"<-- response per second: " | ||
+ responseCounter.get() / stopwatch.elapsed(TimeUnit.SECONDS) | ||
+ ", average response time: " | ||
+ Math.round(avg) | ||
+ " (rps: " | ||
+ Math.round(totalExecutions / (avg / 1000.0)) | ||
+ "), response count from counter: " | ||
+ responseCounter.get() | ||
+ ", last elapsed times: " | ||
+ responseTimes.subList(responseTimes.size() - 20, responseTimes.size())); | ||
} | ||
return chatCompletionsResponse; | ||
}); | ||
|
||
responses.add(response); | ||
} | ||
|
||
Stopwatch started = Stopwatch.createStarted(); | ||
CompletableFuture.allOf(responses.toArray(new CompletableFuture[responses.size()])).join(); | ||
logger.info("Waiting for join: " + started.elapsed()); | ||
|
||
double avg = responseTimes.stream().collect(Collectors.averagingLong(Long::longValue)); | ||
logger.info( | ||
"Total time: " | ||
+ stopwatch.elapsed().toString() | ||
+ ", request per second: " | ||
+ Math.round((double) numberOfRequests / stopwatch.elapsed(TimeUnit.SECONDS)) | ||
+ ", average response time: " | ||
+ Math.round(avg) | ||
+ " (theory rps: " | ||
+ Math.round(totalExecutions / (avg / 1000.0)) | ||
+ ")"); | ||
} | ||
} |
Oops, something went wrong.