Skip to content

Commit

Permalink
Merge branch 'main' of github.com:porink0424/optuna-dashboard into fi…
Browse files Browse the repository at this point in the history
…x/move-graphSlice-to-tslib
  • Loading branch information
porink0424 committed Jun 26, 2024
2 parents 41e6c8c + 0c733bc commit b1dcb22
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 40 deletions.
72 changes: 55 additions & 17 deletions tslib/storage/src/journal.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,36 @@
import * as Optuna from "@optuna/types"
import { OptunaStorage } from "./storage"

// TODO(porink0424): Refactor to common function with sqlite.ts (current workaround duplicates code due to missing file extensions in tsc build output).
const isDistributionEqual = (
a: Optuna.Distribution,
b: Optuna.Distribution
) => {
if (a.type !== b.type) {
return false
}

if (a.type === "IntDistribution" || a.type === "FloatDistribution") {
if (b.type !== "IntDistribution" && b.type !== "FloatDistribution") {
throw new Error("Invalid distribution type")
}
return (
a.low === b.low &&
a.high === b.high &&
a.step === b.step &&
a.log === b.log
)
}
if (a.type === "CategoricalDistribution") {
if (b.type !== "CategoricalDistribution") {
throw new Error("Invalid distribution type")
}
return JSON.stringify(a.choices) === JSON.stringify(b.choices)
}

throw new Error("Invalid distribution type")
}

// JournalStorage
enum JournalOperation {
CREATE_STUDY = 0,
Expand Down Expand Up @@ -137,30 +167,42 @@ class JournalStorage {
public getStudies(): Optuna.Study[] {
for (const study of this.studies) {
const unionUserAttrs: Set<string> = new Set()
const unionSearchSpace: Set<string> = new Set()
let intersectionSearchSpace: string[] = []
const nameToSearchSpaceItem: Map<string, Optuna.SearchSpaceItem> =
new Map()
const unionSearchSpace: Optuna.SearchSpaceItem[] = []
let intersectionSearchSpace: Optuna.SearchSpaceItem[] = []

study.trials.forEach((trial, index) => {
for (const userAttr of trial.user_attrs) {
unionUserAttrs.add(userAttr.key)
}
for (const param of trial.params) {
unionSearchSpace.add(param.name)
if (!nameToSearchSpaceItem.has(param.name)) {
nameToSearchSpaceItem.set(param.name, {
if (
!unionSearchSpace.some(
(item) =>
item.name === param.name &&
isDistributionEqual(item.distribution, param.distribution)
)
) {
unionSearchSpace.push({
name: param.name,
distribution: param.distribution,
})
}
}
if (index === 0) {
intersectionSearchSpace = Array.from(unionSearchSpace)
intersectionSearchSpace = [...unionSearchSpace]
} else {
intersectionSearchSpace = intersectionSearchSpace.filter((name) => {
return trial.params.some((param) => param.name === name)
})
intersectionSearchSpace = intersectionSearchSpace.filter(
(searchSpaceItem) => {
return trial.params.some(
(param) =>
param.name === searchSpaceItem.name &&
isDistributionEqual(
param.distribution,
searchSpaceItem.distribution
)
)
}
)
}
})
study.union_user_attrs = Array.from(unionUserAttrs).map((key) => {
Expand All @@ -169,12 +211,8 @@ class JournalStorage {
sortable: false,
}
})
study.union_search_space = Array.from(unionSearchSpace).map((name) => {
return nameToSearchSpaceItem.get(name) as Optuna.SearchSpaceItem
})
study.intersection_search_space = intersectionSearchSpace.map((name) => {
return nameToSearchSpaceItem.get(name) as Optuna.SearchSpaceItem
})
study.union_search_space = unionSearchSpace
study.intersection_search_space = intersectionSearchSpace
}

return this.studies
Expand Down
64 changes: 43 additions & 21 deletions tslib/storage/src/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,36 @@ import * as Optuna from "@optuna/types"
import sqlite3InitModule from "@sqlite.org/sqlite-wasm"
import { OptunaStorage } from "./storage"

// TODO(porink0424): Refactor to common function with journal.ts (current workaround duplicates code due to missing file extensions in tsc build output).
const isDistributionEqual = (
a: Optuna.Distribution,
b: Optuna.Distribution
) => {
if (a.type !== b.type) {
return false
}

if (a.type === "IntDistribution" || a.type === "FloatDistribution") {
if (b.type !== "IntDistribution" && b.type !== "FloatDistribution") {
throw new Error("Invalid distribution type")
}
return (
a.low === b.low &&
a.high === b.high &&
a.step === b.step &&
a.log === b.log
)
}
if (a.type === "CategoricalDistribution") {
if (b.type !== "CategoricalDistribution") {
throw new Error("Invalid distribution type")
}
return JSON.stringify(a.choices) === JSON.stringify(b.choices)
}

throw new Error("Invalid distribution type")
}

type SQLite3DB = {
exec(options: {
sql: string
Expand Down Expand Up @@ -149,7 +179,7 @@ const getStudy = (
study.metric_names = studySystemAttrs.metric_names
}

let intersection_search_space: Set<Optuna.SearchSpaceItem> = new Set()
let intersectionSearchSpace: Optuna.SearchSpaceItem[] = []
study.trials = getTrials(db, summary.id, schemaVersion)
for (const trial of study.trials) {
const userAttrs = getTrialUserAttributes(db, trial.trial_id)
Expand All @@ -165,16 +195,7 @@ const getStudy = (
}

const params = getTrialParams(db, trial.trial_id)
const paramNames = new Set<string>()
const paramNameToSearchSpaceItem = new Map<string, Optuna.SearchSpaceItem>()
for (const param of params) {
paramNames.add(param.name)
if (paramNameToSearchSpaceItem.has(param.name)) {
paramNameToSearchSpaceItem.set(param.name, {
name: param.name,
distribution: param.distribution,
})
}
if (
study.union_search_space.findIndex((s) => s.name === param.name) === -1
) {
Expand All @@ -184,23 +205,24 @@ const getStudy = (
})
}
}
if (intersection_search_space.size === 0) {
for (const s of paramNames) {
intersection_search_space.add(
paramNameToSearchSpaceItem.get(s) as Optuna.SearchSpaceItem
)
}
if (intersectionSearchSpace.length === 0) {
intersectionSearchSpace = params.map((param) => ({
name: param.name,
distribution: param.distribution,
}))
} else {
intersection_search_space = new Set(
Array.from(intersection_search_space).filter((s) =>
paramNames.has(s.name)
intersectionSearchSpace = intersectionSearchSpace.filter((item) => {
return params.some(
(param) =>
item.name === param.name &&
isDistributionEqual(item.distribution, param.distribution)
)
)
})
}
trial.params = params
trial.user_attrs = userAttrs
}
study.intersection_search_space = Array.from(intersection_search_space)
study.intersection_search_space = intersectionSearchSpace
return study
}

Expand Down
4 changes: 2 additions & 2 deletions tslib/storage/test/generate_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def objective_single(trial: optuna.Trial) -> float:
def objective_single_dynamic(trial: optuna.Trial) -> float:
category = trial.suggest_categorical("category", ["foo", "bar"])
if category == "foo":
return (trial.suggest_float("x1", 0, 10) - 2) ** 2
return (trial.suggest_float("x", 0, 10) - 2) ** 2
else:
return -((trial.suggest_float("x2", -10, 0) + 5) ** 2)
return -((trial.suggest_float("x", -10, 0) + 5) ** 2)

study.optimize(objective_single_dynamic, n_trials=50)

Expand Down
54 changes: 54 additions & 0 deletions tslib/storage/test/journal.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,60 @@ describe("Test Journal File Storage", async () => {
studySummaries.map((_summary, index) => storage.getStudy(index))
)

it("Check the study with dynamic search space", () => {
const study = studies.find((s) => s.name === "single-objective-dynamic")
assert.deepStrictEqual(
study.union_search_space.map((item) => item.name).sort(),
["x", "x", "category"].sort()
)
assert.strictEqual(
study.union_search_space.some(
(item) =>
item.name === "category" &&
item.distribution.type === "CategoricalDistribution" &&
item.distribution.choices.length === 2
),
true
)
assert.strictEqual(
study.union_search_space.some(
(item) =>
item.name === "x" &&
item.distribution.type === "FloatDistribution" &&
item.distribution.low === 0 &&
item.distribution.high === 10 &&
item.distribution.step === null &&
item.distribution.log === false
),
true
)
assert.strictEqual(
study.union_search_space.some(
(item) =>
item.name === "x" &&
item.distribution.type === "FloatDistribution" &&
item.distribution.low === -10 &&
item.distribution.high === 0 &&
item.distribution.step === null &&
item.distribution.log === false
),
true
)
assert.deepStrictEqual(
study.intersection_search_space.map((item) => item.name).sort(),
["category"].sort()
)
assert.strictEqual(
study.intersection_search_space.some(
(item) =>
item.name === "category" &&
item.distribution.type === "CategoricalDistribution" &&
item.distribution.choices.length === 2
),
true
)
})

it("Check the study including Infinities", () => {
const study = studies.find((s) => s.name === "single-inf")
study.trials.forEach((trial, index) => {
Expand Down

0 comments on commit b1dcb22

Please sign in to comment.