diff --git a/tslib/storage/src/journal.ts b/tslib/storage/src/journal.ts index 2bd9540d..79a746d1 100644 --- a/tslib/storage/src/journal.ts +++ b/tslib/storage/src/journal.ts @@ -32,7 +32,7 @@ interface JournalOpDeleteStudy extends JournalOpBase { interface JournalOpSetStudySystemAttr extends JournalOpBase { study_id: number system_attr: { - "study:metric_names": string[] + "study:metric_names"?: string[] } } @@ -80,6 +80,13 @@ interface JournalOpSetTrialUserAttr extends JournalOpBase { user_attr: { [key: string]: any } // eslint-disable-line @typescript-eslint/no-explicit-any } +interface JournalOpSetTrialSystemAttr extends JournalOpBase { + trial_id: number + system_attr: { + constraints?: number[] + } +} + const trialStateNumToTrialState = (state: number): Optuna.TrialState => { switch (state) { case 0: @@ -224,7 +231,7 @@ class JournalStorage { } }) - const userAtter = log.user_attrs + const userAttrs = log.user_attrs ? Object.entries(log.user_attrs).map(([key, value]) => { return { key: key, @@ -249,7 +256,8 @@ class JournalStorage { })(), params: params, intermediate_values: [], - user_attrs: userAtter, + user_attrs: userAttrs, + constraints: [], datetime_start: log.datetime_start ? new Date(log.datetime_start) : undefined, @@ -341,6 +349,16 @@ class JournalStorage { } } } + + public applySetTrialSystemAttr(log: JournalOpSetTrialSystemAttr) { + const [thisStudy, thisTrial] = this.getStudyAndTrial(log.trial_id) + if (thisStudy === undefined || thisTrial === undefined) { + return + } + if (log.system_attr.constraints) { + thisTrial.constraints = log.system_attr.constraints + } + } } const loadJournalStorage = ( @@ -434,7 +452,9 @@ const loadJournalStorage = ( ) break case JournalOperation.SET_TRIAL_SYSTEM_ATTR: - // Unsupported + journalStorage.applySetTrialSystemAttr( + parsedLog as JournalOpSetTrialSystemAttr + ) break } } diff --git a/tslib/storage/src/sqlite.ts b/tslib/storage/src/sqlite.ts index 3cd2cdc1..cb4be1ff 100644 --- a/tslib/storage/src/sqlite.ts +++ b/tslib/storage/src/sqlite.ts @@ -159,6 +159,11 @@ const getStudy = ( } } + const systemAttrs = getTrialSystemAttributes(db, trial.trial_id) + if (systemAttrs !== undefined) { + trial.constraints = systemAttrs.constraints + } + const params = getTrialParams(db, trial.trial_id) const param_names = new Set() for (const param of params) { @@ -222,6 +227,7 @@ const getTrials = ( ), params: [], // Set this column later user_attrs: [], // Set this column later + constraints: [], datetime_start: vals[3], datetime_complete: vals[4], } @@ -404,6 +410,20 @@ const getTrialUserAttributes = ( return attrs } +const getTrialSystemAttributes = (db: SQLite3DB, trialId: number) => { + let attrs: { constraints: number[] } | undefined + db.exec({ + sql: `SELECT key, value_json FROM trial_system_attributes WHERE trial_id = ${trialId} AND key = 'constraints'`, + // biome-ignore lint/suspicious/noExplicitAny: + callback: (vals: any[]) => { + attrs = { + constraints: JSON.parse(vals[1]), + } + }, + }) + return attrs +} + const getTrialIntermediateValues = ( db: SQLite3DB, trialId: number, diff --git a/tslib/storage/test/generate_assets.py b/tslib/storage/test/generate_assets.py index c7cfaca4..2e79e141 100644 --- a/tslib/storage/test/generate_assets.py +++ b/tslib/storage/test/generate_assets.py @@ -96,6 +96,30 @@ def objective_multi(trial: optuna.Trial) -> tuple[float, float]: study.optimize(objective_multi, n_trials=50) + # Multi-objective study with constraints + def objective_constraints(trial: optuna.Trial) -> tuple[float, float]: + x = trial.suggest_float("x", -15, 30) + y = trial.suggest_float("y", -15, 30) + c0 = (x - 5) ** 2 + y**2 - 25 + c1 = -((x - 8) ** 2) - (y + 3) ** 2 + 7.7 + trial.set_user_attr("constraint", (c0, c1)) + v0 = 4 * x**2 + 4 * y**2 + v1 = (x - 5) ** 2 + (y - 5) ** 2 + return v0, v1 + + def constraints(trial: optuna.Trial): + return trial.user_attrs["constraint"] + + sampler = optuna.samplers.NSGAIISampler(constraints_func=constraints) + study = optuna.create_study( + study_name="multi-objective-constraints", + storage=storage, + sampler=sampler, + directions=["minimize", "minimize"], + ) + print(f"Generating {study.study_name} for {type(storage).__name__}...") + study.optimize(objective_constraints, n_trials=32, timeout=600) + if __name__ == "__main__": remove_assets() diff --git a/tslib/storage/test/journal.test.mjs b/tslib/storage/test/journal.test.mjs index 5df85e7f..1f8f1299 100644 --- a/tslib/storage/test/journal.test.mjs +++ b/tslib/storage/test/journal.test.mjs @@ -57,8 +57,15 @@ describe("Test Journal File Storage", async () => { assert.deepStrictEqual(study.metric_names, ["value1", "value2"]) }) + it("Check the study with constraints", () => { + const study = studies.find((s) => s.name === "multi-objective-constraints") + for (const trial of study.trials) { + assert.strictEqual(trial.constraints.length, 2) + } + }) + it("Check the number of studies", () => { - const N_STUDIES = 5 + const N_STUDIES = 6 assert.strictEqual(studies.length, N_STUDIES) }) }) diff --git a/tslib/types/src/index.ts b/tslib/types/src/index.ts index fab91d7d..40073924 100644 --- a/tslib/types/src/index.ts +++ b/tslib/types/src/index.ts @@ -74,6 +74,7 @@ export type Trial = { user_attrs: Attribute[] datetime_start?: Date datetime_complete?: Date + constraints: number[] } export type TrialParam = {