Skip to content

Commit 6bd4edf

Browse files
aroradamanDaman Arora
authored and
Daman Arora
committed
net: add multi listener impl for net.Listener
This adds an implementation of net.Listener which listens on and accepts connections from multiple addresses. Signed-off-by: Daman Arora <[email protected]>
1 parent fe8a2dd commit 6bd4edf

File tree

3 files changed

+325
-0
lines changed

3 files changed

+325
-0
lines changed

net/multi_listen.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
Copyright 2024 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package net
18+
19+
import (
20+
"fmt"
21+
"net"
22+
"syscall"
23+
)
24+
25+
// multiListener implements net.Listener and uses multiplexing to listen to and accept
26+
// TCP connections from multiple addresses.
27+
type multiListener struct {
28+
latestAcceptedFDIndex int
29+
fds []int
30+
addrs []net.Addr
31+
stopCh chan struct{}
32+
}
33+
34+
// compile time check to ensure *multiListener implements net.Listener.
35+
var _ net.Listener = &multiListener{}
36+
37+
// NewMultiListener returns *multiListener as net.Listener allowing consumers to
38+
// listen for TCP connections on multiple addresses.
39+
func NewMultiListener(addresses []string) (net.Listener, error) {
40+
ml := &multiListener{
41+
stopCh: make(chan struct{}),
42+
}
43+
for _, address := range addresses {
44+
fd, addr, err := createBindAndListen(address)
45+
if err != nil {
46+
return nil, err
47+
}
48+
ml.fds = append(ml.fds, fd)
49+
ml.addrs = append(ml.addrs, addr)
50+
}
51+
return ml, nil
52+
}
53+
54+
// Accept is part of net.Listener interface.
55+
func (ml *multiListener) Accept() (net.Conn, error) {
56+
return ml.accept()
57+
}
58+
59+
// Close is part of net.Listener interface.
60+
func (ml *multiListener) Close() error {
61+
close(ml.stopCh)
62+
for _, fd := range ml.fds {
63+
_ = syscall.Close(fd)
64+
}
65+
return nil
66+
}
67+
68+
// Addr is part of net.Listener interface.
69+
func (ml *multiListener) Addr() net.Addr {
70+
return ml.addrs[ml.latestAcceptedFDIndex]
71+
}
72+
73+
// createBindAndListen creates a TCP socket, binds it to the specified address, and starts listening on it.
74+
func createBindAndListen(address string) (int, net.Addr, error) {
75+
host, _, err := net.SplitHostPort(address)
76+
if err != nil {
77+
return -1, nil, err
78+
}
79+
80+
ipFamily := IPFamilyOf(ParseIPSloppy(host))
81+
var network string
82+
var domain int
83+
switch ipFamily {
84+
case IPv4:
85+
network = "tcp4"
86+
domain = syscall.AF_INET
87+
case IPv6:
88+
network = "tcp6"
89+
domain = syscall.AF_INET6
90+
default:
91+
return -1, nil, fmt.Errorf("failed to identify ip family of host '%s'", host)
92+
93+
}
94+
95+
// resolve tcp addr
96+
addr, err := net.ResolveTCPAddr(network, address)
97+
if err != nil {
98+
return -1, nil, err
99+
}
100+
101+
// create socket
102+
fd, err := syscall.Socket(domain, syscall.SOCK_STREAM, 0)
103+
if err != nil {
104+
return -1, nil, err
105+
}
106+
107+
// define socket address for bind
108+
var sockAddr syscall.Sockaddr
109+
if ipFamily == IPv4 {
110+
var ipBytes [4]byte
111+
copy(ipBytes[:], addr.IP.To4())
112+
sockAddr = &syscall.SockaddrInet4{
113+
Addr: ipBytes,
114+
Port: addr.Port,
115+
}
116+
} else {
117+
var ipBytes [16]byte
118+
copy(ipBytes[:], addr.IP.To16())
119+
sockAddr = &syscall.SockaddrInet6{
120+
Addr: ipBytes,
121+
Port: addr.Port,
122+
}
123+
}
124+
125+
// bind socket to specified addr
126+
if err = syscall.Bind(fd, sockAddr); err != nil {
127+
_ = syscall.Close(fd)
128+
return -1, nil, err
129+
}
130+
131+
// start listening on socket
132+
if err = syscall.Listen(fd, syscall.SOMAXCONN); err != nil {
133+
_ = syscall.Close(fd)
134+
return -1, nil, err
135+
}
136+
137+
return fd, addr, nil
138+
}

net/multi_listen_darwin.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
//go:build darwin
2+
// +build darwin
3+
4+
/*
5+
Copyright 2024 The Kubernetes Authors.
6+
7+
Licensed under the Apache License, Version 2.0 (the "License");
8+
you may not use this file except in compliance with the License.
9+
You may obtain a copy of the License at
10+
11+
http://www.apache.org/licenses/LICENSE-2.0
12+
13+
Unless required by applicable law or agreed to in writing, software
14+
distributed under the License is distributed on an "AS IS" BASIS,
15+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
See the License for the specific language governing permissions and
17+
limitations under the License.
18+
*/
19+
20+
package net
21+
22+
import (
23+
"fmt"
24+
"net"
25+
"os"
26+
"syscall"
27+
)
28+
29+
// Accept is part of net.Listener interface.
30+
func (ml *multiListener) accept() (net.Conn, error) {
31+
for {
32+
readFds := &syscall.FdSet{}
33+
maxfd := 0
34+
35+
for _, fd := range ml.fds {
36+
if fd > maxfd {
37+
maxfd = fd
38+
}
39+
addFDToFDSet(fd, readFds)
40+
}
41+
42+
// wait for any of the sockets to be ready for accepting new connection
43+
timeout := syscall.Timeval{Sec: 1, Usec: 0}
44+
err := syscall.Select(maxfd+1, readFds, nil, nil, &timeout)
45+
if err != nil {
46+
return nil, err
47+
}
48+
49+
for i, fd := range ml.fds {
50+
if isFDInFDSet(fd, readFds) {
51+
conn, err := acceptConnection(fd)
52+
if err != nil {
53+
return nil, err
54+
}
55+
ml.latestAcceptedFDIndex = i
56+
return conn, nil
57+
}
58+
}
59+
60+
select {
61+
case <-ml.stopCh:
62+
return nil, fmt.Errorf("multiListener closed")
63+
default:
64+
continue
65+
}
66+
}
67+
}
68+
69+
// addFDToFDSet adds fd to the given fd set
70+
func addFDToFDSet(fd int, p *syscall.FdSet) {
71+
mask := 1 << (uint(fd) % syscall.FD_SETSIZE)
72+
p.Bits[fd/syscall.FD_SETSIZE] |= int32(mask)
73+
}
74+
75+
// isFDInFDSet returns true if fd is in fd set, false otherwise
76+
func isFDInFDSet(fd int, p *syscall.FdSet) bool {
77+
mask := 1 << (uint(fd) % syscall.FD_SETSIZE)
78+
return p.Bits[fd/syscall.FD_SETSIZE]&int32(mask) != 0
79+
}
80+
81+
// acceptConnection accepts connection and returns remote connection object
82+
func acceptConnection(fd int) (net.Conn, error) {
83+
connFD, _, err := syscall.Accept(fd)
84+
if err != nil {
85+
return nil, err
86+
}
87+
88+
conn, err := net.FileConn(os.NewFile(uintptr(connFD), fmt.Sprintf("fd %d", connFD)))
89+
if err != nil {
90+
_ = syscall.Close(connFD)
91+
return nil, err
92+
}
93+
return conn, nil
94+
}

net/multi_listen_linux.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
//go:build linux
2+
// +build linux
3+
4+
/*
5+
Copyright 2024 The Kubernetes Authors.
6+
7+
Licensed under the Apache License, Version 2.0 (the "License");
8+
you may not use this file except in compliance with the License.
9+
You may obtain a copy of the License at
10+
11+
http://www.apache.org/licenses/LICENSE-2.0
12+
13+
Unless required by applicable law or agreed to in writing, software
14+
distributed under the License is distributed on an "AS IS" BASIS,
15+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
See the License for the specific language governing permissions and
17+
limitations under the License.
18+
*/
19+
20+
package net
21+
22+
import (
23+
"fmt"
24+
"net"
25+
"os"
26+
"syscall"
27+
)
28+
29+
func (ml *multiListener) accept() (net.Conn, error) {
30+
for {
31+
readFds := &syscall.FdSet{}
32+
maxfd := 0
33+
34+
for _, fd := range ml.fds {
35+
if fd > maxfd {
36+
maxfd = fd
37+
}
38+
addFDToFDSet(fd, readFds)
39+
}
40+
41+
// wait for any of the sockets to be ready for accepting new connection
42+
timeout := syscall.Timeval{Sec: 1, Usec: 0}
43+
n, err := syscall.Select(maxfd+1, readFds, nil, nil, &timeout)
44+
if err != nil {
45+
return nil, err
46+
}
47+
if n == 0 {
48+
select {
49+
case <-ml.stopCh:
50+
return nil, fmt.Errorf("multiListener closed")
51+
default:
52+
continue
53+
}
54+
}
55+
for i, fd := range ml.fds {
56+
if isFDInFDSet(fd, readFds) {
57+
conn, err := acceptConnection(fd)
58+
if err != nil {
59+
return nil, err
60+
}
61+
ml.latestAcceptedFDIndex = i
62+
return conn, nil
63+
}
64+
}
65+
}
66+
}
67+
68+
// addFDToFDSet adds fd to the given fd set
69+
func addFDToFDSet(fd int, p *syscall.FdSet) {
70+
mask := 1 << (uint(fd) % syscall.FD_SETSIZE)
71+
p.Bits[fd/syscall.FD_SETSIZE] |= int64(mask)
72+
}
73+
74+
// isFDInFDSet returns true if fd is in fd set, false otherwise
75+
func isFDInFDSet(fd int, p *syscall.FdSet) bool {
76+
mask := 1 << (uint(fd) % syscall.FD_SETSIZE)
77+
return p.Bits[fd/syscall.FD_SETSIZE]&int64(mask) != 0
78+
}
79+
80+
// acceptConnection accepts connection and returns remote connection object
81+
func acceptConnection(fd int) (net.Conn, error) {
82+
connFD, _, err := syscall.Accept(fd)
83+
if err != nil {
84+
return nil, err
85+
}
86+
87+
conn, err := net.FileConn(os.NewFile(uintptr(connFD), fmt.Sprintf("fd %d", connFD)))
88+
if err != nil {
89+
_ = syscall.Close(connFD)
90+
return nil, err
91+
}
92+
return conn, nil
93+
}

0 commit comments

Comments
 (0)