Skip to content

Commit

Permalink
Create authentication headers only once
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed Apr 9, 2024
1 parent 678b325 commit 5ac8892
Show file tree
Hide file tree
Showing 15 changed files with 57 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ public final class AssistantsClient extends OpenAIAssistantsClient {

AssistantsClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
String[] authenticationHeaders,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
super(authenticationHeaders, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@ public final class AudioClient extends OpenAIClient {

AudioClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
String[] authenticationHeaders,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
super(authenticationHeaders, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ public final class ChatClient extends OpenAIClient {

ChatClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
String[] authenticationHeaders,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
super(authenticationHeaders, httpClient, requestTimeout);
endpoint = baseUrl.resolve(Endpoint.CHAT.getPath());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ public final class EmbeddingsClient extends OpenAIClient {

EmbeddingsClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
String[] authenticationHeaders,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
super(authenticationHeaders, httpClient, requestTimeout);
endpoint = baseUrl.resolve(Endpoint.EMBEDDINCS.getPath());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ public final class FilesClient extends OpenAIClient {

FilesClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
String[] authenticationHeaders,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
super(authenticationHeaders, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ public final class FineTuningClient extends OpenAIClient {

FineTuningClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
String[] authenticationHeaders,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
super(authenticationHeaders, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ public final class ImagesClient extends OpenAIClient {

ImagesClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
String[] authenticationHeaders,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
super(authenticationHeaders, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ public final class MessagesClient extends OpenAIAssistantsClient {

MessagesClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
String[] authenticationHeaders,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
super(authenticationHeaders, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ public final class ModelsClient extends OpenAIClient {

ModelsClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
String[] authenticationHeaders,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
super(authenticationHeaders, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ public final class ModerationsClient extends OpenAIClient {

ModerationsClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
String[] authenticationHeaders,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
super(authenticationHeaders, httpClient, requestTimeout);
endpoint = baseUrl.resolve(Endpoint.MODERATIONS.getPath());
}

Expand Down
39 changes: 27 additions & 12 deletions src/main/java/io/github/stefanbratanov/jvm/openai/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import java.net.URI;
import java.net.http.HttpClient;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

/**
Expand Down Expand Up @@ -31,22 +33,23 @@ private OpenAI(
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
audioClient = new AudioClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
chatClient = new ChatClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
String[] authenticationHeaders = createAuthenticationHeaders(apiKey, organization);
audioClient = new AudioClient(baseUrl, authenticationHeaders, httpClient, requestTimeout);
chatClient = new ChatClient(baseUrl, authenticationHeaders, httpClient, requestTimeout);
embeddingsClient =
new EmbeddingsClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
new EmbeddingsClient(baseUrl, authenticationHeaders, httpClient, requestTimeout);
fineTuningClient =
new FineTuningClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
filesClient = new FilesClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
imagesClient = new ImagesClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
modelsClient = new ModelsClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
new FineTuningClient(baseUrl, authenticationHeaders, httpClient, requestTimeout);
filesClient = new FilesClient(baseUrl, authenticationHeaders, httpClient, requestTimeout);
imagesClient = new ImagesClient(baseUrl, authenticationHeaders, httpClient, requestTimeout);
modelsClient = new ModelsClient(baseUrl, authenticationHeaders, httpClient, requestTimeout);
moderationsClient =
new ModerationsClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
new ModerationsClient(baseUrl, authenticationHeaders, httpClient, requestTimeout);
assistantsClient =
new AssistantsClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
threadsClient = new ThreadsClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
messagesClient = new MessagesClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
runsClient = new RunsClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
new AssistantsClient(baseUrl, authenticationHeaders, httpClient, requestTimeout);
threadsClient = new ThreadsClient(baseUrl, authenticationHeaders, httpClient, requestTimeout);
messagesClient = new MessagesClient(baseUrl, authenticationHeaders, httpClient, requestTimeout);
runsClient = new RunsClient(baseUrl, authenticationHeaders, httpClient, requestTimeout);
}

/**
Expand Down Expand Up @@ -145,6 +148,18 @@ public RunsClient runsClient() {
return runsClient;
}

private String[] createAuthenticationHeaders(String apiKey, Optional<String> organization) {
List<String> authHeaders = new ArrayList<>();
authHeaders.add("Authorization");
authHeaders.add("Bearer " + apiKey);
organization.ifPresent(
org -> {
authHeaders.add("OpenAI-Organization");
authHeaders.add(org);
});
return authHeaders.toArray(new String[] {});
}

/**
* @param apiKey the API key used for authentication
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
class OpenAIAssistantsClient extends OpenAIClient {

OpenAIAssistantsClient(
String apiKey,
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
String[] authenticationHeaders, HttpClient httpClient, Optional<Duration> requestTimeout) {
super(authenticationHeaders, httpClient, requestTimeout);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*/
abstract class OpenAIClient {

private static final String STREAM_TERMINATION = "(data: \\[DONE]|event: done)";
private static final String STREAM_TERMINATION_REGEX = "(data: \\[DONE]|event: done)";

private final ObjectMapper objectMapper = ObjectMapperSingleton.getInstance();

Expand All @@ -32,11 +32,8 @@ abstract class OpenAIClient {
private final Optional<Duration> requestTimeout;

OpenAIClient(
String apiKey,
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
this.authenticationHeaders = getAuthenticationHeaders(apiKey, organization);
String[] authenticationHeaders, HttpClient httpClient, Optional<Duration> requestTimeout) {
this.authenticationHeaders = authenticationHeaders;
this.httpClient = httpClient;
this.requestTimeout = requestTimeout;
}
Expand Down Expand Up @@ -103,7 +100,7 @@ Stream<String> streamServerSentEvents(HttpRequest httpRequest) {
return sendHttpRequest(httpRequest, HttpResponse.BodyHandlers.ofLines())
.body()
.filter(sseEvent -> !sseEvent.isBlank())
.takeWhile(sseEvent -> !sseEvent.matches(STREAM_TERMINATION));
.takeWhile(sseEvent -> !sseEvent.matches(STREAM_TERMINATION_REGEX));
}

void validateStreamRequest(Supplier<Optional<Boolean>> streamField) {
Expand Down Expand Up @@ -153,18 +150,6 @@ <T> List<T> deserializeDataInResponseAsList(byte[] response, Class<T> elementTyp
}
}

private String[] getAuthenticationHeaders(String apiKey, Optional<String> organization) {
List<String> authHeaders = new ArrayList<>();
authHeaders.add("Authorization");
authHeaders.add("Bearer " + apiKey);
organization.ifPresent(
org -> {
authHeaders.add("OpenAI-Organization");
authHeaders.add(org);
});
return authHeaders.toArray(new String[] {});
}

private Optional<OpenAIException.Error> getErrorFromHttpResponse(HttpResponse<?> httpResponse) {
return getErrorBodyFromHttpResponse(httpResponse)
.flatMap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ public final class RunsClient extends OpenAIAssistantsClient {

RunsClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
String[] authenticationHeaders,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
super(authenticationHeaders, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@ public final class ThreadsClient extends OpenAIAssistantsClient {

ThreadsClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
String[] authenticationHeaders,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
super(authenticationHeaders, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down

0 comments on commit 5ac8892

Please sign in to comment.