Skip to content

Commit b4be494

Browse files
bmizeranybradfitz
authored andcommitted
http2: add server-side trailer support
Change-Id: I39dbf0cdeee0123b6c6efff1fc6854bcedb94753 Reviewed-on: https://go-review.googlesource.com/17878 Reviewed-by: Brad Fitzpatrick <[email protected]>
1 parent c24de9d commit b4be494

File tree

3 files changed

+121
-29
lines changed

3 files changed

+121
-29
lines changed

http2/server.go

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ import (
4646
"log"
4747
"net"
4848
"net/http"
49+
"net/textproto"
4950
"net/url"
5051
"runtime"
5152
"strconv"
@@ -1877,6 +1878,7 @@ type responseWriterState struct {
18771878
// mutated by http.Handler goroutine:
18781879
handlerHeader http.Header // nil until called
18791880
snapHeader http.Header // snapshot of handlerHeader at WriteHeader time
1881+
trailers []string // set in writeChunk
18801882
status int // status code passed to WriteHeader
18811883
wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet.
18821884
sentHeader bool // have we sent the header frame?
@@ -1893,6 +1895,21 @@ type chunkWriter struct{ rws *responseWriterState }
18931895

18941896
func (cw chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) }
18951897

1898+
func (rws *responseWriterState) hasTrailers() bool { return len(rws.trailers) != 0 }
1899+
1900+
// declareTrailer is called for each Trailer header when the
1901+
// response header is written. It notes that a header will need to be
1902+
// written in the trailers at the end of the response.
1903+
func (rws *responseWriterState) declareTrailer(k string) {
1904+
k = http.CanonicalHeaderKey(k)
1905+
switch k {
1906+
case "Transfer-Encoding", "Content-Length", "Trailer":
1907+
// Forbidden by RFC 2616 14.40.
1908+
return
1909+
}
1910+
rws.trailers = append(rws.trailers, k)
1911+
}
1912+
18961913
// writeChunk writes chunks from the bufio.Writer. But because
18971914
// bufio.Writer may bypass its chunking, sometimes p may be
18981915
// arbitrarily large.
@@ -1903,6 +1920,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
19031920
if !rws.wroteHeader {
19041921
rws.writeHeader(200)
19051922
}
1923+
19061924
isHeadResp := rws.req.Method == "HEAD"
19071925
if !rws.sentHeader {
19081926
rws.sentHeader = true
@@ -1928,7 +1946,12 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
19281946
// TODO(bradfitz): be faster here, like net/http? measure.
19291947
date = time.Now().UTC().Format(http.TimeFormat)
19301948
}
1931-
endStream := (rws.handlerDone && len(p) == 0) || isHeadResp
1949+
1950+
for _, v := range rws.snapHeader["Trailer"] {
1951+
foreachHeaderElement(v, rws.declareTrailer)
1952+
}
1953+
1954+
endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp
19321955
err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
19331956
streamID: rws.stream.id,
19341957
httpResCode: rws.status,
@@ -1952,8 +1975,22 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
19521975
return 0, nil
19531976
}
19541977

1955-
if err := rws.conn.writeDataFromHandler(rws.stream, p, rws.handlerDone); err != nil {
1956-
return 0, err
1978+
endStream := rws.handlerDone && !rws.hasTrailers()
1979+
if len(p) > 0 || endStream {
1980+
// only send a 0 byte DATA frame if we're ending the stream.
1981+
if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil {
1982+
return 0, err
1983+
}
1984+
}
1985+
1986+
if rws.handlerDone && rws.hasTrailers() {
1987+
err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
1988+
streamID: rws.stream.id,
1989+
h: rws.handlerHeader,
1990+
trailers: rws.trailers,
1991+
endStream: true,
1992+
})
1993+
return len(p), err
19571994
}
19581995
return len(p), nil
19591996
}
@@ -2083,3 +2120,21 @@ func (w *responseWriter) handlerDone() {
20832120
w.rws = nil
20842121
responseWriterStatePool.Put(rws)
20852122
}
2123+
2124+
// foreachHeaderElement splits v according to the "#rule" construction
2125+
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
2126+
func foreachHeaderElement(v string, fn func(string)) {
2127+
v = textproto.TrimString(v)
2128+
if v == "" {
2129+
return
2130+
}
2131+
if !strings.Contains(v, ",") {
2132+
fn(v)
2133+
return
2134+
}
2135+
for _, f := range strings.Split(v, ",") {
2136+
if f = textproto.TrimString(f); f != "" {
2137+
fn(f)
2138+
}
2139+
}
2140+
}

http2/server_test.go

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2515,17 +2515,32 @@ func TestServerReadsTrailers(t *testing.T) {
25152515
}
25162516

25172517
// test that a server handler can send trailers
2518-
func TestServerWritesTrailers(t *testing.T) {
2519-
t.Skip("known failing test; see golang.org/issue/13557")
2518+
func TestServerWritesTrailers_WithFlush(t *testing.T) { testServerWritesTrailers(t, true) }
2519+
func TestServerWritesTrailers_WithoutFlush(t *testing.T) { testServerWritesTrailers(t, false) }
2520+
2521+
func testServerWritesTrailers(t *testing.T, withFlush bool) {
25202522
// See https://httpwg.github.io/specs/rfc7540.html#rfc.section.8.1.3
25212523
testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
25222524
w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
25232525
w.Header().Add("Trailer", "Server-Trailer-C")
2526+
2527+
// TODO: decide if the server should filter these while
2528+
// writing the Trailer header in the response. Currently it
2529+
// appears net/http doesn't do this for http/1.1
2530+
w.Header().Add("Trailer", "Transfer-Encoding, Content-Length, Trailer") // filtered
25242531
w.Header().Set("Foo", "Bar")
2532+
w.Header().Set("Content-Length", "5")
2533+
25252534
io.WriteString(w, "Hello")
2526-
w.(http.Flusher).Flush()
2535+
if withFlush {
2536+
w.(http.Flusher).Flush()
2537+
}
25272538
w.Header().Set("Server-Trailer-A", "valuea")
25282539
w.Header().Set("Server-Trailer-C", "valuec") // skipping B
2540+
w.Header().Set("Server-Surpise", "surprise! this isn't predeclared!")
2541+
w.Header().Set("Transfer-Encoding", "should not be included; Forbidden by RFC 2616 14.40")
2542+
w.Header().Set("Content-Length", "should not be included; Forbidden by RFC 2616 14.40")
2543+
w.Header().Set("Trailer", "should not be included; Forbidden by RFC 2616 14.40")
25292544
return nil
25302545
}, func(st *serverTester) {
25312546
getSlash(st)
@@ -2542,7 +2557,9 @@ func TestServerWritesTrailers(t *testing.T) {
25422557
{"foo", "Bar"},
25432558
{"trailer", "Server-Trailer-A, Server-Trailer-B"},
25442559
{"trailer", "Server-Trailer-C"},
2560+
{"trailer", "Transfer-Encoding, Content-Length, Trailer"},
25452561
{"content-type", "text/plain; charset=utf-8"},
2562+
{"content-length", "5"},
25462563
}
25472564
if !reflect.DeepEqual(goth, wanth) {
25482565
t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
@@ -2561,8 +2578,14 @@ func TestServerWritesTrailers(t *testing.T) {
25612578
if !tf.HeadersEnded() {
25622579
t.Fatalf("trailers HEADERS lacked END_HEADERS")
25632580
}
2564-
pairs := st.decodeHeader(tf.HeaderBlockFragment())
2565-
t.Logf("Got: %v", pairs)
2581+
wanth = [][2]string{
2582+
{"server-trailer-a", "valuea"},
2583+
{"server-trailer-c", "valuec"},
2584+
}
2585+
goth = st.decodeHeader(tf.HeaderBlockFragment())
2586+
if !reflect.DeepEqual(goth, wanth) {
2587+
t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
2588+
}
25662589
})
25672590
}
25682591

http2/write.go

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,12 @@ func (writeSettingsAck) writeFrame(ctx writeContext) error {
123123
}
124124

125125
// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames
126-
// for HTTP response headers from a server handler.
126+
// for HTTP response headers or trailers from a server handler.
127127
type writeResHeaders struct {
128128
streamID uint32
129-
httpResCode int
129+
httpResCode int // 0 means no ":status" line
130130
h http.Header // may be nil
131+
trailers []string // if non-nil, which keys of h to write. nil means all.
131132
endStream bool
132133

133134
date string
@@ -138,26 +139,16 @@ type writeResHeaders struct {
138139
func (w *writeResHeaders) writeFrame(ctx writeContext) error {
139140
enc, buf := ctx.HeaderEncoder()
140141
buf.Reset()
141-
enc.WriteField(hpack.HeaderField{Name: ":status", Value: httpCodeString(w.httpResCode)})
142142

143-
// TODO: garbage. pool sorters like http1? hot path for 1 key?
144-
keys := make([]string, 0, len(w.h))
145-
for k := range w.h {
146-
keys = append(keys, k)
147-
}
148-
sort.Strings(keys)
149-
for _, k := range keys {
150-
vv := w.h[k]
151-
k = lowerHeader(k)
152-
isTE := k == "transfer-encoding"
153-
for _, v := range vv {
154-
// TODO: more of "8.1.2.2 Connection-Specific Header Fields"
155-
if isTE && v != "trailers" {
156-
continue
157-
}
158-
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
159-
}
143+
if w.httpResCode != 0 {
144+
enc.WriteField(hpack.HeaderField{
145+
Name: ":status",
146+
Value: httpCodeString(w.httpResCode),
147+
})
160148
}
149+
150+
encodeHeaders(enc, w.h, w.trailers)
151+
161152
if w.contentType != "" {
162153
enc.WriteField(hpack.HeaderField{Name: "content-type", Value: w.contentType})
163154
}
@@ -169,7 +160,7 @@ func (w *writeResHeaders) writeFrame(ctx writeContext) error {
169160
}
170161

171162
headerBlock := buf.Bytes()
172-
if len(headerBlock) == 0 {
163+
if len(headerBlock) == 0 && w.trailers == nil {
173164
panic("unexpected empty hpack")
174165
}
175166

@@ -232,3 +223,26 @@ type writeWindowUpdate struct {
232223
func (wu writeWindowUpdate) writeFrame(ctx writeContext) error {
233224
return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n)
234225
}
226+
227+
func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) {
228+
// TODO: garbage. pool sorters like http1? hot path for 1 key?
229+
if keys == nil {
230+
keys = make([]string, 0, len(h))
231+
for k := range h {
232+
keys = append(keys, k)
233+
}
234+
sort.Strings(keys)
235+
}
236+
for _, k := range keys {
237+
vv := h[k]
238+
k = lowerHeader(k)
239+
isTE := k == "transfer-encoding"
240+
for _, v := range vv {
241+
// TODO: more of "8.1.2.2 Connection-Specific Header Fields"
242+
if isTE && v != "trailers" {
243+
continue
244+
}
245+
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
246+
}
247+
}
248+
}

0 commit comments

Comments
 (0)