Skip to content

Commit

Permalink
Add integrations and seed field to CreateFineTuningJobRequest
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed Apr 9, 2024
1 parent 5aee1bf commit 8f17ffc
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,6 @@ private Constants() {}
static final String IMAGE_CONTENT_PART_TYPE = "image_url";

static final String SUBMIT_TOOL_OUTPUTS_REQUIRED_ACTION_TYPE = "submit_tool_outputs";

static final String WANDB_INTEGRATION_TYPE = "wandb";
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
package io.github.stefanbratanov.jvm.openai;

import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public record CreateFineTuningJobRequest(
String model,
String trainingFile,
Optional<Hyperparameters> hyperparameters,
Optional<String> suffix,
Optional<String> validationFile) {
Optional<String> validationFile,
Optional<List<Integration>> integrations,
Optional<Integer> seed) {

public static Builder newBuilder() {
return new Builder();
Expand Down Expand Up @@ -88,13 +93,87 @@ public Hyperparameters build() {
}
}

public sealed interface Integration permits Integration.Wandb {

@JsonProperty(access = JsonProperty.Access.READ_ONLY)
String type();

record Wandb(
String project,
Optional<String> name,
Optional<String> entity,
Optional<Map<String, String>> tags)
implements Integration {

@Override
public String type() {
return Constants.WANDB_INTEGRATION_TYPE;
}

public static Builder newBuilder() {
return new Builder();
}

public static class Builder {

private String project;
private Optional<String> name = Optional.empty();
private Optional<String> entity = Optional.empty();
private Optional<Map<String, String>> tags = Optional.empty();

/**
* @param project The name of the project that the new run will be created under.
*/
public Builder project(String project) {
this.project = project;
return this;
}

/**
* @param name A display name to set for the run. If not set, we will use the Job ID as the
* name.
*/
public Builder name(String name) {
this.name = Optional.of(name);
return this;
}

/**
* @param entity The entity to use for the run. This allows you to set the team or username
* of the WandB user that you would like associated with the run. If not set, the
* default entity for the registered WandB API key is used.
*/
public Builder entity(String entity) {
this.entity = Optional.of(entity);
return this;
}

/**
* @param project A list of tags to be attached to the newly created run. These tags are
* passed through directly to WandB. Some default tags are generated by OpenAI:
* "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".
*/
public Builder tags(Map<String, String> tags) {
this.tags = Optional.of(tags);
return this;
}

public Wandb build() {
return new Wandb(project, name, entity, tags);
}
}
}
}

public static class Builder {

private String model;
private String trainingFile;
private Optional<Hyperparameters> hyperparameters = Optional.empty();
private Optional<String> suffix = Optional.empty();
private Optional<String> validationFile = Optional.empty();
private Optional<List<Integration>> integrations = Optional.empty();
private Optional<Integer> seed = Optional.empty();

/**
* @param model The name of the model to fine-tune
Expand Down Expand Up @@ -150,9 +229,27 @@ public Builder validationFile(String validationFile) {
return this;
}

/**
* @param integrations A list of integrations to enable for your fine-tuning job.
*/
public Builder integrations(List<Integration> integrations) {
this.integrations = Optional.of(integrations);
return this;
}

/**
* @param seed The seed controls the reproducibility of the job. Passing in the same seed and
* job parameters should produce the same results, but may differ in rare cases. If a seed
* is not specified, one will be generated for you.
*/
public Builder seed(int seed) {
this.seed = Optional.of(seed);
return this;
}

public CreateFineTuningJobRequest build() {
return new CreateFineTuningJobRequest(
model, trainingFile, hyperparameters, suffix, validationFile);
model, trainingFile, hyperparameters, suffix, validationFile, integrations, seed);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ void validateFineTuning() {

Response response = createResponseWithBody(serializeObject(fineTuningJob));

validate(request, response);
validate(
request,
response,
// https://github.com/openai/openai-openapi/issues/217
"Object instance has properties which are not allowed by the schema: [\"integrations\",\"seed\"]");

FineTuningClient.PaginatedFineTuningJobs paginatedFineTuningJobs =
testDataUtil.randomPaginatedFineTuningJobs();
Expand Down Expand Up @@ -292,9 +296,9 @@ void validateRuns() {
validate(request);
}

private void validate(Request request, Response response) {
private void validate(Request request, Response response, String... reportMessagesToIgnore) {
ValidationReport report = validator.validate(request, response);
validateReport(report);
validateReport(report, reportMessagesToIgnore);
}

private void validate(Request request, String... reportMessagesToIgnore) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io.github.stefanbratanov.jvm.openai.ChatMessage.UserMessage.UserMessageWithContentParts.ContentPart;
import io.github.stefanbratanov.jvm.openai.CreateChatCompletionRequest.ResponseFormat;
import io.github.stefanbratanov.jvm.openai.CreateFineTuningJobRequest.Integration;
import io.github.stefanbratanov.jvm.openai.ThreadMessage.Content.ImageFileContent;
import io.github.stefanbratanov.jvm.openai.ThreadMessage.Content.TextContent;
import io.github.stefanbratanov.jvm.openai.ThreadMessage.Content.TextContent.Text.Annotation;
Expand Down Expand Up @@ -107,6 +108,8 @@ public CreateFineTuningJobRequest randomCreateFineTuningJobRequest() {
Optional.of(oneOf("auto", randomInt(1, 50)))))
.suffix(randomString(1, 40))
.validationFile(randomString(10))
.integrations(listOf(randomInt(1, 5), this::randomIntegration))
.seed(randomInt())
.build();
}

Expand Down Expand Up @@ -398,6 +401,16 @@ public SubmitToolOutputsRequest randomSubmitToolOutputsRequest() {
.build();
}

private Integration randomIntegration() {
return oneOf(
Integration.Wandb.newBuilder()
.project(randomString(5, 20))
.name(randomString(5, 20))
.entity(randomString(5, 20))
.tags(randomKeyValueMap(randomInt(1, 10), () -> randomString(5), () -> randomString(6)))
.build());
}

private StepDetails randomStepDetails() {
return oneOf(
new MessageCreationStepDetails(
Expand Down Expand Up @@ -474,12 +487,17 @@ private Map<Integer, Integer> randomLogitBias(int length) {
}

private Map<String, String> randomMetadata() {
int length = randomInt(1, 16);
Map<String, String> metadata = new HashMap<>();
return randomKeyValueMap(
randomInt(1, 16), () -> randomString(3, 64), () -> randomString(10, 512));
}

private Map<String, String> randomKeyValueMap(
int length, Supplier<String> keyGenerator, Supplier<String> valueGenerator) {
Map<String, String> keyValueMap = new HashMap<>();
for (int i = 0; i < length; i++) {
metadata.put(randomString(3, 64), randomString(10, 512));
keyValueMap.put(keyGenerator.get(), valueGenerator.get());
}
return metadata;
return keyValueMap;
}

private ChatCompletion.Choice randomChatCompletionChoice() {
Expand Down

0 comments on commit 8f17ffc

Please sign in to comment.