Skip to content

Commit

Permalink
Merge pull request optuna#723 from eukaryo/code-fix/download_csv
Browse files Browse the repository at this point in the history
Code fix/download csv
  • Loading branch information
HideakiImamura authored Dec 8, 2023
2 parents 5dcc4cf + f63b5e9 commit 4e3f168
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 5 deletions.
46 changes: 46 additions & 0 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from __future__ import annotations

import csv
import functools
import io
from itertools import chain
import logging
import os
import re
import typing
from typing import Any
from typing import Optional
Expand Down Expand Up @@ -449,6 +453,48 @@ def save_trial_note(study_id: int, trial_id: int) -> dict[str, Any]:
response.status = 204 # No content
return {}

@app.get("/csv/<study_id:int>")
def download_csv(study_id: int) -> BottleViewReturn:
# Create a CSV file
try:
study_name = storage.get_study_name_from_id(study_id)
study = optuna.load_study(storage=storage, study_name=study_name)
except KeyError:
response.status = 404 # Not found
return {"reason": f"study_id={study_id} is not found"}
trials = study.trials
param_names = sorted(set(chain.from_iterable([t.params.keys() for t in trials])))
user_attr_names = sorted(set(chain.from_iterable([t.user_attrs.keys() for t in trials])))
param_names_header = [f"Param {x}" for x in param_names]
user_attr_names_header = [f"UserAttribute {x}" for x in user_attr_names]
n_objs = len(study.directions)
if study.metric_names is not None:
value_header = study.metric_names
else:
value_header = ["Value"] if n_objs == 1 else [f"Objective {x}" for x in range(n_objs)]
column_names = (
["Number", "State"] + value_header + param_names_header + user_attr_names_header
)

buf = io.StringIO("")
writer = csv.writer(buf)
writer.writerow(column_names)
for frozen_trial in trials:
row = [frozen_trial.number, frozen_trial.state.name]
row.extend(frozen_trial.values if frozen_trial.values is not None else [None] * n_objs)
row.extend([frozen_trial.params.get(name, None) for name in param_names])
row.extend([frozen_trial.user_attrs.get(name, None) for name in user_attr_names])
writer.writerow(row)

# Set response headers
output_name = "-".join(re.sub(r'[\\/:*?"<>|]+', "", study_name).split(" "))
response.headers["Content-Type"] = "text/csv; chatset=cp932"
response.headers["Content-Disposition"] = f"attachment; filename={output_name}.csv"

# Response body
buf.seek(0)
return buf.read()

@app.get("/favicon.ico")
def favicon() -> BottleViewReturn:
use_gzip = "gzip" in request.headers["Accept-Encoding"]
Expand Down
39 changes: 34 additions & 5 deletions optuna_dashboard/ts/components/StudyDetail.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
import Grid2 from "@mui/material/Unstable_Grid2"
import ChevronRightIcon from "@mui/icons-material/ChevronRight"
import HomeIcon from "@mui/icons-material/Home"
import DownloadIcon from "@mui/icons-material/Download"

import { StudyNote } from "./Note"
import { actionCreator } from "../action"
Expand Down Expand Up @@ -149,11 +150,39 @@ export const StudyDetail: FC<{
content = <TrialList studyDetail={studyDetail} />
} else if (page === "trialTable") {
content = (
<Card sx={{ margin: theme.spacing(2) }}>
<CardContent>
<TrialTable studyDetail={studyDetail} initialRowsPerPage={50} />
</CardContent>
</Card>
<Box sx={{ display: "flex", width: "100%", flexDirection: "column" }}>
<Card
sx={{
margin: theme.spacing(2),
width: "auto",
height: "auto",
display: "flex",
justifyContent: "left",
alignItems: "left",
}}
>
<CardContent>
<IconButton
aria-label="download csv"
size="small"
color="inherit"
download
sx={{ margin: "auto 0" }}
href={`/csv/${studyDetail?.id}`}
>
<DownloadIcon />
<Typography variant="button" sx={{ margin: theme.spacing(2) }}>
Download CSV File
</Typography>
</IconButton>
</CardContent>
</Card>
<Card sx={{ margin: theme.spacing(2) }}>
<CardContent>
<TrialTable studyDetail={studyDetail} initialRowsPerPage={50} />
</CardContent>
</Card>
</Box>
)
} else if (page === "note" && studyDetail !== null) {
content = (
Expand Down
109 changes: 109 additions & 0 deletions python_tests/test_csv_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from __future__ import annotations

from typing import Any

import optuna
from optuna.trial import TrialState
from optuna_dashboard._app import create_app
import pytest

from .wsgi_client import send_request


def _validate_output(
storage: optuna.storages.BaseStorage,
correct_status: int,
study_id: int,
expect_no_result: bool = False,
extra_col_names: list[str] | None = None,
) -> None:
app = create_app(storage)
status, _, body = send_request(
app,
f"/csv/{study_id}",
"GET",
content_type="application/json",
)
assert status == correct_status
decoded_csv = str(body.decode("utf-8"))
if expect_no_result:
assert "is not found" in decoded_csv
else:
col_names = ["Number", "State"] + ([] if extra_col_names is None else extra_col_names)
assert all(col_name in decoded_csv for col_name in col_names)


def test_download_csv_no_trial() -> None:
def objective(trial: optuna.Trial) -> float:
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
return x**2 + y

storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
study.optimize(objective, n_trials=0)
_validate_output(storage, 200, 0)


def test_download_csv_all_waiting() -> None:
storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
study.add_trial(optuna.trial.create_trial(state=TrialState.WAITING))
_validate_output(storage, 200, 0)


def test_download_csv_all_running() -> None:
storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
study.add_trial(optuna.trial.create_trial(state=TrialState.RUNNING))
_validate_output(storage, 200, 0)


@pytest.mark.parametrize("study_id", [0, 1])
def test_download_csv_fail(study_id: int) -> None:
def objective(trial: optuna.Trial) -> float:
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
return x**2 + y

storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
optuna.logging.set_verbosity(optuna.logging.ERROR)
study.optimize(objective, n_trials=10)
expect_no_result = study_id != 0
cols = ["Param x", "Param y", "Value"]
_validate_output(storage, 404 if expect_no_result else 200, study_id, expect_no_result, cols)


@pytest.mark.parametrize("is_multi_obj", [True, False])
def test_download_csv_multi_obj(is_multi_obj: bool) -> None:
def objective(trial: optuna.Trial) -> Any:
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
if is_multi_obj:
return x**2, y
return x**2 + y

storage = optuna.storages.InMemoryStorage()
directions = ["minimize", "minimize"] if is_multi_obj else ["minimize"]
study = optuna.create_study(storage=storage, directions=directions)
optuna.logging.set_verbosity(optuna.logging.ERROR)
study.optimize(objective, n_trials=10)
cols = ["Param x", "Param y"]
cols += ["Objective 0", "Objective 1"] if is_multi_obj else ["Value"]
_validate_output(storage, 200, 0, extra_col_names=cols)


def test_download_csv_user_attr() -> None:
def objective(trial: optuna.Trial) -> float:
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
trial.set_user_attr("abs_y", abs(y))
return x**2 + y

storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
optuna.logging.set_verbosity(optuna.logging.ERROR)
study.optimize(objective, n_trials=10)
cols = ["Param x", "Param y", "Value", "UserAttribute abs_y"]
_validate_output(storage, 200, 0, extra_col_names=cols)

0 comments on commit 4e3f168

Please sign in to comment.