@@ -323,6 +323,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
323
323
cacheKey : cm .String (),
324
324
conn : conn ,
325
325
reqch : make (chan requestAndChan , 50 ),
326
+ writech : make (chan writeRequest , 50 ),
326
327
}
327
328
328
329
switch {
@@ -380,6 +381,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
380
381
pconn .br = bufio .NewReader (pconn .conn )
381
382
pconn .bw = bufio .NewWriter (pconn .conn )
382
383
go pconn .readLoop ()
384
+ go pconn .writeLoop ()
383
385
return pconn , nil
384
386
}
385
387
@@ -487,7 +489,8 @@ type persistConn struct {
487
489
closed bool // whether conn has been closed
488
490
br * bufio.Reader // from conn
489
491
bw * bufio.Writer // to conn
490
- reqch chan requestAndChan // written by roundTrip(); read by readLoop()
492
+ reqch chan requestAndChan // written by roundTrip; read by readLoop
493
+ writech chan writeRequest // written by roundTrip; read by writeLoop
491
494
isProxy bool
492
495
493
496
// mutateHeaderFunc is an optional func to modify extra
@@ -519,6 +522,7 @@ func remoteSideClosed(err error) bool {
519
522
}
520
523
521
524
func (pc * persistConn ) readLoop () {
525
+ defer close (pc .writech )
522
526
alive := true
523
527
var lastbody io.ReadCloser // last response body, if any, read on this connection
524
528
@@ -579,7 +583,7 @@ func (pc *persistConn) readLoop() {
579
583
if alive && ! pc .t .putIdleConn (pc ) {
580
584
alive = false
581
585
}
582
- if ! alive {
586
+ if ! alive || pc . isBroken () {
583
587
pc .close ()
584
588
}
585
589
waitForBodyRead <- true
@@ -615,6 +619,23 @@ func (pc *persistConn) readLoop() {
615
619
}
616
620
}
617
621
622
+ func (pc * persistConn ) writeLoop () {
623
+ for wr := range pc .writech {
624
+ if pc .isBroken () {
625
+ wr .ch <- errors .New ("http: can't write HTTP request on broken connection" )
626
+ continue
627
+ }
628
+ err := wr .req .Request .write (pc .bw , pc .isProxy , wr .req .extra )
629
+ if err == nil {
630
+ err = pc .bw .Flush ()
631
+ }
632
+ if err != nil {
633
+ pc .markBroken ()
634
+ }
635
+ wr .ch <- err
636
+ }
637
+ }
638
+
618
639
type responseAndError struct {
619
640
res * Response
620
641
err error
@@ -630,6 +651,15 @@ type requestAndChan struct {
630
651
addedGzip bool
631
652
}
632
653
654
+ // A writeRequest is sent by the readLoop's goroutine to the
655
+ // writeLoop's goroutine to write a request while the read loop
656
+ // concurrently waits on both the write response and the server's
657
+ // reply.
658
+ type writeRequest struct {
659
+ req * transportRequest
660
+ ch chan <- error
661
+ }
662
+
633
663
func (pc * persistConn ) roundTrip (req * transportRequest ) (resp * Response , err error ) {
634
664
if pc .mutateHeaderFunc != nil {
635
665
pc .mutateHeaderFunc (req .extraHeaders ())
@@ -652,23 +682,45 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
652
682
pc .numExpectedResponses ++
653
683
pc .lk .Unlock ()
654
684
655
- err = req .Request .write (pc .bw , pc .isProxy , req .extra )
656
- if err != nil {
657
- pc .close ()
658
- return
685
+ // Write the request concurrently with waiting for a response,
686
+ // in case the server decides to reply before reading our full
687
+ // request body.
688
+ writeErrCh := make (chan error , 1 )
689
+ pc .writech <- writeRequest {req , writeErrCh }
690
+
691
+ resc := make (chan responseAndError , 1 )
692
+ pc .reqch <- requestAndChan {req .Request , resc , requestedGzip }
693
+
694
+ var re responseAndError
695
+ WaitResponse:
696
+ for {
697
+ select {
698
+ case err := <- writeErrCh :
699
+ if err != nil {
700
+ re = responseAndError {nil , err }
701
+ break WaitResponse
702
+ }
703
+ case re = <- resc :
704
+ break WaitResponse
705
+ }
659
706
}
660
- pc .bw .Flush ()
661
707
662
- ch := make (chan responseAndError , 1 )
663
- pc .reqch <- requestAndChan {req .Request , ch , requestedGzip }
664
- re := <- ch
665
708
pc .lk .Lock ()
666
709
pc .numExpectedResponses --
667
710
pc .lk .Unlock ()
668
711
669
712
return re .res , re .err
670
713
}
671
714
715
+ // markBroken marks a connection as broken (so it's not reused).
716
+ // It differs from close in that it doesn't close the underlying
717
+ // connection for use when it's still being read.
718
+ func (pc * persistConn ) markBroken () {
719
+ pc .lk .Lock ()
720
+ defer pc .lk .Unlock ()
721
+ pc .broken = true
722
+ }
723
+
672
724
func (pc * persistConn ) close () {
673
725
pc .lk .Lock ()
674
726
defer pc .lk .Unlock ()
0 commit comments