diff --git a/rpc/block_writer.go b/rpc/block_writer.go index 949b1e5f..a8bd7c43 100644 --- a/rpc/block_writer.go +++ b/rpc/block_writer.go @@ -110,7 +110,13 @@ func (bw *BlockWriter) Close() error { func (bw *BlockWriter) connectNext() error { address := getDatanodeAddress(bw.currentPipeline()[0]) - conn, err := net.DialTimeout("tcp", address, connectTimeout) + var conn net.Conn + var err error + if bw.namenode.BlockWriterDialTimeout != nil { + conn, err = bw.namenode.BlockWriterDialTimeout("tcp", address, connectTimeout) + } else { + conn, err = net.DialTimeout("tcp", address, connectTimeout) + } if err != nil { return err } diff --git a/rpc/namenode.go b/rpc/namenode.go index 15a1184e..45768880 100644 --- a/rpc/namenode.go +++ b/rpc/namenode.go @@ -7,6 +7,7 @@ import ( "io" "net" "sync" + "time" hadoop "github.com/colinmarc/hdfs/protocol/hadoop_common" "github.com/golang/protobuf/proto" @@ -23,12 +24,13 @@ const ( // NamenodeConnection represents an open connection to a namenode. type NamenodeConnection struct { - clientId []byte - clientName string - currentRequestID int - user string - conn net.Conn - reqLock sync.Mutex + clientId []byte + clientName string + currentRequestID int + user string + conn net.Conn + reqLock sync.Mutex + BlockWriterDialTimeout func(network, address string, timeout time.Duration) (net.Conn, error) } // NamenodeError represents an interepreted error from the Namenode, including