Skip to content

Commit

Permalink
Use a Spliterator for the Assistant streaming functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed Apr 9, 2024
1 parent 8dd4c3e commit 8dabf8a
Showing 1 changed file with 68 additions and 24 deletions.
92 changes: 68 additions & 24 deletions src/main/java/io/github/stefanbratanov/jvm/openai/RunsClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Spliterator;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

/**
* Represents an execution run on a thread.
Expand Down Expand Up @@ -289,31 +292,34 @@ public ThreadRun cancelRun(String threadId, String runId) {
return deserializeResponse(httpResponse.body(), ThreadRun.class);
}

private record AssistantStreamEvent(String event, String data) {}

private void streamAndHandleAssistantEvents(
HttpRequest httpRequest, AssistantStreamEventSubscriber subscriber) {
CompletableFuture.supplyAsync(() -> streamServerSentEvents(httpRequest))
.thenAccept(sseEvents -> handleAssistantSseEvents(sseEvents, subscriber))
.whenComplete((result, ex) -> handleAssistantEventsStreamCompletion(ex, subscriber));
CompletableFuture.supplyAsync(() -> streamAssistantEvents(httpRequest))
.thenAccept(
assistantStreamEvents ->
assistantStreamEvents.forEach(
assistantStreamEvent ->
handleAssistantStreamEvent(assistantStreamEvent, subscriber)))
.whenComplete(
(result, ex) -> {
if (ex != null) {
subscriber.onException(ex);
}
subscriber.onComplete();
});
}

private void handleAssistantSseEvents(
Stream<String> sseEvents, AssistantStreamEventSubscriber subscriber) {
Iterator<String> iterator = sseEvents.iterator();
while (iterator.hasNext()) {
// have to group the event and the data because they are received separately
String event = iterator.next().split(":", 2)[1].trim();
String data;
if (iterator.hasNext()) {
data = iterator.next().split(":", 2)[1].trim();
} else {
throw new IllegalStateException("No data available for event " + event);
}
handleAssistantSseEvent(event, data, subscriber);
}
private Stream<AssistantStreamEvent> streamAssistantEvents(HttpRequest httpRequest) {
Stream<String> sseEvents = streamServerSentEvents(httpRequest);
return StreamSupport.stream(new AssistantStreamEventSpliterator(sseEvents), false);
}

private void handleAssistantSseEvent(
String event, String data, AssistantStreamEventSubscriber subscriber) {
private void handleAssistantStreamEvent(
AssistantStreamEvent assistantStreamEvent, AssistantStreamEventSubscriber subscriber) {
String event = assistantStreamEvent.event;
String data = assistantStreamEvent.data;
if (event.startsWith("thread.run.step.delta")) {
subscriber.onThreadRunStepDelta(event, deserializeData(data, ThreadRunStepDelta.class));
} else if (event.startsWith("thread.run.step")) {
Expand All @@ -331,11 +337,49 @@ private void handleAssistantSseEvent(
}
}

private void handleAssistantEventsStreamCompletion(
Throwable ex, AssistantStreamEventSubscriber subscriber) {
if (ex != null) {
subscriber.onException(ex);
private static class AssistantStreamEventSpliterator
implements Spliterator<AssistantStreamEvent> {

private final Iterator<String> sseEventsIterator;

AssistantStreamEventSpliterator(Stream<String> sseEvents) {
this.sseEventsIterator = sseEvents.iterator();
}

@Override
public boolean tryAdvance(Consumer<? super AssistantStreamEvent> action) {
String event = getNextValue();
if (event == null) {
return false;
}
String data = getNextValue();
if (data == null) {
return false;
}
action.accept(new AssistantStreamEvent(event, data));
return true;
}

@Override
public Spliterator<AssistantStreamEvent> trySplit() {
return null;
}

@Override
public long estimateSize() {
return Long.MAX_VALUE;
}

@Override
public int characteristics() {
return ORDERED | NONNULL;
}

private String getNextValue() {
if (!sseEventsIterator.hasNext()) {
return null;
}
return sseEventsIterator.next().split(":", 2)[1].trim();
}
subscriber.onComplete();
}
}

0 comments on commit 8dabf8a

Please sign in to comment.