diff --git a/src/bootstrap/cloud/terraform/service-account.tf b/src/bootstrap/cloud/terraform/service-account.tf index b5cd88164..0c538e74f 100644 --- a/src/bootstrap/cloud/terraform/service-account.tf +++ b/src/bootstrap/cloud/terraform/service-account.tf @@ -60,7 +60,7 @@ resource "google_project_iam_member" "robot-service-roles" { project = data.google_project.project.project_id member = "serviceAccount:${google_service_account.robot-service.email}" for_each = toset([ - "roles/cloudtrace.agent", # Upload cloud traces + "roles/cloudtrace.agent", # Upload cloud traces "roles/logging.logWriter", # Upload text logs to Cloud logging # Required to use robot-service@ for GKE clusters that simulate robots "roles/monitoring.viewer", diff --git a/src/go/cmd/http-relay-client/client/client.go b/src/go/cmd/http-relay-client/client/client.go index f29bef878..24b3c344a 100644 --- a/src/go/cmd/http-relay-client/client/client.go +++ b/src/go/cmd/http-relay-client/client/client.go @@ -91,6 +91,18 @@ type ClientConfig struct { ForceHttp2 bool } +type RelayServerError struct { + msg string +} + +func NewRelayServerError(msg string) error { + return &RelayServerError{msg} +} + +func (e *RelayServerError) Error() string { + return e.msg +} + func DefaultClientConfig() ClientConfig { return ClientConfig{ RemoteRequestTimeout: 60 * time.Second, @@ -384,7 +396,7 @@ func (c *Client) postResponse(remote *http.Client, br *pb.HttpResponse) error { return fmt.Errorf("couldn't read relay server's response body: %v", err) } if resp.StatusCode != http.StatusOK { - err := fmt.Errorf("relay server responded %s: %s", http.StatusText(resp.StatusCode), body) + err := NewRelayServerError(fmt.Sprintf("relay server responded %s: %s", http.StatusText(resp.StatusCode), body)) if resp.StatusCode == http.StatusBadRequest { // http-relay-server may have restarted during the request. return backoff.Permanent(err) @@ -643,8 +655,11 @@ func (c *Client) handleRequest(remote *http.Client, local *http.Client, pbreq *p log.Printf("[%s] Failed to post response to relay: %v", *resp.Id, err) }, ) - if _, ok := err.(*backoff.PermanentError); ok { + if _, ok := err.(*RelayServerError); ok { // A permanent error suggests the request should be aborted. + log.Printf("[%s] %s", *resp.Id, err) + log.Printf("[%s] Closing backend connection", *resp.Id) + hresp.Body.Close() break } } diff --git a/src/go/cmd/http-relay-server/server/broker.go b/src/go/cmd/http-relay-server/server/broker.go index 94c20f0dd..b1d5b8dd3 100644 --- a/src/go/cmd/http-relay-server/server/broker.go +++ b/src/go/cmd/http-relay-server/server/broker.go @@ -171,6 +171,12 @@ func (r *broker) RelayRequest(server string, request *pb.HttpRequest) (<-chan *p } } +// StopRelayRequest forgets a relaying request, this causes the next chunk from the backend +// with the relay id to not be recognized, resulting in the relay server returning an error. +func (r *broker) StopRelayRequest(requestId string) { + delete(r.resp, requestId) +} + // GetRequest obtains a client's request for the server identifier. It blocks // until a client makes a request. func (r *broker) GetRequest(ctx context.Context, server, path string) (*pb.HttpRequest, error) { diff --git a/src/go/cmd/http-relay-server/server/server.go b/src/go/cmd/http-relay-server/server/server.go index 913a37f2a..cce7f1076 100644 --- a/src/go/cmd/http-relay-server/server/server.go +++ b/src/go/cmd/http-relay-server/server/server.go @@ -257,8 +257,10 @@ func (s *Server) bidirectionalStream(backendCtx backendContext, w http.ResponseW numBytes := 0 for responseChunk := range responseChunks { - // TODO(b/130706300): detect dropped connection and end request in broker - _, _ = bufrw.Write(responseChunk.Body) + if _, err = w.Write(responseChunk.Body); err != nil { + log.Printf("[%s] %s", backendCtx.Id, err) + return + } bufrw.Flush() numBytes += len(responseChunk.Body) } @@ -378,6 +380,7 @@ func (s *Server) userClientRequest(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } + defer s.b.StopRelayRequest(backendCtx.Id) header, responseChunksChan, done := s.waitForFirstResponseAndHandleSwitching(ctx, *backendCtx, w, backendRespChan) if done { @@ -392,8 +395,10 @@ func (s *Server) userClientRequest(w http.ResponseWriter, r *http.Request) { // i.e. this will block until numBytes := 0 for responseChunk := range responseChunksChan { - // TODO(b/130706300): detect dropped connection and end request in broker - _, _ = w.Write(responseChunk.Body) + if _, err = w.Write(responseChunk.Body); err != nil { + log.Printf("[%s] %s", backendCtx.Id, err) + return + } if flush, ok := w.(http.Flusher); ok { flush.Flush() } diff --git a/src/go/tests/relay/nok8s_relay_test.go b/src/go/tests/relay/nok8s_relay_test.go index ef6cf68fe..eb071aa6c 100644 --- a/src/go/tests/relay/nok8s_relay_test.go +++ b/src/go/tests/relay/nok8s_relay_test.go @@ -15,6 +15,7 @@ package main import ( + "bufio" "bytes" "context" "fmt" @@ -128,7 +129,7 @@ func (r *relay) stop() error { } // TestHttpRelay launches a local http relay (client + server) and connects a -// test-hhtp-server as a backend. The test is then interacting with the backend +// test-http-server as a backend. The test is then interacting with the backend // through the local relay. func TestHttpRelay(t *testing.T) { tests := []struct { @@ -201,6 +202,67 @@ func TestHttpRelay(t *testing.T) { } } +// TestDroppedUserClientFreesRelayChannel checks that when the user client closes a connection, +// it is propagated to the relay server and client, closing the backend connection as well. +func TestDroppedUserClientFreesRelayChannel(t *testing.T) { + // setup http test server + connClosed := make(chan error) + defer close(connClosed) + finishServer := make(chan bool) + defer close(finishServer) + + // mock a long running backend that uses chunking to send periodic updates + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for { + select { + case <-finishServer: + return + default: + if _, err := fmt.Fprintln(w, "DEADBEEF"); err != nil { + connClosed <- err + return + } + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } else { + t.Fatal("cannot flush") + } + time.Sleep(time.Second) + } + } + })) + defer func() { ts.Close() }() + + backendAddress := strings.TrimPrefix(ts.URL, "http://") + r := &relay{} + if err := r.start(backendAddress); err != nil { + t.Fatal("failed to start relay: ", err) + } + defer func() { + if err := r.stop(); err != nil { + t.Fatal("failed to stop relay: ", err) + } + }() + relayAddress := "http://127.0.0.1:" + r.rsPort + + res, err := http.Get(relayAddress + "/client/remote1/") + if err != nil { + t.Fatal(err) + } + // receive the first chunk then terminates the connection + if _, err := bufio.NewReader(res.Body).ReadString('\n'); err != nil { + t.Fatal(err) + } + res.Body.Close() + + // wait for up to 30s for backend connection to be closed + select { + case <-connClosed: + case <-time.After(30 * time.Second): + t.Error("Server did not close connection") + } +} + type testServer struct { testpb.UnimplementedTestServiceServer responsePayload []byte