Skip to content

Commit

Permalink
Copying x-forwarded-functions from !gh/sdfidk/stac-fastapi.
Browse files Browse the repository at this point in the history
  • Loading branch information
elvios committed Nov 4, 2024
1 parent 9b67faa commit ddcf63f
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 48 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ RUN conda run -n webproj python -m pip install --no-deps .
RUN conda run -n webproj pyproj sync --source-id dk_sdfe --target-dir $WEBPROJ_LIB
RUN conda run -n webproj pyproj sync --source-id dk_sdfi --target-dir $WEBPROJ_LIB

CMD ["conda", "run", "-n", "webproj", "uvicorn", "--proxy-headers", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
CMD ["conda", "run", "-n", "webproj", "uvicorn", "app.main:app", "--proxy-headers", "--host", "0.0.0.0", "--port", "80"]

EXPOSE 80
131 changes: 84 additions & 47 deletions webproj/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,62 +11,99 @@
Altered to accomodate x-forwarded-host instead of x-forwarded-for
Altered: 27-01-2022
"""
import re
from typing import List, Optional, Tuple, Union

from http.client import HTTP_PORT, HTTPS_PORT
from starlette.types import ASGIApp, Receive, Scope, Send

Headers = List[Tuple[bytes, bytes]]


class ProxyHeadersMiddleware:
def __init__(self, app, trusted_hosts: Union[List[str], str] = "127.0.0.1") -> None:
self.app = app
if isinstance(trusted_hosts, str):
self.trusted_hosts = {item.strip() for item in trusted_hosts.split(",")}
else:
self.trusted_hosts = set(trusted_hosts)
self.always_trust = "*" in self.trusted_hosts
class ProxyHeaderMiddleware:
"""Account for forwarding headers when deriving base URL.
def remap_headers(self, src: Headers, before: bytes, after: bytes) -> Headers:
remapped = []
before_value = None
after_value = None
for header in src:
k, v = header
if k == before:
before_value = v
continue
elif k == after:
after_value = v
continue
remapped.append(header)
if after_value:
remapped.append((before, after_value))
elif before_value:
remapped.append((before, before_value))
return remapped
Prioritise standard Forwarded header, look for non-standard X-Forwarded-* if missing.
Default to what can be derived from the URL if no headers provided. Middleware updates
the host header that is interpreted by starlette when deriving Request.base_url.
"""

async def __call__(self, scope, receive, send) -> None:
if scope["type"] in ("http", "websocket"):
def __init__(self, app: ASGIApp):
"""Create proxy header middleware."""
self.app = app

client_addr: Optional[Tuple[str, int]] = scope.get("client")
client_host = client_addr[0] if client_addr else None
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Call from stac-fastapi framework."""
if scope["type"] == "http":
proto, domain, port = self._get_forwarded_url_parts(scope)
scope["scheme"] = proto
if domain is not None:
port_suffix = ""
if port is not None:
if (proto == "http" and port != HTTP_PORT) or (
proto == "https" and port != HTTPS_PORT
):
port_suffix = f":{port}"
scope["headers"] = self._replace_header_value_by_name(
scope,
"host",
f"{domain}{port_suffix}",
)
await self.app(scope, receive, send)

def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]:
proto = scope.get("scheme", "http")
header_host = self._get_header_value_by_name(scope, "host")
if header_host is None:
domain, port = scope.get("server")
else:
header_host_parts = header_host.split(":")
if len(header_host_parts) == 2:
domain, port = header_host_parts
else:
domain = header_host_parts[0]
port = None
forwarded = self._get_header_value_by_name(scope, "forwarded")
if forwarded is not None:
parts = forwarded.split(";")
for part in parts:
if len(part) > 0 and re.search("=", part):
key, value = part.split("=")
if key == "proto":
proto = value
elif key == "host":
host_parts = value.split(":")
domain = host_parts[0]
try:
port = int(host_parts[1]) if len(host_parts) == 2 else None
except ValueError:
# ignore ports that are not valid integers
pass
else:
proto = self._get_header_value_by_name(scope, "x-forwarded-proto", proto)
port_str = self._get_header_value_by_name(scope, "x-forwarded-port", port)
try:
port = int(port_str) if port_str is not None else None
except ValueError:
# ignore ports that are not valid integers
pass

if self.always_trust or client_host in self.trusted_hosts:
headers = dict(scope["headers"])
if b"x-forwarded-proto" in headers:
# Determine if the incoming request was http or https based on
# the X-Forwarded-Proto header.
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1")
scope["scheme"] = x_forwarded_proto.strip() # type: ignore[index]
return (proto, domain, port)

if b"x-forwarded-host" in headers:
# Setting scope["server"] is not enough because of https://github.com/encode/starlette/issues/604#issuecomment-543945716
scope["headers"] = self.remap_headers(
scope["headers"], b"host", b"x-forwarded-host"
)
if b"x-forwarded-prefix" in headers:
x_forwarded_prefix = headers[b"x-forwarded-prefix"].decode("latin1")
scope["root_path"] = x_forwarded_prefix
def _get_header_value_by_name(
self, scope: Scope, header_name: str, default_value: str = None
) -> str:
headers = scope["headers"]
candidates = [
value.decode() for key, value in headers if key.decode() == header_name
]
return candidates[0] if len(candidates) == 1 else default_value

return await self.app(scope, receive, send)
@staticmethod
def _replace_header_value_by_name(
scope: Scope, header_name: str, new_value: str
) -> List[Tuple[str]]:
return [
(name, value)
for name, value in scope["headers"]
if name.decode() != header_name
] + [(str.encode(header_name), str.encode(new_value))]

0 comments on commit ddcf63f

Please sign in to comment.