77 "os"
88 "strconv"
99 "strings"
10- "sync"
10+ "sync/atomic"
11+ "time"
1112
1213 "golang.org/x/crypto/ssh"
1314)
@@ -24,8 +25,9 @@ func NewDialer(addr string) (*Dialer, error) {
2425
2526func NewDialerWithConfig (host string , config * ssh.ClientConfig ) (* Dialer , error ) {
2627 return & Dialer {
27- host : host ,
28- config : config ,
28+ host : host ,
29+ config : config ,
30+ clients : make (chan * ssh.Client , 5 ),
2931 }, nil
3032}
3133
@@ -69,6 +71,17 @@ func parseClientConfig(addr string) (*clientConfig, error) {
6971 config .Auth = append (config .Auth , ssh .PublicKeys (signer ))
7072 }
7173
74+ var timeout = 30 * time .Second
75+ timeoutStr := ur .Query ().Get ("timeout" )
76+ if timeoutStr != "" {
77+ timeout , err = time .ParseDuration (timeoutStr )
78+ if err != nil {
79+ return nil , err
80+ }
81+ }
82+
83+ config .Timeout = timeout
84+
7285 host := ur .Hostname ()
7386 port := ur .Port ()
7487 if port == "" {
@@ -82,7 +95,6 @@ func parseClientConfig(addr string) (*clientConfig, error) {
8295}
8396
8497type Dialer struct {
85- mut sync.Mutex
8698 localAddr net.Addr
8799 // ProxyDial specifies the optional dial function for
88100 // establishing the transport connection.
@@ -91,17 +103,12 @@ type Dialer struct {
91103 host string
92104 config * ssh.ClientConfig
93105
94- pool sync.Pool
106+ conns int32
107+ clients chan * ssh.Client
95108}
96109
97110func (d * Dialer ) Close () error {
98- for {
99- a := d .pool .Get ()
100- if a == nil {
101- break
102- }
103- a .(* ssh.Client ).Close ()
104- }
111+ // In practice, closing the connection doesn't actually release the ssh.Conn but causes a memory leak
105112 return nil
106113}
107114
@@ -115,33 +122,35 @@ func (d *Dialer) proxyDial(ctx context.Context, network, address string) (net.Co
115122}
116123
117124func (d * Dialer ) SSHClient (ctx context.Context ) (* ssh.Client , error ) {
118- return d .GetClient (ctx )
119- }
120-
121- func (d * Dialer ) GetClient (ctx context.Context ) (* ssh.Client , error ) {
122- a := d .pool .Get ()
123- if a != nil {
124- return a .(* ssh.Client ), nil
125+ cli , err := d .getClient (ctx )
126+ if err != nil {
127+ return nil , err
125128 }
129+ d .putClient (cli )
130+ return cli , nil
131+ }
126132
127- d .mut .Lock ()
128- defer d .mut .Unlock ()
129-
130- a = d .pool .Get ()
131- if a != nil {
132- return a .(* ssh.Client ), nil
133+ func (d * Dialer ) getClient (ctx context.Context ) (* ssh.Client , error ) {
134+ if atomic .LoadInt32 (& d .conns ) >= int32 (cap (d .clients )) {
135+ select {
136+ case <- ctx .Done ():
137+ return nil , ctx .Err ()
138+ case cli := <- d .clients :
139+ return cli , nil
140+ }
133141 }
142+ atomic .AddInt32 (& d .conns , 1 )
134143
135144 cli , err := d .sshClient (ctx )
136145 if err != nil {
146+ atomic .AddInt32 (& d .conns , - 1 )
137147 return nil , err
138148 }
139-
140149 return cli , nil
141150}
142151
143- func (d * Dialer ) PutClient (cli * ssh.Client ) {
144- d .pool . Put ( cli )
152+ func (d * Dialer ) putClient (cli * ssh.Client ) {
153+ d .clients <- cli
145154}
146155
147156func (d * Dialer ) sshClient (ctx context.Context ) (* ssh.Client , error ) {
@@ -167,20 +176,16 @@ func buildCmd(name string, args ...string) string {
167176}
168177
169178func (d * Dialer ) CommandDialContext (ctx context.Context , name string , args ... string ) (net.Conn , error ) {
170- cli , err := d .GetClient (ctx )
179+ cli , err := d .getClient (ctx )
171180 if err != nil {
172181 return nil , err
173182 }
183+ defer d .putClient (cli )
184+
174185 sess , err := cli .NewSession ()
175186 if err != nil {
176- if isSSHError (err ) {
177- d .PutClient (cli )
178- } else {
179- cli .Close ()
180- }
181187 return nil , err
182188 }
183- defer d .PutClient (cli )
184189
185190 conn1 , conn2 := net .Pipe ()
186191 sess .Stdin = conn1
@@ -217,42 +222,32 @@ func (d *Dialer) CommandDialContext(ctx context.Context, name string, args ...st
217222}
218223
219224func (d * Dialer ) DialContext (ctx context.Context , network , address string ) (net.Conn , error ) {
220- cli , err := d .GetClient (ctx )
225+ cli , err := d .getClient (ctx )
221226 if err != nil {
222227 return nil , err
223228 }
229+ defer d .putClient (cli )
224230
225231 conn , err := cli .DialContext (ctx , network , address )
226232 if err != nil {
227- if isSSHError (err ) {
228- d .PutClient (cli )
229- } else {
230- cli .Close ()
231- }
232233 return nil , err
233234 }
234235
235- d .PutClient (cli )
236236 return conn , nil
237237}
238238
239239func (d * Dialer ) Dial (network , address string ) (net.Conn , error ) {
240- cli , err := d .GetClient (context .Background ())
240+ cli , err := d .getClient (context .Background ())
241241 if err != nil {
242242 return nil , err
243243 }
244+ defer d .putClient (cli )
244245
245246 conn , err := cli .Dial (network , address )
246247 if err != nil {
247- if isSSHError (err ) {
248- d .PutClient (cli )
249- } else {
250- cli .Close ()
251- }
252248 return nil , err
253249 }
254250
255- d .PutClient (cli )
256251 return conn , nil
257252}
258253
@@ -264,43 +259,16 @@ func (d *Dialer) Listen(ctx context.Context, network, address string) (net.Liste
264259 }
265260 }
266261
267- cli , err := d .GetClient (ctx )
262+ cli , err := d .getClient (ctx )
268263 if err != nil {
269264 return nil , err
270265 }
266+ defer d .putClient (cli )
271267
272268 listener , err := cli .Listen (network , address )
273269 if err != nil {
274- if isSSHError (err ) {
275- d .PutClient (cli )
276- } else {
277- cli .Close ()
278- }
279270 return nil , err
280271 }
281272
282- listener = & listenerWithCleanup {
283- Listener : listener ,
284- cleanup : func () {
285- d .PutClient (cli )
286- },
287- }
288-
289273 return listener , nil
290274}
291-
292- type listenerWithCleanup struct {
293- net.Listener
294- cleanup func ()
295- }
296-
297- func (l * listenerWithCleanup ) Close () error {
298- err := l .Listener .Close ()
299- l .cleanup ()
300- return err
301- }
302-
303- func isSSHError (err error ) bool {
304- msg := err .Error ()
305- return strings .Contains (msg , "ssh: " )
306- }
0 commit comments