diff --git a/testing/proxytest/https.go b/testing/proxytest/https.go new file mode 100644 index 0000000..d8fd3b7 --- /dev/null +++ b/testing/proxytest/https.go @@ -0,0 +1,216 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package proxytest + +import ( + "bufio" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "errors" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/url" + "strings" + + "github.com/elastic/elastic-agent-libs/testing/certutil" +) + +func (p *Proxy) serveHTTPS(w http.ResponseWriter, r *http.Request) { + log := loggerFromReqCtx(r) + + clientCon, err := hijack(w) + if err != nil { + p.http500Error(clientCon, "cannot handle request", err, log) + return + } + defer clientCon.Close() + + // Hijack successful, w is now useless, let's make sure it isn't used by + // mistake ;) + w = nil //nolint:ineffassign,wastedassign // w is now useless, let's make sure it isn't used by mistake ;) + + // ==================== CONNECT accepted, let the client know + _, err = clientCon.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")) + if err != nil { + p.http500Error(clientCon, "failed to send 200-OK after CONNECT", err, log) + return + } + + // ==================== TLS handshake + // client will proceed to perform the TLS handshake with the "target", + // which we're impersonating. + + // generate a TLS certificate matching the target's host + cert, err := p.newTLSCert(r.URL) + if err != nil { + p.http500Error(clientCon, "failed generating certificate", err, log) + return + } + + tlscfg := p.TLS.Clone() + tlscfg.Certificates = []tls.Certificate{*cert} + clientTLSConn := tls.Server(clientCon, tlscfg) + defer clientTLSConn.Close() + err = clientTLSConn.Handshake() + if err != nil { + p.http500Error(clientCon, "failed TLS handshake with client", err, log) + return + } + + clientTLSReader := bufio.NewReader(clientTLSConn) + + notEOF := func(r *bufio.Reader) bool { + _, err = r.Peek(1) + return !errors.Is(err, io.EOF) + } + // ==================== Handle the actual request + for notEOF(clientTLSReader) { + // read request from the client sent after the 1s CONNECT request + req, err := http.ReadRequest(clientTLSReader) + if err != nil { + p.http500Error(clientTLSConn, "failed reading client request", err, log) + return + } + + // carry over the original remote addr + req.RemoteAddr = r.RemoteAddr + + // the read request is relative to the host from the original CONNECT + // request and without scheme. Therefore, set them in the new request. + req.URL, err = url.Parse("https://" + r.Host + req.URL.String()) + if err != nil { + p.http500Error(clientTLSConn, "failed reading request URL from client", err, log) + return + } + cleanUpHeaders(req.Header) + + // now the request is ready, it can be altered and sent just as it's + // done for an HTTP request. + resp, err := p.processRequest(req) + if err != nil { + p.httpError(clientTLSConn, + http.StatusBadGateway, + "failed performing request to target", err, log) + return + } + + clientResp := http.Response{ + ProtoMajor: 1, + ProtoMinor: 1, + StatusCode: resp.StatusCode, + TransferEncoding: append([]string{}, resp.TransferEncoding...), + Trailer: resp.Trailer.Clone(), + Body: resp.Body, + ContentLength: resp.ContentLength, + Header: resp.Header.Clone(), + } + + err = clientResp.Write(clientTLSConn) + if err != nil { + p.http500Error(clientTLSConn, "failed writing response body", err, log) + return + } + + _ = resp.Body.Close() + } +} + +func (p *Proxy) newTLSCert(u *url.URL) (*tls.Certificate, error) { + // generate the certificate key - it needs to be RSA because Elastic Defend + // do not support EC :/ + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("could not create RSA private key: %w", err) + } + host := u.Hostname() + + var name string + var ips []net.IP + ip := net.ParseIP(host) + if ip == nil { // host isn't an IP, therefore it must be an DNS + name = host + } else { + ips = append(ips, ip) + } + + cert, _, err := certutil.GenerateGenericChildCert( + name, + ips, + priv, + &priv.PublicKey, + p.ca.capriv, + p.ca.cacert) + if err != nil { + return nil, fmt.Errorf("could not generate TLS certificate for %s: %w", + host, err) + } + + return cert, nil +} + +func (p *Proxy) http500Error(clientCon net.Conn, msg string, err error, log *slog.Logger) { + p.httpError(clientCon, http.StatusInternalServerError, msg, err, log) +} + +func (p *Proxy) httpError(clientCon net.Conn, status int, msg string, err error, log *slog.Logger) { + log.Error(msg, "err", err) + + resp := http.Response{ + StatusCode: status, + ProtoMajor: 1, + ProtoMinor: 1, + Body: io.NopCloser(strings.NewReader(msg)), + Header: http.Header{}, + } + resp.Header.Set("Content-Type", "text/html; charset=utf-8") + + err = resp.Write(clientCon) + if err != nil { + log.Error("failed writing response", "err", err) + } +} + +func hijack(w http.ResponseWriter) (net.Conn, error) { + hijacker, ok := w.(http.Hijacker) + if !ok { + w.WriteHeader(http.StatusInternalServerError) + _, _ = fmt.Fprint(w, "cannot handle request") + return nil, errors.New("http.ResponseWriter does not support hijacking") + } + + clientCon, _, err := hijacker.Hijack() + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + _, err = fmt.Fprint(w, "cannot handle request") + + return nil, fmt.Errorf("could not Hijack HTTPS CONNECT request: %w", err) + } + + return clientCon, err +} + +func cleanUpHeaders(h http.Header) { + h.Del("Proxy-Connection") + h.Del("Proxy-Authenticate") + h.Del("Proxy-Authorization") + h.Del("Connection") +} diff --git a/testing/proxytest/proxytest.go b/testing/proxytest/proxytest.go new file mode 100644 index 0000000..381a645 --- /dev/null +++ b/testing/proxytest/proxytest.go @@ -0,0 +1,406 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package proxytest + +import ( + "bufio" + "context" + "crypto" + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "log" + "log/slog" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + + "github.com/gofrs/uuid/v5" +) + +type Proxy struct { + *httptest.Server + + // Port is the port Server is listening on. + Port string + + // LocalhostURL is the server URL as "http(s)://localhost:PORT". + // Deprecated. Use Proxy.URL instead. + LocalhostURL string + + // proxiedRequests is a "request log" for every request the proxy receives. + proxiedRequests []string + proxiedRequestsMu sync.Mutex + requestsWG *sync.WaitGroup + + opts options + log *slog.Logger + + ca ca + client *http.Client +} + +type Option func(o *options) + +type options struct { + addr string + rewriteHost func(string) string + rewriteURL func(u *url.URL) + // logFn if set will be used to log every request. + logFn func(format string, a ...any) + verbose bool + serverTLSConfig *tls.Config + capriv crypto.PrivateKey + cacert *x509.Certificate + client *http.Client +} + +type ca struct { + capriv crypto.PrivateKey + cacert *x509.Certificate +} + +// WithAddress will set the address the server will listen on. The format is as +// defined by net.Listen for a tcp connection. +func WithAddress(addr string) Option { + return func(o *options) { + o.addr = addr + } +} + +// WithHTTPClient sets http.Client used to proxy requests to the target host. +func WithHTTPClient(c *http.Client) Option { + return func(o *options) { + o.client = c + } +} + +// WithMITMCA sets the CA used for MITM (men in the middle) when proxying HTTPS +// requests. It's used to generate TLS certificates matching the target host. +// Ideally the CA is the same as the one issuing the TLS certificate for the +// proxy set by WithServerTLSConfig. +func WithMITMCA(priv crypto.PrivateKey, cert *x509.Certificate) func(o *options) { + return func(o *options) { + o.capriv = priv + o.cacert = cert + } +} + +// WithRequestLog sets the proxy to log every request using logFn. It uses name +// as a prefix to the log. +func WithRequestLog(name string, logFn func(format string, a ...any)) Option { + return func(o *options) { + o.logFn = func(format string, a ...any) { + logFn("[proxy-"+name+"] "+format, a...) + } + } +} + +// WithRewrite will replace old by new on the request URL host when forwarding it. +func WithRewrite(old, new string) Option { + return func(o *options) { + o.rewriteHost = func(s string) string { + return strings.Replace(s, old, new, 1) + } + } +} + +// WithRewriteFn calls f on the request *url.URL before forwarding it. +// It takes precedence over WithRewrite. Use if more control over the rewrite +// is needed. +func WithRewriteFn(f func(u *url.URL)) Option { + return func(o *options) { + o.rewriteURL = f + } +} + +// WithServerTLSConfig sets the TLS config for the server. +func WithServerTLSConfig(tc *tls.Config) Option { + return func(o *options) { + o.serverTLSConfig = tc + } +} + +// WithVerboseLog sets the proxy to log every request verbosely and enables +// debug level logging. WithRequestLog must be used as well, otherwise +// WithVerboseLog will not take effect. +func WithVerboseLog() Option { + return func(o *options) { + o.verbose = true + } +} + +// New returns a new Proxy ready for use. Use: +// - WithAddress to set the proxy's address, +// - WithRewrite or WithRewriteFn to rewrite the URL before forwarding the request. +// +// Check the other With* functions for more options. +func New(t *testing.T, optns ...Option) *Proxy { + t.Helper() + + opts := options{addr: "127.0.0.1:0", client: &http.Client{}} + for _, o := range optns { + o(&opts) + } + + if opts.logFn == nil { + opts.logFn = func(format string, a ...any) {} + } + + l, err := net.Listen("tcp", opts.addr) //nolint:gosec,nolintlint // it's a test + if err != nil { + t.Fatalf("NewServer failed to create a net.Listener: %v", err) + } + + // Create a text handler that writes to standard output + lv := slog.LevelInfo + if opts.verbose { + lv = slog.LevelDebug + } + p := Proxy{ + requestsWG: &sync.WaitGroup{}, + opts: opts, + client: opts.client, + log: slog.New(slog.NewTextHandler(logfWriter(opts.logFn), &slog.HandlerOptions{ + Level: lv, + })), + } + if opts.capriv != nil && opts.cacert != nil { + p.ca = ca{capriv: opts.capriv, cacert: opts.cacert} + } + + p.Server = httptest.NewUnstartedServer( + http.HandlerFunc(func(ww http.ResponseWriter, r *http.Request) { + // Sometimes, on CI obviously, the last log happens after the test + // finishes. See https://github.com/elastic/elastic-agent/issues/5869. + // Therefore, let's add an extra layer to try to avoid that. + p.requestsWG.Add(1) + defer p.requestsWG.Done() + + w := &proxyResponseWriter{w: ww} + + requestID := uuid.Must(uuid.NewV4()).String() + p.log.Info(fmt.Sprintf("STARTING - %s '%s' %s %s", + r.Method, r.URL, r.Proto, r.RemoteAddr)) + + rr := addIDToReqCtx(r, requestID) + rrr := addLoggerReqCtx(rr, p.log.With("req_id", requestID)) + + p.ServeHTTP(w, rrr) + + p.log.Info(fmt.Sprintf("[%s] DONE %d - %s %s %s %s\n", + requestID, w.statusCode, r.Method, r.URL, r.Proto, r.RemoteAddr)) + }), + ) + p.Server.Listener = l + + if opts.serverTLSConfig != nil { + p.Server.TLS = opts.serverTLSConfig + } + + u, err := url.Parse(p.URL) + if err != nil { + panic(fmt.Sprintf("could parse fleet-server URL: %v", err)) + } + + p.Port = u.Port() + p.LocalhostURL = "http://localhost:" + p.Port + + return &p +} + +func (p *Proxy) Start() error { + p.Server.Start() + u, err := url.Parse(p.URL) + if err != nil { + return fmt.Errorf("could not parse fleet-server URL: %w", err) + } + + p.Port = u.Port() + p.LocalhostURL = "http://localhost:" + p.Port + + p.log.Info(fmt.Sprintf("running on %s -> %s", p.URL, p.LocalhostURL)) + return nil +} + +func (p *Proxy) StartTLS() error { + p.Server.StartTLS() + u, err := url.Parse(p.URL) + if err != nil { + return fmt.Errorf("could not parse fleet-server URL: %w", err) + } + + p.Port = u.Port() + p.LocalhostURL = "https://localhost:" + p.Port + + p.log.Info(fmt.Sprintf("running on %s -> %s", p.URL, p.LocalhostURL)) + return nil +} + +func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodConnect { + p.serveHTTPS(w, r) + return + } + + p.serveHTTP(w, r) +} + +func (p *Proxy) Close() { + // Sometimes, on CI obviously, the last log happens after the test + // finishes. See https://github.com/elastic/elastic-agent/issues/5869. + // So, manually wait all ongoing requests to finish. + p.requestsWG.Wait() + + p.Server.Close() +} + +func (p *Proxy) serveHTTP(w http.ResponseWriter, r *http.Request) { + resp, err := p.processRequest(r) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + msg := fmt.Sprintf("could not make request: %#v", err.Error()) + log.Print(msg) + _, _ = fmt.Fprint(w, msg) + return + } + defer resp.Body.Close() + + for k, v := range resp.Header { + w.Header()[k] = v + } + + w.WriteHeader(resp.StatusCode) + + if _, err = io.Copy(w, resp.Body); err != nil { + p.opts.logFn("[ERROR] could not write response body: %v", err) + } +} + +// processRequest executes the configured request manipulation and perform the +// request. +func (p *Proxy) processRequest(r *http.Request) (*http.Response, error) { + origURL := r.URL.String() + + switch { + case p.opts.rewriteURL != nil: + p.opts.rewriteURL(r.URL) + case p.opts.rewriteHost != nil: + r.URL.Host = p.opts.rewriteHost(r.URL.Host) + } + + // It should not be required, however if not set, enroll will fail with + // "Unknown resource" + r.Host = r.URL.Host + + p.log.Debug(fmt.Sprintf("original URL: %s, new URL: %s", + origURL, r.URL.String())) + + p.proxiedRequestsMu.Lock() + p.proxiedRequests = append(p.proxiedRequests, + fmt.Sprintf("%s - %s %s %s", + r.Method, r.URL.Scheme, r.URL.Host, r.URL.String())) + p.proxiedRequestsMu.Unlock() + + // when modifying the request, RequestURI isn't updated, and it isn't + // needed anyway, so remove it. + r.RequestURI = "" + + return p.client.Do(r) +} + +// ProxiedRequests returns a slice with the "request log" with every request the +// proxy received. +func (p *Proxy) ProxiedRequests() []string { + p.proxiedRequestsMu.Lock() + defer p.proxiedRequestsMu.Unlock() + + var rs []string + rs = append(rs, p.proxiedRequests...) + return rs +} + +var _ http.Hijacker = &proxyResponseWriter{} + +// proxyResponseWriter wraps a http.ResponseWriter to expose the status code +// through proxyResponseWriter.statusCode +type proxyResponseWriter struct { + w http.ResponseWriter + statusCode int +} + +func (s *proxyResponseWriter) Header() http.Header { + return s.w.Header() +} + +func (s *proxyResponseWriter) Write(bs []byte) (int, error) { + return s.w.Write(bs) +} + +func (s *proxyResponseWriter) WriteHeader(statusCode int) { + s.statusCode = statusCode + s.w.WriteHeader(statusCode) +} + +func (s *proxyResponseWriter) StatusCode() int { + return s.statusCode +} + +func (s *proxyResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := s.w.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("%T does not support hijacking", s.w) + } + + return hijacker.Hijack() +} + +type ctxKeyRecID struct{} +type ctxKeyLogger struct{} + +func addIDToReqCtx(r *http.Request, id string) *http.Request { + return r.WithContext(context.WithValue(r.Context(), ctxKeyRecID{}, id)) +} + +func idFromReqCtx(r *http.Request) string { //nolint:unused // kept for completeness + return r.Context().Value(ctxKeyRecID{}).(string) +} + +func addLoggerReqCtx(r *http.Request, log *slog.Logger) *http.Request { + return r.WithContext(context.WithValue(r.Context(), ctxKeyLogger{}, log)) +} + +func loggerFromReqCtx(r *http.Request) *slog.Logger { + l, ok := r.Context().Value(ctxKeyLogger{}).(*slog.Logger) + if !ok { + return slog.Default() + } + return l +} + +type logfWriter func(format string, a ...any) + +func (w logfWriter) Write(p []byte) (n int, err error) { + w(string(p)) + return len(p), nil +} diff --git a/testing/proxytest/proxytest_example_test.go b/testing/proxytest/proxytest_example_test.go new file mode 100644 index 0000000..4a439fe --- /dev/null +++ b/testing/proxytest/proxytest_example_test.go @@ -0,0 +1,190 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//go:build example + +package proxytest + +import ( + "context" + "crypto/tls" + "crypto/x509" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/elastic/elastic-agent-libs/testing/certutil" +) + +// TestRunHTTPSProxy is an example of how to use the proxytest outside tests, +// and it instructs how to perform a request through the proxy using cURL. +// From the repo's root, run this test with: +// go test -tags example -v -run TestRunHTTPSProxy$ ./testing/proxytest +func TestRunHTTPSProxy(t *testing.T) { + // Create a temporary directory to store certificates + tmpDir := t.TempDir() + + // ========================= generate certificates ========================= + serverCAKey, serverCACert, serverCAPair, err := certutil.NewRootCA( + certutil.WithCNPrefix("server")) + require.NoError(t, err, "error creating root CA") + + serverCert, _, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + serverCAKey, + serverCACert, certutil.WithCNPrefix("server")) + require.NoError(t, err, "error creating server certificate") + serverCACertPool := x509.NewCertPool() + serverCACertPool.AddCert(serverCACert) + + proxyCAKey, proxyCACert, proxyCAPair, err := certutil.NewRootCA( + certutil.WithCNPrefix("proxy")) + require.NoError(t, err, "error creating root CA") + + proxyCert, proxyCertPair, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + proxyCAKey, + proxyCACert, + certutil.WithCNPrefix("proxy")) + require.NoError(t, err, "error creating server certificate") + + clientCAKey, clientCACert, clientCAPair, err := certutil.NewRootCA( + certutil.WithCNPrefix("client")) + require.NoError(t, err, "error creating root CA") + clientCACertPool := x509.NewCertPool() + clientCACertPool.AddCert(clientCACert) + + _, clientCertPair, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + clientCAKey, + clientCACert, + certutil.WithCNPrefix("client")) + require.NoError(t, err, "error creating server certificate") + + // =========================== save certificates =========================== + serverCACertFile := filepath.Join(tmpDir, "serverCA.crt") + if err := os.WriteFile(serverCACertFile, serverCAPair.Cert, 0644); err != nil { + t.Fatal(err) + } + serverCAKeyFile := filepath.Join(tmpDir, "serverCA.key") + if err := os.WriteFile(serverCAKeyFile, serverCAPair.Key, 0644); err != nil { + t.Fatal(err) + } + + proxyCACertFile := filepath.Join(tmpDir, "proxyCA.crt") + if err := os.WriteFile(proxyCACertFile, proxyCAPair.Cert, 0644); err != nil { + t.Fatal(err) + } + proxyCAKeyFile := filepath.Join(tmpDir, "proxyCA.key") + if err := os.WriteFile(proxyCAKeyFile, proxyCAPair.Key, 0644); err != nil { + t.Fatal(err) + } + proxyCertFile := filepath.Join(tmpDir, "proxyCert.crt") + if err := os.WriteFile(proxyCertFile, proxyCertPair.Cert, 0644); err != nil { + t.Fatal(err) + } + proxyKeyFile := filepath.Join(tmpDir, "proxyCert.key") + if err := os.WriteFile(proxyKeyFile, proxyCertPair.Key, 0644); err != nil { + t.Fatal(err) + } + + clientCACertFile := filepath.Join(tmpDir, "clientCA.crt") + if err := os.WriteFile(clientCACertFile, clientCAPair.Cert, 0644); err != nil { + t.Fatal(err) + } + clientCAKeyFile := filepath.Join(tmpDir, "clientCA.key") + if err := os.WriteFile(clientCAKeyFile, clientCAPair.Key, 0644); err != nil { + t.Fatal(err) + } + clientCertCertFile := filepath.Join(tmpDir, "clientCert.crt") + if err := os.WriteFile(clientCertCertFile, clientCertPair.Cert, 0644); err != nil { + t.Fatal(err) + } + clientCertKeyFile := filepath.Join(tmpDir, "clientCert.key") + if err := os.WriteFile(clientCertKeyFile, clientCertPair.Key, 0644); err != nil { + t.Fatal(err) + } + + // ========================== create target server ========================= + targetHost := "not-a-server.co" + server := httptest.NewUnstartedServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("It works!")) + })) + server.TLS = &tls.Config{ + Certificates: []tls.Certificate{*serverCert}, + MinVersion: tls.VersionTLS13, + } + server.StartTLS() + t.Logf("target server running on %s", server.URL) + + // ============================== create proxy ============================= + proxy := New(t, + WithVerboseLog(), + WithRequestLog("https", t.Logf), + WithRewrite(targetHost+":443", server.URL[8:]), + WithMITMCA(proxyCAKey, proxyCACert), + WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: serverCACertPool, + MinVersion: tls.VersionTLS13, + }, + }, + }), + WithServerTLSConfig(&tls.Config{ + Certificates: []tls.Certificate{*proxyCert}, + ClientCAs: clientCACertPool, + ClientAuth: tls.VerifyClientCertIfGiven, + MinVersion: tls.VersionTLS13, + })) + err = proxy.StartTLS() + require.NoError(t, err, "error starting proxy") + t.Logf("proxy running on %s", proxy.LocalhostURL) + defer proxy.Close() + + // ============================ test instructions ========================== + + u := "https://" + targetHost + t.Logf("make request to %q using proxy %q", u, proxy.LocalhostURL) + + t.Logf(`curl \ +--proxy-cacert %s \ +--proxy-cert %s \ +--proxy-key %s \ +--cacert %s \ +--proxy %s \ +%s`, + proxyCACertFile, + clientCertCertFile, + clientCertKeyFile, + proxyCACertFile, + proxy.URL, + u, + ) + + t.Log("CTRL+C to stop") + <-context.Background().Done() +} diff --git a/testing/proxytest/proxytest_test.go b/testing/proxytest/proxytest_test.go new file mode 100644 index 0000000..1bedbff --- /dev/null +++ b/testing/proxytest/proxytest_test.go @@ -0,0 +1,401 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package proxytest + +import ( + "context" + "crypto/tls" + "crypto/x509" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/elastic/elastic-agent-libs/testing/certutil" +) + +func TestProxy(t *testing.T) { + proxyCAKey, proxyCACert, _, err := certutil.NewRootCA() + require.NoError(t, err, "error creating root CA") + + proxyCert, _, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + proxyCAKey, + proxyCACert) + require.NoError(t, err, "error creating server certificate") + + proxyCACertPool := x509.NewCertPool() + proxyCACertPool.AddCert(proxyCACert) + + type setup struct { + fakeBackendServer *httptest.Server + generateTestHttpClient func(t *testing.T, proxy *Proxy) *http.Client + } + type testRequest struct { + method string + url string + body io.Reader + } + type testcase struct { + name string + setup setup + proxyOptions []Option + proxyStartTLS bool + request testRequest + wantErr assert.ErrorAssertionFunc + assertFunc func(t *testing.T, proxy *Proxy, resp *http.Response) + } + + testcases := []testcase{ + { + name: "Basic scenario, no TLS", + setup: setup{ + fakeBackendServer: createFakeBackendServer(), + generateTestHttpClient: nil, + }, + proxyOptions: nil, + proxyStartTLS: false, + request: testRequest{ + method: http.MethodGet, + url: "http://somehost:1234/some/path/here", + body: nil, + }, + wantErr: assert.NoError, + assertFunc: func(t *testing.T, proxy *Proxy, resp *http.Response) { + assert.Equal(t, http.StatusOK, resp.StatusCode) + if assert.NotEmpty(t, proxy.ProxiedRequests(), "proxy should have captured at least 1 request") { + assert.Contains(t, proxy.ProxiedRequests()[0], "/some/path/here") + } + }, + }, + { + name: "TLS scenario, server cert validation", + setup: setup{ + fakeBackendServer: createFakeBackendServer(), + generateTestHttpClient: func(t *testing.T, proxy *Proxy) *http.Client { + proxyURL, err := url.Parse(proxy.URL) + require.NoErrorf(t, err, "failed to parse proxy URL %q", proxy.URL) + + // Client trusting the proxy cert CA + return &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + TLSClientConfig: &tls.Config{ + RootCAs: proxyCACertPool, + MinVersion: tls.VersionTLS12, + }, + }, + } + }, + }, + proxyOptions: []Option{ + WithServerTLSConfig(&tls.Config{ + ClientCAs: proxyCACertPool, + Certificates: []tls.Certificate{*proxyCert}, + MinVersion: tls.VersionTLS12, + }), + }, + proxyStartTLS: true, + request: testRequest{ + method: http.MethodGet, + url: "http://somehost:1234/some/path/here", + body: nil, + }, + wantErr: assert.NoError, + assertFunc: func(t *testing.T, proxy *Proxy, resp *http.Response) { + assert.Equal(t, http.StatusOK, resp.StatusCode) + if assert.NotEmpty(t, proxy.ProxiedRequests(), "proxy should have captured at least 1 request") { + assert.Contains(t, proxy.ProxiedRequests()[0], "/some/path/here") + } + }, + }, + { + name: "mTLS scenario, client and server cert validation", + setup: setup{ + fakeBackendServer: createFakeBackendServer(), + generateTestHttpClient: func(t *testing.T, proxy *Proxy) *http.Client { + proxyURL, err := url.Parse(proxy.URL) + require.NoErrorf(t, err, "failed to parse proxy URL %q", proxy.URL) + + // Client certificate + tlsCert, _, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + proxyCAKey, + proxyCACert) + require.NoError(t, err, "failed generating client certificate") + + // Client with its own certificate and trusting the proxy cert CA + return &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + TLSClientConfig: &tls.Config{ + RootCAs: proxyCACertPool, + Certificates: []tls.Certificate{ + *tlsCert, + }, + MinVersion: tls.VersionTLS12, + }, + }} + }, + }, + proxyOptions: []Option{ + // require client authentication and verify cert + WithServerTLSConfig(&tls.Config{ + ClientCAs: proxyCACertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + Certificates: []tls.Certificate{*proxyCert}, + MinVersion: tls.VersionTLS12, + }), + }, + proxyStartTLS: true, + request: testRequest{ + method: http.MethodGet, + url: "http://somehost:1234/some/path/here", + body: nil, + }, + wantErr: assert.NoError, + assertFunc: func(t *testing.T, proxy *Proxy, resp *http.Response) { + assert.Equal(t, http.StatusOK, resp.StatusCode) + if assert.NotEmpty(t, proxy.ProxiedRequests(), "proxy should have captured at least 1 request") { + assert.Contains(t, proxy.ProxiedRequests()[0], "/some/path/here") + } + }, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + var proxyOpts []Option + + if tt.setup.fakeBackendServer != nil { + defer tt.setup.fakeBackendServer.Close() + serverURL, err := url.Parse(tt.setup.fakeBackendServer.URL) + require.NoErrorf(t, err, "failed to parse test HTTP server URL %q", tt.setup.fakeBackendServer.URL) + proxyOpts = append(proxyOpts, WithRewriteFn(func(u *url.URL) { + // redirect the requests on the proxy itself + u.Host = serverURL.Host + })) + } + + proxyOpts = append(proxyOpts, tt.proxyOptions...) + proxy := New(t, proxyOpts...) + + if tt.proxyStartTLS { + t.Log("Starting proxytest with TLS") + err = proxy.StartTLS() + } else { + t.Log("Starting proxytest without TLS") + err = proxy.Start() + } + require.NoError(t, err, "error starting proxytest") + + defer proxy.Close() + + proxyURL, err := url.Parse(proxy.URL) + require.NoErrorf(t, err, "failed to parse proxy URL %q", proxy.URL) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, tt.request.method, tt.request.url, tt.request.body) + require.NoError(t, err, "error creating request") + + var client *http.Client + if tt.setup.generateTestHttpClient != nil { + client = tt.setup.generateTestHttpClient(t, proxy) + } else { + // basic HTTP client using the proxy + client = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)}} + } + + resp, err := client.Do(req) + if resp != nil { + defer resp.Body.Close() + } + if tt.wantErr(t, err, "unexpected error return value") && tt.assertFunc != nil { + tt.assertFunc(t, proxy, resp) + } + }) + } + +} + +func TestHTTPSProxy(t *testing.T) { + targetHost := "not-a-server.co" + proxy, client, target := prepareMTLSProxyAndTargetServer(t, targetHost) + t.Cleanup(func() { + proxy.Close() + target.Close() + }) + + tcs := []struct { + name string + target string + // assertFn should not close the response body + assertFn func(*testing.T, *http.Response, error) + }{ + { + name: "successful_request", + target: "https://" + targetHost, + assertFn: func(t *testing.T, got *http.Response, err error) { + if !assert.Equal(t, http.StatusOK, got.StatusCode, "unexpected status code") { + body, err := io.ReadAll(got.Body) + if err != nil { + t.Logf("could not read response body") + t.FailNow() + } + _ = got.Body.Close() + + t.Logf("request body: %s", string(body)) + } + + }, + }, + { + name: "request_failure", + target: "https://any.not.target.will.do", + assertFn: func(t *testing.T, got *http.Response, err error) { + assert.NoError(t, err, "request to an invalid host should not fail, but succeed with a HTTP error") + assert.Equal(t, http.StatusBadGateway, got.StatusCode) + + body, err := io.ReadAll(got.Body) + require.NoError(t, err, "failed reading response body") + assert.Contains(t, string(body), "failed performing request to target") + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + t.Logf("making request to %q using proxy %q", tc.target, proxy.URL) + + got, err := client.Get(tc.target) //nolint:noctx // it's a test + require.NoError(t, err, "request should have succeeded") + defer got.Body.Close() + + // assertFn should not close the response body + tc.assertFn(t, got, err) + }) + } +} + +func prepareMTLSProxyAndTargetServer(t *testing.T, targetHost string) (*Proxy, http.Client, *httptest.Server) { + serverCAKey, serverCACert, _, err := certutil.NewRootCA() + require.NoError(t, err, "error creating root CA") + + serverCert, _, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + serverCAKey, + serverCACert) + require.NoError(t, err, "error creating server certificate") + serverCACertPool := x509.NewCertPool() + serverCACertPool.AddCert(serverCACert) + + proxyCAKey, proxyCACert, _, err := certutil.NewRootCA() + require.NoError(t, err, "error creating root CA") + + proxyCert, _, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + proxyCAKey, + proxyCACert) + require.NoError(t, err, "error creating server certificate") + proxyCACertPool := x509.NewCertPool() + proxyCACertPool.AddCert(proxyCACert) + + clientCAKey, clientCACert, _, err := certutil.NewRootCA() + require.NoError(t, err, "error creating root CA") + clientCACertPool := x509.NewCertPool() + clientCACertPool.AddCert(clientCACert) + + clientCert, _, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + clientCAKey, + clientCACert) + require.NoError(t, err, "error creating server certificate") + + server := httptest.NewUnstartedServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("It works!")) + })) + server.TLS = &tls.Config{ + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{*serverCert}, + } + server.StartTLS() + t.Logf("target server running on %s", server.URL) + + proxy := New(t, + WithVerboseLog(), + WithRequestLog("https", t.Logf), + WithRewrite(targetHost+":443", server.URL[8:]), + WithMITMCA(proxyCAKey, proxyCACert), + WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS13, + RootCAs: serverCACertPool, + }, + }, + }), + WithServerTLSConfig(&tls.Config{ + Certificates: []tls.Certificate{*proxyCert}, + ClientCAs: clientCACertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS13, + })) + err = proxy.StartTLS() + require.NoError(t, err, "error starting proxy") + t.Logf("proxy running on %s", proxy.URL) + + proxyURL, err := url.Parse(proxy.URL) + require.NoErrorf(t, err, "failed to parse proxy URL %q", proxy.URL) + + client := http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + TLSClientConfig: &tls.Config{ + RootCAs: proxyCACertPool, + Certificates: []tls.Certificate{*clientCert}, + MinVersion: tls.VersionTLS12, + }, + }, + } + + return proxy, client, server +} + +func createFakeBackendServer() *httptest.Server { + handlerF := func(writer http.ResponseWriter, request *http.Request) { + // always return HTTP 200 + writer.WriteHeader(http.StatusOK) + } + + fakeBackendHTTPServer := httptest.NewServer(http.HandlerFunc(handlerF)) + return fakeBackendHTTPServer +}