Skip to content

Commit

Permalink
Merge pull request #10 from Snawoot/restruct
Browse files Browse the repository at this point in the history
Restruct & rewrite server handler
  • Loading branch information
Snawoot authored Mar 28, 2019
2 parents 10dafa5 + 2c12e2a commit 4abed9e
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 199 deletions.
234 changes: 39 additions & 195 deletions mta-sts-daemon
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import sys
import argparse
import asyncio
import postfix_mta_sts_resolver.utils as utils
import postfix_mta_sts_resolver.defaults as defaults
import pynetstring
import yaml
from postfix_mta_sts_resolver.resolver import *
import collections
import time
import logging
import signal
from functools import partial

import postfix_mta_sts_resolver.utils as utils
import postfix_mta_sts_resolver.defaults as defaults
from postfix_mta_sts_resolver.responder import STSSocketmapResponder


def parse_args():
Expand Down Expand Up @@ -72,176 +72,41 @@ def populate_cfg_defaults(cfg):
return cfg


ZoneEntry = collections.namedtuple('ZoneEntry', ('strict', 'resolver'))


CacheEntry = collections.namedtuple('CacheEntry', ('ts', 'pol_id', 'pol_body'))

def exit_handler(exit_event, signum, frame):
logger = logging.getLogger('MAIN')
if exit_event.is_set():
logger.warning("Got second exit signal! Terminating hard.")
os._exit(1)
else:
logger.warning("Got first exit signal! Terminating gracefully.")
exit_event.set()

class STSSocketmapResponder(object):
def __init__(self, cfg, loop):
self._loop = loop

# Construct configurations and resolvers for every socketmap name
self._default_zone = ZoneEntry(cfg["default_zone"]["strict_testing"],
STSResolver(loop=loop,
timeout=cfg["default_zone"]["timeout"]))
async def heartbeat():
""" Hacky coroutine which keeps event loop spinning with some interval
even if no events are coming. This is required to handle Futures and
Events state change when no events are occuring."""
while True:
await asyncio.sleep(.5)

self._zones = dict((k, ZoneEntry(zone["strict_testing"],
STSResolver(loop=loop,
timeout=zone["timeout"])))
for k, zone in cfg["zones"].items())

# Construct cache
if cfg["cache"]["type"] == "internal":
import postfix_mta_sts_resolver.internal_cache
capacity = cfg["cache"]["options"]["cache_size"]
self._cache = postfix_mta_sts_resolver.internal_cache.InternalLRUCache(capacity)
else:
raise NotImplementedError("Unsupported cache type!")

async def sender(self, queue, writer):
logger = logging.getLogger("STS")
try:
while True:
fut = await queue.get()

# Check for shutdown
if fut is None:
writer.close()
return

logger.debug("Got new future from queue")
try:
data = await fut
except asyncio.CancelledError:
writer.close()
return
except Exception as e:
logging.exception("Unhandled exception from future: %s", e)
writer.close()
return
logger.debug("Future await complete: data=%s", repr(data))
writer.write(data)
logger.debug("Wrote: %s", repr(data))
await writer.drain()
except asyncio.CancelledError:
try:
fut.cancel()
except:
pass
while not queue.empty():
task = queue.get_nowait()
task.cancel()

async def process_request(self, raw_req):
have_policy = True

# Parse request and canonicalize domain
req_zone, _, req_domain = raw_req.decode('latin-1').partition(' ')

domain = req_domain

# Skip lookups for parent domain policies
# Skip lookups to non-recepient domains or non-domains at all
if domain.startswith('.') or domain.startswith('[') or ':' in domain:
return pynetstring.encode('NOTFOUND ')

# Normalize domain name
domain = req_domain.lower().strip().rstrip('.')

# Find appropriate zone config
if req_zone in self._zones:
zone_cfg = self._zones[req_zone]
else:
zone_cfg = self._default_zone
async def amain(cfg, loop):
logger = logging.getLogger("MAIN")
# Construct request handler instance
responder = STSSocketmapResponder(cfg, loop)

# Lookup for cached policy
cached = await self._cache.get(domain)
await responder.start()
logger.info("Server started.")

# Check if newer policy exists or
# retrieve policy from scratch if there is no cached one
if cached is None:
latest_pol_id = None
else:
latest_pol_id = cached.pol_id
status, policy = await zone_cfg.resolver.resolve(domain, latest_pol_id)

# Update local cache
ts = time.time()
if status is STSFetchResult.NOT_CHANGED:
cached = CacheEntry(ts, cached.pol_id, cached.pol_body)
await self._cache.set(domain, cached)
elif status is STSFetchResult.VALID:
pol_id, pol_body = policy
cached = CacheEntry(ts, pol_id, pol_body)
await self._cache.set(domain, cached)
else:
if cached is None:
have_policy = False
else:
# Check if cached policy is expired
if cached.pol_body['max_age'] + cached.ts < ts:
have_policy = False


if have_policy:
mode = cached.pol_body['mode']
if mode == 'none' or (mode == 'testing' and not zone_cfg.strict):
return pynetstring.encode('NOTFOUND ')
else:
assert cached.pol_body['mx'], "Empty MX list for restrictive policy!"
mxlist = [mx.lstrip('*') for mx in set(cached.pol_body['mx'])]
resp = "OK secure match=" + ":".join(mxlist)
return pynetstring.encode(resp)
else:
return pynetstring.encode('NOTFOUND ')


def enqueue_request(self, queue, raw_req):
fut = asyncio.ensure_future(self.process_request(raw_req), loop=self._loop)
queue.put_nowait(fut)

async def handle_msg(self, reader, writer):
logger = logging.getLogger("STS")

# Construct netstring parser
self._decoder = pynetstring.Decoder()

# Construct queue for responses ordering
queue = asyncio.Queue(0, loop=self._loop)

# Create coroutine which awaits for steady responses and sends them
sender = asyncio.ensure_future(self.sender(queue, writer), loop=self._loop)

def cleanup():
sender.cancel()
writer.close()

while True:
try:
part = await reader.read(4096)
logger.debug("Read: %s", repr(part))
except asyncio.CancelledError as e:
cleanup()
return
except ConnectionError as e:
cleanup()
return
if not part:
cleanup()
return

try:
requests = self._decoder.feed(part)
except:
# Bad protocol. Do shutdown
queue.put_nowait(None)
await sender
else:
for req in requests:
logger.debug("Enq request: %s", repr(req))
self.enqueue_request(queue, req)
exit_event = asyncio.Event()
beat = asyncio.ensure_future(heartbeat())
sig_handler = partial(exit_handler, exit_event)
signal.signal(signal.SIGTERM, sig_handler)
signal.signal(signal.SIGINT, sig_handler)
await exit_event.wait()
logger.debug("Eventloop interrupted. Shutting down server...")
beat.cancel()
await responder.stop()


def main():
Expand All @@ -267,31 +132,10 @@ def main():
evloop = asyncio.get_event_loop()
mainLogger.info("Eventloop started.")

# Construct request handler instance
responder = STSSocketmapResponder(cfg, evloop)

# Start server
start_server = asyncio.start_server(responder.handle_msg,
cfg['host'],
cfg['port'],
loop=evloop)
server = evloop.run_until_complete(start_server)
mainLogger.info("Server started.")

try:
evloop.run_forever()
except KeyboardInterrupt:
# Handle interruption: shutdown properly
mainLogger.info("Got exit signal. "
"Press Ctrl+C again to stop waiting connections to close.")
server.close()
try:
evloop.run_until_complete(server.wait_closed())
except KeyboardInterrupt:
pass
finally:
mainLogger.info("Server finished its work.")
evloop.close()

evloop.run_until_complete(amain(cfg, evloop))
evloop.close()
mainLogger.info("Server finished its work.")


if __name__ == '__main__':
Expand Down
3 changes: 3 additions & 0 deletions postfix_mta_sts_resolver/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
HARD_RESP_LIMIT = 64 * 1024
CHUNK = 4096
QUEUE_LIMIT = 128
5 changes: 1 addition & 4 deletions postfix_mta_sts_resolver/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
from . import defaults
from .utils import parse_mta_sts_record, parse_mta_sts_policy, is_plaintext, \
filter_text


HARD_RESP_LIMIT = 64 * 1024
CHUNK = 4096
from .constants import *


class BadSTSPolicy(Exception):
Expand Down
Loading

0 comments on commit 4abed9e

Please sign in to comment.