Skip to content

Commit d9176a1

Browse files
committed
device: add API for on-demand configuration of peers
Updates tailscale/tailscale#17858 Signed-off-by: Brad Fitzpatrick <[email protected]>
1 parent 1d0488a commit d9176a1

File tree

7 files changed

+133
-22
lines changed

7 files changed

+133
-22
lines changed

device/allowedips.go

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,14 @@ func (node *trieEntry) lookup(ip []byte) *Peer {
205205
}
206206

207207
type AllowedIPs struct {
208-
IPv4 *trieEntry
209-
IPv6 *trieEntry
210-
mutex sync.RWMutex
208+
mu sync.RWMutex
209+
ipv4 *trieEntry
210+
ipv6 *trieEntry
211211
}
212212

213213
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
214-
table.mutex.RLock()
215-
defer table.mutex.RUnlock()
214+
table.mu.RLock()
215+
defer table.mu.RUnlock()
216216

217217
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
218218
node := elem.Value.(*trieEntry)
@@ -223,10 +223,23 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
223223
}
224224
}
225225

226+
func (table *AllowedIPs) SetPeerPrefixes(peer *Peer, prefixes []netip.Prefix) {
227+
table.mu.Lock()
228+
defer table.mu.Unlock()
229+
230+
table.removeByPeerLocked(peer)
231+
for _, prefix := range prefixes {
232+
table.insertLocked(prefix, peer)
233+
}
234+
}
235+
226236
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
227-
table.mutex.Lock()
228-
defer table.mutex.Unlock()
237+
table.mu.Lock()
238+
defer table.mu.Unlock()
239+
table.removeByPeerLocked(peer)
240+
}
229241

242+
func (table *AllowedIPs) removeByPeerLocked(peer *Peer) {
230243
var next *list.Element
231244
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
232245
next = elem.Next()
@@ -266,28 +279,31 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
266279
}
267280

268281
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
269-
table.mutex.Lock()
270-
defer table.mutex.Unlock()
282+
table.mu.Lock()
283+
defer table.mu.Unlock()
284+
table.insertLocked(prefix, peer)
285+
}
271286

287+
func (table *AllowedIPs) insertLocked(prefix netip.Prefix, peer *Peer) {
272288
if prefix.Addr().Is6() {
273289
ip := prefix.Addr().As16()
274-
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
290+
parentIndirection{&table.ipv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
275291
} else if prefix.Addr().Is4() {
276292
ip := prefix.Addr().As4()
277-
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
293+
parentIndirection{&table.ipv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
278294
} else {
279295
panic(errors.New("inserting unknown address type"))
280296
}
281297
}
282298

283299
func (table *AllowedIPs) Lookup(ip []byte) *Peer {
284-
table.mutex.RLock()
285-
defer table.mutex.RUnlock()
300+
table.mu.RLock()
301+
defer table.mu.RUnlock()
286302
switch len(ip) {
287303
case net.IPv6len:
288-
return table.IPv6.lookup(ip)
304+
return table.ipv6.lookup(ip)
289305
case net.IPv4len:
290-
return table.IPv4.lookup(ip)
306+
return table.ipv4.lookup(ip)
291307
default:
292308
panic(errors.New("looking up unknown address type"))
293309
}

device/allowedips_rand_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ func TestTrieRandom(t *testing.T) {
135135
allowedIPs.RemoveByPeer(peers[p])
136136
}
137137

138-
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
138+
if allowedIPs.ipv4 != nil || allowedIPs.ipv6 != nil {
139139
t.Error("Failed to remove all nodes from trie by peer")
140140
}
141141
}

device/allowedips_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func TestTrieIPv4(t *testing.T) {
166166
allowedIPs.RemoveByPeer(e)
167167
allowedIPs.RemoveByPeer(g)
168168
allowedIPs.RemoveByPeer(h)
169-
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
169+
if allowedIPs.ipv4 != nil || allowedIPs.ipv6 != nil {
170170
t.Error("Expected removing all the peers to empty trie, but it did not")
171171
}
172172

device/device.go

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
package device
77

88
import (
9+
"errors"
10+
"net/netip"
911
"runtime"
1012
"sync"
1113
"sync/atomic"
@@ -56,6 +58,7 @@ type Device struct {
5658
peers struct {
5759
sync.RWMutex // protects keyMap
5860
keyMap map[NoisePublicKey]*Peer
61+
lookupFunc PeerLookupFunc // or nil if unused
5962
}
6063

6164
rate struct {
@@ -91,6 +94,10 @@ type Device struct {
9194
log *Logger
9295
}
9396

97+
func (device *Device) AllowedIPs() *AllowedIPs {
98+
return &device.allowedips
99+
}
100+
94101
// deviceState represents the state of a Device.
95102
// There are three states: down, up, closed.
96103
// Transitions:
@@ -340,11 +347,36 @@ func (device *Device) BatchSize() int {
340347

341348
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
342349
device.peers.RLock()
343-
defer device.peers.RUnlock()
350+
p, ok := device.peers.keyMap[pk]
351+
lookupFunc := device.peers.lookupFunc
352+
device.peers.RUnlock()
353+
if ok || lookupFunc == nil {
354+
return p
355+
}
344356

345-
return device.peers.keyMap[pk]
357+
allowedIPs := lookupFunc(pk)
358+
if allowedIPs == nil {
359+
return nil
360+
}
361+
362+
p, err := device.NewPeer(pk)
363+
if err != nil {
364+
if errors.Is(err, errAddExistingPeer) {
365+
device.peers.RLock()
366+
defer device.peers.RUnlock()
367+
return device.peers.keyMap[pk]
368+
}
369+
device.log.Errorf("Failed to create peer: %v", err)
370+
return nil
371+
}
372+
p.SetAllowedIPs(allowedIPs)
373+
p.deleteOnIdle = true
374+
p.Start()
375+
return p
346376
}
347377

378+
var errAddExistingPeer = errors.New("adding existing peer")
379+
348380
func (device *Device) RemovePeer(key NoisePublicKey) {
349381
device.peers.Lock()
350382
defer device.peers.Unlock()
@@ -367,6 +399,39 @@ func (device *Device) RemoveAllPeers() {
367399
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
368400
}
369401

402+
// RemoveMatchingPeers removes all peers for which shouldRemove returns true.
403+
//
404+
// It returns the number of peers removed.
405+
func (device *Device) RemoveMatchingPeers(shouldRemove func(NoisePublicKey) bool) (numRemoved int) {
406+
device.peers.Lock()
407+
defer device.peers.Unlock()
408+
409+
for key, peer := range device.peers.keyMap {
410+
if shouldRemove(key) {
411+
removePeerLocked(device, peer, key)
412+
numRemoved++
413+
}
414+
}
415+
return numRemoved
416+
}
417+
418+
// PeerLookupFunc is the type of function used to look up peers by public key
419+
// when receiving packets for unknown peers.
420+
//
421+
// If it returns nil, the peer is not known.
422+
//
423+
// Otherwise, returning non-nil signals that wireguard-go should create the peer
424+
// with the provided allowed IPs.
425+
type PeerLookupFunc func(NoisePublicKey) (allowedIPs []netip.Prefix)
426+
427+
// SetPeerLookupFunc sets the function used to look up peers by public key
428+
// when receiving packets for unknown peers.
429+
func (device *Device) SetPeerLookupFunc(f PeerLookupFunc) {
430+
device.peers.Lock()
431+
defer device.peers.Unlock()
432+
device.peers.lookupFunc = f
433+
}
434+
370435
func (device *Device) Close() {
371436
device.state.Lock()
372437
defer device.state.Unlock()

device/peer.go

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ package device
88
import (
99
"container/list"
1010
"errors"
11+
"net/netip"
12+
"slices"
1113
"sync"
1214
"sync/atomic"
1315
"time"
@@ -25,6 +27,12 @@ type Peer struct {
2527
rxBytes atomic.Uint64 // bytes received from peer
2628
lastHandshakeNano atomic.Int64 // nano seconds since epoch
2729

30+
// deleteOnIdle indicates whether the peer should be deleted when idle
31+
// because it was auto-created via a Device.PeerLookupFunc.
32+
//
33+
// This field should only be set once, before the peer is started.
34+
deleteOnIdle bool
35+
2836
endpoint struct {
2937
sync.Mutex
3038
val conn.Endpoint
@@ -44,7 +52,9 @@ type Peer struct {
4452
}
4553

4654
state struct {
47-
sync.Mutex // protects against concurrent Start/Stop
55+
sync.Mutex // protects against concurrent Start/Stop, and fields below
56+
57+
allowedIPs []netip.Prefix
4858
}
4959

5060
queue struct {
@@ -87,7 +97,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
8797
// map public key
8898
_, ok := device.peers.keyMap[pk]
8999
if ok {
90-
return nil, errors.New("adding existing peer")
100+
return nil, errAddExistingPeer
91101
}
92102

93103
// pre-compute DH
@@ -113,6 +123,18 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
113123
return peer, nil
114124
}
115125

126+
// SetAllowedIPs sets the allowed IP prefixes for this peer.
127+
func (p *Peer) SetAllowedIPs(allowedIPs []netip.Prefix) {
128+
p.state.Lock()
129+
defer p.state.Unlock()
130+
131+
if slices.Equal(p.state.allowedIPs, allowedIPs) {
132+
return
133+
}
134+
p.device.allowedips.SetPeerPrefixes(p, allowedIPs)
135+
p.state.allowedIPs = slices.Clone(allowedIPs) // avoid retaining caller's slice
136+
}
137+
116138
// SendBuffers sends buffers to peer. WireGuard packet data in each element of
117139
// buffers must be preceded by MessageEncapsulatingTransportSize number of
118140
// bytes.

device/timers.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package device
99

1010
import (
11+
"log"
1112
"sync"
1213
"time"
1314
_ "unsafe"
@@ -126,6 +127,13 @@ func expiredNewHandshake(peer *Peer) {
126127
func expiredZeroKeyMaterial(peer *Peer) {
127128
peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds()))
128129
peer.ZeroAndFlushAll()
130+
if peer.deleteOnIdle {
131+
toRemove := peer.handshake.remoteStatic
132+
go func() {
133+
peer.device.RemovePeer(toRemove)
134+
log.Printf("XXX expiredZeroKeyMaterial: removed idle peer %x", toRemove)
135+
}()
136+
}
129137
}
130138

131139
func expiredPersistentKeepalive(peer *Peer) {

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module github.com/tailscale/wireguard-go
22

3-
go 1.20
3+
go 1.25
44

55
require (
66
golang.org/x/crypto v0.13.0

0 commit comments

Comments
 (0)