diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index b1579a796..d746a55ca 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -82,20 +82,33 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) - def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str, str, str]: + def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str, str, int]: proto = scope.get("scheme", "http") - header_host = self._get_header_value_by_name(scope, "host") - if header_host is None: - domain, port = scope["server"] - else: + # Assume default port based on protocol, can be overridden later + port = 443 if proto == "https" else 80 + + if header_host := self._get_header_value_by_name(scope, "host"): header_host_parts = header_host.split(":") + domain = header_host_parts[0] if len(header_host_parts) == 2: - domain, port = header_host_parts - else: - domain = header_host_parts[0] - port = None - - port_str = None # make sure it is defined in all paths since we access it later + with contextlib.suppress(ValueError): + port = int(header_host_parts[1]) + else: + # Not sure when we would not have a host header, but fallback to server info + domain, port = scope["server"] + port = int(port) + + forwarded_port: Optional[str] = None + forwarding_occurred = any( + key + in [ + b"forwarded", + b"x-forwarded-proto", + b"x-forwarded-host", + b"x-forwarded-port", + ] + for key, _ in scope["headers"] + ) if forwarded := self._get_header_value_by_name(scope, "forwarded"): for proxy in forwarded.split(","): @@ -103,15 +116,21 @@ def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str, str, str]: proto = proto_expr.group("proto") if host_expr := _HOST_HEADER_REGEX.search(proxy): domain = host_expr.group("host") - port_str = host_expr.group("port") # None if not present in the match + forwarded_port = host_expr.group("port") # None if not present else: - domain = self._get_header_value_by_name(scope, "x-forwarded-host", domain) - proto = self._get_header_value_by_name(scope, "x-forwarded-proto", proto) - port_str = self._get_header_value_by_name(scope, "x-forwarded-port", port) - - with contextlib.suppress(ValueError): # ignore ports that are not valid integers - port = int(port_str) if port_str is not None else port + domain = self._get_header_value_by_name(scope, "x-forwarded-host") or domain + proto = self._get_header_value_by_name(scope, "x-forwarded-proto") or proto + forwarded_port = self._get_header_value_by_name(scope, "x-forwarded-port") + + if forwarding_occurred and not forwarded_port: + # If forwarding occurred but no port was specified, use protocol default + forwarded_port = "443" if proto == "https" else "80" + + if forwarded_port: + # ignore ports that are not valid integers + with contextlib.suppress(ValueError): + port = int(forwarded_port) return (proto, domain, port) diff --git a/stac_fastapi/api/tests/test_middleware.py b/stac_fastapi/api/tests/test_middleware.py index 4ce97ff80..a07cc57f7 100644 --- a/stac_fastapi/api/tests/test_middleware.py +++ b/stac_fastapi/api/tests/test_middleware.py @@ -75,8 +75,8 @@ def test_replace_header_value_by_name( "scope,expected", [ ( - {"scheme": "https", "server": ["testserver", 80], "headers": []}, - ("https", "testserver", 80), + {"scheme": "https", "server": ["testserver", 8000], "headers": []}, + ("https", "testserver", 8000), ), ( { @@ -92,7 +92,7 @@ def test_replace_header_value_by_name( "server": ["testserver", 80], "headers": [(b"host", b"testserver")], }, - ("http", "testserver", None), + ("http", "testserver", 80), ), ( { @@ -108,7 +108,7 @@ def test_replace_header_value_by_name( "server": ["testserver", 80], "headers": [(b"forwarded", b"proto=https;host=test:not-an-integer")], }, - ("https", "test", 80), + ("https", "test", 443), ), ( { @@ -124,7 +124,7 @@ def test_replace_header_value_by_name( "server": ["testserver", 80], "headers": [(b"x-forwarded-proto", b"https")], }, - ("https", "testserver", 80), + ("https", "testserver", 443), ), ( { @@ -191,11 +191,90 @@ def test_replace_header_value_by_name( ( b"forwarded", # proto is set, but no host - b'for="85.193.181.55";proto=https,for="85.193.181.55";proto=https', + b'for="85.193.181.55";proto=https', ) ], }, - ("https", "testserver", 80), + ("https", "testserver", 443), + ), + ( + { + "scheme": "http", + "server": ["testserver", 8080], + "headers": [ + (b"host", b"example.com:8080"), + ( + b"forwarded", + # Forwarded header with proto and host but no port + # Should use default port for https (443), not the original host + # port (8080) + b"proto=https;host=myproxy.com", + ), + ], + }, + ("https", "myproxy.com", 443), + ), + ( + { + "scheme": "http", + "server": ["testserver", 8080], + "headers": [ + (b"host", b"example.com:8080"), + ( + b"forwarded", + # Forwarded header with proto and host but no port + # Should use default port for http (80), not the original host + # port (8080) + b"proto=http;host=myproxy.com", + ), + ], + }, + ("http", "myproxy.com", 80), + ), + ( + { + "scheme": "http", + "server": ["testserver", 8080], + "headers": [ + (b"host", b"example.com:8080"), + (b"x-forwarded-proto", b"https"), + (b"x-forwarded-host", b"myproxy.com"), + # No x-forwarded-port header + # Should use default port for https (443), not the original host + # port (8080) + ], + }, + ("https", "myproxy.com", 443), + ), + ( + { + "scheme": "http", + "server": ["testserver", 8080], + "headers": [ + (b"host", b"example.com:8080"), + (b"x-forwarded-proto", b"http"), + (b"x-forwarded-host", b"myproxy.com"), + # No x-forwarded-port header + # Should use default port for http (80), not the original host port + # (8080) + ], + }, + ("http", "myproxy.com", 80), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [ + (b"host", b"testserver"), + (b"x-forwarded-proto", b"https"), + # No x-forwarded-host (domain stays the same) + # No x-forwarded-port header + # Should use default port for https (443), not the original port (80) + # because the protocol changed + ], + }, + ("https", "testserver", 443), ), ], )