|
| 1 | +package websocket |
| 2 | + |
| 3 | +import ( |
| 4 | + "bufio" |
| 5 | + "errors" |
| 6 | + "fmt" |
| 7 | + "io" |
| 8 | + "net/http" |
| 9 | + "strings" |
| 10 | +) |
| 11 | + |
| 12 | +// http2Handshaker performs a websocket handshake over an HTTP/2 connection. |
| 13 | +// It is similar to a serverHandshaker, but doesn't use quite the same |
| 14 | +// interface due to differences in the underlying transport protocol. |
| 15 | +type http2Handshaker struct { |
| 16 | + // The server's config. |
| 17 | + config *Config |
| 18 | + // The user-supplied userHandshake callback. |
| 19 | + userHandshake func(*Config, *http.Request) error |
| 20 | +} |
| 21 | + |
| 22 | +// handshake performs a handshake for an HTTP/2 connection and returns a |
| 23 | +// websocket connection or an HTTP status code and an error. The status |
| 24 | +// code is only valid if the error is non-nil. |
| 25 | +func (h *http2Handshaker) handshake(w http.ResponseWriter, req *http.Request) (conn *Conn, statusCode int, err error) { |
| 26 | + statusCode, err = h.checkHeaders(req) |
| 27 | + if err != nil { |
| 28 | + return nil, statusCode, err |
| 29 | + } |
| 30 | + |
| 31 | + // allow the user to perform protocol negotiation |
| 32 | + err = h.userHandshake(h.config, req) |
| 33 | + if err != nil { |
| 34 | + return nil, http.StatusForbidden, ErrBadHandshake |
| 35 | + } |
| 36 | + |
| 37 | + // All the headers we've been sent check out, so we can write |
| 38 | + // a 200 response and inform the client if we have chosen a particular |
| 39 | + // application protocol. |
| 40 | + if len(h.config.Protocol) > 0 { |
| 41 | + w.Header().Add("Sec-Websocket-Protocol", h.config.Protocol[0]) |
| 42 | + } |
| 43 | + w.WriteHeader(http.StatusOK) |
| 44 | + |
| 45 | + // Flush to force the status onto the wire so that clients can start |
| 46 | + // listening. |
| 47 | + flusher, ok := w.(http.Flusher) |
| 48 | + if !ok { |
| 49 | + return nil, http.StatusInternalServerError, errors.New("websocket: response writer must implement flusher") |
| 50 | + } |
| 51 | + flusher.Flush() |
| 52 | + |
| 53 | + // to get a conn, we need a buffered readwriter, a readwritecloser, and |
| 54 | + // the request |
| 55 | + stream := newHTTP2ServerStream(w, req) |
| 56 | + buf := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream)) |
| 57 | + conn = newHybiConn(h.config, buf, stream, req) |
| 58 | + return conn, 0, err |
| 59 | +} |
| 60 | + |
| 61 | +func (h *http2Handshaker) checkHeaders(req *http.Request) (statusCode int, err error) { |
| 62 | + // TODO(ethan): write tests for all of these checks |
| 63 | + if req.Method != "CONNECT" { |
| 64 | + return http.StatusMethodNotAllowed, ErrBadRequestMethod |
| 65 | + } |
| 66 | + |
| 67 | + protocol := req.Header.Get("Hack-Http2-Protocol") |
| 68 | + if protocol != "websocket" { |
| 69 | + return http.StatusBadRequest, ErrBadProtocol |
| 70 | + } |
| 71 | + |
| 72 | + // "On requests that contain the :protocol pseudo-header field, the |
| 73 | + // :scheme and :path pseudo-header fields of the target URI (see |
| 74 | + // Section 5) MUST also be included." |
| 75 | + if req.URL.Path == "" { |
| 76 | + return http.StatusBadRequest, ErrBadPath |
| 77 | + } |
| 78 | + |
| 79 | + version := req.Header.Get("Sec-Websocket-Version") |
| 80 | + if version == "13" { |
| 81 | + h.config.Version = ProtocolVersionHybi13 |
| 82 | + } else { |
| 83 | + return http.StatusBadRequest, ErrBadProtocolVersion |
| 84 | + } |
| 85 | + |
| 86 | + // parse the list of request protocols |
| 87 | + protocolCSV := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol")) |
| 88 | + if protocolCSV != "" { |
| 89 | + protocols := strings.Split(protocolCSV, ",") |
| 90 | + for i := 0; i < len(protocols); i++ { |
| 91 | + // It is ok to mutate Protocol like this because server takes its |
| 92 | + // receiver by value, not reference, so the whole thing is copied |
| 93 | + // for each request. |
| 94 | + h.config.Protocol = append(h.config.Protocol, strings.TrimSpace(protocols[i])) |
| 95 | + } |
| 96 | + } |
| 97 | + |
| 98 | + return 0, nil |
| 99 | +} |
| 100 | + |
| 101 | +// |
| 102 | +// http2ServerStream |
| 103 | +// |
| 104 | + |
| 105 | +// http2ServerStream is a wrapper around a request and response writer that |
| 106 | +// implements io.ReadWriteCloser |
| 107 | +type http2ServerStream struct { |
| 108 | + w http.ResponseWriter |
| 109 | + flusher http.Flusher |
| 110 | + req *http.Request |
| 111 | +} |
| 112 | + |
| 113 | +func newHTTP2ServerStream(w http.ResponseWriter, req *http.Request) *http2ServerStream { |
| 114 | + flusher, ok := w.(http.Flusher) |
| 115 | + if !ok { |
| 116 | + panic("websocket: response writer must implement flusher") |
| 117 | + } |
| 118 | + |
| 119 | + return &http2ServerStream{ |
| 120 | + w: w, |
| 121 | + flusher: flusher, |
| 122 | + req: req, |
| 123 | + } |
| 124 | +} |
| 125 | + |
| 126 | +func (s *http2ServerStream) Read(p []byte) (n int, err error) { |
| 127 | + return s.req.Body.Read(p) |
| 128 | +} |
| 129 | +func (s *http2ServerStream) Write(p []byte) (n int, err error) { |
| 130 | + n, err = s.w.Write(p) |
| 131 | + if err != nil { |
| 132 | + return n, err |
| 133 | + } |
| 134 | + |
| 135 | + // We flush every time since the main websocket code is going to wrap |
| 136 | + // this in a bufio.Writer and expect that when the bufio.Writer is flushed |
| 137 | + // the bytes actually land on the wire. |
| 138 | + s.flusher.Flush() |
| 139 | + |
| 140 | + return n, err |
| 141 | +} |
| 142 | +func (s *http2ServerStream) Close() error { |
| 143 | + return s.req.Body.Close() |
| 144 | +} |
| 145 | + |
| 146 | +// |
| 147 | +// http2ClientStream |
| 148 | +// |
| 149 | + |
| 150 | +// http2ClientStream is a wrapper around a writer and an http response that |
| 151 | +// implements io.ReadWriteCloser |
| 152 | +type http2ClientStream struct { |
| 153 | + w *io.PipeWriter |
| 154 | + resp *http.Response |
| 155 | +} |
| 156 | + |
| 157 | +func newHTTP2ClientStream(w *io.PipeWriter, resp *http.Response) *http2ClientStream { |
| 158 | + return &http2ClientStream{ |
| 159 | + w: w, |
| 160 | + resp: resp, |
| 161 | + } |
| 162 | +} |
| 163 | + |
| 164 | +func (s *http2ClientStream) Read(p []byte) (n int, err error) { |
| 165 | + return s.resp.Body.Read(p) |
| 166 | +} |
| 167 | +func (s *http2ClientStream) Write(p []byte) (n int, err error) { |
| 168 | + return s.w.Write(p) |
| 169 | +} |
| 170 | +func (s *http2ClientStream) Close() error { |
| 171 | + wErr := s.w.Close() |
| 172 | + rErr := s.resp.Body.Close() |
| 173 | + if wErr != nil && rErr != nil { |
| 174 | + return fmt.Errorf("client close: %s: %w", wErr, rErr) |
| 175 | + } |
| 176 | + if wErr != nil { |
| 177 | + return wErr |
| 178 | + } |
| 179 | + if rErr != nil { |
| 180 | + return rErr |
| 181 | + } |
| 182 | + return nil |
| 183 | +} |
0 commit comments