From b214d590714d1c6eb1a5fd7f2cb2fdba67a018d1 Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Mon, 16 Oct 2023 15:12:21 +0200 Subject: [PATCH 01/17] xds/internal/xdsclient: A65 - mTLS Credentials Implement A65: mTLS Credentials in xDS Bootstrap File described in https://github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md. --- xds/internal/xdsclient/bootstrap/bootstrap.go | 15 ++ .../xdsclient/bootstrap/bootstrap_test.go | 23 ++ xds/internal/xdsclient/creds/tls.go | 177 +++++++++++++ xds/internal/xdsclient/creds/tls_test.go | 233 ++++++++++++++++++ 4 files changed, 448 insertions(+) create mode 100644 xds/internal/xdsclient/creds/tls.go create mode 100644 xds/internal/xdsclient/creds/tls_test.go diff --git a/xds/internal/xdsclient/bootstrap/bootstrap.go b/xds/internal/xdsclient/bootstrap/bootstrap.go index 57fcb087b28b..027dfab74633 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap.go @@ -28,6 +28,8 @@ import ( "os" "strings" + "google.golang.org/grpc/xds/internal/xdsclient/creds" + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" "github.com/golang/protobuf/jsonpb" "google.golang.org/grpc" @@ -60,6 +62,7 @@ const ( func init() { bootstrap.RegisterCredentials(&insecureCredsBuilder{}) bootstrap.RegisterCredentials(&googleDefaultCredsBuilder{}) + bootstrap.RegisterCredentials(&tlsCredsBuilder{}) } // For overriding in unit tests. @@ -77,6 +80,18 @@ func (i *insecureCredsBuilder) Name() string { return "insecure" } +// tlsCredsBuilder implements the `Credentials` interface defined in +// package `xds/bootstrap` and encapsulates a TLS credential. +type tlsCredsBuilder struct{} + +func (t *tlsCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, error) { + return creds.NewTLS(config) +} + +func (t *tlsCredsBuilder) Name() string { + return "tls" +} + // googleDefaultCredsBuilder implements the `Credentials` interface defined in // package `xds/boostrap` and encapsulates a Google Default credential. type googleDefaultCredsBuilder struct{} diff --git a/xds/internal/xdsclient/bootstrap/bootstrap_test.go b/xds/internal/xdsclient/bootstrap/bootstrap_test.go index 84075743a8fe..2918dfe5d880 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap_test.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap_test.go @@ -1015,6 +1015,10 @@ func TestDefaultBundles(t *testing.T) { if c := bootstrap.GetCredentials("insecure"); c == nil { t.Errorf(`bootstrap.GetCredentials("insecure") credential is nil, want non-nil`) } + + if c := bootstrap.GetCredentials("tls"); c == nil { + t.Errorf(`bootstrap.GetCredentials("tls") credential is nil, want non-nil`) + } } func TestCredsBuilders(t *testing.T) { @@ -1034,4 +1038,23 @@ func TestCredsBuilders(t *testing.T) { if got, want := i.Name(), "insecure"; got != want { t.Errorf("insecureCredsBuilder.Name = %v, want %v", got, want) } + + tcb := &tlsCredsBuilder{} + if _, err := tcb.Build(nil); err == nil { + t.Errorf("tlsCredsBuilder.Build succeeded, want failure") + } + if got, want := tcb.Name(), "tls"; got != want { + t.Errorf("tlsCredsBuilder.Name = %v, want %v", got, want) + } +} + +func TestTlsCredsBuilder(t *testing.T) { + tls := &tlsCredsBuilder{} + if _, err := tls.Build(json.RawMessage(`{}`)); err != nil { + t.Errorf("unexpected error with empty config: %s", err) + } + if _, err := tls.Build(json.RawMessage(`{"ca_certificate_file":"/ca_certificates.pem","refresh_interval": "asdf"}`)); err == nil { + t.Errorf("expected an error with invalid refresh interval") + } + // more tests for config validity are defined in creds subpackage. } diff --git a/xds/internal/xdsclient/creds/tls.go b/xds/internal/xdsclient/creds/tls.go new file mode 100644 index 000000000000..9ef3d1fcc297 --- /dev/null +++ b/xds/internal/xdsclient/creds/tls.go @@ -0,0 +1,177 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package creds implements gRFC A65: mTLS Credentials in xDS Bootstrap File. +package creds + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "net" + "sync" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/tls/certprovider" + _ "google.golang.org/grpc/credentials/tls/certprovider/pemfile" // for file_watcher provider +) + +type tlsBundle struct { + jd json.RawMessage + transportCredentials credentials.TransportCredentials +} + +// NewTLS returns a credentials.Bundle which delegates certificate loading to +// a file_watcher provider for mTLS transport security. See gRFC A65. +func NewTLS(jd json.RawMessage) (credentials.Bundle, error) { + cfg := &struct { + CertificateFile string `json:"certificate_file"` + CACertificateFile string `json:"ca_certificate_file"` + PrivateKeyFile string `json:"private_key_file"` + }{} + + tlsConfig := tls.Config{} + if err := json.Unmarshal(jd, cfg); err != nil { + return nil, err + } + + // We cannot simply always use a file_watcher provider because it behaves + // slightly differently from the xDS TLS config. Quoting A65: + // + // > The only difference between the file-watcher certificate provider + // > config and this one is that in the file-watcher certificate provider, + // > at least one of the "certificate_file" or "ca_certificate_file" fields + // > must be specified, whereas in this configuration, it is acceptable to + // > specify neither one. + // + // We only use a file_watcher provider if either one of them or both are + // specified. + if cfg.CACertificateFile != "" || cfg.CertificateFile != "" || cfg.PrivateKeyFile != "" { + // file_watcher currently ignores BuildOptions, but we set them for good + // measure. + opts := certprovider.BuildOptions{} + if cfg.CACertificateFile != "" { + opts.WantRoot = true + } + if cfg.CertificateFile != "" { + opts.WantIdentity = true + } + provider, err := certprovider.GetProvider("file_watcher", jd, opts) + if err != nil { + // GetProvider fails if jd is invalid, e.g. if only one of private + // key and certificate is specified. + return nil, fmt.Errorf("failed to get TLS provider: %w", err) + } + if cfg.CertificateFile != "" { + tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + // Client cert reloading for mTLS. + km, err := provider.KeyMaterial(context.Background()) + if err != nil { + return nil, err + } + if len(km.Certs) != 1 { + return nil, fmt.Errorf("there should be exactly exactly one certificate") + } + return &km.Certs[0], nil + } + if cfg.CACertificateFile == "" { + // no need for a callback to load the CA. Use the normal mTLS + // transport credentials. + return &tlsBundle{ + jd: jd, + transportCredentials: credentials.NewTLS(&tlsConfig), + }, nil + } + } + return &tlsBundle{ + jd: jd, + transportCredentials: &caReloadingClientTLSCreds{ + baseConfig: &tlsConfig, + provider: provider, + }, + }, nil + } + + // None of certificate_file and ca_certificate_file are set. + // Use the system-wide root certs. + return &tlsBundle{ + jd: jd, + transportCredentials: credentials.NewTLS(&tlsConfig), + }, nil +} + +func (t *tlsBundle) TransportCredentials() credentials.TransportCredentials { + return t.transportCredentials +} + +func (t *tlsBundle) PerRPCCredentials() credentials.PerRPCCredentials { + // No per-RPC credentials in A65. + return nil +} + +func (t *tlsBundle) NewWithMode(_ string) (credentials.Bundle, error) { + return NewTLS(t.jd) +} + +// caReloadingClientTLSCreds is a client mTLS credentials.TransportCredentials +// that attempts to reload the server root certificate from its provider on +// every client handshake. This is needed because Go's tls.Config does not +// support reloading the root CAs. +type caReloadingClientTLSCreds struct { + mu sync.Mutex + provider certprovider.Provider + baseConfig *tls.Config +} + +func (c *caReloadingClientTLSCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + km, err := c.provider.KeyMaterial(ctx) + if err != nil { + return nil, nil, err + } + c.mu.Lock() + if !km.Roots.Equal(c.baseConfig.RootCAs) { + // provider returned a different root CA. Update it. + c.baseConfig.RootCAs = km.Roots + } + c.mu.Unlock() + return credentials.NewTLS(c.baseConfig).ClientHandshake(ctx, authority, rawConn) +} + +func (c *caReloadingClientTLSCreds) Info() credentials.ProtocolInfo { + c.mu.Lock() + defer c.mu.Unlock() + return credentials.NewTLS(c.baseConfig).Info() +} + +func (c *caReloadingClientTLSCreds) Clone() credentials.TransportCredentials { + c.mu.Lock() + defer c.mu.Unlock() + return &caReloadingClientTLSCreds{ + provider: c.provider, + baseConfig: c.baseConfig.Clone(), + } +} + +func (c *caReloadingClientTLSCreds) OverrideServerName(_ string) error { + panic("cannot override server name for private xds tls credentials") +} + +func (c *caReloadingClientTLSCreds) ServerHandshake(_ net.Conn) (net.Conn, credentials.AuthInfo, error) { + panic("server handshake for xds tls credentials, which are client only") +} diff --git a/xds/internal/xdsclient/creds/tls_test.go b/xds/internal/xdsclient/creds/tls_test.go new file mode 100644 index 000000000000..2a7734ec64eb --- /dev/null +++ b/xds/internal/xdsclient/creds/tls_test.go @@ -0,0 +1,233 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package creds + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "net" + "os" + "strings" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/status" + "google.golang.org/grpc/testdata" + + testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" +) + +func TestValidTlsBuilder(t *testing.T) { + tests := []string{ + `{}`, + `{"ca_certificate_file": "foo"}`, + `{"certificate_file":"bar","private_key_file":"baz"}`, + `{"ca_certificate_file":"foo","certificate_file":"bar","private_key_file":"baz"}`, + `{"refresh_interval": "1s"}`, + `{"refresh_interval": "1s","ca_certificate_file": "foo"}`, + `{"refresh_interval": "1s","certificate_file":"bar","private_key_file":"baz"}`, + `{"refresh_interval": "1s","ca_certificate_file":"foo","certificate_file":"bar","private_key_file":"baz"}`, + } + + for _, jd := range tests { + t.Run(jd, func(t *testing.T) { + msg := json.RawMessage(jd) + if _, err := NewTLS(msg); err != nil { + t.Errorf("expected no error but got: %s", err) + } + }) + } +} + +func TestInvalidTlsBuilder(t *testing.T) { + tests := []struct { + jd, err string + }{ + {`{"ca_certificate_file": 1}`, "json: cannot unmarshal number into Go struct field .ca_certificate_file of type string"}, + {`{"certificate_file":"bar"}`, "failed to get TLS provider: pemfile: private key file and identity cert file should be both specified or not specified"}, + } + + for _, test := range tests { + t.Run(test.jd, func(t *testing.T) { + msg := json.RawMessage(test.jd) + if _, err := NewTLS(msg); err.Error() != test.err { + t.Errorf("expected: %s, got: %s", test.err, err) + } + }) + } +} + +type testServer struct { + testgrpc.UnimplementedTestServiceServer +} + +func (t testServer) EmptyCall(_ context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil +} + +func TestCaReloading(t *testing.T) { + serverCa, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem")) + if err != nil { + t.Fatalf("failed to read test CA cert: %s", err) + } + + // Write CA to a temporary file so that we can modify it later. + caPath := t.TempDir() + "/ca.pem" + err = os.WriteFile(caPath, serverCa, 0644) + if err != nil { + t.Fatalf("failed to write test CA cert: %v", err) + } + cfg := fmt.Sprintf(`{ + "ca_certificate_file": "%s", + "refresh_interval": ".01s" + }`, caPath) + tlsBundle, err := NewTLS([]byte(cfg)) + if err != nil { + t.Fatalf("failed to create TLS bundle: %v", err) + } + + // TLS server with a valid cert + serverCreds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) + if err != nil { + t.Fatalf("failed to generate server credentials: %v", err) + } + s := grpc.NewServer(grpc.Creds(serverCreds)) + testgrpc.RegisterTestServiceServer(s, &testServer{}) + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + go s.Serve(lis) + + conn, err := grpc.Dial( + lis.Addr().String(), + grpc.WithCredentialsBundle(tlsBundle), + grpc.WithAuthority("x.test.example.com"), + ) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + defer conn.Close() + client := testgrpc.NewTestServiceClient(conn) + _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) + if err != nil { + t.Errorf("error calling EmptyCall: %v", err) + } + // close the server so that we force a new handshake later on. + s.Stop() + + invalidCa, err := os.ReadFile(testdata.Path("ca.pem")) + if err != nil { + t.Fatalf("failed to read test CA cert: %v", err) + } + // unload root cert + err = os.WriteFile(caPath, invalidCa, 0644) + if err != nil { + t.Fatalf("failed to write test CA cert: %v", err) + } + + // Leave time for the file_watcher provider to reload the CA. + time.Sleep(100 * time.Millisecond) + + s = grpc.NewServer(grpc.Creds(serverCreds)) + defer s.Stop() + testgrpc.RegisterTestServiceServer(s, &testServer{}) + lis, err = net.Listen("tcp", lis.Addr().String()) + if err != nil { + t.Fatalf("error listening: %v", err) + } + go s.Serve(lis) + + // Client handshake should fail because the server cert is signed by an + // unknown CA. + _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) + if st, ok := status.FromError(err); !ok || st.Code() != codes.Unavailable { + t.Errorf("expected unavailable error, got %v", err) + if strings.Contains(st.Message(), "x509: certificate signed by unknown authority") { + t.Errorf("expected error to contain 'x509: certificate signed by unknown authority', got %v", st.Message()) + } + } +} + +func TestMTLS(t *testing.T) { + cfg := fmt.Sprintf(`{ + "ca_certificate_file": "%s", + "certificate_file": "%s", + "private_key_file": "%s" + }`, + testdata.Path("x509/server_ca_cert.pem"), + testdata.Path("x509/client1_cert.pem"), + testdata.Path("x509/client1_key.pem")) + tlsBundle, err := NewTLS([]byte(cfg)) + if err != nil { + t.Fatalf("failed to create TLS bundle: %v", err) + } + + // Create a TLS server with a valid cert that requires a client cert. + serverCert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) + if err != nil { + t.Fatalf("failed to load server cert: %v", err) + } + pemClientCA, err := os.ReadFile(testdata.Path("x509/client_ca_cert.pem")) + if err != nil { + t.Fatalf("failed to read test client CA cert: %v", err) + } + clientCA := x509.NewCertPool() + if !clientCA.AppendCertsFromPEM(pemClientCA) { + t.Fatal("failed to add client CA's certificate") + } + serverTLSCfg := &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: clientCA, + } + if err != nil { + t.Fatalf("failed to generate server credentials: %v", err) + } + s := grpc.NewServer(grpc.Creds(credentials.NewTLS(serverTLSCfg))) + testgrpc.RegisterTestServiceServer(s, &testServer{}) + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + go s.Serve(lis) + defer s.Stop() + + conn, err := grpc.Dial( + lis.Addr().String(), + grpc.WithCredentialsBundle(tlsBundle), + grpc.WithAuthority("x.test.example.com"), + ) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + defer conn.Close() + client := testgrpc.NewTestServiceClient(conn) + _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) + if err != nil { + t.Errorf("error calling EmptyCall: %v", err) + } +} From 114af39c7ee17efa8d3d3f6e240c724f75af2499 Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Thu, 2 Nov 2023 12:26:02 +0100 Subject: [PATCH 02/17] feedback from Arvind --- xds/internal/xdsclient/bootstrap/bootstrap.go | 3 +- xds/internal/xdsclient/creds/tls.go | 48 ++++++++++--------- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/xds/internal/xdsclient/bootstrap/bootstrap.go b/xds/internal/xdsclient/bootstrap/bootstrap.go index 027dfab74633..127987b71647 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap.go @@ -28,8 +28,6 @@ import ( "os" "strings" - "google.golang.org/grpc/xds/internal/xdsclient/creds" - v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" "github.com/golang/protobuf/jsonpb" "google.golang.org/grpc" @@ -41,6 +39,7 @@ import ( "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/xds/bootstrap" + "google.golang.org/grpc/xds/internal/xdsclient/creds" ) const ( diff --git a/xds/internal/xdsclient/creds/tls.go b/xds/internal/xdsclient/creds/tls.go index 9ef3d1fcc297..70048d873ebb 100644 --- a/xds/internal/xdsclient/creds/tls.go +++ b/xds/internal/xdsclient/creds/tls.go @@ -16,7 +16,8 @@ * */ -// Package creds implements gRFC A65: mTLS Credentials in xDS Bootstrap File. +// Package creds implements mTLS Credentials in xDS Bootstrap File. +// See [gRFC A65](github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md). package creds import ( @@ -32,13 +33,18 @@ import ( _ "google.golang.org/grpc/credentials/tls/certprovider/pemfile" // for file_watcher provider ) +// tlsBundle is an implementation of credentials.Bundle which implements mTLS +// Credentials in xDS Bootstrap File. +// See [gRFC A65](github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md). type tlsBundle struct { jd json.RawMessage transportCredentials credentials.TransportCredentials } -// NewTLS returns a credentials.Bundle which delegates certificate loading to -// a file_watcher provider for mTLS transport security. See gRFC A65. +// NewTLS returns a credentials.Bundle which implements mTLS Credentials in xDS +// Bootstrap File. It delegates certificate loading to a file_watcher provider +// if either client certificates or server root CA is specified. +// See [gRFC A65](github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md). func NewTLS(jd json.RawMessage) (credentials.Bundle, error) { cfg := &struct { CertificateFile string `json:"certificate_file"` @@ -63,16 +69,8 @@ func NewTLS(jd json.RawMessage) (credentials.Bundle, error) { // We only use a file_watcher provider if either one of them or both are // specified. if cfg.CACertificateFile != "" || cfg.CertificateFile != "" || cfg.PrivateKeyFile != "" { - // file_watcher currently ignores BuildOptions, but we set them for good - // measure. - opts := certprovider.BuildOptions{} - if cfg.CACertificateFile != "" { - opts.WantRoot = true - } - if cfg.CertificateFile != "" { - opts.WantIdentity = true - } - provider, err := certprovider.GetProvider("file_watcher", jd, opts) + // file_watcher currently ignores BuildOptions. + provider, err := certprovider.GetProvider("file_watcher", jd, certprovider.BuildOptions{}) if err != nil { // GetProvider fails if jd is invalid, e.g. if only one of private // key and certificate is specified. @@ -86,12 +84,15 @@ func NewTLS(jd json.RawMessage) (credentials.Bundle, error) { return nil, err } if len(km.Certs) != 1 { - return nil, fmt.Errorf("there should be exactly exactly one certificate") + // xDS bootstrap has a single private key file, so there + // must be exactly one certificate or certificate chain + // matching this private key. + return nil, fmt.Errorf("certificate_file must contains exactly one certificate or certificate chain") } return &km.Certs[0], nil } if cfg.CACertificateFile == "" { - // no need for a callback to load the CA. Use the normal mTLS + // No need for a callback to load the CA. Use the normal mTLS // transport credentials. return &tlsBundle{ jd: jd, @@ -121,7 +122,8 @@ func (t *tlsBundle) TransportCredentials() credentials.TransportCredentials { } func (t *tlsBundle) PerRPCCredentials() credentials.PerRPCCredentials { - // No per-RPC credentials in A65. + // mTLS provides transport credentials only. There are no per-RPC + // credentials. return nil } @@ -129,10 +131,10 @@ func (t *tlsBundle) NewWithMode(_ string) (credentials.Bundle, error) { return NewTLS(t.jd) } -// caReloadingClientTLSCreds is a client mTLS credentials.TransportCredentials -// that attempts to reload the server root certificate from its provider on -// every client handshake. This is needed because Go's tls.Config does not -// support reloading the root CAs. +// caReloadingClientTLSCreds is credentials.TransportCredentials for client +// side mTLS that attempts to reload the server root certificate from its +// provider on every client handshake. This is needed because Go's tls.Config +// does not support reloading the root CAs. type caReloadingClientTLSCreds struct { mu sync.Mutex provider certprovider.Provider @@ -145,11 +147,11 @@ func (c *caReloadingClientTLSCreds) ClientHandshake(ctx context.Context, authori return nil, nil, err } c.mu.Lock() + defer c.mu.Unlock() if !km.Roots.Equal(c.baseConfig.RootCAs) { - // provider returned a different root CA. Update it. + // Provider returned a different root CA. Update it. c.baseConfig.RootCAs = km.Roots } - c.mu.Unlock() return credentials.NewTLS(c.baseConfig).ClientHandshake(ctx, authority, rawConn) } @@ -173,5 +175,5 @@ func (c *caReloadingClientTLSCreds) OverrideServerName(_ string) error { } func (c *caReloadingClientTLSCreds) ServerHandshake(_ net.Conn) (net.Conn, credentials.AuthInfo, error) { - panic("server handshake for xds tls credentials, which are client only") + panic("cannot perform handshake for server. xDS TLS credentials are client only.") } From 85e4902e818529f3cf5d81079031f832b0a1ade0 Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Thu, 7 Dec 2023 20:34:37 +0100 Subject: [PATCH 03/17] comments from easwars --- .../tls/certprovider/pemfile/builder.go | 6 +- xds/internal/xdsclient/bootstrap/bootstrap.go | 4 +- .../{creds/tls.go => tlscreds/bundle.go} | 131 +++++++++--------- .../tls_test.go => tlscreds/bundle_test.go} | 12 +- 4 files changed, 77 insertions(+), 76 deletions(-) rename xds/internal/xdsclient/{creds/tls.go => tlscreds/bundle.go} (53%) rename xds/internal/xdsclient/{creds/tls_test.go => tlscreds/bundle_test.go} (95%) diff --git a/credentials/tls/certprovider/pemfile/builder.go b/credentials/tls/certprovider/pemfile/builder.go index 8d8e2d4a0f5a..8c15baeb59f6 100644 --- a/credentials/tls/certprovider/pemfile/builder.go +++ b/credentials/tls/certprovider/pemfile/builder.go @@ -29,7 +29,7 @@ import ( ) const ( - pluginName = "file_watcher" + PluginName = "file_watcher" defaultRefreshInterval = 10 * time.Minute ) @@ -48,13 +48,13 @@ func (p *pluginBuilder) ParseConfig(c any) (*certprovider.BuildableConfig, error if err != nil { return nil, err } - return certprovider.NewBuildableConfig(pluginName, opts.canonical(), func(certprovider.BuildOptions) certprovider.Provider { + return certprovider.NewBuildableConfig(PluginName, opts.canonical(), func(certprovider.BuildOptions) certprovider.Provider { return newProvider(opts) }), nil } func (p *pluginBuilder) Name() string { - return pluginName + return PluginName } func pluginConfigFromJSON(jd json.RawMessage) (Options, error) { diff --git a/xds/internal/xdsclient/bootstrap/bootstrap.go b/xds/internal/xdsclient/bootstrap/bootstrap.go index 127987b71647..89b66952ef5f 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap.go @@ -39,7 +39,7 @@ import ( "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/xds/bootstrap" - "google.golang.org/grpc/xds/internal/xdsclient/creds" + "google.golang.org/grpc/xds/internal/xdsclient/tlscreds" ) const ( @@ -84,7 +84,7 @@ func (i *insecureCredsBuilder) Name() string { type tlsCredsBuilder struct{} func (t *tlsCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, error) { - return creds.NewTLS(config) + return tlscreds.NewBundle(config) } func (t *tlsCredsBuilder) Name() string { diff --git a/xds/internal/xdsclient/creds/tls.go b/xds/internal/xdsclient/tlscreds/bundle.go similarity index 53% rename from xds/internal/xdsclient/creds/tls.go rename to xds/internal/xdsclient/tlscreds/bundle.go index 70048d873ebb..e9ac4e90ef69 100644 --- a/xds/internal/xdsclient/creds/tls.go +++ b/xds/internal/xdsclient/tlscreds/bundle.go @@ -16,9 +16,9 @@ * */ -// Package creds implements mTLS Credentials in xDS Bootstrap File. +// Package tlscreds implements mTLS Credentials in xDS Bootstrap File. // See [gRFC A65](github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md). -package creds +package tlscreds import ( "context" @@ -30,105 +30,106 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" - _ "google.golang.org/grpc/credentials/tls/certprovider/pemfile" // for file_watcher provider + "google.golang.org/grpc/credentials/tls/certprovider/pemfile" ) -// tlsBundle is an implementation of credentials.Bundle which implements mTLS +// bundle is an implementation of credentials.Bundle which implements mTLS // Credentials in xDS Bootstrap File. // See [gRFC A65](github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md). -type tlsBundle struct { +type bundle struct { jd json.RawMessage transportCredentials credentials.TransportCredentials } -// NewTLS returns a credentials.Bundle which implements mTLS Credentials in xDS +// NewBundle returns a credentials.Bundle which implements mTLS Credentials in xDS // Bootstrap File. It delegates certificate loading to a file_watcher provider // if either client certificates or server root CA is specified. // See [gRFC A65](github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md). -func NewTLS(jd json.RawMessage) (credentials.Bundle, error) { +func NewBundle(jd json.RawMessage) (credentials.Bundle, error) { cfg := &struct { CertificateFile string `json:"certificate_file"` CACertificateFile string `json:"ca_certificate_file"` PrivateKeyFile string `json:"private_key_file"` }{} - tlsConfig := tls.Config{} if err := json.Unmarshal(jd, cfg); err != nil { return nil, err } - // We cannot simply always use a file_watcher provider because it behaves - // slightly differently from the xDS TLS config. Quoting A65: - // - // > The only difference between the file-watcher certificate provider - // > config and this one is that in the file-watcher certificate provider, - // > at least one of the "certificate_file" or "ca_certificate_file" fields - // > must be specified, whereas in this configuration, it is acceptable to - // > specify neither one. - // - // We only use a file_watcher provider if either one of them or both are - // specified. - if cfg.CACertificateFile != "" || cfg.CertificateFile != "" || cfg.PrivateKeyFile != "" { - // file_watcher currently ignores BuildOptions. - provider, err := certprovider.GetProvider("file_watcher", jd, certprovider.BuildOptions{}) - if err != nil { - // GetProvider fails if jd is invalid, e.g. if only one of private - // key and certificate is specified. - return nil, fmt.Errorf("failed to get TLS provider: %w", err) - } - if cfg.CertificateFile != "" { - tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { - // Client cert reloading for mTLS. - km, err := provider.KeyMaterial(context.Background()) - if err != nil { - return nil, err - } - if len(km.Certs) != 1 { - // xDS bootstrap has a single private key file, so there - // must be exactly one certificate or certificate chain - // matching this private key. - return nil, fmt.Errorf("certificate_file must contains exactly one certificate or certificate chain") - } - return &km.Certs[0], nil + if cfg.CACertificateFile == "" && cfg.CertificateFile == "" && cfg.PrivateKeyFile == "" { + // We do not always use a file_watcher provider because it behaves + // slightly differently from the xDS TLS config provider. Quoting A65: + // + // > The only difference between the file-watcher certificate provider + // > config and this one is that in the file-watcher certificate provider, + // > at least one of the "certificate_file" or "ca_certificate_file" fields + // > must be specified, whereas in this configuration, it is acceptable to + // > specify neither one. + // + // Here, none of certificate_file and ca_certificate_file are set. + // Use the system-wide root certs. No need to use the file_watcher + // provider. + return &bundle{ + jd: jd, + transportCredentials: credentials.NewTLS(&tls.Config{}), + }, nil + } + + tlsConfig := tls.Config{} + // The pemfile plugin (file_watcher) currently ignores BuildOptions. + provider, err := certprovider.GetProvider(pemfile.PluginName, jd, certprovider.BuildOptions{}) + if err != nil { + // GetProvider fails if jd is invalid, e.g. if only one of private + // key and certificate is specified. + return nil, err + } + if cfg.CertificateFile != "" { + tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + // Client cert reloading for mTLS. + km, err := provider.KeyMaterial(context.Background()) + if err != nil { + return nil, err } - if cfg.CACertificateFile == "" { - // No need for a callback to load the CA. Use the normal mTLS - // transport credentials. - return &tlsBundle{ - jd: jd, - transportCredentials: credentials.NewTLS(&tlsConfig), - }, nil + if len(km.Certs) != 1 { + // xDS bootstrap has a single private key file, so there + // must be exactly one certificate or certificate chain + // matching this private key. + return nil, fmt.Errorf("certificate_file must contains exactly one certificate or certificate chain") } + return &km.Certs[0], nil + } + if cfg.CACertificateFile == "" { + // No need for a callback to load the CA. Use the normal mTLS + // transport credentials. + return &bundle{ + jd: jd, + transportCredentials: credentials.NewTLS(&tlsConfig), + }, nil } - return &tlsBundle{ - jd: jd, - transportCredentials: &caReloadingClientTLSCreds{ - baseConfig: &tlsConfig, - provider: provider, - }, - }, nil } - - // None of certificate_file and ca_certificate_file are set. - // Use the system-wide root certs. - return &tlsBundle{ - jd: jd, - transportCredentials: credentials.NewTLS(&tlsConfig), + return &bundle{ + jd: jd, + transportCredentials: &caReloadingClientTLSCreds{ + baseConfig: &tlsConfig, + provider: provider, + }, }, nil } -func (t *tlsBundle) TransportCredentials() credentials.TransportCredentials { +func (t *bundle) TransportCredentials() credentials.TransportCredentials { return t.transportCredentials } -func (t *tlsBundle) PerRPCCredentials() credentials.PerRPCCredentials { +func (t *bundle) PerRPCCredentials() credentials.PerRPCCredentials { // mTLS provides transport credentials only. There are no per-RPC // credentials. return nil } -func (t *tlsBundle) NewWithMode(_ string) (credentials.Bundle, error) { - return NewTLS(t.jd) +func (t *bundle) NewWithMode(string) (credentials.Bundle, error) { + // This bundle has a single mode which only uses TLS transport credentials, + // so there is no legitimate case where callers would call NewWithMode. + return nil, fmt.Errorf("xDS TLS credentials only support one mode") } // caReloadingClientTLSCreds is credentials.TransportCredentials for client diff --git a/xds/internal/xdsclient/creds/tls_test.go b/xds/internal/xdsclient/tlscreds/bundle_test.go similarity index 95% rename from xds/internal/xdsclient/creds/tls_test.go rename to xds/internal/xdsclient/tlscreds/bundle_test.go index 2a7734ec64eb..5a9df4a2e979 100644 --- a/xds/internal/xdsclient/creds/tls_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_test.go @@ -16,7 +16,7 @@ * */ -package creds +package tlscreds import ( "context" @@ -55,7 +55,7 @@ func TestValidTlsBuilder(t *testing.T) { for _, jd := range tests { t.Run(jd, func(t *testing.T) { msg := json.RawMessage(jd) - if _, err := NewTLS(msg); err != nil { + if _, err := NewBundle(msg); err != nil { t.Errorf("expected no error but got: %s", err) } }) @@ -67,13 +67,13 @@ func TestInvalidTlsBuilder(t *testing.T) { jd, err string }{ {`{"ca_certificate_file": 1}`, "json: cannot unmarshal number into Go struct field .ca_certificate_file of type string"}, - {`{"certificate_file":"bar"}`, "failed to get TLS provider: pemfile: private key file and identity cert file should be both specified or not specified"}, + {`{"certificate_file":"bar"}`, "pemfile: private key file and identity cert file should be both specified or not specified"}, } for _, test := range tests { t.Run(test.jd, func(t *testing.T) { msg := json.RawMessage(test.jd) - if _, err := NewTLS(msg); err.Error() != test.err { + if _, err := NewBundle(msg); err.Error() != test.err { t.Errorf("expected: %s, got: %s", test.err, err) } }) @@ -104,7 +104,7 @@ func TestCaReloading(t *testing.T) { "ca_certificate_file": "%s", "refresh_interval": ".01s" }`, caPath) - tlsBundle, err := NewTLS([]byte(cfg)) + tlsBundle, err := NewBundle([]byte(cfg)) if err != nil { t.Fatalf("failed to create TLS bundle: %v", err) } @@ -181,7 +181,7 @@ func TestMTLS(t *testing.T) { testdata.Path("x509/server_ca_cert.pem"), testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem")) - tlsBundle, err := NewTLS([]byte(cfg)) + tlsBundle, err := NewBundle([]byte(cfg)) if err != nil { t.Fatalf("failed to create TLS bundle: %v", err) } From 1c1654ba97c178b242b559b9cc14ca7c7fda324f Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Thu, 7 Dec 2023 21:37:11 +0100 Subject: [PATCH 04/17] reload both CA and certs on each handshake - add test for failing provider --- xds/internal/xdsclient/tlscreds/bundle.go | 106 +++++----------- .../xdsclient/tlscreds/bundle_test.go | 115 +++++++++++------- 2 files changed, 101 insertions(+), 120 deletions(-) diff --git a/xds/internal/xdsclient/tlscreds/bundle.go b/xds/internal/xdsclient/tlscreds/bundle.go index e9ac4e90ef69..469c7c6bb5ab 100644 --- a/xds/internal/xdsclient/tlscreds/bundle.go +++ b/xds/internal/xdsclient/tlscreds/bundle.go @@ -17,7 +17,7 @@ */ // Package tlscreds implements mTLS Credentials in xDS Bootstrap File. -// See [gRFC A65](github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md). +// See gRFC A65: github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md. package tlscreds import ( @@ -26,7 +26,6 @@ import ( "encoding/json" "fmt" "net" - "sync" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" @@ -35,16 +34,14 @@ import ( // bundle is an implementation of credentials.Bundle which implements mTLS // Credentials in xDS Bootstrap File. -// See [gRFC A65](github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md). type bundle struct { - jd json.RawMessage transportCredentials credentials.TransportCredentials } // NewBundle returns a credentials.Bundle which implements mTLS Credentials in xDS // Bootstrap File. It delegates certificate loading to a file_watcher provider // if either client certificates or server root CA is specified. -// See [gRFC A65](github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md). +// See gRFC A65: github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md func NewBundle(jd json.RawMessage) (credentials.Bundle, error) { cfg := &struct { CertificateFile string `json:"certificate_file"` @@ -57,61 +54,30 @@ func NewBundle(jd json.RawMessage) (credentials.Bundle, error) { } if cfg.CACertificateFile == "" && cfg.CertificateFile == "" && cfg.PrivateKeyFile == "" { - // We do not always use a file_watcher provider because it behaves - // slightly differently from the xDS TLS config provider. Quoting A65: + // We cannot use (and do not need) a file_watcher provider in this case, + // and can simply directly use the TLS transport credentials. + // Quoting A65: // // > The only difference between the file-watcher certificate provider - // > config and this one is that in the file-watcher certificate provider, - // > at least one of the "certificate_file" or "ca_certificate_file" fields - // > must be specified, whereas in this configuration, it is acceptable to - // > specify neither one. - // - // Here, none of certificate_file and ca_certificate_file are set. - // Use the system-wide root certs. No need to use the file_watcher - // provider. + // > config and this one is that in the file-watcher certificate + // > provider, at least one of the "certificate_file" or + // > "ca_certificate_file" fields must be specified, whereas in this + // > configuration, it is acceptable to specify neither one. return &bundle{ - jd: jd, transportCredentials: credentials.NewTLS(&tls.Config{}), }, nil } + // Otherwise we need to use a file_watcher provider to watch the CA, + // private and public keys. - tlsConfig := tls.Config{} // The pemfile plugin (file_watcher) currently ignores BuildOptions. provider, err := certprovider.GetProvider(pemfile.PluginName, jd, certprovider.BuildOptions{}) if err != nil { - // GetProvider fails if jd is invalid, e.g. if only one of private - // key and certificate is specified. return nil, err } - if cfg.CertificateFile != "" { - tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { - // Client cert reloading for mTLS. - km, err := provider.KeyMaterial(context.Background()) - if err != nil { - return nil, err - } - if len(km.Certs) != 1 { - // xDS bootstrap has a single private key file, so there - // must be exactly one certificate or certificate chain - // matching this private key. - return nil, fmt.Errorf("certificate_file must contains exactly one certificate or certificate chain") - } - return &km.Certs[0], nil - } - if cfg.CACertificateFile == "" { - // No need for a callback to load the CA. Use the normal mTLS - // transport credentials. - return &bundle{ - jd: jd, - transportCredentials: credentials.NewTLS(&tlsConfig), - }, nil - } - } return &bundle{ - jd: jd, - transportCredentials: &caReloadingClientTLSCreds{ - baseConfig: &tlsConfig, - provider: provider, + transportCredentials: &reloadingCreds{ + provider: provider, }, }, nil } @@ -132,49 +98,41 @@ func (t *bundle) NewWithMode(string) (credentials.Bundle, error) { return nil, fmt.Errorf("xDS TLS credentials only support one mode") } -// caReloadingClientTLSCreds is credentials.TransportCredentials for client -// side mTLS that attempts to reload the server root certificate from its -// provider on every client handshake. This is needed because Go's tls.Config -// does not support reloading the root CAs. -type caReloadingClientTLSCreds struct { - mu sync.Mutex - provider certprovider.Provider - baseConfig *tls.Config +// reloadingCreds is a credentials.TransportCredentials for client +// side mTLS that reloads the server root CA certificate and the client +// certificates from the provider on every client handshake. This is necessary +// because the standard TLS credentials do not support reloading CA +// certificates. +type reloadingCreds struct { + provider certprovider.Provider } -func (c *caReloadingClientTLSCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { +func (c *reloadingCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { km, err := c.provider.KeyMaterial(ctx) if err != nil { return nil, nil, err } - c.mu.Lock() - defer c.mu.Unlock() - if !km.Roots.Equal(c.baseConfig.RootCAs) { - // Provider returned a different root CA. Update it. - c.baseConfig.RootCAs = km.Roots + config := &tls.Config{ + RootCAs: km.Roots, + Certificates: km.Certs, } - return credentials.NewTLS(c.baseConfig).ClientHandshake(ctx, authority, rawConn) + return credentials.NewTLS(config).ClientHandshake(ctx, authority, rawConn) } -func (c *caReloadingClientTLSCreds) Info() credentials.ProtocolInfo { - c.mu.Lock() - defer c.mu.Unlock() - return credentials.NewTLS(c.baseConfig).Info() +func (c *reloadingCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{SecurityProtocol: "tls"} } -func (c *caReloadingClientTLSCreds) Clone() credentials.TransportCredentials { - c.mu.Lock() - defer c.mu.Unlock() - return &caReloadingClientTLSCreds{ - provider: c.provider, - baseConfig: c.baseConfig.Clone(), +func (c *reloadingCreds) Clone() credentials.TransportCredentials { + return &reloadingCreds{ + provider: c.provider, } } -func (c *caReloadingClientTLSCreds) OverrideServerName(_ string) error { +func (c *reloadingCreds) OverrideServerName(_ string) error { panic("cannot override server name for private xds tls credentials") } -func (c *caReloadingClientTLSCreds) ServerHandshake(_ net.Conn) (net.Conn, credentials.AuthInfo, error) { +func (c *reloadingCreds) ServerHandshake(_ net.Conn) (net.Conn, credentials.AuthInfo, error) { panic("cannot perform handshake for server. xDS TLS credentials are client only.") } diff --git a/xds/internal/xdsclient/tlscreds/bundle_test.go b/xds/internal/xdsclient/tlscreds/bundle_test.go index 5a9df4a2e979..7982117431c9 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_test.go @@ -89,12 +89,14 @@ func (t testServer) EmptyCall(_ context.Context, _ *testpb.Empty) (*testpb.Empty } func TestCaReloading(t *testing.T) { + srvAddr, stopSrv := startServer(t, "localhost:0", tls.NoClientCert) + serverCa, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem")) if err != nil { t.Fatalf("failed to read test CA cert: %s", err) } - // Write CA to a temporary file so that we can modify it later. + // Write CA certs to a temporary file so that we can modify it later. caPath := t.TempDir() + "/ca.pem" err = os.WriteFile(caPath, serverCa, 0644) if err != nil { @@ -109,21 +111,8 @@ func TestCaReloading(t *testing.T) { t.Fatalf("failed to create TLS bundle: %v", err) } - // TLS server with a valid cert - serverCreds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) - if err != nil { - t.Fatalf("failed to generate server credentials: %v", err) - } - s := grpc.NewServer(grpc.Creds(serverCreds)) - testgrpc.RegisterTestServiceServer(s, &testServer{}) - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("error listening: %v", err) - } - go s.Serve(lis) - conn, err := grpc.Dial( - lis.Addr().String(), + srvAddr.String(), grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com"), ) @@ -136,8 +125,9 @@ func TestCaReloading(t *testing.T) { if err != nil { t.Errorf("error calling EmptyCall: %v", err) } - // close the server so that we force a new handshake later on. - s.Stop() + // close the server and create a new one to force client to do a new + // handshake. + stopSrv() invalidCa, err := os.ReadFile(testdata.Path("ca.pem")) if err != nil { @@ -152,27 +142,23 @@ func TestCaReloading(t *testing.T) { // Leave time for the file_watcher provider to reload the CA. time.Sleep(100 * time.Millisecond) - s = grpc.NewServer(grpc.Creds(serverCreds)) - defer s.Stop() - testgrpc.RegisterTestServiceServer(s, &testServer{}) - lis, err = net.Listen("tcp", lis.Addr().String()) - if err != nil { - t.Fatalf("error listening: %v", err) - } - go s.Serve(lis) + _, stopFunc := startServer(t, srvAddr.String(), tls.NoClientCert) + defer stopFunc() // Client handshake should fail because the server cert is signed by an // unknown CA. _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) if st, ok := status.FromError(err); !ok || st.Code() != codes.Unavailable { t.Errorf("expected unavailable error, got %v", err) - if strings.Contains(st.Message(), "x509: certificate signed by unknown authority") { - t.Errorf("expected error to contain 'x509: certificate signed by unknown authority', got %v", st.Message()) - } + } else if !strings.Contains(st.Message(), "certificate signed by unknown authority") { + t.Errorf("expected error to contain 'certificate signed by unknown authority', got %v", st.Message()) } } func TestMTLS(t *testing.T) { + srvAddr, stopFunc := startServer(t, "localhost:0", tls.RequireAndVerifyClientCert) + defer stopFunc() + cfg := fmt.Sprintf(`{ "ca_certificate_file": "%s", "certificate_file": "%s", @@ -185,7 +171,59 @@ func TestMTLS(t *testing.T) { if err != nil { t.Fatalf("failed to create TLS bundle: %v", err) } + dialOpts := []grpc.DialOption{ + grpc.WithCredentialsBundle(tlsBundle), + grpc.WithAuthority("x.test.example.com"), + } + + t.Run("ValidClientCert", func(t *testing.T) { + conn, err := grpc.Dial(srvAddr.String(), dialOpts...) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + client := testgrpc.NewTestServiceClient(conn) + + _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) + if err != nil { + t.Errorf("error calling EmptyCall: %v", err) + } + conn.Close() + }) + t.Run("Provider failing", func(t *testing.T) { + // Check that if the provider returns an errors, we fail the handshake. + // It's not easy to trigger this condition, so we rely on closing the + // provider. + creds, ok := tlsBundle.TransportCredentials().(*reloadingCreds) + + // Force the provider to be initialized. The test is flaky otherwise, + // since close may be a noop. + _, _ = creds.provider.KeyMaterial(context.Background()) + + if !ok { + t.Fatalf("expected reloadingCreds, got %T", tlsBundle.TransportCredentials()) + } + + creds.provider.Close() + + conn, err := grpc.Dial(srvAddr.String(), dialOpts...) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + client := testgrpc.NewTestServiceClient(conn) + _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) + if st, ok := status.FromError(err); !ok || st.Code() != codes.Unavailable { + t.Errorf("expected unavailable error, got %v", err) + } else if !strings.Contains(st.Message(), "provider instance is closed") { + t.Errorf("expected error to contain 'provider instance is closed', got %v", st.Message()) + } + conn.Close() + }) +} + +type stopFunc func() + +func startServer(t *testing.T, addr string, clientAuth tls.ClientAuthType) (net.Addr, stopFunc) { // Create a TLS server with a valid cert that requires a client cert. serverCert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) if err != nil { @@ -201,7 +239,7 @@ func TestMTLS(t *testing.T) { } serverTLSCfg := &tls.Config{ Certificates: []tls.Certificate{serverCert}, - ClientAuth: tls.RequireAndVerifyClientCert, + ClientAuth: clientAuth, ClientCAs: clientCA, } if err != nil { @@ -209,25 +247,10 @@ func TestMTLS(t *testing.T) { } s := grpc.NewServer(grpc.Creds(credentials.NewTLS(serverTLSCfg))) testgrpc.RegisterTestServiceServer(s, &testServer{}) - lis, err := net.Listen("tcp", "localhost:0") + lis, err := net.Listen("tcp", addr) if err != nil { t.Fatalf("error listening: %v", err) } go s.Serve(lis) - defer s.Stop() - - conn, err := grpc.Dial( - lis.Addr().String(), - grpc.WithCredentialsBundle(tlsBundle), - grpc.WithAuthority("x.test.example.com"), - ) - if err != nil { - t.Fatalf("error dialing: %v", err) - } - defer conn.Close() - client := testgrpc.NewTestServiceClient(conn) - _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) - if err != nil { - t.Errorf("error calling EmptyCall: %v", err) - } + return lis.Addr(), func() { s.Stop() } } From c0c080408330c4e96f0914343f6110a1766dfe47 Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Fri, 15 Dec 2023 22:21:33 +0100 Subject: [PATCH 05/17] address easwars comments --- .../xdsclient/bootstrap/bootstrap_test.go | 10 ++-- xds/internal/xdsclient/tlscreds/bundle.go | 29 +++++------ .../xdsclient/tlscreds/bundle_test.go | 52 +++++++++---------- 3 files changed, 44 insertions(+), 47 deletions(-) diff --git a/xds/internal/xdsclient/bootstrap/bootstrap_test.go b/xds/internal/xdsclient/bootstrap/bootstrap_test.go index 2918dfe5d880..1ac9fb01fc72 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap_test.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap_test.go @@ -1040,8 +1040,8 @@ func TestCredsBuilders(t *testing.T) { } tcb := &tlsCredsBuilder{} - if _, err := tcb.Build(nil); err == nil { - t.Errorf("tlsCredsBuilder.Build succeeded, want failure") + if _, err := tcb.Build(nil); err != nil { + t.Errorf("tlsCredsBuilder.Build failed: %v", err) } if got, want := tcb.Name(), "tls"; got != want { t.Errorf("tlsCredsBuilder.Name = %v, want %v", got, want) @@ -1051,10 +1051,10 @@ func TestCredsBuilders(t *testing.T) { func TestTlsCredsBuilder(t *testing.T) { tls := &tlsCredsBuilder{} if _, err := tls.Build(json.RawMessage(`{}`)); err != nil { - t.Errorf("unexpected error with empty config: %s", err) + t.Errorf("tls.Build() failed with empty config: %s", err) } if _, err := tls.Build(json.RawMessage(`{"ca_certificate_file":"/ca_certificates.pem","refresh_interval": "asdf"}`)); err == nil { - t.Errorf("expected an error with invalid refresh interval") + t.Errorf("tls.Build() succeeded with an invalid refresh interval, when expected to fail") } - // more tests for config validity are defined in creds subpackage. + // more tests for config validity are defined in tlscreds subpackage. } diff --git a/xds/internal/xdsclient/tlscreds/bundle.go b/xds/internal/xdsclient/tlscreds/bundle.go index 469c7c6bb5ab..c253cb69b5a3 100644 --- a/xds/internal/xdsclient/tlscreds/bundle.go +++ b/xds/internal/xdsclient/tlscreds/bundle.go @@ -24,6 +24,7 @@ import ( "context" "crypto/tls" "encoding/json" + "errors" "fmt" "net" @@ -49,9 +50,11 @@ func NewBundle(jd json.RawMessage) (credentials.Bundle, error) { PrivateKeyFile string `json:"private_key_file"` }{} - if err := json.Unmarshal(jd, cfg); err != nil { - return nil, err - } + if jd != nil { + if err := json.Unmarshal(jd, cfg); err != nil { + return nil, fmt.Errorf("failed to unmarshal config: %v", err) + } + } // Else the config field is absent. Treat it as an empty config. if cfg.CACertificateFile == "" && cfg.CertificateFile == "" && cfg.PrivateKeyFile == "" { // We cannot use (and do not need) a file_watcher provider in this case, @@ -63,9 +66,7 @@ func NewBundle(jd json.RawMessage) (credentials.Bundle, error) { // > provider, at least one of the "certificate_file" or // > "ca_certificate_file" fields must be specified, whereas in this // > configuration, it is acceptable to specify neither one. - return &bundle{ - transportCredentials: credentials.NewTLS(&tls.Config{}), - }, nil + return &bundle{transportCredentials: credentials.NewTLS(&tls.Config{})}, nil } // Otherwise we need to use a file_watcher provider to watch the CA, // private and public keys. @@ -76,9 +77,7 @@ func NewBundle(jd json.RawMessage) (credentials.Bundle, error) { return nil, err } return &bundle{ - transportCredentials: &reloadingCreds{ - provider: provider, - }, + transportCredentials: &reloadingCreds{provider: provider}, }, nil } @@ -124,15 +123,13 @@ func (c *reloadingCreds) Info() credentials.ProtocolInfo { } func (c *reloadingCreds) Clone() credentials.TransportCredentials { - return &reloadingCreds{ - provider: c.provider, - } + return &reloadingCreds{provider: c.provider} } -func (c *reloadingCreds) OverrideServerName(_ string) error { - panic("cannot override server name for private xds tls credentials") +func (c *reloadingCreds) OverrideServerName(string) error { + return errors.New("overriding server name is not supported by xDS client TLS credentials") } -func (c *reloadingCreds) ServerHandshake(_ net.Conn) (net.Conn, credentials.AuthInfo, error) { - panic("cannot perform handshake for server. xDS TLS credentials are client only.") +func (c *reloadingCreds) ServerHandshake(net.Conn) (net.Conn, credentials.AuthInfo, error) { + return nil, nil, errors.New("server handshake is not supported by xDS client TLS credentials") } diff --git a/xds/internal/xdsclient/tlscreds/bundle_test.go b/xds/internal/xdsclient/tlscreds/bundle_test.go index 7982117431c9..a69d570e4586 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_test.go @@ -56,7 +56,7 @@ func TestValidTlsBuilder(t *testing.T) { t.Run(jd, func(t *testing.T) { msg := json.RawMessage(jd) if _, err := NewBundle(msg); err != nil { - t.Errorf("expected no error but got: %s", err) + t.Errorf("NewBundle(%s): expected no error but got: %s", jd, err) } }) } @@ -66,7 +66,7 @@ func TestInvalidTlsBuilder(t *testing.T) { tests := []struct { jd, err string }{ - {`{"ca_certificate_file": 1}`, "json: cannot unmarshal number into Go struct field .ca_certificate_file of type string"}, + {`{"ca_certificate_file": 1}`, "failed to unmarshal config: json: cannot unmarshal number into Go struct field .ca_certificate_file of type string"}, {`{"certificate_file":"bar"}`, "pemfile: private key file and identity cert file should be both specified or not specified"}, } @@ -74,7 +74,7 @@ func TestInvalidTlsBuilder(t *testing.T) { t.Run(test.jd, func(t *testing.T) { msg := json.RawMessage(test.jd) if _, err := NewBundle(msg); err.Error() != test.err { - t.Errorf("expected: %s, got: %s", test.err, err) + t.Errorf("NewBundle(%s): want error %s, got: %s", msg, test.err, err) } }) } @@ -93,14 +93,14 @@ func TestCaReloading(t *testing.T) { serverCa, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem")) if err != nil { - t.Fatalf("failed to read test CA cert: %s", err) + t.Fatalf("Failed to read test CA cert: %s", err) } // Write CA certs to a temporary file so that we can modify it later. caPath := t.TempDir() + "/ca.pem" err = os.WriteFile(caPath, serverCa, 0644) if err != nil { - t.Fatalf("failed to write test CA cert: %v", err) + t.Fatalf("Failed to write test CA cert: %v", err) } cfg := fmt.Sprintf(`{ "ca_certificate_file": "%s", @@ -108,7 +108,7 @@ func TestCaReloading(t *testing.T) { }`, caPath) tlsBundle, err := NewBundle([]byte(cfg)) if err != nil { - t.Fatalf("failed to create TLS bundle: %v", err) + t.Fatalf("Failed to create TLS bundle: %v", err) } conn, err := grpc.Dial( @@ -117,13 +117,13 @@ func TestCaReloading(t *testing.T) { grpc.WithAuthority("x.test.example.com"), ) if err != nil { - t.Fatalf("error dialing: %v", err) + t.Fatalf("Error dialing: %v", err) } defer conn.Close() client := testgrpc.NewTestServiceClient(conn) _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) if err != nil { - t.Errorf("error calling EmptyCall: %v", err) + t.Errorf("Error calling EmptyCall: %v", err) } // close the server and create a new one to force client to do a new // handshake. @@ -131,12 +131,12 @@ func TestCaReloading(t *testing.T) { invalidCa, err := os.ReadFile(testdata.Path("ca.pem")) if err != nil { - t.Fatalf("failed to read test CA cert: %v", err) + t.Fatalf("Failed to read test CA cert: %v", err) } // unload root cert err = os.WriteFile(caPath, invalidCa, 0644) if err != nil { - t.Fatalf("failed to write test CA cert: %v", err) + t.Fatalf("Failed to write test CA cert: %v", err) } // Leave time for the file_watcher provider to reload the CA. @@ -149,9 +149,9 @@ func TestCaReloading(t *testing.T) { // unknown CA. _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) if st, ok := status.FromError(err); !ok || st.Code() != codes.Unavailable { - t.Errorf("expected unavailable error, got %v", err) - } else if !strings.Contains(st.Message(), "certificate signed by unknown authority") { - t.Errorf("expected error to contain 'certificate signed by unknown authority', got %v", st.Message()) + t.Errorf("Expected unavailable error, got %v", err) + } else if want := "certificate signed by unknown authority"; !strings.Contains(st.Message(), want) { + t.Errorf("Expected call error to contain '%s', got %v", want, st.Message()) } } @@ -169,7 +169,7 @@ func TestMTLS(t *testing.T) { testdata.Path("x509/client1_key.pem")) tlsBundle, err := NewBundle([]byte(cfg)) if err != nil { - t.Fatalf("failed to create TLS bundle: %v", err) + t.Fatalf("Failed to create TLS bundle: %v", err) } dialOpts := []grpc.DialOption{ grpc.WithCredentialsBundle(tlsBundle), @@ -179,13 +179,13 @@ func TestMTLS(t *testing.T) { t.Run("ValidClientCert", func(t *testing.T) { conn, err := grpc.Dial(srvAddr.String(), dialOpts...) if err != nil { - t.Fatalf("error dialing: %v", err) + t.Fatalf("Error dialing: %v", err) } client := testgrpc.NewTestServiceClient(conn) _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) if err != nil { - t.Errorf("error calling EmptyCall: %v", err) + t.Errorf("Error calling EmptyCall: %v", err) } conn.Close() }) @@ -201,21 +201,21 @@ func TestMTLS(t *testing.T) { _, _ = creds.provider.KeyMaterial(context.Background()) if !ok { - t.Fatalf("expected reloadingCreds, got %T", tlsBundle.TransportCredentials()) + t.Fatalf("Expected reloadingCreds, got %T", tlsBundle.TransportCredentials()) } creds.provider.Close() conn, err := grpc.Dial(srvAddr.String(), dialOpts...) if err != nil { - t.Fatalf("error dialing: %v", err) + t.Fatalf("Error dialing: %v", err) } client := testgrpc.NewTestServiceClient(conn) _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) if st, ok := status.FromError(err); !ok || st.Code() != codes.Unavailable { - t.Errorf("expected unavailable error, got %v", err) - } else if !strings.Contains(st.Message(), "provider instance is closed") { - t.Errorf("expected error to contain 'provider instance is closed', got %v", st.Message()) + t.Errorf("Expected unavailable error, got %v", err) + } else if want := "provider instance is closed"; !strings.Contains(st.Message(), want) { + t.Errorf("Expected error to contain '%s', got %v", want, st.Message()) } conn.Close() }) @@ -227,15 +227,15 @@ func startServer(t *testing.T, addr string, clientAuth tls.ClientAuthType) (net. // Create a TLS server with a valid cert that requires a client cert. serverCert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) if err != nil { - t.Fatalf("failed to load server cert: %v", err) + t.Fatalf("Failed to load server cert: %v", err) } pemClientCA, err := os.ReadFile(testdata.Path("x509/client_ca_cert.pem")) if err != nil { - t.Fatalf("failed to read test client CA cert: %v", err) + t.Fatalf("Failed to read test client CA cert: %v", err) } clientCA := x509.NewCertPool() if !clientCA.AppendCertsFromPEM(pemClientCA) { - t.Fatal("failed to add client CA's certificate") + t.Fatal("Failed to add client CA's certificate") } serverTLSCfg := &tls.Config{ Certificates: []tls.Certificate{serverCert}, @@ -243,13 +243,13 @@ func startServer(t *testing.T, addr string, clientAuth tls.ClientAuthType) (net. ClientCAs: clientCA, } if err != nil { - t.Fatalf("failed to generate server credentials: %v", err) + t.Fatalf("Failed to generate server credentials: %v", err) } s := grpc.NewServer(grpc.Creds(credentials.NewTLS(serverTLSCfg))) testgrpc.RegisterTestServiceServer(s, &testServer{}) lis, err := net.Listen("tcp", addr) if err != nil { - t.Fatalf("error listening: %v", err) + t.Fatalf("Error listening: %v", err) } go s.Serve(lis) return lis.Addr(), func() { s.Stop() } From ce575816a025452c782b0ea691363d6decf47368 Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Mon, 18 Dec 2023 12:03:23 +0100 Subject: [PATCH 06/17] comments from easwar on bundle test --- internal/testutils/xds/e2e/setup_certs.go | 24 ++ .../xdsclient/tlscreds/bundle_ext_test.go | 185 +++++++++++++ .../xdsclient/tlscreds/bundle_test.go | 245 ++---------------- 3 files changed, 233 insertions(+), 221 deletions(-) create mode 100644 xds/internal/xdsclient/tlscreds/bundle_ext_test.go diff --git a/internal/testutils/xds/e2e/setup_certs.go b/internal/testutils/xds/e2e/setup_certs.go index 799e18564879..15627453b3a6 100644 --- a/internal/testutils/xds/e2e/setup_certs.go +++ b/internal/testutils/xds/e2e/setup_certs.go @@ -95,3 +95,27 @@ func CreateClientTLSCredentials(t *testing.T) credentials.TransportCredentials { ServerName: "x.test.example.com", }) } + +// CreateServerTLSCredentials creates server-side TLS transport credentials +// using certificate and key files from testdata/x509 directory. +func CreateServerTLSCredentials(t *testing.T, clientAuth tls.ClientAuthType) credentials.TransportCredentials { + t.Helper() + + cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) + if err != nil { + t.Fatalf("tls.LoadX509KeyPair(x509/server1_cert.pem, x509/server1_key.pem) failed: %v", err) + } + b, err := os.ReadFile(testdata.Path("x509/client_ca_cert.pem")) + if err != nil { + t.Fatalf("os.ReadFile(x509/client_ca_cert.pem) failed: %v", err) + } + ca := x509.NewCertPool() + if !ca.AppendCertsFromPEM(b) { + t.Fatal("Failed to append certificates") + } + return credentials.NewTLS(&tls.Config{ + ClientAuth: clientAuth, + Certificates: []tls.Certificate{cert}, + ClientCAs: ca, + }) +} diff --git a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go new file mode 100644 index 000000000000..bbe97f38d64f --- /dev/null +++ b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go @@ -0,0 +1,185 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package tlscreds_test + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "os" + "strings" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils/xds/e2e" + testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/status" + "google.golang.org/grpc/testdata" + "google.golang.org/grpc/xds/internal/xdsclient/tlscreds" +) + +func TestValidTlsBuilder(t *testing.T) { + tests := []struct { + name string + jd string + }{ + {"Absent configuration", `null`}, + {"Empty configuration", `{}`}, + {"Only CA certificate chain", `{"ca_certificate_file": "foo"}`}, + {"Only private key and certificate chain", `{"certificate_file":"bar","private_key_file":"baz"}`}, + {"CA chain, private key and certificate chain", `{"ca_certificate_file":"foo","certificate_file":"bar","private_key_file":"baz"}`}, + {"Only refresh interval", `{"refresh_interval": "1s"}`}, + {"Refresh interval and CA certificate chain", `{"refresh_interval": "1s","ca_certificate_file": "foo"}`}, + {"Refresh interval, private key and certificate chain", `{"refresh_interval": "1s","certificate_file":"bar","private_key_file":"baz"}`}, + {"Refresh interval, CA chain, private key and certificate chain", `{"refresh_interval": "1s","ca_certificate_file":"foo","certificate_file":"bar","private_key_file":"baz"}`}, + {"Unknown field", `{"unknown_field": "foo"}`}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + msg := json.RawMessage(test.jd) + if _, err := tlscreds.NewBundle(msg); err != nil { + t.Errorf("NewBundle(%s): expected no error but got: %s", test.jd, err) + } + }) + } +} + +func TestInvalidTlsBuilder(t *testing.T) { + tests := []struct { + name, jd, wantErrPrefix string + }{ + {"Wrong type in json", `{"ca_certificate_file": 1}`, "failed to unmarshal config:"}, + {"Missing private key", `{"certificate_file":"bar"}`, "pemfile: private key file and identity cert file should be both specified or not specified"}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + msg := json.RawMessage(test.jd) + if _, err := tlscreds.NewBundle(msg); err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) { + t.Errorf("NewBundle(%s): want error %s, got: %s", msg, test.wantErrPrefix, err) + } + }) + } +} + +func TestCaReloading(t *testing.T) { + serverCa, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem")) + if err != nil { + t.Fatalf("Failed to read test CA cert: %s", err) + } + + // Write CA certs to a temporary file so that we can modify it later. + caPath := t.TempDir() + "/ca.pem" + err = os.WriteFile(caPath, serverCa, 0644) + if err != nil { + t.Fatalf("Failed to write test CA cert: %v", err) + } + cfg := fmt.Sprintf(`{ + "ca_certificate_file": "%s", + "refresh_interval": ".01s" + }`, caPath) + tlsBundle, err := tlscreds.NewBundle([]byte(cfg)) + if err != nil { + t.Fatalf("Failed to create TLS bundle: %v", err) + } + + serverCredentials := grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.NoClientCert)) + server := stubserver.StartTestService(t, nil, serverCredentials) + + conn, err := grpc.Dial( + server.Address, + grpc.WithCredentialsBundle(tlsBundle), + grpc.WithAuthority("x.test.example.com"), + ) + if err != nil { + t.Fatalf("Error dialing: %v", err) + } + defer conn.Close() + client := testgrpc.NewTestServiceClient(conn) + _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) + if err != nil { + t.Errorf("Error calling EmptyCall: %v", err) + } + // close the server and create a new one to force client to do a new + // handshake. + server.Stop() + + invalidCa, err := os.ReadFile(testdata.Path("ca.pem")) + if err != nil { + t.Fatalf("Failed to read test CA cert: %v", err) + } + // unload root cert + err = os.WriteFile(caPath, invalidCa, 0644) + if err != nil { + t.Fatalf("Failed to write test CA cert: %v", err) + } + + // Leave time for the file_watcher provider to reload the CA. + time.Sleep(100 * time.Millisecond) + + server = stubserver.StartTestService(t, &stubserver.StubServer{Address: server.Address}, serverCredentials) + defer server.Stop() + + // Client handshake should fail because the server cert is signed by an + // unknown CA. + _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) + if st, ok := status.FromError(err); !ok || st.Code() != codes.Unavailable { + t.Errorf("Expected unavailable error, got %v", err) + } else if want := "certificate signed by unknown authority"; !strings.Contains(st.Message(), want) { + t.Errorf("Expected call error to contain '%s', got %v", want, st.Message()) + } +} + +func TestMTLS(t *testing.T) { + s := stubserver.StartTestService(t, nil, grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert))) + defer s.Stop() + + cfg := fmt.Sprintf(`{ + "ca_certificate_file": "%s", + "certificate_file": "%s", + "private_key_file": "%s" + }`, + testdata.Path("x509/server_ca_cert.pem"), + testdata.Path("x509/client1_cert.pem"), + testdata.Path("x509/client1_key.pem")) + tlsBundle, err := tlscreds.NewBundle([]byte(cfg)) + if err != nil { + t.Fatalf("Failed to create TLS bundle: %v", err) + } + dialOpts := []grpc.DialOption{ + grpc.WithCredentialsBundle(tlsBundle), + grpc.WithAuthority("x.test.example.com"), + } + + conn, err := grpc.Dial(s.Address, dialOpts...) + if err != nil { + t.Fatalf("Error dialing: %v", err) + } + client := testgrpc.NewTestServiceClient(conn) + _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) + if err != nil { + t.Errorf("Error calling EmptyCall: %v", err) + } +} diff --git a/xds/internal/xdsclient/tlscreds/bundle_test.go b/xds/internal/xdsclient/tlscreds/bundle_test.go index a69d570e4586..94371d965487 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_test.go @@ -1,163 +1,22 @@ -/* - * - * Copyright 2023 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - package tlscreds import ( "context" "crypto/tls" - "crypto/x509" - "encoding/json" "fmt" - "net" - "os" - "strings" "testing" - "time" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/status" - "google.golang.org/grpc/testdata" - + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils/xds/e2e" testgrpc "google.golang.org/grpc/interop/grpc_testing" testpb "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/testdata" ) -func TestValidTlsBuilder(t *testing.T) { - tests := []string{ - `{}`, - `{"ca_certificate_file": "foo"}`, - `{"certificate_file":"bar","private_key_file":"baz"}`, - `{"ca_certificate_file":"foo","certificate_file":"bar","private_key_file":"baz"}`, - `{"refresh_interval": "1s"}`, - `{"refresh_interval": "1s","ca_certificate_file": "foo"}`, - `{"refresh_interval": "1s","certificate_file":"bar","private_key_file":"baz"}`, - `{"refresh_interval": "1s","ca_certificate_file":"foo","certificate_file":"bar","private_key_file":"baz"}`, - } - - for _, jd := range tests { - t.Run(jd, func(t *testing.T) { - msg := json.RawMessage(jd) - if _, err := NewBundle(msg); err != nil { - t.Errorf("NewBundle(%s): expected no error but got: %s", jd, err) - } - }) - } -} - -func TestInvalidTlsBuilder(t *testing.T) { - tests := []struct { - jd, err string - }{ - {`{"ca_certificate_file": 1}`, "failed to unmarshal config: json: cannot unmarshal number into Go struct field .ca_certificate_file of type string"}, - {`{"certificate_file":"bar"}`, "pemfile: private key file and identity cert file should be both specified or not specified"}, - } - - for _, test := range tests { - t.Run(test.jd, func(t *testing.T) { - msg := json.RawMessage(test.jd) - if _, err := NewBundle(msg); err.Error() != test.err { - t.Errorf("NewBundle(%s): want error %s, got: %s", msg, test.err, err) - } - }) - } -} - -type testServer struct { - testgrpc.UnimplementedTestServiceServer -} - -func (t testServer) EmptyCall(_ context.Context, _ *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil -} - -func TestCaReloading(t *testing.T) { - srvAddr, stopSrv := startServer(t, "localhost:0", tls.NoClientCert) - - serverCa, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem")) - if err != nil { - t.Fatalf("Failed to read test CA cert: %s", err) - } - - // Write CA certs to a temporary file so that we can modify it later. - caPath := t.TempDir() + "/ca.pem" - err = os.WriteFile(caPath, serverCa, 0644) - if err != nil { - t.Fatalf("Failed to write test CA cert: %v", err) - } - cfg := fmt.Sprintf(`{ - "ca_certificate_file": "%s", - "refresh_interval": ".01s" - }`, caPath) - tlsBundle, err := NewBundle([]byte(cfg)) - if err != nil { - t.Fatalf("Failed to create TLS bundle: %v", err) - } - - conn, err := grpc.Dial( - srvAddr.String(), - grpc.WithCredentialsBundle(tlsBundle), - grpc.WithAuthority("x.test.example.com"), - ) - if err != nil { - t.Fatalf("Error dialing: %v", err) - } - defer conn.Close() - client := testgrpc.NewTestServiceClient(conn) - _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) - if err != nil { - t.Errorf("Error calling EmptyCall: %v", err) - } - // close the server and create a new one to force client to do a new - // handshake. - stopSrv() - - invalidCa, err := os.ReadFile(testdata.Path("ca.pem")) - if err != nil { - t.Fatalf("Failed to read test CA cert: %v", err) - } - // unload root cert - err = os.WriteFile(caPath, invalidCa, 0644) - if err != nil { - t.Fatalf("Failed to write test CA cert: %v", err) - } - - // Leave time for the file_watcher provider to reload the CA. - time.Sleep(100 * time.Millisecond) - - _, stopFunc := startServer(t, srvAddr.String(), tls.NoClientCert) - defer stopFunc() - - // Client handshake should fail because the server cert is signed by an - // unknown CA. - _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) - if st, ok := status.FromError(err); !ok || st.Code() != codes.Unavailable { - t.Errorf("Expected unavailable error, got %v", err) - } else if want := "certificate signed by unknown authority"; !strings.Contains(st.Message(), want) { - t.Errorf("Expected call error to contain '%s', got %v", want, st.Message()) - } -} - -func TestMTLS(t *testing.T) { - srvAddr, stopFunc := startServer(t, "localhost:0", tls.RequireAndVerifyClientCert) - defer stopFunc() +func TestFaillingProvider(t *testing.T) { + s := stubserver.StartTestService(t, nil, grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert))) + defer s.Stop() cfg := fmt.Sprintf(`{ "ca_certificate_file": "%s", @@ -168,89 +27,33 @@ func TestMTLS(t *testing.T) { testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem")) tlsBundle, err := NewBundle([]byte(cfg)) - if err != nil { - t.Fatalf("Failed to create TLS bundle: %v", err) - } dialOpts := []grpc.DialOption{ grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com"), } - t.Run("ValidClientCert", func(t *testing.T) { - conn, err := grpc.Dial(srvAddr.String(), dialOpts...) - if err != nil { - t.Fatalf("Error dialing: %v", err) - } - client := testgrpc.NewTestServiceClient(conn) - - _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) - if err != nil { - t.Errorf("Error calling EmptyCall: %v", err) - } - conn.Close() - }) - - t.Run("Provider failing", func(t *testing.T) { - // Check that if the provider returns an errors, we fail the handshake. - // It's not easy to trigger this condition, so we rely on closing the - // provider. - creds, ok := tlsBundle.TransportCredentials().(*reloadingCreds) - - // Force the provider to be initialized. The test is flaky otherwise, - // since close may be a noop. - _, _ = creds.provider.KeyMaterial(context.Background()) - - if !ok { - t.Fatalf("Expected reloadingCreds, got %T", tlsBundle.TransportCredentials()) - } - - creds.provider.Close() + // Check that if the provider returns an errors, we fail the handshake. + // It's not easy to trigger this condition, so we rely on closing the + // provider. + creds, ok := tlsBundle.TransportCredentials().(*reloadingCreds) + if !ok { + t.Fatalf("Expected reloadingCreds, got %T", tlsBundle.TransportCredentials()) + } - conn, err := grpc.Dial(srvAddr.String(), dialOpts...) - if err != nil { - t.Fatalf("Error dialing: %v", err) - } - client := testgrpc.NewTestServiceClient(conn) - _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) - if st, ok := status.FromError(err); !ok || st.Code() != codes.Unavailable { - t.Errorf("Expected unavailable error, got %v", err) - } else if want := "provider instance is closed"; !strings.Contains(st.Message(), want) { - t.Errorf("Expected error to contain '%s', got %v", want, st.Message()) - } - conn.Close() - }) -} + // Force the provider to be initialized. The test is flaky otherwise, + // since close may be a noop. + _, _ = creds.provider.KeyMaterial(context.Background()) -type stopFunc func() + creds.provider.Close() -func startServer(t *testing.T, addr string, clientAuth tls.ClientAuthType) (net.Addr, stopFunc) { - // Create a TLS server with a valid cert that requires a client cert. - serverCert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) - if err != nil { - t.Fatalf("Failed to load server cert: %v", err) - } - pemClientCA, err := os.ReadFile(testdata.Path("x509/client_ca_cert.pem")) + conn, err := grpc.Dial(s.Address, dialOpts...) if err != nil { - t.Fatalf("Failed to read test client CA cert: %v", err) - } - clientCA := x509.NewCertPool() - if !clientCA.AppendCertsFromPEM(pemClientCA) { - t.Fatal("Failed to add client CA's certificate") - } - serverTLSCfg := &tls.Config{ - Certificates: []tls.Certificate{serverCert}, - ClientAuth: clientAuth, - ClientCAs: clientCA, - } - if err != nil { - t.Fatalf("Failed to generate server credentials: %v", err) + t.Fatalf("Error dialing: %v", err) } - s := grpc.NewServer(grpc.Creds(credentials.NewTLS(serverTLSCfg))) - testgrpc.RegisterTestServiceServer(s, &testServer{}) - lis, err := net.Listen("tcp", addr) - if err != nil { - t.Fatalf("Error listening: %v", err) + client := testgrpc.NewTestServiceClient(conn) + _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) + if wantErr := "provider instance is closed"; err.Error() != wantErr { + t.Errorf("Expected error to end with %v, got %v", wantErr, err) } - go s.Serve(lis) - return lis.Addr(), func() { s.Stop() } + conn.Close() } From 31df4f3a788266b60773cd3e4ef3c53a844eac2f Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Tue, 19 Dec 2023 10:05:46 +0100 Subject: [PATCH 07/17] update tests based on easwars feedback --- .../xdsclient/tlscreds/bundle_ext_test.go | 65 ++++++++++++------- .../xdsclient/tlscreds/bundle_test.go | 27 +++++++- 2 files changed, 65 insertions(+), 27 deletions(-) diff --git a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go index bbe97f38d64f..4d1c0bdcd4d3 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go @@ -39,28 +39,32 @@ import ( "google.golang.org/grpc/xds/internal/xdsclient/tlscreds" ) +const ( + defaultTestTimeout = 5 * time.Second +) + func TestValidTlsBuilder(t *testing.T) { tests := []struct { name string jd string }{ - {"Absent configuration", `null`}, - {"Empty configuration", `{}`}, - {"Only CA certificate chain", `{"ca_certificate_file": "foo"}`}, - {"Only private key and certificate chain", `{"certificate_file":"bar","private_key_file":"baz"}`}, - {"CA chain, private key and certificate chain", `{"ca_certificate_file":"foo","certificate_file":"bar","private_key_file":"baz"}`}, - {"Only refresh interval", `{"refresh_interval": "1s"}`}, - {"Refresh interval and CA certificate chain", `{"refresh_interval": "1s","ca_certificate_file": "foo"}`}, - {"Refresh interval, private key and certificate chain", `{"refresh_interval": "1s","certificate_file":"bar","private_key_file":"baz"}`}, - {"Refresh interval, CA chain, private key and certificate chain", `{"refresh_interval": "1s","ca_certificate_file":"foo","certificate_file":"bar","private_key_file":"baz"}`}, - {"Unknown field", `{"unknown_field": "foo"}`}, + {name: "Absent configuration", jd: `null`}, + {name: "Empty configuration", jd: `{}`}, + {name: "Only CA certificate chain", jd: `{"ca_certificate_file": "foo"}`}, + {name: "Only private key and certificate chain", jd: `{"certificate_file":"bar","private_key_file":"baz"}`}, + {name: "CA chain, private key and certificate chain", jd: `{"ca_certificate_file":"foo","certificate_file":"bar","private_key_file":"baz"}`}, + {name: "Only refresh interval", jd: `{"refresh_interval": "1s"}`}, + {name: "Refresh interval and CA certificate chain", jd: `{"refresh_interval": "1s","ca_certificate_file": "foo"}`}, + {name: "Refresh interval, private key and certificate chain", jd: `{"refresh_interval": "1s","certificate_file":"bar","private_key_file":"baz"}`}, + {name: "Refresh interval, CA chain, private key and certificate chain", jd: `{"refresh_interval": "1s","ca_certificate_file":"foo","certificate_file":"bar","private_key_file":"baz"}`}, + {name: "Unknown field", jd: `{"unknown_field": "foo"}`}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { msg := json.RawMessage(test.jd) if _, err := tlscreds.NewBundle(msg); err != nil { - t.Errorf("NewBundle(%s): expected no error but got: %s", test.jd, err) + t.Errorf("NewBundle(%s) returned error %s when expected to succeed", test.jd, err) } }) } @@ -78,7 +82,7 @@ func TestInvalidTlsBuilder(t *testing.T) { t.Run(test.name, func(t *testing.T) { msg := json.RawMessage(test.jd) if _, err := tlscreds.NewBundle(msg); err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) { - t.Errorf("NewBundle(%s): want error %s, got: %s", msg, test.wantErrPrefix, err) + t.Errorf("NewBundle(%s): got error %s, want %s", msg, err, test.wantErrPrefix) } }) } @@ -136,19 +140,30 @@ func TestCaReloading(t *testing.T) { t.Fatalf("Failed to write test CA cert: %v", err) } - // Leave time for the file_watcher provider to reload the CA. - time.Sleep(100 * time.Millisecond) - - server = stubserver.StartTestService(t, &stubserver.StubServer{Address: server.Address}, serverCredentials) - defer server.Stop() - - // Client handshake should fail because the server cert is signed by an - // unknown CA. - _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) - if st, ok := status.FromError(err); !ok || st.Code() != codes.Unavailable { - t.Errorf("Expected unavailable error, got %v", err) - } else if want := "certificate signed by unknown authority"; !strings.Contains(st.Message(), want) { - t.Errorf("Expected call error to contain '%s', got %v", want, st.Message()) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + for ; ctx.Err() == nil; <-time.After(10 * time.Millisecond) { + ss := stubserver.StubServer{ + Address: server.Address, + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, + } + server = stubserver.StartTestService(t, &ss, serverCredentials) + + // Client handshake should fail because the server cert is signed by an + // unknown CA. + t.Log(server) + _, err = client.EmptyCall(ctx, &testpb.Empty{}) + const wantErr = "certificate signed by unknown authority" + if status.Code(err) == codes.Unavailable && strings.Contains(err.Error(), wantErr) { + // Certs have reloaded. + break + } + t.Logf("EmptyCall() want code: %s, want err: %s, got err: %s", codes.Unavailable, wantErr, err) + server.Stop() + } + if ctx.Err() != nil { + t.Errorf("Timed out waiting for CA certs reloading") } } diff --git a/xds/internal/xdsclient/tlscreds/bundle_test.go b/xds/internal/xdsclient/tlscreds/bundle_test.go index 94371d965487..2656298be612 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_test.go @@ -1,9 +1,28 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + package tlscreds import ( "context" "crypto/tls" "fmt" + "strings" "testing" "google.golang.org/grpc" @@ -14,7 +33,7 @@ import ( "google.golang.org/grpc/testdata" ) -func TestFaillingProvider(t *testing.T) { +func TestFailingProvider(t *testing.T) { s := stubserver.StartTestService(t, nil, grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert))) defer s.Stop() @@ -27,6 +46,10 @@ func TestFaillingProvider(t *testing.T) { testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem")) tlsBundle, err := NewBundle([]byte(cfg)) + if err != nil { + t.Fatalf("Failed to create TLS bundle: %v", err) + } + dialOpts := []grpc.DialOption{ grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com"), @@ -52,7 +75,7 @@ func TestFaillingProvider(t *testing.T) { } client := testgrpc.NewTestServiceClient(conn) _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) - if wantErr := "provider instance is closed"; err.Error() != wantErr { + if wantErr := "provider instance is closed"; strings.HasSuffix(err.Error(), wantErr) { t.Errorf("Expected error to end with %v, got %v", wantErr, err) } conn.Close() From e3a4724c3d706688de3beda9c27ac32ccd57a103 Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Tue, 19 Dec 2023 10:44:07 +0100 Subject: [PATCH 08/17] fix more test formatting --- .../xdsclient/bootstrap/bootstrap_test.go | 2 +- .../xdsclient/tlscreds/bundle_ext_test.go | 19 +++++++++++++------ .../xdsclient/tlscreds/bundle_test.go | 6 +++--- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/xds/internal/xdsclient/bootstrap/bootstrap_test.go b/xds/internal/xdsclient/bootstrap/bootstrap_test.go index 1ac9fb01fc72..52133292256b 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap_test.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap_test.go @@ -1051,7 +1051,7 @@ func TestCredsBuilders(t *testing.T) { func TestTlsCredsBuilder(t *testing.T) { tls := &tlsCredsBuilder{} if _, err := tls.Build(json.RawMessage(`{}`)); err != nil { - t.Errorf("tls.Build() failed with empty config: %s", err) + t.Errorf("tls.Build() failed with error %s when expected to succeed", err) } if _, err := tls.Build(json.RawMessage(`{"ca_certificate_file":"/ca_certificates.pem","refresh_interval": "asdf"}`)); err == nil { t.Errorf("tls.Build() succeeded with an invalid refresh interval, when expected to fail") diff --git a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go index 4d1c0bdcd4d3..1ea48b2cb533 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go @@ -74,8 +74,15 @@ func TestInvalidTlsBuilder(t *testing.T) { tests := []struct { name, jd, wantErrPrefix string }{ - {"Wrong type in json", `{"ca_certificate_file": 1}`, "failed to unmarshal config:"}, - {"Missing private key", `{"certificate_file":"bar"}`, "pemfile: private key file and identity cert file should be both specified or not specified"}, + { + name: "Wrong type in json", + jd: `{"ca_certificate_file": 1}`, + wantErrPrefix: "failed to unmarshal config:"}, + { + name: "Missing private key", + jd: `{"certificate_file":"bar"}`, + wantErrPrefix: "pemfile: private key file and identity cert file should be both specified or not specified", + }, } for _, test := range tests { @@ -150,8 +157,8 @@ func TestCaReloading(t *testing.T) { } server = stubserver.StartTestService(t, &ss, serverCredentials) - // Client handshake should fail because the server cert is signed by an - // unknown CA. + // Client handshake should eventually fail because the client CA was + // reloaded, and thus the server cert is signed by an unknown CA. t.Log(server) _, err = client.EmptyCall(ctx, &testpb.Empty{}) const wantErr = "certificate signed by unknown authority" @@ -159,7 +166,7 @@ func TestCaReloading(t *testing.T) { // Certs have reloaded. break } - t.Logf("EmptyCall() want code: %s, want err: %s, got err: %s", codes.Unavailable, wantErr, err) + t.Logf("EmptyCall() got err: %s, want code: %s, want err: %s", err, codes.Unavailable, wantErr) server.Stop() } if ctx.Err() != nil { @@ -195,6 +202,6 @@ func TestMTLS(t *testing.T) { client := testgrpc.NewTestServiceClient(conn) _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) if err != nil { - t.Errorf("Error calling EmptyCall: %v", err) + t.Errorf("EmptyCall(): got error %v when expected to succeed", err) } } diff --git a/xds/internal/xdsclient/tlscreds/bundle_test.go b/xds/internal/xdsclient/tlscreds/bundle_test.go index 2656298be612..a7834f1700f7 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_test.go @@ -60,7 +60,7 @@ func TestFailingProvider(t *testing.T) { // provider. creds, ok := tlsBundle.TransportCredentials().(*reloadingCreds) if !ok { - t.Fatalf("Expected reloadingCreds, got %T", tlsBundle.TransportCredentials()) + t.Fatalf("Got %T, expected reloadingCreds", tlsBundle.TransportCredentials()) } // Force the provider to be initialized. The test is flaky otherwise, @@ -75,8 +75,8 @@ func TestFailingProvider(t *testing.T) { } client := testgrpc.NewTestServiceClient(conn) _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) - if wantErr := "provider instance is closed"; strings.HasSuffix(err.Error(), wantErr) { - t.Errorf("Expected error to end with %v, got %v", wantErr, err) + if wantErr := "provider instance is closed"; err == nil || !strings.Contains(err.Error(), wantErr) { + t.Errorf("got error %v when expected error to contain %v", err, wantErr) } conn.Close() } From 0053a949242f605d46cd28132cb5c5a44cd4456b Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Tue, 19 Dec 2023 11:06:55 +0100 Subject: [PATCH 09/17] fix flaky TestFailingProvider test --- .../xdsclient/tlscreds/bundle_ext_test.go | 4 +-- .../xdsclient/tlscreds/bundle_test.go | 29 ++++++++++++++----- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go index 1ea48b2cb533..bc9b1e8c9c94 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go @@ -39,9 +39,7 @@ import ( "google.golang.org/grpc/xds/internal/xdsclient/tlscreds" ) -const ( - defaultTestTimeout = 5 * time.Second -) +const defaultTestTimeout = 5 * time.Second func TestValidTlsBuilder(t *testing.T) { tests := []struct { diff --git a/xds/internal/xdsclient/tlscreds/bundle_test.go b/xds/internal/xdsclient/tlscreds/bundle_test.go index a7834f1700f7..79579c732d84 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_test.go @@ -24,6 +24,7 @@ import ( "fmt" "strings" "testing" + "time" "google.golang.org/grpc" "google.golang.org/grpc/internal/stubserver" @@ -33,6 +34,8 @@ import ( "google.golang.org/grpc/testdata" ) +const defaultTestTimeout = 5 * time.Second + func TestFailingProvider(t *testing.T) { s := stubserver.StartTestService(t, nil, grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert))) defer s.Stop() @@ -69,14 +72,24 @@ func TestFailingProvider(t *testing.T) { creds.provider.Close() - conn, err := grpc.Dial(s.Address, dialOpts...) - if err != nil { - t.Fatalf("Error dialing: %v", err) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + for ; ctx.Err() == nil; <-time.After(10 * time.Millisecond) { + conn, err := grpc.Dial(s.Address, dialOpts...) + if err != nil { + t.Fatalf("Error dialing: %v", err) + } + client := testgrpc.NewTestServiceClient(conn) + _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) + const wantErr = "provider instance is closed" + if err != nil && strings.Contains(err.Error(), wantErr) { + break + } + t.Logf("EmptyCall() got err: %s, want err to contain: %s", err, "provider instance is closed") + conn.Close() } - client := testgrpc.NewTestServiceClient(conn) - _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) - if wantErr := "provider instance is closed"; err == nil || !strings.Contains(err.Error(), wantErr) { - t.Errorf("got error %v when expected error to contain %v", err, wantErr) + if ctx.Err() != nil { + t.Errorf("Timed out waiting for provider closed to trigger an RPC error") } - conn.Close() } From 50ab88ddaba50365097a564e5f3d87695dc653ea Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Tue, 19 Dec 2023 15:04:00 +0100 Subject: [PATCH 10/17] provider error test: use a fake provider instead of trying to close. --- .../xdsclient/tlscreds/bundle_test.go | 76 ++++++------------- 1 file changed, 25 insertions(+), 51 deletions(-) diff --git a/xds/internal/xdsclient/tlscreds/bundle_test.go b/xds/internal/xdsclient/tlscreds/bundle_test.go index 79579c732d84..b7fe69783aa1 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_test.go @@ -1,32 +1,15 @@ -/* - * - * Copyright 2023 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - package tlscreds import ( "context" "crypto/tls" + "errors" "fmt" "strings" "testing" - "time" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/internal/testutils/xds/e2e" testgrpc "google.golang.org/grpc/interop/grpc_testing" @@ -34,17 +17,23 @@ import ( "google.golang.org/grpc/testdata" ) -const defaultTestTimeout = 5 * time.Second +type failingProvider struct{} + +func (f failingProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) { + return nil, errors.New("test error") +} + +func (f failingProvider) Close() {} func TestFailingProvider(t *testing.T) { s := stubserver.StartTestService(t, nil, grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert))) defer s.Stop() cfg := fmt.Sprintf(`{ - "ca_certificate_file": "%s", - "certificate_file": "%s", - "private_key_file": "%s" - }`, + "ca_certificate_file": "%s", + "certificate_file": "%s", + "private_key_file": "%s" + }`, testdata.Path("x509/server_ca_cert.pem"), testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem")) @@ -58,38 +47,23 @@ func TestFailingProvider(t *testing.T) { grpc.WithAuthority("x.test.example.com"), } - // Check that if the provider returns an errors, we fail the handshake. - // It's not easy to trigger this condition, so we rely on closing the - // provider. + // Force a provider that returns an error, and make sure the client fails + // the handshake. creds, ok := tlsBundle.TransportCredentials().(*reloadingCreds) if !ok { t.Fatalf("Got %T, expected reloadingCreds", tlsBundle.TransportCredentials()) } + creds.provider = &failingProvider{} - // Force the provider to be initialized. The test is flaky otherwise, - // since close may be a noop. - _, _ = creds.provider.KeyMaterial(context.Background()) - - creds.provider.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - - for ; ctx.Err() == nil; <-time.After(10 * time.Millisecond) { - conn, err := grpc.Dial(s.Address, dialOpts...) - if err != nil { - t.Fatalf("Error dialing: %v", err) - } - client := testgrpc.NewTestServiceClient(conn) - _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) - const wantErr = "provider instance is closed" - if err != nil && strings.Contains(err.Error(), wantErr) { - break - } - t.Logf("EmptyCall() got err: %s, want err to contain: %s", err, "provider instance is closed") - conn.Close() + conn, err := grpc.Dial(s.Address, dialOpts...) + if err != nil { + t.Fatalf("Error dialing: %v", err) } - if ctx.Err() != nil { - t.Errorf("Timed out waiting for provider closed to trigger an RPC error") + defer conn.Close() + + client := testgrpc.NewTestServiceClient(conn) + _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) + if wantErr := "test error"; err == nil || !strings.Contains(err.Error(), wantErr) { + t.Errorf("EmptyCall() got err: %s, want err to contain: %s", err, wantErr) } } From 7c2cda2d3b4cad8c25849888d58623646980e2d7 Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Tue, 19 Dec 2023 15:14:57 +0100 Subject: [PATCH 11/17] add license --- .../xdsclient/tlscreds/bundle_ext_test.go | 2 +- xds/internal/xdsclient/tlscreds/bundle_test.go | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go index bc9b1e8c9c94..ca417e9985d3 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go @@ -87,7 +87,7 @@ func TestInvalidTlsBuilder(t *testing.T) { t.Run(test.name, func(t *testing.T) { msg := json.RawMessage(test.jd) if _, err := tlscreds.NewBundle(msg); err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) { - t.Errorf("NewBundle(%s): got error %s, want %s", msg, err, test.wantErrPrefix) + t.Errorf("NewBundle(%s): got error %s, want an error with prefix %s", msg, err, test.wantErrPrefix) } }) } diff --git a/xds/internal/xdsclient/tlscreds/bundle_test.go b/xds/internal/xdsclient/tlscreds/bundle_test.go index b7fe69783aa1..0a8b448681dc 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_test.go @@ -1,3 +1,21 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + package tlscreds import ( From 0d8263dea67f0856c8db4fee28f62887b752a386 Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Tue, 19 Dec 2023 21:27:43 +0100 Subject: [PATCH 12/17] comments from easwar --- .../xdsclient/tlscreds/bundle_ext_test.go | 25 ++++++++----------- .../xdsclient/tlscreds/bundle_test.go | 7 +----- 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go index ca417e9985d3..46eb3e802aac 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go @@ -101,8 +101,7 @@ func TestCaReloading(t *testing.T) { // Write CA certs to a temporary file so that we can modify it later. caPath := t.TempDir() + "/ca.pem" - err = os.WriteFile(caPath, serverCa, 0644) - if err != nil { + if err = os.WriteFile(caPath, serverCa, 0644); err != nil { t.Fatalf("Failed to write test CA cert: %v", err) } cfg := fmt.Sprintf(`{ @@ -126,9 +125,12 @@ func TestCaReloading(t *testing.T) { t.Fatalf("Error dialing: %v", err) } defer conn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + client := testgrpc.NewTestServiceClient(conn) - _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) - if err != nil { + if _, err = client.EmptyCall(ctx, &testpb.Empty{}); err != nil { t.Errorf("Error calling EmptyCall: %v", err) } // close the server and create a new one to force client to do a new @@ -145,9 +147,6 @@ func TestCaReloading(t *testing.T) { t.Fatalf("Failed to write test CA cert: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - for ; ctx.Err() == nil; <-time.After(10 * time.Millisecond) { ss := stubserver.StubServer{ Address: server.Address, @@ -162,6 +161,7 @@ func TestCaReloading(t *testing.T) { const wantErr = "certificate signed by unknown authority" if status.Code(err) == codes.Unavailable && strings.Contains(err.Error(), wantErr) { // Certs have reloaded. + server.Stop() break } t.Logf("EmptyCall() got err: %s, want code: %s, want err: %s", err, codes.Unavailable, wantErr) @@ -188,18 +188,13 @@ func TestMTLS(t *testing.T) { if err != nil { t.Fatalf("Failed to create TLS bundle: %v", err) } - dialOpts := []grpc.DialOption{ - grpc.WithCredentialsBundle(tlsBundle), - grpc.WithAuthority("x.test.example.com"), - } - - conn, err := grpc.Dial(s.Address, dialOpts...) + conn, err := grpc.Dial(s.Address, grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com")) if err != nil { t.Fatalf("Error dialing: %v", err) } + defer conn.Close() client := testgrpc.NewTestServiceClient(conn) - _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) - if err != nil { + if _, err = client.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { t.Errorf("EmptyCall(): got error %v when expected to succeed", err) } } diff --git a/xds/internal/xdsclient/tlscreds/bundle_test.go b/xds/internal/xdsclient/tlscreds/bundle_test.go index 0a8b448681dc..4e95823c00bb 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_test.go @@ -60,11 +60,6 @@ func TestFailingProvider(t *testing.T) { t.Fatalf("Failed to create TLS bundle: %v", err) } - dialOpts := []grpc.DialOption{ - grpc.WithCredentialsBundle(tlsBundle), - grpc.WithAuthority("x.test.example.com"), - } - // Force a provider that returns an error, and make sure the client fails // the handshake. creds, ok := tlsBundle.TransportCredentials().(*reloadingCreds) @@ -73,7 +68,7 @@ func TestFailingProvider(t *testing.T) { } creds.provider = &failingProvider{} - conn, err := grpc.Dial(s.Address, dialOpts...) + conn, err := grpc.Dial(s.Address, grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com")) if err != nil { t.Fatalf("Error dialing: %v", err) } From 1505f65518984a2feee728f808d3d1de52783f47 Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Tue, 19 Dec 2023 21:47:14 +0100 Subject: [PATCH 13/17] make returned bundle closeable --- xds/internal/xdsclient/bootstrap/bootstrap.go | 16 ++-- xds/internal/xdsclient/clientimpl.go | 5 ++ xds/internal/xdsclient/tlscreds/bundle.go | 9 +++ .../xdsclient/tlscreds/bundle_ext_test.go | 80 +++++++++++++++---- .../xdsclient/tlscreds/bundle_test.go | 12 ++- xds/internal/xdsclient/transport/transport.go | 4 +- 6 files changed, 98 insertions(+), 28 deletions(-) diff --git a/xds/internal/xdsclient/bootstrap/bootstrap.go b/xds/internal/xdsclient/bootstrap/bootstrap.go index 89b66952ef5f..c12268329382 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap.go @@ -155,9 +155,9 @@ type ServerConfig struct { // As part of unmarshaling the JSON config into this struct, we ensure that // the credentials config is valid by building an instance of the specified - // credentials and store it here as a grpc.DialOption for easy access when - // dialing this xDS server. - credsDialOption grpc.DialOption + // credentials and store it here for easy access when dialing this xDS + // server. + credsBundle credentials.Bundle // IgnoreResourceDeletion controls the behavior of the xDS client when the // server deletes a previously sent Listener or Cluster resource. If set, the @@ -167,9 +167,9 @@ type ServerConfig struct { IgnoreResourceDeletion bool } -// CredsDialOption returns the configured credentials as a grpc dial option. -func (sc *ServerConfig) CredsDialOption() grpc.DialOption { - return sc.credsDialOption +// CredsBundle returns the configured credentials bundle. +func (sc *ServerConfig) CredsBundle() credentials.Bundle { + return sc.credsBundle } // String returns the string representation of the ServerConfig. @@ -225,7 +225,7 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error { return fmt.Errorf("failed to build credentials bundle from bootstrap for %q: %v", cc.Type, err) } sc.Creds = ChannelCreds(cc) - sc.credsDialOption = grpc.WithCredentialsBundle(bundle) + sc.credsBundle = bundle break } return nil @@ -538,7 +538,7 @@ func newConfigFromContents(data []byte) (*Config, error) { if config.XDSServer.ServerURI == "" { return nil, fmt.Errorf("xds: required field %q not found in bootstrap %s", "xds_servers.server_uri", jsonData["xds_servers"]) } - if config.XDSServer.CredsDialOption() == nil { + if config.XDSServer.CredsBundle() == nil { return nil, fmt.Errorf("xds: required field %q doesn't contain valid value in bootstrap %s", "xds_servers.channel_creds", jsonData["xds_servers"]) } // Post-process the authorities' client listener resource template field: diff --git a/xds/internal/xdsclient/clientimpl.go b/xds/internal/xdsclient/clientimpl.go index 2c05ea66f5f9..6fa870cea607 100644 --- a/xds/internal/xdsclient/clientimpl.go +++ b/xds/internal/xdsclient/clientimpl.go @@ -85,5 +85,10 @@ func (c *clientImpl) close() { c.authorityMu.Unlock() c.serializerClose() + if closableBundle, ok := c.config.XDSServer.CredsBundle().(interface { + Close() + }); ok { + closableBundle.Close() + } c.logger.Infof("Shutdown") } diff --git a/xds/internal/xdsclient/tlscreds/bundle.go b/xds/internal/xdsclient/tlscreds/bundle.go index c253cb69b5a3..1edd10e867b3 100644 --- a/xds/internal/xdsclient/tlscreds/bundle.go +++ b/xds/internal/xdsclient/tlscreds/bundle.go @@ -97,6 +97,15 @@ func (t *bundle) NewWithMode(string) (credentials.Bundle, error) { return nil, fmt.Errorf("xDS TLS credentials only support one mode") } +// Close releases the underlying provider. Note that credentials.Bundle are +// not closeable, so users of this type must use a type assertion to call Close. +func (t *bundle) Close() { + cred, ok := t.transportCredentials.(*reloadingCreds) + if ok { + cred.provider.Close() + } +} + // reloadingCreds is a credentials.TransportCredentials for client // side mTLS that reloads the server root CA certificate and the client // certificates from the provider on every client handshake. This is necessary diff --git a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go index 46eb3e802aac..7ed9462cdbc5 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/internal/testutils/xds/e2e" testgrpc "google.golang.org/grpc/interop/grpc_testing" @@ -41,34 +42,76 @@ import ( const defaultTestTimeout = 5 * time.Second -func TestValidTlsBuilder(t *testing.T) { +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func (s) TestValidTlsBuilder(t *testing.T) { + caCert := testdata.Path("x509/server_ca_cert.pem") + clientCert := testdata.Path("x509/client1_cert.pem") + clientKey := testdata.Path("x509/client1_key.pem") tests := []struct { name string jd string }{ - {name: "Absent configuration", jd: `null`}, - {name: "Empty configuration", jd: `{}`}, - {name: "Only CA certificate chain", jd: `{"ca_certificate_file": "foo"}`}, - {name: "Only private key and certificate chain", jd: `{"certificate_file":"bar","private_key_file":"baz"}`}, - {name: "CA chain, private key and certificate chain", jd: `{"ca_certificate_file":"foo","certificate_file":"bar","private_key_file":"baz"}`}, - {name: "Only refresh interval", jd: `{"refresh_interval": "1s"}`}, - {name: "Refresh interval and CA certificate chain", jd: `{"refresh_interval": "1s","ca_certificate_file": "foo"}`}, - {name: "Refresh interval, private key and certificate chain", jd: `{"refresh_interval": "1s","certificate_file":"bar","private_key_file":"baz"}`}, - {name: "Refresh interval, CA chain, private key and certificate chain", jd: `{"refresh_interval": "1s","ca_certificate_file":"foo","certificate_file":"bar","private_key_file":"baz"}`}, - {name: "Unknown field", jd: `{"unknown_field": "foo"}`}, + { + name: "Absent configuration", + jd: `null`, + }, + { + name: "Empty configuration", + jd: `{}`, + }, + { + name: "Only CA certificate chain", + jd: fmt.Sprintf(`{"ca_certificate_file": "%s"}`, caCert), + }, + { + name: "Only private key and certificate chain", + jd: fmt.Sprintf(`{"certificate_file":"%s","private_key_file":"%s"}`, clientCert, clientKey), + }, + { + name: "CA chain, private key and certificate chain", + jd: fmt.Sprintf(`{"ca_certificate_file":"%s","certificate_file":"%s","private_key_file":"%s"}`, caCert, clientCert, clientKey), + }, + { + name: "Only refresh interval", jd: `{"refresh_interval": "1s"}`, + }, + { + name: "Refresh interval and CA certificate chain", + jd: fmt.Sprintf(`{"refresh_interval": "1s","ca_certificate_file": "%s"}`, caCert), + }, + { + name: "Refresh interval, private key and certificate chain", + jd: fmt.Sprintf(`{"refresh_interval": "1s","certificate_file":"%s","private_key_file":"%s"}`, clientCert, clientKey), + }, + { + name: "Refresh interval, CA chain, private key and certificate chain", + jd: fmt.Sprintf(`{"refresh_interval": "1s","ca_certificate_file":"%s","certificate_file":"%s","private_key_file":"%s"}`, caCert, clientCert, clientKey), + }, + { + name: "Unknown field", + jd: `{"unknown_field": "foo"}`, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { msg := json.RawMessage(test.jd) - if _, err := tlscreds.NewBundle(msg); err != nil { + if bundle, err := tlscreds.NewBundle(msg); err != nil { t.Errorf("NewBundle(%s) returned error %s when expected to succeed", test.jd, err) + } else { + bundle.Close() } }) } } -func TestInvalidTlsBuilder(t *testing.T) { +func (s) TestInvalidTlsBuilder(t *testing.T) { tests := []struct { name, jd, wantErrPrefix string }{ @@ -78,7 +121,7 @@ func TestInvalidTlsBuilder(t *testing.T) { wantErrPrefix: "failed to unmarshal config:"}, { name: "Missing private key", - jd: `{"certificate_file":"bar"}`, + jd: fmt.Sprintf(`{"certificate_file":"%s"}`, testdata.Path("x509/server_cert.pem")), wantErrPrefix: "pemfile: private key file and identity cert file should be both specified or not specified", }, } @@ -86,14 +129,15 @@ func TestInvalidTlsBuilder(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { msg := json.RawMessage(test.jd) - if _, err := tlscreds.NewBundle(msg); err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) { + if bundle, err := tlscreds.NewBundle(msg); err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) { t.Errorf("NewBundle(%s): got error %s, want an error with prefix %s", msg, err, test.wantErrPrefix) + bundle.Close() } }) } } -func TestCaReloading(t *testing.T) { +func (s) TestCaReloading(t *testing.T) { serverCa, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem")) if err != nil { t.Fatalf("Failed to read test CA cert: %s", err) @@ -112,6 +156,7 @@ func TestCaReloading(t *testing.T) { if err != nil { t.Fatalf("Failed to create TLS bundle: %v", err) } + defer tlsBundle.Close() serverCredentials := grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.NoClientCert)) server := stubserver.StartTestService(t, nil, serverCredentials) @@ -172,7 +217,7 @@ func TestCaReloading(t *testing.T) { } } -func TestMTLS(t *testing.T) { +func (s) TestMTLS(t *testing.T) { s := stubserver.StartTestService(t, nil, grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert))) defer s.Stop() @@ -188,6 +233,7 @@ func TestMTLS(t *testing.T) { if err != nil { t.Fatalf("Failed to create TLS bundle: %v", err) } + defer tlsBundle.Close() conn, err := grpc.Dial(s.Address, grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com")) if err != nil { t.Fatalf("Error dialing: %v", err) diff --git a/xds/internal/xdsclient/tlscreds/bundle_test.go b/xds/internal/xdsclient/tlscreds/bundle_test.go index 4e95823c00bb..67a54dcd0eda 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_test.go @@ -28,6 +28,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/tls/certprovider" + "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/internal/testutils/xds/e2e" testgrpc "google.golang.org/grpc/interop/grpc_testing" @@ -35,6 +36,14 @@ import ( "google.golang.org/grpc/testdata" ) +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + type failingProvider struct{} func (f failingProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) { @@ -43,7 +52,7 @@ func (f failingProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMate func (f failingProvider) Close() {} -func TestFailingProvider(t *testing.T) { +func (s) TestFailingProvider(t *testing.T) { s := stubserver.StartTestService(t, nil, grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert))) defer s.Stop() @@ -66,6 +75,7 @@ func TestFailingProvider(t *testing.T) { if !ok { t.Fatalf("Got %T, expected reloadingCreds", tlsBundle.TransportCredentials()) } + creds.provider.Close() creds.provider = &failingProvider{} conn, err := grpc.Dial(s.Address, grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com")) diff --git a/xds/internal/xdsclient/transport/transport.go b/xds/internal/xdsclient/transport/transport.go index 001552d7b479..76016a3073ff 100644 --- a/xds/internal/xdsclient/transport/transport.go +++ b/xds/internal/xdsclient/transport/transport.go @@ -177,7 +177,7 @@ func New(opts Options) (*Transport, error) { switch { case opts.ServerCfg.ServerURI == "": return nil, errors.New("missing server URI when creating a new transport") - case opts.ServerCfg.CredsDialOption() == nil: + case opts.ServerCfg.CredsBundle() == nil: return nil, errors.New("missing credentials when creating a new transport") case opts.OnRecvHandler == nil: return nil, errors.New("missing OnRecv callback handler when creating a new transport") @@ -189,7 +189,7 @@ func New(opts Options) (*Transport, error) { // Dial the xDS management with the passed in credentials. dopts := []grpc.DialOption{ - opts.ServerCfg.CredsDialOption(), + grpc.WithCredentialsBundle(opts.ServerCfg.CredsBundle()), grpc.WithKeepaliveParams(keepalive.ClientParameters{ // We decided to use these sane defaults in all languages, and // kicked the can down the road as far making these configurable. From 1ea86030d7f0fc19810c9cbd5f5c53f57fe1d053 Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Wed, 20 Dec 2023 09:27:25 +0100 Subject: [PATCH 14/17] fix tests --- test/xds/xds_client_certificate_providers_test.go | 3 ++- xds/internal/xdsclient/tlscreds/bundle_ext_test.go | 12 ++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/test/xds/xds_client_certificate_providers_test.go b/test/xds/xds_client_certificate_providers_test.go index a2979ca1beae..7741dc7581b9 100644 --- a/test/xds/xds_client_certificate_providers_test.go +++ b/test/xds/xds_client_certificate_providers_test.go @@ -20,6 +20,7 @@ package xds_test import ( "context" + "crypto/tls" "fmt" "strings" "testing" @@ -226,7 +227,7 @@ func (s) TestClientSideXDS_WithValidAndInvalidSecurityConfiguration(t *testing.T // backend1 configured with TLS creds, represents cluster1 // backend2 configured with insecure creds, represents cluster2 // backend3 configured with insecure creds, represents cluster3 - creds := e2e.CreateServerTLSCredentials(t) + creds := e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert) server1 := stubserver.StartTestService(t, nil, grpc.Creds(creds)) defer server1.Stop() server2 := stubserver.StartTestService(t, nil) diff --git a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go index 7ed9462cdbc5..5ef986619da0 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go @@ -50,6 +50,10 @@ func Test(t *testing.T) { grpctest.RunSubTests(t, s{}) } +type Closable interface { + Close() +} + func (s) TestValidTlsBuilder(t *testing.T) { caCert := testdata.Path("x509/server_ca_cert.pem") clientCert := testdata.Path("x509/client1_cert.pem") @@ -105,7 +109,7 @@ func (s) TestValidTlsBuilder(t *testing.T) { if bundle, err := tlscreds.NewBundle(msg); err != nil { t.Errorf("NewBundle(%s) returned error %s when expected to succeed", test.jd, err) } else { - bundle.Close() + bundle.(Closable).Close() } }) } @@ -131,7 +135,7 @@ func (s) TestInvalidTlsBuilder(t *testing.T) { msg := json.RawMessage(test.jd) if bundle, err := tlscreds.NewBundle(msg); err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) { t.Errorf("NewBundle(%s): got error %s, want an error with prefix %s", msg, err, test.wantErrPrefix) - bundle.Close() + bundle.(Closable).Close() } }) } @@ -156,7 +160,7 @@ func (s) TestCaReloading(t *testing.T) { if err != nil { t.Fatalf("Failed to create TLS bundle: %v", err) } - defer tlsBundle.Close() + defer tlsBundle.(Closable).Close() serverCredentials := grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.NoClientCert)) server := stubserver.StartTestService(t, nil, serverCredentials) @@ -233,7 +237,7 @@ func (s) TestMTLS(t *testing.T) { if err != nil { t.Fatalf("Failed to create TLS bundle: %v", err) } - defer tlsBundle.Close() + defer tlsBundle.(Closable).Close() conn, err := grpc.Dial(s.Address, grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com")) if err != nil { t.Fatalf("Error dialing: %v", err) From 8f3472e1bea91a94a5868fd6cdeab9999c2f643f Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Thu, 21 Dec 2023 14:26:01 +0100 Subject: [PATCH 15/17] generic xds client cleanups --- xds/bootstrap/bootstrap.go | 6 ++-- xds/bootstrap/bootstrap_test.go | 6 ++-- xds/internal/xdsclient/bootstrap/bootstrap.go | 33 +++++++++++-------- .../xdsclient/bootstrap/bootstrap_test.go | 19 ++++++++--- xds/internal/xdsclient/clientimpl.go | 6 ++-- xds/internal/xdsclient/tlscreds/bundle.go | 10 +++--- .../xdsclient/tlscreds/bundle_ext_test.go | 16 ++++----- .../xdsclient/tlscreds/bundle_test.go | 4 +-- xds/internal/xdsclient/transport/transport.go | 4 +-- 9 files changed, 59 insertions(+), 45 deletions(-) diff --git a/xds/bootstrap/bootstrap.go b/xds/bootstrap/bootstrap.go index fcb99bdfd967..ef55ff0c02db 100644 --- a/xds/bootstrap/bootstrap.go +++ b/xds/bootstrap/bootstrap.go @@ -37,8 +37,10 @@ var registry = make(map[string]Credentials) // Credentials interface encapsulates a credentials.Bundle builder // that can be used for communicating with the xDS Management server. type Credentials interface { - // Build returns a credential bundle associated with this credential. - Build(config json.RawMessage) (credentials.Bundle, error) + // Build returns a credential bundle associated with this credential, and + // a function to cleans up additional resources associated with this bundle + // when it is no longer needed. + Build(config json.RawMessage) (credentials.Bundle, func(), error) // Name returns the credential name associated with this credential. Name() string } diff --git a/xds/bootstrap/bootstrap_test.go b/xds/bootstrap/bootstrap_test.go index 80ae31ccd2e3..1afc3ce7075a 100644 --- a/xds/bootstrap/bootstrap_test.go +++ b/xds/bootstrap/bootstrap_test.go @@ -36,9 +36,9 @@ type testCredsBuilder struct { config json.RawMessage } -func (t *testCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, error) { +func (t *testCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, func(), error) { t.config = config - return nil, nil + return nil, nil, nil } func (t *testCredsBuilder) Name() string { @@ -53,7 +53,7 @@ func TestRegisterNew(t *testing.T) { const sampleConfig = "sample_config" rawMessage := json.RawMessage(sampleConfig) - if _, err := c.Build(rawMessage); err != nil { + if _, _, err := c.Build(rawMessage); err != nil { t.Errorf("Build(%v) error = %v, want nil", rawMessage, err) } diff --git a/xds/internal/xdsclient/bootstrap/bootstrap.go b/xds/internal/xdsclient/bootstrap/bootstrap.go index c12268329382..31a7a69f93cf 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap.go @@ -71,8 +71,8 @@ var bootstrapFileReadFunc = os.ReadFile // package `xds/bootstrap` and encapsulates an insecure credential. type insecureCredsBuilder struct{} -func (i *insecureCredsBuilder) Build(json.RawMessage) (credentials.Bundle, error) { - return insecure.NewBundle(), nil +func (i *insecureCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func(), error) { + return insecure.NewBundle(), func() {}, nil } func (i *insecureCredsBuilder) Name() string { @@ -83,7 +83,7 @@ func (i *insecureCredsBuilder) Name() string { // package `xds/bootstrap` and encapsulates a TLS credential. type tlsCredsBuilder struct{} -func (t *tlsCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, error) { +func (t *tlsCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, func(), error) { return tlscreds.NewBundle(config) } @@ -95,8 +95,8 @@ func (t *tlsCredsBuilder) Name() string { // package `xds/boostrap` and encapsulates a Google Default credential. type googleDefaultCredsBuilder struct{} -func (d *googleDefaultCredsBuilder) Build(json.RawMessage) (credentials.Bundle, error) { - return google.NewDefaultCredentials(), nil +func (d *googleDefaultCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func(), error) { + return google.NewDefaultCredentials(), func() {}, nil } func (d *googleDefaultCredsBuilder) Name() string { @@ -155,9 +155,9 @@ type ServerConfig struct { // As part of unmarshaling the JSON config into this struct, we ensure that // the credentials config is valid by building an instance of the specified - // credentials and store it here for easy access when dialing this xDS - // server. - credsBundle credentials.Bundle + // credentials and store it here as a grpc.DialOption for easy access when + // dialing this xDS server. + credsDialOption grpc.DialOption // IgnoreResourceDeletion controls the behavior of the xDS client when the // server deletes a previously sent Listener or Cluster resource. If set, the @@ -165,11 +165,15 @@ type ServerConfig struct { // when a resource is deleted, nor will it remove the existing resource value // from its cache. IgnoreResourceDeletion bool + + // Cleanups are called when the xDS client for this server is closed. Allows + // cleaning up resources created specifically for the xDS client. + Cleanups []func() } -// CredsBundle returns the configured credentials bundle. -func (sc *ServerConfig) CredsBundle() credentials.Bundle { - return sc.credsBundle +// CredsDialOption returns the configured credentials as a grpc dial option. +func (sc *ServerConfig) CredsDialOption() grpc.DialOption { + return sc.credsDialOption } // String returns the string representation of the ServerConfig. @@ -220,12 +224,13 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error { if c == nil { continue } - bundle, err := c.Build(cc.Config) + bundle, cancel, err := c.Build(cc.Config) if err != nil { return fmt.Errorf("failed to build credentials bundle from bootstrap for %q: %v", cc.Type, err) } sc.Creds = ChannelCreds(cc) - sc.credsBundle = bundle + sc.credsDialOption = grpc.WithCredentialsBundle(bundle) + sc.Cleanups = append(sc.Cleanups, cancel) break } return nil @@ -538,7 +543,7 @@ func newConfigFromContents(data []byte) (*Config, error) { if config.XDSServer.ServerURI == "" { return nil, fmt.Errorf("xds: required field %q not found in bootstrap %s", "xds_servers.server_uri", jsonData["xds_servers"]) } - if config.XDSServer.CredsBundle() == nil { + if config.XDSServer.CredsDialOption() == nil { return nil, fmt.Errorf("xds: required field %q doesn't contain valid value in bootstrap %s", "xds_servers.channel_creds", jsonData["xds_servers"]) } // Post-process the authorities' client listener resource template field: diff --git a/xds/internal/xdsclient/bootstrap/bootstrap_test.go b/xds/internal/xdsclient/bootstrap/bootstrap_test.go index 52133292256b..ac822ae219bb 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap_test.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap_test.go @@ -1023,16 +1023,20 @@ func TestDefaultBundles(t *testing.T) { func TestCredsBuilders(t *testing.T) { b := &googleDefaultCredsBuilder{} - if _, err := b.Build(nil); err != nil { + if _, stop, err := b.Build(nil); err != nil { t.Errorf("googleDefaultCredsBuilder.Build failed: %v", err) + } else { + stop() } if got, want := b.Name(), "google_default"; got != want { t.Errorf("googleDefaultCredsBuilder.Name = %v, want %v", got, want) } i := &insecureCredsBuilder{} - if _, err := i.Build(nil); err != nil { + if _, stop, err := i.Build(nil); err != nil { t.Errorf("insecureCredsBuilder.Build failed: %v", err) + } else { + stop() } if got, want := i.Name(), "insecure"; got != want { @@ -1040,8 +1044,10 @@ func TestCredsBuilders(t *testing.T) { } tcb := &tlsCredsBuilder{} - if _, err := tcb.Build(nil); err != nil { + if _, stop, err := tcb.Build(nil); err != nil { t.Errorf("tlsCredsBuilder.Build failed: %v", err) + } else { + stop() } if got, want := tcb.Name(), "tls"; got != want { t.Errorf("tlsCredsBuilder.Name = %v, want %v", got, want) @@ -1050,11 +1056,14 @@ func TestCredsBuilders(t *testing.T) { func TestTlsCredsBuilder(t *testing.T) { tls := &tlsCredsBuilder{} - if _, err := tls.Build(json.RawMessage(`{}`)); err != nil { + if _, stop, err := tls.Build(json.RawMessage(`{}`)); err != nil { t.Errorf("tls.Build() failed with error %s when expected to succeed", err) + } else { + stop() } - if _, err := tls.Build(json.RawMessage(`{"ca_certificate_file":"/ca_certificates.pem","refresh_interval": "asdf"}`)); err == nil { + if _, stop, err := tls.Build(json.RawMessage(`{"ca_certificate_file":"/ca_certificates.pem","refresh_interval": "asdf"}`)); err == nil { t.Errorf("tls.Build() succeeded with an invalid refresh interval, when expected to fail") + stop() } // more tests for config validity are defined in tlscreds subpackage. } diff --git a/xds/internal/xdsclient/clientimpl.go b/xds/internal/xdsclient/clientimpl.go index 6fa870cea607..3dc5ceac7d16 100644 --- a/xds/internal/xdsclient/clientimpl.go +++ b/xds/internal/xdsclient/clientimpl.go @@ -85,10 +85,8 @@ func (c *clientImpl) close() { c.authorityMu.Unlock() c.serializerClose() - if closableBundle, ok := c.config.XDSServer.CredsBundle().(interface { - Close() - }); ok { - closableBundle.Close() + for _, f := range c.config.XDSServer.Cleanups { + f() } c.logger.Infof("Shutdown") } diff --git a/xds/internal/xdsclient/tlscreds/bundle.go b/xds/internal/xdsclient/tlscreds/bundle.go index 1edd10e867b3..c4e977c9d76d 100644 --- a/xds/internal/xdsclient/tlscreds/bundle.go +++ b/xds/internal/xdsclient/tlscreds/bundle.go @@ -43,7 +43,7 @@ type bundle struct { // Bootstrap File. It delegates certificate loading to a file_watcher provider // if either client certificates or server root CA is specified. // See gRFC A65: github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md -func NewBundle(jd json.RawMessage) (credentials.Bundle, error) { +func NewBundle(jd json.RawMessage) (credentials.Bundle, func(), error) { cfg := &struct { CertificateFile string `json:"certificate_file"` CACertificateFile string `json:"ca_certificate_file"` @@ -52,7 +52,7 @@ func NewBundle(jd json.RawMessage) (credentials.Bundle, error) { if jd != nil { if err := json.Unmarshal(jd, cfg); err != nil { - return nil, fmt.Errorf("failed to unmarshal config: %v", err) + return nil, nil, fmt.Errorf("failed to unmarshal config: %v", err) } } // Else the config field is absent. Treat it as an empty config. @@ -66,7 +66,7 @@ func NewBundle(jd json.RawMessage) (credentials.Bundle, error) { // > provider, at least one of the "certificate_file" or // > "ca_certificate_file" fields must be specified, whereas in this // > configuration, it is acceptable to specify neither one. - return &bundle{transportCredentials: credentials.NewTLS(&tls.Config{})}, nil + return &bundle{transportCredentials: credentials.NewTLS(&tls.Config{})}, func() {}, nil } // Otherwise we need to use a file_watcher provider to watch the CA, // private and public keys. @@ -74,11 +74,11 @@ func NewBundle(jd json.RawMessage) (credentials.Bundle, error) { // The pemfile plugin (file_watcher) currently ignores BuildOptions. provider, err := certprovider.GetProvider(pemfile.PluginName, jd, certprovider.BuildOptions{}) if err != nil { - return nil, err + return nil, nil, err } return &bundle{ transportCredentials: &reloadingCreds{provider: provider}, - }, nil + }, func() { provider.Close() }, nil } func (t *bundle) TransportCredentials() credentials.TransportCredentials { diff --git a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go index 5ef986619da0..432dd70cba9c 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go @@ -106,10 +106,10 @@ func (s) TestValidTlsBuilder(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { msg := json.RawMessage(test.jd) - if bundle, err := tlscreds.NewBundle(msg); err != nil { + if _, stop, err := tlscreds.NewBundle(msg); err != nil { t.Errorf("NewBundle(%s) returned error %s when expected to succeed", test.jd, err) } else { - bundle.(Closable).Close() + stop() } }) } @@ -133,9 +133,9 @@ func (s) TestInvalidTlsBuilder(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { msg := json.RawMessage(test.jd) - if bundle, err := tlscreds.NewBundle(msg); err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) { + if _, stop, err := tlscreds.NewBundle(msg); err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) { t.Errorf("NewBundle(%s): got error %s, want an error with prefix %s", msg, err, test.wantErrPrefix) - bundle.(Closable).Close() + stop() } }) } @@ -156,11 +156,11 @@ func (s) TestCaReloading(t *testing.T) { "ca_certificate_file": "%s", "refresh_interval": ".01s" }`, caPath) - tlsBundle, err := tlscreds.NewBundle([]byte(cfg)) + tlsBundle, stop, err := tlscreds.NewBundle([]byte(cfg)) if err != nil { t.Fatalf("Failed to create TLS bundle: %v", err) } - defer tlsBundle.(Closable).Close() + defer stop() serverCredentials := grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.NoClientCert)) server := stubserver.StartTestService(t, nil, serverCredentials) @@ -233,11 +233,11 @@ func (s) TestMTLS(t *testing.T) { testdata.Path("x509/server_ca_cert.pem"), testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem")) - tlsBundle, err := tlscreds.NewBundle([]byte(cfg)) + tlsBundle, stop, err := tlscreds.NewBundle([]byte(cfg)) if err != nil { t.Fatalf("Failed to create TLS bundle: %v", err) } - defer tlsBundle.(Closable).Close() + defer stop() conn, err := grpc.Dial(s.Address, grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com")) if err != nil { t.Fatalf("Error dialing: %v", err) diff --git a/xds/internal/xdsclient/tlscreds/bundle_test.go b/xds/internal/xdsclient/tlscreds/bundle_test.go index 67a54dcd0eda..8bc3f55b4c13 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_test.go @@ -64,10 +64,11 @@ func (s) TestFailingProvider(t *testing.T) { testdata.Path("x509/server_ca_cert.pem"), testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem")) - tlsBundle, err := NewBundle([]byte(cfg)) + tlsBundle, stop, err := NewBundle([]byte(cfg)) if err != nil { t.Fatalf("Failed to create TLS bundle: %v", err) } + stop() // Force a provider that returns an error, and make sure the client fails // the handshake. @@ -75,7 +76,6 @@ func (s) TestFailingProvider(t *testing.T) { if !ok { t.Fatalf("Got %T, expected reloadingCreds", tlsBundle.TransportCredentials()) } - creds.provider.Close() creds.provider = &failingProvider{} conn, err := grpc.Dial(s.Address, grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com")) diff --git a/xds/internal/xdsclient/transport/transport.go b/xds/internal/xdsclient/transport/transport.go index 76016a3073ff..001552d7b479 100644 --- a/xds/internal/xdsclient/transport/transport.go +++ b/xds/internal/xdsclient/transport/transport.go @@ -177,7 +177,7 @@ func New(opts Options) (*Transport, error) { switch { case opts.ServerCfg.ServerURI == "": return nil, errors.New("missing server URI when creating a new transport") - case opts.ServerCfg.CredsBundle() == nil: + case opts.ServerCfg.CredsDialOption() == nil: return nil, errors.New("missing credentials when creating a new transport") case opts.OnRecvHandler == nil: return nil, errors.New("missing OnRecv callback handler when creating a new transport") @@ -189,7 +189,7 @@ func New(opts Options) (*Transport, error) { // Dial the xDS management with the passed in credentials. dopts := []grpc.DialOption{ - grpc.WithCredentialsBundle(opts.ServerCfg.CredsBundle()), + opts.ServerCfg.CredsDialOption(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ // We decided to use these sane defaults in all languages, and // kicked the can down the road as far making these configurable. From e694f4cde8a70babdfb9dfc5f3bc98df0b251908 Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Thu, 21 Dec 2023 22:35:14 +0100 Subject: [PATCH 16/17] feedback from easwar: - close credentials provider for each authority - Fatal tests where possible. --- xds/internal/xdsclient/authority.go | 4 ++++ xds/internal/xdsclient/bootstrap/bootstrap_test.go | 8 ++++---- xds/internal/xdsclient/clientimpl.go | 3 --- xds/internal/xdsclient/tlscreds/bundle_ext_test.go | 4 +++- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/xds/internal/xdsclient/authority.go b/xds/internal/xdsclient/authority.go index 6ad61dae4ae4..ba0b080c92bc 100644 --- a/xds/internal/xdsclient/authority.go +++ b/xds/internal/xdsclient/authority.go @@ -448,6 +448,10 @@ func (a *authority) close() { a.resourcesMu.Lock() a.closed = true a.resourcesMu.Unlock() + + for _, cleanup := range a.serverCfg.Cleanups { + cleanup() + } } func (a *authority) watchResource(rType xdsresource.Type, resourceName string, watcher xdsresource.ResourceWatcher) func() { diff --git a/xds/internal/xdsclient/bootstrap/bootstrap_test.go b/xds/internal/xdsclient/bootstrap/bootstrap_test.go index ac822ae219bb..7975c66667b9 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap_test.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap_test.go @@ -1056,11 +1056,11 @@ func TestCredsBuilders(t *testing.T) { func TestTlsCredsBuilder(t *testing.T) { tls := &tlsCredsBuilder{} - if _, stop, err := tls.Build(json.RawMessage(`{}`)); err != nil { - t.Errorf("tls.Build() failed with error %s when expected to succeed", err) - } else { - stop() + _, stop, err := tls.Build(json.RawMessage(`{}`)) + if err != nil { + t.Fatalf("tls.Build() failed with error %s when expected to succeed", err) } + stop() if _, stop, err := tls.Build(json.RawMessage(`{"ca_certificate_file":"/ca_certificates.pem","refresh_interval": "asdf"}`)); err == nil { t.Errorf("tls.Build() succeeded with an invalid refresh interval, when expected to fail") stop() diff --git a/xds/internal/xdsclient/clientimpl.go b/xds/internal/xdsclient/clientimpl.go index 3dc5ceac7d16..2c05ea66f5f9 100644 --- a/xds/internal/xdsclient/clientimpl.go +++ b/xds/internal/xdsclient/clientimpl.go @@ -85,8 +85,5 @@ func (c *clientImpl) close() { c.authorityMu.Unlock() c.serializerClose() - for _, f := range c.config.XDSServer.Cleanups { - f() - } c.logger.Infof("Shutdown") } diff --git a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go index 432dd70cba9c..02eedf78dee3 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go @@ -135,7 +135,9 @@ func (s) TestInvalidTlsBuilder(t *testing.T) { msg := json.RawMessage(test.jd) if _, stop, err := tlscreds.NewBundle(msg); err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) { t.Errorf("NewBundle(%s): got error %s, want an error with prefix %s", msg, err, test.wantErrPrefix) - stop() + if err == nil { + stop() + } } }) } From 01f8c827dc99891412cb8a82438f4e6ab9a69e27 Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Fri, 22 Dec 2023 12:04:33 +0100 Subject: [PATCH 17/17] latest review from easwar - cleanup in clientimpl close rather than individual authority close - convert tests in bootstrap_test to table driven tests --- xds/internal/xdsclient/authority.go | 4 -- xds/internal/xdsclient/bootstrap/bootstrap.go | 2 +- .../xdsclient/bootstrap/bootstrap_test.go | 63 ++++++++----------- xds/internal/xdsclient/clientimpl.go | 12 ++++ xds/internal/xdsclient/tlscreds/bundle.go | 16 ++--- .../xdsclient/tlscreds/bundle_ext_test.go | 15 ++--- .../xdsclient/tlscreds/bundle_test.go | 2 +- 7 files changed, 54 insertions(+), 60 deletions(-) diff --git a/xds/internal/xdsclient/authority.go b/xds/internal/xdsclient/authority.go index ba0b080c92bc..6ad61dae4ae4 100644 --- a/xds/internal/xdsclient/authority.go +++ b/xds/internal/xdsclient/authority.go @@ -448,10 +448,6 @@ func (a *authority) close() { a.resourcesMu.Lock() a.closed = true a.resourcesMu.Unlock() - - for _, cleanup := range a.serverCfg.Cleanups { - cleanup() - } } func (a *authority) watchResource(rType xdsresource.Type, resourceName string, watcher xdsresource.ResourceWatcher) func() { diff --git a/xds/internal/xdsclient/bootstrap/bootstrap.go b/xds/internal/xdsclient/bootstrap/bootstrap.go index 31a7a69f93cf..0736a06d73a4 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap.go @@ -167,7 +167,7 @@ type ServerConfig struct { IgnoreResourceDeletion bool // Cleanups are called when the xDS client for this server is closed. Allows - // cleaning up resources created specifically for the xDS client. + // cleaning up resources created specifically for this ServerConfig. Cleanups []func() } diff --git a/xds/internal/xdsclient/bootstrap/bootstrap_test.go b/xds/internal/xdsclient/bootstrap/bootstrap_test.go index 7975c66667b9..a1138e2363d5 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap_test.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap_test.go @@ -1008,49 +1008,39 @@ func TestServerConfigMarshalAndUnmarshal(t *testing.T) { } func TestDefaultBundles(t *testing.T) { - if c := bootstrap.GetCredentials("google_default"); c == nil { - t.Errorf(`bootstrap.GetCredentials("google_default") credential is nil, want non-nil`) - } - - if c := bootstrap.GetCredentials("insecure"); c == nil { - t.Errorf(`bootstrap.GetCredentials("insecure") credential is nil, want non-nil`) - } + tests := []string{"google_default", "insecure", "tls"} - if c := bootstrap.GetCredentials("tls"); c == nil { - t.Errorf(`bootstrap.GetCredentials("tls") credential is nil, want non-nil`) + for _, typename := range tests { + t.Run(typename, func(t *testing.T) { + if c := bootstrap.GetCredentials(typename); c == nil { + t.Errorf(`bootstrap.GetCredentials(%s) credential is nil, want non-nil`, typename) + } + }) } } func TestCredsBuilders(t *testing.T) { - b := &googleDefaultCredsBuilder{} - if _, stop, err := b.Build(nil); err != nil { - t.Errorf("googleDefaultCredsBuilder.Build failed: %v", err) - } else { - stop() - } - if got, want := b.Name(), "google_default"; got != want { - t.Errorf("googleDefaultCredsBuilder.Name = %v, want %v", got, want) - } - - i := &insecureCredsBuilder{} - if _, stop, err := i.Build(nil); err != nil { - t.Errorf("insecureCredsBuilder.Build failed: %v", err) - } else { - stop() + tests := []struct { + typename string + builder bootstrap.Credentials + }{ + {"google_default", &googleDefaultCredsBuilder{}}, + {"insecure", &insecureCredsBuilder{}}, + {"tls", &tlsCredsBuilder{}}, } - if got, want := i.Name(), "insecure"; got != want { - t.Errorf("insecureCredsBuilder.Name = %v, want %v", got, want) - } + for _, test := range tests { + t.Run(test.typename, func(t *testing.T) { + if got, want := test.builder.Name(), test.typename; got != want { + t.Errorf("%T.Name = %v, want %v", test.builder, got, want) + } - tcb := &tlsCredsBuilder{} - if _, stop, err := tcb.Build(nil); err != nil { - t.Errorf("tlsCredsBuilder.Build failed: %v", err) - } else { - stop() - } - if got, want := tcb.Name(), "tls"; got != want { - t.Errorf("tlsCredsBuilder.Name = %v, want %v", got, want) + _, stop, err := test.builder.Build(nil) + if err != nil { + t.Fatalf("%T.Build failed: %v", test.builder, err) + } + stop() + }) } } @@ -1061,9 +1051,10 @@ func TestTlsCredsBuilder(t *testing.T) { t.Fatalf("tls.Build() failed with error %s when expected to succeed", err) } stop() + if _, stop, err := tls.Build(json.RawMessage(`{"ca_certificate_file":"/ca_certificates.pem","refresh_interval": "asdf"}`)); err == nil { t.Errorf("tls.Build() succeeded with an invalid refresh interval, when expected to fail") stop() } - // more tests for config validity are defined in tlscreds subpackage. + // package internal/xdsclient/tlscreds has tests for config validity. } diff --git a/xds/internal/xdsclient/clientimpl.go b/xds/internal/xdsclient/clientimpl.go index 2c05ea66f5f9..1088b60301cb 100644 --- a/xds/internal/xdsclient/clientimpl.go +++ b/xds/internal/xdsclient/clientimpl.go @@ -85,5 +85,17 @@ func (c *clientImpl) close() { c.authorityMu.Unlock() c.serializerClose() + for _, f := range c.config.XDSServer.Cleanups { + f() + } + for _, a := range c.config.Authorities { + if a.XDSServer == nil { + // The server for this authority is the top-level one, cleaned up above. + continue + } + for _, f := range a.XDSServer.Cleanups { + f() + } + } c.logger.Infof("Shutdown") } diff --git a/xds/internal/xdsclient/tlscreds/bundle.go b/xds/internal/xdsclient/tlscreds/bundle.go index c4e977c9d76d..02da3dbf3496 100644 --- a/xds/internal/xdsclient/tlscreds/bundle.go +++ b/xds/internal/xdsclient/tlscreds/bundle.go @@ -31,6 +31,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/credentials/tls/certprovider/pemfile" + "google.golang.org/grpc/internal/grpcsync" ) // bundle is an implementation of credentials.Bundle which implements mTLS @@ -41,7 +42,9 @@ type bundle struct { // NewBundle returns a credentials.Bundle which implements mTLS Credentials in xDS // Bootstrap File. It delegates certificate loading to a file_watcher provider -// if either client certificates or server root CA is specified. +// if either client certificates or server root CA is specified. The second +// return value is a close func that should be called when the caller no longer +// needs this bundle. // See gRFC A65: github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md func NewBundle(jd json.RawMessage) (credentials.Bundle, func(), error) { cfg := &struct { @@ -78,7 +81,7 @@ func NewBundle(jd json.RawMessage) (credentials.Bundle, func(), error) { } return &bundle{ transportCredentials: &reloadingCreds{provider: provider}, - }, func() { provider.Close() }, nil + }, grpcsync.OnceFunc(func() { provider.Close() }), nil } func (t *bundle) TransportCredentials() credentials.TransportCredentials { @@ -97,15 +100,6 @@ func (t *bundle) NewWithMode(string) (credentials.Bundle, error) { return nil, fmt.Errorf("xDS TLS credentials only support one mode") } -// Close releases the underlying provider. Note that credentials.Bundle are -// not closeable, so users of this type must use a type assertion to call Close. -func (t *bundle) Close() { - cred, ok := t.transportCredentials.(*reloadingCreds) - if ok { - cred.provider.Close() - } -} - // reloadingCreds is a credentials.TransportCredentials for client // side mTLS that reloads the server root CA certificate and the client // certificates from the provider on every client handshake. This is necessary diff --git a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go index 02eedf78dee3..bda7319d83ce 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go @@ -106,11 +106,11 @@ func (s) TestValidTlsBuilder(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { msg := json.RawMessage(test.jd) - if _, stop, err := tlscreds.NewBundle(msg); err != nil { - t.Errorf("NewBundle(%s) returned error %s when expected to succeed", test.jd, err) - } else { - stop() + _, stop, err := tlscreds.NewBundle(msg) + if err != nil { + t.Fatalf("NewBundle(%s) returned error %s when expected to succeed", test.jd, err) } + stop() }) } } @@ -133,11 +133,12 @@ func (s) TestInvalidTlsBuilder(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { msg := json.RawMessage(test.jd) - if _, stop, err := tlscreds.NewBundle(msg); err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) { - t.Errorf("NewBundle(%s): got error %s, want an error with prefix %s", msg, err, test.wantErrPrefix) - if err == nil { + _, stop, err := tlscreds.NewBundle(msg) + if err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) { + if stop != nil { stop() } + t.Fatalf("NewBundle(%s): got error %s, want an error with prefix %s", msg, err, test.wantErrPrefix) } }) } diff --git a/xds/internal/xdsclient/tlscreds/bundle_test.go b/xds/internal/xdsclient/tlscreds/bundle_test.go index 8bc3f55b4c13..ad50508aeb94 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_test.go @@ -46,7 +46,7 @@ func Test(t *testing.T) { type failingProvider struct{} -func (f failingProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) { +func (f failingProvider) KeyMaterial(context.Context) (*certprovider.KeyMaterial, error) { return nil, errors.New("test error") }