diff --git a/mautrix/api.py b/mautrix/api.py index d034149b..b8ee2897 100644 --- a/mautrix/api.py +++ b/mautrix/api.py @@ -16,7 +16,7 @@ import platform import time -from aiohttp import ClientResponse, ClientSession, __version__ as aiohttp_version +from aiohttp import ClientResponse, ClientSession, UnixConnector, __version__ as aiohttp_version from aiohttp.client_exceptions import ClientError, ContentTypeError from yarl import URL @@ -229,9 +229,15 @@ def __init__( self.base_url = URL(base_url) self.token = token self.log = log or logging.getLogger("mau.http") - self.session = client_session or ClientSession( - loop=loop, headers={"User-Agent": self.default_ua} - ) + if client_session: + self.session = client_session + else: + connector = None + if base_url.startswith("unix://"): + connector = UnixConnector(path=base_url.replace("unix://", "")) + self.session = ClientSession( + loop=loop, headers={"User-Agent": self.default_ua}, connector=connector + ) self.as_user_id = as_user_id self.as_device_id = as_device_id if txn_id is not None: diff --git a/mautrix/appservice/appservice.py b/mautrix/appservice/appservice.py index 892db195..c877e942 100644 --- a/mautrix/appservice/appservice.py +++ b/mautrix/appservice/appservice.py @@ -152,7 +152,11 @@ async def __aexit__(self) -> None: async def start(self, host: str = "127.0.0.1", port: int = 8080) -> None: await self.state_store.open() self.log.debug(f"Starting appservice web server on {host}:{port}") - if self.server.startswith("https://") and not self.verify_ssl: + if self.server.startswith("unix://"): + path = self.server.replace("unix://", "") + self.server = "http://localhost" + connector = aiohttp.UnixConnector(limit=self.connection_limit, path=path) + elif self.server.startswith("https://") and not self.verify_ssl: connector = aiohttp.TCPConnector(limit=self.connection_limit, verify_ssl=False) else: connector = aiohttp.TCPConnector(limit=self.connection_limit) @@ -176,7 +180,10 @@ async def start(self, host: str = "127.0.0.1", port: int = 8080) -> None: ssl_ctx.load_cert_chain(self.tls_cert, self.tls_key) self.runner = web.AppRunner(self.app) await self.runner.setup() - site = web.TCPSite(self.runner, host, port, ssl_context=ssl_ctx) + if host.startswith("/"): + site = web.UnixSite(self.runner, host) + else: + site = web.TCPSite(self.runner, host, port, ssl_context=ssl_ctx) await site.start() async def stop(self) -> None: