Skip to content

Commit ae2de87

Browse files
committed
fix: ensure dht query is aborted on early exit
If query results are consumed from a `for await..of`-style loop, and that loop is exited from before the results are complete, ensure we abort any running sub-queries.
1 parent 3687f1e commit ae2de87

File tree

7 files changed

+191
-90
lines changed

7 files changed

+191
-90
lines changed

packages/kad-dht/src/query/manager.ts

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import { AbortError, TypedEventEmitter, CustomEvent, setMaxListeners } from '@libp2p/interface'
1+
import { TypedEventEmitter, CustomEvent, setMaxListeners } from '@libp2p/interface'
22
import { PeerSet } from '@libp2p/peer-collections'
33
import { anySignal } from 'any-signal'
44
import merge from 'it-merge'
5+
import { raceSignal } from 'race-signal'
56
import { toString as uint8ArrayToString } from 'uint8arrays/to-string'
67
import {
78
ALPHA, K, DEFAULT_QUERY_TIMEOUT
@@ -127,7 +128,16 @@ export class QueryManager implements Startable {
127128
}
128129
}
129130

130-
const signal = anySignal([this.shutDownController.signal, options.signal])
131+
// if the user breaks out of a for..await of loop iterating over query
132+
// results we need to cancel any in-flight network requests
133+
const queryEarlyExitController = new AbortController()
134+
setMaxListeners(Infinity, queryEarlyExitController.signal)
135+
136+
const signal = anySignal([
137+
this.shutDownController.signal,
138+
queryEarlyExitController.signal,
139+
options.signal
140+
])
131141

132142
// this signal will get listened to for every invocation of queryFunc
133143
// so make sure we don't make a lot of noise in the logs
@@ -138,19 +148,13 @@ export class QueryManager implements Startable {
138148
// query a subset of peers up to `kBucketSize / 2` in length
139149
const startTime = Date.now()
140150
const cleanUp = new TypedEventEmitter<CleanUpEvents>()
151+
let queryFinished = false
141152

142153
try {
143154
if (options.isSelfQuery !== true && this.initialQuerySelfHasRun != null) {
144155
log('waiting for initial query-self query before continuing')
145156

146-
await Promise.race([
147-
new Promise((resolve, reject) => {
148-
signal.addEventListener('abort', () => {
149-
reject(new AbortError('Query was aborted before self-query ran'))
150-
})
151-
}),
152-
this.initialQuerySelfHasRun.promise
153-
])
157+
await raceSignal(this.initialQuerySelfHasRun.promise, signal)
154158

155159
this.initialQuerySelfHasRun = undefined
156160
}
@@ -192,19 +196,26 @@ export class QueryManager implements Startable {
192196

193197
// Execute the query along each disjoint path and yield their results as they become available
194198
for await (const event of merge(...paths)) {
195-
yield event
196-
197199
if (event.name === 'QUERY_ERROR') {
198-
log('error', event.error)
200+
log.error('query error', event.error)
199201
}
202+
203+
yield event
200204
}
205+
206+
queryFinished = true
201207
} catch (err: any) {
202208
if (!this.running && err.code === 'ERR_QUERY_ABORTED') {
203209
// ignore query aborted errors that were thrown during query manager shutdown
204210
} else {
205211
throw err
206212
}
207213
} finally {
214+
if (!queryFinished) {
215+
log('query exited early')
216+
queryEarlyExitController.abort()
217+
}
218+
208219
signal.clear()
209220

210221
this.queries--

packages/kad-dht/test/query.spec.ts

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import { QueryManager, type QueryManagerInit } from '../src/query/manager.js'
1919
import { convertBuffer } from '../src/utils.js'
2020
import { createPeerId, createPeerIds } from './utils/create-peer-id.js'
2121
import { sortClosestPeers } from './utils/sort-closest-peers.js'
22-
import type { QueryFunc } from '../src/query/types.js'
22+
import type { QueryContext, QueryFunc } from '../src/query/types.js'
2323
import type { RoutingTable } from '../src/routing-table/index.js'
2424
import type { PeerId } from '@libp2p/interface'
2525

@@ -29,12 +29,9 @@ interface TopologyEntry {
2929
value?: Uint8Array
3030
closerPeers?: number[]
3131
event: QueryEvent
32+
context?: QueryContext
3233
}
33-
type Topology = Record<string, {
34-
delay?: number | undefined
35-
error?: Error | undefined
36-
event: QueryEvent
37-
}>
34+
type Topology = Record<string, TopologyEntry>
3835

3936
describe('QueryManager', () => {
4037
let ourPeerId: PeerId
@@ -55,7 +52,7 @@ describe('QueryManager', () => {
5552
}
5653

5754
function createTopology (opts: Record<number, { delay?: number, error?: Error, value?: Uint8Array, closerPeers?: number[] }>): Topology {
58-
const topology: Record<string, { delay?: number, error?: Error, event: QueryEvent }> = {}
55+
const topology: Topology = {}
5956

6057
Object.keys(opts).forEach(key => {
6158
const id = parseInt(key)
@@ -94,9 +91,12 @@ describe('QueryManager', () => {
9491
return topology
9592
}
9693

97-
function createQueryFunction (topology: Record<string, { delay?: number, event: QueryEvent }>): QueryFunc {
98-
const queryFunc: QueryFunc = async function * ({ peer }) {
94+
function createQueryFunction (topology: Topology): QueryFunc {
95+
const queryFunc: QueryFunc = async function * (context) {
96+
const { peer } = context
97+
9998
const res = topology[peer.toString()]
99+
res.context = context
100100

101101
if (res.delay != null) {
102102
await delay(res.delay)
@@ -870,4 +870,43 @@ describe('QueryManager', () => {
870870

871871
await manager.stop()
872872
})
873+
874+
it('should abort the query if we break out of the loop early', async () => {
875+
const manager = new QueryManager({
876+
peerId: ourPeerId,
877+
logger: defaultLogger()
878+
}, {
879+
...defaultInit(),
880+
disjointPaths: 2
881+
})
882+
await manager.start()
883+
884+
// 1 -> 0 [pathComplete]
885+
// 4 -> 3 [delay] -> 2 [pathComplete]
886+
const topology = createTopology({
887+
// quick value path
888+
0: { value: uint8ArrayFromString('true') },
889+
1: { closerPeers: [0] },
890+
// slow value path
891+
2: { value: uint8ArrayFromString('true') },
892+
3: { delay: 100, closerPeers: [2] },
893+
4: { closerPeers: [3] }
894+
})
895+
896+
routingTable.closestPeers.returns([peers[1], peers[4]])
897+
898+
for await (const event of manager.run(key, createQueryFunction(topology))) {
899+
if (event.name === 'VALUE') {
900+
expect(event.from.toString()).to.equal(peers[0].toString())
901+
902+
// break out of loop early
903+
break
904+
}
905+
}
906+
907+
// should have aborted query on slow path
908+
expect(topology[peers[3].toString()]).to.have.nested.property('context.signal.aborted', true)
909+
910+
await manager.stop()
911+
})
873912
})

packages/libp2p/src/connection-manager/dial-queue.ts

Lines changed: 97 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ import {
1515
MAX_PEER_ADDRS_TO_DIAL,
1616
LAST_DIAL_FAILURE_KEY
1717
} from './constants.js'
18-
import { combineSignals, resolveMultiaddrs } from './utils.js'
19-
import type { AddressSorter, AbortOptions, PendingDial, ComponentLogger, Logger, Connection, ConnectionGater, Metric, Metrics, PeerId, Address, PeerStore } from '@libp2p/interface'
18+
import { resolveMultiaddrs } from './utils.js'
19+
import type { AddressSorter, AbortOptions, ComponentLogger, Logger, Connection, ConnectionGater, Metrics, PeerId, Address, PeerStore, PeerRouting } from '@libp2p/interface'
2020
import type { TransportManager } from '@libp2p/interface-internal'
2121

2222
export interface PendingDialTarget {
@@ -59,15 +59,11 @@ interface DialQueueComponents {
5959
transportManager: TransportManager
6060
connectionGater: ConnectionGater
6161
logger: ComponentLogger
62+
peerRouting: PeerRouting
6263
}
6364

6465
export class DialQueue {
65-
public pendingDials: PendingDialInternal[]
66-
public queue: PQueue
67-
private readonly peerId: PeerId
68-
private readonly peerStore: PeerStore
69-
private readonly connectionGater: ConnectionGater
70-
private readonly transportManager: TransportManager
66+
public queue: Queue<Connection, DialQueueJobOptions>
7167
private readonly addressSorter: AddressSorter
7268
private readonly maxPeerAddrsToDial: number
7369
private readonly dialTimeout: number
@@ -76,6 +72,7 @@ export class DialQueue {
7672
private shutDownController: AbortController
7773
private readonly connections: PeerMap<Connection[]>
7874
private readonly log: Logger
75+
private readonly components: DialQueueComponents
7976

8077
constructor (components: DialQueueComponents, init: DialerInit = {}) {
8178
this.addressSorter = init.addressSorter ?? defaultOptions.addressSorter
@@ -84,10 +81,7 @@ export class DialQueue {
8481
this.connections = init.connections ?? new PeerMap()
8582
this.log = components.logger.forComponent('libp2p:connection-manager:dial-queue')
8683

87-
this.peerId = components.peerId
88-
this.peerStore = components.peerStore
89-
this.connectionGater = components.connectionGater
90-
this.transportManager = components.transportManager
84+
this.components = components
9185
this.shutDownController = new AbortController()
9286

9387
setMaxListeners(Infinity, this.shutDownController.signal)
@@ -242,6 +236,75 @@ export class DialQueue {
242236
// remove our pending dial entry
243237
this.pendingDials = this.pendingDials.filter(p => p.id !== pendingDial.id)
244238

239+
try {
240+
// load addresses from address book, resolve and dnsaddrs, filter
241+
// undiallables, add peer IDs, etc
242+
addrsToDial = await this.calculateMultiaddrs(peerId, options?.multiaddrs, {
243+
...options,
244+
signal
245+
})
246+
247+
addrsToDial.map(({ multiaddr }) => multiaddr.toString()).forEach(addr => {
248+
options?.multiaddrs.add(addr)
249+
})
250+
} catch (err) {
251+
signal.clear()
252+
throw err
253+
}
254+
255+
try {
256+
let dialed = 0
257+
const errors: Error[] = []
258+
259+
for (const address of addrsToDial) {
260+
if (dialed === this.maxPeerAddrsToDial) {
261+
this.log('dialed maxPeerAddrsToDial (%d) addresses for %p, not trying any others', dialed, peerId)
262+
263+
throw new CodeError('Peer had more than maxPeerAddrsToDial', codes.ERR_TOO_MANY_ADDRESSES)
264+
}
265+
266+
dialed++
267+
268+
try {
269+
const conn = await this.components.transportManager.dial(address.multiaddr, {
270+
...options,
271+
signal
272+
})
273+
274+
this.log('dial to %a succeeded', address.multiaddr)
275+
276+
return conn
277+
} catch (err: any) {
278+
this.log.error('dial failed to %a', address.multiaddr, err)
279+
280+
if (peerId != null) {
281+
// record the failed dial
282+
try {
283+
await this.components.peerStore.patch(peerId, {
284+
metadata: {
285+
[LAST_DIAL_FAILURE_KEY]: uint8ArrayFromString(Date.now().toString())
286+
}
287+
})
288+
} catch (err: any) {
289+
this.log.error('could not update last dial failure key for %p', peerId, err)
290+
}
291+
}
292+
293+
// the user/dial timeout/shutdown controller signal aborted
294+
if (signal.aborted) {
295+
throw new CodeError(err.message, ERR_TIMEOUT)
296+
}
297+
298+
errors.push(err)
299+
}
300+
}
301+
302+
if (errors.length === 1) {
303+
throw errors[0]
304+
}
305+
306+
throw new AggregateCodeError(errors, 'All multiaddr dials failed', codes.ERR_TRANSPORT_DIAL_FAILED)
307+
} finally {
245308
// clean up abort signals/controllers
246309
signal.clear()
247310
})
@@ -315,19 +378,20 @@ export class DialQueue {
315378
private async calculateMultiaddrs (peerId?: PeerId, addrs: Address[] = [], options: DialOptions = {}): Promise<Address[]> {
316379
// if a peer id or multiaddr(s) with a peer id, make sure it isn't our peer id and that we are allowed to dial it
317380
if (peerId != null) {
318-
if (this.peerId.equals(peerId)) {
381+
if (this.components.peerId.equals(peerId)) {
319382
throw new CodeError('Tried to dial self', codes.ERR_DIALED_SELF)
320383
}
321384

322-
if ((await this.connectionGater.denyDialPeer?.(peerId)) === true) {
385+
if ((await this.components.connectionGater.denyDialPeer?.(peerId)) === true) {
323386
throw new CodeError('The dial request is blocked by gater.allowDialPeer', codes.ERR_PEER_DIAL_INTERCEPTED)
324387
}
325388

326-
// if just a peer id was passed, load available multiaddrs for this peer from the address book
389+
// if just a peer id was passed, load available multiaddrs for this peer
390+
// from the peer store
327391
if (addrs.length === 0) {
328392
this.log('loading multiaddrs for %p', peerId)
329393
try {
330-
const peer = await this.peerStore.get(peerId)
394+
const peer = await this.components.peerStore.get(peerId)
331395
addrs.push(...peer.addresses)
332396
this.log('loaded multiaddrs for %p', peerId, addrs.map(({ multiaddr }) => multiaddr.toString()))
333397
} catch (err: any) {
@@ -336,9 +400,23 @@ export class DialQueue {
336400
}
337401
}
338402
}
403+
404+
// if the peer store had no addresses for the peer, try to find them via
405+
// peer routing
406+
if (addrs.length === 0) {
407+
const peer = await this.components.peerRouting.findPeer(peerId, options)
408+
409+
peer.multiaddrs.forEach(multiaddr => {
410+
addrs.push({
411+
multiaddr,
412+
isCertified: false
413+
})
414+
})
415+
}
339416
}
340417

341-
// resolve addresses - this can result in a one-to-many translation when dnsaddrs are resolved
418+
// resolve addresses - this can result in a one-to-many translation when
419+
// dnsaddrs are resolved
342420
let resolvedAddresses = (await Promise.all(
343421
addrs.map(async addr => {
344422
const result = await resolveMultiaddrs(addr.multiaddr, {
@@ -383,7 +461,7 @@ export class DialQueue {
383461

384462
const filteredAddrs = resolvedAddresses.filter(addr => {
385463
// filter out any multiaddrs that we do not have transports for
386-
if (this.transportManager.transportForMultiaddr(addr.multiaddr) == null) {
464+
if (this.components.transportManager.transportForMultiaddr(addr.multiaddr) == null) {
387465
return false
388466
}
389467

@@ -433,7 +511,7 @@ export class DialQueue {
433511
const gatedAdrs: Address[] = []
434512

435513
for (const addr of dedupedMultiaddrs) {
436-
if (this.connectionGater.denyDialMultiaddr != null && await this.connectionGater.denyDialMultiaddr(addr.multiaddr)) {
514+
if (this.components.connectionGater.denyDialMultiaddr != null && await this.components.connectionGater.denyDialMultiaddr(addr.multiaddr)) {
437515
continue
438516
}
439517

0 commit comments

Comments
 (0)