Skip to content

Commit 8d3c647

Browse files
authored
Add support for negotiating a subprotocol (#1150)
1 parent e71e739 commit 8d3c647

File tree

11 files changed

+226
-5
lines changed

11 files changed

+226
-5
lines changed

pkgs/web_socket/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
## 0.1.0-wip
22

3-
- Abstract interface definition.
3+
- Basic functionality in place.

pkgs/web_socket/lib/src/browser_web_socket.dart

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,23 @@ class BrowserWebSocket implements WebSocket {
1818
final web.WebSocket _webSocket;
1919
final _events = StreamController<WebSocketEvent>();
2020

21+
/// Create a new WebSocket connection using the JavaScript WebSocket API.
22+
///
23+
/// The URL supplied in [url] must use the scheme ws or wss.
24+
///
25+
/// If provided, the [protocols] argument indicates that subprotocols that
26+
/// the peer is able to select. See
27+
/// [RFC-6455 1.9](https://datatracker.ietf.org/doc/html/rfc6455#section-1.9).
2128
static Future<BrowserWebSocket> connect(Uri url,
2229
{Iterable<String>? protocols}) async {
23-
final webSocket = web.WebSocket(url.toString())..binaryType = 'arraybuffer';
30+
if (!url.isScheme('ws') && !url.isScheme('wss')) {
31+
throw ArgumentError.value(
32+
url, 'url', 'only ws: and wss: schemes are supported');
33+
}
34+
35+
final webSocket = web.WebSocket(url.toString(),
36+
protocols?.map((e) => e.toJS).toList().toJS ?? JSArray())
37+
..binaryType = 'arraybuffer';
2438
final browserSocket = BrowserWebSocket._(webSocket);
2539
final webSocketConnected = Completer<BrowserWebSocket>();
2640

@@ -126,6 +140,9 @@ class BrowserWebSocket implements WebSocket {
126140

127141
@override
128142
Stream<WebSocketEvent> get events => _events.stream;
143+
144+
@override
145+
String get protocol => _webSocket.protocol;
129146
}
130147

131148
const connect = BrowserWebSocket.connect;

pkgs/web_socket/lib/src/io_web_socket.dart

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ import 'dart:async';
66
import 'dart:io' as io;
77
import 'dart:typed_data';
88

9-
import '../web_socket.dart';
109
import 'utils.dart';
10+
import 'web_socket.dart';
1111

1212
/// A `dart-io`-based [WebSocket] implementation.
1313
///
@@ -16,14 +16,37 @@ class IOWebSocket implements WebSocket {
1616
final io.WebSocket _webSocket;
1717
final _events = StreamController<WebSocketEvent>();
1818

19+
/// Create a new WebSocket connection using dart:io WebSocket.
20+
///
21+
/// The URL supplied in [url] must use the scheme ws or wss.
22+
///
23+
/// If provided, the [protocols] argument indicates that subprotocols that
24+
/// the peer is able to select. See
25+
/// [RFC-6455 1.9](https://datatracker.ietf.org/doc/html/rfc6455#section-1.9).
1926
static Future<IOWebSocket> connect(Uri url,
2027
{Iterable<String>? protocols}) async {
28+
if (!url.isScheme('ws') && !url.isScheme('wss')) {
29+
throw ArgumentError.value(
30+
url, 'url', 'only ws: and wss: schemes are supported');
31+
}
32+
33+
final io.WebSocket webSocket;
2134
try {
22-
final webSocket = await io.WebSocket.connect(url.toString());
23-
return IOWebSocket._(webSocket);
35+
webSocket =
36+
await io.WebSocket.connect(url.toString(), protocols: protocols);
2437
} on io.WebSocketException catch (e) {
2538
throw WebSocketException(e.message);
2639
}
40+
41+
if (webSocket.protocol != null &&
42+
!(protocols ?? []).contains(webSocket.protocol)) {
43+
// dart:io WebSocket does not correctly validate the returned protocol.
44+
// See https://github.com/dart-lang/sdk/issues/55106
45+
await webSocket.close(1002); // protocol error
46+
throw WebSocketException(
47+
'unexpected protocol selected by peer: ${webSocket.protocol}');
48+
}
49+
return IOWebSocket._(webSocket);
2750
}
2851

2952
IOWebSocket._(this._webSocket) {
@@ -90,6 +113,9 @@ class IOWebSocket implements WebSocket {
90113

91114
@override
92115
Stream<WebSocketEvent> get events => _events.stream;
116+
117+
@override
118+
String get protocol => _webSocket.protocol ?? '';
93119
}
94120

95121
const connect = IOWebSocket.connect;

pkgs/web_socket/lib/src/web_socket.dart

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ class WebSocketConnectionClosed extends WebSocketException {
115115
/// socket.sendText('Hello Dart WebSockets! 🎉');
116116
/// }
117117
abstract interface class WebSocket {
118+
/// Create a new WebSocket connection.
119+
///
120+
/// The URL supplied in [url] must use the scheme ws or wss.
121+
///
122+
/// If provided, the [protocols] argument indicates that subprotocols that
123+
/// the peer is able to select. See
124+
/// [RFC-6455 1.9](https://datatracker.ietf.org/doc/html/rfc6455#section-1.9).
118125
static Future<WebSocket> connect(Uri url, {Iterable<String>? protocols}) =>
119126
connector.connect(url, protocols: protocols);
120127

@@ -169,4 +176,12 @@ abstract interface class WebSocket {
169176
///
170177
/// Errors will never appear in this [Stream].
171178
Stream<WebSocketEvent> get events;
179+
180+
/// The WebSocket subprotocol negotiated with the peer.
181+
///
182+
/// Will be the empty string if no subprotocol was negotiated.
183+
///
184+
/// See
185+
/// [RFC-6455 1.9](https://datatracker.ietf.org/doc/html/rfc6455#section-1.9).
186+
String get protocol;
172187
}

pkgs/web_socket_conformance_tests/example/client_test.dart

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ class MyWebSocketImplementation implements WebSocket {
2020

2121
@override
2222
void sendText(String s) => throw UnimplementedError();
23+
24+
@override
25+
String get protocol => throw UnimplementedError();
2326
}
2427

2528
void main() {
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Copyright (c) 2024, the Dart project authors. Please see the AUTHORS file
2+
// for details. All rights reserved. Use of this source code is governed by a
3+
// BSD-style license that can be found in the LICENSE file.
4+
5+
import 'package:test/test.dart';
6+
import 'package:web_socket/web_socket.dart';
7+
8+
/// Tests that the [WebSocket] rejects invalid connection URIs.
9+
void testConnectUri(
10+
Future<WebSocket> Function(Uri uri, {Iterable<String>? protocols})
11+
channelFactory) {
12+
group('connect uri', () {
13+
test('no protocol', () async {
14+
await expectLater(() => channelFactory(Uri.https('www.example.com', '/')),
15+
throwsA(isA<ArgumentError>()));
16+
});
17+
});
18+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) 2024, the Dart project authors. Please see the AUTHORS file
2+
// for details. All rights reserved. Use of this source code is governed by a
3+
// BSD-style license that can be found in the LICENSE file.
4+
5+
import 'dart:async';
6+
import 'dart:convert';
7+
import 'dart:io';
8+
9+
import 'package:crypto/crypto.dart';
10+
import 'package:stream_channel/stream_channel.dart';
11+
12+
const _webSocketGuid = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11';
13+
14+
/// Starts an WebSocket server that responds with a scripted subprotocol.
15+
void hybridMain(StreamChannel<Object?> channel) async {
16+
late final HttpServer server;
17+
server = (await HttpServer.bind('localhost', 0))
18+
..listen((request) async {
19+
final serverProtocol = request.requestedUri.queryParameters['protocol'];
20+
var key = request.headers.value('Sec-WebSocket-Key');
21+
var digest = sha1.convert('$key$_webSocketGuid'.codeUnits);
22+
var accept = base64.encode(digest.bytes);
23+
channel.sink.add(request.headers['Sec-WebSocket-Protocol']);
24+
request.response
25+
..statusCode = HttpStatus.switchingProtocols
26+
..headers.add(HttpHeaders.connectionHeader, 'Upgrade')
27+
..headers.add(HttpHeaders.upgradeHeader, 'websocket')
28+
..headers.add('Sec-WebSocket-Accept', accept);
29+
if (serverProtocol != null) {
30+
request.response.headers.add('Sec-WebSocket-Protocol', serverProtocol);
31+
}
32+
request.response.contentLength = 0;
33+
final socket = await request.response.detachSocket();
34+
final webSocket = WebSocket.fromUpgradedSocket(socket,
35+
protocol: serverProtocol, serverSide: true);
36+
webSocket.listen((e) async {
37+
webSocket.add(e);
38+
await webSocket.close();
39+
});
40+
});
41+
42+
channel.sink.add(server.port);
43+
44+
await channel
45+
.stream.first; // Any writes indicates that the server should exit.
46+
unawaited(server.close());
47+
}

pkgs/web_socket_conformance_tests/lib/src/protocol_server_vm.dart

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkgs/web_socket_conformance_tests/lib/src/protocol_server_web.dart

Lines changed: 9 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright (c) 2024, the Dart project authors. Please see the AUTHORS file
2+
// for details. All rights reserved. Use of this source code is governed by a
3+
// BSD-style license that can be found in the LICENSE file.
4+
5+
import 'package:async/async.dart';
6+
import 'package:stream_channel/stream_channel.dart';
7+
import 'package:test/test.dart';
8+
import 'package:web_socket/web_socket.dart';
9+
10+
import 'protocol_server_vm.dart'
11+
if (dart.library.html) 'protocol_server_web.dart';
12+
13+
/// Tests that the [WebSocket] can correctly negotiate a subprotocol with the
14+
/// peer.
15+
///
16+
/// See
17+
/// [RFC-6455 1.9](https://datatracker.ietf.org/doc/html/rfc6455#section-1.9).
18+
void testProtocols(
19+
Future<WebSocket> Function(Uri uri, {Iterable<String>? protocols})
20+
channelFactory) {
21+
group('protocols', () {
22+
late Uri uri;
23+
late StreamChannel<Object?> httpServerChannel;
24+
late StreamQueue<Object?> httpServerQueue;
25+
26+
setUp(() async {
27+
httpServerChannel = await startServer();
28+
httpServerQueue = StreamQueue(httpServerChannel.stream);
29+
uri = Uri.parse('ws://localhost:${await httpServerQueue.next}');
30+
});
31+
tearDown(() => httpServerChannel.sink.add(null));
32+
33+
test('no protocol', () async {
34+
final socket = await channelFactory(uri);
35+
36+
expect(await httpServerQueue.next, null);
37+
expect(socket.protocol, '');
38+
socket.sendText('Hello World!');
39+
});
40+
41+
test('single protocol', () async {
42+
final socket = await channelFactory(
43+
uri.replace(queryParameters: {'protocol': 'chat.example.com'}),
44+
protocols: ['chat.example.com']);
45+
46+
expect(await httpServerQueue.next, ['chat.example.com']);
47+
expect(socket.protocol, 'chat.example.com');
48+
socket.sendText('Hello World!');
49+
});
50+
51+
test('mutiple protocols', () async {
52+
final socket = await channelFactory(
53+
uri.replace(queryParameters: {'protocol': 'text.example.com'}),
54+
protocols: ['chat.example.com', 'text.example.com']);
55+
56+
expect(
57+
await httpServerQueue.next, ['chat.example.com, text.example.com']);
58+
expect(socket.protocol, 'text.example.com');
59+
socket.sendText('Hello World!');
60+
});
61+
62+
test('protocol mismatch', () async {
63+
await expectLater(
64+
() => channelFactory(
65+
uri.replace(queryParameters: {'protocol': 'example.example.com'}),
66+
protocols: ['chat.example.com']),
67+
throwsA(isA<WebSocketException>()));
68+
});
69+
});
70+
}

pkgs/web_socket_conformance_tests/lib/web_socket_conformance_tests.dart

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,23 @@
55
import 'package:web_socket/web_socket.dart';
66
import 'src/close_local_tests.dart';
77
import 'src/close_remote_tests.dart';
8+
import 'src/connect_uri_tests.dart';
89
import 'src/disconnect_after_upgrade_tests.dart';
910
import 'src/no_upgrade_tests.dart';
1011
import 'src/payload_transfer_tests.dart';
1112
import 'src/peer_protocol_errors_tests.dart';
13+
import 'src/protocol_tests.dart';
1214

1315
/// Runs the entire test suite against the given [WebSocket].
1416
void testAll(
1517
Future<WebSocket> Function(Uri uri, {Iterable<String>? protocols})
1618
webSocketFactory) {
1719
testCloseLocal(webSocketFactory);
1820
testCloseRemote(webSocketFactory);
21+
testConnectUri(webSocketFactory);
1922
testDisconnectAfterUpgrade(webSocketFactory);
2023
testNoUpgrade(webSocketFactory);
2124
testPayloadTransfer(webSocketFactory);
2225
testPeerProtocolErrors(webSocketFactory);
26+
testProtocols(webSocketFactory);
2327
}

0 commit comments

Comments
 (0)