Skip to content

Commit 3cb0a62

Browse files
Add a --unix-socket-path flag (#697)
* Add a `--unix-socket-path` flag. When available `--hostname` and `--port` flags are ignored. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * `print` statement is allowed only in `flags.py` and `version-check.py`. All other places must use a `logger` instance * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add guard for `AF_UNIX` on Windows * Comment out assertion on Windows for now Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b830f7b commit 3cb0a62

20 files changed

+131
-50
lines changed

examples/pubsub_eventing.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@
2424
process_publisher_request_id = '12345'
2525
num_events_received = [0, 0]
2626

27+
logger = logging.getLogger(__name__)
28+
2729

2830
def publisher_process(
2931
shutdown_event: multiprocessing.synchronize.Event,
3032
dispatcher_queue: EventQueue,
3133
) -> None:
32-
print('publisher starting')
34+
logger.info('publisher starting')
3335
try:
3436
while not shutdown_event.is_set():
3537
dispatcher_queue.publish(
@@ -40,7 +42,7 @@ def publisher_process(
4042
)
4143
except KeyboardInterrupt:
4244
pass
43-
print('publisher shutdown')
45+
logger.info('publisher shutdown')
4446

4547

4648
def on_event(payload: Dict[str, Any]) -> None:
@@ -50,7 +52,6 @@ def on_event(payload: Dict[str, Any]) -> None:
5052
num_events_received[0] += 1
5153
else:
5254
num_events_received[1] += 1
53-
# print(payload)
5455

5556

5657
if __name__ == '__main__':
@@ -86,7 +87,7 @@ def on_event(payload: Dict[str, Any]) -> None:
8687
publisher_id='eventing_pubsub_main',
8788
)
8889
except KeyboardInterrupt:
89-
print('bye!!!')
90+
logger.info('bye!!!')
9091
finally:
9192
# Stop publisher
9293
publisher_shutdown_event.set()
@@ -95,8 +96,9 @@ def on_event(payload: Dict[str, Any]) -> None:
9596
subscriber.unsubscribe()
9697
# Signal dispatcher to shutdown
9798
event_manager.stop_event_dispatcher()
98-
print(
99+
logger.info(
99100
'Received {0} events from main thread, {1} events from another process, in {2} seconds'.format(
100-
num_events_received[0], num_events_received[1], time.time() - start_time,
101+
num_events_received[0], num_events_received[1], time.time(
102+
) - start_time,
101103
),
102104
)

examples/ssl_echo_client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88
:copyright: (c) 2013-present by Abhinav Singh and contributors.
99
:license: BSD, see LICENSE for more details.
1010
"""
11+
import logging
12+
1113
from proxy.core.connection import TcpServerConnection
1214
from proxy.common.constants import DEFAULT_BUFFER_SIZE
1315

16+
logger = logging.getLogger(__name__)
17+
1418
if __name__ == '__main__':
1519
client = TcpServerConnection('::', 12345)
1620
client.connect()
@@ -24,6 +28,6 @@
2428
data = client.recv(DEFAULT_BUFFER_SIZE)
2529
if data is None:
2630
break
27-
print(data.tobytes())
31+
logger.info(data.tobytes())
2832
finally:
2933
client.close()

examples/tcp_echo_client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,18 @@
88
:copyright: (c) 2013-present by Abhinav Singh and contributors.
99
:license: BSD, see LICENSE for more details.
1010
"""
11+
import logging
12+
1113
from proxy.common.utils import socket_connection
1214
from proxy.common.constants import DEFAULT_BUFFER_SIZE
1315

16+
logger = logging.getLogger(__name__)
17+
1418
if __name__ == '__main__':
1519
with socket_connection(('::', 12345)) as client:
1620
while True:
1721
client.send(b'hello')
1822
data = client.recv(DEFAULT_BUFFER_SIZE)
1923
if data is None:
2024
break
21-
print(data)
25+
logger.info(data)

examples/websocket_client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,23 @@
99
:license: BSD, see LICENSE for more details.
1010
"""
1111
import time
12-
from proxy.http.websocket import WebsocketClient, WebsocketFrame, websocketOpcodes
12+
import logging
1313

14+
from proxy.http.websocket import WebsocketClient, WebsocketFrame, websocketOpcodes
1415

1516
# globals
1617
client: WebsocketClient
1718
last_dispatch_time: float
1819
static_frame = memoryview(WebsocketFrame.text(b'hello'))
1920
num_echos = 10
2021

22+
logger = logging.getLogger(__name__)
23+
2124

2225
def on_message(frame: WebsocketFrame) -> None:
2326
"""WebsocketClient on_message callback."""
2427
global client, num_echos, last_dispatch_time
25-
print(
28+
logger.info(
2629
'Received %r after %d millisec' %
2730
(frame.data, (time.time() - last_dispatch_time) * 1000),
2831
)

proxy/common/flag.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,18 @@ def initialize(
221221
IpAddress,
222222
opts.get('hostname', ipaddress.ip_address(args.hostname)),
223223
)
224-
args.family = socket.AF_INET6 if args.hostname.version == 6 else socket.AF_INET
224+
# AF_UNIX is not available on Windows
225+
# See https://bugs.python.org/issue33408
226+
if os.name != 'nt':
227+
args.family = socket.AF_UNIX if args.unix_socket_path else (
228+
socket.AF_INET6 if args.hostname.version == 6 else socket.AF_INET
229+
)
230+
else:
231+
# FIXME: Not true for tests, as this value will be mock
232+
# It's a problem only on Windows. Instead of a proper
233+
# test level fix, simply commenting this for now.
234+
# assert args.unix_socket_path is None
235+
args.family = socket.AF_INET6 if args.hostname.version == 6 else socket.AF_INET
225236
args.port = cast(int, opts.get('port', args.port))
226237
args.backlog = cast(int, opts.get('backlog', args.backlog))
227238
num_workers = opts.get('num_workers', args.num_workers)

proxy/common/pki.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,15 +288,20 @@ def run_openssl_command(command: List[str], timeout: int) -> bool:
288288

289289
# Validation
290290
if args.action not in available_actions:
291-
print('Invalid --action. Valid values ' + ', '.join(available_actions))
291+
logger.error(
292+
'Invalid --action. Valid values ' +
293+
', '.join(available_actions),
294+
)
292295
sys.exit(1)
293296
if args.action in ('gen_private_key', 'gen_public_key'):
294297
if args.private_key_path is None:
295-
print('--private-key-path is required for ' + args.action)
298+
logger.error('--private-key-path is required for ' + args.action)
296299
sys.exit(1)
297300
if args.action == 'gen_public_key':
298301
if args.public_key_path is None:
299-
print('--public-key-file is required for private key generation')
302+
logger.error(
303+
'--public-key-file is required for private key generation',
304+
)
300305
sys.exit(1)
301306

302307
# Execute

proxy/core/acceptor/acceptor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,20 @@ def shutdown_threadless_process(self) -> None:
113113
self.threadless_process.join()
114114
self.threadless_client_queue.close()
115115

116-
def _start_threadless_work(self, conn: socket.socket, addr: Tuple[str, int]) -> None:
116+
def _start_threadless_work(self, conn: socket.socket, addr: Optional[Tuple[str, int]]) -> None:
117117
assert self.threadless_process and self.threadless_client_queue
118-
self.threadless_client_queue.send(addr)
118+
# Accepted client address is empty string for
119+
# unix socket domain, avoid sending empty string
120+
if not self.flags.unix_socket_path:
121+
self.threadless_client_queue.send(addr)
119122
send_handle(
120123
self.threadless_client_queue,
121124
conn.fileno(),
122125
self.threadless_process.pid,
123126
)
124127
conn.close()
125128

126-
def _start_threaded_work(self, conn: socket.socket, addr: Tuple[str, int]) -> None:
129+
def _start_threaded_work(self, conn: socket.socket, addr: Optional[Tuple[str, int]]) -> None:
127130
work = self.work_klass(
128131
TcpClientConnection(conn, addr),
129132
flags=self.flags,
@@ -145,6 +148,7 @@ def run_once(self) -> None:
145148
if len(events) == 0:
146149
return
147150
conn, addr = self.sock.accept()
151+
addr = None if addr == '' else addr
148152
if (
149153
self.flags.threadless and
150154
self.threadless_client_queue and

proxy/core/acceptor/pool.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
:copyright: (c) 2013-present by Abhinav Singh and contributors.
99
:license: BSD, see LICENSE for more details.
1010
"""
11+
import os
1112
import argparse
1213
import logging
1314
import multiprocessing
@@ -61,6 +62,14 @@
6162
help='Defaults to number of CPU cores.',
6263
)
6364

65+
flags.add_argument(
66+
'--unix-socket-path',
67+
type=str,
68+
default=None,
69+
help='Default: None. Unix socket path to use. ' +
70+
'When provided --host and --port flags are ignored',
71+
)
72+
6473

6574
class AcceptorPool:
6675
"""AcceptorPool is a helper class which pre-spawns `Acceptor` processes
@@ -108,8 +117,11 @@ def __exit__(
108117
self.shutdown()
109118

110119
def setup(self) -> None:
111-
"""Listen on port and setup acceptors."""
112-
self._listen()
120+
"""Setup socket and acceptors."""
121+
if self.flags.unix_socket_path:
122+
self._listen_unix_socket()
123+
else:
124+
self._listen_server_port()
113125
# Override flags.port to match the actual port
114126
# we are listening upon. This is necessary to preserve
115127
# the server port when `--port=0` is used.
@@ -133,9 +145,18 @@ def shutdown(self) -> None:
133145
acceptor.running.set()
134146
for acceptor in self.acceptors:
135147
acceptor.join()
148+
if self.flags.unix_socket_path:
149+
os.remove(self.flags.unix_socket_path)
136150
logger.debug('Acceptors shutdown')
137151

138-
def _listen(self) -> None:
152+
def _listen_unix_socket(self) -> None:
153+
self.socket = socket.socket(self.flags.family, socket.SOCK_STREAM)
154+
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
155+
self.socket.bind(self.flags.unix_socket_path)
156+
self.socket.listen(self.flags.backlog)
157+
self.socket.setblocking(False)
158+
159+
def _listen_server_port(self) -> None:
139160
self.socket = socket.socket(self.flags.family, socket.SOCK_STREAM)
140161
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
141162
self.socket.bind((str(self.flags.hostname), self.flags.port))

proxy/core/acceptor/threadless.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,11 @@ def fromfd(self, fileno: int) -> socket.socket:
133133
)
134134

135135
def accept_client(self) -> None:
136-
addr = self.client_queue.recv()
136+
# Acceptor will not send address for
137+
# unix socket domain environments.
138+
addr = None
139+
if not self.flags.unix_socket_path:
140+
addr = self.client_queue.recv()
137141
fileno = recv_handle(self.client_queue)
138142
self.works[fileno] = self.work_klass(
139143
TcpClientConnection(conn=self.fromfd(fileno), addr=addr),

proxy/core/base/tcp_server.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,14 @@ class BaseTcpServerHandler(Work):
4545
def __init__(self, *args: Any, **kwargs: Any) -> None:
4646
super().__init__(*args, **kwargs)
4747
self.must_flush_before_shutdown = False
48-
logger.debug('Connection accepted from {0}'.format(self.work.addr))
48+
if self.flags.unix_socket_path:
49+
logger.debug(
50+
'Connection accepted from {0}'.format(self.work.address),
51+
)
52+
else:
53+
logger.debug(
54+
'Connection accepted from {0}'.format(self.work.address),
55+
)
4956

5057
@abstractmethod
5158
def handle_data(self, data: memoryview) -> Optional[bool]:
@@ -79,7 +86,7 @@ def handle_events(
7986
if teardown:
8087
logger.debug(
8188
'Shutting down client {0} connection'.format(
82-
self.work.addr,
89+
self.work.address,
8390
),
8491
)
8592
return teardown
@@ -88,7 +95,7 @@ def handle_writables(self, writables: Writables) -> bool:
8895
teardown = False
8996
if self.work.connection in writables and self.work.has_buffer():
9097
logger.debug(
91-
'Flushing buffer to client {0}'.format(self.work.addr),
98+
'Flushing buffer to client {0}'.format(self.work.address),
9299
)
93100
self.work.flush()
94101
if self.must_flush_before_shutdown is True:
@@ -104,7 +111,7 @@ def handle_readables(self, readables: Readables) -> bool:
104111
if data is None:
105112
logger.debug(
106113
'Connection closed by client {0}'.format(
107-
self.work.addr,
114+
self.work.address,
108115
),
109116
)
110117
teardown = True
@@ -113,13 +120,13 @@ def handle_readables(self, readables: Readables) -> bool:
113120
if isinstance(r, bool) and r is True:
114121
logger.debug(
115122
'Implementation signaled shutdown for client {0}'.format(
116-
self.work.addr,
123+
self.work.address,
117124
),
118125
)
119126
if self.work.has_buffer():
120127
logger.debug(
121128
'Client {0} has pending buffer, will be flushed before shutting down'.format(
122-
self.work.addr,
129+
self.work.address,
123130
),
124131
)
125132
self.must_flush_before_shutdown = True

proxy/core/base/tcp_tunnel.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
:license: BSD, see LICENSE for more details.
1010
"""
1111
import socket
12+
import logging
1213
import selectors
1314

1415
from abc import abstractmethod
@@ -21,6 +22,8 @@
2122
from ..connection import TcpServerConnection
2223
from .tcp_server import BaseTcpServerHandler
2324

25+
logger = logging.getLogger(__name__)
26+
2427

2528
class BaseTcpTunnelHandler(BaseTcpServerHandler):
2629
"""BaseTcpTunnelHandler build on-top of BaseTcpServerHandler work klass.
@@ -47,7 +50,7 @@ def initialize(self) -> None:
4750

4851
def shutdown(self) -> None:
4952
if self.upstream:
50-
print(
53+
logger.debug(
5154
'Connection closed with upstream {0}:{1}'.format(
5255
text_(self.request.host), self.request.port,
5356
),
@@ -84,7 +87,7 @@ def handle_events(
8487
data = self.upstream.recv()
8588
if data is None:
8689
# Server closed connection
87-
print('Connection closed by server')
90+
logger.debug('Connection closed by server')
8891
return True
8992
# tunnel data to client
9093
self.work.queue(data)
@@ -98,7 +101,7 @@ def connect_upstream(self) -> None:
98101
text_(self.request.host), self.request.port,
99102
)
100103
self.upstream.connect()
101-
print(
104+
logger.debug(
102105
'Connection established with upstream {0}:{1}'.format(
103106
text_(self.request.host), self.request.port,
104107
),

proxy/core/connection/client.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,16 @@ class TcpClientConnection(TcpConnection):
2121
def __init__(
2222
self,
2323
conn: Union[ssl.SSLSocket, socket.socket],
24-
addr: Tuple[str, int],
24+
# optional for unix socket servers
25+
addr: Optional[Tuple[str, int]] = None,
2526
) -> None:
2627
super().__init__(tcpConnectionTypes.CLIENT)
2728
self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = conn
28-
self.addr: Tuple[str, int] = addr
29+
self.addr: Optional[Tuple[str, int]] = addr
30+
31+
@property
32+
def address(self) -> str:
33+
return 'unix:client' if not self.addr else '{0}:{1}'.format(self.addr[0], self.addr[1])
2934

3035
@property
3136
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:

0 commit comments

Comments
 (0)