diff --git a/packages/interface-compliance-tests/src/mocks/transport-manager.ts b/packages/interface-compliance-tests/src/mocks/transport-manager.ts new file mode 100644 index 0000000000..3784513ce5 --- /dev/null +++ b/packages/interface-compliance-tests/src/mocks/transport-manager.ts @@ -0,0 +1,231 @@ +import type { TransportManager } from "@libp2p/interface-internal/transport-manager" +import type { EventEmitter } from "@libp2p/interface/events" +import type { Libp2pEvents } from "@libp2p/interface" +import type { Startable } from "@libp2p/interface/startable" +import { FaultTolerance, type Listener, type Transport, type Upgrader } from "@libp2p/interface/transport" +import type { Connection } from "@libp2p/interface/src/connection" +import type { Multiaddr } from "@multiformats/multiaddr" +import { CodeError } from "@libp2p/interface/errors" + +export interface MockTransportManagerComponents { + events: EventEmitter + upgrader: Upgrader +} + +class MockTransportManager implements TransportManager, Startable { + private readonly components: MockTransportManagerComponents + private readonly transports: Map + private readonly listeners: Map + private readonly faultTolerance: FaultTolerance + private started: boolean + + constructor(components: MockTransportManagerComponents) { + this.components = components + this.started = false + this.transports = new Map() + this.listeners = new Map() + this.faultTolerance = FaultTolerance.FATAL_ALL + } + + isStarted(): boolean { + return this.started + } + + async start(): Promise { + this.started = true + } + + async stop(): Promise { + const tasks = [] + for (const [_, listeners] of this.listeners) { + while (listeners.length > 0) { + const listener = listeners.pop() + + if (listener == null) { + continue + } + + tasks.push(listener.close()) + } + } + + await Promise.all(tasks) + for (const key of this.listeners.keys()) { + this.listeners.set(key, []) + } + + this.started = false + } + + getAddrs(): Multiaddr[] { + let addrs: Multiaddr[] = [] + for (const listeners of this.listeners.values()) { + for (const listener of listeners) { + addrs = [...addrs, ...listener.getAddrs()] + } + } + return addrs + } + + /** + * Returns all the transports instances + */ + getTransports(): Transport[] { + return Array.of(...this.transports.values()) + } + + /** + * Returns all the listener instances + */ + getListeners(): Listener[] { + return Array.of(...this.listeners.values()).flat() + } + + add(transport: Transport): void { + const tag = transport[Symbol.toStringTag] + + if (tag == null) { + throw new CodeError("Transport must have a valid tag", "INVALID_TAG") + } + + if (this.transports.has(tag)) { + throw new CodeError(`There is already a transport with the tag ${tag}`, "DUPLICATE") + } + + this.transports.set(tag, transport) + + if (!this.listeners.has(tag)) { + this.listeners.set(tag, []) + } + } + + transportForMultiaddr(ma: Multiaddr): Transport | undefined { + for (const transport of this.transports.values()) { + const addrs = transport.filter([ma]) + + if (addrs.length > 0) { + return transport + } + } + } + + async dial(ma: Multiaddr, options?: any): Promise { + const transport = this.transportForMultiaddr(ma) + + if (transport == null) { + throw new CodeError(`No transport available for address ${String(ma)}`, "TRANSPORT_UNAVAILABLE") + } + + try { + return await transport.dial(ma, { + ...options, + upgrader: this.components.upgrader, + }) + } catch (err: any) { + if (err.code == null) { + err.code = "TRANSPORT_DIAL_FAILED" + } + + throw err + } + } + + /** + * Starts listeners for each listen Multiaddr + */ + async listen(addrs: Multiaddr[]): Promise { + if (!this.isStarted()) { + throw new CodeError("Not started", "ERR_NODE_NOT_STARTED") + } + + if (addrs == null || addrs.length === 0) { + return + } + + const couldNotListen = [] + + for (const [key, transport] of this.transports.entries()) { + const supportedAddrs = transport.filter(addrs) + const tasks = [] + + // For each supported multiaddr, create a listener + for (const addr of supportedAddrs) { + const listener = transport.createListener({ + upgrader: this.components.upgrader, + }) + + let listeners: Listener[] = this.listeners.get(key) ?? [] + + if (listeners == null) { + listeners = [] + this.listeners.set(key, listeners) + } + + listeners.push(listener) + + // Track listen/close events + listener.addEventListener("listening", () => { + this.components.events.safeDispatchEvent("transport:listening", { + detail: listener, + }) + }) + listener.addEventListener("close", () => { + const index = listeners.findIndex((l) => l === listener) + + // remove the listener + listeners.splice(index, 1) + + this.components.events.safeDispatchEvent("transport:close", { + detail: listener, + }) + }) + + // We need to attempt to listen on everything + tasks.push(listener.listen(addr)) + } + + // Keep track of transports we had no addresses for + if (tasks.length === 0) { + couldNotListen.push(key) + continue + } + + const results = await Promise.allSettled(tasks) + + const isListening = results.find((r) => r.status === "fulfilled") + if (isListening == null && this.faultTolerance !== FaultTolerance.NO_FATAL) { + throw new CodeError(`Transport (${key}) could not listen on any available address`, "ERR_NO_VALID_ADDRESSES") + } + } + + if (couldNotListen.length === this.transports.size) { + const message = `no valid addresses were provided for transports [${couldNotListen.join(", ")}]` + if (this.faultTolerance === FaultTolerance.FATAL_ALL) { + throw new CodeError(message, "ERR_NO_VALID_ADDRESSES") + } + } + } + + async remove(key: string): Promise { + // Close any running listeners + for (const listener of this.listeners.get(key) ?? []) { + await listener.close() + } + + this.transports.delete(key) + this.listeners.delete(key) + } + + async removeAll(): Promise { + const tasks = [] + for (const key of this.transports.keys()) { + tasks.push(this.remove(key)) + } + + await Promise.all(tasks) + } +} + +export function mockTransportManager(components: MockTransportManagerComponents): TransportManager { + return new MockTransportManager(components) +} diff --git a/packages/interface-compliance-tests/src/transport/dial-test.ts b/packages/interface-compliance-tests/src/transport/dial-test.ts index 2139a755d0..50e95d098f 100644 --- a/packages/interface-compliance-tests/src/transport/dial-test.ts +++ b/packages/interface-compliance-tests/src/transport/dial-test.ts @@ -19,9 +19,11 @@ export default (common: TestSetup): void => { let upgrader: Upgrader let registrar: Registrar let addrs: Multiaddr[] + let listeningAddrs: Multiaddr[] let transport: Transport let connector: Connector let listener: Listener + let hasListener: boolean before(async () => { registrar = mockRegistrar() @@ -30,7 +32,7 @@ export default (common: TestSetup): void => { events: new EventEmitter() }); - ({ addrs, transport, connector } = await common.setup()) + ({ addrs, transport, connector, listeningAddrs =[], hasListener = true } = await common.setup()) }) after(async () => { @@ -38,16 +40,19 @@ export default (common: TestSetup): void => { }) beforeEach(async () => { - listener = transport.createListener({ - upgrader - }) - await listener.listen(addrs[0]) + if (hasListener) { + listener = transport.createListener({ + upgrader + }) + listeningAddrs.length > 0 ? await listener.listen(listeningAddrs[0]) : await listener.listen(addrs[0]) + } }) afterEach(async () => { sinon.restore() connector.restore() - await listener.close() + if (hasListener) + await listener.close() }) it('simple', async () => { @@ -56,13 +61,13 @@ export default (common: TestSetup): void => { void pipe([ uint8ArrayFromString('hey') ], - data.stream, - drain + data.stream, + drain ) }) const upgradeSpy = sinon.spy(upgrader, 'upgradeOutbound') - const conn = await transport.dial(addrs[0], { + const conn = await transport.dial(listeningAddrs[0], { upgrader }) diff --git a/packages/interface-compliance-tests/src/transport/index.ts b/packages/interface-compliance-tests/src/transport/index.ts index d2b9525687..46d6e491e6 100644 --- a/packages/interface-compliance-tests/src/transport/index.ts +++ b/packages/interface-compliance-tests/src/transport/index.ts @@ -12,8 +12,10 @@ export interface Connector { export interface TransportTestFixtures { addrs: Multiaddr[] + listeningAddrs?: Multiaddr[] transport: Transport connector: Connector + hasListener?: boolean } export default (common: TestSetup): void => { diff --git a/packages/interface-compliance-tests/src/transport/listen-test.ts b/packages/interface-compliance-tests/src/transport/listen-test.ts index d518511dc7..b22f36f13a 100644 --- a/packages/interface-compliance-tests/src/transport/listen-test.ts +++ b/packages/interface-compliance-tests/src/transport/listen-test.ts @@ -20,6 +20,7 @@ export default (common: TestSetup): void => { describe('listen', () => { let upgrader: Upgrader let addrs: Multiaddr[] + let listeningAddrs: Multiaddr[] let transport: Transport let registrar: Registrar @@ -30,7 +31,7 @@ export default (common: TestSetup): void => { events: new EventEmitter() }); - ({ transport, addrs } = await common.setup()) + ({ transport, addrs, listeningAddrs = [] } = await common.setup()) }) after(async () => { @@ -45,7 +46,7 @@ export default (common: TestSetup): void => { const listener = transport.createListener({ upgrader }) - await listener.listen(addrs[0]) + listeningAddrs.length > 0 ? await listener.listen(listeningAddrs[0]) : await listener.listen(addrs[0]) await listener.close() }) @@ -66,7 +67,7 @@ export default (common: TestSetup): void => { }) // Listen - await listener.listen(addrs[0]) + listeningAddrs.length > 0 ? await listener.listen(listeningAddrs[0]) : await listener.listen(addrs[0]) // Create two connections to the listener const [conn1] = await Promise.all([ @@ -113,7 +114,7 @@ export default (common: TestSetup): void => { }) // Listen - await listener.listen(addrs[0]) + listeningAddrs.length > 0 ? await listener.listen(listeningAddrs[0]) : await listener.listen(addrs[0]) // Create a connection to the listener const conn = await transport.dial(addrs[0], { @@ -139,7 +140,7 @@ export default (common: TestSetup): void => { }) void (async () => { - await listener.listen(addrs[0]) + listeningAddrs.length > 0 ? await listener.listen(listeningAddrs[0]) : await listener.listen(addrs[0]) await transport.dial(addrs[0], { upgrader }) @@ -159,7 +160,8 @@ export default (common: TestSetup): void => { listener.addEventListener('listening', () => { listener.close().then(done, done) }) - void listener.listen(addrs[0]) + const addrToListenOn = listeningAddrs.length > 0 ? listeningAddrs[0] : addrs[0] + void listener.listen(addrToListenOn) }) it('error', (done) => { @@ -182,7 +184,7 @@ export default (common: TestSetup): void => { listener.addEventListener('close', () => { done() }) void (async () => { - await listener.listen(addrs[0]) + listeningAddrs.length > 0 ? await listener.listen(listeningAddrs[0]) : await listener.listen(addrs[0]) await listener.close() })() }) diff --git a/packages/interface/package.json b/packages/interface/package.json index f4365c77f1..ae42e6a109 100644 --- a/packages/interface/package.json +++ b/packages/interface/package.json @@ -163,11 +163,15 @@ "it-stream-types": "^2.0.1", "multiformats": "^12.0.1", "p-defer": "^4.0.0", + "race-signal": "^1.0.0", "uint8arraylist": "^2.4.3" }, "devDependencies": { "@types/sinon": "^10.0.15", "aegir": "^40.0.8", + "delay": "^6.0.0", + "it-all": "^3.0.3", + "it-drain": "^3.0.3", "sinon": "^16.0.0", "sinon-ts": "^1.0.0" } diff --git a/packages/interface/src/stream-muxer/stream.ts b/packages/interface/src/stream-muxer/stream.ts index bac8c7f729..65d00cc842 100644 --- a/packages/interface/src/stream-muxer/stream.ts +++ b/packages/interface/src/stream-muxer/stream.ts @@ -1,12 +1,14 @@ import { abortableSource } from 'abortable-iterator' import { type Pushable, pushable } from 'it-pushable' import defer, { type DeferredPromise } from 'p-defer' +import { raceSignal } from 'race-signal' import { Uint8ArrayList } from 'uint8arraylist' import { CodeError } from '../errors.js' import type { Direction, ReadStatus, Stream, StreamStatus, StreamTimeline, WriteStatus } from '../connection/index.js' import type { AbortOptions } from '../index.js' import type { Source } from 'it-stream-types' +// copied from @libp2p/logger to break a circular dependency interface Logger { (formatter: any, ...args: any[]): void error: (formatter: any, ...args: any[]) => void @@ -128,7 +130,6 @@ export abstract class AbstractStream implements Stream { this.log.trace('source ended') } - this.readStatus = 'closed' this.onSourceEnd(err) } }) @@ -173,11 +174,17 @@ export abstract class AbstractStream implements Stream { } } - this.log.trace('sink finished reading from source') - this.writeStatus = 'done' + this.log.trace('sink finished reading from source, write status is "%s"', this.writeStatus) + + if (this.writeStatus === 'writing') { + this.writeStatus = 'closing' + + this.log.trace('send close write to remote') + await this.sendCloseWrite(options) + + this.writeStatus = 'closed' + } - this.log.trace('sink calling closeWrite') - await this.closeWrite(options) this.onSinkEnd() } catch (err: any) { this.log.trace('sink ended with error, calling abort with error', err) @@ -196,6 +203,7 @@ export abstract class AbstractStream implements Stream { } this.timeline.closeRead = Date.now() + this.readStatus = 'closed' if (err != null && this.endErr == null) { this.endErr = err @@ -221,6 +229,7 @@ export abstract class AbstractStream implements Stream { } this.timeline.closeWrite = Date.now() + this.writeStatus = 'closed' if (err != null && this.endErr == null) { this.endErr = err @@ -266,16 +275,16 @@ export abstract class AbstractStream implements Stream { const readStatus = this.readStatus this.readStatus = 'closing' - if (readStatus === 'ready') { - this.log.trace('ending internal source queue') - this.streamSource.end() - } - if (this.status !== 'reset' && this.status !== 'aborted' && this.timeline.closeRead == null) { this.log.trace('send close read to remote') await this.sendCloseRead(options) } + if (readStatus === 'ready') { + this.log.trace('ending internal source queue') + this.streamSource.end() + } + this.log.trace('closed readable end of stream') } @@ -286,16 +295,13 @@ export abstract class AbstractStream implements Stream { this.log.trace('closing writable end of stream with starting write status "%s"', this.writeStatus) - const writeStatus = this.writeStatus - if (this.writeStatus === 'ready') { this.log.trace('sink was never sunk, sink an empty array') - await this.sink([]) - } - this.writeStatus = 'closing' + await raceSignal(this.sink([]), options.signal) + } - if (writeStatus === 'writing') { + if (this.writeStatus === 'writing') { // stop reading from the source passed to `.sink` in the microtask queue // - this lets any data queued by the user in the current tick get read // before we exit @@ -303,16 +309,12 @@ export abstract class AbstractStream implements Stream { queueMicrotask(() => { this.log.trace('aborting source passed to .sink') this.sinkController.abort() - this.sinkEnd.promise.then(resolve, reject) + raceSignal(this.sinkEnd.promise, options.signal) + .then(resolve, reject) }) }) } - if (this.status !== 'reset' && this.status !== 'aborted' && this.timeline.closeWrite == null) { - this.log.trace('send close write to remote') - await this.sendCloseWrite(options) - } - this.writeStatus = 'closed' this.log.trace('closed writable end of stream') @@ -357,6 +359,7 @@ export abstract class AbstractStream implements Stream { const err = new CodeError('stream reset', ERR_STREAM_RESET) this.status = 'reset' + this.timeline.reset = Date.now() this._closeSinkAndSource(err) this.onReset?.() } @@ -423,7 +426,7 @@ export abstract class AbstractStream implements Stream { return } - this.log.trace('muxer destroyed') + this.log.trace('stream destroyed') this._closeSinkAndSource() } diff --git a/packages/interface/test/fixtures/logger.ts b/packages/interface/test/fixtures/logger.ts new file mode 100644 index 0000000000..bd614bd944 --- /dev/null +++ b/packages/interface/test/fixtures/logger.ts @@ -0,0 +1,16 @@ +// copied from @libp2p/logger to break a circular dependency +interface Logger { + (): void + error: () => void + trace: () => void + enabled: boolean +} + +export function logger (): Logger { + const output = (): void => {} + output.trace = (): void => {} + output.error = (): void => {} + output.enabled = false + + return output +} diff --git a/packages/interface/test/stream-muxer/stream.spec.ts b/packages/interface/test/stream-muxer/stream.spec.ts new file mode 100644 index 0000000000..eaeb33f04a --- /dev/null +++ b/packages/interface/test/stream-muxer/stream.spec.ts @@ -0,0 +1,196 @@ +import { expect } from 'aegir/chai' +import delay from 'delay' +import all from 'it-all' +import drain from 'it-drain' +import Sinon from 'sinon' +import { Uint8ArrayList } from 'uint8arraylist' +import { AbstractStream } from '../../src/stream-muxer/stream.js' +import { logger } from '../fixtures/logger.js' +import type { AbortOptions } from '../../src/index.js' + +class TestStream extends AbstractStream { + async sendNewStream (options?: AbortOptions): Promise { + + } + + async sendData (buf: Uint8ArrayList, options?: AbortOptions): Promise { + + } + + async sendReset (options?: AbortOptions): Promise { + + } + + async sendCloseWrite (options?: AbortOptions): Promise { + + } + + async sendCloseRead (options?: AbortOptions): Promise { + + } +} + +describe('abstract stream', () => { + let stream: TestStream + + beforeEach(() => { + stream = new TestStream({ + id: 'test', + direction: 'outbound', + log: logger() + }) + }) + + it('sends data', async () => { + const sendSpy = Sinon.spy(stream, 'sendData') + const data = [ + Uint8Array.from([0, 1, 2, 3, 4]) + ] + + await stream.sink(data) + + const call = sendSpy.getCall(0) + expect(call.args[0].subarray()).to.equalBytes(data[0]) + }) + + it('receives data', async () => { + const data = new Uint8ArrayList( + Uint8Array.from([0, 1, 2, 3, 4]) + ) + + stream.sourcePush(data) + stream.remoteCloseWrite() + + const output = await all(stream.source) + expect(output[0].subarray()).to.equalBytes(data.subarray()) + }) + + it('closes', async () => { + const sendCloseReadSpy = Sinon.spy(stream, 'sendCloseRead') + const sendCloseWriteSpy = Sinon.spy(stream, 'sendCloseWrite') + + await stream.close() + + expect(sendCloseReadSpy.calledOnce).to.be.true() + expect(sendCloseWriteSpy.calledOnce).to.be.true() + + expect(stream).to.have.property('status', 'closed') + expect(stream).to.have.property('writeStatus', 'closed') + expect(stream).to.have.property('readStatus', 'closed') + expect(stream).to.have.nested.property('timeline.close').that.is.a('number') + expect(stream).to.have.nested.property('timeline.closeRead').that.is.a('number') + expect(stream).to.have.nested.property('timeline.closeWrite').that.is.a('number') + expect(stream).to.not.have.nested.property('timeline.reset') + expect(stream).to.not.have.nested.property('timeline.abort') + }) + + it('closes for reading', async () => { + const sendCloseReadSpy = Sinon.spy(stream, 'sendCloseRead') + const sendCloseWriteSpy = Sinon.spy(stream, 'sendCloseWrite') + + await stream.closeRead() + + expect(sendCloseReadSpy.calledOnce).to.be.true() + expect(sendCloseWriteSpy.called).to.be.false() + + expect(stream).to.have.property('status', 'open') + expect(stream).to.have.property('writeStatus', 'ready') + expect(stream).to.have.property('readStatus', 'closed') + expect(stream).to.not.have.nested.property('timeline.close') + expect(stream).to.have.nested.property('timeline.closeRead').that.is.a('number') + expect(stream).to.not.have.nested.property('timeline.closeWrite') + expect(stream).to.not.have.nested.property('timeline.reset') + expect(stream).to.not.have.nested.property('timeline.abort') + }) + + it('closes for writing', async () => { + const sendCloseReadSpy = Sinon.spy(stream, 'sendCloseRead') + const sendCloseWriteSpy = Sinon.spy(stream, 'sendCloseWrite') + + await stream.closeWrite() + + expect(sendCloseReadSpy.called).to.be.false() + expect(sendCloseWriteSpy.calledOnce).to.be.true() + + expect(stream).to.have.property('status', 'open') + expect(stream).to.have.property('writeStatus', 'closed') + expect(stream).to.have.property('readStatus', 'ready') + expect(stream).to.not.have.nested.property('timeline.close') + expect(stream).to.not.have.nested.property('timeline.closeRead') + expect(stream).to.have.nested.property('timeline.closeWrite').that.is.a('number') + expect(stream).to.not.have.nested.property('timeline.reset') + expect(stream).to.not.have.nested.property('timeline.abort') + }) + + it('aborts', async () => { + const sendResetSpy = Sinon.spy(stream, 'sendReset') + + stream.abort(new Error('Urk!')) + + expect(sendResetSpy.calledOnce).to.be.true() + + expect(stream).to.have.property('status', 'aborted') + expect(stream).to.have.property('writeStatus', 'closed') + expect(stream).to.have.property('readStatus', 'closed') + expect(stream).to.have.nested.property('timeline.close').that.is.a('number') + expect(stream).to.have.nested.property('timeline.closeRead').that.is.a('number') + expect(stream).to.have.nested.property('timeline.closeWrite').that.is.a('number') + expect(stream).to.not.have.nested.property('timeline.reset') + expect(stream).to.have.nested.property('timeline.abort').that.is.a('number') + + await expect(stream.sink([])).to.eventually.be.rejected + .with.property('code', 'ERR_SINK_INVALID_STATE') + await expect(drain(stream.source)).to.eventually.be.rejected + .with('Urk!') + }) + + it('gets reset remotely', async () => { + stream.reset() + + expect(stream).to.have.property('status', 'reset') + expect(stream).to.have.property('writeStatus', 'closed') + expect(stream).to.have.property('readStatus', 'closed') + expect(stream).to.have.nested.property('timeline.close').that.is.a('number') + expect(stream).to.have.nested.property('timeline.closeRead').that.is.a('number') + expect(stream).to.have.nested.property('timeline.closeWrite').that.is.a('number') + expect(stream).to.have.nested.property('timeline.reset').that.is.a('number') + expect(stream).to.not.have.nested.property('timeline.abort') + + await expect(stream.sink([])).to.eventually.be.rejected + .with.property('code', 'ERR_SINK_INVALID_STATE') + await expect(drain(stream.source)).to.eventually.be.rejected + .with.property('code', 'ERR_STREAM_RESET') + }) + + it('does not send close read when remote closes write', async () => { + const sendCloseReadSpy = Sinon.spy(stream, 'sendCloseRead') + + stream.remoteCloseWrite() + + await delay(100) + + expect(sendCloseReadSpy.called).to.be.false() + }) + + it('does not send close write when remote closes read', async () => { + const sendCloseWriteSpy = Sinon.spy(stream, 'sendCloseWrite') + + stream.remoteCloseRead() + + await delay(100) + + expect(sendCloseWriteSpy.called).to.be.false() + }) + + it('does not send close read or write when remote resets', async () => { + const sendCloseReadSpy = Sinon.spy(stream, 'sendCloseRead') + const sendCloseWriteSpy = Sinon.spy(stream, 'sendCloseWrite') + + stream.reset() + + await delay(100) + + expect(sendCloseReadSpy.called).to.be.false() + expect(sendCloseWriteSpy.called).to.be.false() + }) +}) diff --git a/packages/libp2p/src/connection/index.ts b/packages/libp2p/src/connection/index.ts index 8509b8d70a..5726d1cc54 100644 --- a/packages/libp2p/src/connection/index.ts +++ b/packages/libp2p/src/connection/index.ts @@ -158,16 +158,22 @@ export class ConnectionImpl implements Connection { } catch { } try { + log.trace('closing all streams') + // close all streams gracefully - this can throw if we're not multiplexed await Promise.all( this.streams.map(async s => s.close(options)) ) - // Close raw connection + log.trace('closing underlying transport') + + // close raw connection await this._close(options) - this.timeline.close = Date.now() + log.trace('updating timeline with close time') + this.status = 'closed' + this.timeline.close = Date.now() } catch (err: any) { log.error('error encountered during graceful close of connection to %a', this.remoteAddr, err) this.abort(err) diff --git a/packages/libp2p/src/dcutr/dcutr.ts b/packages/libp2p/src/dcutr/dcutr.ts index 190766bfa2..e6852b6d03 100644 --- a/packages/libp2p/src/dcutr/dcutr.ts +++ b/packages/libp2p/src/dcutr/dcutr.ts @@ -262,7 +262,10 @@ export class DefaultDCUtRService implements Startable { } log('unilateral connection upgrade to %p succeeded via %a, closing relayed connection', relayedConnection.remotePeer, connection.remoteAddr) - await relayedConnection.close() + + await relayedConnection.close({ + signal + }) return true } catch (err) { diff --git a/packages/libp2p/src/upgrader.ts b/packages/libp2p/src/upgrader.ts index c7a6829175..18934e3ad0 100644 --- a/packages/libp2p/src/upgrader.ts +++ b/packages/libp2p/src/upgrader.ts @@ -545,11 +545,16 @@ export class DefaultUpgrader implements Upgrader { newStream: newStream ?? errConnectionNotMultiplexed, getStreams: () => { if (muxer != null) { return muxer.streams } else { return [] } }, close: async (options?: AbortOptions) => { - await maConn.close(options) // Ensure remaining streams are closed gracefully if (muxer != null) { + log.trace('close muxer') await muxer.close(options) } + + log.trace('close maconn') + // close the underlying transport + await maConn.close(options) + log.trace('closed maconn') }, abort: (err) => { maConn.abort(err) diff --git a/packages/transport-webrtc/.aegir.js b/packages/transport-webrtc/.aegir.js index 491576df87..3b38600b7a 100644 --- a/packages/transport-webrtc/.aegir.js +++ b/packages/transport-webrtc/.aegir.js @@ -8,7 +8,6 @@ export default { before: async () => { const { createLibp2p } = await import('libp2p') const { circuitRelayServer } = await import('libp2p/circuit-relay') - const { identifyService } = await import('libp2p/identify') const { webSockets } = await import('@libp2p/websockets') const { noise } = await import('@chainsafe/libp2p-noise') const { yamux } = await import('@chainsafe/libp2p-yamux') @@ -34,11 +33,11 @@ export default { reservations: { maxReservations: Infinity } - }), - identify: identifyService() + }) }, connectionManager: { - minConnections: 0 + minConnections: 0, + inboundConnectionThreshold: Infinity } }) diff --git a/packages/transport-webrtc/package.json b/packages/transport-webrtc/package.json index 46302127ae..744b203097 100644 --- a/packages/transport-webrtc/package.json +++ b/packages/transport-webrtc/package.json @@ -34,7 +34,7 @@ "scripts": { "generate": "protons src/private-to-private/pb/message.proto src/pb/message.proto", "build": "aegir build", - "test": "aegir test -t node -t browser -t electron-main -- --exit", + "test": "aegir test -t node -t electron-main -g \"interface-transport compliance\"", "test:node": "aegir test -t node --cov -- --exit", "test:chrome": "aegir test -t browser --cov", "test:firefox": "aegir test -t browser -- --browser firefox", @@ -52,7 +52,6 @@ "@multiformats/mafmt": "^12.1.2", "@multiformats/multiaddr": "^12.1.5", "@multiformats/multiaddr-matcher": "^1.0.1", - "abortable-iterator": "^5.0.1", "detect-browser": "^5.3.0", "it-length-prefixed": "^9.0.1", "it-pipe": "^3.0.1", @@ -62,16 +61,17 @@ "it-to-buffer": "^4.0.2", "multiformats": "^12.0.1", "multihashes": "^4.0.3", - "node-datachannel": "^0.4.3", + "node-datachannel": "^0.5.0-dev", "p-defer": "^4.0.0", "p-event": "^6.0.0", + "p-timeout": "^6.1.2", "protons-runtime": "^5.0.0", "uint8arraylist": "^2.4.3", "uint8arrays": "^4.0.6" }, "devDependencies": { "@chainsafe/libp2p-yamux": "^5.0.0", - "@libp2p/interface-compliance-tests": "^4.0.6", + "@libp2p/interface-compliance-tests": "file:../interface-compliance-tests", "@libp2p/peer-id-factory": "^3.0.4", "@libp2p/websockets": "^7.0.7", "@types/sinon": "^10.0.15", diff --git a/packages/transport-webrtc/src/index.ts b/packages/transport-webrtc/src/index.ts index 0245aefcc9..8ce992f839 100644 --- a/packages/transport-webrtc/src/index.ts +++ b/packages/transport-webrtc/src/index.ts @@ -3,6 +3,40 @@ import { WebRTCDirectTransport, type WebRTCTransportDirectInit, type WebRTCDirec import type { WebRTCTransportComponents, WebRTCTransportInit } from './private-to-private/transport.js' import type { Transport } from '@libp2p/interface/transport' +export interface DataChannelOptions { + /** + * The maximum message size sendable over the channel + */ + maxMessageSize?: number + + /** + * If the channel's `bufferedAmount` grows over this amount in bytes, wait + * for it to drain before sending more data (default: 16MB) + */ + maxBufferedAmount?: number + + /** + * When `bufferedAmount` is above `maxBufferedAmount`, we pause sending until + * the `bufferedAmountLow` event fires - this controls how long we wait for + * that event in ms (default: 30s) + */ + bufferedAmountLowEventTimeout?: number + + /** + * When closing a stream, we wait for `bufferedAmount` to become 0 before + * closing the underlying RTCDataChannel - this controls how long we wait + * (default: 30s) + */ + drainTimeout?: number + + /** + * When closing a stream we first send a FIN flag to the remote and wait + * for a FIN_ACK reply before closing the underlying RTCDataChannel - this + * controls how long we wait for the acknowledgement (default: 5s) + */ + closeTimeout?: number +} + /** * @param {WebRTCTransportDirectInit} init - WebRTC direct transport configuration * @param init.dataChannel - DataChannel configurations diff --git a/packages/transport-webrtc/src/maconn.ts b/packages/transport-webrtc/src/maconn.ts index c32d88760c..067fdaab46 100644 --- a/packages/transport-webrtc/src/maconn.ts +++ b/packages/transport-webrtc/src/maconn.ts @@ -5,7 +5,7 @@ import type { CounterGroup } from '@libp2p/interface/metrics' import type { AbortOptions, Multiaddr } from '@multiformats/multiaddr' import type { Source, Sink } from 'it-stream-types' -const log = logger('libp2p:webrtc:connection') +const log = logger('libp2p:webrtc:maconn') interface WebRTCMultiaddrConnectionInit { /** @@ -65,8 +65,16 @@ export class WebRTCMultiaddrConnection implements MultiaddrConnection { this.timeline = init.timeline this.peerConnection = init.peerConnection + const initialState = this.peerConnection.connectionState + this.peerConnection.onconnectionstatechange = () => { - if (this.peerConnection.connectionState === 'closed' || this.peerConnection.connectionState === 'disconnected' || this.peerConnection.connectionState === 'failed') { + log.trace('peer connection state change', this.peerConnection.connectionState, 'initial state', initialState) + + if (this.peerConnection.connectionState === 'disconnected') { + // attempt to reconnect + this.peerConnection.restartIce() + } else if (this.peerConnection.connectionState === 'closed') { + // nothing else to do but close the connection this.timeline.close = Date.now() } } diff --git a/packages/transport-webrtc/src/muxer.ts b/packages/transport-webrtc/src/muxer.ts index 991d130230..fae5f15011 100644 --- a/packages/transport-webrtc/src/muxer.ts +++ b/packages/transport-webrtc/src/muxer.ts @@ -1,6 +1,7 @@ +import { logger } from '@libp2p/logger' import { createStream } from './stream.js' -import { nopSink, nopSource } from './util.js' -import type { DataChannelOpts } from './stream.js' +import { drainAndClose, nopSink, nopSource } from './util.js' +import type { DataChannelOptions } from './index.js' import type { Stream } from '@libp2p/interface/connection' import type { CounterGroup } from '@libp2p/interface/metrics' import type { StreamMuxer, StreamMuxerFactory, StreamMuxerInit } from '@libp2p/interface/stream-muxer' @@ -8,6 +9,8 @@ import type { AbortOptions } from '@multiformats/multiaddr' import type { Source, Sink } from 'it-stream-types' import type { Uint8ArrayList } from 'uint8arraylist' +const log = logger('libp2p:webrtc:muxer') + const PROTOCOL = '/webrtc' export interface DataChannelMuxerFactoryInit { @@ -17,19 +20,16 @@ export interface DataChannelMuxerFactoryInit { peerConnection: RTCPeerConnection /** - * Optional metrics for this data channel muxer + * The protocol to use */ - metrics?: CounterGroup + protocol?: string /** - * Data channel options + * Optional metrics for this data channel muxer */ - dataChannelOptions?: Partial + metrics?: CounterGroup - /** - * The protocol to use - */ - protocol?: string + dataChannelOptions?: DataChannelOptions } export class DataChannelMuxerFactory implements StreamMuxerFactory { @@ -41,23 +41,23 @@ export class DataChannelMuxerFactory implements StreamMuxerFactory { private readonly peerConnection: RTCPeerConnection private streamBuffer: Stream[] = [] private readonly metrics?: CounterGroup - private readonly dataChannelOptions?: Partial + private readonly dataChannelOptions?: DataChannelOptions constructor (init: DataChannelMuxerFactoryInit) { this.peerConnection = init.peerConnection this.metrics = init.metrics this.protocol = init.protocol ?? PROTOCOL - this.dataChannelOptions = init.dataChannelOptions + this.dataChannelOptions = init.dataChannelOptions ?? {} // store any datachannels opened before upgrade has been completed this.peerConnection.ondatachannel = ({ channel }) => { const stream = createStream({ channel, direction: 'inbound', - dataChannelOptions: init.dataChannelOptions, onEnd: () => { this.streamBuffer = this.streamBuffer.filter(s => s.id !== stream.id) - } + }, + ...this.dataChannelOptions }) this.streamBuffer.push(stream) } @@ -90,34 +90,15 @@ export class DataChannelMuxer implements StreamMuxer { public protocol: string private readonly peerConnection: RTCPeerConnection - private readonly dataChannelOptions?: DataChannelOpts + private readonly dataChannelOptions: DataChannelOptions private readonly metrics?: CounterGroup - /** - * Gracefully close all tracked streams and stop the muxer - */ - close: (options?: AbortOptions) => Promise = async () => { } - - /** - * Abort all tracked streams and stop the muxer - */ - abort: (err: Error) => void = () => { } - - /** - * The stream source, a no-op as the transport natively supports multiplexing - */ - source: AsyncGenerator = nopSource() - - /** - * The stream destination, a no-op as the transport natively supports multiplexing - */ - sink: Sink, Promise> = nopSink - constructor (readonly init: DataChannelMuxerInit) { this.streams = init.streams this.peerConnection = init.peerConnection this.protocol = init.protocol ?? PROTOCOL this.metrics = init.metrics + this.dataChannelOptions = init.dataChannelOptions ?? {} /** * Fired when a data channel has been added to the connection has been @@ -129,19 +110,19 @@ export class DataChannelMuxer implements StreamMuxer { const stream = createStream({ channel, direction: 'inbound', - dataChannelOptions: this.dataChannelOptions, onEnd: () => { + log.trace('stream %s %s %s onEnd', stream.direction, stream.id, stream.protocol) + drainAndClose(channel, `inbound ${stream.id} ${stream.protocol}`, this.dataChannelOptions.drainTimeout) this.streams = this.streams.filter(s => s.id !== stream.id) this.metrics?.increment({ stream_end: true }) init?.onStreamEnd?.(stream) - } + }, + ...this.dataChannelOptions }) this.streams.push(stream) - if ((init?.onIncomingStream) != null) { - this.metrics?.increment({ incoming_stream: true }) - init.onIncomingStream(stream) - } + this.metrics?.increment({ incoming_stream: true }) + init?.onIncomingStream?.(stream) } const onIncomingStream = init?.onIncomingStream @@ -150,19 +131,52 @@ export class DataChannelMuxer implements StreamMuxer { } } + /** + * Gracefully close all tracked streams and stop the muxer + */ + async close (options?: AbortOptions): Promise { + try { + await Promise.all( + this.streams.map(async stream => stream.close(options)) + ) + } catch (err: any) { + this.abort(err) + } + } + + /** + * Abort all tracked streams and stop the muxer + */ + abort (err: Error): void { + for (const stream of this.streams) { + stream.abort(err) + } + } + + /** + * The stream source, a no-op as the transport natively supports multiplexing + */ + source: AsyncGenerator = nopSource() + + /** + * The stream destination, a no-op as the transport natively supports multiplexing + */ + sink: Sink, Promise> = nopSink + newStream (): Stream { // The spec says the label SHOULD be an empty string: https://github.com/libp2p/specs/blob/master/webrtc/README.md#rtcdatachannel-label const channel = this.peerConnection.createDataChannel('') const stream = createStream({ channel, direction: 'outbound', - dataChannelOptions: this.dataChannelOptions, onEnd: () => { - channel.close() // Stream initiator is responsible for closing the channel + log.trace('stream %s %s %s onEnd', stream.direction, stream.id, stream.protocol) + drainAndClose(channel, `outbound ${stream.id} ${stream.protocol}`, this.dataChannelOptions.drainTimeout) this.streams = this.streams.filter(s => s.id !== stream.id) this.metrics?.increment({ stream_end: true }) this.init?.onStreamEnd?.(stream) - } + }, + ...this.dataChannelOptions }) this.streams.push(stream) this.metrics?.increment({ outgoing_stream: true }) diff --git a/packages/transport-webrtc/src/pb/message.proto b/packages/transport-webrtc/src/pb/message.proto index 9301bd802b..ea1ae55b99 100644 --- a/packages/transport-webrtc/src/pb/message.proto +++ b/packages/transport-webrtc/src/pb/message.proto @@ -2,7 +2,8 @@ syntax = "proto3"; message Message { enum Flag { - // The sender will no longer send messages on the stream. + // The sender will no longer send messages on the stream. The recipient + // should send a FIN_ACK back to the sender. FIN = 0; // The sender will no longer read messages on the stream. Incoming data is @@ -12,6 +13,10 @@ message Message { // The sender abruptly terminates the sending part of the stream. The // receiver can discard any data that it already received on that stream. RESET = 2; + + // The sender previously received a FIN. + // Workaround for https://bugs.chromium.org/p/chromium/issues/detail?id=1484907 + FIN_ACK = 3; } optional Flag flag = 1; diff --git a/packages/transport-webrtc/src/pb/message.ts b/packages/transport-webrtc/src/pb/message.ts index a74ca6dd06..f8abb7a4a9 100644 --- a/packages/transport-webrtc/src/pb/message.ts +++ b/packages/transport-webrtc/src/pb/message.ts @@ -17,13 +17,15 @@ export namespace Message { export enum Flag { FIN = 'FIN', STOP_SENDING = 'STOP_SENDING', - RESET = 'RESET' + RESET = 'RESET', + FIN_ACK = 'FIN_ACK' } enum __FlagValues { FIN = 0, STOP_SENDING = 1, - RESET = 2 + RESET = 2, + FIN_ACK = 3 } export namespace Flag { diff --git a/packages/transport-webrtc/src/private-to-private/handler.ts b/packages/transport-webrtc/src/private-to-private/handler.ts deleted file mode 100644 index 8564fc84d2..0000000000 --- a/packages/transport-webrtc/src/private-to-private/handler.ts +++ /dev/null @@ -1,177 +0,0 @@ -import { CodeError } from '@libp2p/interface/errors' -import { logger } from '@libp2p/logger' -import { abortableDuplex } from 'abortable-iterator' -import { pbStream } from 'it-protobuf-stream' -import pDefer, { type DeferredPromise } from 'p-defer' -import { DataChannelMuxerFactory } from '../muxer.js' -import { RTCPeerConnection, RTCSessionDescription } from '../webrtc/index.js' -import { Message } from './pb/message.js' -import { readCandidatesUntilConnected, resolveOnConnected } from './util.js' -import type { DataChannelOpts } from '../stream.js' -import type { Stream } from '@libp2p/interface/connection' -import type { StreamMuxerFactory } from '@libp2p/interface/stream-muxer' -import type { IncomingStreamData } from '@libp2p/interface-internal/registrar' - -const DEFAULT_TIMEOUT = 30 * 1000 - -const log = logger('libp2p:webrtc:peer') - -export type IncomingStreamOpts = { rtcConfiguration?: RTCConfiguration, dataChannelOptions?: Partial } & IncomingStreamData - -export async function handleIncomingStream ({ rtcConfiguration, dataChannelOptions, stream: rawStream }: IncomingStreamOpts): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> { - const signal = AbortSignal.timeout(DEFAULT_TIMEOUT) - const stream = pbStream(abortableDuplex(rawStream, signal)).pb(Message) - const pc = new RTCPeerConnection(rtcConfiguration) - - try { - const muxerFactory = new DataChannelMuxerFactory({ peerConnection: pc, dataChannelOptions }) - const connectedPromise: DeferredPromise = pDefer() - const answerSentPromise: DeferredPromise = pDefer() - - signal.onabort = () => { - connectedPromise.reject(new CodeError('Timed out while trying to connect', 'ERR_TIMEOUT')) - } - // candidate callbacks - pc.onicecandidate = ({ candidate }) => { - answerSentPromise.promise.then( - async () => { - await stream.write({ - type: Message.Type.ICE_CANDIDATE, - data: (candidate != null) ? JSON.stringify(candidate.toJSON()) : '' - }) - }, - (err) => { - log.error('cannot set candidate since sending answer failed', err) - connectedPromise.reject(err) - } - ) - } - - resolveOnConnected(pc, connectedPromise) - - // read an SDP offer - const pbOffer = await stream.read() - if (pbOffer.type !== Message.Type.SDP_OFFER) { - throw new Error(`expected message type SDP_OFFER, received: ${pbOffer.type ?? 'undefined'} `) - } - const offer = new RTCSessionDescription({ - type: 'offer', - sdp: pbOffer.data - }) - - await pc.setRemoteDescription(offer).catch(err => { - log.error('could not execute setRemoteDescription', err) - throw new Error('Failed to set remoteDescription') - }) - - // create and write an SDP answer - const answer = await pc.createAnswer().catch(err => { - log.error('could not execute createAnswer', err) - answerSentPromise.reject(err) - throw new Error('Failed to create answer') - }) - // write the answer to the remote - await stream.write({ type: Message.Type.SDP_ANSWER, data: answer.sdp }) - - await pc.setLocalDescription(answer).catch(err => { - log.error('could not execute setLocalDescription', err) - answerSentPromise.reject(err) - throw new Error('Failed to set localDescription') - }) - - answerSentPromise.resolve() - - // wait until candidates are connected - await readCandidatesUntilConnected(connectedPromise, pc, stream) - - const remoteAddress = parseRemoteAddress(pc.currentRemoteDescription?.sdp ?? '') - - return { pc, muxerFactory, remoteAddress } - } catch (err) { - pc.close() - throw err - } -} - -export interface ConnectOptions { - stream: Stream - signal: AbortSignal - rtcConfiguration?: RTCConfiguration - dataChannelOptions?: Partial -} - -export async function initiateConnection ({ rtcConfiguration, dataChannelOptions, signal, stream: rawStream }: ConnectOptions): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> { - const stream = pbStream(abortableDuplex(rawStream, signal)).pb(Message) - // setup peer connection - const pc = new RTCPeerConnection(rtcConfiguration) - - try { - const muxerFactory = new DataChannelMuxerFactory({ peerConnection: pc, dataChannelOptions }) - - const connectedPromise: DeferredPromise = pDefer() - resolveOnConnected(pc, connectedPromise) - - // reject the connectedPromise if the signal aborts - signal.onabort = connectedPromise.reject - // we create the channel so that the peerconnection has a component for which - // to collect candidates. The label is not relevant to connection initiation - // but can be useful for debugging - const channel = pc.createDataChannel('init') - // setup callback to write ICE candidates to the remote - // peer - pc.onicecandidate = ({ candidate }) => { - void stream.write({ - type: Message.Type.ICE_CANDIDATE, - data: (candidate != null) ? JSON.stringify(candidate.toJSON()) : '' - }) - .catch(err => { - log.error('error sending ICE candidate', err) - }) - } - - // create an offer - const offerSdp = await pc.createOffer() - // write the offer to the stream - await stream.write({ type: Message.Type.SDP_OFFER, data: offerSdp.sdp }) - // set offer as local description - await pc.setLocalDescription(offerSdp).catch(err => { - log.error('could not execute setLocalDescription', err) - throw new Error('Failed to set localDescription') - }) - - // read answer - const answerMessage = await stream.read() - if (answerMessage.type !== Message.Type.SDP_ANSWER) { - throw new Error('remote should send an SDP answer') - } - - const answerSdp = new RTCSessionDescription({ type: 'answer', sdp: answerMessage.data }) - await pc.setRemoteDescription(answerSdp).catch(err => { - log.error('could not execute setRemoteDescription', err) - throw new Error('Failed to set remoteDescription') - }) - - await readCandidatesUntilConnected(connectedPromise, pc, stream) - channel.close() - - const remoteAddress = parseRemoteAddress(pc.currentRemoteDescription?.sdp ?? '') - - return { pc, muxerFactory, remoteAddress } - } catch (err) { - pc.close() - throw err - } -} - -function parseRemoteAddress (sdp: string): string { - // 'a=candidate:1746876089 1 udp 2113937151 0614fbad-b...ocal 54882 typ host generation 0 network-cost 999' - const candidateLine = sdp.split('\r\n').filter(line => line.startsWith('a=candidate')).pop() - const candidateParts = candidateLine?.split(' ') - - if (candidateLine == null || candidateParts == null || candidateParts.length < 5) { - log('could not parse remote address from', candidateLine) - return '/webrtc' - } - - return `/dnsaddr/${candidateParts[4]}/${candidateParts[2].toLowerCase()}/${candidateParts[5]}/webrtc` -} diff --git a/packages/transport-webrtc/src/private-to-private/initiate-connection.ts b/packages/transport-webrtc/src/private-to-private/initiate-connection.ts new file mode 100644 index 0000000000..2821467e38 --- /dev/null +++ b/packages/transport-webrtc/src/private-to-private/initiate-connection.ts @@ -0,0 +1,194 @@ +import { CodeError } from '@libp2p/interface/errors' +import { logger } from '@libp2p/logger' +import { peerIdFromString } from '@libp2p/peer-id' +import { multiaddr, type Multiaddr } from '@multiformats/multiaddr' +import { pbStream } from 'it-protobuf-stream' +import pDefer, { type DeferredPromise } from 'p-defer' +import { type RTCPeerConnection, RTCSessionDescription } from '../webrtc/index.js' +import { Message } from './pb/message.js' +import { SIGNALING_PROTO_ID, splitAddr, type WebRTCTransportMetrics } from './transport.js' +import { parseRemoteAddress, readCandidatesUntilConnected, resolveOnConnected } from './util.js' +import type { DataChannelOptions } from '../index.js' +import type { Connection } from '@libp2p/interface/connection' +import type { ConnectionManager } from '@libp2p/interface-internal/connection-manager' +import type { IncomingStreamData } from '@libp2p/interface-internal/registrar' +import type { TransportManager } from '@libp2p/interface-internal/transport-manager' + +const log = logger('libp2p:webrtc:initiate-connection') + +export interface IncomingStreamOpts extends IncomingStreamData { + rtcConfiguration?: RTCConfiguration + dataChannelOptions?: Partial + signal: AbortSignal +} + +export interface ConnectOptions { + peerConnection: RTCPeerConnection + multiaddr: Multiaddr + connectionManager: ConnectionManager + transportManager: TransportManager + dataChannelOptions?: Partial + signal?: AbortSignal + metrics?: WebRTCTransportMetrics +} + +export async function initiateConnection ({ peerConnection, signal, metrics, multiaddr: ma, connectionManager, transportManager }: ConnectOptions): Promise<{ remoteAddress: Multiaddr }> { + const { baseAddr, peerId } = splitAddr(ma) + + metrics?.dialerEvents.increment({ open: true }) + + log.trace('dialing base address: %a', baseAddr) + + const relayPeer = baseAddr.getPeerId() + + if (relayPeer == null) { + throw new CodeError('Relay peer was missing', 'ERR_INVALID_ADDRESS') + } + + const connections = connectionManager.getConnections(peerIdFromString(relayPeer)) + let connection: Connection + let shouldCloseConnection = false + + if (connections.length === 0) { + // use the transport manager to open a connection. Initiating a WebRTC + // connection takes place in the context of a dial - if we use the + // connection manager instead we can end up joining our own dial context + connection = await transportManager.dial(baseAddr, { + signal + }) + // this connection is unmanaged by the connection manager so we should + // close it when we are done + shouldCloseConnection = true + } else { + connection = connections[0] + } + + try { + const stream = await connection.newStream(SIGNALING_PROTO_ID, { + signal, + runOnTransientConnection: true + }) + + const messageStream = pbStream(stream).pb(Message) + const connectedPromise: DeferredPromise = pDefer() + const sdpAbortedListener = (): void => { + connectedPromise.reject(new CodeError('SDP handshake aborted', 'ERR_SDP_HANDSHAKE_ABORTED')) + } + + try { + resolveOnConnected(peerConnection, connectedPromise) + + // reject the connectedPromise if the signal aborts + signal?.addEventListener('abort', sdpAbortedListener) + + // we create the channel so that the RTCPeerConnection has a component for + // which to collect candidates. The label is not relevant to connection + // initiation but can be useful for debugging + const channel = peerConnection.createDataChannel('init') + + // setup callback to write ICE candidates to the remote peer + peerConnection.onicecandidate = ({ candidate }) => { + let data = '' + + if (candidate != null) { + data = JSON.stringify(candidate.toJSON()) + log.trace('initiator send ICE candidate %s', data) + } + + log.trace('initiator sending ICE candidate %s', data) + + void messageStream.write({ + type: Message.Type.ICE_CANDIDATE, + data + }, { + signal + }) + .catch(err => { + log.error('error sending ICE candidate', err) + }) + } + peerConnection.onicecandidateerror = (event) => { + log('initiator ICE candidate error', event) + } + + // create an offer + const offerSdp = await peerConnection.createOffer() + + log.trace('initiator send SDP offer %s', offerSdp.sdp) + + // write the offer to the stream + await messageStream.write({ type: Message.Type.SDP_OFFER, data: offerSdp.sdp }, { + signal + }) + + // set offer as local description + await peerConnection.setLocalDescription(offerSdp).catch(err => { + log.error('could not execute setLocalDescription', err) + throw new Error('Failed to set localDescription') + }) + + // read answer + const answerMessage = await messageStream.read({ + signal + }) + + if (answerMessage.type !== Message.Type.SDP_ANSWER) { + throw new Error('remote should send an SDP answer') + } + + log.trace('initiator receive SDP answer %s', answerMessage.data) + + const answerSdp = new RTCSessionDescription({ type: 'answer', sdp: answerMessage.data }) + await peerConnection.setRemoteDescription(answerSdp).catch(err => { + log.error('could not execute setRemoteDescription', err) + throw new Error('Failed to set remoteDescription') + }) + + log.trace('initiator read candidates until connected') + + await readCandidatesUntilConnected(connectedPromise, peerConnection, messageStream, { + direction: 'initiator', + signal + }) + + log.trace('initiator connected, closing init channel') + channel.close() + + const remoteAddress = parseRemoteAddress(peerConnection.currentRemoteDescription?.sdp ?? '') + + log.trace('initiator connected to remote address %s', remoteAddress) + + // close the signalling stream + await messageStream.unwrap().unwrap().close({ + signal + }) + + log.trace('initiator closed signalling stream') + + return { + remoteAddress: multiaddr(remoteAddress).encapsulate(`/p2p/${peerId.toString()}`) + } + } catch (err: any) { + peerConnection.close() + stream.abort(err) + throw err + } finally { + // remove event listeners + signal?.removeEventListener('abort', sdpAbortedListener) + peerConnection.onicecandidate = null + peerConnection.onicecandidateerror = null + } + } finally { + // if we had to open a connection to perform the SDP handshake + // close it because it's not tracked by the connection manager + if (shouldCloseConnection) { + try { + await connection.close({ + signal + }) + } catch (err: any) { + connection.abort(err) + } + } + } +} diff --git a/packages/transport-webrtc/src/private-to-private/listener.ts b/packages/transport-webrtc/src/private-to-private/listener.ts index 1dccac6e25..53a3d299c6 100644 --- a/packages/transport-webrtc/src/private-to-private/listener.ts +++ b/packages/transport-webrtc/src/private-to-private/listener.ts @@ -5,20 +5,27 @@ import type { ListenerEvents, Listener } from '@libp2p/interface/transport' import type { TransportManager } from '@libp2p/interface-internal/transport-manager' import type { Multiaddr } from '@multiformats/multiaddr' -export interface ListenerOptions { +export interface WebRTCPeerListenerComponents { peerId: PeerId transportManager: TransportManager } +export interface WebRTCPeerListenerInit { + shutdownController: AbortController +} + export class WebRTCPeerListener extends EventEmitter implements Listener { private readonly peerId: PeerId private readonly transportManager: TransportManager + private readonly shutdownController: AbortController - constructor (opts: ListenerOptions) { + constructor (components: WebRTCPeerListenerComponents, init: WebRTCPeerListenerInit) { super() - this.peerId = opts.peerId - this.transportManager = opts.transportManager + this.peerId = components.peerId + this.transportManager = components.transportManager + + this.shutdownController = init.shutdownController } async listen (): Promise { @@ -39,6 +46,7 @@ export class WebRTCPeerListener extends EventEmitter implements } async close (): Promise { + this.shutdownController.abort() this.safeDispatchEvent('close', {}) } } diff --git a/packages/transport-webrtc/src/private-to-private/signaling-stream-handler.ts b/packages/transport-webrtc/src/private-to-private/signaling-stream-handler.ts new file mode 100644 index 0000000000..ab3733eceb --- /dev/null +++ b/packages/transport-webrtc/src/private-to-private/signaling-stream-handler.ts @@ -0,0 +1,129 @@ +import { CodeError } from '@libp2p/interface/errors' +import { logger } from '@libp2p/logger' +import { pbStream } from 'it-protobuf-stream' +import pDefer, { type DeferredPromise } from 'p-defer' +import { type RTCPeerConnection, RTCSessionDescription } from '../webrtc/index.js' +import { Message } from './pb/message.js' +import { parseRemoteAddress, readCandidatesUntilConnected, resolveOnConnected } from './util.js' +import type { IncomingStreamData } from '@libp2p/interface-internal/registrar' + +const log = logger('libp2p:webrtc:signaling-stream-handler') + +export interface IncomingStreamOpts extends IncomingStreamData { + peerConnection: RTCPeerConnection + signal: AbortSignal +} + +export async function handleIncomingStream ({ peerConnection, stream, signal, connection }: IncomingStreamOpts): Promise<{ remoteAddress: string }> { + log.trace('new inbound signaling stream') + + const messageStream = pbStream(stream).pb(Message) + + try { + const connectedPromise: DeferredPromise = pDefer() + const answerSentPromise: DeferredPromise = pDefer() + + signal.onabort = () => { + connectedPromise.reject(new CodeError('Timed out while trying to connect', 'ERR_TIMEOUT')) + } + + // candidate callbacks + peerConnection.onicecandidate = ({ candidate }) => { + answerSentPromise.promise.then( + async () => { + let data = '' + + if (candidate != null) { + data = JSON.stringify(candidate.toJSON()) + log.trace('recipient send ICE candidate %s', data) + } + + await messageStream.write({ + type: Message.Type.ICE_CANDIDATE, + data + }, { + signal + }) + }, + (err) => { + log.error('cannot set candidate since sending answer failed', err) + connectedPromise.reject(err) + } + ) + } + + resolveOnConnected(peerConnection, connectedPromise) + + // read an SDP offer + const pbOffer = await messageStream.read({ + signal + }) + + if (pbOffer.type !== Message.Type.SDP_OFFER) { + throw new Error(`expected message type SDP_OFFER, received: ${pbOffer.type ?? 'undefined'} `) + } + + log.trace('recipient receive SDP offer %s', pbOffer.data) + + const offer = new RTCSessionDescription({ + type: 'offer', + sdp: pbOffer.data + }) + + await peerConnection.setRemoteDescription(offer).catch(err => { + log.error('could not execute setRemoteDescription', err) + throw new Error('Failed to set remoteDescription') + }) + + // create and write an SDP answer + const answer = await peerConnection.createAnswer().catch(err => { + log.error('could not execute createAnswer', err) + answerSentPromise.reject(err) + throw new Error('Failed to create answer') + }) + + log.trace('recipient send SDP answer %s', answer.sdp) + + // write the answer to the remote + await messageStream.write({ type: Message.Type.SDP_ANSWER, data: answer.sdp }, { + signal + }) + + await peerConnection.setLocalDescription(answer).catch(err => { + log.error('could not execute setLocalDescription', err) + answerSentPromise.reject(err) + throw new Error('Failed to set localDescription') + }) + + answerSentPromise.resolve() + + log.trace('recipient read candidates until connected') + + // wait until candidates are connected + await readCandidatesUntilConnected(connectedPromise, peerConnection, messageStream, { + direction: 'recipient', + signal + }) + + log.trace('recipient connected, closing signaling stream') + + await messageStream.unwrap().unwrap().close({ + signal + }) + } catch (err: any) { + if (peerConnection.connectionState !== 'connected') { + log.error('error while handling signaling stream from peer %a', connection.remoteAddr, err) + + peerConnection.close() + throw err + } else { + log('error while handling signaling stream from peer %a, ignoring as the RTCPeerConnection is already connected', connection.remoteAddr, err) + } + } + + const remoteAddress = parseRemoteAddress(peerConnection.currentRemoteDescription?.sdp ?? '') + + log.trace('recipient connected to remote address %s', remoteAddress) + + return { remoteAddress } +} diff --git a/packages/transport-webrtc/src/private-to-private/transport.ts b/packages/transport-webrtc/src/private-to-private/transport.ts index 9ab8de1f1e..848438e465 100644 --- a/packages/transport-webrtc/src/private-to-private/transport.ts +++ b/packages/transport-webrtc/src/private-to-private/transport.ts @@ -6,26 +6,36 @@ import { multiaddr, type Multiaddr } from '@multiformats/multiaddr' import { WebRTC } from '@multiformats/multiaddr-matcher' import { codes } from '../error.js' import { WebRTCMultiaddrConnection } from '../maconn.js' -import { cleanup } from '../webrtc/index.js' -import { initiateConnection, handleIncomingStream } from './handler.js' +import { DataChannelMuxerFactory } from '../muxer.js' +import { cleanup, RTCPeerConnection } from '../webrtc/index.js' +import { initiateConnection } from './initiate-connection.js' import { WebRTCPeerListener } from './listener.js' -import type { DataChannelOpts } from '../stream.js' +import { handleIncomingStream } from './signaling-stream-handler.js' +import type { DataChannelOptions } from '../index.js' import type { Connection } from '@libp2p/interface/connection' import type { PeerId } from '@libp2p/interface/peer-id' import type { CounterGroup, Metrics } from '@libp2p/interface/src/metrics/index.js' import type { Startable } from '@libp2p/interface/startable' import type { IncomingStreamData, Registrar } from '@libp2p/interface-internal/registrar' +import type { ConnectionManager } from '@libp2p/interface-internal/src/connection-manager/index.js' import type { TransportManager } from '@libp2p/interface-internal/transport-manager' const log = logger('libp2p:webrtc:peer') -const WEBRTC_TRANSPORT = '/webrtc' -const CIRCUIT_RELAY_TRANSPORT = '/p2p-circuit' -const SIGNALING_PROTO_ID = '/webrtc-signaling/0.0.1' +export const WEBRTC_TRANSPORT = '/webrtc' +export const CIRCUIT_RELAY_TRANSPORT = '/p2p-circuit' +export const SIGNALING_PROTO_ID = '/webrtc-signaling/0.0.1' +export const INBOUND_CONNECTION_TIMEOUT = 30 * 1000 export interface WebRTCTransportInit { rtcConfiguration?: RTCConfiguration - dataChannel?: Partial + dataChannel?: DataChannelOptions + + /** + * Inbound connections must complete the upgrade within this many ms + * (default: 30s) + */ + inboundConnectionTimeout?: number } export interface WebRTCTransportComponents { @@ -33,6 +43,7 @@ export interface WebRTCTransportComponents { registrar: Registrar upgrader: Upgrader transportManager: TransportManager + connectionManager: ConnectionManager metrics?: Metrics } @@ -44,11 +55,14 @@ export interface WebRTCTransportMetrics { export class WebRTCTransport implements Transport, Startable { private _started = false private readonly metrics?: WebRTCTransportMetrics + private readonly shutdownController: AbortController constructor ( private readonly components: WebRTCTransportComponents, private readonly init: WebRTCTransportInit = {} ) { + this.shutdownController = new AbortController() + if (components.metrics != null) { this.metrics = { dialerEvents: components.metrics.registerCounterGroup('libp2p_webrtc_dialer_events_total', { @@ -83,7 +97,9 @@ export class WebRTCTransport implements Transport, Startable { } createListener (options: CreateListenerOptions): Listener { - return new WebRTCPeerListener(this.components) + return new WebRTCPeerListener(this.components, { + shutdownController: this.shutdownController + }) } readonly [Symbol.toStringTag] = '@libp2p/webrtc' @@ -102,84 +118,124 @@ export class WebRTCTransport implements Transport, Startable { * /p2p//p2p-circuit/webrtc/p2p/ */ async dial (ma: Multiaddr, options: DialOptions): Promise { - log.trace('dialing address: ', ma) - const { baseAddr, peerId } = splitAddr(ma) + log.trace('dialing address: %a', ma) - if (options.signal == null) { - const controller = new AbortController() - options.signal = controller.signal - } + const peerConnection = new RTCPeerConnection(this.init.rtcConfiguration) + const muxerFactory = new DataChannelMuxerFactory({ + peerConnection, + dataChannelOptions: this.init.dataChannel + }) - this.metrics?.dialerEvents.increment({ open: true }) - const connection = await this.components.transportManager.dial(baseAddr, options) - const signalingStream = await connection.newStream(SIGNALING_PROTO_ID, { - ...options, - runOnTransientConnection: true + const { remoteAddress } = await initiateConnection({ + peerConnection, + multiaddr: ma, + dataChannelOptions: this.init.dataChannel, + signal: options.signal, + connectionManager: this.components.connectionManager, + transportManager: this.components.transportManager }) - try { - const { pc, muxerFactory, remoteAddress } = await initiateConnection({ - stream: signalingStream, - rtcConfiguration: this.init.rtcConfiguration, + const webRTCConn = new WebRTCMultiaddrConnection({ + peerConnection, + timeline: { open: Date.now() }, + remoteAddr: remoteAddress, + metrics: this.metrics?.dialerEvents + }) + + const connection = await options.upgrader.upgradeOutbound(webRTCConn, { + skipProtection: true, + skipEncryption: true, + muxerFactory + }) + + peerConnection.onnegotiationneeded = () => { + log('initiator renegotiating connection') + + this.metrics?.dialerEvents.increment({ renegotiate: true }) + + let signal = options.signal + + if (signal?.aborted === true) { + signal = undefined + } + + void initiateConnection({ + peerConnection, + multiaddr: ma, dataChannelOptions: this.init.dataChannel, - signal: options.signal + signal: options.signal, + connectionManager: this.components.connectionManager, + transportManager: this.components.transportManager }) - - const result = await options.upgrader.upgradeOutbound( - new WebRTCMultiaddrConnection({ - peerConnection: pc, - timeline: { open: Date.now() }, - remoteAddr: multiaddr(remoteAddress).encapsulate(`/p2p/${peerId.toString()}`), - metrics: this.metrics?.dialerEvents - }), - { - skipProtection: true, - skipEncryption: true, - muxerFactory - } - ) - - // close the stream if SDP has been exchanged successfully - await signalingStream.close() - return result - } catch (err: any) { - this.metrics?.dialerEvents.increment({ error: true }) - // reset the stream in case of any error - signalingStream.abort(err) - throw err - } finally { - // Close the signaling connection - await connection.close() + .then(({ remoteAddress }) => { + webRTCConn.remoteAddr = multiaddr(remoteAddress) + }) + .catch(err => { + log.error('initiator errored while renegotiating connection') + connection.abort(err) + }) } + + // close the connection on shut down + this._closeOnShutdown(peerConnection, webRTCConn) + + return connection } async _onProtocol ({ connection, stream }: IncomingStreamData): Promise { + const signal = AbortSignal.timeout(this.init.inboundConnectionTimeout ?? INBOUND_CONNECTION_TIMEOUT) + const peerConnection = new RTCPeerConnection(this.init.rtcConfiguration) + const muxerFactory = new DataChannelMuxerFactory({ peerConnection, dataChannelOptions: this.init.dataChannel }) + try { - const { pc, muxerFactory, remoteAddress } = await handleIncomingStream({ - rtcConfiguration: this.init.rtcConfiguration, + const { remoteAddress } = await handleIncomingStream({ + peerConnection, connection, stream, - dataChannelOptions: this.init.dataChannel + signal }) - await this.components.upgrader.upgradeInbound(new WebRTCMultiaddrConnection({ - peerConnection: pc, + const webRTCConn = new WebRTCMultiaddrConnection({ + peerConnection, timeline: { open: (new Date()).getTime() }, remoteAddr: multiaddr(remoteAddress).encapsulate(`/p2p/${connection.remotePeer.toString()}`), metrics: this.metrics?.listenerEvents - }), { + }) + + // close the connection on shut down + this._closeOnShutdown(peerConnection, webRTCConn) + + await this.components.upgrader.upgradeInbound(webRTCConn, { skipEncryption: true, skipProtection: true, muxerFactory }) + + // close the stream if SDP messages have been exchanged successfully + await stream.close({ + signal + }) } catch (err: any) { stream.abort(err) throw err - } finally { - // Close the signaling connection - await connection.close() } } + + private _closeOnShutdown (pc: RTCPeerConnection, webRTCConn: WebRTCMultiaddrConnection): void { + // close the connection on shut down + const shutDownListener = (): void => { + webRTCConn.close() + .catch(err => { + log.error('could not close WebRTCMultiaddrConnection', err) + }) + } + + this.shutdownController.signal.addEventListener('abort', shutDownListener) + + pc.addEventListener('close', () => { + this.shutdownController.signal.removeEventListener('abort', shutDownListener) + }) + } } export function splitAddr (ma: Multiaddr): { baseAddr: Multiaddr, peerId: PeerId } { diff --git a/packages/transport-webrtc/src/private-to-private/util.ts b/packages/transport-webrtc/src/private-to-private/util.ts index 6d2b97898d..618bc49029 100644 --- a/packages/transport-webrtc/src/private-to-private/util.ts +++ b/packages/transport-webrtc/src/private-to-private/util.ts @@ -2,48 +2,55 @@ import { logger } from '@libp2p/logger' import { isFirefox } from '../util.js' import { RTCIceCandidate } from '../webrtc/index.js' import { Message } from './pb/message.js' +import type { AbortOptions, MessageStream } from 'it-protobuf-stream' import type { DeferredPromise } from 'p-defer' -interface MessageStream { - read: () => Promise - write: (d: Message) => void | Promise -} - const log = logger('libp2p:webrtc:peer:util') -export const readCandidatesUntilConnected = async (connectedPromise: DeferredPromise, pc: RTCPeerConnection, stream: MessageStream): Promise => { +export interface ReadCandidatesOptions extends AbortOptions { + direction: string +} + +export const readCandidatesUntilConnected = async (connectedPromise: DeferredPromise, pc: RTCPeerConnection, stream: MessageStream, options: ReadCandidatesOptions): Promise => { while (true) { - const readResult = await Promise.race([connectedPromise.promise, stream.read()]) - // check if readResult is a message - if (readResult instanceof Object) { - const message = readResult - if (message.type !== Message.Type.ICE_CANDIDATE) { - throw new Error('expected only ice candidates') - } - // end of candidates has been signalled - if (message.data == null || message.data === '') { - log.trace('end-of-candidates received') - break - } - - log.trace('received new ICE candidate: %s', message.data) - try { - await pc.addIceCandidate(new RTCIceCandidate(JSON.parse(message.data))) - } catch (err) { - log.error('bad candidate received: ', err) - throw new Error('bad candidate received') - } - } else { + const readResult = await Promise.race([ + connectedPromise.promise, + stream.read(options) + ]) + + if (readResult == null) { // connected promise resolved break } + + const message = readResult + + if (message.type !== Message.Type.ICE_CANDIDATE) { + throw new Error('expected only ice candidates') + } + + // end of candidates has been signalled + if (message.data == null || message.data === '') { + log.trace('%s received end-of-candidates', options.direction) + break + } + + log.trace('%s received new ICE candidate: %s', options.direction, message.data) + + try { + await pc.addIceCandidate(new RTCIceCandidate(JSON.parse(message.data))) + } catch (err) { + log.error('%s bad candidate received:', options.direction, err) + throw new Error('bad candidate received') + } } + await connectedPromise.promise } export function resolveOnConnected (pc: RTCPeerConnection, promise: DeferredPromise): void { pc[isFirefox ? 'oniceconnectionstatechange' : 'onconnectionstatechange'] = (_) => { - log.trace('receiver peerConnectionState state: ', pc.connectionState) + log.trace('receiver peerConnectionState state change: %s', pc.connectionState) switch (isFirefox ? pc.iceConnectionState : pc.connectionState) { case 'connected': promise.resolve() @@ -58,3 +65,16 @@ export function resolveOnConnected (pc: RTCPeerConnection, promise: DeferredProm } } } + +export function parseRemoteAddress (sdp: string): string { + // 'a=candidate:1746876089 1 udp 2113937151 0614fbad-b...ocal 54882 typ host generation 0 network-cost 999' + const candidateLine = sdp.split('\r\n').filter(line => line.startsWith('a=candidate')).pop() + const candidateParts = candidateLine?.split(' ') + + if (candidateLine == null || candidateParts == null || candidateParts.length < 5) { + log('could not parse remote address from', candidateLine) + return '/webrtc' + } + + return `/dnsaddr/${candidateParts[4]}/${candidateParts[2].toLowerCase()}/${candidateParts[5]}/webrtc` +} diff --git a/packages/transport-webrtc/src/private-to-public/transport.ts b/packages/transport-webrtc/src/private-to-public/transport.ts index 23bb5d1994..a82f84a2b5 100644 --- a/packages/transport-webrtc/src/private-to-public/transport.ts +++ b/packages/transport-webrtc/src/private-to-public/transport.ts @@ -16,7 +16,7 @@ import { RTCPeerConnection } from '../webrtc/index.js' import * as sdp from './sdp.js' import { genUfrag } from './util.js' import type { WebRTCDialOptions } from './options.js' -import type { DataChannelOpts } from '../stream.js' +import type { DataChannelOptions } from '../index.js' import type { Connection } from '@libp2p/interface/connection' import type { CounterGroup, Metrics } from '@libp2p/interface/metrics' import type { PeerId } from '@libp2p/interface/peer-id' @@ -56,7 +56,7 @@ export interface WebRTCMetrics { } export interface WebRTCTransportDirectInit { - dataChannel?: Partial + dataChannel?: DataChannelOptions } export class WebRTCDirectTransport implements Transport { @@ -81,7 +81,7 @@ export class WebRTCDirectTransport implements Transport { */ async dial (ma: Multiaddr, options: WebRTCDialOptions): Promise { const rawConn = await this._connect(ma, options) - log(`dialing address - ${ma.toString()}`) + log('dialing address: %a', ma) return rawConn } @@ -194,7 +194,7 @@ export class WebRTCDirectTransport implements Transport { // we pass in undefined for these parameters. const noise = Noise({ prologueBytes: fingerprintsPrologue })() - const wrappedChannel = createStream({ channel: handshakeDataChannel, direction: 'inbound', dataChannelOptions: this.init.dataChannel }) + const wrappedChannel = createStream({ channel: handshakeDataChannel, direction: 'inbound', ...(this.init.dataChannel ?? {}) }) const wrappedDuplex = { ...wrappedChannel, sink: wrappedChannel.sink.bind(wrappedChannel), diff --git a/packages/transport-webrtc/src/stream.ts b/packages/transport-webrtc/src/stream.ts index 1e9bb0c3c7..8e47b767a9 100644 --- a/packages/transport-webrtc/src/stream.ts +++ b/packages/transport-webrtc/src/stream.ts @@ -3,18 +3,17 @@ import { AbstractStream, type AbstractStreamInit } from '@libp2p/interface/strea import { logger } from '@libp2p/logger' import * as lengthPrefixed from 'it-length-prefixed' import { type Pushable, pushable } from 'it-pushable' +import pDefer from 'p-defer' import { pEvent, TimeoutError } from 'p-event' +import pTimeout from 'p-timeout' import { Uint8ArrayList } from 'uint8arraylist' import { Message } from './pb/message.js' +import type { DataChannelOptions } from './index.js' +import type { AbortOptions } from '@libp2p/interface' import type { Direction } from '@libp2p/interface/connection' +import type { DeferredPromise } from 'p-defer' -export interface DataChannelOpts { - maxMessageSize: number - maxBufferedAmount: number - bufferedAmountLowEventTimeout: number -} - -export interface WebRTCStreamInit extends AbstractStreamInit { +export interface WebRTCStreamInit extends AbstractStreamInit, DataChannelOptions { /** * The network channel used for bidirectional peer-to-peer transfers of * arbitrary data @@ -22,38 +21,46 @@ export interface WebRTCStreamInit extends AbstractStreamInit { * {@link https://developer.mozilla.org/en-US/docs/Web/API/RTCDataChannel} */ channel: RTCDataChannel - - dataChannelOptions?: Partial - - maxDataSize: number } -// Max message size that can be sent to the DataChannel -export const MAX_MESSAGE_SIZE = 16 * 1024 - -// How much can be buffered to the DataChannel at once +/** + * How much can be buffered to the DataChannel at once + */ export const MAX_BUFFERED_AMOUNT = 16 * 1024 * 1024 -// How long time we wait for the 'bufferedamountlow' event to be emitted +/** + * How long time we wait for the 'bufferedamountlow' event to be emitted + */ export const BUFFERED_AMOUNT_LOW_TIMEOUT = 30 * 1000 -// protobuf field definition overhead +/** + * protobuf field definition overhead + */ export const PROTOBUF_OVERHEAD = 5 -// Length of varint, in bytes. +/** + * Length of varint, in bytes + */ export const VARINT_LENGTH = 2 +/** + * Max message size that can be sent to the DataChannel + */ +export const MAX_MESSAGE_SIZE = 16 * 1024 + +/** + * When closing streams we send a FIN then wait for the remote to + * reply with a FIN_ACK. If that does not happen within this timeout + * we close the stream anyway. + */ +export const FIN_ACK_TIMEOUT = 5000 + export class WebRTCStream extends AbstractStream { /** * The data channel used to send and receive data */ private readonly channel: RTCDataChannel - /** - * Data channel options - */ - private readonly dataChannelOptions: DataChannelOpts - /** * push data from the underlying datachannel to the length prefix decoder * and then the protobuf decoder. @@ -62,24 +69,58 @@ export class WebRTCStream extends AbstractStream { private messageQueue?: Uint8ArrayList + private readonly maxBufferedAmount: number + + private readonly bufferedAmountLowEventTimeout: number + /** * The maximum size of a message in bytes */ - private readonly maxDataSize: number + private readonly maxMessageSize: number + + /** + * When this promise is resolved, the remote has sent us a FIN flag + */ + private readonly receiveFinAck: DeferredPromise + private receivedFinAck: boolean + private readonly finAckTimeout: number constructor (init: WebRTCStreamInit) { + // override onEnd to send/receive FIN_ACK before closing the stream + const originalOnEnd = init.onEnd + init.onEnd = (err?: Error): void => { + this.log.trace('received FIN, sending FIN_ACK', this.status) + this._sendFlag(Message.Flag.FIN_ACK) + .catch(err => { + this.log.error('error sending FIN_ACK', err) + }) + .then(async () => { + this.receivedFinAck.toString() + + await pTimeout(this.receiveFinAck.promise, { + milliseconds: this.finAckTimeout + }) + }) + .catch(err => { + this.log.error('error receiving FIN_ACK', err) + }) + .finally(() => { + originalOnEnd?.(err) + }) + } + super(init) this.channel = init.channel this.channel.binaryType = 'arraybuffer' this.incomingData = pushable() this.messageQueue = new Uint8ArrayList() - this.dataChannelOptions = { - bufferedAmountLowEventTimeout: init.dataChannelOptions?.bufferedAmountLowEventTimeout ?? BUFFERED_AMOUNT_LOW_TIMEOUT, - maxBufferedAmount: init.dataChannelOptions?.maxBufferedAmount ?? MAX_BUFFERED_AMOUNT, - maxMessageSize: init.dataChannelOptions?.maxMessageSize ?? init.maxDataSize - } - this.maxDataSize = init.maxDataSize + this.bufferedAmountLowEventTimeout = init.bufferedAmountLowEventTimeout ?? BUFFERED_AMOUNT_LOW_TIMEOUT + this.maxBufferedAmount = init.maxBufferedAmount ?? MAX_BUFFERED_AMOUNT + this.maxMessageSize = (init.maxMessageSize ?? MAX_MESSAGE_SIZE) - PROTOBUF_OVERHEAD - VARINT_LENGTH + this.receiveFinAck = pDefer() + this.receivedFinAck = false + this.finAckTimeout = init.closeTimeout ?? FIN_ACK_TIMEOUT // set up initial state switch (this.channel.readyState) { @@ -105,14 +146,18 @@ export class WebRTCStream extends AbstractStream { this.channel.onopen = (_evt) => { this.timeline.open = new Date().getTime() - if (this.messageQueue != null) { + if (this.messageQueue != null && this.messageQueue.byteLength > 0) { + this.log.trace('dataChannel opened, sending queued messages', this.messageQueue.byteLength, this.channel.readyState) + // send any queued messages this._sendMessage(this.messageQueue) .catch(err => { + this.log.error('error sending queued messages', err) this.abort(err) }) - this.messageQueue = undefined } + + this.messageQueue = undefined } this.channel.onclose = (_evt) => { @@ -126,8 +171,6 @@ export class WebRTCStream extends AbstractStream { this.abort(err) } - const self = this - this.channel.onmessage = async (event: MessageEvent) => { const { data } = event @@ -138,6 +181,8 @@ export class WebRTCStream extends AbstractStream { this.incomingData.push(new Uint8Array(data, 0, data.byteLength)) } + const self = this + // pipe framed protobuf messages through a length prefixed decoder, and // surface data from the `Message.message` field through a source. Promise.resolve().then(async () => { @@ -159,9 +204,9 @@ export class WebRTCStream extends AbstractStream { } async _sendMessage (data: Uint8ArrayList, checkBuffer: boolean = true): Promise { - if (checkBuffer && this.channel.bufferedAmount > this.dataChannelOptions.maxBufferedAmount) { + if (checkBuffer && this.channel.bufferedAmount > this.maxBufferedAmount) { try { - await pEvent(this.channel, 'bufferedamountlow', { timeout: this.dataChannelOptions.bufferedAmountLowEventTimeout }) + await pEvent(this.channel, 'bufferedamountlow', { timeout: this.bufferedAmountLowEventTimeout }) } catch (err: any) { if (err instanceof TimeoutError) { throw new Error('Timed out waiting for DataChannel buffer to clear') @@ -172,7 +217,7 @@ export class WebRTCStream extends AbstractStream { } if (this.channel.readyState === 'closed' || this.channel.readyState === 'closing') { - throw new CodeError('Invalid datachannel state - closed or closing', 'ERR_INVALID_STATE') + throw new CodeError(`Invalid datachannel state - ${this.channel.readyState}`, 'ERR_INVALID_STATE') } if (this.channel.readyState === 'open') { @@ -194,10 +239,12 @@ export class WebRTCStream extends AbstractStream { } async sendData (data: Uint8ArrayList): Promise { + // sending messages is an async operation so use a copy of the list as it + // may be changed beneath us data = data.sublist() while (data.byteLength > 0) { - const toSend = Math.min(data.byteLength, this.maxDataSize) + const toSend = Math.min(data.byteLength, this.maxMessageSize) const buf = data.subarray(0, toSend) const msgbuf = Message.encode({ message: buf }) const sendbuf = lengthPrefixed.encode.single(msgbuf) @@ -211,7 +258,12 @@ export class WebRTCStream extends AbstractStream { await this._sendFlag(Message.Flag.RESET) } - async sendCloseWrite (): Promise { + async sendCloseWrite (options: AbortOptions): Promise { + if (this.channel.readyState === 'closed') { + return + } + + this.log.trace('send FIN') await this._sendFlag(Message.Flag.FIN) } @@ -226,6 +278,8 @@ export class WebRTCStream extends AbstractStream { const message = Message.decode(buffer) if (message.flag !== undefined) { + this.log.trace('incoming flag', message.flag) + if (message.flag === Message.Flag.FIN) { // We should expect no more data from the remote, stop reading this.incomingData.end() @@ -241,13 +295,19 @@ export class WebRTCStream extends AbstractStream { // The remote has stopped reading this.remoteCloseRead() } + + if (message.flag === Message.Flag.FIN_ACK) { + this.log.trace('received FIN_ACK') + this.receivedFinAck = true + this.receiveFinAck.resolve() + } } return message.message } private async _sendFlag (flag: Message.Flag): Promise { - this.log.trace('Sending flag: %s', flag.toString()) + this.log.trace('sending flag: %s', flag.toString()) const msgbuf = Message.encode({ flag }) const prefixedBuf = lengthPrefixed.encode.single(msgbuf) @@ -255,7 +315,7 @@ export class WebRTCStream extends AbstractStream { } } -export interface WebRTCStreamOptions { +export interface WebRTCStreamOptions extends DataChannelOptions { /** * The network channel used for bidirectional peer-to-peer transfers of * arbitrary data @@ -269,23 +329,18 @@ export interface WebRTCStreamOptions { */ direction: Direction - dataChannelOptions?: Partial - - maxMsgSize?: number - + /** + * A callback invoked when the channel ends + */ onEnd?: (err?: Error | undefined) => void } export function createStream (options: WebRTCStreamOptions): WebRTCStream { - const { channel, direction, onEnd, dataChannelOptions } = options + const { channel, direction } = options return new WebRTCStream({ id: direction === 'inbound' ? (`i${channel.id}`) : `r${channel.id}`, - direction, - maxDataSize: (dataChannelOptions?.maxMessageSize ?? MAX_MESSAGE_SIZE) - PROTOBUF_OVERHEAD - VARINT_LENGTH, - dataChannelOptions, - onEnd, - channel, - log: logger(`libp2p:webrtc:stream:${direction}:${channel.id}`) + log: logger(`libp2p:webrtc:stream:${direction}:${channel.id}`), + ...options }) } diff --git a/packages/transport-webrtc/src/util.ts b/packages/transport-webrtc/src/util.ts index e26e64dd5f..e35b90ca43 100644 --- a/packages/transport-webrtc/src/util.ts +++ b/packages/transport-webrtc/src/util.ts @@ -1,4 +1,9 @@ +import { logger } from '@libp2p/logger' import { detect } from 'detect-browser' +import pDefer from 'p-defer' +import pTimeout from 'p-timeout' + +const log = logger('libp2p:webrtc:utils') const browser = detect() export const isFirefox = ((browser != null) && browser.name === 'firefox') @@ -6,3 +11,58 @@ export const isFirefox = ((browser != null) && browser.name === 'firefox') export const nopSource = async function * nop (): AsyncGenerator {} export const nopSink = async (_: any): Promise => {} + +export const DATA_CHANNEL_DRAIN_TIMEOUT = 30 * 1000 + +export function drainAndClose (channel: RTCDataChannel, direction: string, drainTimeout: number = DATA_CHANNEL_DRAIN_TIMEOUT): void { + if (channel.readyState !== 'open') { + return + } + + void Promise.resolve() + .then(async () => { + // wait for bufferedAmount to become zero + if (channel.bufferedAmount > 0) { + log('%s drain channel with %d buffered bytes', direction, channel.bufferedAmount) + const deferred = pDefer() + let drained = false + + channel.bufferedAmountLowThreshold = 0 + + const closeListener = (): void => { + if (!drained) { + log('%s drain channel closed before drain', direction) + deferred.resolve() + } + } + + channel.addEventListener('close', closeListener, { + once: true + }) + + channel.addEventListener('bufferedamountlow', () => { + drained = true + channel.removeEventListener('close', closeListener) + deferred.resolve() + }) + + await pTimeout(deferred.promise, { + milliseconds: drainTimeout + }) + } + }) + .then(async () => { + // only close if the channel is still open + if (channel.readyState === 'open') { + channel.close() + } + }) + .catch(err => { + log.error('error closing outbound stream', err) + }) +} + +export interface AbortPromiseOptions { + signal?: AbortSignal + message?: string +} diff --git a/packages/transport-webrtc/test/basics.spec.ts b/packages/transport-webrtc/test/basics.spec.ts index 03d89a1c8b..3aabbd6373 100644 --- a/packages/transport-webrtc/test/basics.spec.ts +++ b/packages/transport-webrtc/test/basics.spec.ts @@ -12,12 +12,13 @@ import { pipe } from 'it-pipe' import toBuffer from 'it-to-buffer' import { createLibp2p } from 'libp2p' import { circuitRelayTransport } from 'libp2p/circuit-relay' -import { identifyService } from 'libp2p/identify' +import pDefer from 'p-defer' import { webRTC } from '../src/index.js' import type { Libp2p } from '@libp2p/interface' import type { Connection } from '@libp2p/interface/connection' +import type { StreamHandler } from '@libp2p/interface/stream-handler' -async function createNode (): Promise { +export async function createRelayNode (): Promise { return createLibp2p({ addresses: { listen: [ @@ -38,9 +39,6 @@ async function createNode (): Promise { streamMuxers: [ yamux() ], - services: { - identify: identifyService() - }, connectionGater: { denyDialMultiaddr: () => false }, @@ -55,6 +53,7 @@ describe('basics', () => { let localNode: Libp2p let remoteNode: Libp2p + let streamHandler: StreamHandler async function connectNodes (): Promise { const remoteAddr = remoteNode.getMultiaddrs() @@ -64,11 +63,8 @@ describe('basics', () => { throw new Error('Remote peer could not listen on relay') } - await remoteNode.handle(echo, ({ stream }) => { - void pipe( - stream, - stream - ) + await remoteNode.handle(echo, (info) => { + streamHandler(info) }, { runOnTransientConnection: true }) @@ -83,8 +79,15 @@ describe('basics', () => { } beforeEach(async () => { - localNode = await createNode() - remoteNode = await createNode() + streamHandler = ({ stream }) => { + void pipe( + stream, + stream + ) + } + + localNode = await createRelayNode() + remoteNode = await createRelayNode() }) afterEach(async () => { @@ -101,9 +104,7 @@ describe('basics', () => { const connection = await connectNodes() // open a stream on the echo protocol - const stream = await connection.newStream(echo, { - runOnTransientConnection: true - }) + const stream = await connection.newStream(echo) // send and receive some data const input = new Array(5).fill(0).map(() => new Uint8Array(10)) @@ -138,4 +139,69 @@ describe('basics', () => { // asset that we got the right data expect(output).to.equalBytes(toBuffer(input)) }) + + it('can close local stream for reading but send a large file', async () => { + let output: Uint8Array = new Uint8Array(0) + const streamClosed = pDefer() + + streamHandler = ({ stream }) => { + void Promise.resolve().then(async () => { + output = await toBuffer(map(stream.source, (buf) => buf.subarray())) + await stream.close() + streamClosed.resolve() + }) + } + + const connection = await connectNodes() + + // open a stream on the echo protocol + const stream = await connection.newStream(echo, { + runOnTransientConnection: true + }) + + // close for reading + await stream.closeRead() + + // send some data + const input = new Array(5).fill(0).map(() => new Uint8Array(1024 * 1024)) + + await stream.sink(input) + await stream.close() + + // wait for remote to receive all data + await streamClosed.promise + + // asset that we got the right data + expect(output).to.equalBytes(toBuffer(input)) + }) + + it('can close local stream for writing but receive a large file', async () => { + const input = new Array(5).fill(0).map(() => new Uint8Array(1024 * 1024)) + + streamHandler = ({ stream }) => { + void Promise.resolve().then(async () => { + // send some data + await stream.sink(input) + await stream.close() + }) + } + + const connection = await connectNodes() + + // open a stream on the echo protocol + const stream = await connection.newStream(echo, { + runOnTransientConnection: true + }) + + // close for reading + await stream.closeWrite() + + // receive some data + const output = await toBuffer(map(stream.source, (buf) => buf.subarray())) + + await stream.close() + + // asset that we got the right data + expect(output).to.equalBytes(toBuffer(input)) + }) }) diff --git a/packages/transport-webrtc/test/compliance.spec.ts b/packages/transport-webrtc/test/compliance.spec.ts new file mode 100644 index 0000000000..dd30998a73 --- /dev/null +++ b/packages/transport-webrtc/test/compliance.spec.ts @@ -0,0 +1,66 @@ +/* eslint-env mocha */ + +import tests from '@libp2p/interface-compliance-tests/transport' +import { multiaddr } from '@multiformats/multiaddr' +import { mockUpgrader } from '@libp2p/interface-compliance-tests/mocks' +import { mockRegistrar } from '@libp2p/interface-compliance-tests/mocks' +import { WebRTCTransport } from '../src/private-to-private/transport.js' +import { stubInterface } from 'sinon-ts' +import type { ConnectionManager } from '@libp2p/interface-internal/connection-manager' +import { createRelayNode } from './basics.spec.js' +import { WebRTC } from '@multiformats/multiaddr-matcher' +import type { Connection } from '@libp2p/interface/connection' + + +describe('interface-transport compliance', () => { + tests({ + async setup() { + + const relayNode = await createRelayNode() + + const node = await createRelayNode() + + await node.start() + + const remoteAddr = relayNode.getMultiaddrs() + .filter(ma => WebRTC.matches(ma)).pop() + + if (remoteAddr == null) { + throw new Error('Remote peer could not listen on relay') + } + + const connection = await node.dial(remoteAddr) + + const peerA: any = { + peerId: node.peerId, + registrar: mockRegistrar(), + upgrader: mockUpgrader(), + } + + peerA.connectionManager = stubInterface() + + peerA.connectionManager.getConnections.returns([connection]) + + const wrtc = new WebRTCTransport(peerA) + + await wrtc.start() + + const addrs = [ + multiaddr(`/ip4/1.2.3.4/udp/1234/webrtc-direct/certhash/uEiAUqV7kzvM1wI5DYDc1RbcekYVmXli_Qprlw3IkiEg6tQ/p2p/${node.peerId.toString()}`) + ] + + const listeningAddrs = [ + remoteAddr, + ] + + // Used by the dial tests to simulate a delayed connect + const connector = { + delay() { }, + restore() { } + } + + return { transport: wrtc, addrs, connector, listeningAddrs } + }, + async teardown() { } + }) +}) \ No newline at end of file diff --git a/packages/transport-webrtc/test/listener.spec.ts b/packages/transport-webrtc/test/listener.spec.ts index 34feedb859..036e727d7c 100644 --- a/packages/transport-webrtc/test/listener.spec.ts +++ b/packages/transport-webrtc/test/listener.spec.ts @@ -16,6 +16,8 @@ describe('webrtc private-to-private listener', () => { const listener = new WebRTCPeerListener({ peerId, transportManager + }, { + shutdownController: new AbortController() }) const otherListener = stubInterface({ diff --git a/packages/transport-webrtc/test/peer.browser.spec.ts b/packages/transport-webrtc/test/peer.browser.spec.ts index 623a8a8542..5e98c1078a 100644 --- a/packages/transport-webrtc/test/peer.browser.spec.ts +++ b/packages/transport-webrtc/test/peer.browser.spec.ts @@ -1,56 +1,119 @@ -import { mockConnection, mockMultiaddrConnection, mockRegistrar, mockStream, mockUpgrader } from '@libp2p/interface-compliance-tests/mocks' +import { mockRegistrar, mockUpgrader, streamPair } from '@libp2p/interface-compliance-tests/mocks' import { createEd25519PeerId } from '@libp2p/peer-id-factory' -import { multiaddr } from '@multiformats/multiaddr' +import { multiaddr, type Multiaddr } from '@multiformats/multiaddr' import { expect } from 'aegir/chai' import { detect } from 'detect-browser' -import { pair } from 'it-pair' import { duplexPair } from 'it-pair/duplex' import { pbStream } from 'it-protobuf-stream' import Sinon from 'sinon' -import { initiateConnection, handleIncomingStream } from '../src/private-to-private/handler.js' +import { stubInterface, type StubbedInstance } from 'sinon-ts' +import { initiateConnection } from '../src/private-to-private/initiate-connection.js' import { Message } from '../src/private-to-private/pb/message.js' -import { WebRTCTransport, splitAddr } from '../src/private-to-private/transport.js' +import { handleIncomingStream } from '../src/private-to-private/signaling-stream-handler.js' +import { SIGNALING_PROTO_ID, WebRTCTransport, splitAddr } from '../src/private-to-private/transport.js' import { RTCPeerConnection, RTCSessionDescription } from '../src/webrtc/index.js' +import type { Connection, Stream } from '@libp2p/interface/connection' +import type { ConnectionManager } from '@libp2p/interface-internal/connection-manager' +import type { TransportManager } from '@libp2p/interface-internal/transport-manager' const browser = detect() +interface PrivateToPrivateComponents { + initiator: { + multiaddr: Multiaddr + peerConnection: RTCPeerConnection + connectionManager: StubbedInstance + transportManager: StubbedInstance + connection: StubbedInstance + stream: Stream + } + recipient: { + peerConnection: RTCPeerConnection + connection: StubbedInstance + abortController: AbortController + signal: AbortSignal + stream: Stream + } +} + +async function getComponents (): Promise { + const relayPeerId = await createEd25519PeerId() + const receiverPeerId = await createEd25519PeerId() + const receiverMultiaddr = multiaddr(`/ip4/123.123.123.123/tcp/123/p2p/${relayPeerId}/p2p-circuit/webrtc/p2p/${receiverPeerId}`) + const [initiatorToReceiver, receiverToInitiator] = duplexPair() + const [initiatorStream, receiverStream] = streamPair({ + duplex: initiatorToReceiver, + init: { + protocol: SIGNALING_PROTO_ID + } + }, { + duplex: receiverToInitiator, + init: { + protocol: SIGNALING_PROTO_ID + } + }) + + const recipientAbortController = new AbortController() + + return { + initiator: { + multiaddr: receiverMultiaddr, + peerConnection: new RTCPeerConnection(), + connectionManager: stubInterface(), + transportManager: stubInterface(), + connection: stubInterface(), + stream: initiatorStream + }, + recipient: { + peerConnection: new RTCPeerConnection(), + connection: stubInterface(), + abortController: recipientAbortController, + signal: recipientAbortController.signal, + stream: receiverStream + } + } +} + describe('webrtc basic', () => { const isFirefox = ((browser != null) && browser.name === 'firefox') it('should connect', async () => { - const [receiver, initiator] = duplexPair() - const dstPeerId = await createEd25519PeerId() - const connection = mockConnection( - mockMultiaddrConnection(pair(), dstPeerId) - ) - const controller = new AbortController() - const initiatorPeerConnectionPromise = initiateConnection({ stream: mockStream(initiator), signal: controller.signal }) - const receiverPeerConnectionPromise = handleIncomingStream({ stream: mockStream(receiver), connection }) - await expect(initiatorPeerConnectionPromise).to.be.fulfilled() - await expect(receiverPeerConnectionPromise).to.be.fulfilled() - const [{ pc: pc0 }, { pc: pc1 }] = await Promise.all([initiatorPeerConnectionPromise, receiverPeerConnectionPromise]) + const { initiator, recipient } = await getComponents() + + // no existing connection + initiator.connectionManager.getConnections.returns([]) + + // transport manager dials recipient + initiator.transportManager.dial.resolves(initiator.connection) + + // signalling stream opens successfully + initiator.connection.newStream.withArgs(SIGNALING_PROTO_ID).resolves(initiator.stream) + + await expect( + Promise.all([ + initiateConnection(initiator), + handleIncomingStream(recipient) + ]) + ).to.eventually.be.fulfilled() + if (isFirefox) { - expect(pc0.iceConnectionState).eq('connected') - expect(pc1.iceConnectionState).eq('connected') + expect(initiator.peerConnection.iceConnectionState).eq('connected') + expect(recipient.peerConnection.iceConnectionState).eq('connected') return } - expect(pc0.connectionState).eq('connected') - expect(pc1.connectionState).eq('connected') + expect(initiator.peerConnection.connectionState).eq('connected') + expect(recipient.peerConnection.connectionState).eq('connected') - pc0.close() - pc1.close() + initiator.peerConnection.close() + recipient.peerConnection.close() }) }) describe('webrtc receiver', () => { it('should fail receiving on invalid sdp offer', async () => { - const [receiver, initiator] = duplexPair() - const dstPeerId = await createEd25519PeerId() - const connection = mockConnection( - mockMultiaddrConnection(pair(), dstPeerId) - ) - const receiverPeerConnectionPromise = handleIncomingStream({ stream: mockStream(receiver), connection }) - const stream = pbStream(initiator).pb(Message) + const { initiator, recipient } = await getComponents() + const receiverPeerConnectionPromise = handleIncomingStream(recipient) + const stream = pbStream(initiator.stream).pb(Message) await stream.write({ type: Message.Type.SDP_OFFER, data: 'bad' }) await expect(receiverPeerConnectionPromise).to.be.rejectedWith(/Failed to set remoteDescription/) @@ -59,10 +122,18 @@ describe('webrtc receiver', () => { describe('webrtc dialer', () => { it('should fail receiving on invalid sdp answer', async () => { - const [receiver, initiator] = duplexPair() - const controller = new AbortController() - const initiatorPeerConnectionPromise = initiateConnection({ signal: controller.signal, stream: mockStream(initiator) }) - const stream = pbStream(receiver).pb(Message) + const { initiator, recipient } = await getComponents() + + // existing connection already exists + initiator.connectionManager.getConnections.returns([ + initiator.connection + ]) + + // signalling stream opens successfully + initiator.connection.newStream.withArgs(SIGNALING_PROTO_ID).resolves(initiator.stream) + + const initiatorPeerConnectionPromise = initiateConnection(initiator) + const stream = pbStream(recipient.stream).pb(Message) const offerMessage = await stream.read() expect(offerMessage.type).to.eq(Message.Type.SDP_OFFER) @@ -72,10 +143,19 @@ describe('webrtc dialer', () => { }) it('should fail on receiving a candidate before an answer', async () => { - const [receiver, initiator] = duplexPair() - const controller = new AbortController() - const initiatorPeerConnectionPromise = initiateConnection({ signal: controller.signal, stream: mockStream(initiator) }) - const stream = pbStream(receiver).pb(Message) + const { initiator, recipient } = await getComponents() + + // existing connection already exists + initiator.connectionManager.getConnections.returns([ + initiator.connection + ]) + + // signalling stream opens successfully + initiator.connection.newStream.withArgs(SIGNALING_PROTO_ID).resolves(initiator.stream) + + const initiatorPeerConnectionPromise = initiateConnection(initiator) + + const stream = pbStream(recipient.stream).pb(Message) const pc = new RTCPeerConnection() pc.onicecandidate = ({ candidate }) => { @@ -99,7 +179,8 @@ describe('webrtc dialer', () => { describe('webrtc filter', () => { it('can filter multiaddrs to dial', async () => { const transport = new WebRTCTransport({ - transportManager: Sinon.stub() as any, + transportManager: stubInterface(), + connectionManager: stubInterface(), peerId: Sinon.stub() as any, registrar: mockRegistrar(), upgrader: mockUpgrader({}) diff --git a/packages/transport-webrtc/test/stream.browser.spec.ts b/packages/transport-webrtc/test/stream.browser.spec.ts index 457f95317d..dcbc3394e1 100644 --- a/packages/transport-webrtc/test/stream.browser.spec.ts +++ b/packages/transport-webrtc/test/stream.browser.spec.ts @@ -11,7 +11,7 @@ const TEST_MESSAGE = 'test_message' function setup (): { peerConnection: RTCPeerConnection, dataChannel: RTCDataChannel, stream: WebRTCStream } { const peerConnection = new RTCPeerConnection() const dataChannel = peerConnection.createDataChannel('whatever', { negotiated: true, id: 91 }) - const stream = createStream({ channel: dataChannel, direction: 'outbound' }) + const stream = createStream({ channel: dataChannel, direction: 'outbound', closeTimeout: 1 }) return { peerConnection, dataChannel, stream } } @@ -28,9 +28,10 @@ function generatePbByFlag (flag?: Message.Flag): Uint8Array { describe('Stream Stats', () => { let stream: WebRTCStream let peerConnection: RTCPeerConnection + let dataChannel: RTCDataChannel beforeEach(async () => { - ({ stream, peerConnection } = setup()) + ({ stream, peerConnection, dataChannel } = setup()) }) afterEach(() => { @@ -45,7 +46,14 @@ describe('Stream Stats', () => { it('close marks it closed', async () => { expect(stream.timeline.close).to.not.exist() - await stream.close() + + const msgbuf = Message.encode({ flag: Message.Flag.FIN_ACK }) + const prefixedBuf = lengthPrefixed.encode.single(msgbuf) + + const p = stream.close() + dataChannel.dispatchEvent(new MessageEvent('message', { data: prefixedBuf })) + await p + expect(stream.timeline.close).to.be.a('number') }) diff --git a/packages/transport-webrtc/test/stream.spec.ts b/packages/transport-webrtc/test/stream.spec.ts index 500cbb02de..4e8073a530 100644 --- a/packages/transport-webrtc/test/stream.spec.ts +++ b/packages/transport-webrtc/test/stream.spec.ts @@ -4,6 +4,7 @@ import { expect } from 'aegir/chai' import length from 'it-length' import * as lengthPrefixed from 'it-length-prefixed' import { pushable } from 'it-pushable' +import pDefer from 'p-defer' import { Uint8ArrayList } from 'uint8arraylist' import { Message } from '../src/pb/message.js' import { MAX_BUFFERED_AMOUNT, MAX_MESSAGE_SIZE, PROTOBUF_OVERHEAD, createStream } from '../src/stream.js' @@ -33,7 +34,8 @@ describe('Max message size', () => { sent.append(bytes) } }), - direction: 'outbound' + direction: 'outbound', + closeTimeout: 1 }) p.push(data) @@ -78,11 +80,10 @@ describe('Max message size', () => { it('closes the stream if bufferamountlow timeout', async () => { const timeout = 100 - let closed = false + const closed = pDefer() const webrtcStream = createStream({ - dataChannelOptions: { - bufferedAmountLowEventTimeout: timeout - }, + bufferedAmountLowEventTimeout: timeout, + closeTimeout: 1, channel: mockDataChannel({ send: () => { throw new Error('Expected to not send') @@ -91,7 +92,7 @@ describe('Max message size', () => { }), direction: 'outbound', onEnd: () => { - closed = true + closed.resolve() } }) @@ -102,7 +103,7 @@ describe('Max message size', () => { const t1 = Date.now() expect(t1 - t0).greaterThan(timeout) expect(t1 - t0).lessThan(timeout + 1000) // Some upper bound - expect(closed).true() + await closed.promise expect(webrtcStream.timeline.close).to.be.greaterThan(webrtcStream.timeline.open) expect(webrtcStream.timeline.abort).to.be.greaterThan(webrtcStream.timeline.open) }) diff --git a/packages/transport-websockets/package.json b/packages/transport-websockets/package.json index 5da0939ac9..45d530f230 100644 --- a/packages/transport-websockets/package.json +++ b/packages/transport-websockets/package.json @@ -82,7 +82,7 @@ "ws": "^8.12.1" }, "devDependencies": { - "@libp2p/interface-compliance-tests": "^4.0.6", + "@libp2p/interface-compliance-tests": "file://Users/horizon/Desktop/work/js-libp2p/packages/interface-compliance-tests", "aegir": "^40.0.8", "is-loopback-addr": "^2.0.1", "it-all": "^3.0.1", diff --git a/packages/transport-websockets/test/node.ts b/packages/transport-websockets/test/node.ts index 3526129a71..32aa1b02ae 100644 --- a/packages/transport-websockets/test/node.ts +++ b/packages/transport-websockets/test/node.ts @@ -332,6 +332,10 @@ describe('dial', () => { return !isLoopbackAddr(address) }) + if (addrs.length === 0) { + return + } + // Dial first no loopback address const conn = await ws.dial(addrs[0], { upgrader }) const s = goodbye({ source: [uint8ArrayFromString('hey')], sink: all })