@@ -1523,6 +1523,24 @@ func TestOnProxyConnectResponse(t *testing.T) {
1523
1523
1524
1524
c := proxy .Client ()
1525
1525
1526
+ var (
1527
+ dials atomic.Int32
1528
+ closes atomic.Int32
1529
+ )
1530
+ c .Transport .(* Transport ).DialContext = func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1531
+ conn , err := net .Dial (network , addr )
1532
+ if err != nil {
1533
+ return nil , err
1534
+ }
1535
+ dials .Add (1 )
1536
+ return noteCloseConn {
1537
+ Conn : conn ,
1538
+ closeFunc : func () {
1539
+ closes .Add (1 )
1540
+ },
1541
+ }, nil
1542
+ }
1543
+
1526
1544
c .Transport .(* Transport ).Proxy = ProxyURL (pu )
1527
1545
c .Transport .(* Transport ).OnProxyConnectResponse = func (ctx context.Context , proxyURL * url.URL , connectReq * Request , connectRes * Response ) error {
1528
1546
if proxyURL .String () != pu .String () {
@@ -1534,10 +1552,23 @@ func TestOnProxyConnectResponse(t *testing.T) {
1534
1552
}
1535
1553
return tcase .err
1536
1554
}
1555
+ wantCloses := int32 (0 )
1537
1556
if _ , err := c .Head (ts .URL ); err != nil {
1557
+ wantCloses = 1
1538
1558
if tcase .err != nil && ! strings .Contains (err .Error (), tcase .err .Error ()) {
1539
1559
t .Errorf ("got %v, want %v" , err , tcase .err )
1540
1560
}
1561
+ } else {
1562
+ if tcase .err != nil {
1563
+ t .Errorf ("got %v, want nil" , err )
1564
+ }
1565
+ }
1566
+ if got , want := dials .Load (), int32 (1 ); got != want {
1567
+ t .Errorf ("got %v dials, want %v" , got , want )
1568
+ }
1569
+ // #64804: If OnProxyConnectResponse returns an error, we should close the conn.
1570
+ if got , want := closes .Load (), wantCloses ; got != want {
1571
+ t .Errorf ("got %v closes, want %v" , got , want )
1541
1572
}
1542
1573
}
1543
1574
}
0 commit comments