Skip to content

Commit

Permalink
Fine-tuning updates, add gpt-4-turbo and other spec updates
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed Apr 10, 2024
1 parent 8dabf8a commit bbd3aa4
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 140 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public Builder description(String description) {

/**
* @param instructions The system instructions that the assistant uses. The maximum length is
* 32768 characters.
* 256,000 characters.
*/
public Builder instructions(String instructions) {
this.instructions = Optional.of(instructions);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package io.github.stefanbratanov.jvm.openai;

import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public record CreateFineTuningJobRequest(
Expand All @@ -11,7 +9,7 @@ public record CreateFineTuningJobRequest(
Optional<Hyperparameters> hyperparameters,
Optional<String> suffix,
Optional<String> validationFile,
Optional<List<Integration>> integrations,
Optional<List<FineTuningJobIntegration>> integrations,
Optional<Integer> seed) {

public static Builder newBuilder() {
Expand Down Expand Up @@ -93,86 +91,14 @@ public Hyperparameters build() {
}
}

public sealed interface Integration permits Integration.Wandb {

@JsonProperty(access = JsonProperty.Access.READ_ONLY)
String type();

record Wandb(
String project,
Optional<String> name,
Optional<String> entity,
Optional<Map<String, String>> tags)
implements Integration {

@Override
public String type() {
return Constants.WANDB_INTEGRATION_TYPE;
}

public static Builder newBuilder() {
return new Builder();
}

public static class Builder {

private String project;
private Optional<String> name = Optional.empty();
private Optional<String> entity = Optional.empty();
private Optional<Map<String, String>> tags = Optional.empty();

/**
* @param project The name of the project that the new run will be created under.
*/
public Builder project(String project) {
this.project = project;
return this;
}

/**
* @param name A display name to set for the run. If not set, we will use the Job ID as the
* name.
*/
public Builder name(String name) {
this.name = Optional.of(name);
return this;
}

/**
* @param entity The entity to use for the run. This allows you to set the team or username
* of the WandB user that you would like associated with the run. If not set, the
* default entity for the registered WandB API key is used.
*/
public Builder entity(String entity) {
this.entity = Optional.of(entity);
return this;
}

/**
* @param tags A list of tags to be attached to the newly created run. These tags are passed
* through directly to WandB. Some default tags are generated by OpenAI:
* "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".
*/
public Builder tags(Map<String, String> tags) {
this.tags = Optional.of(tags);
return this;
}

public Wandb build() {
return new Wandb(project, name, entity, tags);
}
}
}
}

public static class Builder {

private String model;
private String trainingFile;
private Optional<Hyperparameters> hyperparameters = Optional.empty();
private Optional<String> suffix = Optional.empty();
private Optional<String> validationFile = Optional.empty();
private Optional<List<Integration>> integrations = Optional.empty();
private Optional<List<FineTuningJobIntegration>> integrations = Optional.empty();
private Optional<Integer> seed = Optional.empty();

/**
Expand Down Expand Up @@ -232,7 +158,7 @@ public Builder validationFile(String validationFile) {
/**
* @param integrations A list of integrations to enable for your fine-tuning job.
*/
public Builder integrations(List<Integration> integrations) {
public Builder integrations(List<FineTuningJobIntegration> integrations) {
this.integrations = Optional.of(integrations);
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ public record FineTuningJob(
String status,
Integer trainedTokens,
String trainingFile,
String validationFile) {
String validationFile,
List<FineTuningJobIntegration> integrations,
int seed) {

public record Error(String code, String message, String param) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ public record FineTuningJobCheckpoint(
long createdAt,
String fineTunedModelCheckpoint,
int stepNumber,
Long finishedAt,
Metrics metrics,
String fineTuningJobId) {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package io.github.stefanbratanov.jvm.openai;

import java.util.List;
import java.util.Optional;

public record FineTuningJobIntegration(String type, Wandb wandb) {

public record Wandb(
String project, Optional<String> name, Optional<String> entity, Optional<List<String>> tags) {

public static Builder newBuilder() {
return new Builder();
}

public static class Builder {

private String project;
private Optional<String> name = Optional.empty();
private Optional<String> entity = Optional.empty();
private Optional<List<String>> tags = Optional.empty();

/**
* @param project The name of the project that the new run will be created under.
*/
public Builder project(String project) {
this.project = project;
return this;
}

/**
* @param name A display name to set for the run. If not set, we will use the Job ID as the
* name.
*/
public Builder name(String name) {
this.name = Optional.of(name);
return this;
}

/**
* @param entity The entity to use for the run. This allows you to set the team or username of
* the WandB user that you would like associated with the run. If not set, the default
* entity for the registered WandB API key is used.
*/
public Builder entity(String entity) {
this.entity = Optional.of(entity);
return this;
}

/**
* @param tags A list of tags to be attached to the newly created run. These tags are passed
* through directly to WandB. Some default tags are generated by OpenAI:
* "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".
*/
public Builder tags(List<String> tags) {
this.tags = Optional.of(tags);
return this;
}

public Wandb build() {
return new Wandb(project, name, entity, tags);
}
}
}

static FineTuningJobIntegration wandbIntegration(Wandb wandb) {
return new FineTuningJobIntegration(Constants.WANDB_INTEGRATION_TYPE, wandb);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public Builder description(String description) {

/**
* @param instructions The system instructions that the assistant uses. The maximum length is
* 32768 characters.
* 256,000 characters.
*/
public Builder instructions(String instructions) {
this.instructions = Optional.of(instructions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public enum OpenAIModel {

// GPT-4 and GPT-4 Turbo (https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo)
GPT_4("gpt-4"),
GPT_4_TURBO("gpt-4-turbo"),
GPT_4_TURBO_PREVIEW("gpt-4-turbo-preview"),
GPT_4_1106_PREVIEW("gpt-4-1106-preview"),
GPT_4_VISION_PREVIEW("gpt-4-vision-preview"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,7 @@ void validateFineTuning() {

Response response = createResponseWithBody(serializeObject(fineTuningJob));

validate(
request,
response,
// https://github.com/openai/openai-openapi/issues/217
"Object instance has properties which are not allowed by the schema: [\"integrations\",\"seed\"]");
validate(request, response);

FineTuningClient.PaginatedFineTuningJobs paginatedFineTuningJobs =
testDataUtil.randomPaginatedFineTuningJobs();
Expand All @@ -120,6 +116,18 @@ void validateFineTuning() {
listEventsResponse,
// https://github.com/openai/openai-openapi/pull/168
"Object instance has properties which are not allowed by the schema: [\"has_more\"]");

FineTuningClient.PaginatedFineTuningCheckpoints paginatedFineTuningCheckpoints =
testDataUtil.randomPaginatedFineTuningCheckpoints();

Response listCheckpointsResponse =
createResponseWithBody(serializeObject(paginatedFineTuningCheckpoints));

validate(
"/" + Endpoint.FINE_TUNING.getPath() + "/{fine_tuning_job_id}/checkpoints",
Method.GET,
listCheckpointsResponse,
"Object has missing required properties ([\"n_epochs\"]");
}

@RepeatedTest(50)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.stefanbratanov.jvm.openai.FineTuningJobCheckpoint.Metrics;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -34,27 +33,6 @@ void deserializesChatCompletionChunk() throws JsonProcessingException {
});
}

@Test
void deserializesFineTuningJobCheckpoint() throws JsonProcessingException {
FineTuningJobCheckpoint fineTuningJobCheckpoint =
objectMapper.readValue(
getStringResource("/fine-tuning-job-checkpoint.json"), FineTuningJobCheckpoint.class);

assertThat(fineTuningJobCheckpoint).isNotNull();
assertThat(fineTuningJobCheckpoint.id()).isEqualTo("ftckpt_qtZ5Gyk4BLq1SfLFWp3RtO3P");
assertThat(fineTuningJobCheckpoint.fineTunedModelCheckpoint())
.isEqualTo("ft:gpt-3.5-turbo-0125:my-org:custom_suffix:9ABel2dg:ckpt-step-88");
assertThat(fineTuningJobCheckpoint.fineTuningJobId())
.isEqualTo("ftjob-fpbNQ3H1GrMehXRf8cO97xTN");
Metrics metrics = fineTuningJobCheckpoint.metrics();

assertThat(metrics.step()).isEqualTo(88.0);
assertThat(metrics.trainMeanTokenAccuracy()).isEqualTo(0.924);
assertThat(metrics.fullValidLoss()).isEqualTo(0.567);

assertThat(fineTuningJobCheckpoint.stepNumber()).isEqualTo(88);
}

@RepeatedTest(50)
void serializesAndDeserializesThreadMessageDelta() throws JsonProcessingException {
ThreadMessageDelta threadMessageDelta = testDataUtil.randomThreadMessageDelta();
Expand Down
Loading

0 comments on commit bbd3aa4

Please sign in to comment.