-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test(export): correct unit tests for `VariantStudyService.generate_ta…
…sk` method
- Loading branch information
Showing
28 changed files
with
1,452 additions
and
375 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,107 +1,19 @@ | ||
import time | ||
from datetime import datetime, timedelta, timezone | ||
from functools import wraps | ||
from pathlib import Path | ||
from typing import Any, Callable, Dict, List, cast | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
import pytest | ||
from antarest.core.model import SUB_JSON | ||
from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db | ||
from antarest.dbmodel import Base | ||
from sqlalchemy import create_engine # type: ignore | ||
|
||
# noinspection PyUnresolvedReferences | ||
from tests.conftest_db import * | ||
|
||
# noinspection PyUnresolvedReferences | ||
from tests.conftest_services import * | ||
|
||
# fmt: off | ||
HERE = Path(__file__).parent.resolve() | ||
PROJECT_DIR = next(iter(p for p in HERE.parents if p.joinpath("antarest").exists())) | ||
# fmt: on | ||
|
||
|
||
@pytest.fixture | ||
@pytest.fixture(scope="session") | ||
def project_path() -> Path: | ||
return PROJECT_DIR | ||
|
||
|
||
def with_db_context(f: Callable[..., Any]) -> Callable[..., Any]: | ||
@wraps(f) | ||
def wrapper(*args: Any, **kwargs: Any) -> Any: | ||
engine = create_engine("sqlite:///:memory:", echo=False) | ||
Base.metadata.create_all(engine) | ||
# noinspection SpellCheckingInspection | ||
DBSessionMiddleware( | ||
None, | ||
custom_engine=engine, | ||
session_args={"autocommit": False, "autoflush": False}, | ||
) | ||
with db(): | ||
return f(*args, **kwargs) | ||
|
||
return wrapper | ||
|
||
|
||
def _assert_dict(a: Dict[str, Any], b: Dict[str, Any]) -> None: | ||
if a.keys() != b.keys(): | ||
raise AssertionError( | ||
f"study level has not the same keys {a.keys()} != {b.keys()}" | ||
) | ||
for k, v in a.items(): | ||
assert_study(v, b[k]) | ||
|
||
|
||
def _assert_list(a: List[Any], b: List[Any]) -> None: | ||
for i, j in zip(a, b): | ||
assert_study(i, j) | ||
|
||
|
||
def _assert_pointer_path(a: str, b: str) -> None: | ||
# pointer is like studyfile://study-id/a/b/c | ||
# we should compare a/b/c only | ||
if a.split("/")[4:] != b.split("/")[4:]: | ||
raise AssertionError(f"element in study not the same {a} != {b}") | ||
|
||
|
||
def _assert_others(a: Any, b: Any) -> None: | ||
if a != b: | ||
raise AssertionError(f"element in study not the same {a} != {b}") | ||
|
||
|
||
def _assert_array( | ||
a: npt.NDArray[np.float64], | ||
b: npt.NDArray[np.float64], | ||
) -> None: | ||
if not (a == b).all(): | ||
raise AssertionError(f"element in study not the same {a} != {b}") | ||
|
||
|
||
def assert_study(a: SUB_JSON, b: SUB_JSON) -> None: | ||
if isinstance(a, dict) and isinstance(b, dict): | ||
_assert_dict(a, b) | ||
elif isinstance(a, list) and isinstance(b, list): | ||
_assert_list(a, b) | ||
elif ( | ||
isinstance(a, str) | ||
and isinstance(b, str) | ||
and "studyfile://" in a | ||
and "studyfile://" in b | ||
): | ||
_assert_pointer_path(a, b) | ||
elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): | ||
_assert_array(a, b) | ||
elif isinstance(a, np.ndarray) and isinstance(b, list): | ||
_assert_list(cast(List[float], a.tolist()), b) | ||
elif isinstance(a, list) and isinstance(b, np.ndarray): | ||
_assert_list(a, cast(List[float], b.tolist())) | ||
else: | ||
_assert_others(a, b) | ||
|
||
|
||
def auto_retry_assert( | ||
predicate: Callable[..., bool], timeout: int = 2, delay: float = 0.2 | ||
) -> None: | ||
threshold = datetime.now(timezone.utc) + timedelta(seconds=timeout) | ||
while datetime.now(timezone.utc) < threshold: | ||
if predicate(): | ||
return | ||
time.sleep(delay) | ||
raise AssertionError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import contextlib | ||
from typing import Any, Generator | ||
|
||
import pytest | ||
from sqlalchemy import create_engine # type: ignore | ||
from sqlalchemy.orm import sessionmaker | ||
|
||
from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware | ||
from antarest.dbmodel import Base | ||
|
||
__all__ = ("db_engine_fixture", "db_session_fixture", "db_middleware_fixture") | ||
|
||
|
||
@pytest.fixture(name="db_engine") | ||
def db_engine_fixture() -> Generator[Any, None, None]: | ||
""" | ||
Fixture that creates an in-memory SQLite database engine for testing. | ||
Yields: | ||
An instance of the created SQLite database engine. | ||
""" | ||
engine = create_engine("sqlite:///:memory:") | ||
Base.metadata.create_all(engine) | ||
yield engine | ||
engine.dispose() | ||
|
||
|
||
@pytest.fixture(name="db_session") | ||
def db_session_fixture(db_engine) -> Generator: | ||
""" | ||
Fixture that creates a database session for testing purposes. | ||
This fixture uses the provided db engine fixture to create a session maker, | ||
which in turn generates a new database session bound to the specified engine. | ||
Args: | ||
db_engine: The database engine instance provided by the db_engine fixture. | ||
Yields: | ||
A new SQLAlchemy session object for database operations. | ||
""" | ||
make_session = sessionmaker(bind=db_engine) | ||
with contextlib.closing(make_session()) as session: | ||
yield session | ||
|
||
|
||
@pytest.fixture(name="db_middleware", autouse=True) | ||
def db_middleware_fixture( | ||
db_engine: Any, | ||
) -> Generator[DBSessionMiddleware, None, None]: | ||
""" | ||
Fixture that sets up a database session middleware with custom engine settings. | ||
Args: | ||
db_engine: The database engine instance created by the db_engine fixture. | ||
Yields: | ||
An instance of the configured DBSessionMiddleware. | ||
""" | ||
yield DBSessionMiddleware( | ||
None, | ||
custom_engine=db_engine, | ||
session_args={"autocommit": False, "autoflush": False}, | ||
) |
Oops, something went wrong.