Skip to content

Commit f86b780

Browse files
Use callback to create net.Listener (#1)
Co-authored-by: Dean Sheather <[email protected]>
1 parent 9e6b773 commit f86b780

File tree

3 files changed

+44
-26
lines changed

3 files changed

+44
-26
lines changed

ssh.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package ssh
22

33
import (
44
"crypto/subtle"
5+
"errors"
56
"net"
67

78
gossh "golang.org/x/crypto/ssh"
@@ -29,6 +30,9 @@ const (
2930
// DefaultHandler is the default Handler used by Serve.
3031
var DefaultHandler Handler
3132

33+
// ErrReject is returned by some callbacks to reject a request.
34+
var ErrRejected = errors.New("rejected")
35+
3236
// Option is a functional option handler for Server.
3337
type Option func(*Server) error
3438

@@ -69,8 +73,9 @@ type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort u
6973
type LocalUnixForwardingCallback func(ctx Context, socketPath string) bool
7074

7175
// ReverseUnixForwardingCallback is a hook for allowing reverse unix forwarding
72-
73-
type ReverseUnixForwardingCallback func(ctx Context, socketPath string) bool
76+
// ([email protected]). Returning ErrRejected will reject the
77+
// request.
78+
type ReverseUnixForwardingCallback func(ctx Context, socketPath string) (net.Listener, error)
7479

7580
// ServerConfigCallback is a hook for creating custom default server configs
7681
type ServerConfigCallback func(ctx Context) *gossh.ServerConfig

streamlocal.go

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *g
110110
return false, nil
111111
}
112112

113-
if srv.ReverseUnixForwardingCallback == nil || !srv.ReverseUnixForwardingCallback(ctx, reqPayload.SocketPath) {
113+
if srv.ReverseUnixForwardingCallback == nil {
114114
return false, []byte("unix forwarding is disabled")
115115
}
116116

@@ -123,26 +123,11 @@ func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *g
123123
return false, nil
124124
}
125125

126-
// Create socket parent dir if not exists.
127-
parentDir := filepath.Dir(addr)
128-
err = os.MkdirAll(parentDir, 0700)
129-
if err != nil {
130-
// TODO: log mkdir failure
131-
return false, nil
132-
}
133-
134-
// Remove existing socket if it exists. We do not use os.Remove() here
135-
// so that directories are kept. Note that it's possible that we will
136-
// overwrite a regular file here. Both of these behaviors match OpenSSH,
137-
// however, which is why we unlink.
138-
err = unlink(addr)
139-
if err != nil && !errors.Is(err, fs.ErrNotExist) {
140-
// TODO: log unlink failure
141-
return false, nil
142-
}
143-
144-
ln, err := net.Listen("unix", addr)
126+
ln, err := srv.ReverseUnixForwardingCallback(ctx, addr)
145127
if err != nil {
128+
if errors.Is(err, ErrRejected) {
129+
return false, []byte("unix forwarding is disabled")
130+
}
146131
// TODO: log unix listen failure
147132
return false, nil
148133
}
@@ -227,3 +212,31 @@ func unlink(path string) error {
227212
}
228213
}
229214
}
215+
216+
// SimpleUnixReverseForwardingCallback provides a basic implementation for
217+
// ReverseUnixForwardingCallback. The parent directory will be created (with
218+
// os.MkdirAll), and existing files with the same name will be removed.
219+
func SimpleUnixReverseForwardingCallback(_ Context, socketPath string) (net.Listener, error) {
220+
// Create socket parent dir if not exists.
221+
parentDir := filepath.Dir(socketPath)
222+
err := os.MkdirAll(parentDir, 0700)
223+
if err != nil {
224+
return nil, fmt.Errorf("failed to create parent directory %q for socket %q: %w", parentDir, socketPath, err)
225+
}
226+
227+
// Remove existing socket if it exists. We do not use os.Remove() here
228+
// so that directories are kept. Note that it's possible that we will
229+
// overwrite a regular file here. Both of these behaviors match OpenSSH,
230+
// however, which is why we unlink.
231+
err = unlink(socketPath)
232+
if err != nil && !errors.Is(err, fs.ErrNotExist) {
233+
return nil, fmt.Errorf("failed to remove existing file in socket path %q: %w", socketPath, err)
234+
}
235+
236+
ln, err := net.Listen("unix", socketPath)
237+
if err != nil {
238+
return nil, fmt.Errorf("failed to listen on unix socket %q: %w", socketPath, err)
239+
}
240+
241+
return ln, err
242+
}

streamlocal_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,11 @@ func TestReverseUnixForwardingWorks(t *testing.T) {
127127

128128
_, client, cleanup := newTestSession(t, &Server{
129129
Handler: func(s Session) {},
130-
ReverseUnixForwardingCallback: func(ctx Context, socketPath string) bool {
130+
ReverseUnixForwardingCallback: func(ctx Context, socketPath string) (net.Listener, error) {
131131
if socketPath != remoteSocketPath {
132132
panic("unexpected socket path: " + socketPath)
133133
}
134-
return true
134+
return SimpleUnixReverseForwardingCallback(ctx, socketPath)
135135
},
136136
}, nil)
137137
defer cleanup()
@@ -182,12 +182,12 @@ func TestReverseUnixForwardingRespectsCallback(t *testing.T) {
182182
var called int64
183183
_, client, cleanup := newTestSession(t, &Server{
184184
Handler: func(s Session) {},
185-
ReverseUnixForwardingCallback: func(ctx Context, socketPath string) bool {
185+
ReverseUnixForwardingCallback: func(ctx Context, socketPath string) (net.Listener, error) {
186186
atomic.AddInt64(&called, 1)
187187
if socketPath != remoteSocketPath {
188188
panic("unexpected socket path: " + socketPath)
189189
}
190-
return false
190+
return nil, ErrRejected
191191
},
192192
}, nil)
193193
defer cleanup()

0 commit comments

Comments
 (0)