From 8b1ddd5eb1f2cbb9e6203d74d389685ec07581cc Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Thu, 28 Mar 2019 14:50:16 +0200 Subject: [PATCH 1/5] move server class into separate unit --- mta-sts-daemon | 177 +--------------------------- postfix_mta_sts_resolver/server.py | 179 +++++++++++++++++++++++++++++ 2 files changed, 180 insertions(+), 176 deletions(-) create mode 100644 postfix_mta_sts_resolver/server.py diff --git a/mta-sts-daemon b/mta-sts-daemon index 251924a..47d51ca 100755 --- a/mta-sts-daemon +++ b/mta-sts-daemon @@ -5,11 +5,8 @@ import argparse import asyncio import postfix_mta_sts_resolver.utils as utils import postfix_mta_sts_resolver.defaults as defaults -import pynetstring import yaml -from postfix_mta_sts_resolver.resolver import * -import collections -import time +from postfix_mta_sts_resolver.server import STSSocketmapResponder import logging @@ -72,178 +69,6 @@ def populate_cfg_defaults(cfg): return cfg -ZoneEntry = collections.namedtuple('ZoneEntry', ('strict', 'resolver')) - - -CacheEntry = collections.namedtuple('CacheEntry', ('ts', 'pol_id', 'pol_body')) - - -class STSSocketmapResponder(object): - def __init__(self, cfg, loop): - self._loop = loop - - # Construct configurations and resolvers for every socketmap name - self._default_zone = ZoneEntry(cfg["default_zone"]["strict_testing"], - STSResolver(loop=loop, - timeout=cfg["default_zone"]["timeout"])) - - self._zones = dict((k, ZoneEntry(zone["strict_testing"], - STSResolver(loop=loop, - timeout=zone["timeout"]))) - for k, zone in cfg["zones"].items()) - - # Construct cache - if cfg["cache"]["type"] == "internal": - import postfix_mta_sts_resolver.internal_cache - capacity = cfg["cache"]["options"]["cache_size"] - self._cache = postfix_mta_sts_resolver.internal_cache.InternalLRUCache(capacity) - else: - raise NotImplementedError("Unsupported cache type!") - - async def sender(self, queue, writer): - logger = logging.getLogger("STS") - try: - while True: - fut = await queue.get() - - # Check for shutdown - if fut is None: - writer.close() - return - - logger.debug("Got new future from queue") - try: - data = await fut - except asyncio.CancelledError: - writer.close() - return - except Exception as e: - logging.exception("Unhandled exception from future: %s", e) - writer.close() - return - logger.debug("Future await complete: data=%s", repr(data)) - writer.write(data) - logger.debug("Wrote: %s", repr(data)) - await writer.drain() - except asyncio.CancelledError: - try: - fut.cancel() - except: - pass - while not queue.empty(): - task = queue.get_nowait() - task.cancel() - - async def process_request(self, raw_req): - have_policy = True - - # Parse request and canonicalize domain - req_zone, _, req_domain = raw_req.decode('latin-1').partition(' ') - - domain = req_domain - - # Skip lookups for parent domain policies - # Skip lookups to non-recepient domains or non-domains at all - if domain.startswith('.') or domain.startswith('[') or ':' in domain: - return pynetstring.encode('NOTFOUND ') - - # Normalize domain name - domain = req_domain.lower().strip().rstrip('.') - - # Find appropriate zone config - if req_zone in self._zones: - zone_cfg = self._zones[req_zone] - else: - zone_cfg = self._default_zone - - # Lookup for cached policy - cached = await self._cache.get(domain) - - # Check if newer policy exists or - # retrieve policy from scratch if there is no cached one - if cached is None: - latest_pol_id = None - else: - latest_pol_id = cached.pol_id - status, policy = await zone_cfg.resolver.resolve(domain, latest_pol_id) - - # Update local cache - ts = time.time() - if status is STSFetchResult.NOT_CHANGED: - cached = CacheEntry(ts, cached.pol_id, cached.pol_body) - await self._cache.set(domain, cached) - elif status is STSFetchResult.VALID: - pol_id, pol_body = policy - cached = CacheEntry(ts, pol_id, pol_body) - await self._cache.set(domain, cached) - else: - if cached is None: - have_policy = False - else: - # Check if cached policy is expired - if cached.pol_body['max_age'] + cached.ts < ts: - have_policy = False - - - if have_policy: - mode = cached.pol_body['mode'] - if mode == 'none' or (mode == 'testing' and not zone_cfg.strict): - return pynetstring.encode('NOTFOUND ') - else: - assert cached.pol_body['mx'], "Empty MX list for restrictive policy!" - mxlist = [mx.lstrip('*') for mx in set(cached.pol_body['mx'])] - resp = "OK secure match=" + ":".join(mxlist) - return pynetstring.encode(resp) - else: - return pynetstring.encode('NOTFOUND ') - - - def enqueue_request(self, queue, raw_req): - fut = asyncio.ensure_future(self.process_request(raw_req), loop=self._loop) - queue.put_nowait(fut) - - async def handle_msg(self, reader, writer): - logger = logging.getLogger("STS") - - # Construct netstring parser - self._decoder = pynetstring.Decoder() - - # Construct queue for responses ordering - queue = asyncio.Queue(0, loop=self._loop) - - # Create coroutine which awaits for steady responses and sends them - sender = asyncio.ensure_future(self.sender(queue, writer), loop=self._loop) - - def cleanup(): - sender.cancel() - writer.close() - - while True: - try: - part = await reader.read(4096) - logger.debug("Read: %s", repr(part)) - except asyncio.CancelledError as e: - cleanup() - return - except ConnectionError as e: - cleanup() - return - if not part: - cleanup() - return - - try: - requests = self._decoder.feed(part) - except: - # Bad protocol. Do shutdown - queue.put_nowait(None) - await sender - else: - for req in requests: - logger.debug("Enq request: %s", repr(req)) - self.enqueue_request(queue, req) - - def main(): # Parse command line arguments and setup basic logging args = parse_args() diff --git a/postfix_mta_sts_resolver/server.py b/postfix_mta_sts_resolver/server.py new file mode 100644 index 0000000..25ab28a --- /dev/null +++ b/postfix_mta_sts_resolver/server.py @@ -0,0 +1,179 @@ +import asyncio +import pynetstring +import logging +import time +import collections + +from postfix_mta_sts_resolver.resolver import * + + +ZoneEntry = collections.namedtuple('ZoneEntry', ('strict', 'resolver')) + + +CacheEntry = collections.namedtuple('CacheEntry', ('ts', 'pol_id', 'pol_body')) + + +class STSSocketmapResponder(object): + def __init__(self, cfg, loop): + self._loop = loop + + # Construct configurations and resolvers for every socketmap name + self._default_zone = ZoneEntry(cfg["default_zone"]["strict_testing"], + STSResolver(loop=loop, + timeout=cfg["default_zone"]["timeout"])) + + self._zones = dict((k, ZoneEntry(zone["strict_testing"], + STSResolver(loop=loop, + timeout=zone["timeout"]))) + for k, zone in cfg["zones"].items()) + + # Construct cache + if cfg["cache"]["type"] == "internal": + import postfix_mta_sts_resolver.internal_cache + capacity = cfg["cache"]["options"]["cache_size"] + self._cache = postfix_mta_sts_resolver.internal_cache.InternalLRUCache(capacity) + else: + raise NotImplementedError("Unsupported cache type!") + + async def sender(self, queue, writer): + logger = logging.getLogger("STS") + try: + while True: + fut = await queue.get() + + # Check for shutdown + if fut is None: + writer.close() + return + + logger.debug("Got new future from queue") + try: + data = await fut + except asyncio.CancelledError: + writer.close() + return + except Exception as e: + logging.exception("Unhandled exception from future: %s", e) + writer.close() + return + logger.debug("Future await complete: data=%s", repr(data)) + writer.write(data) + logger.debug("Wrote: %s", repr(data)) + await writer.drain() + except asyncio.CancelledError: + try: + fut.cancel() + except: + pass + while not queue.empty(): + task = queue.get_nowait() + task.cancel() + + async def process_request(self, raw_req): + have_policy = True + + # Parse request and canonicalize domain + req_zone, _, req_domain = raw_req.decode('latin-1').partition(' ') + + domain = req_domain + + # Skip lookups for parent domain policies + # Skip lookups to non-recepient domains or non-domains at all + if domain.startswith('.') or domain.startswith('[') or ':' in domain: + return pynetstring.encode('NOTFOUND ') + + # Normalize domain name + domain = req_domain.lower().strip().rstrip('.') + + # Find appropriate zone config + if req_zone in self._zones: + zone_cfg = self._zones[req_zone] + else: + zone_cfg = self._default_zone + + # Lookup for cached policy + cached = await self._cache.get(domain) + + # Check if newer policy exists or + # retrieve policy from scratch if there is no cached one + if cached is None: + latest_pol_id = None + else: + latest_pol_id = cached.pol_id + status, policy = await zone_cfg.resolver.resolve(domain, latest_pol_id) + + # Update local cache + ts = time.time() + if status is STSFetchResult.NOT_CHANGED: + cached = CacheEntry(ts, cached.pol_id, cached.pol_body) + await self._cache.set(domain, cached) + elif status is STSFetchResult.VALID: + pol_id, pol_body = policy + cached = CacheEntry(ts, pol_id, pol_body) + await self._cache.set(domain, cached) + else: + if cached is None: + have_policy = False + else: + # Check if cached policy is expired + if cached.pol_body['max_age'] + cached.ts < ts: + have_policy = False + + + if have_policy: + mode = cached.pol_body['mode'] + if mode == 'none' or (mode == 'testing' and not zone_cfg.strict): + return pynetstring.encode('NOTFOUND ') + else: + assert cached.pol_body['mx'], "Empty MX list for restrictive policy!" + mxlist = [mx.lstrip('*') for mx in set(cached.pol_body['mx'])] + resp = "OK secure match=" + ":".join(mxlist) + return pynetstring.encode(resp) + else: + return pynetstring.encode('NOTFOUND ') + + + def enqueue_request(self, queue, raw_req): + fut = asyncio.ensure_future(self.process_request(raw_req), loop=self._loop) + queue.put_nowait(fut) + + async def handle_msg(self, reader, writer): + logger = logging.getLogger("STS") + + # Construct netstring parser + self._decoder = pynetstring.Decoder() + + # Construct queue for responses ordering + queue = asyncio.Queue(0, loop=self._loop) + + # Create coroutine which awaits for steady responses and sends them + sender = asyncio.ensure_future(self.sender(queue, writer), loop=self._loop) + + def cleanup(): + sender.cancel() + writer.close() + + while True: + try: + part = await reader.read(4096) + logger.debug("Read: %s", repr(part)) + except asyncio.CancelledError as e: + cleanup() + return + except ConnectionError as e: + cleanup() + return + if not part: + cleanup() + return + + try: + requests = self._decoder.feed(part) + except: + # Bad protocol. Do shutdown + queue.put_nowait(None) + await sender + else: + for req in requests: + logger.debug("Enq request: %s", repr(req)) + self.enqueue_request(queue, req) From 3e8397cc30fdae4f5d77cc6e89c14349c55332f4 Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Thu, 28 Mar 2019 15:46:26 +0200 Subject: [PATCH 2/5] improved server start/stop --- mta-sts-daemon | 73 +++++++++++++++++++----------- postfix_mta_sts_resolver/server.py | 41 +++++++++++++---- 2 files changed, 77 insertions(+), 37 deletions(-) diff --git a/mta-sts-daemon b/mta-sts-daemon index 47d51ca..17c4cf1 100755 --- a/mta-sts-daemon +++ b/mta-sts-daemon @@ -3,11 +3,14 @@ import sys import argparse import asyncio +import yaml +import logging +import signal +from functools import partial + import postfix_mta_sts_resolver.utils as utils import postfix_mta_sts_resolver.defaults as defaults -import yaml from postfix_mta_sts_resolver.server import STSSocketmapResponder -import logging def parse_args(): @@ -69,6 +72,43 @@ def populate_cfg_defaults(cfg): return cfg +def exit_handler(exit_event, signum, frame): + logger = logging.getLogger('MAIN') + if exit_event.is_set(): + logger.warning("Got second exit signal! Terminating hard.") + os._exit(1) + else: + logger.warning("Got first exit signal! Terminating gracefully.") + exit_event.set() + + +async def heartbeat(): + """ Hacky coroutine which keeps event loop spinning with some interval + even if no events are coming. This is required to handle Futures and + Events state change when no events are occuring.""" + while True: + await asyncio.sleep(.5) + + +async def amain(cfg, loop): + logger = logging.getLogger("MAIN") + # Construct request handler instance + responder = STSSocketmapResponder(cfg, loop) + + await responder.start() + logger.info("Server started.") + + exit_event = asyncio.Event() + beat = asyncio.ensure_future(heartbeat()) + sig_handler = partial(exit_handler, exit_event) + signal.signal(signal.SIGTERM, sig_handler) + signal.signal(signal.SIGINT, sig_handler) + await exit_event.wait() + logger.debug("Eventloop interrupted. Shutting down server...") + beat.cancel() + await responder.stop() + + def main(): # Parse command line arguments and setup basic logging args = parse_args() @@ -92,31 +132,10 @@ def main(): evloop = asyncio.get_event_loop() mainLogger.info("Eventloop started.") - # Construct request handler instance - responder = STSSocketmapResponder(cfg, evloop) - - # Start server - start_server = asyncio.start_server(responder.handle_msg, - cfg['host'], - cfg['port'], - loop=evloop) - server = evloop.run_until_complete(start_server) - mainLogger.info("Server started.") - - try: - evloop.run_forever() - except KeyboardInterrupt: - # Handle interruption: shutdown properly - mainLogger.info("Got exit signal. " - "Press Ctrl+C again to stop waiting connections to close.") - server.close() - try: - evloop.run_until_complete(server.wait_closed()) - except KeyboardInterrupt: - pass - finally: - mainLogger.info("Server finished its work.") - evloop.close() + + evloop.run_until_complete(amain(cfg, evloop)) + evloop.close() + mainLogger.info("Server finished its work.") if __name__ == '__main__': diff --git a/postfix_mta_sts_resolver/server.py b/postfix_mta_sts_resolver/server.py index 25ab28a..e393d0b 100644 --- a/postfix_mta_sts_resolver/server.py +++ b/postfix_mta_sts_resolver/server.py @@ -3,6 +3,7 @@ import logging import time import collections +import weakref from postfix_mta_sts_resolver.resolver import * @@ -15,7 +16,10 @@ class STSSocketmapResponder(object): def __init__(self, cfg, loop): + self._logger = logging.getLogger("STS") self._loop = loop + self._host = cfg['host'] + self._port = cfg['port'] # Construct configurations and resolvers for every socketmap name self._default_zone = ZoneEntry(cfg["default_zone"]["strict_testing"], @@ -34,9 +38,28 @@ def __init__(self, cfg, loop): self._cache = postfix_mta_sts_resolver.internal_cache.InternalLRUCache(capacity) else: raise NotImplementedError("Unsupported cache type!") + self._children = weakref.WeakSet() + + async def start(self): + def _spawn(reader, writer): + self._children.add( + self._loop.create_task(self.handler(reader, writer))) + + self._server = await asyncio.start_server(_spawn, + self._host, + self._port) + + async def stop(self): + self._server.close() + await self._server.wait_closed() + if self._children: + self._logger.debug("Cancelling %d client handlers...", + len(self._children)) + for task in self._children: + task.cancel() + await asyncio.wait(self._children) async def sender(self, queue, writer): - logger = logging.getLogger("STS") try: while True: fut = await queue.get() @@ -46,19 +69,19 @@ async def sender(self, queue, writer): writer.close() return - logger.debug("Got new future from queue") + self._logger.debug("Got new future from queue") try: data = await fut except asyncio.CancelledError: writer.close() return except Exception as e: - logging.exception("Unhandled exception from future: %s", e) + self._logger.exception("Unhandled exception from future: %s", e) writer.close() return - logger.debug("Future await complete: data=%s", repr(data)) + self._logger.debug("Future await complete: data=%s", repr(data)) writer.write(data) - logger.debug("Wrote: %s", repr(data)) + self._logger.debug("Wrote: %s", repr(data)) await writer.drain() except asyncio.CancelledError: try: @@ -137,9 +160,7 @@ def enqueue_request(self, queue, raw_req): fut = asyncio.ensure_future(self.process_request(raw_req), loop=self._loop) queue.put_nowait(fut) - async def handle_msg(self, reader, writer): - logger = logging.getLogger("STS") - + async def handler(self, reader, writer): # Construct netstring parser self._decoder = pynetstring.Decoder() @@ -156,7 +177,7 @@ def cleanup(): while True: try: part = await reader.read(4096) - logger.debug("Read: %s", repr(part)) + self._logger.debug("Read: %s", repr(part)) except asyncio.CancelledError as e: cleanup() return @@ -175,5 +196,5 @@ def cleanup(): await sender else: for req in requests: - logger.debug("Enq request: %s", repr(req)) + self._logger.debug("Enq request: %s", repr(req)) self.enqueue_request(queue, req) From 8f4f3ef9afe6c6d5ea6d5a13da2b5b0acdee0616 Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Thu, 28 Mar 2019 16:50:42 +0200 Subject: [PATCH 3/5] rewrite server handler code for sake of stability --- postfix_mta_sts_resolver/constants.py | 3 ++ postfix_mta_sts_resolver/resolver.py | 5 +- postfix_mta_sts_resolver/server.py | 77 +++++++++++++++++---------- 3 files changed, 52 insertions(+), 33 deletions(-) create mode 100644 postfix_mta_sts_resolver/constants.py diff --git a/postfix_mta_sts_resolver/constants.py b/postfix_mta_sts_resolver/constants.py new file mode 100644 index 0000000..55826d4 --- /dev/null +++ b/postfix_mta_sts_resolver/constants.py @@ -0,0 +1,3 @@ +HARD_RESP_LIMIT = 64 * 1024 +CHUNK = 4096 +QUEUE_LIMIT = 128 diff --git a/postfix_mta_sts_resolver/resolver.py b/postfix_mta_sts_resolver/resolver.py index 4600be2..ebeaa1f 100644 --- a/postfix_mta_sts_resolver/resolver.py +++ b/postfix_mta_sts_resolver/resolver.py @@ -6,10 +6,7 @@ from . import defaults from .utils import parse_mta_sts_record, parse_mta_sts_policy, is_plaintext, \ filter_text - - -HARD_RESP_LIMIT = 64 * 1024 -CHUNK = 4096 +from .constants import * class BadSTSPolicy(Exception): diff --git a/postfix_mta_sts_resolver/server.py b/postfix_mta_sts_resolver/server.py index e393d0b..df3caeb 100644 --- a/postfix_mta_sts_resolver/server.py +++ b/postfix_mta_sts_resolver/server.py @@ -5,7 +5,8 @@ import collections import weakref -from postfix_mta_sts_resolver.resolver import * +from .resolver import * +from .constants import * ZoneEntry = collections.namedtuple('ZoneEntry', ('strict', 'resolver')) @@ -155,46 +156,64 @@ async def process_request(self, raw_req): else: return pynetstring.encode('NOTFOUND ') - - def enqueue_request(self, queue, raw_req): - fut = asyncio.ensure_future(self.process_request(raw_req), loop=self._loop) - queue.put_nowait(fut) - async def handler(self, reader, writer): # Construct netstring parser self._decoder = pynetstring.Decoder() # Construct queue for responses ordering - queue = asyncio.Queue(0, loop=self._loop) + queue = asyncio.Queue(QUEUE_LIMIT, loop=self._loop) # Create coroutine which awaits for steady responses and sends them sender = asyncio.ensure_future(self.sender(queue, writer), loop=self._loop) - def cleanup(): - sender.cancel() - writer.close() + class ParserInvokationError(Exception): + pass + + class EndOfStream(Exception): + pass - while True: + async def finalize(): try: - part = await reader.read(4096) + await queue.put(None) + except asyncio.CancelledError: + sender.cancel() + raise + await sender + + try: + while True: + #Extract and parse requests + part = await reader.read(CHUNK) + if not part: + raise EndOfStream() self._logger.debug("Read: %s", repr(part)) - except asyncio.CancelledError as e: - cleanup() - return - except ConnectionError as e: - cleanup() - return - if not part: - cleanup() - return + try: + requests = self._decoder.feed(part) + except: + raise ParserInvokationError("Bad netstring protocol.") + pass - try: - requests = self._decoder.feed(part) - except: - # Bad protocol. Do shutdown - queue.put_nowait(None) - await sender - else: + # Enqueue tasks for received requests for req in requests: self._logger.debug("Enq request: %s", repr(req)) - self.enqueue_request(queue, req) + fut = asyncio.ensure_future(self.process_request(req), loop=self._loop) + await queue.put(fut) + except ParserInvokationError: + self._logger.warning("Bad netstring message received") + await finalize() + except (EndOfStream, ConnectionError, TimeoutError): + self._logger.debug("Client disconnected") + await finalize() + except OSError as e: + if e.errno == 107: + self._logger.debug("Client disconnected") + await finalize() + else: + self._logger.exception("Unhandled exception: %s", e) + except asyncio.CancelledError: + sender.cancel() + raise + except Exception as e: + self._logger.exception("Unhandled exception: %s", e) + finally: + writer.close() From 0b6b8fec2139d25f61cfa8bc40cd13ede477658c Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Thu, 28 Mar 2019 16:55:41 +0200 Subject: [PATCH 4/5] rename server.py to responder.py --- mta-sts-daemon | 2 +- postfix_mta_sts_resolver/{server.py => responder.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename postfix_mta_sts_resolver/{server.py => responder.py} (100%) diff --git a/mta-sts-daemon b/mta-sts-daemon index 17c4cf1..f41f888 100755 --- a/mta-sts-daemon +++ b/mta-sts-daemon @@ -10,7 +10,7 @@ from functools import partial import postfix_mta_sts_resolver.utils as utils import postfix_mta_sts_resolver.defaults as defaults -from postfix_mta_sts_resolver.server import STSSocketmapResponder +from postfix_mta_sts_resolver.responder import STSSocketmapResponder def parse_args(): diff --git a/postfix_mta_sts_resolver/server.py b/postfix_mta_sts_resolver/responder.py similarity index 100% rename from postfix_mta_sts_resolver/server.py rename to postfix_mta_sts_resolver/responder.py From 2c12e2af0c5fb932d067a9cb46c568d58f3be5cb Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Thu, 28 Mar 2019 17:06:34 +0200 Subject: [PATCH 5/5] use finalize() for unhandled exceptions too --- postfix_mta_sts_resolver/responder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/postfix_mta_sts_resolver/responder.py b/postfix_mta_sts_resolver/responder.py index df3caeb..841210d 100644 --- a/postfix_mta_sts_resolver/responder.py +++ b/postfix_mta_sts_resolver/responder.py @@ -210,10 +210,12 @@ async def finalize(): await finalize() else: self._logger.exception("Unhandled exception: %s", e) + await finalize() except asyncio.CancelledError: sender.cancel() raise except Exception as e: self._logger.exception("Unhandled exception: %s", e) + await finalize() finally: writer.close()