@@ -5109,10 +5109,28 @@ func TestTransportServerResetStreamAtHeaders(t *testing.T) {
5109
5109
res .Body .Close ()
5110
5110
}
5111
5111
5112
+ type trackingReader struct {
5113
+ rdr io.Reader
5114
+ wasRead uint32
5115
+ }
5116
+
5117
+ func (tr * trackingReader ) Read (p []byte ) (int , error ) {
5118
+ atomic .StoreUint32 (& tr .wasRead , 1 )
5119
+ return tr .rdr .Read (p )
5120
+ }
5121
+
5122
+ func (tr * trackingReader ) WasRead () bool {
5123
+ return atomic .LoadUint32 (& tr .wasRead ) != 0
5124
+ }
5125
+
5112
5126
func TestTransportExpectContinue (t * testing.T ) {
5113
5127
st := newServerTester (t , func (w http.ResponseWriter , r * http.Request ) {
5114
- io .Copy (io .Discard , r .Body )
5115
- return
5128
+ switch r .URL .Path {
5129
+ case "/reject" :
5130
+ w .WriteHeader (403 )
5131
+ default :
5132
+ io .Copy (io .Discard , r .Body )
5133
+ }
5116
5134
}, optOnlyServer )
5117
5135
defer st .Close ()
5118
5136
@@ -5130,31 +5148,54 @@ func TestTransportExpectContinue(t *testing.T) {
5130
5148
Transport : tr ,
5131
5149
}
5132
5150
5133
- reqCh := make (chan error )
5134
- startTime := time .Now ()
5151
+ testCases := []struct {
5152
+ Name string
5153
+ Path string
5154
+ Body * trackingReader
5155
+ ExpectedCode int
5156
+ ShouldRead bool
5157
+ }{
5158
+ {
5159
+ Name : "read-all" ,
5160
+ Path : "/" ,
5161
+ Body : & trackingReader {rdr : strings .NewReader ("hello" )},
5162
+ ExpectedCode : 200 ,
5163
+ ShouldRead : true ,
5164
+ },
5165
+ {
5166
+ Name : "reject" ,
5167
+ Path : "/reject" ,
5168
+ Body : & trackingReader {rdr : strings .NewReader ("hello" )},
5169
+ ExpectedCode : 403 ,
5170
+ ShouldRead : false ,
5171
+ },
5172
+ }
5135
5173
5136
- go func () {
5137
- req , err := http .NewRequest ("POST" , st .ts .URL , strings .NewReader ("hello" ))
5138
- if err != nil {
5139
- reqCh <- err
5140
- return
5141
- }
5142
- req .Header .Set ("Expect" , "100-continue" )
5143
- res , err := client .Do (req )
5144
- if err != nil {
5145
- reqCh <- err
5146
- return
5147
- }
5148
- reqCh <- res .Body .Close ()
5149
- }()
5174
+ for _ , tc := range testCases {
5175
+ t .Run (tc .Name , func (t * testing.T ) {
5176
+ startTime := time .Now ()
5150
5177
5151
- err = <- reqCh
5152
- if err != nil {
5153
- t .Fatal (err )
5154
- }
5155
- delta := time .Since (startTime )
5156
- if delta >= tr .ExpectContinueTimeout {
5157
- t .Error ("Request didn't resume after receiving 100 continue" )
5178
+ req , err := http .NewRequest ("POST" , st .ts .URL + tc .Path , tc .Body )
5179
+ if err != nil {
5180
+ t .Fatal (err )
5181
+ }
5182
+ req .Header .Set ("Expect" , "100-continue" )
5183
+ res , err := client .Do (req )
5184
+ if err != nil {
5185
+ t .Fatal (err )
5186
+ }
5187
+ res .Body .Close ()
5188
+
5189
+ if delta := time .Since (startTime ); delta >= tr .ExpectContinueTimeout {
5190
+ t .Error ("Request didn't finish before expect continue timeout" )
5191
+ }
5192
+ if res .StatusCode != tc .ExpectedCode {
5193
+ t .Errorf ("Unexpected status code, got %d, expected %d" , res .StatusCode , tc .ExpectedCode )
5194
+ }
5195
+ if tc .Body .WasRead () != tc .ShouldRead {
5196
+ t .Errorf ("Unexpected read status, got %v, expected %v" , tc .Body .WasRead (), tc .ShouldRead )
5197
+ }
5198
+ })
5158
5199
}
5159
5200
}
5160
5201
0 commit comments