Skip to content

Commit

Permalink
Change learningRateMultiplier to be a double
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed Apr 16, 2024
1 parent 1dc611f commit f596e72
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public Builder batchSize(String batchSize) {
* @param batchSize (string or integer) Number of examples in each batch. A larger batch size
* means that model parameters are updated less frequently, but with lower variance.
*/
public Builder batchSize(Integer batchSize) {
public Builder batchSize(int batchSize) {
this.batchSize = Optional.of(batchSize);
return this;
}
Expand All @@ -62,7 +62,7 @@ public Builder learningRateMultiplier(String learningRateMultiplier) {
* @param learningRateMultiplier (string or integer) Scaling factor for the learning rate. A
* smaller learning rate may be useful to avoid overfitting.
*/
public Builder learningRateMultiplier(Integer learningRateMultiplier) {
public Builder learningRateMultiplier(double learningRateMultiplier) {
this.learningRateMultiplier = Optional.of(learningRateMultiplier);
return this;
}
Expand All @@ -80,7 +80,7 @@ public Builder nEpochs(String nEpochs) {
* @param nEpochs (string or integer) The number of epochs to train the model for. An epoch
* refers to one full cycle through the training dataset.
*/
public Builder nEpochs(Integer nEpochs) {
public Builder nEpochs(int nEpochs) {
this.nEpochs = Optional.of(nEpochs);
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public CreateFineTuningJobRequest randomCreateFineTuningJobRequest() {
.hyperparameters(
new CreateFineTuningJobRequest.Hyperparameters(
Optional.of(oneOf("auto", randomInt(1, 256))),
Optional.of(oneOf("auto", randomInt(0, 10_000))),
Optional.of(oneOf("auto", randomDoubleExclusiveMin(0, 10_000))),
Optional.of(oneOf("auto", randomInt(1, 50)))))
.suffix(randomString(1, 40))
.validationFile(randomString(10))
Expand Down Expand Up @@ -758,6 +758,10 @@ private double randomDouble(double min, double max) {
return random.doubles(1, min, max + EPSILON).findFirst().orElse(0.0);
}

private double randomDoubleExclusiveMin(double min, double max) {
return random.doubles(1, min + EPSILON, max + EPSILON).findFirst().orElse(0.0);
}

private long randomLong(long min, long max) {
return random.nextLong(min, max + 1);
}
Expand Down

0 comments on commit f596e72

Please sign in to comment.