Skip to content

Commit 00e25e8

Browse files
feat: add Unix forwarding server implementations
Adds optional (disabled by default) implementations of local->remote and remote->local Unix forwarding through OpenSSH's protocol extensions: - [email protected] - [email protected] - [email protected] - [email protected] Adds tests for Unix forwarding, reverse Unix forwarding and reverse TCP forwarding. Co-authored-by: Samuel Corsi-House <[email protected]>
1 parent adec695 commit 00e25e8

9 files changed

+642
-30
lines changed

options_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func TestPasswordAuth(t *testing.T) {
4949

5050
func TestPasswordAuthBadPass(t *testing.T) {
5151
t.Parallel()
52-
l := newLocalListener()
52+
l := newLocalTCPListener()
5353
srv := &Server{Handler: func(s Session) {}}
5454
srv.SetOption(PasswordAuth(func(ctx Context, password string) bool {
5555
return false

server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ type Server struct {
4747
ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
4848
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
4949
ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil
50+
LocalUnixForwardingCallback LocalUnixForwardingCallback // callback for allowing local unix forwarding ([email protected]), denies all if nil
51+
ReverseUnixForwardingCallback ReverseUnixForwardingCallback // callback for allowing reverse unix forwarding ([email protected]), denies all if nil
5052
ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options
5153
SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions
5254

server_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func TestAddHostKey(t *testing.T) {
2929
}
3030

3131
func TestServerShutdown(t *testing.T) {
32-
l := newLocalListener()
32+
l := newLocalTCPListener()
3333
testBytes := []byte("Hello world\n")
3434
s := &Server{
3535
Handler: func(s Session) {
@@ -80,7 +80,7 @@ func TestServerShutdown(t *testing.T) {
8080
}
8181

8282
func TestServerClose(t *testing.T) {
83-
l := newLocalListener()
83+
l := newLocalTCPListener()
8484
s := &Server{
8585
Handler: func(s Session) {
8686
time.Sleep(5 * time.Second)

session_test.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,25 @@ func (srv *Server) serveOnce(l net.Listener) error {
2020
return e
2121
}
2222
srv.ChannelHandlers = map[string]ChannelHandler{
23-
"session": DefaultSessionHandler,
24-
"direct-tcpip": DirectTCPIPHandler,
23+
"session": DefaultSessionHandler,
24+
"direct-tcpip": DirectTCPIPHandler,
25+
"[email protected]": DirectStreamLocalHandler,
2526
}
27+
28+
forwardedTCPHandler := &ForwardedTCPHandler{}
29+
forwardedUnixHandler := &ForwardedUnixHandler{}
30+
srv.RequestHandlers = map[string]RequestHandler{
31+
"tcpip-forward": forwardedTCPHandler.HandleSSHRequest,
32+
"cancel-tcpip-forward": forwardedTCPHandler.HandleSSHRequest,
33+
"[email protected]": forwardedUnixHandler.HandleSSHRequest,
34+
"[email protected]": forwardedUnixHandler.HandleSSHRequest,
35+
}
36+
2637
srv.HandleConn(conn)
2738
return nil
2839
}
2940

30-
func newLocalListener() net.Listener {
41+
func newLocalTCPListener() net.Listener {
3142
l, err := net.Listen("tcp", "127.0.0.1:0")
3243
if err != nil {
3344
if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
@@ -64,7 +75,7 @@ func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*g
6475
}
6576

6677
func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) {
67-
l := newLocalListener()
78+
l := newLocalTCPListener()
6879
go srv.serveOnce(l)
6980
return newClientSession(t, l.Addr().String(), cfg)
7081
}

ssh.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package ssh
22

33
import (
44
"crypto/subtle"
5+
"errors"
56
"net"
67

78
gossh "golang.org/x/crypto/ssh"
@@ -29,6 +30,9 @@ const (
2930
// DefaultHandler is the default Handler used by Serve.
3031
var DefaultHandler Handler
3132

33+
// ErrReject is returned by some callbacks to reject a request.
34+
var ErrRejected = errors.New("ssh: rejected")
35+
3236
// Option is a functional option handler for Server.
3337
type Option func(*Server) error
3438

@@ -64,6 +68,22 @@ type LocalPortForwardingCallback func(ctx Context, destinationHost string, desti
6468
// ReversePortForwardingCallback is a hook for allowing reverse port forwarding
6569
type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool
6670

71+
// LocalUnixForwardingCallback is a hook for allowing unix forwarding
72+
// ([email protected]). Returning ErrRejected will reject the
73+
// request. The returned net.Conn will be closed by the server when no longer
74+
// needed.
75+
//
76+
// Use SimpleUnixLocalForwardingCallback for a basic implementation.
77+
type LocalUnixForwardingCallback func(ctx Context, socketPath string) (net.Conn, error)
78+
79+
// ReverseUnixForwardingCallback is a hook for allowing reverse unix forwarding
80+
// ([email protected]). Returning ErrRejected will reject the
81+
// request. The returned net.Listener will be closed by the server when no
82+
// longer needed.
83+
//
84+
// Use SimpleUnixReverseForwardingCallback for a basic implementation.
85+
type ReverseUnixForwardingCallback func(ctx Context, socketPath string) (net.Listener, error)
86+
6787
// ServerConfigCallback is a hook for creating custom default server configs
6888
type ServerConfigCallback func(ctx Context) *gossh.ServerConfig
6989

streamlocal.go

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
package ssh
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"io/fs"
8+
"net"
9+
"os"
10+
"path/filepath"
11+
"sync"
12+
"syscall"
13+
14+
gossh "golang.org/x/crypto/ssh"
15+
)
16+
17+
const (
18+
forwardedUnixChannelType = "[email protected]"
19+
)
20+
21+
// directStreamLocalChannelData data struct as specified in OpenSSH's protocol
22+
// extensions document, Section 2.4.
23+
// https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL?annotate=HEAD
24+
type directStreamLocalChannelData struct {
25+
SocketPath string
26+
27+
Reserved1 string
28+
Reserved2 uint32
29+
}
30+
31+
// DirectStreamLocalHandler provides Unix forwarding from client -> server. It
32+
// can be enabled by adding it to the server's ChannelHandlers under
33+
34+
//
35+
// Unix socket support on Windows is not widely available, so this handler may
36+
// not work on all Windows installations and is not tested on Windows.
37+
func DirectStreamLocalHandler(srv *Server, _ *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
38+
var d directStreamLocalChannelData
39+
err := gossh.Unmarshal(newChan.ExtraData(), &d)
40+
if err != nil {
41+
_ = newChan.Reject(gossh.ConnectionFailed, "error parsing direct-streamlocal data: "+err.Error())
42+
return
43+
}
44+
45+
if srv.LocalUnixForwardingCallback == nil {
46+
_ = newChan.Reject(gossh.Prohibited, "unix forwarding is disabled")
47+
return
48+
}
49+
dconn, err := srv.LocalUnixForwardingCallback(ctx, d.SocketPath)
50+
if err != nil {
51+
if errors.Is(err, ErrRejected) {
52+
_ = newChan.Reject(gossh.Prohibited, "unix forwarding is disabled")
53+
return
54+
}
55+
_ = newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("dial unix socket %q: %+v", d.SocketPath, err.Error()))
56+
return
57+
}
58+
59+
ch, reqs, err := newChan.Accept()
60+
if err != nil {
61+
_ = dconn.Close()
62+
return
63+
}
64+
go gossh.DiscardRequests(reqs)
65+
66+
bicopy(ctx, ch, dconn)
67+
}
68+
69+
// remoteUnixForwardRequest describes the extra data sent in a
70+
// [email protected] containing the socket path to bind to.
71+
type remoteUnixForwardRequest struct {
72+
SocketPath string
73+
}
74+
75+
// remoteUnixForwardChannelData describes the data sent as the payload in the new
76+
// channel request when a Unix connection is accepted by the listener.
77+
type remoteUnixForwardChannelData struct {
78+
SocketPath string
79+
Reserved uint32
80+
}
81+
82+
// ForwardedUnixHandler can be enabled by creating a ForwardedUnixHandler and
83+
// adding the HandleSSHRequest callback to the server's RequestHandlers under
84+
85+
86+
//
87+
// Unix socket support on Windows is not widely available, so this handler may
88+
// not work on all Windows installations and is not tested on Windows.
89+
type ForwardedUnixHandler struct {
90+
sync.Mutex
91+
forwards map[string]net.Listener
92+
}
93+
94+
func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) {
95+
h.Lock()
96+
if h.forwards == nil {
97+
h.forwards = make(map[string]net.Listener)
98+
}
99+
h.Unlock()
100+
conn, ok := ctx.Value(ContextKeyConn).(*gossh.ServerConn)
101+
if !ok {
102+
// TODO: log cast failure
103+
return false, nil
104+
}
105+
106+
switch req.Type {
107+
108+
var reqPayload remoteUnixForwardRequest
109+
err := gossh.Unmarshal(req.Payload, &reqPayload)
110+
if err != nil {
111+
// TODO: log parse failure
112+
return false, nil
113+
}
114+
115+
if srv.ReverseUnixForwardingCallback == nil {
116+
return false, []byte("unix forwarding is disabled")
117+
}
118+
119+
addr := reqPayload.SocketPath
120+
h.Lock()
121+
_, ok := h.forwards[addr]
122+
h.Unlock()
123+
if ok {
124+
// TODO: log failure
125+
return false, nil
126+
}
127+
128+
ln, err := srv.ReverseUnixForwardingCallback(ctx, addr)
129+
if err != nil {
130+
if errors.Is(err, ErrRejected) {
131+
return false, []byte("unix forwarding is disabled")
132+
}
133+
// TODO: log unix listen failure
134+
return false, nil
135+
}
136+
137+
// The listener needs to successfully start before it can be added to
138+
// the map, so we don't have to worry about checking for an existing
139+
// listener as you can't listen on the same socket twice.
140+
//
141+
// This is also what the TCP version of this code does.
142+
h.Lock()
143+
h.forwards[addr] = ln
144+
h.Unlock()
145+
146+
ctx, cancel := context.WithCancel(ctx)
147+
go func() {
148+
<-ctx.Done()
149+
_ = ln.Close()
150+
}()
151+
go func() {
152+
defer cancel()
153+
154+
for {
155+
c, err := ln.Accept()
156+
if err != nil {
157+
// closed below
158+
break
159+
}
160+
payload := gossh.Marshal(&remoteUnixForwardChannelData{
161+
SocketPath: addr,
162+
})
163+
164+
go func() {
165+
ch, reqs, err := conn.OpenChannel(forwardedUnixChannelType, payload)
166+
if err != nil {
167+
_ = c.Close()
168+
return
169+
}
170+
go gossh.DiscardRequests(reqs)
171+
bicopy(ctx, ch, c)
172+
}()
173+
}
174+
175+
h.Lock()
176+
ln2, ok := h.forwards[addr]
177+
if ok && ln2 == ln {
178+
delete(h.forwards, addr)
179+
}
180+
h.Unlock()
181+
_ = ln.Close()
182+
}()
183+
184+
return true, nil
185+
186+
187+
var reqPayload remoteUnixForwardRequest
188+
err := gossh.Unmarshal(req.Payload, &reqPayload)
189+
if err != nil {
190+
// TODO: log parse failure
191+
return false, nil
192+
}
193+
h.Lock()
194+
ln, ok := h.forwards[reqPayload.SocketPath]
195+
h.Unlock()
196+
if ok {
197+
_ = ln.Close()
198+
}
199+
return true, nil
200+
201+
default:
202+
return false, nil
203+
}
204+
}
205+
206+
// unlink removes files and unlike os.Remove, directories are kept.
207+
func unlink(path string) error {
208+
// Ignore EINTR like os.Remove, see ignoringEINTR in os/file_posix.go
209+
// for more details.
210+
for {
211+
err := syscall.Unlink(path)
212+
if !errors.Is(err, syscall.EINTR) {
213+
return err
214+
}
215+
}
216+
}
217+
218+
// SimpleUnixLocalForwardingCallback provides a basic implementation for
219+
// LocalUnixForwardingCallback. It will simply dial the requested socket using
220+
// a context-aware dialer.
221+
func SimpleUnixLocalForwardingCallback(ctx Context, socketPath string) (net.Conn, error) {
222+
var d net.Dialer
223+
return d.DialContext(ctx, "unix", socketPath)
224+
}
225+
226+
// SimpleUnixReverseForwardingCallback provides a basic implementation for
227+
// ReverseUnixForwardingCallback. The parent directory will be created (with
228+
// os.MkdirAll), and existing files with the same name will be removed.
229+
func SimpleUnixReverseForwardingCallback(_ Context, socketPath string) (net.Listener, error) {
230+
// Create socket parent dir if not exists.
231+
parentDir := filepath.Dir(socketPath)
232+
err := os.MkdirAll(parentDir, 0700)
233+
if err != nil {
234+
return nil, fmt.Errorf("failed to create parent directory %q for socket %q: %w", parentDir, socketPath, err)
235+
}
236+
237+
// Remove existing socket if it exists. We do not use os.Remove() here
238+
// so that directories are kept. Note that it's possible that we will
239+
// overwrite a regular file here. Both of these behaviors match OpenSSH,
240+
// however, which is why we unlink.
241+
err = unlink(socketPath)
242+
if err != nil && !errors.Is(err, fs.ErrNotExist) {
243+
return nil, fmt.Errorf("failed to remove existing file in socket path %q: %w", socketPath, err)
244+
}
245+
246+
ln, err := net.Listen("unix", socketPath)
247+
if err != nil {
248+
return nil, fmt.Errorf("failed to listen on unix socket %q: %w", socketPath, err)
249+
}
250+
251+
return ln, err
252+
}

0 commit comments

Comments
 (0)