Skip to content

Commit

Permalink
Merge pull request optuna#960 from c-bata/fix-ci-issues
Browse files Browse the repository at this point in the history
Fix mypy errors in Optuna 4.0+
  • Loading branch information
c-bata authored Sep 11, 2024
2 parents 7d6b619 + 5a1d636 commit 3b894a5
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 9 deletions.
16 changes: 12 additions & 4 deletions python_tests/artifact/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def test_successful_study_artifact_retrieval() -> None:
with tempfile.NamedTemporaryFile() as f:
f.write(b"dummy_content")
f.flush()
artifact_id = upload_artifact(study, f.name, artifact_store=artifact_store)
artifact_id = upload_artifact(
study_or_trial=study, file_path=f.name, artifact_store=artifact_store
)
app = create_app(storage, artifact_store)
status, _, body = send_request(
app,
Expand Down Expand Up @@ -188,7 +190,9 @@ def test_successful_trial_artifact_retrieval() -> None:
with tempfile.NamedTemporaryFile() as f:
f.write(b"dummy_content")
f.flush()
artifact_id = upload_artifact(trial, f.name, artifact_store=artifact_store)
artifact_id = upload_artifact(
study_or_trial=trial, file_path=f.name, artifact_store=artifact_store
)
app = create_app(storage, artifact_store)
status, _, body = send_request(
app,
Expand Down Expand Up @@ -281,7 +285,9 @@ def test_delete_study_artifact() -> None:
with tempfile.NamedTemporaryFile() as f:
f.write(b"dummy_content")
f.flush()
artifact_id = upload_artifact(study, f.name, artifact_store=artifact_store)
artifact_id = upload_artifact(
study_or_trial=study, file_path=f.name, artifact_store=artifact_store
)
app = create_app(storage, artifact_store)
status, _, _ = send_request(
app,
Expand All @@ -306,7 +312,9 @@ def test_delete_trial_artifact() -> None:
with tempfile.NamedTemporaryFile() as f:
f.write(b"dummy_content")
f.flush()
artifact_id = upload_artifact(trial, f.name, artifact_store=artifact_store)
artifact_id = upload_artifact(
study_or_trial=trial, file_path=f.name, artifact_store=artifact_store
)
app = create_app(storage, artifact_store)
status, _, _ = send_request(
app,
Expand Down
6 changes: 4 additions & 2 deletions python_tests/artifact/test_optuna_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_list_optuna_trial_artifacts() -> None:
with tempfile.NamedTemporaryFile() as f:
f.write(dummy_content)
f.flush()
upload_artifact(trial, f.name, artifact_store=artifact_store)
upload_artifact(study_or_trial=trial, file_path=f.name, artifact_store=artifact_store)

study.tell(trial, 0.0)

Expand Down Expand Up @@ -76,7 +76,9 @@ def test_delete_optuna_study_artifacts() -> None:
artifact_store = FileSystemArtifactStore(tmpdir)

def objective(trial: optuna.Trial) -> float:
upload_artifact(trial, dummy_file_path, artifact_store=artifact_store)
upload_artifact(
study_or_trial=trial, file_path=dummy_file_path, artifact_store=artifact_store
)
return 0.0

study.optimize(objective, n_trials=10)
Expand Down
8 changes: 6 additions & 2 deletions python_tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,9 @@ def test_delete_study_with_removing_artifacts(self) -> None:
with tempfile.NamedTemporaryFile() as f:
f.write(b"dummy")
f.flush()
artifact_id = upload_artifact(study, f.name, artifact_store)
artifact_id = upload_artifact(
study_or_trial=study, file_path=f.name, artifact_store=artifact_store
)

app = create_app(storage, artifact_store)

Expand Down Expand Up @@ -575,7 +577,9 @@ def test_delete_study_without_removing_artifacts(self) -> None:
with tempfile.NamedTemporaryFile() as f:
f.write(b"dummy")
f.flush()
artifact_id = upload_artifact(study, f.name, artifact_store)
artifact_id = upload_artifact(
study_or_trial=study, file_path=f.name, artifact_store=artifact_store
)

app = create_app(storage, artifact_store)

Expand Down
6 changes: 5 additions & 1 deletion python_tests/wsgi_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@


if typing.TYPE_CHECKING:
from sys import _OptExcInfo

from _typeshed.wsgi import WSGIEnvironment


Expand Down Expand Up @@ -64,7 +66,9 @@ def send_request(
status: str = ""
response_headers: list[tuple[str, str]] = []

def start_response(status_: str, headers_: list[tuple[str, str]]) -> None:
def start_response(
status_: str, headers_: list[tuple[str, str]], exc_info: _OptExcInfo | None = None
) -> None:
nonlocal status, response_headers
status = status_
response_headers = headers_
Expand Down

0 comments on commit 3b894a5

Please sign in to comment.