Skip to content

Commit f48771f

Browse files
authored
Top-level notion of work not client (#695)
* Top-level notion of work not client * Update ssl echo server example
1 parent d3cee32 commit f48771f

13 files changed

+76
-73
lines changed

examples/https_connect_tunnel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
5353

5454
# Drop the request if not a CONNECT request
5555
if self.request.method != httpMethods.CONNECT:
56-
self.client.queue(
56+
self.work.queue(
5757
HttpsConnectTunnelHandler.PROXY_TUNNEL_UNSUPPORTED_SCHEME,
5858
)
5959
return True
@@ -66,7 +66,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
6666
self.connect_upstream()
6767

6868
# Queue tunnel established response to client
69-
self.client.queue(
69+
self.work.queue(
7070
HttpsConnectTunnelHandler.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT,
7171
)
7272

examples/ssl_echo_server.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,19 @@ def initialize(self) -> None:
2727
# here using wrap_socket() utility.
2828
assert self.flags.keyfile is not None and self.flags.certfile is not None
2929
conn = wrap_socket(
30-
self.client.connection,
30+
self.work.connection,
3131
self.flags.keyfile,
3232
self.flags.certfile,
3333
)
3434
conn.setblocking(False)
3535
# Upgrade plain TcpClientConnection to SSL connection object
36-
self.client = TcpClientConnection(
37-
conn=conn, addr=self.client.addr,
36+
self.work = TcpClientConnection(
37+
conn=conn, addr=self.work.addr,
3838
)
3939

4040
def handle_data(self, data: memoryview) -> Optional[bool]:
4141
# echo back to client
42-
self.client.queue(data)
42+
self.work.queue(data)
4343
return None
4444

4545

examples/tcp_echo_server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ class EchoServerHandler(BaseTcpServerHandler):
2020
"""Sets client socket to non-blocking during initialization."""
2121

2222
def initialize(self) -> None:
23-
self.client.connection.setblocking(False)
23+
self.work.connection.setblocking(False)
2424

2525
def handle_data(self, data: memoryview) -> Optional[bool]:
2626
# echo back to client
27-
self.client.queue(data)
27+
self.work.queue(data)
2828
return None
2929

3030

proxy/core/acceptor/work.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,18 @@ class Work(ABC):
2525

2626
def __init__(
2727
self,
28-
client: TcpClientConnection,
28+
work: TcpClientConnection,
2929
flags: argparse.Namespace,
3030
event_queue: Optional[EventQueue] = None,
3131
uid: Optional[UUID] = None,
3232
) -> None:
33-
self.client = client
33+
# Work uuid
34+
self.uid: UUID = uid if uid is not None else uuid4()
3435
self.flags = flags
36+
# Eventing core queue
3537
self.event_queue = event_queue
36-
self.uid: UUID = uid if uid is not None else uuid4()
38+
# Accept work
39+
self.work = work
3740

3841
@abstractmethod
3942
def get_events(self) -> Dict[socket.socket, int]:

proxy/core/base/tcp_server.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class BaseTcpServerHandler(Work):
4545
def __init__(self, *args: Any, **kwargs: Any) -> None:
4646
super().__init__(*args, **kwargs)
4747
self.must_flush_before_shutdown = False
48-
logger.debug('Connection accepted from {0}'.format(self.client.addr))
48+
logger.debug('Connection accepted from {0}'.format(self.work.addr))
4949

5050
@abstractmethod
5151
def handle_data(self, data: memoryview) -> Optional[bool]:
@@ -57,14 +57,14 @@ def get_events(self) -> Dict[socket.socket, int]:
5757
# We always want to read from client
5858
# Register for EVENT_READ events
5959
if self.must_flush_before_shutdown is False:
60-
events[self.client.connection] = selectors.EVENT_READ
60+
events[self.work.connection] = selectors.EVENT_READ
6161
# If there is pending buffer for client
6262
# also register for EVENT_WRITE events
63-
if self.client.has_buffer():
64-
if self.client.connection in events:
65-
events[self.client.connection] |= selectors.EVENT_WRITE
63+
if self.work.has_buffer():
64+
if self.work.connection in events:
65+
events[self.work.connection] |= selectors.EVENT_WRITE
6666
else:
67-
events[self.client.connection] = selectors.EVENT_WRITE
67+
events[self.work.connection] = selectors.EVENT_WRITE
6868
return events
6969

7070
def handle_events(
@@ -79,32 +79,32 @@ def handle_events(
7979
if teardown:
8080
logger.debug(
8181
'Shutting down client {0} connection'.format(
82-
self.client.addr,
82+
self.work.addr,
8383
),
8484
)
8585
return teardown
8686

8787
def handle_writables(self, writables: Writables) -> bool:
8888
teardown = False
89-
if self.client.connection in writables and self.client.has_buffer():
89+
if self.work.connection in writables and self.work.has_buffer():
9090
logger.debug(
91-
'Flushing buffer to client {0}'.format(self.client.addr),
91+
'Flushing buffer to client {0}'.format(self.work.addr),
9292
)
93-
self.client.flush()
93+
self.work.flush()
9494
if self.must_flush_before_shutdown is True:
95-
if not self.client.has_buffer():
95+
if not self.work.has_buffer():
9696
teardown = True
9797
self.must_flush_before_shutdown = False
9898
return teardown
9999

100100
def handle_readables(self, readables: Readables) -> bool:
101101
teardown = False
102-
if self.client.connection in readables:
103-
data = self.client.recv(self.flags.client_recvbuf_size)
102+
if self.work.connection in readables:
103+
data = self.work.recv(self.flags.client_recvbuf_size)
104104
if data is None:
105105
logger.debug(
106106
'Connection closed by client {0}'.format(
107-
self.client.addr,
107+
self.work.addr,
108108
),
109109
)
110110
teardown = True
@@ -113,13 +113,13 @@ def handle_readables(self, readables: Readables) -> bool:
113113
if isinstance(r, bool) and r is True:
114114
logger.debug(
115115
'Implementation signaled shutdown for client {0}'.format(
116-
self.client.addr,
116+
self.work.addr,
117117
),
118118
)
119-
if self.client.has_buffer():
119+
if self.work.has_buffer():
120120
logger.debug(
121121
'Client {0} has pending buffer, will be flushed before shutting down'.format(
122-
self.client.addr,
122+
self.work.addr,
123123
),
124124
)
125125
self.must_flush_before_shutdown = True

proxy/core/base/tcp_tunnel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
4343
pass # pragma: no cover
4444

4545
def initialize(self) -> None:
46-
self.client.connection.setblocking(False)
46+
self.work.connection.setblocking(False)
4747

4848
def shutdown(self) -> None:
4949
if self.upstream:
@@ -87,7 +87,7 @@ def handle_events(
8787
print('Connection closed by server')
8888
return True
8989
# tunnel data to client
90-
self.client.queue(data)
90+
self.work.queue(data)
9191
if self.upstream and self.upstream.connection in writables:
9292
self.upstream.flush()
9393
return False

proxy/http/handler.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -89,25 +89,25 @@ def __init__(self, *args: Any, **kwargs: Any):
8989

9090
def initialize(self) -> None:
9191
"""Optionally upgrades connection to HTTPS, set conn in non-blocking mode and initializes plugins."""
92-
conn = self._optionally_wrap_socket(self.client.connection)
92+
conn = self._optionally_wrap_socket(self.work.connection)
9393
conn.setblocking(False)
9494
# Update client connection reference if connection was wrapped
9595
if self._encryption_enabled():
96-
self.client = TcpClientConnection(conn=conn, addr=self.client.addr)
96+
self.work = TcpClientConnection(conn=conn, addr=self.work.addr)
9797
if b'HttpProtocolHandlerPlugin' in self.flags.plugins:
9898
for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']:
9999
instance: HttpProtocolHandlerPlugin = klass(
100100
self.uid,
101101
self.flags,
102-
self.client,
102+
self.work,
103103
self.request,
104104
self.event_queue,
105105
)
106106
self.plugins[instance.name()] = instance
107-
logger.debug('Handling connection %r' % self.client.connection)
107+
logger.debug('Handling connection %r' % self.work.connection)
108108

109109
def is_inactive(self) -> bool:
110-
if not self.client.has_buffer() and \
110+
if not self.work.has_buffer() and \
111111
self._connection_inactive_for() > self.flags.timeout:
112112
return True
113113
return False
@@ -127,20 +127,20 @@ def shutdown(self) -> None:
127127
logger.debug(
128128
'Closing client connection %r '
129129
'at address %r has buffer %s' %
130-
(self.client.connection, self.client.addr, self.client.has_buffer()),
130+
(self.work.connection, self.work.addr, self.work.has_buffer()),
131131
)
132132

133-
conn = self.client.connection
133+
conn = self.work.connection
134134
# Unwrap if wrapped before shutdown.
135135
if self._encryption_enabled() and \
136-
isinstance(self.client.connection, ssl.SSLSocket):
137-
conn = self.client.connection.unwrap()
136+
isinstance(self.work.connection, ssl.SSLSocket):
137+
conn = self.work.connection.unwrap()
138138
conn.shutdown(socket.SHUT_WR)
139139
logger.debug('Client connection shutdown successful')
140140
except OSError:
141141
pass
142142
finally:
143-
self.client.connection.close()
143+
self.work.connection.close()
144144
logger.debug('Client connection closed')
145145
super().shutdown()
146146

@@ -196,7 +196,7 @@ def handle_events(
196196
def handle_data(self, data: memoryview) -> Optional[bool]:
197197
if data is None:
198198
logger.debug('Client closed connection, tearing down...')
199-
self.client.closed = True
199+
self.work.closed = True
200200
return True
201201

202202
try:
@@ -227,7 +227,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
227227
logger.debug(
228228
'Updated client conn to %s', upgraded_sock,
229229
)
230-
self.client._conn = upgraded_sock
230+
self.work._conn = upgraded_sock
231231
for plugin_ in self.plugins.values():
232232
if plugin_ != plugin:
233233
plugin_.client._conn = upgraded_sock
@@ -237,20 +237,20 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
237237
logger.debug('HttpProtocolException raised')
238238
response: Optional[memoryview] = e.response(self.request)
239239
if response:
240-
self.client.queue(response)
240+
self.work.queue(response)
241241
return True
242242
return False
243243

244244
def handle_writables(self, writables: Writables) -> bool:
245-
if self.client.connection in writables and self.client.has_buffer():
245+
if self.work.connection in writables and self.work.has_buffer():
246246
logger.debug('Client is ready for writes, flushing buffer')
247247
self.last_activity = time.time()
248248

249249
# TODO(abhinavsingh): This hook could just reside within server recv block
250250
# instead of invoking when flushed to client.
251251
#
252252
# Invoke plugin.on_response_chunk
253-
chunk = self.client.buffer
253+
chunk = self.work.buffer
254254
for plugin in self.plugins.values():
255255
chunk = plugin.on_response_chunk(chunk)
256256
if chunk is None:
@@ -272,7 +272,7 @@ def handle_writables(self, writables: Writables) -> bool:
272272
return False
273273

274274
def handle_readables(self, readables: Readables) -> bool:
275-
if self.client.connection in readables:
275+
if self.work.connection in readables:
276276
logger.debug('Client is ready for reads, reading')
277277
self.last_activity = time.time()
278278
try:
@@ -290,7 +290,7 @@ def handle_readables(self, readables: Readables) -> bool:
290290
else:
291291
logger.exception(
292292
'Exception while receiving from %s connection %r with reason %r' %
293-
(self.client.tag, self.client.connection, e),
293+
(self.work.tag, self.work.connection, e),
294294
)
295295
return True
296296
return False
@@ -324,7 +324,7 @@ def run(self) -> None:
324324
except Exception as e:
325325
logger.exception(
326326
'Exception while handling connection %r' %
327-
self.client.connection, exc_info=e,
327+
self.work.connection, exc_info=e,
328328
)
329329
finally:
330330
self.shutdown()
@@ -377,24 +377,24 @@ def _run_once(self) -> bool:
377377

378378
def _flush(self) -> None:
379379
assert self.selector
380-
if not self.client.has_buffer():
380+
if not self.work.has_buffer():
381381
return
382382
try:
383383
self.selector.register(
384-
self.client.connection,
384+
self.work.connection,
385385
selectors.EVENT_WRITE,
386386
)
387-
while self.client.has_buffer():
387+
while self.work.has_buffer():
388388
ev: List[
389389
Tuple[selectors.SelectorKey, int]
390390
] = self.selector.select(timeout=1)
391391
if len(ev) == 0:
392392
continue
393-
self.client.flush()
393+
self.work.flush()
394394
except BrokenPipeError:
395395
pass
396396
finally:
397-
self.selector.unregister(self.client.connection)
397+
self.selector.unregister(self.work.connection)
398398

399399
def _connection_inactive_for(self) -> float:
400400
return time.time() - self.last_activity

tests/http/exceptions/test_http_proxy_auth_failed.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ def test_proxy_auth_fails_without_cred(self, mock_server_conn: mock.Mock) -> Non
6363

6464
self.protocol_handler._run_once()
6565
mock_server_conn.assert_not_called()
66-
self.assertEqual(self.protocol_handler.client.has_buffer(), True)
66+
self.assertEqual(self.protocol_handler.work.has_buffer(), True)
6767
self.assertEqual(
68-
self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
68+
self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
6969
)
7070
self._conn.send.assert_not_called()
7171

@@ -92,9 +92,9 @@ def test_proxy_auth_fails_with_invalid_cred(self, mock_server_conn: mock.Mock) -
9292

9393
self.protocol_handler._run_once()
9494
mock_server_conn.assert_not_called()
95-
self.assertEqual(self.protocol_handler.client.has_buffer(), True)
95+
self.assertEqual(self.protocol_handler.work.has_buffer(), True)
9696
self.assertEqual(
97-
self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
97+
self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
9898
)
9999
self._conn.send.assert_not_called()
100100

@@ -121,7 +121,7 @@ def test_proxy_auth_works_with_valid_cred(self, mock_server_conn: mock.Mock) ->
121121

122122
self.protocol_handler._run_once()
123123
mock_server_conn.assert_called_once()
124-
self.assertEqual(self.protocol_handler.client.has_buffer(), False)
124+
self.assertEqual(self.protocol_handler.work.has_buffer(), False)
125125

126126
@mock.patch('proxy.http.proxy.server.TcpServerConnection')
127127
def test_proxy_auth_works_with_mixed_case_basic_string(self, mock_server_conn: mock.Mock) -> None:
@@ -146,4 +146,4 @@ def test_proxy_auth_works_with_mixed_case_basic_string(self, mock_server_conn: m
146146

147147
self.protocol_handler._run_once()
148148
mock_server_conn.assert_called_once()
149-
self.assertEqual(self.protocol_handler.client.has_buffer(), False)
149+
self.assertEqual(self.protocol_handler.work.has_buffer(), False)

tests/http/test_http_proxy_tls_interception.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def mock_connection() -> Any:
201201
)
202202
self.assertEqual(self._conn.setblocking.call_count, 2)
203203
self.assertEqual(
204-
self.protocol_handler.client.connection,
204+
self.protocol_handler.work.connection,
205205
self.mock_ssl_wrap.return_value,
206206
)
207207

0 commit comments

Comments
 (0)