diff --git a/examples/pubsub_eventing.py b/examples/pubsub_eventing.py index 3e247c38eb..d5bcd12f67 100644 --- a/examples/pubsub_eventing.py +++ b/examples/pubsub_eventing.py @@ -9,22 +9,17 @@ :license: BSD, see LICENSE for more details. """ import time -import threading import multiprocessing import logging from typing import Dict, Any -from proxy.core.event import EventQueue, EventSubscriber, EventDispatcher, eventNames +from proxy.core.event import EventQueue, EventSubscriber, eventNames +from proxy.core.event.manager import EventManager # Enable debug logging to view core event logs logging.basicConfig(level=logging.DEBUG) -# Eventing requires a multiprocess safe queue -# so that events can be safely published and received -# between processes. -manager = multiprocessing.Manager() - main_publisher_request_id = '1234' process_publisher_request_id = '12345' num_events_received = [0, 0] @@ -59,17 +54,13 @@ def on_event(payload: Dict[str, Any]) -> None: if __name__ == '__main__': start_time = time.time() - # Start dispatcher thread - dispatcher_queue = EventQueue(manager.Queue()) - dispatcher_shutdown_event = threading.Event() - dispatcher = EventDispatcher( - shutdown=dispatcher_shutdown_event, - event_queue=dispatcher_queue) - dispatcher_thread = threading.Thread(target=dispatcher.run) - dispatcher_thread.start() + # Start dispatcher thread using EventManager + event_manager = EventManager() + event_manager.start_event_dispatcher() + assert event_manager.event_queue # Create a subscriber - subscriber = EventSubscriber(dispatcher_queue) + subscriber = EventSubscriber(event_manager.event_queue) # Internally, subscribe will start a separate thread # to receive incoming published messages subscriber.subscribe(on_event) @@ -79,13 +70,13 @@ def on_event(payload: Dict[str, Any]) -> None: publisher_shutdown_event = multiprocessing.Event() publisher = multiprocessing.Process( target=publisher_process, args=( - publisher_shutdown_event, dispatcher_queue, )) + publisher_shutdown_event, event_manager.event_queue, )) publisher.start() try: while True: # Dispatch event from main process - dispatcher_queue.publish( + event_manager.event_queue.publish( request_id=main_publisher_request_id, event_name=eventNames.WORK_STARTED, event_payload={'time': time.time()}, @@ -100,8 +91,6 @@ def on_event(payload: Dict[str, Any]) -> None: # Stop subscriber thread subscriber.unsubscribe() # Signal dispatcher to shutdown - dispatcher_shutdown_event.set() - # Wait for dispatcher shutdown - dispatcher_thread.join() + event_manager.stop_event_dispatcher() print('Received {0} events from main thread, {1} events from another process, in {2} seconds'.format( num_events_received[0], num_events_received[1], time.time() - start_time)) diff --git a/proxy/common/constants.py b/proxy/common/constants.py index 21e3955089..1e520cbf71 100644 --- a/proxy/common/constants.py +++ b/proxy/common/constants.py @@ -60,8 +60,15 @@ DEFAULT_IPV6_HOSTNAME = ipaddress.IPv6Address('::1') DEFAULT_KEY_FILE = None DEFAULT_LOG_FILE = None -DEFAULT_LOG_FORMAT = '%(asctime)s - pid:%(process)d [%(levelname)-.1s] %(funcName)s:%(lineno)d - %(message)s' +DEFAULT_LOG_FORMAT = '%(asctime)s - pid:%(process)d [%(levelname)-.1s] %(filename)s:%(funcName)s:%(lineno)d - %(message)s' DEFAULT_LOG_LEVEL = 'INFO' +DEFAULT_HTTP_ACCESS_LOG_FORMAT = '{client_ip}:{client_port} - ' + \ + '{request_method} {server_host}:{server_port}{request_path} - ' + \ + '{response_code} {response_reason} - {response_bytes} bytes - ' + \ + '{connection_time_ms} ms' +DEFAULT_HTTPS_ACCESS_LOG_FORMAT = '{client_ip}:{client_port} - ' + \ + '{request_method} {server_host}:{server_port} - ' + \ + '{response_bytes} bytes - {connection_time_ms} ms' DEFAULT_NUM_WORKERS = 0 DEFAULT_OPEN_FILE_LIMIT = 1024 DEFAULT_PAC_FILE = None diff --git a/proxy/common/utils.py b/proxy/common/utils.py index 203a17aee2..b894da5d08 100644 --- a/proxy/common/utils.py +++ b/proxy/common/utils.py @@ -9,15 +9,17 @@ :license: BSD, see LICENSE for more details. """ import ssl -import contextlib +import socket +import logging import functools import ipaddress -import socket +import contextlib from types import TracebackType from typing import Optional, Dict, Any, List, Tuple, Type, Callable from .constants import HTTP_1_1, COLON, WHITESPACE, CRLF, DEFAULT_TIMEOUT +from .constants import DEFAULT_LOG_FILE, DEFAULT_LOG_FORMAT, DEFAULT_LOG_LEVEL def text_(s: Any, encoding: str = 'utf-8', errors: str = 'strict') -> Any: @@ -89,14 +91,14 @@ def build_http_pkt(line: List[bytes], headers: Optional[Dict[bytes, bytes]] = None, body: Optional[bytes] = None) -> bytes: """Build and returns a HTTP request or response packet.""" - req = WHITESPACE.join(line) + CRLF + pkt = WHITESPACE.join(line) + CRLF if headers is not None: for k in headers: - req += build_http_header(k, headers[k]) + CRLF - req += CRLF + pkt += build_http_header(k, headers[k]) + CRLF + pkt += CRLF if body: - req += body - return req + pkt += body + return pkt def build_websocket_handshake_request( @@ -226,3 +228,24 @@ def get_available_port() -> int: sock.bind(('', 0)) _, port = sock.getsockname() return int(port) + + +def setup_logger( + log_file: Optional[str] = DEFAULT_LOG_FILE, + log_level: str = DEFAULT_LOG_LEVEL, + log_format: str = DEFAULT_LOG_FORMAT) -> None: + ll = getattr( + logging, + {'D': 'DEBUG', + 'I': 'INFO', + 'W': 'WARNING', + 'E': 'ERROR', + 'C': 'CRITICAL'}[log_level.upper()[0]]) + if log_file: + logging.basicConfig( + filename=log_file, + filemode='a', + level=ll, + format=log_format) + else: + logging.basicConfig(level=ll, format=log_format) diff --git a/proxy/core/acceptor/acceptor.py b/proxy/core/acceptor/acceptor.py index 3a9ca61828..411bc2ad2d 100644 --- a/proxy/core/acceptor/acceptor.py +++ b/proxy/core/acceptor/acceptor.py @@ -27,6 +27,7 @@ from ..event import EventQueue, eventNames from ...common.constants import DEFAULT_THREADLESS from ...common.flag import flags +from ...common.utils import setup_logger logger = logging.getLogger(__name__) @@ -133,6 +134,8 @@ def run_once(self) -> None: self.start_work(conn, addr) def run(self) -> None: + setup_logger(self.flags.log_file, self.flags.log_level, + self.flags.log_format) self.selector = selectors.DefaultSelector() fileno = recv_handle(self.work_queue) self.work_queue.close() diff --git a/proxy/core/acceptor/pool.py b/proxy/core/acceptor/pool.py index 358619760e..2152ab1f28 100644 --- a/proxy/core/acceptor/pool.py +++ b/proxy/core/acceptor/pool.py @@ -12,7 +12,7 @@ import logging import multiprocessing import socket -import threading + from multiprocessing import connection from multiprocessing.reduction import send_handle from typing import List, Optional, Type @@ -20,10 +20,10 @@ from .acceptor import Acceptor from .work import Work -from ..event import EventQueue, EventDispatcher +from ..event import EventQueue + from ...common.flag import flags -from ...common.constants import DEFAULT_BACKLOG, DEFAULT_ENABLE_EVENTS -from ...common.constants import DEFAULT_IPV6_HOSTNAME, DEFAULT_NUM_WORKERS, DEFAULT_PORT +from ...common.constants import DEFAULT_BACKLOG, DEFAULT_IPV6_HOSTNAME, DEFAULT_NUM_WORKERS, DEFAULT_PORT logger = logging.getLogger(__name__) @@ -37,14 +37,6 @@ default=DEFAULT_BACKLOG, help='Default: 100. Maximum number of pending connections to proxy server') -flags.add_argument( - '--enable-events', - action='store_true', - default=DEFAULT_ENABLE_EVENTS, - help='Default: False. Enables core to dispatch lifecycle events. ' - 'Plugins can be used to subscribe for core events.' -) - flags.add_argument( '--hostname', type=str, @@ -79,31 +71,16 @@ class AcceptorPool: pool.shutdown() `work_klass` must implement `work.Work` class. - - Optionally, AcceptorPool also initialize a global event queue. - It is a multiprocess safe queue which can be used to build pubsub patterns - for message sharing or signaling. - - TODO(abhinavsingh): Decouple event queue setup & teardown into its own class. """ def __init__(self, flags: argparse.Namespace, - work_klass: Type[Work]) -> None: + work_klass: Type[Work], event_queue: Optional[EventQueue] = None) -> None: self.flags = flags self.socket: Optional[socket.socket] = None self.acceptors: List[Acceptor] = [] self.work_queues: List[connection.Connection] = [] self.work_klass = work_klass - - self.event_queue: Optional[EventQueue] = None - self.event_dispatcher: Optional[EventDispatcher] = None - self.event_dispatcher_thread: Optional[threading.Thread] = None - self.event_dispatcher_shutdown: Optional[threading.Event] = None - self.manager: Optional[multiprocessing.managers.SyncManager] = None - - if self.flags.enable_events: - self.manager = multiprocessing.Manager() - self.event_queue = EventQueue(self.manager.Queue()) + self.event_queue: Optional[EventQueue] = event_queue def listen(self) -> None: self.socket = socket.socket(self.flags.family, socket.SOCK_STREAM) @@ -137,32 +114,10 @@ def start_workers(self) -> None: self.work_queues.append(work_queue[0]) logger.info('Started %d workers' % self.flags.num_workers) - def start_event_dispatcher(self) -> None: - self.event_dispatcher_shutdown = threading.Event() - assert self.event_dispatcher_shutdown - assert self.event_queue - self.event_dispatcher = EventDispatcher( - shutdown=self.event_dispatcher_shutdown, - event_queue=self.event_queue - ) - self.event_dispatcher_thread = threading.Thread( - target=self.event_dispatcher.run - ) - self.event_dispatcher_thread.start() - logger.debug('Thread ID: %d', self.event_dispatcher_thread.ident) - def shutdown(self) -> None: logger.info('Shutting down %d workers' % self.flags.num_workers) for acceptor in self.acceptors: acceptor.running.set() - if self.flags.enable_events: - assert self.event_dispatcher_shutdown - assert self.event_dispatcher_thread - self.event_dispatcher_shutdown.set() - self.event_dispatcher_thread.join() - logger.debug( - 'Shutdown of global event dispatcher thread %d successful', - self.event_dispatcher_thread.ident) for acceptor in self.acceptors: acceptor.join() logger.debug('Acceptors shutdown') @@ -170,9 +125,6 @@ def shutdown(self) -> None: def setup(self) -> None: """Listen on port, setup workers and pass server socket to workers.""" self.listen() - if self.flags.enable_events: - logger.info('Core Event enabled') - self.start_event_dispatcher() self.start_workers() # Send server socket to all acceptor processes. assert self.socket is not None diff --git a/proxy/core/acceptor/threadless.py b/proxy/core/acceptor/threadless.py index 6cfc99e97f..0b5d40b3d8 100644 --- a/proxy/core/acceptor/threadless.py +++ b/proxy/core/acceptor/threadless.py @@ -26,6 +26,7 @@ from ..connection import TcpClientConnection from ..event import EventQueue, eventNames +from ...common.utils import setup_logger from ...common.types import Readables, Writables from ...common.constants import DEFAULT_TIMEOUT @@ -179,6 +180,8 @@ def run_once(self) -> None: self.cleanup_inactive() def run(self) -> None: + setup_logger(self.flags.log_file, self.flags.log_level, + self.flags.log_format) try: self.selector = selectors.DefaultSelector() self.selector.register(self.client_queue, selectors.EVENT_READ) diff --git a/proxy/core/connection/server.py b/proxy/core/connection/server.py index c5636e6a92..b19d8e4734 100644 --- a/proxy/core/connection/server.py +++ b/proxy/core/connection/server.py @@ -10,9 +10,11 @@ """ import ssl import socket + from typing import Optional, Union, Tuple from .connection import TcpConnection, tcpConnectionTypes, TcpConnectionUninitializedException + from ...common.utils import new_socket_connection @@ -23,6 +25,7 @@ def __init__(self, host: str, port: int): super().__init__(tcpConnectionTypes.SERVER) self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None self.addr: Tuple[str, int] = (host, int(port)) + self.closed = True @property def connection(self) -> Union[ssl.SSLSocket, socket.socket]: @@ -31,9 +34,9 @@ def connection(self) -> Union[ssl.SSLSocket, socket.socket]: return self._conn def connect(self) -> None: - if self._conn is not None: - return - self._conn = new_socket_connection(self.addr) + if self._conn is None: + self._conn = new_socket_connection(self.addr) + self.closed = False def wrap(self, hostname: str, ca_file: Optional[str]) -> None: ctx = ssl.create_default_context( diff --git a/proxy/core/event/__init__.py b/proxy/core/event/__init__.py index 6907dcd55b..17e1074e6e 100644 --- a/proxy/core/event/__init__.py +++ b/proxy/core/event/__init__.py @@ -12,6 +12,7 @@ from .names import EventNames, eventNames from .dispatcher import EventDispatcher from .subscriber import EventSubscriber +from .manager import EventManager __all__ = [ 'eventNames', @@ -19,4 +20,5 @@ 'EventQueue', 'EventDispatcher', 'EventSubscriber', + 'EventManager', ] diff --git a/proxy/core/event/dispatcher.py b/proxy/core/event/dispatcher.py index fb7c527533..80cb13043d 100644 --- a/proxy/core/event/dispatcher.py +++ b/proxy/core/event/dispatcher.py @@ -35,7 +35,7 @@ class EventDispatcher: module is not-recommended. Python native multiprocessing queue doesn't provide a fanout functionality which core dispatcher module implements so that several plugins can consume same published - event at a time. + event concurrently. When --enable-events is used, a multiprocessing.Queue is created and attached to global argparse. This queue can then be used for diff --git a/proxy/core/event/manager.py b/proxy/core/event/manager.py new file mode 100644 index 0000000000..74e10b033a --- /dev/null +++ b/proxy/core/event/manager.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import logging +import threading +import multiprocessing + +from typing import Optional + +from .queue import EventQueue +from .dispatcher import EventDispatcher + +logger = logging.getLogger(__name__) + + +class EventManager: + """Event manager is an encapsulation around various initialization, dispatcher + start / stop API required for end-to-end eventing. + """ + + def __init__(self) -> None: + self.event_queue: Optional[EventQueue] = None + self.event_dispatcher: Optional[EventDispatcher] = None + self.event_dispatcher_thread: Optional[threading.Thread] = None + self.event_dispatcher_shutdown: Optional[threading.Event] = None + self.manager: Optional[multiprocessing.managers.SyncManager] = None + + def start_event_dispatcher(self) -> None: + self.manager = multiprocessing.Manager() + self.event_queue = EventQueue(self.manager.Queue()) + self.event_dispatcher_shutdown = threading.Event() + assert self.event_dispatcher_shutdown + assert self.event_queue + self.event_dispatcher = EventDispatcher( + shutdown=self.event_dispatcher_shutdown, + event_queue=self.event_queue + ) + self.event_dispatcher_thread = threading.Thread( + target=self.event_dispatcher.run + ) + self.event_dispatcher_thread.start() + logger.debug('Thread ID: %d', self.event_dispatcher_thread.ident) + + def stop_event_dispatcher(self) -> None: + assert self.event_dispatcher_shutdown + assert self.event_dispatcher_thread + self.event_dispatcher_shutdown.set() + self.event_dispatcher_thread.join() + logger.debug( + 'Shutdown of global event dispatcher thread %d successful', + self.event_dispatcher_thread.ident) diff --git a/proxy/core/event/names.py b/proxy/core/event/names.py index b45a70b2d5..e56f5626b9 100644 --- a/proxy/core/event/names.py +++ b/proxy/core/event/names.py @@ -10,6 +10,9 @@ """ from typing import NamedTuple +# Name of the events that eventing framework will support +# Ideally this must be configurable via command line or +# at-least extendable via plugins. EventNames = NamedTuple('EventNames', [ ('SUBSCRIBE', int), ('UNSUBSCRIBE', int), diff --git a/proxy/core/event/queue.py b/proxy/core/event/queue.py index 36b246648d..b4e6ab9615 100644 --- a/proxy/core/event/queue.py +++ b/proxy/core/event/queue.py @@ -19,23 +19,22 @@ class EventQueue: - """Global event queue. + """Global event queue. Ideally the queue must come from multiprocessing.Manager, + specially if you intent to publish/subscribe from multiple processes. - Each event contains: - - 1. Request ID - Globally unique - 2. Process ID - Process ID of event publisher. - This will be process id of acceptor workers. - 3. Thread ID - Thread ID of event publisher. - When --threadless is enabled, this value will - be same for all the requests - received by a single acceptor worker. - When --threadless is disabled, this value will be - Thread ID of the thread handling the client request. - 4. Event Timestamp - Time when this event occur - 5. Event Name - One of the defined or custom event name - 6. Event Payload - Optional data associated with the event - 7. Publisher ID (optional) - Optionally, publishing entity unique name / ID + Each published event contains following schema: + { + 'request_id': 'Globally unique request ID', + 'process_id': 'Process ID of event publisher. ' + 'This will be the process ID of acceptor workers.', + 'thread_id': 'Thread ID of event publisher. ' + 'When --threadless is enabled, this value will be ' + 'same for all the requests.' + 'event_timestamp': 'Time when this event occured', + 'event_name': 'one of the pre-defined or custom event name', + 'event_payload': 'Optional data associated with the event', + 'publisher_id': 'Optional publisher entity unique name', + } """ def __init__(self, queue: DictQueueType) -> None: diff --git a/proxy/http/handler.py b/proxy/http/handler.py index bff6f32829..d4f82b9a5b 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -273,17 +273,15 @@ def handle_readables(self, readables: Readables) -> bool: try: # HttpProtocolHandlerPlugin.on_client_data # Can raise HttpProtocolException to teardown the connection - plugin_index = 0 - plugins = list(self.plugins.values()) - while plugin_index < len(plugins) and client_data: - client_data = plugins[plugin_index].on_client_data( - client_data) + for plugin in self.plugins.values(): + client_data = plugin.on_client_data(client_data) if client_data is None: break - plugin_index += 1 - # Don't parse request any further after 1st request has completed. + # Don't parse incoming data any further after 1st request has completed. + # # This specially does happen for pipeline requests. + # # Plugins can utilize on_client_data for such cases and # apply custom logic to handle request data sent after 1st # valid request. diff --git a/proxy/http/parser.py b/proxy/http/parser.py index 341b1599e0..368e11ab87 100644 --- a/proxy/http/parser.py +++ b/proxy/http/parser.py @@ -14,7 +14,7 @@ from .methods import httpMethods from .chunk_parser import ChunkParser, chunkParserStates -from ..common.constants import DEFAULT_DISABLE_HEADERS, COLON, CRLF, WHITESPACE, HTTP_1_1, DEFAULT_HTTP_PORT +from ..common.constants import DEFAULT_DISABLE_HEADERS, COLON, SLASH, CRLF, WHITESPACE, HTTP_1_1, DEFAULT_HTTP_PORT from ..common.utils import build_http_request, build_http_response, find_http_line, text_ @@ -237,7 +237,7 @@ def build_path(self) -> bytes: url += b'#' + self.url.fragment return url - def build(self, disable_headers: Optional[List[bytes]] = None) -> bytes: + def build(self, disable_headers: Optional[List[bytes]] = None, for_proxy: bool = False) -> bytes: """Rebuild the request object.""" assert self.method and self.version and self.path and self.type == httpParserTypes.REQUEST_PARSER if disable_headers is None: @@ -245,8 +245,20 @@ def build(self, disable_headers: Optional[List[bytes]] = None) -> bytes: body: Optional[bytes] = ChunkParser.to_chunks(self.body) \ if self.is_chunked_encoded() and self.body else \ self.body + path = self.path + if for_proxy: + assert self.url and self.host and self.port and self.path + path = ( + self.url.scheme + + COLON + SLASH + SLASH + + self.host + + COLON + + str(self.port).encode() + + self.path + ) if self.method != httpMethods.CONNECT else (self.host + COLON + str(self.port).encode()) + return build_http_request( - self.method, self.path, self.version, + self.method, path, self.version, headers={} if not self.headers else {self.headers[k][0]: self.headers[k][1] for k in self.headers if k.lower() not in disable_headers}, body=body @@ -263,7 +275,7 @@ def build_response(self) -> bytes: self.headers[k][0]: self.headers[k][1] for k in self.headers}, body=self.body if not self.is_chunked_encoded() else ChunkParser.to_chunks(self.body)) - def has_upstream_server(self) -> bool: + def has_host(self) -> bool: """Host field SHOULD be None for incoming local WebServer requests.""" return self.host is not None diff --git a/proxy/http/proxy/auth.py b/proxy/http/proxy/auth.py index d1ac9a8611..24d00f2aa9 100644 --- a/proxy/http/proxy/auth.py +++ b/proxy/http/proxy/auth.py @@ -45,7 +45,7 @@ def handle_client_request( return request def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: - return chunk + return chunk # pragma: no cover def on_upstream_connection_close(self) -> None: - pass + pass # pragma: no cover diff --git a/proxy/http/proxy/plugin.py b/proxy/http/proxy/plugin.py index b439ac7f2a..aa679c2314 100644 --- a/proxy/http/proxy/plugin.py +++ b/proxy/http/proxy/plugin.py @@ -12,7 +12,7 @@ import argparse from uuid import UUID -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from abc import ABC, abstractmethod from ..parser import HttpParser @@ -86,6 +86,22 @@ def before_upstream_connection( Raise HttpRequestRejected or HttpProtocolException directly to drop the connection.""" return request # pragma: no cover + # Since 3.4.0 + # + # @abstractmethod + def handle_client_data( + self, raw: memoryview) -> Optional[memoryview]: + """Handler called in special scenarios when an upstream server connection + is never established. + + Essentially, if you return None from within before_upstream_connection, + be prepared to handle_client_data and not handle_client_request. + + Raise HttpRequestRejected to teardown the connection + Return None to drop the connection + """ + return raw # pragma: no cover + @abstractmethod def handle_client_request( self, request: HttpParser) -> Optional[HttpParser]: @@ -117,3 +133,17 @@ def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: def on_upstream_connection_close(self) -> None: """Handler called right after upstream connection has been closed.""" pass # pragma: no cover + + def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Use this method to override default access log format (see + DEFAULT_HTTP_ACCESS_LOG_FORMAT and DEFAULT_HTTPS_ACCESS_LOG_FORMAT) and to + add/update/modify/delete context for next plugin.on_access_log invocation. + + This is specially useful if a plugins want to provide extra context + in the access log which may not available within other plugins' context or even + in proxy.py core. + + Returns Log context or None. If plugin chooses to access log, they ideally + must return None to prevent other plugin.on_access_log invocation. + """ + return context diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 92a0253387..5b26b17163 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -8,14 +8,15 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import logging -import threading -import subprocess import os import ssl -import socket import time import errno +import socket +import logging +import threading +import subprocess + from typing import Optional, List, Union, Dict, cast, Any, Tuple from .plugin import HttpProxyBasePlugin @@ -30,6 +31,7 @@ from ...common.constants import DEFAULT_CA_KEY_FILE, DEFAULT_CA_SIGNING_KEY_FILE from ...common.constants import COMMA, DEFAULT_SERVER_RECVBUF_SIZE, DEFAULT_CERT_FILE from ...common.constants import PROXY_AGENT_HEADER_VALUE, DEFAULT_DISABLE_HEADERS +from ...common.constants import DEFAULT_HTTP_ACCESS_LOG_FORMAT, DEFAULT_HTTPS_ACCESS_LOG_FORMAT from ...common.utils import build_http_response, text_ from ...common.pki import gen_public_key, gen_csr, sign_csr @@ -137,7 +139,7 @@ def tls_interception_enabled(self) -> bool: def get_descriptors( self) -> Tuple[List[socket.socket], List[socket.socket]]: - if not self.request.has_upstream_server(): + if not self.request.has_host(): return [], [] r: List[socket.socket] = [] @@ -159,13 +161,15 @@ def get_descriptors( return r, w def write_to_descriptors(self, w: Writables) -> bool: - if self.server and self.server.connection not in w: + if (self.server and self.server.connection not in w) or not self.server: # Currently, we just call write/read block of each plugins. It is # plugins responsibility to ignore this callback, if passed descriptors # doesn't contain the descriptor they registered. for plugin in self.plugins.values(): - plugin.write_to_descriptors(w) - elif self.request.has_upstream_server() and \ + teardown = plugin.write_to_descriptors(w) + if teardown: + return True + elif self.request.has_host() and \ self.server and not self.server.closed and \ self.server.has_buffer() and \ self.server.connection in w: @@ -187,13 +191,15 @@ def write_to_descriptors(self, w: Writables) -> bool: return False def read_from_descriptors(self, r: Readables) -> bool: - if self.server and self.server.connection not in r: + if (self.server and self.server.connection not in r) or not self.server: # Currently, we just call write/read block of each plugins. It is # plugins responsibility to ignore this callback, if passed descriptors - # doesn't contain the descriptor they registered. + # doesn't contain the descriptor they registered for. for plugin in self.plugins.values(): - plugin.write_to_descriptors(r) - elif self.request.has_upstream_server() \ + teardown = plugin.read_from_descriptors(r) + if teardown: + return True + elif self.request.has_host() \ and self.server \ and not self.server.closed \ and self.server.connection in r: @@ -253,21 +259,49 @@ def read_from_descriptors(self, r: Readables) -> bool: return False def on_client_connection_close(self) -> None: - if not self.request.has_upstream_server(): + if not self.request.has_host(): return - self.access_log() - - # If server was never initialized, return - if self.server is None: - return + context = { + 'client_ip': self.client.addr[0], + 'client_port': self.client.addr[1], + 'request_method': text_(self.request.method), + 'request_path': text_(self.request.path), + 'server_host': text_(self.server.addr[0] if self.server else None), + 'server_port': text_(self.server.addr[1] if self.server else None), + 'response_bytes': self.response.total_size, + 'connection_time_ms': '%.2f' % ((time.time() - self.start_time) * 1000), + 'response_code': text_(self.response.code), + 'response_reason': text_(self.response.reason), + } + log_handled = False + for plugin in self.plugins.values(): + ctx = plugin.on_access_log(context) + if ctx is None: + log_handled = True + break + context = ctx + if not log_handled: + self.access_log(context) # Note that, server instance was initialized # but not necessarily the connection object exists. + # + # Unfortunately this is still being called when an upstream + # server connection was never established. This is done currently + # to assist proxy pool plugin to close its upstream proxy connections. + # + # In short, treat on_upstream_connection_close as on_client_connection_close + # equivalent within proxy plugins. + # # Invoke plugin.on_upstream_connection_close for plugin in self.plugins.values(): plugin.on_upstream_connection_close() + # If server was never initialized, return + if self.server is None: + return + try: try: self.server.connection.shutdown(socket.SHUT_WR) @@ -283,6 +317,12 @@ def on_client_connection_close(self) -> None: 'Closed server connection, has buffer %s' % self.server.has_buffer()) + def access_log(self, log_attrs: Dict[str, Any]) -> None: + access_log_format = DEFAULT_HTTPS_ACCESS_LOG_FORMAT + if self.request.method != httpMethods.CONNECT: + access_log_format = DEFAULT_HTTP_ACCESS_LOG_FORMAT + logger.info(access_log_format.format_map(log_attrs)) + def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]: # TODO: Allow to output multiple access_log lines # for each request over a pipelined HTTP connection (not for HTTPS). @@ -294,11 +334,31 @@ def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]: # self.access_log() return chunk + # Can return None to teardown connection def on_client_data(self, raw: memoryview) -> Optional[memoryview]: - if not self.request.has_upstream_server(): + if not self.request.has_host(): return raw - if self.server and not self.server.closed: + # For scenarios when an upstream connection was never established, + # let plugin do whatever they wish to. These are special scenarios + # where plugins are trying to do something magical. Within the core + # we don't know the context. Infact, we are not even sure if data + # exchanged is http spec compliant. + # + # Hence, here we pass raw data to HTTP proxy plugins as is. + # + # We only call handle_client_data once original request has been + # completely received + if not self.server: + for plugin in self.plugins.values(): + o = plugin.handle_client_data(raw) + if o is None: + return None + raw = o + elif self.server and not self.server.closed: + # For http proxy requests, handle pipeline case. + # We also handle pipeline scenario for https proxy + # requests is TLS interception is enabled. if self.request.state == httpParserStates.COMPLETE and ( self.request.method != httpMethods.CONNECT or self.tls_interception_enabled()): @@ -332,19 +392,26 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: self.pipeline_request.build())) if not self.pipeline_request.is_connection_upgrade(): self.pipeline_request = None + # For scenarios where we cannot peek into the data, + # simply queue for upstream server. else: self.server.queue(raw) return None return raw def on_request_complete(self) -> Union[socket.socket, bool]: - if not self.request.has_upstream_server(): + if not self.request.has_host(): return False self.emit_request_complete() - # Note: can raise HttpRequestRejected exception # Invoke plugin.before_upstream_connection + # + # before_upstream_connection can: + # 1) Raise HttpRequestRejected exception to reject the connection + # 2) return None to continue without establishing an upstream server connection + # e.g. for scenarios when plugins want to return response from cache, or, + # via out-of-band over the network request. do_connect = True for plugin in self.plugins.values(): r = plugin.before_upstream_connection(self.request) @@ -353,9 +420,11 @@ def on_request_complete(self) -> Union[socket.socket, bool]: break self.request = r + # Connect to upstream if do_connect: self.connect_upstream() + # Invoke plugin.handle_client_request for plugin in self.plugins.values(): assert self.request is not None r = plugin.handle_client_request(self.request) @@ -364,30 +433,35 @@ def on_request_complete(self) -> Union[socket.socket, bool]: else: return False - if self.request.method == httpMethods.CONNECT: - self.client.queue( - HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) - if self.tls_interception_enabled(): - return self.intercept() - elif self.server: - # - proxy-connection header is a mistake, it doesn't seem to be - # officially documented in any specification, drop it. - # - proxy-authorization is of no use for upstream, remove it. - self.request.del_headers( - [b'proxy-authorization', b'proxy-connection']) - # - For HTTP/1.0, connection header defaults to close - # - For HTTP/1.1, connection header defaults to keep-alive - # Respect headers sent by client instead of manipulating - # Connection or Keep-Alive header. However, note that per - # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection - # connection headers are meant for communication between client and - # first intercepting proxy. - self.request.add_headers( - [(b'Via', b'1.1 %s' % PROXY_AGENT_HEADER_VALUE)]) - # Disable args.disable_headers before dispatching to upstream - self.server.queue( - memoryview(self.request.build( - disable_headers=self.flags.disable_headers))) + # For https requests, respond back with tunnel established response. + # Optionally, setup interceptor if TLS interception is enabled. + if self.server: + if self.request.method == httpMethods.CONNECT: + self.client.queue( + HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) + if self.tls_interception_enabled(): + return self.intercept() + # If an upstream server connection was established for http request, + # queue the request for upstream server. + else: + # - proxy-connection header is a mistake, it doesn't seem to be + # officially documented in any specification, drop it. + # - proxy-authorization is of no use for upstream, remove it. + self.request.del_headers( + [b'proxy-authorization', b'proxy-connection']) + # - For HTTP/1.0, connection header defaults to close + # - For HTTP/1.1, connection header defaults to keep-alive + # Respect headers sent by client instead of manipulating + # Connection or Keep-Alive header. However, note that per + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection + # connection headers are meant for communication between client and + # first intercepting proxy. + self.request.add_headers( + [(b'Via', b'1.1 %s' % PROXY_AGENT_HEADER_VALUE)]) + # Disable args.disable_headers before dispatching to upstream + self.server.queue( + memoryview(self.request.build( + disable_headers=self.flags.disable_headers))) return False def handle_pipeline_response(self, raw: memoryview) -> None: @@ -400,32 +474,6 @@ def handle_pipeline_response(self, raw: memoryview) -> None: if self.pipeline_response.state == httpParserStates.COMPLETE: self.pipeline_response = None - def access_log(self) -> None: - server_host, server_port = self.server.addr if self.server else ( - None, None) - connection_time_ms = (time.time() - self.start_time) * 1000 - if self.request.method == httpMethods.CONNECT: - logger.info( - '%s:%s - %s %s:%s - %s bytes - %.2f ms' % - (self.client.addr[0], - self.client.addr[1], - text_(self.request.method), - text_(server_host), - text_(server_port), - self.response.total_size, - connection_time_ms)) - elif self.request.method: - logger.info( - '%s:%s - %s %s:%s%s - %s %s - %s bytes - %.2f ms' % - (self.client.addr[0], self.client.addr[1], - text_(self.request.method), - text_(server_host), server_port, - text_(self.request.path), - text_(self.response.code), - text_(self.response.reason), - self.response.total_size, - connection_time_ms)) - def connect_upstream(self) -> None: host, port = self.request.host, self.request.port if host and port: diff --git a/proxy/http/server/web.py b/proxy/http/server/web.py index 313932fd4a..04abd298ab 100644 --- a/proxy/http/server/web.py +++ b/proxy/http/server/web.py @@ -135,7 +135,7 @@ def try_upgrade(self) -> bool: return False def on_request_complete(self) -> Union[socket.socket, bool]: - if self.request.has_upstream_server(): + if self.request.has_host(): return False assert self.request.path @@ -234,7 +234,7 @@ def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]: return chunk def on_client_connection_close(self) -> None: - if self.request.has_upstream_server(): + if self.request.has_host(): return if self.switched_protocol: # Invoke plugin.on_websocket_close diff --git a/proxy/plugin/proxy_pool.py b/proxy/plugin/proxy_pool.py index 5b4c96cb0b..c0a969e018 100644 --- a/proxy/plugin/proxy_pool.py +++ b/proxy/plugin/proxy_pool.py @@ -10,16 +10,36 @@ """ import random import socket -from typing import Optional, Any +import logging -from ..common.constants import DEFAULT_BUFFER_SIZE, SLASH, COLON -from ..common.utils import new_socket_connection +from typing import Dict, List, Optional, Any, Tuple + +from ..core.connection.server import TcpServerConnection +from ..common.types import Readables, Writables +from ..http.exception import HttpProtocolException from ..http.proxy import HttpProxyBasePlugin from ..http.parser import HttpParser +from ..http.methods import httpMethods + +logger = logging.getLogger(__name__) + +DEFAULT_HTTP_ACCESS_LOG_FORMAT = '{client_ip}:{client_port} - ' + \ + '{request_method} {server_host}:{server_port}{request_path} -> ' + \ + '{upstream_proxy_host}:{upstream_proxy_port} - ' + \ + '{response_code} {response_reason} - {response_bytes} bytes - ' + \ + '{connection_time_ms} ms' + +DEFAULT_HTTPS_ACCESS_LOG_FORMAT = '{client_ip}:{client_port} - ' + \ + '{request_method} {server_host}:{server_port} -> ' + \ + '{upstream_proxy_host}:{upstream_proxy_port} - ' + \ + '{response_bytes} bytes - {connection_time_ms} ms' class ProxyPoolPlugin(HttpProxyBasePlugin): - """Proxy incoming client proxy requests through a set of upstream proxies.""" + """Proxy pool plugin simply acts as a proxy adapter for proxy.py itself. + + Imagine this plugin as setting up proxy settings for proxy.py instance itself. + All incoming client requests are proxied to configured upstream proxies.""" # Run two separate instances of proxy.py # on port 9000 and 9001 BUT WITHOUT ProxyPool plugin @@ -31,56 +51,132 @@ class ProxyPoolPlugin(HttpProxyBasePlugin): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self.conn: Optional[socket.socket] = None + self.upstream: Optional[TcpServerConnection] = None + # Cached attributes to be used during access log override + self.request_host_port_path_method: List[Any] = [ + None, None, None, None] + self.total_size = 0 + + def get_descriptors(self) -> Tuple[List[socket.socket], List[socket.socket]]: + if not self.upstream: + return [], [] + return [self.upstream.connection], [self.upstream.connection] if self.upstream.has_buffer() else [] + + def read_from_descriptors(self, r: Readables) -> bool: + # Read from upstream proxy and queue for client + if self.upstream and self.upstream.connection in r: + try: + raw = self.upstream.recv(self.flags.server_recvbuf_size) + if raw is not None: + self.total_size += len(raw) + self.client.queue(raw) + else: + return True # Teardown because upstream proxy closed the connection + except ConnectionResetError: + logger.debug('Connection reset by upstream proxy') + return True + return False # Do not teardown connection + + def write_to_descriptors(self, w: Writables) -> bool: + # Flush queued data to upstream proxy now + if self.upstream and self.upstream.connection in w and self.upstream.has_buffer(): + try: + self.upstream.flush() + except BrokenPipeError: + logger.debug('BrokenPipeError when flushing to upstream proxy') + return True + return False def before_upstream_connection( self, request: HttpParser) -> Optional[HttpParser]: - """Avoids upstream connection to the server by returning None. - - Initialize, connection to upstream proxy. + """Avoids establishing the default connection to upstream server + by returning None. """ + # TODO(abhinavsingh): Ideally connection to upstream proxy endpoints + # must be bootstrapped within it's own re-usable and gc'd pool, to avoid establishing + # a fresh upstream proxy connection for each client request. + # # Implement your own logic here e.g. round-robin, least connection etc. - self.conn = new_socket_connection( - random.choice(self.UPSTREAM_PROXY_POOL)) + endpoint = random.choice(self.UPSTREAM_PROXY_POOL) + logger.debug('Using endpoint: {0}:{1}'.format(*endpoint)) + self.upstream = TcpServerConnection( + endpoint[0], endpoint[1]) + try: + self.upstream.connect() + except ConnectionRefusedError: + # TODO(abhinavsingh): Try another choice, when all (or max configured) choices have + # exhausted, retry for configured number of times before giving up. + # + # Failing upstream proxies, must be removed from the pool temporarily. + # A periodic health check must put them back in the pool. This can be achieved + # using a datastructure without having to spawn separate thread/process for health + # check. + logger.info( + 'Connection refused by upstream proxy {0}:{1}'.format(*endpoint)) + raise HttpProtocolException() + logger.debug( + 'Established connection to upstream proxy {0}:{1}'.format(*endpoint)) return None def handle_client_request( self, request: HttpParser) -> Optional[HttpParser]: - request.path = self.rebuild_original_path(request) - self.tunnel(request) - # Returning None indicates the core to gracefully - # flush client buffer and teardown the connection - return None + """Only invoked once after client original proxy request has been received completely.""" + assert self.upstream + # For log sanity (i.e. to avoid None:None), expose upstream host:port from headers + host, port = None, None + # Browser or applications may sometime send + # CONNECT / HTTP/1.0\r\n\r\n + # for proxy keep alive check + if request.has_header(b'host'): + parts = request.header(b'host').decode().split(':') + if len(parts) == 2: + host, port = parts[0], parts[1] + else: + assert len(parts) == 1 + host = parts[0] + port = '443' if request.method == httpMethods.CONNECT else '80' + path = None if not request.path else request.path.decode() + self.request_host_port_path_method = [ + host, port, path, request.method] + # Queue original request to upstream proxy + self.upstream.queue(memoryview(request.build(for_proxy=True))) + return request - def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: - """Will never be called since we didn't establish an upstream connection.""" - raise Exception("This should have never been called") + def handle_client_data(self, raw: memoryview) -> Optional[memoryview]: + """Only invoked when before_upstream_connection returns None""" + # Queue data to the proxy endpoint + assert self.upstream + self.upstream.queue(raw) + return raw def on_upstream_connection_close(self) -> None: - """Will never be called since we didn't establish an upstream connection.""" - raise Exception("This should have never been called") + """Called when client connection has been closed.""" + if self.upstream and not self.upstream.closed: + logger.debug('Closing upstream proxy connection') + self.upstream.close() + self.upstream = None - # TODO(abhinavsingh): Upgrade to use non-blocking get/read/write API. - def tunnel(self, request: HttpParser) -> None: - """Send to upstream proxy, receive from upstream proxy, queue back to client.""" - assert self.conn - self.conn.send(request.build()) - response = self.conn.recv(DEFAULT_BUFFER_SIZE) - self.client.queue(memoryview(response)) + def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]: + addr, port = ( + self.upstream.addr[0], self.upstream.addr[1]) if self.upstream else (None, None) + context.update({ + 'upstream_proxy_host': addr, + 'upstream_proxy_port': port, + 'server_host': self.request_host_port_path_method[0], + 'server_port': self.request_host_port_path_method[1], + 'request_path': self.request_host_port_path_method[2], + 'response_bytes': self.total_size, + }) + self.access_log(context) + return None - @staticmethod - def rebuild_original_path(request: HttpParser) -> bytes: - """Re-builds original upstream server URL. + def access_log(self, log_attrs: Dict[str, Any]) -> None: + access_log_format = DEFAULT_HTTPS_ACCESS_LOG_FORMAT + request_method = self.request_host_port_path_method[3] + if request_method and request_method != httpMethods.CONNECT: + access_log_format = DEFAULT_HTTP_ACCESS_LOG_FORMAT + logger.info(access_log_format.format_map(log_attrs)) - proxy server core by default strips upstream host:port - from incoming client proxy request. - """ - assert request.url and request.host and request.port and request.path - return ( - request.url.scheme + - COLON + SLASH + SLASH + - request.host + - COLON + - str(request.port).encode() + - request.path - ) + def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: + """Will never be called since we didn't establish an upstream connection.""" + raise Exception("This should have never been called") diff --git a/proxy/proxy.py b/proxy/proxy.py index 2a5382eaf1..950bdf793f 100644 --- a/proxy/proxy.py +++ b/proxy/proxy.py @@ -28,11 +28,12 @@ from proxy.core.acceptor.work import Work -from .common.utils import bytes_, text_ +from .common.utils import bytes_, text_, setup_logger from .common.types import IpAddress from .common.version import __version__ from .core.acceptor import AcceptorPool from .http.handler import HttpProtocolHandler +from .core.event import EventManager from .common.flag import flags from .common.constants import COMMA, DEFAULT_DATA_DIRECTORY_PATH, PLUGIN_PROXY_AUTH from .common.constants import DEFAULT_DEVTOOLS_WS_PATH, DEFAULT_DISABLE_HEADERS @@ -43,7 +44,7 @@ from .common.constants import DEFAULT_OPEN_FILE_LIMIT, DEFAULT_PID_FILE, DEFAULT_PLUGINS from .common.constants import DEFAULT_VERSION, DOT, PLUGIN_DASHBOARD, PLUGIN_DEVTOOLS_PROTOCOL from .common.constants import PLUGIN_HTTP_PROXY, PLUGIN_INSPECT_TRAFFIC, PLUGIN_PAC_FILE -from .common.constants import PLUGIN_WEB_SERVER, PY2_DEPRECATION_MESSAGE +from .common.constants import PLUGIN_WEB_SERVER, PY2_DEPRECATION_MESSAGE, DEFAULT_ENABLE_EVENTS if os.name != 'nt': import resource @@ -93,6 +94,13 @@ action='store_true', default=DEFAULT_ENABLE_WEB_SERVER, help='Default: False. Whether to enable proxy.HttpWebServerPlugin.') +flags.add_argument( + '--enable-events', + action='store_true', + default=DEFAULT_ENABLE_EVENTS, + help='Default: False. Enables core to dispatch lifecycle events. ' + 'Plugins can be used to subscribe for core events.' +) flags.add_argument( '--log-level', type=str, @@ -128,6 +136,10 @@ class Proxy: By default, AcceptorPool is started with HttpProtocolHandler worker class i.e. we are only expecting HTTP traffic to flow between clients and server. + + Optionally, also initialize a global event queue. + It is a multiprocess safe queue which can be used to build pubsub patterns + for message sharing or signaling. """ def __init__(self, input_args: Optional[List[str]], **opts: Any) -> None: @@ -137,6 +149,7 @@ def __init__(self, input_args: Optional[List[str]], **opts: Any) -> None: # e.g. A clear text protocol. Or imagine a TelnetProtocolHandler instead # of default HttpProtocolHandler. self.work_klass: Type[Work] = HttpProtocolHandler + self.event_manager: Optional[EventManager] = None def write_pid_file(self) -> None: if self.flags.pid_file is not None: @@ -150,9 +163,14 @@ def delete_pid_file(self) -> None: os.remove(self.flags.pid_file) def __enter__(self) -> 'Proxy': + if self.flags.enable_events: + logger.info('Core Event enabled') + self.event_manager = EventManager() + self.event_manager.start_event_dispatcher() self.pool = AcceptorPool( flags=self.flags, - work_klass=self.work_klass + work_klass=self.work_klass, + event_queue=self.event_manager.event_queue if self.event_manager is not None else None ) self.pool.setup() self.write_pid_file() @@ -165,6 +183,9 @@ def __exit__( exc_tb: Optional[TracebackType]) -> None: assert self.pool self.pool.shutdown() + if self.flags.enable_events: + assert self.event_manager is not None + self.event_manager.stop_event_dispatcher() self.delete_pid_file() @staticmethod @@ -192,7 +213,7 @@ def initialize(input_args: Optional[List[str]] sys.exit(0) # Setup logging module - Proxy.setup_logger(args.log_file, args.log_level, args.log_format) + setup_logger(args.log_file, args.log_level, args.log_format) # Setup limits Proxy.set_open_file_limit(args.open_file_limit) @@ -397,27 +418,6 @@ def is_py3() -> bool: """Exists only to avoid mocking sys.version_info in tests.""" return sys.version_info[0] == 3 - @staticmethod - def setup_logger( - log_file: Optional[str] = DEFAULT_LOG_FILE, - log_level: str = DEFAULT_LOG_LEVEL, - log_format: str = DEFAULT_LOG_FORMAT) -> None: - ll = getattr( - logging, - {'D': 'DEBUG', - 'I': 'INFO', - 'W': 'WARNING', - 'E': 'ERROR', - 'C': 'CRITICAL'}[log_level.upper()[0]]) - if log_file: - logging.basicConfig( - filename=log_file, - filemode='a', - level=ll, - format=log_format) - else: - logging.basicConfig(level=ll, format=log_format) - @staticmethod def set_open_file_limit(soft_limit: int) -> None: """Configure open file description soft limit on supported OS.""" @@ -455,6 +455,13 @@ def main( (proxy.pool.flags.hostname, proxy.pool.flags.port)) # TODO: Introduce cron feature # https://github.com/abhinavsingh/proxy.py/issues/392 + # + # TODO: Introduce ability to publish + # adhoc events which can modify behaviour of server + # at runtime. Example, updating flags, plugin + # configuration etc. + # + # TODO: Python shell within running proxy.py environment while True: time.sleep(1) except KeyboardInterrupt: diff --git a/tests/http/test_http_proxy.py b/tests/http/test_http_proxy.py index 1be5af9657..3d85ed52ae 100644 --- a/tests/http/test_http_proxy.py +++ b/tests/http/test_http_proxy.py @@ -52,6 +52,8 @@ def test_proxy_plugin_initialized(self) -> None: def test_proxy_plugin_on_and_before_upstream_connection( self, mock_server_conn: mock.Mock) -> None: + self.plugin.return_value.write_to_descriptors.return_value = False + self.plugin.return_value.read_from_descriptors.return_value = False self.plugin.return_value.before_upstream_connection.side_effect = lambda r: r self.plugin.return_value.handle_client_request.side_effect = lambda r: r @@ -76,6 +78,8 @@ def test_proxy_plugin_on_and_before_upstream_connection( def test_proxy_plugin_before_upstream_connection_can_teardown( self, mock_server_conn: mock.Mock) -> None: + self.plugin.return_value.write_to_descriptors.return_value = False + self.plugin.return_value.read_from_descriptors.return_value = False self.plugin.return_value.before_upstream_connection.side_effect = HttpProtocolException() self._conn.recv.return_value = build_http_request( @@ -91,5 +95,5 @@ def test_proxy_plugin_before_upstream_connection_can_teardown( data=None), selectors.EVENT_READ)], ] self.protocol_handler.run_once() - self.plugin.return_value.before_upstream_connection.assert_called() mock_server_conn.assert_not_called() + self.plugin.return_value.before_upstream_connection.assert_called() diff --git a/tests/http/test_http_proxy_tls_interception.py b/tests/http/test_http_proxy_tls_interception.py index ddb7987018..96564fcf3d 100644 --- a/tests/http/test_http_proxy_tls_interception.py +++ b/tests/http/test_http_proxy_tls_interception.py @@ -124,6 +124,8 @@ def mock_connection() -> Any: self.plugin.return_value.on_client_connection_close.return_value = None # Prepare mocked HttpProxyBasePlugin + self.proxy_plugin.return_value.write_to_descriptors.return_value = False + self.proxy_plugin.return_value.read_from_descriptors.return_value = False self.proxy_plugin.return_value.before_upstream_connection.side_effect = lambda r: r self.proxy_plugin.return_value.handle_client_request.side_effect = lambda r: r diff --git a/tests/test_main.py b/tests/test_main.py index 0e171eb04f..598a9ed45e 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -53,37 +53,42 @@ def mock_default_args(mock_args: mock.Mock) -> None: mock_args.port = DEFAULT_PORT mock_args.num_workers = DEFAULT_NUM_WORKERS mock_args.disable_http_proxy = DEFAULT_DISABLE_HTTP_PROXY - mock_args.enable_web_server = DEFAULT_ENABLE_WEB_SERVER mock_args.pac_file = DEFAULT_PAC_FILE mock_args.plugins = DEFAULT_PLUGINS mock_args.server_recvbuf_size = DEFAULT_SERVER_RECVBUF_SIZE mock_args.client_recvbuf_size = DEFAULT_CLIENT_RECVBUF_SIZE mock_args.open_file_limit = DEFAULT_OPEN_FILE_LIMIT - mock_args.enable_static_server = DEFAULT_ENABLE_STATIC_SERVER - mock_args.enable_devtools = DEFAULT_ENABLE_DEVTOOLS mock_args.devtools_event_queue = None mock_args.devtools_ws_path = DEFAULT_DEVTOOLS_WS_PATH mock_args.timeout = DEFAULT_TIMEOUT mock_args.threadless = DEFAULT_THREADLESS + mock_args.enable_web_server = DEFAULT_ENABLE_WEB_SERVER + mock_args.enable_static_server = DEFAULT_ENABLE_STATIC_SERVER + mock_args.enable_devtools = DEFAULT_ENABLE_DEVTOOLS mock_args.enable_events = DEFAULT_ENABLE_EVENTS @mock.patch('time.sleep') @mock.patch('proxy.proxy.Proxy.initialize') + @mock.patch('proxy.proxy.EventManager') @mock.patch('proxy.proxy.AcceptorPool') @mock.patch('logging.basicConfig') def test_init_with_no_arguments( self, mock_logging_config: mock.Mock, mock_acceptor_pool: mock.Mock, + mock_event_manager: mock.Mock, mock_initialize: mock.Mock, mock_sleep: mock.Mock) -> None: mock_sleep.side_effect = KeyboardInterrupt() input_args: List[str] = [] + mock_initialize.return_value.enable_events = False main(input_args) + mock_event_manager.assert_not_called() mock_acceptor_pool.assert_called_with( flags=mock_initialize.return_value, work_klass=HttpProtocolHandler, + event_queue=None ) mock_acceptor_pool.return_value.setup.assert_called() mock_acceptor_pool.return_value.shutdown.assert_called() @@ -93,12 +98,14 @@ def test_init_with_no_arguments( @mock.patch('os.remove') @mock.patch('os.path.exists') @mock.patch('builtins.open') + @mock.patch('proxy.proxy.EventManager') @mock.patch('proxy.proxy.AcceptorPool') @mock.patch('proxy.common.flag.FlagParser.parse_args') def test_pid_file_is_written_and_removed( self, mock_parse_args: mock.Mock, mock_acceptor_pool: mock.Mock, + mock_event_manager: mock.Mock, mock_open: mock.Mock, mock_exists: mock.Mock, mock_remove: mock.Mock, @@ -108,9 +115,12 @@ def test_pid_file_is_written_and_removed( mock_args = mock_parse_args.return_value self.mock_default_args(mock_args) mock_args.pid_file = pid_file + mock_args.enable_dashboard = False main(['--pid-file', pid_file]) + mock_parse_args.assert_called_once() mock_acceptor_pool.assert_called() mock_acceptor_pool.return_value.setup.assert_called() + mock_event_manager.assert_not_called() mock_open.assert_called_with(pid_file, 'wb') mock_open.return_value.__enter__.return_value.write.assert_called_with( bytes_(os.getpid())) @@ -118,10 +128,12 @@ def test_pid_file_is_written_and_removed( mock_remove.assert_called_with(pid_file) @mock.patch('time.sleep') + @mock.patch('proxy.proxy.EventManager') @mock.patch('proxy.proxy.AcceptorPool') def test_basic_auth( self, mock_acceptor_pool: mock.Mock, + mock_event_manager: mock.Mock, mock_sleep: mock.Mock) -> None: mock_sleep.side_effect = KeyboardInterrupt() @@ -129,6 +141,7 @@ def test_basic_auth( flgs = Proxy.initialize(input_args) main(input_args=input_args) + mock_event_manager.assert_not_called() mock_acceptor_pool.assert_called_once() self.assertEqual( flgs.auth_code, @@ -136,12 +149,14 @@ def test_basic_auth( @mock.patch('time.sleep') @mock.patch('builtins.print') + @mock.patch('proxy.proxy.EventManager') @mock.patch('proxy.proxy.AcceptorPool') @mock.patch('proxy.proxy.Proxy.is_py3') def test_main_py3_runs( self, mock_is_py3: mock.Mock, mock_acceptor_pool: mock.Mock, + mock_event_manager: mock.Mock, mock_print: mock.Mock, mock_sleep: mock.Mock) -> None: mock_sleep.side_effect = KeyboardInterrupt() @@ -153,6 +168,8 @@ def test_main_py3_runs( mock_is_py3.assert_called() mock_print.assert_not_called() + + mock_event_manager.assert_not_called() mock_acceptor_pool.assert_called_once() mock_acceptor_pool.return_value.setup.assert_called()