Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion azure/functions/_http_wsgi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import logging
import re

from io import BytesIO, StringIO
from os import linesep
from typing import Dict, List, Optional, Any
Expand All @@ -17,6 +19,7 @@ def wsgi_encoding_dance(value):

class WsgiRequest:
_environ_cache: Optional[Dict[str, Any]] = None
_logger = logging.getLogger('azure.functions.WsgiMiddleware')

def __init__(self,
func_req: HttpRequest,
Expand Down Expand Up @@ -113,7 +116,18 @@ def to_environ(self, errors_buffer: StringIO) -> Dict[str, Any]:
def _get_port(self, parsed_url, lowercased_headers: Dict[str, str]) -> int:
port: int = 80
if lowercased_headers.get('x-forwarded-port'):
return int(lowercased_headers['x-forwarded-port'])
# Split on commas in case of multiple proxy hops
parts = [p.strip() for p in lowercased_headers['x-forwarded-port'].split(',')]

for part in parts:
# Extract leading number (port must start with digits)
match = re.match(r"(\d+)", part)
if match:
port = int(match.group(1))
return port
# If no valid port found, log a warning
self._logger.warning("Invalid X-Forwarded-Port header value: %s. "
"Using default port 80", part)
elif getattr(parsed_url, 'port', None):
return int(parsed_url.port)
elif parsed_url.scheme == 'https':
Expand Down
45 changes: 45 additions & 0 deletions tests/test_http_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,51 @@ def test_request_protocol_by_header(self):
self.assertEqual(environ['SERVER_PORT'], str(8081))
self.assertEqual(environ['wsgi.url_scheme'], 'https')

def test_request_protocol_by_header_hostlike(self):
func_request = self._generate_func_request(headers={
"x-forwarded-port": "443.example.com"
})
error_buffer = StringIO()
environ = WsgiRequest(func_request).to_environ(error_buffer)
self.assertEqual(environ['SERVER_PORT'], str(443))
self.assertEqual(environ['wsgi.url_scheme'], 'https')

def test_request_protocol_by_header_unusual_tokens(self):
func_request = self._generate_func_request(headers={
"x-forwarded-port": "443;proto=https"
})
error_buffer = StringIO()
environ = WsgiRequest(func_request).to_environ(error_buffer)
self.assertEqual(environ['SERVER_PORT'], str(443))
self.assertEqual(environ['wsgi.url_scheme'], 'https')

def test_request_protocol_by_header_with_multiple_ports(self):
func_request = self._generate_func_request(headers={
"x-forwarded-port": "443,8080,433"
})
error_buffer = StringIO()
environ = WsgiRequest(func_request).to_environ(error_buffer)
self.assertEqual(environ['SERVER_PORT'], str(443))
self.assertEqual(environ['wsgi.url_scheme'], 'https')

def test_request_protocol_by_header_with_spaces(self):
func_request = self._generate_func_request(headers={
"x-forwarded-port": " 8443 , 8080 "
})
error_buffer = StringIO()
environ = WsgiRequest(func_request).to_environ(error_buffer)
self.assertEqual(environ['SERVER_PORT'], str(8443))
self.assertEqual(environ['wsgi.url_scheme'], 'https')

def test_request_protocol_by_header_invalid(self):
func_request = self._generate_func_request(headers={
"x-forwarded-port": "abc"
})
error_buffer = StringIO()
environ = WsgiRequest(func_request).to_environ(error_buffer)
self.assertEqual(environ['SERVER_PORT'], str(80))
self.assertEqual(environ['wsgi.url_scheme'], 'https')

def test_request_protocol_by_scheme(self):
func_request = self._generate_func_request(url="http://a.b.com")
error_buffer = StringIO()
Expand Down
Loading