-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
aebe3e9
commit f973522
Showing
6 changed files
with
176 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,80 @@ | ||
DUMMY = True | ||
from types import TracebackType | ||
from typing import Any, Generator, Optional, Type | ||
|
||
from asyncpg import Connection, create_pool | ||
from asyncpg.pool import Pool | ||
|
||
__all__ = ["Engine"] | ||
|
||
|
||
class Engine: | ||
__slots__ = "_pool", "_use_single_con", "_global_con" | ||
|
||
_pool: Pool | ||
_use_single_con: bool | ||
_global_con: Optional[Connection] | ||
|
||
def __init__(self, pool: Pool, use_single_connection: bool): | ||
self._pool = pool | ||
self._use_single_con = use_single_connection | ||
|
||
self._global_con = None | ||
|
||
@classmethod | ||
async def create( | ||
cls, url: str, *, use_single_connection: bool = False, **kwargs: Any | ||
) -> "Engine": | ||
pool = await create_pool(url, min_size=2, init=cls._set_codecs, **kwargs) | ||
return cls(pool, use_single_connection) | ||
|
||
async def close(self) -> None: | ||
await self._pool.close() | ||
|
||
def acquire(self) -> "ConnectionAcquire": | ||
return ConnectionAcquire(self) | ||
|
||
async def _acquire(self) -> Connection: | ||
if self._use_single_con: | ||
self._global_con = self._global_con or await self._pool.acquire() | ||
return self._global_con | ||
|
||
return await self._pool.acquire() | ||
|
||
async def release(self, con: Connection, *, force: bool = False) -> None: | ||
if not self._use_single_con or force: | ||
await self._pool.release(con) | ||
self._global_con = None | ||
|
||
async def healthcheck(self) -> None: | ||
async with self.acquire() as con: # type: Connection | ||
await con.execute("SELECT 1") | ||
|
||
@staticmethod | ||
async def _set_codecs(con: Connection) -> None: | ||
"""Override this method if you want to set custom codecs.""" | ||
|
||
|
||
class ConnectionAcquire: | ||
__slots__ = "engine", "con" | ||
|
||
engine: Engine | ||
con: Optional[Connection] | ||
|
||
def __init__(self, engine: Engine): | ||
self.engine = engine | ||
self.con = None | ||
|
||
async def __call__(self) -> Connection: | ||
self.con = await self.engine._acquire() | ||
return self.con | ||
|
||
def __await__(self) -> Generator[Connection, None, None]: | ||
return self().__await__() | ||
|
||
async def __aenter__(self) -> Connection: | ||
return await self() | ||
|
||
async def __aexit__( | ||
self, exc_type: Type[BaseException], exc: BaseException, tv: TracebackType | ||
) -> None: | ||
await self.engine.release(self.con) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
version: '3.1' | ||
|
||
services: | ||
|
||
postgres: | ||
image: postgres:13 | ||
container_name: upmarket-cc-partners | ||
environment: | ||
POSTGRES_PASSWORD: guest | ||
POSTGRES_USER: guest | ||
POSTGRES_DB: guest | ||
ports: | ||
- 5432:5432 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import pytest | ||
|
||
_config = { | ||
"POSTGRES_URL": "postgres://guest:guest@localhost:5432/guest?sslmode=disable" | ||
} | ||
|
||
|
||
@pytest.fixture() | ||
def config() -> dict[str, str]: | ||
return _config.copy() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import pytest | ||
from asyncpg import InterfaceError | ||
|
||
from asyncpg_engine import Engine | ||
|
||
pytestmark = [pytest.mark.asyncio] | ||
|
||
|
||
@pytest.fixture() | ||
async def engine(config: dict[str, str]) -> Engine: | ||
return await Engine.create(config["POSTGRES_URL"]) | ||
|
||
|
||
async def test_returns_new_connection_every_acquire(engine: Engine) -> None: | ||
async with engine.acquire() as con_0: | ||
async with engine.acquire() as con_1: | ||
|
||
assert con_0 != con_1 | ||
|
||
|
||
async def test_returns_the_same_connection_every_acquire_if_single( | ||
config: dict[str, str] | ||
) -> None: | ||
engine = await Engine.create(config["POSTGRES_URL"], use_single_connection=True) | ||
|
||
async with engine.acquire() as con_0: | ||
async with engine.acquire() as con_1: | ||
|
||
assert con_0 == con_1 | ||
|
||
|
||
async def test_non_force_release_ignored_for_single_connection( | ||
config: dict[str, str] | ||
) -> None: | ||
engine = await Engine.create(config["POSTGRES_URL"], use_single_connection=True) | ||
async with engine.acquire() as con_0: | ||
pass | ||
async with engine.acquire() as con_1: | ||
|
||
assert con_0 is con_1 | ||
|
||
await engine.release(con_0, force=True) | ||
await engine.release(con_1, force=True) | ||
|
||
|
||
async def test_closes_well(engine: Engine) -> None: | ||
await engine.close() | ||
|
||
with pytest.raises(InterfaceError): | ||
await engine.acquire() | ||
|
||
|
||
async def test_healthcheck_returns_nothing(engine: Engine) -> None: | ||
got = await engine.healthcheck() | ||
|
||
assert got is None | ||
|
||
|
||
async def test_healthcheck_raises_if_something_went_wrong(engine: Engine) -> None: | ||
await engine.close() | ||
|
||
with pytest.raises(InterfaceError): | ||
await engine.healthcheck() |