From 947643288ba5b023a6811341673833b5d0440246 Mon Sep 17 00:00:00 2001 From: Mark Zealey Date: Mon, 11 Nov 2024 15:43:04 +0000 Subject: [PATCH] trust_number_of_proxies to ProxyHeadersMiddleware Adds a new setting of `trusted_number_of_proxies` to the `ProxyHeadersMiddleware` that allows the user to specify the number of proxies that are trusted to be present in the request and not otherwise verified. --- tests/middleware/test_proxy_headers.py | 26 ++++++++++++++++++- uvicorn/middleware/proxy_headers.py | 35 ++++++++++++++++++++++---- 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/tests/middleware/test_proxy_headers.py b/tests/middleware/test_proxy_headers.py index 0ade97450..71c3b6a60 100644 --- a/tests/middleware/test_proxy_headers.py +++ b/tests/middleware/test_proxy_headers.py @@ -39,15 +39,17 @@ async def default_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISend def make_httpx_client( trusted_hosts: str | list[str], client: tuple[str, int] = ("127.0.0.1", 123), + trust_number_of_proxies: int = 0, ) -> httpx.AsyncClient: """Create async client for use in test cases. Args: trusted_hosts: trusted_hosts for proxy middleware client: transport client to use + trust_number_of_proxies: number of proxies to trust """ - app = ProxyHeadersMiddleware(default_app, trusted_hosts) + app = ProxyHeadersMiddleware(default_app, trusted_hosts, trust_number_of_proxies) transport = httpx.ASGITransport(app=app, client=client) # type: ignore return httpx.AsyncClient(transport=transport, base_url="http://testserver") @@ -492,3 +494,25 @@ async def test_proxy_headers_empty_x_forwarded_for() -> None: response = await client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "https://127.0.0.1:123" + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("x_forwarded_for", "trust_number_of_proxies", "expected"), + [ + ("", 0, "http://127.0.0.1:123"), + ("", 1, "http://127.0.0.1:123"), + ("192.168.0.0, 192.168.0.1, 192.168.0.2", 0, "http://127.0.0.1:123"), + ("192.168.0.0, 192.168.0.1, 192.168.0.2", 1, "http://192.168.0.2:0"), + ("192.168.0.0, 192.168.0.1, 192.168.0.2", 2, "http://192.168.0.1:0"), + ("192.168.0.0, 192.168.0.1, 192.168.0.2", 3, "http://192.168.0.0:0"), + ("192.168.0.0, 192.168.0.1, 192.168.0.2", 4, "http://192.168.0.0:0"), + ("192.168.0.0, 192.168.0.1, 192.168.0.2", 5, "http://192.168.0.0:0"), + ], +) +async def test_trust_number_of_proxies(x_forwarded_for: str, trust_number_of_proxies: int, expected: str) -> None: + async with make_httpx_client([], trust_number_of_proxies=trust_number_of_proxies) as client: + headers = {X_FORWARDED_FOR: x_forwarded_for} + response = await client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == expected diff --git a/uvicorn/middleware/proxy_headers.py b/uvicorn/middleware/proxy_headers.py index 7c3609de6..5fb39eceb 100644 --- a/uvicorn/middleware/proxy_headers.py +++ b/uvicorn/middleware/proxy_headers.py @@ -15,14 +15,27 @@ class ProxyHeadersMiddleware: Modifies the `client` and `scheme` information so that they reference the connecting client, rather that the connecting proxy. + You can pass through a list of trusted hosts via the `trusted_hosts` + parameter, which can be either a list or single entry. Each entry can be + either a host ("127.0.0.1") or a network ("192.168.0.0/24"). An entry of + "*" means that all hosts are trusted. + + Alternatively, if you know how many proxies are in front of your application + you can pass the `trust_number_of_proxies` parameter to only trust the first + N proxies. In this case you probably want to pass an empty list for + `trusted_hosts`. + References: - - """ - def __init__(self, app: ASGI3Application, trusted_hosts: list[str] | str = "127.0.0.1") -> None: + def __init__( + self, app: ASGI3Application, trusted_hosts: list[str] | str = "127.0.0.1", trust_number_of_proxies: int = 0 + ) -> None: self.app = app - self.trusted_hosts = _TrustedHosts(trusted_hosts) + self.trust_number_of_proxies = trust_number_of_proxies + self.trusted_hosts = _TrustedHosts(trusted_hosts, trust_number_of_proxies) async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: if scope["type"] == "lifespan": @@ -31,7 +44,7 @@ async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGIS client_addr = scope.get("client") client_host = client_addr[0] if client_addr else None - if client_host in self.trusted_hosts: + if self.trust_number_of_proxies > 0 or client_host in self.trusted_hosts: headers = dict(scope["headers"]) if b"x-forwarded-proto" in headers: @@ -67,13 +80,17 @@ def _parse_raw_hosts(value: str) -> list[str]: class _TrustedHosts: """Container for trusted hosts and networks""" - def __init__(self, trusted_hosts: list[str] | str) -> None: + def __init__(self, trusted_hosts: list[str] | str, trust_number_of_proxies: int = 0) -> None: self.always_trust: bool = trusted_hosts in ("*", ["*"]) self.trusted_literals: set[str] = set() self.trusted_hosts: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set() self.trusted_networks: set[ipaddress.IPv4Network | ipaddress.IPv6Network] = set() + # Always trust the first N proxies, only apply the other arguments after + # bypassing these. + self.trust_number_of_proxies: int = trust_number_of_proxies + # Notes: # - We separate hosts from literals as there are many ways to write # an IPv6 Address so we need to compare by object. @@ -133,7 +150,15 @@ def get_trusted_client_host(self, x_forwarded_for: str) -> str: return x_forwarded_for_hosts[0] # Note: each proxy appends to the header list so check it in reverse order - for host in reversed(x_forwarded_for_hosts): + # + # If we have trust_number_of_proxies set, remember that the 'first' one we are skipping is the + # original source of the request, so we actually remove N-1 proxies from the X-Forwarded-For list. + x_forwarded_skip = self.trust_number_of_proxies - 1 + if x_forwarded_skip > len(x_forwarded_for_hosts): + return x_forwarded_for_hosts[0] + + hosts_to_check = x_forwarded_for_hosts[: len(x_forwarded_for_hosts) - x_forwarded_skip] + for host in reversed(hosts_to_check): if host not in self: return host