Skip to content

Commit

Permalink
Merge pull request optuna#609 from c-bata/support-optuna-study-artifacts
Browse files Browse the repository at this point in the history
Add support for Optuna's study artifacts
  • Loading branch information
c-bata authored Sep 27, 2023
2 parents e0cecb1 + 0eae639 commit 2f99e3b
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 17 deletions.
2 changes: 2 additions & 0 deletions optuna_dashboard/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ._named_objectives import get_objective_names
from ._preference_setting import _SYSTEM_ATTR_FEEDBACK_COMPONENT
from ._preferential_history import _SYSTEM_ATTR_PREFIX_HISTORY
from .artifact._backend import list_study_artifacts
from .artifact._backend import list_trial_artifacts
from .preferential._study import _SYSTEM_ATTR_PREFERENTIAL_STUDY
from .preferential._system_attrs import get_preferences
Expand Down Expand Up @@ -144,6 +145,7 @@ def serialize_study_detail(
"user_attrs": serialize_attrs(summary.user_attrs),
}
system_attrs = getattr(summary, "system_attrs", {})
serialized["artifacts"] = list_study_artifacts(system_attrs)
if summary.datetime_start is not None:
serialized["datetime_start"] = summary.datetime_start.isoformat()

Expand Down
66 changes: 56 additions & 10 deletions optuna_dashboard/artifact/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,49 @@


def get_artifact_path(
trial: optuna.Trial,
study_or_trial: optuna.Trial | optuna.Study,
artifact_id: str,
) -> str:
"""Get the URL path for a given artifact ID."""
study_id = trial.study._study_id
trial_id = trial._trial_id
if isinstance(study_or_trial, optuna.Study):
study_id = study_or_trial._study_id
return f"/artifacts/{study_id}/{artifact_id}"

study_id = study_or_trial.study._study_id
trial_id = study_or_trial._trial_id
return f"/artifacts/{study_id}/{trial_id}/{artifact_id}"


def register_artifact_route(
app: Bottle, storage: BaseStorage, artifact_store: Optional[ArtifactStore]
) -> None:
@app.get("/artifacts/<study_id:int>/<artifact_id:re:[0-9a-fA-F-]+>")
def proxy_study_artifact(study_id: int, artifact_id: str) -> HTTPResponse | bytes:
if artifact_store is None:
response.status = 400 # Bad Request
return b"Cannot access to the artifacts."
artifact_dict = get_study_artifact_meta(storage, study_id, artifact_id)
if artifact_dict is None:
response.status = 404
return b"Not Found"
headers = {"Content-Type": artifact_dict["mimetype"]}
encoding = artifact_dict.get("encoding")
if encoding:
headers["Content-Encodings"] = encoding

fp = artifact_store.open_reader(artifact_id)
return HTTPResponse(fp, headers=headers)

@app.get("/artifacts/<study_id:int>/<trial_id:int>/<artifact_id:re:[0-9a-fA-F-]+>")
def proxy_artifact(study_id: int, trial_id: int, artifact_id: str) -> HTTPResponse | bytes:
def proxy_trial_artifact(
study_id: int,
trial_id: int,
artifact_id: str,
) -> HTTPResponse | bytes:
if artifact_store is None:
response.status = 400 # Bad Request
return b"Cannot access to the artifacts."
artifact_dict = get_artifact_meta(storage, study_id, trial_id, artifact_id)
artifact_dict = get_trial_artifact_meta(storage, study_id, trial_id, artifact_id)
if artifact_dict is None:
response.status = 404
return b"Not Found"
Expand Down Expand Up @@ -129,7 +154,7 @@ def delete_artifact(study_id: int, trial_id: int, artifact_id: str) -> dict[str,

# The artifact's metadata is stored in one of the following two locations:
storage.set_study_system_attr(
study_id, _artifact_prefix(trial_id) + artifact_id, json.dumps(None)
study_id, _dashboard_trial_artifact_prefix(trial_id) + artifact_id, json.dumps(None)
)
storage.set_trial_system_attr(
trial_id, ARTIFACTS_ATTR_PREFIX + artifact_id, json.dumps(None)
Expand Down Expand Up @@ -195,16 +220,27 @@ def objective(trial: optuna.Trial) -> float:
return artifact_id


def _artifact_prefix(trial_id: int) -> str:
def _dashboard_trial_artifact_prefix(trial_id: int) -> str:
return DASHBOARD_ARTIFACTS_ATTR_PREFIX + f"{trial_id}:"


def get_artifact_meta(
def get_study_artifact_meta(
storage: BaseStorage, study_id: int, artifact_id: str
) -> Optional[ArtifactMeta]:
study_system_attrs = storage.get_study_system_attrs(study_id)
attr_key = ARTIFACTS_ATTR_PREFIX + artifact_id
artifact_meta = study_system_attrs.get(attr_key)
if artifact_meta is not None:
return json.loads(artifact_meta)
return None


def get_trial_artifact_meta(
storage: BaseStorage, study_id: int, trial_id: int, artifact_id: str
) -> Optional[ArtifactMeta]:
# Search study_system_attrs due to backward compatibility.
study_system_attrs = storage.get_study_system_attrs(study_id)
attr_key = _artifact_prefix(trial_id=trial_id) + artifact_id
attr_key = _dashboard_trial_artifact_prefix(trial_id=trial_id) + artifact_id
artifact_meta = study_system_attrs.get(attr_key)
if artifact_meta is not None:
return json.loads(artifact_meta)
Expand All @@ -223,6 +259,7 @@ def get_artifact_meta(
def delete_all_artifacts(backend: ArtifactStore, storage: BaseStorage, study_id: int) -> None:
artifact_metas = []
study_system_attrs = storage.get_study_system_attrs(study_id)
artifact_metas.extend(list_study_artifacts(study_system_attrs))
for trial in storage.get_all_trials(study_id):
trial_artifacts = list_trial_artifacts(study_system_attrs, trial)
artifact_metas.extend(trial_artifacts)
Expand All @@ -231,14 +268,23 @@ def delete_all_artifacts(backend: ArtifactStore, storage: BaseStorage, study_id:
backend.remove(meta["artifact_id"])


def list_study_artifacts(study_system_attrs: dict[str, Any]) -> list[ArtifactMeta]:
artifact_metas = [
json.loads(value)
for key, value in study_system_attrs.items()
if key.startswith(ARTIFACTS_ATTR_PREFIX)
]
return [a for a in artifact_metas if a is not None]


def list_trial_artifacts(
study_system_attrs: dict[str, Any], trial: FrozenTrial
) -> list[ArtifactMeta]:
# Collect ArtifactMeta from study_system_attrs due to backward compatibility.
dashboard_artifact_metas = [
json.loads(value)
for key, value in study_system_attrs.items()
if key.startswith(_artifact_prefix(trial._trial_id))
if key.startswith(_dashboard_trial_artifact_prefix(trial._trial_id))
]

# Collect ArtifactMeta from trial_system_attrs. Note that artifacts uploaded via
Expand Down
2 changes: 2 additions & 0 deletions optuna_dashboard/ts/apiClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ interface StudyDetailResponse {
preferences?: [number, number][]
preference_history?: PreferenceHistoryResponse[]
plotly_graph_objects: PlotlyGraphObject[]
artifacts: Artifact[]
feedback_component_type: FeedbackComponentType
skipped_trial_numbers?: number[]
}
Expand Down Expand Up @@ -142,6 +143,7 @@ export const getStudyDetailAPI = (
convertPreferenceHistory
),
plotly_graph_objects: res.data.plotly_graph_objects,
artifacts: res.data.artifacts,
skipped_trial_numbers: res.data.skipped_trial_numbers ?? [],
}
})
Expand Down
1 change: 1 addition & 0 deletions optuna_dashboard/ts/types/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ type StudyDetail = {
preferences?: [number, number][]
preference_history?: PreferenceHistory[]
plotly_graph_objects: PlotlyGraphObject[]
artifacts: Artifact[]
skipped_trial_numbers: number[]
}

Expand Down
10 changes: 5 additions & 5 deletions python_tests/artifact/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
def test_get_artifact_path() -> None:
study = MagicMock(_study_id=0)
trial = MagicMock(_trial_id=0, study=study)
assert _backend.get_artifact_path(trial=trial, artifact_id="id0") == "/artifacts/0/0/id0"
assert _backend.get_artifact_path(trial, "id0") == "/artifacts/0/0/id0"


def test_artifact_prefix() -> None:
actual = _backend._artifact_prefix(trial_id=0)
actual = _backend._dashboard_trial_artifact_prefix(trial_id=0)
assert actual == "dashboard:artifacts:0:"


Expand Down Expand Up @@ -47,13 +47,13 @@ def init_storage_with_artifact_meta() -> BaseStorage:
def test_get_artifact_meta(init_storage_with_artifact_meta: MagicMock) -> None:
storage = init_storage_with_artifact_meta

actual = _backend.get_artifact_meta(storage, study_id=0, trial_id=0, artifact_id="id0")
actual = _backend.get_trial_artifact_meta(storage, study_id=0, trial_id=0, artifact_id="id0")
assert actual == {"artifact_id": "id0", "filename": "foo.txt"}

actual = _backend.get_artifact_meta(storage, study_id=0, trial_id=1, artifact_id="id3")
actual = _backend.get_trial_artifact_meta(storage, study_id=0, trial_id=1, artifact_id="id3")
assert actual == {"artifact_id": "id3", "filename": "qux.txt"}

actual = _backend.get_artifact_meta(storage, study_id=0, trial_id=0, artifact_id="id4")
actual = _backend.get_trial_artifact_meta(storage, study_id=0, trial_id=0, artifact_id="id4")
assert actual is None


Expand Down
4 changes: 2 additions & 2 deletions python_tests/artifact/test_optuna_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import optuna
from optuna.version import __version__ as optuna_ver
from optuna_dashboard.artifact._backend import delete_all_artifacts
from optuna_dashboard.artifact._backend import get_artifact_meta
from optuna_dashboard.artifact._backend import get_trial_artifact_meta
from optuna_dashboard.artifact._backend import list_trial_artifacts
from packaging import version
import pytest
Expand Down Expand Up @@ -44,7 +44,7 @@ def test_list_optuna_trial_artifacts() -> None:
with artifact_store.open_reader(artifact_id) as reader:
assert reader.read() == dummy_content

artifact_meta = get_artifact_meta(
artifact_meta = get_trial_artifact_meta(
storage=storage,
study_id=study._study_id,
trial_id=trial._trial_id,
Expand Down

0 comments on commit 2f99e3b

Please sign in to comment.