Skip to content

Commit 3b90a77

Browse files
committed
context/ctxhttp: allow cancellation after Do returns
Fixes #13325. Change-Id: I17f35232cd0ea43e50ea12db09272195789426e9 Reviewed-on: https://go-review.googlesource.com/18188 Reviewed-by: Brad Fitzpatrick <[email protected]>
1 parent 0ab0090 commit 3b90a77

File tree

3 files changed

+85
-1
lines changed

3 files changed

+85
-1
lines changed

context/ctxhttp/cancelreq.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package ctxhttp
99
import "net/http"
1010

1111
func canceler(client *http.Client, req *http.Request) func() {
12+
// TODO(djd): Respect any existing value of req.Cancel.
1213
ch := make(chan struct{})
1314
req.Cancel = ch
1415

context/ctxhttp/ctxhttp.go

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,32 @@ func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Resp
3636
result <- responseAndError{resp, err}
3737
}()
3838

39+
var resp *http.Response
40+
3941
select {
4042
case <-ctx.Done():
4143
cancel()
4244
return nil, ctx.Err()
4345
case r := <-result:
44-
return r.resp, r.err
46+
var err error
47+
resp, err = r.resp, r.err
48+
if err != nil {
49+
return resp, err
50+
}
4551
}
52+
53+
c := make(chan struct{})
54+
go func() {
55+
select {
56+
case <-ctx.Done():
57+
cancel()
58+
case <-c:
59+
// The response's Body is closed.
60+
}
61+
}()
62+
resp.Body = &notifyingReader{resp.Body, c}
63+
64+
return resp, nil
4665
}
4766

4867
// Get issues a GET request via the Do function.
@@ -77,3 +96,28 @@ func Post(ctx context.Context, client *http.Client, url string, bodyType string,
7796
func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) {
7897
return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
7998
}
99+
100+
// notifyingReader is an io.ReadCloser that closes the notify channel after
101+
// Close is called or a Read fails on the underlying ReadCloser.
102+
type notifyingReader struct {
103+
io.ReadCloser
104+
notify chan<- struct{}
105+
}
106+
107+
func (r *notifyingReader) Read(p []byte) (int, error) {
108+
n, err := r.ReadCloser.Read(p)
109+
if err != nil && r.notify != nil {
110+
close(r.notify)
111+
r.notify = nil
112+
}
113+
return n, err
114+
}
115+
116+
func (r *notifyingReader) Close() error {
117+
err := r.ReadCloser.Close()
118+
if r.notify != nil {
119+
close(r.notify)
120+
r.notify = nil
121+
}
122+
return err
123+
}

context/ctxhttp/ctxhttp_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ func TestNoTimeout(t *testing.T) {
2727
t.Fatalf("error received from client: %v %v", err, resp)
2828
}
2929
}
30+
3031
func TestCancel(t *testing.T) {
3132
ctx, cancel := context.WithCancel(context.Background())
3233
go func() {
@@ -59,6 +60,44 @@ func TestCancelAfterRequest(t *testing.T) {
5960
}
6061
}
6162

63+
func TestCancelAfterHangingRequest(t *testing.T) {
64+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
65+
w.WriteHeader(http.StatusOK)
66+
w.(http.Flusher).Flush()
67+
<-w.(http.CloseNotifier).CloseNotify()
68+
})
69+
70+
serv := httptest.NewServer(handler)
71+
defer serv.Close()
72+
73+
ctx, cancel := context.WithCancel(context.Background())
74+
resp, err := Get(ctx, nil, serv.URL)
75+
if err != nil {
76+
t.Fatalf("unexpected error in Get: %v", err)
77+
}
78+
79+
// Cancel befer reading the body.
80+
// Reading Request.Body should fail, since the request was
81+
// canceled before anything was written.
82+
cancel()
83+
84+
done := make(chan struct{})
85+
86+
go func() {
87+
b, err := ioutil.ReadAll(resp.Body)
88+
if len(b) != 0 || err == nil {
89+
t.Errorf(`Read got (%q, %v); want ("", error)`, b, err)
90+
}
91+
close(done)
92+
}()
93+
94+
select {
95+
case <-time.After(1 * time.Second):
96+
t.Errorf("Test timed out")
97+
case <-done:
98+
}
99+
}
100+
62101
func doRequest(ctx context.Context) (*http.Response, error) {
63102
var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
64103
time.Sleep(requestDuration)

0 commit comments

Comments
 (0)