diff --git a/config_examples/README.md b/config_examples/README.md index 9a5db38..70dbf94 100644 --- a/config_examples/README.md +++ b/config_examples/README.md @@ -15,6 +15,8 @@ Reference for configuration values: * `cache_size`: (int) number of cache entries to store in memory * Options for `sqlite` type: * `filename`: (str) path to database file + * `threads`: (int) number of threads in pool for SQLite connections + * `timeout`: (float) timeout in seconds for acquiring connection from pool or DB lock * Options for `redis` type: * All parameters are passed to [aioredis.create_redis_pool](https://aioredis.readthedocs.io/en/latest/api_reference.html#aioredis.create_redis_pool). Use it for parameter reference. * `default_zone`: diff --git a/postfix_mta_sts_resolver/defaults.py b/postfix_mta_sts_resolver/defaults.py index bba05d8..28910b2 100644 --- a/postfix_mta_sts_resolver/defaults.py +++ b/postfix_mta_sts_resolver/defaults.py @@ -1,3 +1,5 @@ +from multiprocessing import cpu_count + HOST = "127.0.0.1" PORT = 8461 REUSE_PORT = True @@ -7,5 +9,7 @@ CONFIG_LOCATION = "/etc/postfix/mta-sts-daemon.yml" CACHE_BACKEND = "internal" INTERNAL_CACHE_SIZE = 10000 +SQLITE_THREADS = cpu_count() +SQLITE_TIMEOUT = 5 REDIS_TIMEOUT = 5 CACHE_GRACE = 60 diff --git a/postfix_mta_sts_resolver/sqlite_cache.py b/postfix_mta_sts_resolver/sqlite_cache.py index 16d3b43..7f25514 100644 --- a/postfix_mta_sts_resolver/sqlite_cache.py +++ b/postfix_mta_sts_resolver/sqlite_cache.py @@ -1,31 +1,101 @@ +import asyncio import aiosqlite import sqlite3 import json import logging +from .utils import _anext +from .defaults import SQLITE_THREADS, SQLITE_TIMEOUT from .base_cache import BaseCache, CacheEntry +class SqliteConnPool: + def __init__(self, threads, conn_args=(), conn_kwargs={}, init_queries=()): + self._threads = threads + self._conn_args = conn_args + self._conn_kwargs = conn_kwargs + self._init_queries = init_queries + self._free_conns = asyncio.Queue() + self._ready = False + self._stopped = False + + async def _new_conn(self): + async def gen(): + async with aiosqlite.connect(*self._conn_args, **self._conn_kwargs) as c: + for q in self._init_queries: + await c.execute(q) + yield c + it = gen() + return it, await _anext(it) + + async def prepare(self): + for _ in range(self._threads): + self._free_conns.put_nowait(await self._new_conn()) + self._ready = True + + async def stop(self): + self._ready = False + self._stopped = True + try: + while True: + g, db = self._free_conns.get_nowait() + await _anext(g, None) + except asyncio.QueueEmpty: + pass + + def borrow(self, timeout=None): + #assert self._ready + class PoolBorrow: + async def __aenter__(s): + s._conn = await asyncio.wait_for(self._free_conns.get(), + timeout) + return s._conn[1] + + async def __aexit__(s, exc_type, exc, tb): + if self._stopped: + await _anext(s._conn[0], None) + return + if exc_type is not None: + await _anext(s._conn[0], None) + s._conn = self._new_conn() + self._free_conns.put_nowait(s._conn) + return PoolBorrow() + + class SqliteCache(BaseCache): - def __init__(self, filename): + def __init__(self, filename, *, + threads=SQLITE_THREADS, timeout=SQLITE_TIMEOUT): self._filename = filename + self._threads = threads + self._timeout = timeout sqlitelogger = logging.getLogger("aiosqlite") if not sqlitelogger.hasHandlers(): sqlitelogger.addHandler(logging.NullHandler()) async def setup(self): + conn_init = [ + "PRAGMA journal_mode=WAL", + "PRAGMA synchronous=NORMAL", + ] + self._pool = SqliteConnPool(self._threads, + conn_args=(self._filename,), + conn_kwargs={ + "timeout": self._timeout, + }, + init_queries=conn_init) + await self._pool.prepare() queries = [ - "create table if not exists sts_policy_cache (domain text, ts integer, pol_id text, pol_body text)", - "create unique index if not exists sts_policy_domain on sts_policy_cache (domain)", - "create index if not exists sts_policy_domain_ts on sts_policy_cache (domain, ts)", + "create table if not exists sts_policy_cache (domain text, ts integer, pol_id text, pol_body text)", + "create unique index if not exists sts_policy_domain on sts_policy_cache (domain)", + "create index if not exists sts_policy_domain_ts on sts_policy_cache (domain, ts)", ] - async with aiosqlite.connect(self._filename) as db: + async with self._pool.borrow(self._timeout) as db: for q in queries: await db.execute(q) await db.commit() async def get(self, key): - async with aiosqlite.connect(self._filename) as db: + async with self._pool.borrow(self._timeout) as db: async with db.execute('select ts, pol_id, pol_body from ' 'sts_policy_cache where domain=?', (key,)) as cur: @@ -41,7 +111,7 @@ async def get(self, key): async def set(self, key, value): ts, pol_id, pol_body = value pol_body = json.dumps(pol_body) - async with aiosqlite.connect(self._filename) as db: + async with self._pool.borrow(self._timeout) as db: try: await db.execute('insert into sts_policy_cache (domain, ts, ' 'pol_id, pol_body) values (?, ?, ?, ?)', @@ -55,4 +125,4 @@ async def set(self, key, value): await db.commit() async def teardown(self): - pass + await self._pool.stop() diff --git a/postfix_mta_sts_resolver/utils.py b/postfix_mta_sts_resolver/utils.py index f287caf..ee80d12 100644 --- a/postfix_mta_sts_resolver/utils.py +++ b/postfix_mta_sts_resolver/utils.py @@ -183,3 +183,17 @@ def create_cache(type, options): else: raise NotImplementedError("Unsupported cache type!") return cache + +class NoDefault: + pass + +NODEFAULT = NoDefault() + +async def _anext(gen, default=NODEFAULT): + try: + return await gen.__anext__() + except StopAsyncIteration: + if default is NODEFAULT: + raise + else: + return default