Skip to content

Commit

Permalink
unix socket support
Browse files Browse the repository at this point in the history
  • Loading branch information
cyberb committed Nov 24, 2023
1 parent c29c371 commit bb72d6e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ dist/
pip-selfcheck.json
*.pyc
__pycache__
.idea
14 changes: 10 additions & 4 deletions mautrix/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import platform
import time

from aiohttp import ClientResponse, ClientSession, __version__ as aiohttp_version
from aiohttp import ClientResponse, ClientSession, __version__ as aiohttp_version, UnixConnector
from aiohttp.client_exceptions import ClientError, ContentTypeError
from yarl import URL

Expand Down Expand Up @@ -229,9 +229,15 @@ def __init__(
self.base_url = URL(base_url)
self.token = token
self.log = log or logging.getLogger("mau.http")
self.session = client_session or ClientSession(
loop=loop, headers={"User-Agent": self.default_ua}
)
if client_session:
self.session = client_session
else:
connector = None
if base_url.startswith('unix://'):
connector = UnixConnector(path=base_url.replace('unix://', ''))
self.session = ClientSession(
loop=loop, headers={"User-Agent": self.default_ua}, connector=connector
)
self.as_user_id = as_user_id
self.as_device_id = as_device_id
if txn_id is not None:
Expand Down
11 changes: 9 additions & 2 deletions mautrix/appservice/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ async def __aexit__(self) -> None:
async def start(self, host: str = "127.0.0.1", port: int = 8080) -> None:
await self.state_store.open()
self.log.debug(f"Starting appservice web server on {host}:{port}")
if self.server.startswith("https://") and not self.verify_ssl:
if self.server.startswith("unix://"):
path = self.server.replace('unix://', '')
self.server = 'http://localhost'
connector = aiohttp.UnixConnector(limit=self.connection_limit, path=path)
elif self.server.startswith("https://") and not self.verify_ssl:
connector = aiohttp.TCPConnector(limit=self.connection_limit, verify_ssl=False)
else:
connector = aiohttp.TCPConnector(limit=self.connection_limit)
Expand All @@ -176,7 +180,10 @@ async def start(self, host: str = "127.0.0.1", port: int = 8080) -> None:
ssl_ctx.load_cert_chain(self.tls_cert, self.tls_key)
self.runner = web.AppRunner(self.app)
await self.runner.setup()
site = web.TCPSite(self.runner, host, port, ssl_context=ssl_ctx)
if host.startswith("/"):
site = web.UnixSite(self.runner, host)
else:
site = web.TCPSite(self.runner, host, port, ssl_context=ssl_ctx)
await site.start()

async def stop(self) -> None:
Expand Down

0 comments on commit bb72d6e

Please sign in to comment.