Skip to content

Commit df38e5d

Browse files
committed
websocket: strawman http2 support
This patch adds http2 support to x/net/websocket. It is still pretty hacky and not well tested yet, but it shows that it can be done. Change-Id: I123253a74a2dbb6e42e7e31b724362814da112a5
1 parent d233d0c commit df38e5d

File tree

7 files changed

+424
-4
lines changed

7 files changed

+424
-4
lines changed

http2/stream_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func TestHTTP2Stream(t *testing.T) {
8787
// psudo headers by setting things in the headers hashmap.
8888
// I think the real solution here is to add a new `Protocol`
8989
// field to the `http.Request` struct.
90-
req.Header.Add("HACK-HTTP2-Protocol", "websocket")
90+
req.Header.Add("Hack-Http2-Protocol", "websocket")
9191

9292
resp, err := client.Transport.RoundTrip(req)
9393
if err != nil {

http2/transport.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,7 +1121,7 @@ func (cc *ClientConn) decrStreamReservationsLocked() {
11211121
}
11221122

11231123
func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
1124-
if req.Method == "CONNECT" && req.Header.Get("HACK-HTTP2-Protocol") != "" {
1124+
if req.Method == "CONNECT" && req.Header.Get("Hack-Http2-Protocol") != "" {
11251125
// This is an extended CONNECT https://datatracker.ietf.org/doc/html/rfc8441#section-4
11261126
// We need to check if the server supports it.
11271127
if err := cc.checkServerSupportsExtendedConnect(); err != nil {
@@ -1783,7 +1783,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
17831783
return nil, err
17841784
}
17851785

1786-
protocol := req.Header.Get("HACK-HTTP2-Protocol")
1786+
protocol := req.Header.Get("Hack-Http2-Protocol")
17871787

17881788
var path string
17891789
if req.Method != "CONNECT" || (cc.serverAllowsExtendedConnect && protocol != "") {

websocket/client.go

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@ package websocket
66

77
import (
88
"bufio"
9+
"crypto/tls"
10+
"errors"
11+
"fmt"
912
"io"
1013
"net"
1114
"net/http"
1215
"net/url"
16+
"strings"
1317
)
1418

1519
// DialError is an error that occurs while dialling a websocket server.
@@ -79,13 +83,22 @@ func parseAuthority(location *url.URL) string {
7983

8084
// DialConfig opens a new client connection to a WebSocket with a config.
8185
func DialConfig(config *Config) (ws *Conn, err error) {
82-
var client net.Conn
8386
if config.Location == nil {
8487
return nil, &DialError{config, ErrBadWebSocketLocation}
8588
}
8689
if config.Origin == nil {
8790
return nil, &DialError{config, ErrBadWebSocketOrigin}
8891
}
92+
93+
if config.HTTP2Transport != nil {
94+
return dialHTTP2(config)
95+
}
96+
97+
return dialHTTP1(config)
98+
}
99+
100+
func dialHTTP1(config *Config) (ws *Conn, err error) {
101+
var client net.Conn
89102
dialer := config.Dialer
90103
if dialer == nil {
91104
dialer = &net.Dialer{}
@@ -104,3 +117,97 @@ func DialConfig(config *Config) (ws *Conn, err error) {
104117
Error:
105118
return nil, &DialError{config, err}
106119
}
120+
121+
func dialHTTP2(config *Config) (ws *Conn, err error) {
122+
// Respect tls config set on the top level config if the transport doesn't
123+
// already have one set.
124+
if config.TlsConfig != nil && config.HTTP2Transport.TLSClientConfig == nil {
125+
config.HTTP2Transport.TLSClientConfig = config.TlsConfig
126+
}
127+
128+
// try to respect the dialer configured in the websocket config
129+
if config.Dialer != nil && config.HTTP2Transport.DialTLS == nil {
130+
config.HTTP2Transport.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
131+
d := tls.Dialer{NetDialer: config.Dialer, Config: cfg}
132+
return d.Dial(network, addr)
133+
}
134+
}
135+
136+
if config.Location.Scheme == "ws" && !config.HTTP2Transport.AllowHTTP {
137+
return nil, &DialError{Config: config, Err: errors.New("HTTP/2 requires TLS")}
138+
}
139+
140+
if config.Version != ProtocolVersionHybi13 {
141+
return nil, &DialError{Config: config, Err: ErrBadProtocolVersion}
142+
}
143+
144+
// https://datatracker.ietf.org/doc/html/rfc8441#section-5
145+
// 'The scheme of the target URI (Section 5.1 of [RFC7230]) MUST be
146+
// "https" for "wss"-schemed WebSockets and "http" for "ws"-schemed
147+
// WebSockets.'
148+
if config.Location.Scheme == "wss" {
149+
config.Location.Scheme = "https"
150+
}
151+
if config.Location.Scheme == "ws" {
152+
config.Location.Scheme = "http"
153+
}
154+
155+
// TODO(ethan): replace pipe with something context cancelable
156+
sr, sw := io.Pipe()
157+
req, err := http.NewRequest("CONNECT", config.Location.String(), sr)
158+
if err != nil {
159+
return nil, &DialError{Config: config, Err: err}
160+
}
161+
162+
req.Header.Add("Hack-Http2-Protocol", "websocket")
163+
req.Header.Add("Origin", config.Origin.String())
164+
req.Header.Add("Sec-Websocket-Version", fmt.Sprintf("%d", config.Version))
165+
if len(config.Protocol) > 0 {
166+
req.Header.Add("Sec-Websocket-Protocol", strings.Join(config.Protocol, ","))
167+
}
168+
169+
// inject user supplied headers, if any
170+
for k, vals := range config.Header {
171+
req.Header[k] = vals
172+
}
173+
174+
resp, err := config.HTTP2Transport.RoundTrip(req)
175+
if err != nil {
176+
return nil, &DialError{Config: config, Err: err}
177+
}
178+
179+
// check response headers and status
180+
181+
if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
182+
// we don't support any extentions
183+
return nil, &DialError{Config: config, Err: ErrUnsupportedExtensions}
184+
}
185+
186+
if resp.StatusCode != http.StatusOK {
187+
return nil, &DialError{Config: config, Err: ErrBadStatus}
188+
}
189+
190+
// TODO(ethan): this logic is copied from the HTTP/1.1 branch.
191+
// I should refactor to consolidate.
192+
offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
193+
if offeredProtocol != "" {
194+
protocolMatched := false
195+
for i := 0; i < len(config.Protocol); i++ {
196+
if config.Protocol[i] == offeredProtocol {
197+
protocolMatched = true
198+
break
199+
}
200+
}
201+
if !protocolMatched {
202+
return nil, &DialError{Config: config, Err: ErrBadWebSocketProtocol}
203+
}
204+
config.Protocol = []string{offeredProtocol}
205+
}
206+
207+
// The handshake is complete, so we wrap things up in a Conn and return.
208+
stream := newHTTP2ClientStream(sw, resp)
209+
buf := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
210+
conn := newHybiClientConn(config, buf, stream)
211+
212+
return conn, nil
213+
}

websocket/http2.go

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
package websocket
2+
3+
import (
4+
"bufio"
5+
"errors"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"strings"
10+
)
11+
12+
// http2Handshaker performs a websocket handshake over an HTTP/2 connection.
13+
// It is similar to a serverHandshaker, but doesn't use quite the same
14+
// interface due to differences in the underlying transport protocol.
15+
type http2Handshaker struct {
16+
// The server's config.
17+
config *Config
18+
// The user-supplied userHandshake callback.
19+
userHandshake func(*Config, *http.Request) error
20+
}
21+
22+
// handshake performs a handshake for an HTTP/2 connection and returns a
23+
// websocket connection or an HTTP status code and an error. The status
24+
// code is only valid if the error is non-nil.
25+
func (h *http2Handshaker) handshake(w http.ResponseWriter, req *http.Request) (conn *Conn, statusCode int, err error) {
26+
statusCode, err = h.checkHeaders(req)
27+
if err != nil {
28+
return nil, statusCode, err
29+
}
30+
31+
// allow the user to perform protocol negotiation
32+
err = h.userHandshake(h.config, req)
33+
if err != nil {
34+
return nil, http.StatusForbidden, ErrBadHandshake
35+
}
36+
37+
// All the headers we've been sent check out, so we can write
38+
// a 200 response and inform the client if we have chosen a particular
39+
// application protocol.
40+
if len(h.config.Protocol) > 0 {
41+
w.Header().Add("Sec-Websocket-Protocol", h.config.Protocol[0])
42+
}
43+
w.WriteHeader(http.StatusOK)
44+
45+
// Flush to force the status onto the wire so that clients can start
46+
// listening.
47+
flusher, ok := w.(http.Flusher)
48+
if !ok {
49+
return nil, http.StatusInternalServerError, errors.New("websocket: response writer must implement flusher")
50+
}
51+
flusher.Flush()
52+
53+
// to get a conn, we need a buffered readwriter, a readwritecloser, and
54+
// the request
55+
stream := newHTTP2ServerStream(w, req)
56+
buf := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
57+
conn = newHybiConn(h.config, buf, stream, req)
58+
return conn, 0, err
59+
}
60+
61+
func (h *http2Handshaker) checkHeaders(req *http.Request) (statusCode int, err error) {
62+
// TODO(ethan): write tests for all of these checks
63+
if req.Method != "CONNECT" {
64+
return http.StatusMethodNotAllowed, ErrBadRequestMethod
65+
}
66+
67+
protocol := req.Header.Get("Hack-Http2-Protocol")
68+
if protocol != "websocket" {
69+
return http.StatusBadRequest, ErrBadProtocol
70+
}
71+
72+
// "On requests that contain the :protocol pseudo-header field, the
73+
// :scheme and :path pseudo-header fields of the target URI (see
74+
// Section 5) MUST also be included."
75+
if req.URL.Path == "" {
76+
return http.StatusBadRequest, ErrBadPath
77+
}
78+
79+
version := req.Header.Get("Sec-Websocket-Version")
80+
if version == "13" {
81+
h.config.Version = ProtocolVersionHybi13
82+
} else {
83+
return http.StatusBadRequest, ErrBadProtocolVersion
84+
}
85+
86+
// parse the list of request protocols
87+
protocolCSV := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
88+
if protocolCSV != "" {
89+
protocols := strings.Split(protocolCSV, ",")
90+
for i := 0; i < len(protocols); i++ {
91+
// It is ok to mutate Protocol like this because server takes its
92+
// receiver by value, not reference, so the whole thing is copied
93+
// for each request.
94+
h.config.Protocol = append(h.config.Protocol, strings.TrimSpace(protocols[i]))
95+
}
96+
}
97+
98+
return 0, nil
99+
}
100+
101+
//
102+
// http2ServerStream
103+
//
104+
105+
// http2ServerStream is a wrapper around a request and response writer that
106+
// implements io.ReadWriteCloser
107+
type http2ServerStream struct {
108+
w http.ResponseWriter
109+
flusher http.Flusher
110+
req *http.Request
111+
}
112+
113+
func newHTTP2ServerStream(w http.ResponseWriter, req *http.Request) *http2ServerStream {
114+
flusher, ok := w.(http.Flusher)
115+
if !ok {
116+
panic("websocket: response writer must implement flusher")
117+
}
118+
119+
return &http2ServerStream{
120+
w: w,
121+
flusher: flusher,
122+
req: req,
123+
}
124+
}
125+
126+
func (s *http2ServerStream) Read(p []byte) (n int, err error) {
127+
return s.req.Body.Read(p)
128+
}
129+
func (s *http2ServerStream) Write(p []byte) (n int, err error) {
130+
n, err = s.w.Write(p)
131+
if err != nil {
132+
return n, err
133+
}
134+
135+
// We flush every time since the main websocket code is going to wrap
136+
// this in a bufio.Writer and expect that when the bufio.Writer is flushed
137+
// the bytes actually land on the wire.
138+
s.flusher.Flush()
139+
140+
return n, err
141+
}
142+
func (s *http2ServerStream) Close() error {
143+
return s.req.Body.Close()
144+
}
145+
146+
//
147+
// http2ClientStream
148+
//
149+
150+
// http2ClientStream is a wrapper around a writer and an http response that
151+
// implements io.ReadWriteCloser
152+
type http2ClientStream struct {
153+
w *io.PipeWriter
154+
resp *http.Response
155+
}
156+
157+
func newHTTP2ClientStream(w *io.PipeWriter, resp *http.Response) *http2ClientStream {
158+
return &http2ClientStream{
159+
w: w,
160+
resp: resp,
161+
}
162+
}
163+
164+
func (s *http2ClientStream) Read(p []byte) (n int, err error) {
165+
return s.resp.Body.Read(p)
166+
}
167+
func (s *http2ClientStream) Write(p []byte) (n int, err error) {
168+
return s.w.Write(p)
169+
}
170+
func (s *http2ClientStream) Close() error {
171+
wErr := s.w.Close()
172+
rErr := s.resp.Body.Close()
173+
if wErr != nil && rErr != nil {
174+
return fmt.Errorf("client close: %s: %w", wErr, rErr)
175+
}
176+
if wErr != nil {
177+
return wErr
178+
}
179+
if rErr != nil {
180+
return rErr
181+
}
182+
return nil
183+
}

0 commit comments

Comments
 (0)