66package device
77
88import (
9+ "errors"
10+ "net/netip"
911 "runtime"
1012 "sync"
1113 "sync/atomic"
@@ -56,6 +58,7 @@ type Device struct {
5658 peers struct {
5759 sync.RWMutex // protects keyMap
5860 keyMap map [NoisePublicKey ]* Peer
61+ lookupFunc PeerLookupFunc // or nil if unused
5962 }
6063
6164 rate struct {
@@ -91,6 +94,10 @@ type Device struct {
9194 log * Logger
9295}
9396
97+ func (device * Device ) AllowedIPs () * AllowedIPs {
98+ return & device .allowedips
99+ }
100+
94101// deviceState represents the state of a Device.
95102// There are three states: down, up, closed.
96103// Transitions:
@@ -340,11 +347,36 @@ func (device *Device) BatchSize() int {
340347
341348func (device * Device ) LookupPeer (pk NoisePublicKey ) * Peer {
342349 device .peers .RLock ()
343- defer device .peers .RUnlock ()
350+ p , ok := device .peers .keyMap [pk ]
351+ lookupFunc := device .peers .lookupFunc
352+ device .peers .RUnlock ()
353+ if ok || lookupFunc == nil {
354+ return p
355+ }
344356
345- return device .peers .keyMap [pk ]
357+ allowedIPs := lookupFunc (pk )
358+ if allowedIPs == nil {
359+ return nil
360+ }
361+
362+ p , err := device .NewPeer (pk )
363+ if err != nil {
364+ if errors .Is (err , errAddExistingPeer ) {
365+ device .peers .RLock ()
366+ defer device .peers .RUnlock ()
367+ return device .peers .keyMap [pk ]
368+ }
369+ device .log .Errorf ("Failed to create peer: %v" , err )
370+ return nil
371+ }
372+ p .SetAllowedIPs (allowedIPs )
373+ p .deleteOnIdle = true
374+ p .Start ()
375+ return p
346376}
347377
378+ var errAddExistingPeer = errors .New ("adding existing peer" )
379+
348380func (device * Device ) RemovePeer (key NoisePublicKey ) {
349381 device .peers .Lock ()
350382 defer device .peers .Unlock ()
@@ -367,6 +399,39 @@ func (device *Device) RemoveAllPeers() {
367399 device .peers .keyMap = make (map [NoisePublicKey ]* Peer )
368400}
369401
402+ // RemoveMatchingPeers removes all peers for which shouldRemove returns true.
403+ //
404+ // It returns the number of peers removed.
405+ func (device * Device ) RemoveMatchingPeers (shouldRemove func (NoisePublicKey ) bool ) (numRemoved int ) {
406+ device .peers .Lock ()
407+ defer device .peers .Unlock ()
408+
409+ for key , peer := range device .peers .keyMap {
410+ if shouldRemove (key ) {
411+ removePeerLocked (device , peer , key )
412+ numRemoved ++
413+ }
414+ }
415+ return numRemoved
416+ }
417+
418+ // PeerLookupFunc is the type of function used to look up peers by public key
419+ // when receiving packets for unknown peers.
420+ //
421+ // If it returns nil, the peer is not known.
422+ //
423+ // Otherwise, returning non-nil signals that wireguard-go should create the peer
424+ // with the provided allowed IPs.
425+ type PeerLookupFunc func (NoisePublicKey ) (allowedIPs []netip.Prefix )
426+
427+ // SetPeerLookupFunc sets the function used to look up peers by public key
428+ // when receiving packets for unknown peers.
429+ func (device * Device ) SetPeerLookupFunc (f PeerLookupFunc ) {
430+ device .peers .Lock ()
431+ defer device .peers .Unlock ()
432+ device .peers .lookupFunc = f
433+ }
434+
370435func (device * Device ) Close () {
371436 device .state .Lock ()
372437 defer device .state .Unlock ()
0 commit comments