Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trust_number_of_proxies to ProxyHeadersMiddleware #2508

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion tests/middleware/test_proxy_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,17 @@ async def default_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISend
def make_httpx_client(
trusted_hosts: str | list[str],
client: tuple[str, int] = ("127.0.0.1", 123),
trust_number_of_proxies: int = 0,
) -> httpx.AsyncClient:
"""Create async client for use in test cases.

Args:
trusted_hosts: trusted_hosts for proxy middleware
client: transport client to use
trust_number_of_proxies: number of proxies to trust
"""

app = ProxyHeadersMiddleware(default_app, trusted_hosts)
app = ProxyHeadersMiddleware(default_app, trusted_hosts, trust_number_of_proxies)
transport = httpx.ASGITransport(app=app, client=client) # type: ignore
return httpx.AsyncClient(transport=transport, base_url="http://testserver")

Expand Down Expand Up @@ -492,3 +494,25 @@ async def test_proxy_headers_empty_x_forwarded_for() -> None:
response = await client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "https://127.0.0.1:123"


@pytest.mark.anyio
@pytest.mark.parametrize(
("x_forwarded_for", "trust_number_of_proxies", "expected"),
[
("", 0, "http://127.0.0.1:123"),
("", 1, "http://127.0.0.1:123"),
("192.168.0.0, 192.168.0.1, 192.168.0.2", 0, "http://127.0.0.1:123"),
("192.168.0.0, 192.168.0.1, 192.168.0.2", 1, "http://192.168.0.2:0"),
("192.168.0.0, 192.168.0.1, 192.168.0.2", 2, "http://192.168.0.1:0"),
("192.168.0.0, 192.168.0.1, 192.168.0.2", 3, "http://192.168.0.0:0"),
("192.168.0.0, 192.168.0.1, 192.168.0.2", 4, "http://192.168.0.0:0"),
("192.168.0.0, 192.168.0.1, 192.168.0.2", 5, "http://192.168.0.0:0"),
],
)
async def test_trust_number_of_proxies(x_forwarded_for: str, trust_number_of_proxies: int, expected: str) -> None:
async with make_httpx_client([], trust_number_of_proxies=trust_number_of_proxies) as client:
headers = {X_FORWARDED_FOR: x_forwarded_for}
response = await client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == expected
35 changes: 30 additions & 5 deletions uvicorn/middleware/proxy_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,27 @@ class ProxyHeadersMiddleware:
Modifies the `client` and `scheme` information so that they reference
the connecting client, rather that the connecting proxy.

You can pass through a list of trusted hosts via the `trusted_hosts`
parameter, which can be either a list or single entry. Each entry can be
either a host ("127.0.0.1") or a network ("192.168.0.0/24"). An entry of
"*" means that all hosts are trusted.

Alternatively, if you know how many proxies are in front of your application
you can pass the `trust_number_of_proxies` parameter to only trust the first
N proxies. In this case you probably want to pass an empty list for
`trusted_hosts`.

References:
- <https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies>
- <https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For>
"""

def __init__(self, app: ASGI3Application, trusted_hosts: list[str] | str = "127.0.0.1") -> None:
def __init__(
self, app: ASGI3Application, trusted_hosts: list[str] | str = "127.0.0.1", trust_number_of_proxies: int = 0
) -> None:
self.app = app
self.trusted_hosts = _TrustedHosts(trusted_hosts)
self.trust_number_of_proxies = trust_number_of_proxies
self.trusted_hosts = _TrustedHosts(trusted_hosts, trust_number_of_proxies)

async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
if scope["type"] == "lifespan":
Expand All @@ -31,7 +44,7 @@ async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGIS
client_addr = scope.get("client")
client_host = client_addr[0] if client_addr else None

if client_host in self.trusted_hosts:
if self.trust_number_of_proxies > 0 or client_host in self.trusted_hosts:
headers = dict(scope["headers"])

if b"x-forwarded-proto" in headers:
Expand Down Expand Up @@ -67,13 +80,17 @@ def _parse_raw_hosts(value: str) -> list[str]:
class _TrustedHosts:
"""Container for trusted hosts and networks"""

def __init__(self, trusted_hosts: list[str] | str) -> None:
def __init__(self, trusted_hosts: list[str] | str, trust_number_of_proxies: int = 0) -> None:
self.always_trust: bool = trusted_hosts in ("*", ["*"])

self.trusted_literals: set[str] = set()
self.trusted_hosts: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
self.trusted_networks: set[ipaddress.IPv4Network | ipaddress.IPv6Network] = set()

# Always trust the first N proxies, only apply the other arguments after
# bypassing these.
self.trust_number_of_proxies: int = trust_number_of_proxies

# Notes:
# - We separate hosts from literals as there are many ways to write
# an IPv6 Address so we need to compare by object.
Expand Down Expand Up @@ -133,7 +150,15 @@ def get_trusted_client_host(self, x_forwarded_for: str) -> str:
return x_forwarded_for_hosts[0]

# Note: each proxy appends to the header list so check it in reverse order
for host in reversed(x_forwarded_for_hosts):
#
# If we have trust_number_of_proxies set, remember that the 'first' one we are skipping is the
# original source of the request, so we actually remove N-1 proxies from the X-Forwarded-For list.
x_forwarded_skip = self.trust_number_of_proxies - 1
if x_forwarded_skip > len(x_forwarded_for_hosts):
return x_forwarded_for_hosts[0]

hosts_to_check = x_forwarded_for_hosts[: len(x_forwarded_for_hosts) - x_forwarded_skip]
for host in reversed(hosts_to_check):
if host not in self:
return host

Expand Down
Loading