diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0d22388..e674a7c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -51,6 +51,16 @@ jobs: matrix: python-version: ["3.9", "3.10"] + services: + postgres: + image: postgres:13 + env: + POSTGRES_PASSWORD: guest + POSTGRES_USER: guest + POSTGRES_DB: guest + ports: + - 5432:5432 + steps: - uses: actions/checkout@v2 diff --git a/asyncpg_engine.py b/asyncpg_engine.py index 63ed07e..ecc9d7d 100644 --- a/asyncpg_engine.py +++ b/asyncpg_engine.py @@ -1 +1,80 @@ -DUMMY = True +from types import TracebackType +from typing import Any, Generator, Optional, Type + +from asyncpg import Connection, create_pool +from asyncpg.pool import Pool + +__all__ = ["Engine"] + + +class Engine: + __slots__ = "_pool", "_use_single_con", "_global_con" + + _pool: Pool + _use_single_con: bool + _global_con: Optional[Connection] + + def __init__(self, pool: Pool, use_single_connection: bool): + self._pool = pool + self._use_single_con = use_single_connection + + self._global_con = None + + @classmethod + async def create( + cls, url: str, *, use_single_connection: bool = False, **kwargs: Any + ) -> "Engine": + pool = await create_pool(url, min_size=2, init=cls._set_codecs, **kwargs) + return cls(pool, use_single_connection) + + async def close(self) -> None: + await self._pool.close() + + def acquire(self) -> "ConnectionAcquire": + return ConnectionAcquire(self) + + async def _acquire(self) -> Connection: + if self._use_single_con: + self._global_con = self._global_con or await self._pool.acquire() + return self._global_con + + return await self._pool.acquire() + + async def release(self, con: Connection, *, force: bool = False) -> None: + if not self._use_single_con or force: + await self._pool.release(con) + self._global_con = None + + async def healthcheck(self) -> None: + async with self.acquire() as con: # type: Connection + await con.execute("SELECT 1") + + @staticmethod + async def _set_codecs(con: Connection) -> None: + """Override this method if you want to set custom codecs.""" + + +class ConnectionAcquire: + __slots__ = "engine", "con" + + engine: Engine + con: Optional[Connection] + + def __init__(self, engine: Engine): + self.engine = engine + self.con = None + + async def __call__(self) -> Connection: + self.con = await self.engine._acquire() + return self.con + + def __await__(self) -> Generator[Connection, None, None]: + return self().__await__() + + async def __aenter__(self) -> Connection: + return await self() + + async def __aexit__( + self, exc_type: Type[BaseException], exc: BaseException, tv: TracebackType + ) -> None: + await self.engine.release(self.con) diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..f7a1b8e --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,13 @@ +version: '3.1' + +services: + + postgres: + image: postgres:13 + container_name: upmarket-cc-partners + environment: + POSTGRES_PASSWORD: guest + POSTGRES_USER: guest + POSTGRES_DB: guest + ports: + - 5432:5432 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..f90be9e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,10 @@ +import pytest + +_config = { + "POSTGRES_URL": "postgres://guest:guest@localhost:5432/guest?sslmode=disable" +} + + +@pytest.fixture() +def config() -> dict[str, str]: + return _config.copy() diff --git a/tests/test_dummy.py b/tests/test_dummy.py deleted file mode 100644 index 4d81e1a..0000000 --- a/tests/test_dummy.py +++ /dev/null @@ -1,7 +0,0 @@ -import pytest - -pytestmark = [pytest.mark.asyncio] - - -def test_dummy() -> None: - assert True diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000..8b909a5 --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,63 @@ +import pytest +from asyncpg import InterfaceError + +from asyncpg_engine import Engine + +pytestmark = [pytest.mark.asyncio] + + +@pytest.fixture() +async def engine(config: dict[str, str]) -> Engine: + return await Engine.create(config["POSTGRES_URL"]) + + +async def test_returns_new_connection_every_acquire(engine: Engine) -> None: + async with engine.acquire() as con_0: + async with engine.acquire() as con_1: + + assert con_0 != con_1 + + +async def test_returns_the_same_connection_every_acquire_if_single( + config: dict[str, str] +) -> None: + engine = await Engine.create(config["POSTGRES_URL"], use_single_connection=True) + + async with engine.acquire() as con_0: + async with engine.acquire() as con_1: + + assert con_0 == con_1 + + +async def test_non_force_release_ignored_for_single_connection( + config: dict[str, str] +) -> None: + engine = await Engine.create(config["POSTGRES_URL"], use_single_connection=True) + async with engine.acquire() as con_0: + pass + async with engine.acquire() as con_1: + + assert con_0 is con_1 + + await engine.release(con_0, force=True) + await engine.release(con_1, force=True) + + +async def test_closes_well(engine: Engine) -> None: + await engine.close() + + with pytest.raises(InterfaceError): + await engine.acquire() + + +async def test_healthcheck_returns_nothing(engine: Engine) -> None: + got = await engine.healthcheck() + + assert got is None + + +async def test_healthcheck_raises_if_something_went_wrong(engine: Engine) -> None: + await engine.close() + + with pytest.raises(InterfaceError): + await engine.healthcheck()