From 46af5f942db9655ae63297edfc3bfb532b884652 Mon Sep 17 00:00:00 2001 From: Zach Reyes Date: Tue, 10 Oct 2023 19:21:53 -0400 Subject: [PATCH 1/2] Move conn begin conn end and in header to grpc layer --- internal/transport/handler_server.go | 37 ++++++------ internal/transport/handler_server_test.go | 10 ++-- internal/transport/http2_server.go | 69 ++++++----------------- internal/transport/transport.go | 24 +++++++- internal/transport/transport_test.go | 67 +++------------------- internal/xds/rbac/rbac_engine.go | 3 +- server.go | 58 +++++++++++++++++-- test/end2end_test.go | 5 +- xds/server.go | 3 +- 9 files changed, 125 insertions(+), 151 deletions(-) diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index 17f7a21b5a9f..67b987a5a407 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -167,6 +167,14 @@ func (ht *serverHandlerTransport) Close(err error) { func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) } +func (ht *serverHandlerTransport) LocalAddr() net.Addr { return nil } // Server Handler transport has no access to local addr (was simply not calling sh with local addr). + +func (ht *serverHandlerTransport) Peer() *peer.Peer { + return &peer.Peer{ + Addr: ht.RemoteAddr(), + } +} + // strAddr is a net.Addr backed by either a TCP "ip:port" string, or // the empty string if unknown. type strAddr string @@ -347,7 +355,7 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { return err } -func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) { +func (ht *serverHandlerTransport) HandleStreams(_ context.Context, startStream func(*Stream)) { // With this transport type there will be exactly 1 stream: this HTTP request. ctx := ht.req.Context() @@ -371,16 +379,16 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) { }() req := ht.req - s := &Stream{ - id: 0, // irrelevant - requestRead: func(int) {}, - cancel: cancel, - buf: newRecvBuffer(), - st: ht, - method: req.URL.Path, - recvCompress: req.Header.Get("grpc-encoding"), - contentSubtype: ht.contentSubtype, + id: 0, // irrelevant + requestRead: func(int) {}, + cancel: cancel, + buf: newRecvBuffer(), + st: ht, + method: req.URL.Path, + recvCompress: req.Header.Get("grpc-encoding"), + contentSubtype: ht.contentSubtype, + headerWireLength: 0, // doesn't know header wire length, will call into stats handler as 0. } pr := &peer.Peer{ Addr: ht.RemoteAddr(), @@ -390,15 +398,6 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) { } ctx = metadata.NewIncomingContext(ctx, ht.headerMD) s.ctx = peer.NewContext(ctx, pr) - for _, sh := range ht.stats { - s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) - inHeader := &stats.InHeader{ - FullMethod: s.method, - RemoteAddr: ht.RemoteAddr(), - Compression: s.recvCompress, - } - sh.HandleRPC(s.ctx, inHeader) - } s.trReader = &transportReader{ reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}}, windowHandler: func(int) {}, diff --git a/internal/transport/handler_server_test.go b/internal/transport/handler_server_test.go index bf67480b318f..4d92e9b9c042 100644 --- a/internal/transport/handler_server_test.go +++ b/internal/transport/handler_server_test.go @@ -314,7 +314,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) { st.ht.WriteStatus(s, status.New(codes.OK, "")) } st.ht.HandleStreams( - func(s *Stream) { go handleStream(s) }, + context.Background(), func(s *Stream) { go handleStream(s) }, ) wantHeader := http.Header{ "Date": nil, @@ -347,7 +347,7 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) st.ht.WriteStatus(s, status.New(statusCode, msg)) } st.ht.HandleStreams( - func(s *Stream) { go handleStream(s) }, + context.Background(), func(s *Stream) { go handleStream(s) }, ) wantHeader := http.Header{ "Date": nil, @@ -396,7 +396,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow")) } ht.HandleStreams( - func(s *Stream) { go runStream(s) }, + context.Background(), func(s *Stream) { go runStream(s) }, ) wantHeader := http.Header{ "Date": nil, @@ -448,7 +448,7 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) { func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) { st := newHandleStreamTest(t) st.ht.HandleStreams( - func(s *Stream) { go handleStream(st, s) }, + context.Background(), func(s *Stream) { go handleStream(st, s) }, ) } @@ -481,7 +481,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) { hst.ht.WriteStatus(s, st) } hst.ht.HandleStreams( - func(s *Stream) { go handleStream(s) }, + context.Background(), func(s *Stream) { go handleStream(s) }, ) wantHeader := http.Header{ "Date": nil, diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 6fa1eb41992a..d0421304b4dc 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -69,7 +69,6 @@ var serverConnectionCounter uint64 // http2Server implements the ServerTransport interface with HTTP2. type http2Server struct { lastRead int64 // Keep this field 64-bit aligned. Accessed atomically. - ctx context.Context done chan struct{} conn net.Conn loopy *loopyWriter @@ -244,7 +243,6 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, done := make(chan struct{}) t := &http2Server{ - ctx: setConnection(context.Background(), rawConn), done: done, conn: conn, remoteAddr: conn.RemoteAddr(), @@ -267,8 +265,6 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, bufferPool: newBufferPool(), } t.logger = prefixLoggerForServerTransport(t) - // Add peer information to the http2server context. - t.ctx = peer.NewContext(t.ctx, t.getPeer()) t.controlBuf = newControlBuffer(t.done) if dynamicWindow { @@ -277,14 +273,6 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, updateFlowControl: t.updateFlowControl, } } - for _, sh := range t.stats { - t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{ - RemoteAddr: t.remoteAddr, - LocalAddr: t.localAddr, - }) - connBegin := &stats.ConnBegin{} - sh.HandleConn(t.ctx, connBegin) - } t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr)) if err != nil { return nil, err @@ -342,7 +330,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, // operateHeaders takes action on the decoded headers. Returns an error if fatal // error encountered and transport needs to close, otherwise returns nil. -func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) error { +func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeadersFrame, handle func(*Stream)) error { // Acquire max stream ID lock for entire duration t.maxStreamMu.Lock() defer t.maxStreamMu.Unlock() @@ -369,10 +357,11 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( buf := newRecvBuffer() s := &Stream{ - id: streamID, - st: t, - buf: buf, - fc: &inFlow{limit: uint32(t.initialWindowSize)}, + id: streamID, + st: t, + buf: buf, + fc: &inFlow{limit: uint32(t.initialWindowSize)}, + headerWireLength: int(frame.Header().Length), } var ( // if false, content-type was missing or invalid @@ -511,9 +500,9 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( s.state = streamReadDone } if timeoutSet { - s.ctx, s.cancel = context.WithTimeout(t.ctx, timeout) + s.ctx, s.cancel = context.WithTimeout(ctx, timeout) } else { - s.ctx, s.cancel = context.WithCancel(t.ctx) + s.ctx, s.cancel = context.WithCancel(ctx) } // Attach the received metadata to the context. @@ -592,18 +581,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( s.requestRead = func(n int) { t.adjustWindow(s, uint32(n)) } - for _, sh := range t.stats { - s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) - inHeader := &stats.InHeader{ - FullMethod: s.method, - RemoteAddr: t.remoteAddr, - LocalAddr: t.localAddr, - Compression: s.recvCompress, - WireLength: int(frame.Header().Length), - Header: mdata.Copy(), - } - sh.HandleRPC(s.ctx, inHeader) - } s.ctxDone = s.ctx.Done() s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) s.trReader = &transportReader{ @@ -629,7 +606,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( // HandleStreams receives incoming streams using the given handler. This is // typically run in a separate goroutine. // traceCtx attaches trace to ctx and returns the new context. -func (t *http2Server) HandleStreams(handle func(*Stream)) { +func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) { defer close(t.readerDone) for { t.controlBuf.throttle() @@ -664,7 +641,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { } switch frame := frame.(type) { case *http2.MetaHeadersFrame: - if err := t.operateHeaders(frame, handle); err != nil { + if err := t.operateHeaders(ctx, frame, handle); err != nil { t.Close(err) break } @@ -1242,10 +1219,6 @@ func (t *http2Server) Close(err error) { for _, s := range streams { s.cancel() } - for _, sh := range t.stats { - connEnd := &stats.ConnEnd{} - sh.HandleConn(t.ctx, connEnd) - } } // deleteStream deletes the stream s from transport's active streams. @@ -1311,6 +1284,10 @@ func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eo }) } +func (t *http2Server) LocalAddr() net.Addr { + return t.localAddr +} + func (t *http2Server) RemoteAddr() net.Addr { return t.remoteAddr } @@ -1433,7 +1410,8 @@ func (t *http2Server) getOutFlowWindow() int64 { } } -func (t *http2Server) getPeer() *peer.Peer { +// Peer returns the peer of the transport. +func (t *http2Server) Peer() *peer.Peer { return &peer.Peer{ Addr: t.remoteAddr, AuthInfo: t.authInfo, // Can be nil @@ -1449,18 +1427,3 @@ func getJitter(v time.Duration) time.Duration { j := grpcrand.Int63n(2*r) - r return time.Duration(j) } - -type connectionKey struct{} - -// GetConnection gets the connection from the context. -func GetConnection(ctx context.Context) net.Conn { - conn, _ := ctx.Value(connectionKey{}).(net.Conn) - return conn -} - -// SetConnection adds the connection to the context to be able to get -// information about the destination ip and port for an incoming RPC. This also -// allows any unary or streaming interceptors to see the connection. -func setConnection(ctx context.Context, conn net.Conn) context.Context { - return context.WithValue(ctx, connectionKey{}, conn) -} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index aac056e723bb..6e46ff6dceae 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -37,6 +37,7 @@ import ( "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" "google.golang.org/grpc/resolver" "google.golang.org/grpc/stats" "google.golang.org/grpc/status" @@ -265,7 +266,8 @@ type Stream struct { // headerValid indicates whether a valid header was received. Only // meaningful after headerChan is closed (always call waitOnHeader() before // reading its value). Not valid on server side. - headerValid bool + headerValid bool + headerWireLength int // Only set on server side. // hdrMu protects header and trailer metadata on the server-side. hdrMu sync.Mutex @@ -425,6 +427,12 @@ func (s *Stream) Context() context.Context { return s.ctx } +// SetContext sets the context of the stream. This will be deleted once the +// stats handler callouts all move to gRPC layer. +func (s *Stream) SetContext(ctx context.Context) { + s.ctx = ctx +} + // Method returns the method for the stream. func (s *Stream) Method() string { return s.method @@ -437,6 +445,12 @@ func (s *Stream) Status() *status.Status { return s.status } +// HeaderWireLength returns the size of the headers of the stream as received +// from the wire. Valid only on the server. +func (s *Stream) HeaderWireLength() int { + return s.headerWireLength +} + // SetHeader sets the header metadata. This can be called multiple times. // Server side only. // This should not be called in parallel to other data writes. @@ -698,7 +712,7 @@ type ClientTransport interface { // Write methods for a given Stream will be called serially. type ServerTransport interface { // HandleStreams receives incoming streams using the given handler. - HandleStreams(func(*Stream)) + HandleStreams(context.Context, func(*Stream)) // WriteHeader sends the header metadata for the given stream. // WriteHeader may not be called on all streams. @@ -717,9 +731,15 @@ type ServerTransport interface { // handlers will be terminated asynchronously. Close(err error) + // LocalAddr returns the local network address. + LocalAddr() net.Addr + // RemoteAddr returns the remote network address. RemoteAddr() net.Addr + // Peer returns the peer of the server transport. + Peer() *peer.Peer + // Drain notifies the client this ServerTransport stops accepting new RPCs. Drain(debugData string) diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 4488482fc09e..20292bd9f0f8 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -35,8 +35,6 @@ import ( "testing" "time" - "google.golang.org/grpc/peer" - "github.com/google/go-cmp/cmp" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" @@ -356,19 +354,19 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT s.mu.Unlock() switch ht { case notifyCall: - go transport.HandleStreams(h.handleStreamAndNotify) + go transport.HandleStreams(context.Background(), h.handleStreamAndNotify) case suspended: - go transport.HandleStreams(func(*Stream) {}) + go transport.HandleStreams(context.Background(), func(*Stream) {}) case misbehaved: - go transport.HandleStreams(func(s *Stream) { + go transport.HandleStreams(context.Background(), func(s *Stream) { go h.handleStreamMisbehave(t, s) }) case encodingRequiredStatus: - go transport.HandleStreams(func(s *Stream) { + go transport.HandleStreams(context.Background(), func(s *Stream) { go h.handleStreamEncodingRequiredStatus(s) }) case invalidHeaderField: - go transport.HandleStreams(func(s *Stream) { + go transport.HandleStreams(context.Background(), func(s *Stream) { go h.handleStreamInvalidHeaderField(s) }) case delayRead: @@ -377,15 +375,15 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT s.mu.Lock() close(s.ready) s.mu.Unlock() - go transport.HandleStreams(func(s *Stream) { + go transport.HandleStreams(context.Background(), func(s *Stream) { go h.handleStreamDelayRead(t, s) }) case pingpong: - go transport.HandleStreams(func(s *Stream) { + go transport.HandleStreams(context.Background(), func(s *Stream) { go h.handleStreamPingPong(t, s) }) default: - go transport.HandleStreams(func(s *Stream) { + go transport.HandleStreams(context.Background(), func(s *Stream) { go h.handleStream(t, s) }) } @@ -2594,52 +2592,3 @@ func TestConnectionError_Unwrap(t *testing.T) { t.Error("ConnectionError does not unwrap") } } - -func (s) TestPeerSetInServerContext(t *testing.T) { - // create client and server transports. - server, client, cancel := setUp(t, 0, normal) - defer cancel() - defer server.stop() - defer client.Close(fmt.Errorf("closed manually by test")) - - // create a stream with client transport. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - stream, err := client.NewStream(ctx, &CallHdr{}) - if err != nil { - t.Fatalf("failed to create a stream: %v", err) - } - - waitWhileTrue(t, func() (bool, error) { - server.mu.Lock() - defer server.mu.Unlock() - - if len(server.conns) == 0 { - return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") - } - return false, nil - }) - - // verify peer is set in client transport context. - if _, ok := peer.FromContext(client.ctx); !ok { - t.Fatalf("Peer expected in client transport's context, but actually not found.") - } - - // verify peer is set in stream context. - if _, ok := peer.FromContext(stream.ctx); !ok { - t.Fatalf("Peer expected in stream context, but actually not found.") - } - - // verify peer is set in server transport context. - server.mu.Lock() - for k := range server.conns { - sc, ok := k.(*http2Server) - if !ok { - t.Fatalf("ServerTransport is of type %T, want %T", k, &http2Server{}) - } - if _, ok = peer.FromContext(sc.ctx); !ok { - t.Fatalf("Peer expected in server transport's context, but actually not found.") - } - } - server.mu.Unlock() -} diff --git a/internal/xds/rbac/rbac_engine.go b/internal/xds/rbac/rbac_engine.go index 63237affe23f..0d7f6c3b077c 100644 --- a/internal/xds/rbac/rbac_engine.go +++ b/internal/xds/rbac/rbac_engine.go @@ -34,7 +34,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" @@ -42,7 +41,7 @@ import ( var logger = grpclog.Component("rbac") -var getConnection = transport.GetConnection +var getConnection = grpc.GetConnection // ChainEngine represents a chain of RBAC Engines, used to make authorization // decisions on incoming RPCs. diff --git a/server.go b/server.go index 8f60d421437d..1a5e6fd765ac 100644 --- a/server.go +++ b/server.go @@ -917,7 +917,7 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) { return } go func() { - s.serveStreams(st) + s.serveStreams(st, rawConn) s.removeConn(lisAddr, st) }() } @@ -971,12 +971,42 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { return st } -func (s *Server) serveStreams(st transport.ServerTransport) { - defer st.Close(errors.New("finished serving streams for the server transport")) - var wg sync.WaitGroup +type connectionKey struct{} + +// GetConnection gets the connection from the context. +func GetConnection(ctx context.Context) net.Conn { + conn, _ := ctx.Value(connectionKey{}).(net.Conn) + return conn +} + +// setConnection adds the connection to the context to be able to get +// information about the destination ip and port for an incoming RPC. This also +// allows any unary or streaming interceptors to see the connection. +func setConnection(ctx context.Context, conn net.Conn) context.Context { + return context.WithValue(ctx, connectionKey{}, conn) +} + +func (s *Server) serveStreams(st transport.ServerTransport, rawConn net.Conn) { + ctx := setConnection(context.Background(), rawConn) + ctx = peer.NewContext(ctx, st.Peer()) + for _, sh := range s.opts.statsHandlers { + ctx = sh.TagConn(ctx, &stats.ConnTagInfo{ + RemoteAddr: st.RemoteAddr(), + LocalAddr: st.LocalAddr(), + }) + sh.HandleConn(ctx, &stats.ConnBegin{}) + } + defer func() { + st.Close(errors.New("finished serving streams for the server transport")) + for _, sh := range s.opts.statsHandlers { + sh.HandleConn(ctx, &stats.ConnEnd{}) + } + }() + + var wg sync.WaitGroup streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams) - st.HandleStreams(func(stream *transport.Stream) { + st.HandleStreams(ctx, func(stream *transport.Stream) { wg.Add(1) streamQuota.acquire() @@ -1040,7 +1070,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } defer s.removeConn(listenerAddressForServeHTTP, st) - s.serveStreams(st) + s.serveStreams(st, nil) } func (s *Server) addConn(addr string, st transport.ServerTransport) bool { @@ -1731,6 +1761,22 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str service := sm[:pos] method := sm[pos+1:] + md, _ := metadata.FromIncomingContext(ctx) + for _, sh := range s.opts.statsHandlers { + ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()}) + sh.HandleRPC(ctx, &stats.InHeader{ + FullMethod: stream.Method(), + RemoteAddr: t.RemoteAddr(), + LocalAddr: t.LocalAddr(), + Compression: stream.RecvCompress(), + WireLength: stream.HeaderWireLength(), + Header: md, + }) + } + // To have calls in stream callouts work. Will delete once all stats handler + // calls come from the gRPC layer. + stream.SetContext(ctx) + srv, knownService := s.services[service] if knownService { if md, ok := srv.methods[method]; ok { diff --git a/test/end2end_test.go b/test/end2end_test.go index 1dd5757c7b9f..dd85d6d0c952 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -60,7 +60,6 @@ import ( "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/internal/testutils" - "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/resolver" @@ -5834,7 +5833,7 @@ func (s) TestClientSettingsFloodCloseConn(t *testing.T) { } func unaryInterceptorVerifyConn(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - conn := transport.GetConnection(ctx) + conn := grpc.GetConnection(ctx) if conn == nil { return nil, status.Error(codes.NotFound, "connection was not in context") } @@ -5859,7 +5858,7 @@ func (s) TestUnaryServerInterceptorGetsConnection(t *testing.T) { } func streamingInterceptorVerifyConn(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - conn := transport.GetConnection(ss.Context()) + conn := grpc.GetConnection(ss.Context()) if conn == nil { return status.Error(codes.NotFound, "connection was not in context") } diff --git a/xds/server.go b/xds/server.go index fe2138c8bc24..2496d88bfb9c 100644 --- a/xds/server.go +++ b/xds/server.go @@ -35,7 +35,6 @@ import ( internalgrpclog "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" iresolver "google.golang.org/grpc/internal/resolver" - "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/grpc/xds/internal/server" @@ -341,7 +340,7 @@ func (s *GRPCServer) GracefulStop() { // table and also processes the RPC by running the incoming RPC through any HTTP // Filters configured. func routeAndProcess(ctx context.Context) error { - conn := transport.GetConnection(ctx) + conn := grpc.GetConnection(ctx) cw, ok := conn.(interface { VirtualHosts() []xdsresource.VirtualHostWithInterceptors }) From 06ab8e11d53262493004152a75df80aa1f196fcd Mon Sep 17 00:00:00 2001 From: Zach Reyes Date: Thu, 12 Oct 2023 14:27:23 -0400 Subject: [PATCH 2/2] Add server is registered method --- internal/internal.go | 5 ++ internal/stubstatshandler/stubstatshandler.go | 72 +++++++++++++++++ server.go | 43 ++++++++++ stats/stats_test.go | 79 +++++++++++++++++++ 4 files changed, 199 insertions(+) create mode 100644 internal/stubstatshandler/stubstatshandler.go diff --git a/internal/internal.go b/internal/internal.go index 0d94c63e06e2..70663c889a30 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -73,6 +73,11 @@ var ( // xDS-enabled server invokes this method on a grpc.Server when a particular // listener moves to "not-serving" mode. DrainServerTransports any // func(*grpc.Server, string) + // IsRegisteredMethod returns whether the passed in method is registered as + // a method on the server. + IsRegisteredMethod any // func(*grpc.Server, string) + // ServerFromContext returns the server from the context. + ServerFromContext any // func(context.Context) *Server // AddGlobalServerOptions adds an array of ServerOption that will be // effective globally for newly created servers. The priority will be: 1. // user-provided; 2. this method; 3. default values. diff --git a/internal/stubstatshandler/stubstatshandler.go b/internal/stubstatshandler/stubstatshandler.go new file mode 100644 index 000000000000..b3cd0ab18e34 --- /dev/null +++ b/internal/stubstatshandler/stubstatshandler.go @@ -0,0 +1,72 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package stubstatshandler is a stubbable implementation of +// google.golang.org/grpc/stats.Handler for testing purposes. +package stubstatshandler + +import ( + "context" + + "google.golang.org/grpc/stats" + + testgrpc "google.golang.org/grpc/interop/grpc_testing" +) + +// StubStatsHandler is a stats handler that is easy to customize within +// individual test cases. +type StubStatsHandler struct { + // Guarantees we satisfy this interface; panics if unimplemented methods are + // called. + testgrpc.TestServiceServer + + TagRPCF func(ctx context.Context, info *stats.RPCTagInfo) context.Context + HandleRPCF func(ctx context.Context, info stats.RPCStats) + TagConnF func(ctx context.Context, info *stats.ConnTagInfo) context.Context + HandleConnF func(ctx context.Context, info stats.ConnStats) +} + +// TagRPC calls the StubStatsHandler's TagRPCF, if set. +func (ssh *StubStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { + if ssh.TagRPCF != nil { + return ssh.TagRPCF(ctx, info) + } + return ctx +} + +// HandleRPC calls the StubStatsHandler's HandleRPCF, if set. +func (ssh *StubStatsHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) { + if ssh.HandleRPCF != nil { + ssh.HandleRPCF(ctx, rs) + } +} + +// TagConn calls the StubStatsHandler's TagConnF, if set. +func (ssh *StubStatsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context { + if ssh.TagConnF != nil { + return ssh.TagConnF(ctx, info) + } + return ctx +} + +// HandleConn calls the StubStatsHandler's HandleConnF, if set. +func (ssh *StubStatsHandler) HandleConn(ctx context.Context, cs stats.ConnStats) { + if ssh.HandleConnF != nil { + ssh.HandleConnF(ctx, cs) + } +} diff --git a/server.go b/server.go index 1a5e6fd765ac..9fd7dc90ca8c 100644 --- a/server.go +++ b/server.go @@ -73,6 +73,10 @@ func init() { internal.DrainServerTransports = func(srv *Server, addr string) { srv.drainServerTransports(addr) } + internal.IsRegisteredMethod = func(srv *Server, method string) bool { + return srv.isRegisteredMethod(method) + } + internal.ServerFromContext = serverFromContext internal.AddGlobalServerOptions = func(opt ...ServerOption) { globalServerOptions = append(globalServerOptions, opt...) } @@ -1719,6 +1723,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTran func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) { ctx := stream.Context() + ctx = contextWithServer(ctx, s) var ti *traceInfo if EnableTracing { tr := trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()) @@ -1964,6 +1969,44 @@ func (s *Server) getCodec(contentSubtype string) baseCodec { return codec } +type serverKey struct{} + +// serverFromContext gets the Server from the context. +func serverFromContext(ctx context.Context) *Server { + s, _ := ctx.Value(serverKey{}).(*Server) + return s +} + +// contextWithServer sets the Server in the context. +func contextWithServer(ctx context.Context, server *Server) context.Context { + return context.WithValue(ctx, serverKey{}, server) +} + +// isRegisteredMethod returns whether the passed in method is registered as a +// method on the server. /service/method and service/method will match if the +// service and method are registered on the server. +func (s *Server) isRegisteredMethod(serviceMethod string) bool { + if serviceMethod != "" && serviceMethod[0] == '/' { + serviceMethod = serviceMethod[1:] + } + pos := strings.LastIndex(serviceMethod, "/") + if pos == -1 { // Invalid method name syntax. + return false + } + service := serviceMethod[:pos] + method := serviceMethod[pos+1:] + srv, knownService := s.services[service] + if knownService { + if _, ok := srv.methods[method]; ok { + return true + } + if _, ok := srv.streams[method]; ok { + return true + } + } + return false +} + // SetHeader sets the header metadata to be sent from the server to the client. // The context provided must be the context passed to the server's handler. // diff --git a/stats/stats_test.go b/stats/stats_test.go index 903c3ed7a774..2be5ba43f5d3 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -20,6 +20,7 @@ package stats_test import ( "context" + "errors" "fmt" "io" "net" @@ -31,7 +32,11 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/stubstatshandler" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" "google.golang.org/grpc/status" @@ -1457,3 +1462,77 @@ func (s) TestMultipleServerStatsHandler(t *testing.T) { t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4) } } + +// TestStatsHandlerCallsServerIsRegisteredMethod tests whether a stats handler +// gets access to a Server on the server side, and thus the method that the +// server owns which specifies whether a method is made or not. The test sets up +// a server with a unary call and full duplex call configured, and makes an RPC. +// Within the stats handler, asking the server whether unary or duplex method +// names are registered should return true, and any other query should return +// false. +func (s) TestStatsHandlerCallsServerIsRegisteredMethod(t *testing.T) { + errorCh := testutils.NewChannel() + stubStatsHandler := &stubstatshandler.StubStatsHandler{ + TagRPCF: func(ctx context.Context, _ *stats.RPCTagInfo) context.Context { + // OpenTelemetry instrumentation needs the passed in Server to determine if + // methods are registered in different handle calls in to record metrics. + // This tag RPC call context gets passed into every handle call, so can + // assert once here, since it maps to all the handle RPC calls that come + // after. These internal calls will be how the OpenTelemetry instrumentation + // component accesses this server and the subsequent helper on the server. + server := internal.ServerFromContext.(func(context.Context) *grpc.Server)(ctx) + if server == nil { + errorCh.Send("stats handler received ctx has no server present") + } + isRegisteredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool) + // /s/m and s/m are valid. + if !isRegisteredMethod(server, "/grpc.testing.TestService/UnaryCall") { + errorCh.Send(errors.New("UnaryCall should be a registered method according to server")) + return ctx + } + if !isRegisteredMethod(server, "grpc.testing.TestService/FullDuplexCall") { + errorCh.Send(errors.New("FullDuplexCall should be a registered method according to server")) + return ctx + } + if isRegisteredMethod(server, "/grpc.testing.TestService/DoesNotExistCall") { + errorCh.Send(errors.New("DoesNotExistCall should not be a registered method according to server")) + return ctx + } + if isRegisteredMethod(server, "/unknownService/UnaryCall") { + errorCh.Send(errors.New("/unknownService/UnaryCall should not be a registered method according to server")) + return ctx + } + errorCh.Send(nil) + return ctx + }, + } + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + for { + if _, err := stream.Recv(); err == io.EOF { + return nil + } + } + }, + } + if err := ss.Start([]grpc.ServerOption{grpc.StatsHandler(stubStatsHandler)}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: &testpb.Payload{}}); err != nil { + t.Fatalf("Unexpected error from UnaryCall: %v", err) + } + err, errRecv := errorCh.Receive(ctx) + if errRecv != nil { + t.Fatalf("error receiving from channel: %v", errRecv) + } + if err != nil { + t.Fatalf("error received from error channel: %v", err) + } +}