diff --git a/.gitignore b/.gitignore index 96e4ceb9..73c64f1f 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ dist/ pip-selfcheck.json *.pyc __pycache__ +.idea \ No newline at end of file diff --git a/mautrix/api.py b/mautrix/api.py index d034149b..f833ecbe 100644 --- a/mautrix/api.py +++ b/mautrix/api.py @@ -16,7 +16,7 @@ import platform import time -from aiohttp import ClientResponse, ClientSession, __version__ as aiohttp_version +from aiohttp import ClientResponse, ClientSession, __version__ as aiohttp_version, UnixConnector from aiohttp.client_exceptions import ClientError, ContentTypeError from yarl import URL @@ -229,9 +229,15 @@ def __init__( self.base_url = URL(base_url) self.token = token self.log = log or logging.getLogger("mau.http") - self.session = client_session or ClientSession( - loop=loop, headers={"User-Agent": self.default_ua} - ) + if client_session: + self.session = client_session + else: + connector = None + if base_url.startswith('unix://'): + connector = UnixConnector(path=base_url.replace('unix://', '')) + self.session = ClientSession( + loop=loop, headers={"User-Agent": self.default_ua}, connector=connector + ) self.as_user_id = as_user_id self.as_device_id = as_device_id if txn_id is not None: diff --git a/mautrix/appservice/appservice.py b/mautrix/appservice/appservice.py index 892db195..aff6d278 100644 --- a/mautrix/appservice/appservice.py +++ b/mautrix/appservice/appservice.py @@ -152,7 +152,11 @@ async def __aexit__(self) -> None: async def start(self, host: str = "127.0.0.1", port: int = 8080) -> None: await self.state_store.open() self.log.debug(f"Starting appservice web server on {host}:{port}") - if self.server.startswith("https://") and not self.verify_ssl: + if self.server.startswith("unix://"): + path = self.server.replace('unix://', '') + self.server = 'http://localhost' + connector = aiohttp.UnixConnector(limit=self.connection_limit, path=path) + elif self.server.startswith("https://") and not self.verify_ssl: connector = aiohttp.TCPConnector(limit=self.connection_limit, verify_ssl=False) else: connector = aiohttp.TCPConnector(limit=self.connection_limit) @@ -176,7 +180,10 @@ async def start(self, host: str = "127.0.0.1", port: int = 8080) -> None: ssl_ctx.load_cert_chain(self.tls_cert, self.tls_key) self.runner = web.AppRunner(self.app) await self.runner.setup() - site = web.TCPSite(self.runner, host, port, ssl_context=ssl_ctx) + if host.startswith("/"): + site = web.UnixSite(self.runner, host) + else: + site = web.TCPSite(self.runner, host, port, ssl_context=ssl_ctx) await site.start() async def stop(self) -> None: