@@ -16,6 +16,7 @@ import (
16
16
"log"
17
17
. "net/http"
18
18
"net/http/httptest"
19
+ "net/url"
19
20
"os"
20
21
"reflect"
21
22
"sort"
@@ -675,3 +676,61 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) {
675
676
t .Errorf ("read %q; want %q" , data , resBody )
676
677
}
677
678
}
679
+
680
+ func TestConnectRequest_h1 (t * testing.T ) { testConnectRequest (t , h1Mode ) }
681
+ func TestConnectRequest_h2 (t * testing.T ) { testConnectRequest (t , h2Mode ) }
682
+ func testConnectRequest (t * testing.T , h2 bool ) {
683
+ defer afterTest (t )
684
+ gotc := make (chan * Request , 1 )
685
+ cst := newClientServerTest (t , h2 , HandlerFunc (func (w ResponseWriter , r * Request ) {
686
+ gotc <- r
687
+ }))
688
+ defer cst .close ()
689
+
690
+ u , err := url .Parse (cst .ts .URL )
691
+ if err != nil {
692
+ t .Fatal (err )
693
+ }
694
+
695
+ tests := []struct {
696
+ req * Request
697
+ want string
698
+ }{
699
+ {
700
+ req : & Request {
701
+ Method : "CONNECT" ,
702
+ Header : Header {},
703
+ URL : u ,
704
+ },
705
+ want : u .Host ,
706
+ },
707
+ {
708
+ req : & Request {
709
+ Method : "CONNECT" ,
710
+ Header : Header {},
711
+ URL : u ,
712
+ Host : "example.com:123" ,
713
+ },
714
+ want : "example.com:123" ,
715
+ },
716
+ }
717
+
718
+ for i , tt := range tests {
719
+ res , err := cst .c .Do (tt .req )
720
+ if err != nil {
721
+ t .Errorf ("%d. RoundTrip = %v" , i , err )
722
+ continue
723
+ }
724
+ res .Body .Close ()
725
+ req := <- gotc
726
+ if req .Method != "CONNECT" {
727
+ t .Errorf ("method = %q; want CONNECT" , req .Method )
728
+ }
729
+ if req .Host != tt .want {
730
+ t .Errorf ("Host = %q; want %q" , req .Host , tt .want )
731
+ }
732
+ if req .URL .Host != tt .want {
733
+ t .Errorf ("URL.Host = %q; want %q" , req .URL .Host , tt .want )
734
+ }
735
+ }
736
+ }
0 commit comments