77 "os"
88 "strconv"
99 "strings"
10- "sync/atomic "
10+ "sync"
1111 "time"
1212
1313 "golang.org/x/crypto/ssh"
@@ -25,9 +25,8 @@ func NewDialer(addr string) (*Dialer, error) {
2525
2626func NewDialerWithConfig (host string , config * ssh.ClientConfig ) (* Dialer , error ) {
2727 return & Dialer {
28- host : host ,
29- config : config ,
30- clients : make (chan * ssh.Client , 5 ),
28+ host : host ,
29+ config : config ,
3130 }, nil
3231}
3332
@@ -103,13 +102,19 @@ type Dialer struct {
103102 host string
104103 config * ssh.ClientConfig
105104
106- conns int32
107- clients chan * ssh.Client
105+ mut sync. RWMutex
106+ sshCli * ssh.Client
108107}
109108
110109func (d * Dialer ) Close () error {
111- // In practice, closing the connection doesn't actually release the ssh.Conn but causes a memory leak
112- return nil
110+ d .mut .Lock ()
111+ defer d .mut .Unlock ()
112+ if d .sshCli == nil {
113+ return nil
114+ }
115+ err := d .sshCli .Close ()
116+ d .sshCli = nil
117+ return err
113118}
114119
115120func (d * Dialer ) proxyDial (ctx context.Context , network , address string ) (net.Conn , error ) {
@@ -122,38 +127,20 @@ func (d *Dialer) proxyDial(ctx context.Context, network, address string) (net.Co
122127}
123128
124129func (d * Dialer ) SSHClient (ctx context.Context ) (* ssh.Client , error ) {
125- cli , err := d .getClient (ctx )
126- if err != nil {
127- return nil , err
128- }
129- d .putClient (cli )
130- return cli , nil
131- }
130+ d .mut .RLock ()
131+ sshCli := d .sshCli
132+ d .mut .RUnlock ()
132133
133- func (d * Dialer ) getClient (ctx context.Context ) (* ssh.Client , error ) {
134- if atomic .LoadInt32 (& d .conns ) >= int32 (cap (d .clients )) {
135- select {
136- case <- ctx .Done ():
137- return nil , ctx .Err ()
138- case cli := <- d .clients :
139- return cli , nil
140- }
134+ if sshCli != nil {
135+ return sshCli , nil
141136 }
142- atomic .AddInt32 (& d .conns , 1 )
143137
144- cli , err := d . sshClient ( ctx )
145- if err != nil {
146- atomic . AddInt32 ( & d . conns , - 1 )
147- return nil , err
138+ d . mut . Lock ( )
139+ defer d . mut . Unlock ()
140+ if d . sshCli != nil {
141+ return d . sshCli , nil
148142 }
149- return cli , nil
150- }
151143
152- func (d * Dialer ) putClient (cli * ssh.Client ) {
153- d .clients <- cli
154- }
155-
156- func (d * Dialer ) sshClient (ctx context.Context ) (* ssh.Client , error ) {
157144 conn , err := d .proxyDial (ctx , "tcp" , d .host )
158145 if err != nil {
159146 return nil , err
@@ -163,7 +150,9 @@ func (d *Dialer) sshClient(ctx context.Context) (*ssh.Client, error) {
163150 if err != nil {
164151 return nil , err
165152 }
166- return ssh .NewClient (con , chans , reqs ), nil
153+
154+ d .sshCli = ssh .NewClient (con , chans , reqs )
155+ return d .sshCli , nil
167156}
168157
169158func buildCmd (name string , args ... string ) string {
@@ -176,14 +165,14 @@ func buildCmd(name string, args ...string) string {
176165}
177166
178167func (d * Dialer ) CommandDialContext (ctx context.Context , name string , args ... string ) (net.Conn , error ) {
179- cli , err := d .getClient (ctx )
168+ cli , err := d .SSHClient (ctx )
180169 if err != nil {
181170 return nil , err
182171 }
183- defer d .putClient (cli )
184172
185173 sess , err := cli .NewSession ()
186174 if err != nil {
175+ d .Close ()
187176 return nil , err
188177 }
189178
@@ -222,29 +211,29 @@ func (d *Dialer) CommandDialContext(ctx context.Context, name string, args ...st
222211}
223212
224213func (d * Dialer ) DialContext (ctx context.Context , network , address string ) (net.Conn , error ) {
225- cli , err := d .getClient (ctx )
214+ cli , err := d .SSHClient (ctx )
226215 if err != nil {
227216 return nil , err
228217 }
229- defer d .putClient (cli )
230218
231219 conn , err := cli .DialContext (ctx , network , address )
232220 if err != nil {
221+ d .Close ()
233222 return nil , err
234223 }
235224
236225 return conn , nil
237226}
238227
239228func (d * Dialer ) Dial (network , address string ) (net.Conn , error ) {
240- cli , err := d .getClient (context .Background ())
229+ cli , err := d .SSHClient (context .Background ())
241230 if err != nil {
242231 return nil , err
243232 }
244- defer d .putClient (cli )
245233
246234 conn , err := cli .Dial (network , address )
247235 if err != nil {
236+ d .Close ()
248237 return nil , err
249238 }
250239
@@ -259,14 +248,14 @@ func (d *Dialer) Listen(ctx context.Context, network, address string) (net.Liste
259248 }
260249 }
261250
262- cli , err := d .getClient (ctx )
251+ cli , err := d .SSHClient (ctx )
263252 if err != nil {
264253 return nil , err
265254 }
266- defer d .putClient (cli )
267255
268256 listener , err := cli .Listen (network , address )
269257 if err != nil {
258+ d .Close ()
270259 return nil , err
271260 }
272261
0 commit comments