Skip to content

Commit

Permalink
Merge pull request optuna#879 from porink0424/feat/add-trail-constraints
Browse files Browse the repository at this point in the history
Add `constraints` into `Trial` type
  • Loading branch information
c-bata authored May 10, 2024
2 parents 5d8b664 + 0ceaec0 commit cffc86d
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 5 deletions.
28 changes: 24 additions & 4 deletions tslib/storage/src/journal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ interface JournalOpDeleteStudy extends JournalOpBase {
interface JournalOpSetStudySystemAttr extends JournalOpBase {
study_id: number
system_attr: {
"study:metric_names": string[]
"study:metric_names"?: string[]
}
}

Expand Down Expand Up @@ -80,6 +80,13 @@ interface JournalOpSetTrialUserAttr extends JournalOpBase {
user_attr: { [key: string]: any } // eslint-disable-line @typescript-eslint/no-explicit-any
}

interface JournalOpSetTrialSystemAttr extends JournalOpBase {
trial_id: number
system_attr: {
constraints?: number[]
}
}

const trialStateNumToTrialState = (state: number): Optuna.TrialState => {
switch (state) {
case 0:
Expand Down Expand Up @@ -224,7 +231,7 @@ class JournalStorage {
}
})

const userAtter = log.user_attrs
const userAttrs = log.user_attrs
? Object.entries(log.user_attrs).map(([key, value]) => {
return {
key: key,
Expand All @@ -249,7 +256,8 @@ class JournalStorage {
})(),
params: params,
intermediate_values: [],
user_attrs: userAtter,
user_attrs: userAttrs,
constraints: [],
datetime_start: log.datetime_start
? new Date(log.datetime_start)
: undefined,
Expand Down Expand Up @@ -341,6 +349,16 @@ class JournalStorage {
}
}
}

public applySetTrialSystemAttr(log: JournalOpSetTrialSystemAttr) {
const [thisStudy, thisTrial] = this.getStudyAndTrial(log.trial_id)
if (thisStudy === undefined || thisTrial === undefined) {
return
}
if (log.system_attr.constraints) {
thisTrial.constraints = log.system_attr.constraints
}
}
}

const loadJournalStorage = (
Expand Down Expand Up @@ -434,7 +452,9 @@ const loadJournalStorage = (
)
break
case JournalOperation.SET_TRIAL_SYSTEM_ATTR:
// Unsupported
journalStorage.applySetTrialSystemAttr(
parsedLog as JournalOpSetTrialSystemAttr
)
break
}
}
Expand Down
20 changes: 20 additions & 0 deletions tslib/storage/src/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ const getStudy = (
}
}

const systemAttrs = getTrialSystemAttributes(db, trial.trial_id)
if (systemAttrs !== undefined) {
trial.constraints = systemAttrs.constraints
}

const params = getTrialParams(db, trial.trial_id)
const param_names = new Set<string>()
for (const param of params) {
Expand Down Expand Up @@ -222,6 +227,7 @@ const getTrials = (
),
params: [], // Set this column later
user_attrs: [], // Set this column later
constraints: [],
datetime_start: vals[3],
datetime_complete: vals[4],
}
Expand Down Expand Up @@ -404,6 +410,20 @@ const getTrialUserAttributes = (
return attrs
}

const getTrialSystemAttributes = (db: SQLite3DB, trialId: number) => {
let attrs: { constraints: number[] } | undefined
db.exec({
sql: `SELECT key, value_json FROM trial_system_attributes WHERE trial_id = ${trialId} AND key = 'constraints'`,
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
callback: (vals: any[]) => {
attrs = {
constraints: JSON.parse(vals[1]),
}
},
})
return attrs
}

const getTrialIntermediateValues = (
db: SQLite3DB,
trialId: number,
Expand Down
24 changes: 24 additions & 0 deletions tslib/storage/test/generate_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,30 @@ def objective_multi(trial: optuna.Trial) -> tuple[float, float]:

study.optimize(objective_multi, n_trials=50)

# Multi-objective study with constraints
def objective_constraints(trial: optuna.Trial) -> tuple[float, float]:
x = trial.suggest_float("x", -15, 30)
y = trial.suggest_float("y", -15, 30)
c0 = (x - 5) ** 2 + y**2 - 25
c1 = -((x - 8) ** 2) - (y + 3) ** 2 + 7.7
trial.set_user_attr("constraint", (c0, c1))
v0 = 4 * x**2 + 4 * y**2
v1 = (x - 5) ** 2 + (y - 5) ** 2
return v0, v1

def constraints(trial: optuna.Trial):
return trial.user_attrs["constraint"]

sampler = optuna.samplers.NSGAIISampler(constraints_func=constraints)
study = optuna.create_study(
study_name="multi-objective-constraints",
storage=storage,
sampler=sampler,
directions=["minimize", "minimize"],
)
print(f"Generating {study.study_name} for {type(storage).__name__}...")
study.optimize(objective_constraints, n_trials=32, timeout=600)


if __name__ == "__main__":
remove_assets()
Expand Down
9 changes: 8 additions & 1 deletion tslib/storage/test/journal.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,15 @@ describe("Test Journal File Storage", async () => {
assert.deepStrictEqual(study.metric_names, ["value1", "value2"])
})

it("Check the study with constraints", () => {
const study = studies.find((s) => s.name === "multi-objective-constraints")
for (const trial of study.trials) {
assert.strictEqual(trial.constraints.length, 2)
}
})

it("Check the number of studies", () => {
const N_STUDIES = 5
const N_STUDIES = 6
assert.strictEqual(studies.length, N_STUDIES)
})
})
1 change: 1 addition & 0 deletions tslib/types/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ export type Trial = {
user_attrs: Attribute[]
datetime_start?: Date
datetime_complete?: Date
constraints: number[]
}

export type TrialParam = {
Expand Down

0 comments on commit cffc86d

Please sign in to comment.