From 5a1d6365cd076312e3f698a8e2dbc6a72819514e Mon Sep 17 00:00:00 2001 From: c-bata Date: Wed, 11 Sep 2024 16:32:33 +0900 Subject: [PATCH] Fix mypy errors in Optuna 4.0+ --- python_tests/artifact/test_backend.py | 16 ++++++++++++---- .../artifact/test_optuna_compatibility.py | 6 ++++-- python_tests/test_api.py | 8 ++++++-- python_tests/wsgi_client.py | 6 +++++- 4 files changed, 27 insertions(+), 9 deletions(-) diff --git a/python_tests/artifact/test_backend.py b/python_tests/artifact/test_backend.py index 38333fe9e..e36ad5860 100644 --- a/python_tests/artifact/test_backend.py +++ b/python_tests/artifact/test_backend.py @@ -142,7 +142,9 @@ def test_successful_study_artifact_retrieval() -> None: with tempfile.NamedTemporaryFile() as f: f.write(b"dummy_content") f.flush() - artifact_id = upload_artifact(study, f.name, artifact_store=artifact_store) + artifact_id = upload_artifact( + study_or_trial=study, file_path=f.name, artifact_store=artifact_store + ) app = create_app(storage, artifact_store) status, _, body = send_request( app, @@ -188,7 +190,9 @@ def test_successful_trial_artifact_retrieval() -> None: with tempfile.NamedTemporaryFile() as f: f.write(b"dummy_content") f.flush() - artifact_id = upload_artifact(trial, f.name, artifact_store=artifact_store) + artifact_id = upload_artifact( + study_or_trial=trial, file_path=f.name, artifact_store=artifact_store + ) app = create_app(storage, artifact_store) status, _, body = send_request( app, @@ -281,7 +285,9 @@ def test_delete_study_artifact() -> None: with tempfile.NamedTemporaryFile() as f: f.write(b"dummy_content") f.flush() - artifact_id = upload_artifact(study, f.name, artifact_store=artifact_store) + artifact_id = upload_artifact( + study_or_trial=study, file_path=f.name, artifact_store=artifact_store + ) app = create_app(storage, artifact_store) status, _, _ = send_request( app, @@ -306,7 +312,9 @@ def test_delete_trial_artifact() -> None: with tempfile.NamedTemporaryFile() as f: f.write(b"dummy_content") f.flush() - artifact_id = upload_artifact(trial, f.name, artifact_store=artifact_store) + artifact_id = upload_artifact( + study_or_trial=trial, file_path=f.name, artifact_store=artifact_store + ) app = create_app(storage, artifact_store) status, _, _ = send_request( app, diff --git a/python_tests/artifact/test_optuna_compatibility.py b/python_tests/artifact/test_optuna_compatibility.py index 1ade64abe..d055e0ab3 100644 --- a/python_tests/artifact/test_optuna_compatibility.py +++ b/python_tests/artifact/test_optuna_compatibility.py @@ -31,7 +31,7 @@ def test_list_optuna_trial_artifacts() -> None: with tempfile.NamedTemporaryFile() as f: f.write(dummy_content) f.flush() - upload_artifact(trial, f.name, artifact_store=artifact_store) + upload_artifact(study_or_trial=trial, file_path=f.name, artifact_store=artifact_store) study.tell(trial, 0.0) @@ -76,7 +76,9 @@ def test_delete_optuna_study_artifacts() -> None: artifact_store = FileSystemArtifactStore(tmpdir) def objective(trial: optuna.Trial) -> float: - upload_artifact(trial, dummy_file_path, artifact_store=artifact_store) + upload_artifact( + study_or_trial=trial, file_path=dummy_file_path, artifact_store=artifact_store + ) return 0.0 study.optimize(objective, n_trials=10) diff --git a/python_tests/test_api.py b/python_tests/test_api.py index d6732903c..e4e755b84 100644 --- a/python_tests/test_api.py +++ b/python_tests/test_api.py @@ -539,7 +539,9 @@ def test_delete_study_with_removing_artifacts(self) -> None: with tempfile.NamedTemporaryFile() as f: f.write(b"dummy") f.flush() - artifact_id = upload_artifact(study, f.name, artifact_store) + artifact_id = upload_artifact( + study_or_trial=study, file_path=f.name, artifact_store=artifact_store + ) app = create_app(storage, artifact_store) @@ -575,7 +577,9 @@ def test_delete_study_without_removing_artifacts(self) -> None: with tempfile.NamedTemporaryFile() as f: f.write(b"dummy") f.flush() - artifact_id = upload_artifact(study, f.name, artifact_store) + artifact_id = upload_artifact( + study_or_trial=study, file_path=f.name, artifact_store=artifact_store + ) app = create_app(storage, artifact_store) diff --git a/python_tests/wsgi_client.py b/python_tests/wsgi_client.py index 0bf527bfa..eab252c17 100644 --- a/python_tests/wsgi_client.py +++ b/python_tests/wsgi_client.py @@ -10,6 +10,8 @@ if typing.TYPE_CHECKING: + from sys import _OptExcInfo + from _typeshed.wsgi import WSGIEnvironment @@ -64,7 +66,9 @@ def send_request( status: str = "" response_headers: list[tuple[str, str]] = [] - def start_response(status_: str, headers_: list[tuple[str, str]]) -> None: + def start_response( + status_: str, headers_: list[tuple[str, str]], exc_info: _OptExcInfo | None = None + ) -> None: nonlocal status, response_headers status = status_ response_headers = headers_