Skip to content

Commit 66a431e

Browse files
aroradamanDaman Arora
authored and
Daman Arora
committed
net: add multi listener impl for net.Listener
This adds an implementation of net.Listener which uses synchronous multiplexing to listen on and accept connections from multiple addrs. Signed-off-by: Daman Arora <[email protected]>
1 parent fe8a2dd commit 66a431e

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed

net/multi_listen.go

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
/*
2+
Copyright 2024 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package net
18+
19+
import (
20+
"fmt"
21+
"net"
22+
"os"
23+
"syscall"
24+
)
25+
26+
// multiListener implements net.Listener and uses synchronous multiplexing to listen to and accept
27+
// TCP connections from multiple addresses.
28+
type multiListener struct {
29+
latestAcceptedFDIndex int
30+
fds []int
31+
addrs []net.Addr
32+
stopCh chan struct{}
33+
}
34+
35+
// compile time check to ensure *multiListener implements net.Listener.
36+
var _ net.Listener = &multiListener{}
37+
38+
// NewMultiListener returns *multiListener as net.Listener allowing consumers to synchronously
39+
// listen for TCP connections on multiple addresses.
40+
func NewMultiListener(addresses []string) (net.Listener, error) {
41+
ml := &multiListener{
42+
stopCh: make(chan struct{}),
43+
}
44+
for _, address := range addresses {
45+
fd, addr, err := createBindAndListen(address)
46+
if err != nil {
47+
return nil, err
48+
}
49+
ml.fds = append(ml.fds, fd)
50+
ml.addrs = append(ml.addrs, addr)
51+
}
52+
return ml, nil
53+
}
54+
55+
// createBindAndListen creates a TCP socket, binds it to the specified address, and starts listening on it.
56+
func createBindAndListen(address string) (int, net.Addr, error) {
57+
host, _, err := net.SplitHostPort(address)
58+
if err != nil {
59+
return -1, nil, err
60+
}
61+
62+
ipFamily := IPFamilyOf(ParseIPSloppy(host))
63+
var network string
64+
var domain int
65+
switch ipFamily {
66+
case IPv4:
67+
network = "tcp4"
68+
domain = syscall.AF_INET
69+
case IPv6:
70+
network = "tcp6"
71+
domain = syscall.AF_INET6
72+
default:
73+
return -1, nil, fmt.Errorf("failed to idenfity ip family of host '%s'", host)
74+
75+
}
76+
77+
// resolve tcp addr
78+
addr, err := net.ResolveTCPAddr(network, address)
79+
if err != nil {
80+
return -1, nil, err
81+
}
82+
83+
// create socket
84+
fd, err := syscall.Socket(domain, syscall.SOCK_STREAM, 0)
85+
if err != nil {
86+
return -1, nil, err
87+
}
88+
89+
// define socket address for bind
90+
var sockAddr syscall.Sockaddr
91+
if ipFamily == IPv4 {
92+
var ipBytes [4]byte
93+
copy(ipBytes[:], addr.IP.To4())
94+
sockAddr = &syscall.SockaddrInet4{
95+
Addr: ipBytes,
96+
Port: addr.Port,
97+
}
98+
} else {
99+
var ipBytes [16]byte
100+
copy(ipBytes[:], addr.IP.To16())
101+
sockAddr = &syscall.SockaddrInet6{
102+
Addr: ipBytes,
103+
Port: addr.Port,
104+
}
105+
}
106+
107+
// bind socket to specified addr
108+
if err = syscall.Bind(fd, sockAddr); err != nil {
109+
_ = syscall.Close(fd)
110+
return -1, nil, err
111+
}
112+
113+
// start listening on socket
114+
if err = syscall.Listen(fd, syscall.SOMAXCONN); err != nil {
115+
_ = syscall.Close(fd)
116+
return -1, nil, err
117+
}
118+
119+
return fd, addr, nil
120+
}
121+
122+
// Accept is part of net.Listener interface.
123+
func (ml *multiListener) Accept() (net.Conn, error) {
124+
for {
125+
readFds := &syscall.FdSet{}
126+
maxfd := 0
127+
128+
for _, fd := range ml.fds {
129+
if fd > maxfd {
130+
maxfd = fd
131+
}
132+
addFDToFDSet(fd, readFds)
133+
}
134+
135+
// wait for any of the sockets to be ready for accepting new connection
136+
timeout := syscall.Timeval{Sec: 1, Usec: 0}
137+
n, err := syscall.Select(maxfd+1, readFds, nil, nil, &timeout)
138+
if err != nil {
139+
return nil, err
140+
}
141+
if n == 0 {
142+
select {
143+
case <-ml.stopCh:
144+
return nil, fmt.Errorf("multiListener closed")
145+
default:
146+
continue
147+
}
148+
}
149+
for i, fd := range ml.fds {
150+
if isFDInFDSet(fd, readFds) {
151+
conn, err := acceptConnection(fd)
152+
if err != nil {
153+
return nil, err
154+
}
155+
ml.latestAcceptedFDIndex = i
156+
return conn, nil
157+
}
158+
}
159+
}
160+
}
161+
162+
// addFDToFDSet adds fd to the given fd set
163+
func addFDToFDSet(fd int, p *syscall.FdSet) {
164+
mask := 1 << (uint(fd) % syscall.FD_SETSIZE)
165+
p.Bits[fd/syscall.FD_SETSIZE] |= int64(mask)
166+
}
167+
168+
// isFDInFDSet returns true if fd is in fd set, false otherwise
169+
func isFDInFDSet(fd int, p *syscall.FdSet) bool {
170+
mask := 1 << (uint(fd) % syscall.FD_SETSIZE)
171+
return p.Bits[fd/syscall.FD_SETSIZE]&int64(mask) != 0
172+
}
173+
174+
// acceptConnection accepts connection and returns remote connection object
175+
func acceptConnection(fd int) (net.Conn, error) {
176+
connFD, _, err := syscall.Accept(fd)
177+
if err != nil {
178+
return nil, err
179+
}
180+
181+
conn, err := net.FileConn(os.NewFile(uintptr(connFD), fmt.Sprintf("fd %d", connFD)))
182+
if err != nil {
183+
_ = syscall.Close(connFD)
184+
return nil, err
185+
}
186+
return conn, nil
187+
}
188+
189+
// Close is part of net.Listener interface.
190+
func (ml *multiListener) Close() error {
191+
close(ml.stopCh)
192+
for _, fd := range ml.fds {
193+
_ = syscall.Close(fd)
194+
}
195+
return nil
196+
}
197+
198+
// Addr is part of net.Listener interface.
199+
func (ml *multiListener) Addr() net.Addr {
200+
return ml.addrs[ml.latestAcceptedFDIndex]
201+
}

0 commit comments

Comments
 (0)