From d20da86c8ad147cb6da9029a7d881fc75b88ec72 Mon Sep 17 00:00:00 2001 From: Randy Reddig Date: Tue, 20 Jun 2023 17:34:55 -0700 Subject: [PATCH 1/6] ssh: add (*Client).DialContext method Fixes golang/go#20288. --- ssh/tcpip.go | 32 +++++++++++++++++++++++++++ ssh/tcpip_test.go | 40 ++++++++++++++++++++++++++++++++++ ssh/test/dial_unix_test.go | 44 +++++++++++++++++++++++++++----------- 3 files changed, 104 insertions(+), 12 deletions(-) diff --git a/ssh/tcpip.go b/ssh/tcpip.go index 80d35f5ec1..fb144c0367 100644 --- a/ssh/tcpip.go +++ b/ssh/tcpip.go @@ -5,6 +5,7 @@ package ssh import ( + "context" "errors" "fmt" "io" @@ -332,6 +333,37 @@ func (l *tcpListener) Addr() net.Addr { return l.laddr } +// DialContext initiates a connection to the addr from the remote host. +// If the supplied context is cancelled before the connection can be opened, +// ctx.Err() will be returned. +// The resulting connection has a zero LocalAddr() and RemoteAddr(). +func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + type connErr struct { + conn net.Conn + err error + } + ch := make(chan connErr) + go func() { + conn, err := c.Dial(n, addr) + select { + case ch <- connErr{conn, err}: + case <-ctx.Done(): + if conn != nil { + conn.Close() + } + } + }() + select { + case res := <-ch: + return res.conn, res.err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + // Dial initiates a connection to the addr from the remote host. // The resulting connection has a zero LocalAddr() and RemoteAddr(). func (c *Client) Dial(n, addr string) (net.Conn, error) { diff --git a/ssh/tcpip_test.go b/ssh/tcpip_test.go index f1265cb496..8fafeb5e26 100644 --- a/ssh/tcpip_test.go +++ b/ssh/tcpip_test.go @@ -5,7 +5,10 @@ package ssh import ( + "context" + "net" "testing" + "time" ) func TestAutoPortListenBroken(t *testing.T) { @@ -18,3 +21,40 @@ func TestAutoPortListenBroken(t *testing.T) { t.Errorf("version %q marked as broken", works) } } + +func TestClientImplementsDialContext(t *testing.T) { + type ContextDialer interface { + DialContext(context.Context, string, string) (net.Conn, error) + } + var _ ContextDialer = &Client{} +} + +func TestClientDialContextWithCancel(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.Canceled { + t.Errorf("DialContext: got nil error, expected %v", context.Canceled) + } +} + +func TestClientDialContextWithDeadline(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithDeadline(context.Background(), time.Now()) + defer cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.DeadlineExceeded { + t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded) + } +} + +func TestClientDialContextWithTimeout(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithTimeout(context.Background(), 0) + defer cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.DeadlineExceeded { + t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded) + } +} diff --git a/ssh/test/dial_unix_test.go b/ssh/test/dial_unix_test.go index 4a7ec31737..6f3f23351d 100644 --- a/ssh/test/dial_unix_test.go +++ b/ssh/test/dial_unix_test.go @@ -10,6 +10,7 @@ package test // direct-tcpip and direct-streamlocal functional tests import ( + "context" "fmt" "io" "net" @@ -47,19 +48,38 @@ func testDial(t *testing.T, n, listenAddr string, x dialTester) { } }() - conn, err := sshConn.Dial(n, l.Addr().String()) - if err != nil { - t.Fatalf("Dial: %v", err) - } - x.TestClientConn(t, conn) - defer conn.Close() - b, err := io.ReadAll(conn) - if err != nil { - t.Fatalf("ReadAll: %v", err) + { + conn, err := sshConn.Dial(n, l.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + x.TestClientConn(t, conn) + defer conn.Close() + b, err := io.ReadAll(conn) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + t.Logf("got %q", string(b)) + if string(b) != testData { + t.Fatalf("expected %q, got %q", testData, string(b)) + } } - t.Logf("got %q", string(b)) - if string(b) != testData { - t.Fatalf("expected %q, got %q", testData, string(b)) + + { + conn, err := sshConn.DialContext(context.Background(), n, l.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + x.TestClientConn(t, conn) + defer conn.Close() + b, err := io.ReadAll(conn) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + t.Logf("got %q", string(b)) + if string(b) != testData { + t.Fatalf("expected %q, got %q", testData, string(b)) + } } } From 9207fc3be438fa29bd771bb15fc1cd0ce20e05c3 Mon Sep 17 00:00:00 2001 From: Randy Reddig Date: Wed, 21 Jun 2023 08:41:38 -0700 Subject: [PATCH 2/6] ssh: CL feedback --- ssh/tcpip.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ssh/tcpip.go b/ssh/tcpip.go index fb144c0367..ef5059a11d 100644 --- a/ssh/tcpip.go +++ b/ssh/tcpip.go @@ -334,9 +334,12 @@ func (l *tcpListener) Addr() net.Addr { } // DialContext initiates a connection to the addr from the remote host. -// If the supplied context is cancelled before the connection can be opened, -// ctx.Err() will be returned. -// The resulting connection has a zero LocalAddr() and RemoteAddr(). +// +// The provided Context must be non-nil. If the context expires before the +// connection is complete, an error is returned. Once successfully connected, +// any expiration of the context will not affect the connection. +// +// See func Dial for additional information. func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) { if err := ctx.Err(); err != nil { return nil, err From 58d81774502604f7274f6b05edf2f8782291da86 Mon Sep 17 00:00:00 2001 From: Randy Reddig Date: Wed, 21 Jun 2023 08:51:59 -0700 Subject: [PATCH 3/6] ssh: CL feedback: show that net.Dialer also implements ContextDialer --- ssh/tcpip_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ssh/tcpip_test.go b/ssh/tcpip_test.go index 8fafeb5e26..9cc73e2f92 100644 --- a/ssh/tcpip_test.go +++ b/ssh/tcpip_test.go @@ -26,6 +26,9 @@ func TestClientImplementsDialContext(t *testing.T) { type ContextDialer interface { DialContext(context.Context, string, string) (net.Conn, error) } + // Belt and suspenders assertion, since package net does not + // declare a ContextDialer type. + var _ ContextDialer = &net.Dialer{} var _ ContextDialer = &Client{} } From d051e978ef5b2a646cf45134661a4db7c6ee416c Mon Sep 17 00:00:00 2001 From: Randy Reddig Date: Wed, 21 Jun 2023 09:36:54 -0700 Subject: [PATCH 4/6] ssh: remove redundant test --- ssh/tcpip_test.go | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/ssh/tcpip_test.go b/ssh/tcpip_test.go index 9cc73e2f92..4d85114727 100644 --- a/ssh/tcpip_test.go +++ b/ssh/tcpip_test.go @@ -51,13 +51,3 @@ func TestClientDialContextWithDeadline(t *testing.T) { t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded) } } - -func TestClientDialContextWithTimeout(t *testing.T) { - c := &Client{} - ctx, cancel := context.WithTimeout(context.Background(), 0) - defer cancel() - _, err := c.DialContext(ctx, "tcp", "localhost:1000") - if err != context.DeadlineExceeded { - t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded) - } -} From ce9c1793c51f61512a63715c2d880617d3ee45a7 Mon Sep 17 00:00:00 2001 From: Randy Reddig Date: Thu, 22 Jun 2023 08:44:48 -0700 Subject: [PATCH 5/6] ssh: CL feedback on test --- ssh/test/dial_unix_test.go | 47 +++++++++++++------------------------- 1 file changed, 16 insertions(+), 31 deletions(-) diff --git a/ssh/test/dial_unix_test.go b/ssh/test/dial_unix_test.go index 6f3f23351d..2a39fc0692 100644 --- a/ssh/test/dial_unix_test.go +++ b/ssh/test/dial_unix_test.go @@ -48,38 +48,23 @@ func testDial(t *testing.T, n, listenAddr string, x dialTester) { } }() - { - conn, err := sshConn.Dial(n, l.Addr().String()) - if err != nil { - t.Fatalf("Dial: %v", err) - } - x.TestClientConn(t, conn) - defer conn.Close() - b, err := io.ReadAll(conn) - if err != nil { - t.Fatalf("ReadAll: %v", err) - } - t.Logf("got %q", string(b)) - if string(b) != testData { - t.Fatalf("expected %q, got %q", testData, string(b)) - } + ctx, cancel := context.WithCancel(context.Background()) + conn, err := sshConn.DialContext(ctx, n, l.Addr().String()) + // Canceling the context after dial should have no effect + // on the opened connection. + cancel() + if err != nil { + t.Fatalf("Dial: %v", err) } - - { - conn, err := sshConn.DialContext(context.Background(), n, l.Addr().String()) - if err != nil { - t.Fatalf("Dial: %v", err) - } - x.TestClientConn(t, conn) - defer conn.Close() - b, err := io.ReadAll(conn) - if err != nil { - t.Fatalf("ReadAll: %v", err) - } - t.Logf("got %q", string(b)) - if string(b) != testData { - t.Fatalf("expected %q, got %q", testData, string(b)) - } + x.TestClientConn(t, conn) + defer conn.Close() + b, err := io.ReadAll(conn) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + t.Logf("got %q", string(b)) + if string(b) != testData { + t.Fatalf("expected %q, got %q", testData, string(b)) } } From 3176984a71a9a1422702e3a071340ecfff71ff62 Mon Sep 17 00:00:00 2001 From: Randy Reddig Date: Thu, 22 Jun 2023 09:05:25 -0700 Subject: [PATCH 6/6] empty commit