diff --git a/device/allowedips.go b/device/allowedips.go index fa46f97c1..682d3d949 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -205,14 +205,14 @@ func (node *trieEntry) lookup(ip []byte) *Peer { } type AllowedIPs struct { - IPv4 *trieEntry - IPv6 *trieEntry - mutex sync.RWMutex + mu sync.RWMutex + ipv4 *trieEntry + ipv6 *trieEntry } func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { - table.mutex.RLock() - defer table.mutex.RUnlock() + table.mu.RLock() + defer table.mu.RUnlock() for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { node := elem.Value.(*trieEntry) @@ -223,10 +223,23 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) } } +func (table *AllowedIPs) SetPeerPrefixes(peer *Peer, prefixes []netip.Prefix) { + table.mu.Lock() + defer table.mu.Unlock() + + table.removeByPeerLocked(peer) + for _, prefix := range prefixes { + table.insertLocked(prefix, peer) + } +} + func (table *AllowedIPs) RemoveByPeer(peer *Peer) { - table.mutex.Lock() - defer table.mutex.Unlock() + table.mu.Lock() + defer table.mu.Unlock() + table.removeByPeerLocked(peer) +} +func (table *AllowedIPs) removeByPeerLocked(peer *Peer) { var next *list.Element for elem := peer.trieEntries.Front(); elem != nil; elem = next { next = elem.Next() @@ -266,28 +279,31 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { } func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { - table.mutex.Lock() - defer table.mutex.Unlock() + table.mu.Lock() + defer table.mu.Unlock() + table.insertLocked(prefix, peer) +} +func (table *AllowedIPs) insertLocked(prefix netip.Prefix, peer *Peer) { if prefix.Addr().Is6() { ip := prefix.Addr().As16() - parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + parentIndirection{&table.ipv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) } else if prefix.Addr().Is4() { ip := prefix.Addr().As4() - parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + parentIndirection{&table.ipv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) } else { panic(errors.New("inserting unknown address type")) } } func (table *AllowedIPs) Lookup(ip []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() + table.mu.RLock() + defer table.mu.RUnlock() switch len(ip) { case net.IPv6len: - return table.IPv6.lookup(ip) + return table.ipv6.lookup(ip) case net.IPv4len: - return table.IPv4.lookup(ip) + return table.ipv4.lookup(ip) default: panic(errors.New("looking up unknown address type")) } diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index 07065c30a..d43d61657 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -135,7 +135,7 @@ func TestTrieRandom(t *testing.T) { allowedIPs.RemoveByPeer(peers[p]) } - if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { + if allowedIPs.ipv4 != nil || allowedIPs.ipv6 != nil { t.Error("Failed to remove all nodes from trie by peer") } } diff --git a/device/allowedips_test.go b/device/allowedips_test.go index cde068ec3..04ae9cb43 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -166,7 +166,7 @@ func TestTrieIPv4(t *testing.T) { allowedIPs.RemoveByPeer(e) allowedIPs.RemoveByPeer(g) allowedIPs.RemoveByPeer(h) - if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { + if allowedIPs.ipv4 != nil || allowedIPs.ipv6 != nil { t.Error("Expected removing all the peers to empty trie, but it did not") } diff --git a/device/device.go b/device/device.go index 5b2348564..dfd95ac0d 100644 --- a/device/device.go +++ b/device/device.go @@ -6,6 +6,8 @@ package device import ( + "errors" + "net/netip" "runtime" "sync" "sync/atomic" @@ -56,6 +58,7 @@ type Device struct { peers struct { sync.RWMutex // protects keyMap keyMap map[NoisePublicKey]*Peer + lookupFunc PeerLookupFunc // or nil if unused } rate struct { @@ -91,6 +94,10 @@ type Device struct { log *Logger } +func (device *Device) AllowedIPs() *AllowedIPs { + return &device.allowedips +} + // deviceState represents the state of a Device. // There are three states: down, up, closed. // Transitions: @@ -338,13 +345,63 @@ func (device *Device) BatchSize() int { return size } +// LookupPeer looks up a peer by its public key. +// +// If the peer does not exist and a [PeerLookupFunc] is set (via +// [Device.SetPeerLookupFunc]), then that function is used to create the peer +// before returning it. Peers created via this mechanism exist only their state +// machine reaches idle, and then the peers are removed. +// +// If the peer does not exist and no [PeerLookupFunc] is set, nil is returned. +// +// Use [Device.LookupActivePeer] to only return already-existing peers, without +// using a [PeerLookupFunc]. func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { device.peers.RLock() - defer device.peers.RUnlock() + p, ok := device.peers.keyMap[pk] + lookupFunc := device.peers.lookupFunc + device.peers.RUnlock() + if ok || lookupFunc == nil { + return p + } + + allowedIPs := lookupFunc(pk) + if allowedIPs == nil { + return nil + } + + p, err := device.NewPeer(pk) + if err != nil { + if errors.Is(err, errAddExistingPeer) { + device.peers.RLock() + defer device.peers.RUnlock() + return device.peers.keyMap[pk] + } + device.log.Errorf("Failed to create peer: %v", err) + return nil + } + p.SetAllowedIPs(allowedIPs) + p.deleteOnIdle = true + p.Start() + return p +} - return device.peers.keyMap[pk] +// LookupActivePeer looks up a peer by its public key. +// +// Unlike [Device.LookupPeer], this function does not use a [PeerLookupFunc] to +// create the peer if it does not already exist. +// +// If the peer does not exist or was created lazily via [PeerLookupFunc] +// and has subsequently idled away, it returns (nil, false). +func (device *Device) LookupActivePeer(pk NoisePublicKey) (_ *Peer, ok bool) { + device.peers.RLock() + defer device.peers.RUnlock() + p, ok := device.peers.keyMap[pk] + return p, ok } +var errAddExistingPeer = errors.New("adding existing peer") + func (device *Device) RemovePeer(key NoisePublicKey) { device.peers.Lock() defer device.peers.Unlock() @@ -367,6 +424,41 @@ func (device *Device) RemoveAllPeers() { device.peers.keyMap = make(map[NoisePublicKey]*Peer) } +// RemoveMatchingPeers removes all peers for which shouldRemove returns true. +// +// It returns the number of peers removed. +func (device *Device) RemoveMatchingPeers(shouldRemove func(NoisePublicKey) bool) (numRemoved int) { + device.peers.Lock() + defer device.peers.Unlock() + + for key, peer := range device.peers.keyMap { + if shouldRemove(key) { + removePeerLocked(device, peer, key) + numRemoved++ + } + } + return numRemoved +} + +// PeerLookupFunc is the type of function used to look up peers by public key +// when receiving packets for unknown peers. +// +// If it returns nil, the peer is not known. +// +// Otherwise, returning non-nil signals that wireguard-go should create the peer +// with the provided allowed IPs. +// +// See [Device.SetPeerLookupFunc] and [Device.LookupPeer]. +type PeerLookupFunc func(NoisePublicKey) (allowedIPs []netip.Prefix) + +// SetPeerLookupFunc sets the function used to look up peers by public key +// when receiving packets for unknown peers. +func (device *Device) SetPeerLookupFunc(f PeerLookupFunc) { + device.peers.Lock() + defer device.peers.Unlock() + device.peers.lookupFunc = f +} + func (device *Device) Close() { device.state.Lock() defer device.state.Unlock() diff --git a/device/peer.go b/device/peer.go index 064feb22b..181f328de 100644 --- a/device/peer.go +++ b/device/peer.go @@ -8,6 +8,8 @@ package device import ( "container/list" "errors" + "net/netip" + "slices" "sync" "sync/atomic" "time" @@ -25,6 +27,12 @@ type Peer struct { rxBytes atomic.Uint64 // bytes received from peer lastHandshakeNano atomic.Int64 // nano seconds since epoch + // deleteOnIdle indicates whether the peer should be deleted when idle + // because it was auto-created via a Device.PeerLookupFunc. + // + // This field should only be set once, before the peer is started. + deleteOnIdle bool + endpoint struct { sync.Mutex val conn.Endpoint @@ -44,7 +52,9 @@ type Peer struct { } state struct { - sync.Mutex // protects against concurrent Start/Stop + sync.Mutex // protects against concurrent Start/Stop, and fields below + + allowedIPs []netip.Prefix } queue struct { @@ -87,7 +97,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // map public key _, ok := device.peers.keyMap[pk] if ok { - return nil, errors.New("adding existing peer") + return nil, errAddExistingPeer } // pre-compute DH @@ -113,6 +123,18 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } +// SetAllowedIPs sets the allowed IP prefixes for this peer. +func (p *Peer) SetAllowedIPs(allowedIPs []netip.Prefix) { + p.state.Lock() + defer p.state.Unlock() + + if slices.Equal(p.state.allowedIPs, allowedIPs) { + return + } + p.device.allowedips.SetPeerPrefixes(p, allowedIPs) + p.state.allowedIPs = slices.Clone(allowedIPs) // avoid retaining caller's slice +} + // SendBuffers sends buffers to peer. WireGuard packet data in each element of // buffers must be preceded by MessageEncapsulatingTransportSize number of // bytes. diff --git a/device/timers.go b/device/timers.go index d4a4ed4e5..dcfce85c6 100644 --- a/device/timers.go +++ b/device/timers.go @@ -8,6 +8,7 @@ package device import ( + "log" "sync" "time" _ "unsafe" @@ -126,6 +127,13 @@ func expiredNewHandshake(peer *Peer) { func expiredZeroKeyMaterial(peer *Peer) { peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds())) peer.ZeroAndFlushAll() + if peer.deleteOnIdle { + toRemove := peer.handshake.remoteStatic + go func() { + peer.device.RemovePeer(toRemove) + log.Printf("expiredZeroKeyMaterial: removed idle lazy peer %x", toRemove) + }() + } } func expiredPersistentKeepalive(peer *Peer) { diff --git a/go.mod b/go.mod index 9c9b02a66..37b0b6daf 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/tailscale/wireguard-go -go 1.20 +go 1.25 require ( golang.org/x/crypto v0.13.0