Skip to content

Commit 0d98132

Browse files
committed
net: add multi listener impl for net.Listener
This adds an implementation of net.Listener which listens on and accepts connections from multiple addresses. Signed-off-by: Daman Arora <[email protected]>
1 parent fe8a2dd commit 0d98132

File tree

2 files changed

+421
-0
lines changed

2 files changed

+421
-0
lines changed

net/multi_listen.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
Copyright 2024 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package net
18+
19+
import (
20+
"fmt"
21+
"net"
22+
"sync"
23+
"sync/atomic"
24+
)
25+
26+
// connAndErrPair packs results of accept call on actual listeners together along with the
27+
// index of the listener. It is used for communication between main and listener goroutines.
28+
type connAndErrPair struct {
29+
index int
30+
conn net.Conn
31+
err error
32+
}
33+
34+
// multiListener implements net.Listener
35+
type multiListener struct {
36+
latestAcceptedIndex atomic.Int32
37+
addrs []net.Addr
38+
listeners []net.Listener
39+
40+
wg sync.WaitGroup
41+
connAndErrCh chan connAndErrPair
42+
stopCh chan struct{}
43+
}
44+
45+
// compile time check to ensure *multiListener implements net.Listener.
46+
var _ net.Listener = &multiListener{}
47+
48+
// MultiListen returns net.Listener which can listen for and accept
49+
// TCP connections on multiple addresses.
50+
func MultiListen(addresses []string) (net.Listener, error) {
51+
return multiListen(addresses, net.ResolveTCPAddr, net.Listen)
52+
}
53+
54+
func multiListen(
55+
addresses []string,
56+
resolveTCPAddrFunc func(network, address string) (*net.TCPAddr, error),
57+
listenFunc func(network, address string) (net.Listener, error),
58+
) (net.Listener, error) {
59+
ml := &multiListener{
60+
connAndErrCh: make(chan connAndErrPair),
61+
stopCh: make(chan struct{}),
62+
}
63+
64+
for _, address := range addresses {
65+
addr, err := resolveTCPAddrFunc("tcp", address)
66+
if err != nil {
67+
// close all listeners
68+
_ = ml.Close()
69+
return nil, err
70+
}
71+
72+
var network string
73+
host, _, err := net.SplitHostPort(addr.String())
74+
if err != nil {
75+
// close all listeners
76+
_ = ml.Close()
77+
return nil, err
78+
}
79+
switch IPFamilyOf(ParseIPSloppy(host)) {
80+
case IPv4:
81+
network = "tcp4"
82+
case IPv6:
83+
network = "tcp6"
84+
default:
85+
// close all listeners
86+
_ = ml.Close()
87+
return nil, fmt.Errorf("failed to identify ip family of address '%s", addr.String())
88+
}
89+
90+
l, err := listenFunc(network, addr.String())
91+
if err != nil {
92+
// close all listeners
93+
_ = ml.Close()
94+
return nil, err
95+
}
96+
97+
ml.addrs = append(ml.addrs, addr)
98+
ml.listeners = append(ml.listeners, l)
99+
}
100+
101+
for i := range ml.listeners {
102+
index := i
103+
ml.wg.Add(1)
104+
// spawn a go routine for every listener to wait for incoming connection requests
105+
go func() {
106+
defer ml.wg.Done()
107+
for {
108+
conn, err := ml.listeners[index].Accept()
109+
if err != nil {
110+
select {
111+
// Accept() will throw "use of closed network connection" when listener is closed,
112+
// we can ignore that error and break out of this goroutine.
113+
case <-ml.stopCh:
114+
return
115+
default:
116+
ml.connAndErrCh <- connAndErrPair{conn: conn, err: err, index: index}
117+
}
118+
}
119+
ml.connAndErrCh <- connAndErrPair{conn: conn, err: err, index: index}
120+
}
121+
}()
122+
}
123+
return ml, nil
124+
}
125+
126+
// Accept is part of net.Listener interface.
127+
func (ml *multiListener) Accept() (net.Conn, error) {
128+
connAndErr := <-ml.connAndErrCh
129+
// update latestAcceptedIndex with index of the listener which accepted the connection
130+
ml.latestAcceptedIndex.Store(int32(connAndErr.index))
131+
return connAndErr.conn, connAndErr.err
132+
}
133+
134+
// Close is part of net.Listener interface.
135+
func (ml *multiListener) Close() error {
136+
close(ml.stopCh)
137+
for i := range ml.listeners {
138+
_ = ml.listeners[i].Close()
139+
}
140+
ml.wg.Wait()
141+
return nil
142+
}
143+
144+
// Addr is part of net.Listener interface.
145+
func (ml *multiListener) Addr() net.Addr {
146+
return ml.addrs[ml.latestAcceptedIndex.Load()]
147+
}

0 commit comments

Comments
 (0)