@@ -110,7 +110,7 @@ func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *g
110
110
return false , nil
111
111
}
112
112
113
- if srv .ReverseUnixForwardingCallback == nil || ! srv . ReverseUnixForwardingCallback ( ctx , reqPayload . SocketPath ) {
113
+ if srv .ReverseUnixForwardingCallback == nil {
114
114
return false , []byte ("unix forwarding is disabled" )
115
115
}
116
116
@@ -123,26 +123,11 @@ func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *g
123
123
return false , nil
124
124
}
125
125
126
- // Create socket parent dir if not exists.
127
- parentDir := filepath .Dir (addr )
128
- err = os .MkdirAll (parentDir , 0700 )
129
- if err != nil {
130
- // TODO: log mkdir failure
131
- return false , nil
132
- }
133
-
134
- // Remove existing socket if it exists. We do not use os.Remove() here
135
- // so that directories are kept. Note that it's possible that we will
136
- // overwrite a regular file here. Both of these behaviors match OpenSSH,
137
- // however, which is why we unlink.
138
- err = unlink (addr )
139
- if err != nil && ! errors .Is (err , fs .ErrNotExist ) {
140
- // TODO: log unlink failure
141
- return false , nil
142
- }
143
-
144
- ln , err := net .Listen ("unix" , addr )
126
+ ln , err := srv .ReverseUnixForwardingCallback (ctx , addr )
145
127
if err != nil {
128
+ if errors .Is (err , ErrRejected ) {
129
+ return false , []byte ("unix forwarding is disabled" )
130
+ }
146
131
// TODO: log unix listen failure
147
132
return false , nil
148
133
}
@@ -227,3 +212,31 @@ func unlink(path string) error {
227
212
}
228
213
}
229
214
}
215
+
216
+ // SimpleUnixReverseForwardingCallback provides a basic implementation for
217
+ // ReverseUnixForwardingCallback. The parent directory will be created (with
218
+ // os.MkdirAll), and existing files with the same name will be removed.
219
+ func SimpleUnixReverseForwardingCallback (_ Context , socketPath string ) (net.Listener , error ) {
220
+ // Create socket parent dir if not exists.
221
+ parentDir := filepath .Dir (socketPath )
222
+ err := os .MkdirAll (parentDir , 0700 )
223
+ if err != nil {
224
+ return nil , fmt .Errorf ("failed to create parent directory %q for socket %q: %w" , parentDir , socketPath , err )
225
+ }
226
+
227
+ // Remove existing socket if it exists. We do not use os.Remove() here
228
+ // so that directories are kept. Note that it's possible that we will
229
+ // overwrite a regular file here. Both of these behaviors match OpenSSH,
230
+ // however, which is why we unlink.
231
+ err = unlink (socketPath )
232
+ if err != nil && ! errors .Is (err , fs .ErrNotExist ) {
233
+ return nil , fmt .Errorf ("failed to remove existing file in socket path %q: %w" , socketPath , err )
234
+ }
235
+
236
+ ln , err := net .Listen ("unix" , socketPath )
237
+ if err != nil {
238
+ return nil , fmt .Errorf ("failed to listen on unix socket %q: %w" , socketPath , err )
239
+ }
240
+
241
+ return ln , err
242
+ }
0 commit comments