Skip to content

Commit 7b05eea

Browse files
committed
Revert client pool
1 parent 7710186 commit 7b05eea

File tree

1 file changed

+33
-44
lines changed

1 file changed

+33
-44
lines changed

client.go

Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77
"os"
88
"strconv"
99
"strings"
10-
"sync/atomic"
10+
"sync"
1111
"time"
1212

1313
"golang.org/x/crypto/ssh"
@@ -25,9 +25,8 @@ func NewDialer(addr string) (*Dialer, error) {
2525

2626
func NewDialerWithConfig(host string, config *ssh.ClientConfig) (*Dialer, error) {
2727
return &Dialer{
28-
host: host,
29-
config: config,
30-
clients: make(chan *ssh.Client, 5),
28+
host: host,
29+
config: config,
3130
}, nil
3231
}
3332

@@ -103,13 +102,19 @@ type Dialer struct {
103102
host string
104103
config *ssh.ClientConfig
105104

106-
conns int32
107-
clients chan *ssh.Client
105+
mut sync.RWMutex
106+
sshCli *ssh.Client
108107
}
109108

110109
func (d *Dialer) Close() error {
111-
// In practice, closing the connection doesn't actually release the ssh.Conn but causes a memory leak
112-
return nil
110+
d.mut.Lock()
111+
defer d.mut.Unlock()
112+
if d.sshCli == nil {
113+
return nil
114+
}
115+
err := d.sshCli.Close()
116+
d.sshCli = nil
117+
return err
113118
}
114119

115120
func (d *Dialer) proxyDial(ctx context.Context, network, address string) (net.Conn, error) {
@@ -122,38 +127,20 @@ func (d *Dialer) proxyDial(ctx context.Context, network, address string) (net.Co
122127
}
123128

124129
func (d *Dialer) SSHClient(ctx context.Context) (*ssh.Client, error) {
125-
cli, err := d.getClient(ctx)
126-
if err != nil {
127-
return nil, err
128-
}
129-
d.putClient(cli)
130-
return cli, nil
131-
}
130+
d.mut.RLock()
131+
sshCli := d.sshCli
132+
d.mut.RUnlock()
132133

133-
func (d *Dialer) getClient(ctx context.Context) (*ssh.Client, error) {
134-
if atomic.LoadInt32(&d.conns) >= int32(cap(d.clients)) {
135-
select {
136-
case <-ctx.Done():
137-
return nil, ctx.Err()
138-
case cli := <-d.clients:
139-
return cli, nil
140-
}
134+
if sshCli != nil {
135+
return sshCli, nil
141136
}
142-
atomic.AddInt32(&d.conns, 1)
143137

144-
cli, err := d.sshClient(ctx)
145-
if err != nil {
146-
atomic.AddInt32(&d.conns, -1)
147-
return nil, err
138+
d.mut.Lock()
139+
defer d.mut.Unlock()
140+
if d.sshCli != nil {
141+
return d.sshCli, nil
148142
}
149-
return cli, nil
150-
}
151143

152-
func (d *Dialer) putClient(cli *ssh.Client) {
153-
d.clients <- cli
154-
}
155-
156-
func (d *Dialer) sshClient(ctx context.Context) (*ssh.Client, error) {
157144
conn, err := d.proxyDial(ctx, "tcp", d.host)
158145
if err != nil {
159146
return nil, err
@@ -163,7 +150,9 @@ func (d *Dialer) sshClient(ctx context.Context) (*ssh.Client, error) {
163150
if err != nil {
164151
return nil, err
165152
}
166-
return ssh.NewClient(con, chans, reqs), nil
153+
154+
d.sshCli = ssh.NewClient(con, chans, reqs)
155+
return d.sshCli, nil
167156
}
168157

169158
func buildCmd(name string, args ...string) string {
@@ -176,14 +165,14 @@ func buildCmd(name string, args ...string) string {
176165
}
177166

178167
func (d *Dialer) CommandDialContext(ctx context.Context, name string, args ...string) (net.Conn, error) {
179-
cli, err := d.getClient(ctx)
168+
cli, err := d.SSHClient(ctx)
180169
if err != nil {
181170
return nil, err
182171
}
183-
defer d.putClient(cli)
184172

185173
sess, err := cli.NewSession()
186174
if err != nil {
175+
d.Close()
187176
return nil, err
188177
}
189178

@@ -222,29 +211,29 @@ func (d *Dialer) CommandDialContext(ctx context.Context, name string, args ...st
222211
}
223212

224213
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
225-
cli, err := d.getClient(ctx)
214+
cli, err := d.SSHClient(ctx)
226215
if err != nil {
227216
return nil, err
228217
}
229-
defer d.putClient(cli)
230218

231219
conn, err := cli.DialContext(ctx, network, address)
232220
if err != nil {
221+
d.Close()
233222
return nil, err
234223
}
235224

236225
return conn, nil
237226
}
238227

239228
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
240-
cli, err := d.getClient(context.Background())
229+
cli, err := d.SSHClient(context.Background())
241230
if err != nil {
242231
return nil, err
243232
}
244-
defer d.putClient(cli)
245233

246234
conn, err := cli.Dial(network, address)
247235
if err != nil {
236+
d.Close()
248237
return nil, err
249238
}
250239

@@ -259,14 +248,14 @@ func (d *Dialer) Listen(ctx context.Context, network, address string) (net.Liste
259248
}
260249
}
261250

262-
cli, err := d.getClient(ctx)
251+
cli, err := d.SSHClient(ctx)
263252
if err != nil {
264253
return nil, err
265254
}
266-
defer d.putClient(cli)
267255

268256
listener, err := cli.Listen(network, address)
269257
if err != nil {
258+
d.Close()
270259
return nil, err
271260
}
272261

0 commit comments

Comments
 (0)