Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ var (
// xDS-enabled server invokes this method on a grpc.Server when a particular
// listener moves to "not-serving" mode.
DrainServerTransports any // func(*grpc.Server, string)
// IsRegisteredMethod returns whether the passed in method is registered as
// a method on the server.
IsRegisteredMethod any // func(*grpc.Server, string)
// ServerFromContext returns the server from the context.
ServerFromContext any // func(context.Context) *Server
// AddGlobalServerOptions adds an array of ServerOption that will be
// effective globally for newly created servers. The priority will be: 1.
// user-provided; 2. this method; 3. default values.
Expand Down
72 changes: 72 additions & 0 deletions internal/stubstatshandler/stubstatshandler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

// Package stubstatshandler is a stubbable implementation of
// google.golang.org/grpc/stats.Handler for testing purposes.
package stubstatshandler

import (
"context"

"google.golang.org/grpc/stats"

testgrpc "google.golang.org/grpc/interop/grpc_testing"
)

// StubStatsHandler is a stats handler that is easy to customize within
// individual test cases.
type StubStatsHandler struct {
// Guarantees we satisfy this interface; panics if unimplemented methods are
// called.
testgrpc.TestServiceServer

TagRPCF func(ctx context.Context, info *stats.RPCTagInfo) context.Context
HandleRPCF func(ctx context.Context, info stats.RPCStats)
TagConnF func(ctx context.Context, info *stats.ConnTagInfo) context.Context
HandleConnF func(ctx context.Context, info stats.ConnStats)
}

// TagRPC calls the StubStatsHandler's TagRPCF, if set.
func (ssh *StubStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
if ssh.TagRPCF != nil {
return ssh.TagRPCF(ctx, info)
}
return ctx
}

// HandleRPC calls the StubStatsHandler's HandleRPCF, if set.
func (ssh *StubStatsHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) {
if ssh.HandleRPCF != nil {
ssh.HandleRPCF(ctx, rs)
}
}

// TagConn calls the StubStatsHandler's TagConnF, if set.
func (ssh *StubStatsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
if ssh.TagConnF != nil {
return ssh.TagConnF(ctx, info)
}
return ctx
}

// HandleConn calls the StubStatsHandler's HandleConnF, if set.
func (ssh *StubStatsHandler) HandleConn(ctx context.Context, cs stats.ConnStats) {
if ssh.HandleConnF != nil {
ssh.HandleConnF(ctx, cs)
}
}
37 changes: 18 additions & 19 deletions internal/transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,14 @@ func (ht *serverHandlerTransport) Close(err error) {

func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) }

func (ht *serverHandlerTransport) LocalAddr() net.Addr { return nil } // Server Handler transport has no access to local addr (was simply not calling sh with local addr).

func (ht *serverHandlerTransport) Peer() *peer.Peer {
return &peer.Peer{
Addr: ht.RemoteAddr(),
}
}

// strAddr is a net.Addr backed by either a TCP "ip:port" string, or
// the empty string if unknown.
type strAddr string
Expand Down Expand Up @@ -347,7 +355,7 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
return err
}

func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
func (ht *serverHandlerTransport) HandleStreams(_ context.Context, startStream func(*Stream)) {
// With this transport type there will be exactly 1 stream: this HTTP request.

ctx := ht.req.Context()
Expand All @@ -371,16 +379,16 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
}()

req := ht.req

s := &Stream{
id: 0, // irrelevant
requestRead: func(int) {},
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
id: 0, // irrelevant
requestRead: func(int) {},
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
headerWireLength: 0, // doesn't know header wire length, will call into stats handler as 0.
}
pr := &peer.Peer{
Addr: ht.RemoteAddr(),
Expand All @@ -390,15 +398,6 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
}
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
s.ctx = peer.NewContext(ctx, pr)
for _, sh := range ht.stats {
s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
inHeader := &stats.InHeader{
FullMethod: s.method,
RemoteAddr: ht.RemoteAddr(),
Compression: s.recvCompress,
}
sh.HandleRPC(s.ctx, inHeader)
}
s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},
windowHandler: func(int) {},
Expand Down
10 changes: 5 additions & 5 deletions internal/transport/handler_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
st.ht.WriteStatus(s, status.New(codes.OK, ""))
}
st.ht.HandleStreams(
func(s *Stream) { go handleStream(s) },
context.Background(), func(s *Stream) { go handleStream(s) },
)
wantHeader := http.Header{
"Date": nil,
Expand Down Expand Up @@ -347,7 +347,7 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string)
st.ht.WriteStatus(s, status.New(statusCode, msg))
}
st.ht.HandleStreams(
func(s *Stream) { go handleStream(s) },
context.Background(), func(s *Stream) { go handleStream(s) },
)
wantHeader := http.Header{
"Date": nil,
Expand Down Expand Up @@ -396,7 +396,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
}
ht.HandleStreams(
func(s *Stream) { go runStream(s) },
context.Background(), func(s *Stream) { go runStream(s) },
)
wantHeader := http.Header{
"Date": nil,
Expand Down Expand Up @@ -448,7 +448,7 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
st := newHandleStreamTest(t)
st.ht.HandleStreams(
func(s *Stream) { go handleStream(st, s) },
context.Background(), func(s *Stream) { go handleStream(st, s) },
)
}

Expand Down Expand Up @@ -481,7 +481,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
hst.ht.WriteStatus(s, st)
}
hst.ht.HandleStreams(
func(s *Stream) { go handleStream(s) },
context.Background(), func(s *Stream) { go handleStream(s) },
)
wantHeader := http.Header{
"Date": nil,
Expand Down
69 changes: 16 additions & 53 deletions internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ var serverConnectionCounter uint64
// http2Server implements the ServerTransport interface with HTTP2.
type http2Server struct {
lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
ctx context.Context
done chan struct{}
conn net.Conn
loopy *loopyWriter
Expand Down Expand Up @@ -244,7 +243,6 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,

done := make(chan struct{})
t := &http2Server{
ctx: setConnection(context.Background(), rawConn),
done: done,
conn: conn,
remoteAddr: conn.RemoteAddr(),
Expand All @@ -267,8 +265,6 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
bufferPool: newBufferPool(),
}
t.logger = prefixLoggerForServerTransport(t)
// Add peer information to the http2server context.
t.ctx = peer.NewContext(t.ctx, t.getPeer())

t.controlBuf = newControlBuffer(t.done)
if dynamicWindow {
Expand All @@ -277,14 +273,6 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
updateFlowControl: t.updateFlowControl,
}
}
for _, sh := range t.stats {
t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{
RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr,
})
connBegin := &stats.ConnBegin{}
sh.HandleConn(t.ctx, connBegin)
}
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr))
if err != nil {
return nil, err
Expand Down Expand Up @@ -342,7 +330,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,

// operateHeaders takes action on the decoded headers. Returns an error if fatal
// error encountered and transport needs to close, otherwise returns nil.
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) error {
func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeadersFrame, handle func(*Stream)) error {
// Acquire max stream ID lock for entire duration
t.maxStreamMu.Lock()
defer t.maxStreamMu.Unlock()
Expand All @@ -369,10 +357,11 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(

buf := newRecvBuffer()
s := &Stream{
id: streamID,
st: t,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
id: streamID,
st: t,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
headerWireLength: int(frame.Header().Length),
}
var (
// if false, content-type was missing or invalid
Expand Down Expand Up @@ -511,9 +500,9 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.state = streamReadDone
}
if timeoutSet {
s.ctx, s.cancel = context.WithTimeout(t.ctx, timeout)
s.ctx, s.cancel = context.WithTimeout(ctx, timeout)
} else {
s.ctx, s.cancel = context.WithCancel(t.ctx)
s.ctx, s.cancel = context.WithCancel(ctx)
}

// Attach the received metadata to the context.
Expand Down Expand Up @@ -592,18 +581,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
}
for _, sh := range t.stats {
s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
inHeader := &stats.InHeader{
FullMethod: s.method,
RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr,
Compression: s.recvCompress,
WireLength: int(frame.Header().Length),
Header: mdata.Copy(),
}
sh.HandleRPC(s.ctx, inHeader)
}
s.ctxDone = s.ctx.Done()
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{
Expand All @@ -629,7 +606,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
// HandleStreams receives incoming streams using the given handler. This is
// typically run in a separate goroutine.
// traceCtx attaches trace to ctx and returns the new context.
func (t *http2Server) HandleStreams(handle func(*Stream)) {
func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) {
defer close(t.readerDone)
for {
t.controlBuf.throttle()
Expand Down Expand Up @@ -664,7 +641,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
}
switch frame := frame.(type) {
case *http2.MetaHeadersFrame:
if err := t.operateHeaders(frame, handle); err != nil {
if err := t.operateHeaders(ctx, frame, handle); err != nil {
t.Close(err)
break
}
Expand Down Expand Up @@ -1242,10 +1219,6 @@ func (t *http2Server) Close(err error) {
for _, s := range streams {
s.cancel()
}
for _, sh := range t.stats {
connEnd := &stats.ConnEnd{}
sh.HandleConn(t.ctx, connEnd)
}
}

// deleteStream deletes the stream s from transport's active streams.
Expand Down Expand Up @@ -1311,6 +1284,10 @@ func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eo
})
}

func (t *http2Server) LocalAddr() net.Addr {
return t.localAddr
}

func (t *http2Server) RemoteAddr() net.Addr {
return t.remoteAddr
}
Expand Down Expand Up @@ -1433,7 +1410,8 @@ func (t *http2Server) getOutFlowWindow() int64 {
}
}

func (t *http2Server) getPeer() *peer.Peer {
// Peer returns the peer of the transport.
func (t *http2Server) Peer() *peer.Peer {
return &peer.Peer{
Addr: t.remoteAddr,
AuthInfo: t.authInfo, // Can be nil
Expand All @@ -1449,18 +1427,3 @@ func getJitter(v time.Duration) time.Duration {
j := grpcrand.Int63n(2*r) - r
return time.Duration(j)
}

type connectionKey struct{}

// GetConnection gets the connection from the context.
func GetConnection(ctx context.Context) net.Conn {
conn, _ := ctx.Value(connectionKey{}).(net.Conn)
return conn
}

// SetConnection adds the connection to the context to be able to get
// information about the destination ip and port for an incoming RPC. This also
// allows any unary or streaming interceptors to see the connection.
func setConnection(ctx context.Context, conn net.Conn) context.Context {
return context.WithValue(ctx, connectionKey{}, conn)
}
Loading