@@ -5109,6 +5109,96 @@ func TestTransportServerResetStreamAtHeaders(t *testing.T) {
5109
5109
res .Body .Close ()
5110
5110
}
5111
5111
5112
+ type trackingReader struct {
5113
+ rdr io.Reader
5114
+ wasRead uint32
5115
+ }
5116
+
5117
+ func (tr * trackingReader ) Read (p []byte ) (int , error ) {
5118
+ atomic .StoreUint32 (& tr .wasRead , 1 )
5119
+ return tr .rdr .Read (p )
5120
+ }
5121
+
5122
+ func (tr * trackingReader ) WasRead () bool {
5123
+ return atomic .LoadUint32 (& tr .wasRead ) != 0
5124
+ }
5125
+
5126
+ func TestTransportExpectContinue (t * testing.T ) {
5127
+ st := newServerTester (t , func (w http.ResponseWriter , r * http.Request ) {
5128
+ switch r .URL .Path {
5129
+ case "/reject" :
5130
+ w .WriteHeader (403 )
5131
+ default :
5132
+ io .Copy (io .Discard , r .Body )
5133
+ }
5134
+ }, optOnlyServer )
5135
+ defer st .Close ()
5136
+
5137
+ tr := & http.Transport {
5138
+ TLSClientConfig : tlsConfigInsecure ,
5139
+ MaxConnsPerHost : 1 ,
5140
+ ExpectContinueTimeout : 10 * time .Second ,
5141
+ }
5142
+
5143
+ err := ConfigureTransport (tr )
5144
+ if err != nil {
5145
+ t .Fatal (err )
5146
+ }
5147
+ client := & http.Client {
5148
+ Transport : tr ,
5149
+ }
5150
+
5151
+ testCases := []struct {
5152
+ Name string
5153
+ Path string
5154
+ Body * trackingReader
5155
+ ExpectedCode int
5156
+ ShouldRead bool
5157
+ }{
5158
+ {
5159
+ Name : "read-all" ,
5160
+ Path : "/" ,
5161
+ Body : & trackingReader {rdr : strings .NewReader ("hello" )},
5162
+ ExpectedCode : 200 ,
5163
+ ShouldRead : true ,
5164
+ },
5165
+ {
5166
+ Name : "reject" ,
5167
+ Path : "/reject" ,
5168
+ Body : & trackingReader {rdr : strings .NewReader ("hello" )},
5169
+ ExpectedCode : 403 ,
5170
+ ShouldRead : false ,
5171
+ },
5172
+ }
5173
+
5174
+ for _ , tc := range testCases {
5175
+ t .Run (tc .Name , func (t * testing.T ) {
5176
+ startTime := time .Now ()
5177
+
5178
+ req , err := http .NewRequest ("POST" , st .ts .URL + tc .Path , tc .Body )
5179
+ if err != nil {
5180
+ t .Fatal (err )
5181
+ }
5182
+ req .Header .Set ("Expect" , "100-continue" )
5183
+ res , err := client .Do (req )
5184
+ if err != nil {
5185
+ t .Fatal (err )
5186
+ }
5187
+ res .Body .Close ()
5188
+
5189
+ if delta := time .Since (startTime ); delta >= tr .ExpectContinueTimeout {
5190
+ t .Error ("Request didn't finish before expect continue timeout" )
5191
+ }
5192
+ if res .StatusCode != tc .ExpectedCode {
5193
+ t .Errorf ("Unexpected status code, got %d, expected %d" , res .StatusCode , tc .ExpectedCode )
5194
+ }
5195
+ if tc .Body .WasRead () != tc .ShouldRead {
5196
+ t .Errorf ("Unexpected read status, got %v, expected %v" , tc .Body .WasRead (), tc .ShouldRead )
5197
+ }
5198
+ })
5199
+ }
5200
+ }
5201
+
5112
5202
type closeChecker struct {
5113
5203
io.ReadCloser
5114
5204
closed chan struct {}
0 commit comments