diff --git a/answer_test.go b/answer_test.go index 24fe4ac8..f9bf4916 100644 --- a/answer_test.go +++ b/answer_test.go @@ -54,11 +54,14 @@ func TestPromiseReject(t *testing.T) { } func TestPromiseFulfill(t *testing.T) { + t.Parallel() + t.Run("Done", func(t *testing.T) { p := NewPromise(dummyMethod, dummyPipelineCaller{}) done := p.Answer().Done() msg, seg, _ := NewMessage(SingleSegment(nil)) - defer msg.Reset(nil) + defer msg.Release() + res, _ := NewStruct(seg, ObjectSize{DataSize: 8}) p.Fulfill(res.ToPtr()) select { @@ -73,7 +76,8 @@ func TestPromiseFulfill(t *testing.T) { defer p.ReleaseClients() ans := p.Answer() msg, seg, _ := NewMessage(SingleSegment(nil)) - defer msg.Reset(nil) + defer msg.Release() + res, _ := NewStruct(seg, ObjectSize{DataSize: 8}) res.SetUint32(0, 0xdeadbeef) p.Fulfill(res.ToPtr()) @@ -96,7 +100,8 @@ func TestPromiseFulfill(t *testing.T) { c := NewClient(h) defer c.Release() msg, seg, _ := NewMessage(SingleSegment(nil)) - defer msg.Reset(nil) + defer msg.Release() + res, _ := NewStruct(seg, ObjectSize{PointerCount: 3}) res.SetPtr(1, NewInterface(seg, msg.AddCap(c.AddRef())).ToPtr()) diff --git a/answerqueue.go b/answerqueue.go index e85ff9a7..39a48dc6 100644 --- a/answerqueue.go +++ b/answerqueue.go @@ -14,12 +14,12 @@ import ( // // An AnswerQueue can be in one of three states: // -// 1) Queueing. Incoming method calls will be added to the queue. -// 2) Draining, entered by calling Fulfill or Reject. Queued method -// calls will be delivered in sequence, and new incoming method calls -// will block until the AnswerQueue enters the Drained state. -// 3) Drained, entered once all queued methods have been delivered. -// Incoming methods are passthrough. +// 1. Queueing. Incoming method calls will be added to the queue. +// 2. Draining, entered by calling Fulfill or Reject. Queued method +// calls will be delivered in sequence, and new incoming method calls +// will block until the AnswerQueue enters the Drained state. +// 3. Drained, entered once all queued methods have been delivered. +// Incoming methods are passthrough. type AnswerQueue struct { method Method draining chan struct{} // closed while exiting queueing state @@ -154,8 +154,9 @@ func (qc queueCaller) PipelineRecv(ctx context.Context, transform []PipelineOp, func (qc queueCaller) PipelineSend(ctx context.Context, transform []PipelineOp, s Send) (*Answer, ReleaseFunc) { ret := new(StructReturner) r := Recv{ - Method: s.Method, - Returner: ret, + Method: s.Method, + Returner: ret, + ReleaseArgs: func() {}, } if s.PlaceArgs != nil { var err error @@ -167,12 +168,9 @@ func (qc queueCaller) PipelineSend(ctx context.Context, transform []PipelineOp, if err = s.PlaceArgs(r.Args); err != nil { return ErrorAnswer(s.Method, err), func() {} } - r.ReleaseArgs = func() { - r.Args.Message().Reset(nil) - } - } else { - r.ReleaseArgs = func() {} + r.ReleaseArgs = r.Args.Message().Release } + pcall := qc.PipelineRecv(ctx, transform, r) return ret.Answer(s.Method, pcall) } @@ -258,7 +256,7 @@ func (sr *StructReturner) ReleaseResults() { return } if err != nil && msg != nil { - msg.Reset(nil) + msg.Release() } } @@ -280,7 +278,7 @@ func (sr *StructReturner) Answer(m Method, pcall PipelineCaller) (*Answer, Relea sr.result = Struct{} sr.mu.Unlock() if msg != nil { - msg.Reset(nil) + msg.Release() } } } @@ -294,7 +292,7 @@ func (sr *StructReturner) Answer(m Method, pcall PipelineCaller) (*Answer, Relea sr.mu.Unlock() sr.p.ReleaseClients() if msg != nil { - msg.Reset(nil) + msg.Release() } } } diff --git a/message.go b/message.go index 0a642248..aefb33e3 100644 --- a/message.go +++ b/message.go @@ -94,10 +94,16 @@ func NewMultiSegmentMessage(b [][]byte) (msg *Message, first *Segment) { return msg, first } +// Release is syntactic sugar for Message.Reset(nil). See +// docstring for Reset for an important warning. +func (m *Message) Release() { + m.Reset(nil) +} + // Reset the message to use a different arena, allowing it // to be reused. This invalidates any existing pointers in -// the Message, and releases all clients in the cap table, -// so use with caution. +// the Message, releases all clients in the cap table, and +// releases the current Arena, so use with caution. func (m *Message) Reset(arena Arena) (first *Segment, err error) { for _, c := range m.CapTable { c.Release() @@ -107,10 +113,15 @@ func (m *Message) Reset(arena Arena) (first *Segment, err error) { delete(m.segs, k) } + if m.Arena != nil { + m.Arena.Release() + } + *m = Message{ Arena: arena, TraverseLimit: m.TraverseLimit, DepthLimit: m.DepthLimit, + CapTable: m.CapTable[:0], segs: m.segs, } diff --git a/request.go b/request.go index 1398f54a..df3c4cbc 100644 --- a/request.go +++ b/request.go @@ -77,9 +77,9 @@ func (r *Request) Future() *Future { // Release resources associated with the request. In particular: // -// * Release the arguments if they have not yet been released. -// * If the request has been sent, wait for the result and release -// the results. +// - Release the arguments if they have not yet been released. +// - If the request has been sent, wait for the result and release +// the results. func (r *Request) Release() { r.releaseArgs() rel := r.releaseResponse @@ -91,12 +91,9 @@ func (r *Request) Release() { } func (r *Request) releaseArgs() { - if r.args.IsValid() { - return + if !r.args.IsValid() { + msg := r.args.Message() + r.args = Struct{} + msg.Release() } - msg := r.args.Message() - r.args = Struct{} - arena := msg.Arena - msg.Reset(nil) - arena.Release() } diff --git a/rpc/answer.go b/rpc/answer.go index 067b5dd2..f1390877 100644 --- a/rpc/answer.go +++ b/rpc/answer.go @@ -103,7 +103,7 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Rel if err != nil { return rpccp.Return{}, nil, nil, rpcerr.WrapFailed("create return", err) } - ret, err := outMsg.Message.NewReturn() + ret, err := outMsg.Message().NewReturn() if err != nil { outMsg.Release() return rpccp.Return{}, nil, nil, rpcerr.WrapFailed("create return", err) diff --git a/rpc/flow_test.go b/rpc/flow_test.go index dd4f3973..e77eb973 100644 --- a/rpc/flow_test.go +++ b/rpc/flow_test.go @@ -34,26 +34,36 @@ func (t *measuringTransport) RecvMessage() (transport.IncomingMessage, error) { return inMsg, err } - size, err := capnp.Struct(inMsg.Message).Message().TotalSize() + size, err := inMsg.Message().Message().TotalSize() if err != nil { return inMsg, err } t.mu.Lock() - t.inUse += size - if t.inUse > t.maxInUse { + defer t.mu.Unlock() + + if t.inUse += size; t.inUse > t.maxInUse { t.maxInUse = t.inUse } - t.mu.Unlock() - - oldRelease := inMsg.Release - inMsg.Release = capnp.ReleaseFunc(func() { - oldRelease() - t.mu.Lock() - defer t.mu.Unlock() - t.inUse -= size - }) - return inMsg, err + + return releaseHook{ + t: t, + IncomingMessage: inMsg, + }, nil +} + +type releaseHook struct { + t *measuringTransport + size uint64 + transport.IncomingMessage +} + +func (rh releaseHook) Release() { + rh.IncomingMessage.Release() + + rh.t.mu.Lock() + rh.t.inUse -= rh.size + rh.t.mu.Lock() } // Test that attaching a fixed-size FlowLimiter results in actually limiting the diff --git a/rpc/level0_test.go b/rpc/level0_test.go index 7fbeac6d..ab3b6972 100644 --- a/rpc/level0_test.go +++ b/rpc/level0_test.go @@ -319,8 +319,8 @@ func TestSendBootstrapCall(t *testing.T) { if err != nil { t.Fatal("p2.NewMessage():", err) } - iptr := capnp.NewInterface(outMsg.Message.Segment(), 0) - err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{ + iptr := capnp.NewInterface(outMsg.Message().Segment(), 0) + err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{ Which: rpccp.Message_Which_return, Return: &rpcReturn{ AnswerID: qid, @@ -416,12 +416,12 @@ func TestSendBootstrapCall(t *testing.T) { if err != nil { t.Fatal("p2.NewMessage():", err) } - resp, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8}) + resp, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8}) if err != nil { t.Fatal("capnp.NewStruct:", err) } resp.SetUint64(0, 0xdeadbeef) - err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{ + err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{ Which: rpccp.Message_Which_return, Return: &rpcReturn{ AnswerID: qid, @@ -530,8 +530,8 @@ func TestSendBootstrapCallException(t *testing.T) { if err != nil { t.Fatal("p2.NewMessage():", err) } - iptr := capnp.NewInterface(outMsg.Message.Segment(), 0) - err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{ + iptr := capnp.NewInterface(outMsg.Message().Segment(), 0) + err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{ Which: rpccp.Message_Which_return, Return: &rpcReturn{ AnswerID: qid, @@ -755,12 +755,12 @@ func TestSendBootstrapPipelineCall(t *testing.T) { if err != nil { t.Fatal("p2.NewMessage():", err) } - resp, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8}) + resp, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8}) if err != nil { t.Fatal("capnp.NewStruct:", err) } resp.SetUint64(0, 0xdeadbeef) - err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{ + err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{ Which: rpccp.Message_Which_return, Return: &rpcReturn{ AnswerID: qid, @@ -963,12 +963,12 @@ func TestRecvBootstrapCall(t *testing.T) { if err != nil { t.Fatal("p2.NewMessage():", err) } - params, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8}) + params, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8}) if err != nil { t.Fatal("capnp.NewStruct:", err) } params.SetUint32(0, 0x2a2b) - err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{ + err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{ Which: rpccp.Message_Which_call, Call: &rpcCall{ QuestionID: callQID, @@ -1114,12 +1114,12 @@ func TestRecvBootstrapCallException(t *testing.T) { if err != nil { t.Fatal("p2.NewMessage():", err) } - params, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8}) + params, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8}) if err != nil { t.Fatal("capnp.NewStruct:", err) } params.SetUint32(0, 0x2a2b) - err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{ + err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{ Which: rpccp.Message_Which_call, Call: &rpcCall{ QuestionID: callQID, @@ -1258,12 +1258,12 @@ func TestRecvBootstrapPipelineCall(t *testing.T) { if err != nil { t.Fatal("p2.NewMessage():", err) } - params, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8}) + params, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8}) if err != nil { t.Fatal("capnp.NewStruct:", err) } params.SetUint32(0, 0x2a2b) - err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{ + err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{ Which: rpccp.Message_Which_call, Call: &rpcCall{ QuestionID: callQID, @@ -1452,8 +1452,8 @@ func TestCallOnClosedConn(t *testing.T) { if err != nil { t.Fatal("p2.NewMessage():", err) } - iptr := capnp.NewInterface(outMsg.Message.Segment(), 0) - err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{ + iptr := capnp.NewInterface(outMsg.Message().Segment(), 0) + err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{ Which: rpccp.Message_Which_return, Return: &rpcReturn{ AnswerID: qid, @@ -1593,7 +1593,7 @@ func TestRecvCancel(t *testing.T) { if err != nil { t.Fatal("p2.NewMessage():", err) } - err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{ + err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{ Which: rpccp.Message_Which_call, Call: &rpcCall{ QuestionID: callQID, @@ -1732,8 +1732,8 @@ func TestSendCancel(t *testing.T) { if err != nil { t.Fatal("p2.NewMessage():", err) } - iptr := capnp.NewInterface(outMsg.Message.Segment(), 0) - err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{ + iptr := capnp.NewInterface(outMsg.Message().Segment(), 0) + err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{ Which: rpccp.Message_Which_return, Return: &rpcReturn{ AnswerID: bootQID, @@ -2046,7 +2046,7 @@ func sendMessage(ctx context.Context, t rpc.Transport, msg *rpcMessage) error { return fmt.Errorf("send message: %v", err) } defer outMsg.Release() - if err := pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), msg); err != nil { + if err := pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), msg); err != nil { return fmt.Errorf("send message: %v", err) } if err := outMsg.Send(); err != nil { @@ -2061,7 +2061,7 @@ func recvMessage(ctx context.Context, t rpc.Transport) (*rpcMessage, capnp.Relea return nil, nil, err } r := new(rpcMessage) - if err := pogs.Extract(r, rpccp.Message_TypeID, capnp.Struct(inMsg.Message)); err != nil { + if err := pogs.Extract(r, rpccp.Message_TypeID, capnp.Struct(inMsg.Message())); err != nil { inMsg.Release() return nil, nil, fmt.Errorf("extract RPC message: %v", err) } diff --git a/rpc/level1_test.go b/rpc/level1_test.go index f9189bd8..1921410a 100644 --- a/rpc/level1_test.go +++ b/rpc/level1_test.go @@ -61,8 +61,8 @@ func testSendDisembargo(t *testing.T, sendPrimeTo rpccp.Call_sendResultsTo_Which if err != nil { t.Fatal("p2.NewMessage():", err) } - iptr := capnp.NewInterface(outMsg.Message.Segment(), 0) - err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{ + iptr := capnp.NewInterface(outMsg.Message().Segment(), 0) + err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{ Which: rpccp.Message_Which_return, Return: &rpcReturn{ AnswerID: bootQID, @@ -239,14 +239,14 @@ func testSendDisembargo(t *testing.T, sendPrimeTo rpccp.Call_sendResultsTo_Which if err != nil { t.Fatal("p2.NewMessage():", err) } - results, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{PointerCount: 1}) + results, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{PointerCount: 1}) if err != nil { t.Fatal("capnp.NewStruct:", err) } - if err := results.SetPtr(0, capnp.NewInterface(outMsg.Message.Segment(), 0).ToPtr()); err != nil { + if err := results.SetPtr(0, capnp.NewInterface(outMsg.Message().Segment(), 0).ToPtr()); err != nil { t.Fatal("results.SetPtr:", err) } - err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{ + err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{ Which: rpccp.Message_Which_return, Return: &rpcReturn{ AnswerID: qidA, @@ -557,7 +557,7 @@ func TestRecvDisembargo(t *testing.T) { if err != nil { t.Fatal("p2.NewMessage():", err) } - params, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{PointerCount: 1}) + params, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{PointerCount: 1}) if err != nil { outMsg.Release() t.Fatal("capnp.NewStruct:", err) @@ -567,7 +567,7 @@ func TestRecvDisembargo(t *testing.T) { outMsg.Release() t.Fatal("capnp.NewStruct.SetPtr:", err) } - err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{ + err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{ Which: rpccp.Message_Which_call, Call: &rpcCall{ QuestionID: callQID, @@ -861,7 +861,7 @@ func TestIssue3(t *testing.T) { if err != nil { t.Fatal("p2.NewMessage():", err) } - params, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{PointerCount: 1}) + params, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{PointerCount: 1}) if err != nil { outMsg.Release() t.Fatal("capnp.NewStruct:", err) @@ -871,7 +871,7 @@ func TestIssue3(t *testing.T) { outMsg.Release() t.Fatal("capnp.NewStruct.SetPtr:", err) } - err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{ + err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{ Which: rpccp.Message_Which_call, Call: &rpcCall{ QuestionID: callQID, @@ -938,13 +938,13 @@ func TestIssue3(t *testing.T) { if err != nil { t.Fatal("p2.NewMessage():", err) } - results, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8}) + results, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8}) if err != nil { outMsg.Release() t.Fatal("capnp.NewStruct:", err) } results.SetUint64(0, 0xdeadbeef) - err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{ + err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{ Which: rpccp.Message_Which_return, Return: &rpcReturn{ AnswerID: callbackQID, diff --git a/rpc/receiveranswer_test.go b/rpc/receiveranswer_test.go index 7c5acce2..e09bd2c8 100644 --- a/rpc/receiveranswer_test.go +++ b/rpc/receiveranswer_test.go @@ -153,7 +153,7 @@ func TestBootstrapReceiverAnswer(t *testing.T) { outMsg, err := trans.NewMessage() util.Chkfatal(err) - bs, err := outMsg.Message.NewBootstrap() + bs, err := outMsg.Message().NewBootstrap() util.Chkfatal(err) bs.SetQuestionId(0) outMsg.Send() @@ -163,7 +163,7 @@ func TestBootstrapReceiverAnswer(t *testing.T) { util.Chkfatal(err) // bootstrap.call(cap = bootstrap) - call, err := outMsg.Message.NewCall() + call, err := outMsg.Message().NewCall() util.Chkfatal(err) call.SetQuestionId(1) tgt, err := call.NewTarget() @@ -218,7 +218,7 @@ func TestCallReceiverAnswer(t *testing.T) { outMsg, err := trans.NewMessage() util.Chkfatal(err) - bs, err := outMsg.Message.NewBootstrap() + bs, err := outMsg.Message().NewBootstrap() util.Chkfatal(err) bs.SetQuestionId(0) outMsg.Send() @@ -228,7 +228,7 @@ func TestCallReceiverAnswer(t *testing.T) { util.Chkfatal(err) // qid1 = bootstrap.self() - call, err := outMsg.Message.NewCall() + call, err := outMsg.Message().NewCall() util.Chkfatal(err) call.SetQuestionId(1) tgt, err := call.NewTarget() @@ -245,7 +245,7 @@ func TestCallReceiverAnswer(t *testing.T) { util.Chkfatal(err) // qid1.self.call(cap = qid1.self) - call, err = outMsg.Message.NewCall() + call, err = outMsg.Message().NewCall() util.Chkfatal(err) call.SetQuestionId(2) tgt, err = call.NewTarget() diff --git a/rpc/rpc.go b/rpc/rpc.go index a30992b6..14247d16 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -543,7 +543,7 @@ func (c *Conn) abort(abortErr error) { defer outMsg.Release() // configure & send abort message - if abort, err := outMsg.Message.NewAbort(); err == nil { + if abort, err := outMsg.Message().NewAbort(); err == nil { abort.SetType(rpccp.Exception_Type(exc.TypeOf(abortErr))) if err = abort.SetReason(abortErr.Error()); err == nil { outMsg.Send() @@ -591,7 +591,7 @@ func (c *Conn) receive(ctx context.Context) func() error { return nil } - switch in.Message.Which() { + switch in.Message().Which() { case rpccp.Message_Which_unimplemented: if err := c.handleUnimplemented(in); err != nil { return err @@ -648,7 +648,7 @@ func (c *Conn) receive(ctx context.Context) func() error { func (c *Conn) handleAbort(in transport.IncomingMessage) { defer in.Release() - e, err := in.Message.Abort() + e, err := in.Message().Abort() if err != nil { c.er.ReportError(exc.WrapError("read abort", err)) return @@ -666,7 +666,7 @@ func (c *Conn) handleAbort(in transport.IncomingMessage) { func (c *Conn) handleUnimplemented(in transport.IncomingMessage) error { defer in.Release() - msg, err := in.Message.Unimplemented() + msg, err := in.Message().Unimplemented() if err != nil { return exc.WrapError("read unimplemented", err) } @@ -708,7 +708,7 @@ func (c *Conn) handleUnimplemented(in transport.IncomingMessage) error { func (c *Conn) handleBootstrap(in transport.IncomingMessage) error { defer in.Release() - bootstrap, err := in.Message.Bootstrap() + bootstrap, err := in.Message().Bootstrap() if err != nil { c.er.ReportError(exc.WrapError("read bootstrap", err)) return nil @@ -763,7 +763,7 @@ func (c *Conn) handleBootstrap(in transport.IncomingMessage) error { } func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) error { - call, err := in.Message.Call() + call, err := in.Message().Call() if err != nil { in.Release() c.er.ReportError(exc.WrapError("read call", err)) @@ -1090,7 +1090,7 @@ func parseTransform(list rpccp.PromisedAnswer_Op_List) ([]capnp.PipelineOp, erro } func (c *Conn) handleReturn(ctx context.Context, in transport.IncomingMessage) error { - ret, err := in.Message.Return() + ret, err := in.Message().Return() if err != nil { in.Release() c.er.ReportError(exc.WrapError("read return", err)) @@ -1292,7 +1292,7 @@ type parsedReturn struct { func (c *Conn) handleFinish(ctx context.Context, in transport.IncomingMessage) error { defer in.Release() - fin, err := in.Message.Finish() + fin, err := in.Message().Finish() if err != nil { c.er.ReportError(exc.WrapError("read finish", err)) return nil @@ -1543,7 +1543,7 @@ func (c *lockedConn) recvPayload(rl *releaseList, payload rpccp.Payload) (_ capn func (c *Conn) handleRelease(ctx context.Context, in transport.IncomingMessage) error { defer in.Release() - rel, err := in.Message.Release() + rel, err := in.Message().Release() if err != nil { c.er.ReportError(exc.WrapError("read release", err)) return nil @@ -1564,7 +1564,7 @@ func (c *Conn) handleRelease(ctx context.Context, in transport.IncomingMessage) } func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessage) error { - d, err := in.Message.Disembargo() + d, err := in.Message().Disembargo() if err != nil { in.Release() c.er.ReportError(exc.WrapError("read disembargo", err)) @@ -1739,13 +1739,13 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag } func (c *Conn) handleUnknownMessageType(ctx context.Context, in transport.IncomingMessage) { - err := errors.New("unknown message type " + in.Message.Which().String() + " from remote") + err := errors.New("unknown message type " + in.Message().Which().String() + " from remote") c.er.ReportError(err) c.withLocked(func(c *lockedConn) { c.sendMessage(ctx, func(m rpccp.Message) error { defer in.Release() - if err := m.SetUnimplemented(in.Message); err != nil { + if err := m.SetUnimplemented(in.Message()); err != nil { return rpcerr.Annotate(err, "send unimplemented") } return nil @@ -1786,7 +1786,7 @@ func (c *lockedConn) sendMessage(ctx context.Context, build func(rpccp.Message) send = func() error { return rpcerr.WrapFailed("create message", err) } - } else if err = build(outMsg.Message); err != nil { + } else if err = build(outMsg.Message()); err != nil { send = func() error { return rpcerr.WrapFailed("build message", err) } diff --git a/rpc/transport/transport.go b/rpc/transport/transport.go index a4715a56..17fd7275 100644 --- a/rpc/transport/transport.go +++ b/rpc/transport/transport.go @@ -22,14 +22,14 @@ import ( // with RecvMessage. type Transport interface { // NewMessage allocates a new message to be sent over the transport. - // The caller must call the release function when it no longer needs - // to reference the message. Calling the release function more than once - // has no effect. Before releasing the message, send may be called at most - // once to send the mssage. + // The caller must call OutgoingMessage.Release() when it no longer + // needs to reference the message. Calling Release() more than once + // has no effect. Before releasing the message, Send() MAY be called + // at most once to send the mssage. // - // Messages returned by NewMessage must have a nil CapTable. - // When the returned ReleaseFunc is called, any clients in the message's - // CapTable will be released. + // When Release() is called, the underlying *capnp.Message SHOULD be + // released. This will also release any clients in the CapTable and + // release its Arena. // // The Arena in the returned message should be fast at allocating new // segments. The returned ReleaseFunc MUST be safe to call concurrently @@ -37,13 +37,14 @@ type Transport interface { NewMessage() (OutgoingMessage, error) // RecvMessage receives the next message sent from the remote vat. - // The returned message is only valid until the release function is - // called. The release function may be called concurrently with - // RecvMessage or with any other release function returned by RecvMessage. + // The returned message is only valid until the release method is + // called. The IncomingMessage.Release() method may be called + // concurrently with RecvMessage or with any other release function + // returned by RecvMessage. // - // Messages returned by RecvMessage must have a nil CapTable. - // When the returned ReleaseFunc is called, any clients in the message's - // CapTable will be released. + // When Release() is called, the underlying *capnp.Message SHOULD be + // released. This will also release any clients in the CapTable and + // release its Arena. // // The Arena in the returned message should not fetch segments lazily; // the Arena should be fast to access other segments. @@ -55,15 +56,31 @@ type Transport interface { Close() error } -type OutgoingMessage struct { - Message rpccp.Message - Send func() error - Release capnp.ReleaseFunc +// OutgoingMessage is a message that can be sent at a later time. +// Release() MUST be called when the OutgoingMessage is no longer in +// use. Before releasing an ougoing message, Send() MAY be called at +// most once to send the message over the transport that produced it. +// +// Implementations SHOULD release the underlying *capnp.Message when +// the Release() method is called. +// +// Release() MUST be idempotent, and calls to Send() after a call to +// Release MUST panic. +type OutgoingMessage interface { + Send() error + Message() rpccp.Message + Release() } -type IncomingMessage struct { - Message rpccp.Message - Release capnp.ReleaseFunc +// IncomingMessage is a message that has arrived over a transport. +// Release() MUST be called when the IncomingMessage is no longer +// in use. +// +// Implementations SHOULD release the underlying *capnp.Message when +// the Release() method is called. Release() MUST be idempotent. +type IncomingMessage interface { + Message() rpccp.Message + Release() } // A Codec is responsible for encoding and decoding messages from @@ -108,43 +125,28 @@ func NewPackedStream(rwc io.ReadWriteCloser) Transport { // It is safe to call NewMessage concurrently with RecvMessage. func (s *transport) NewMessage() (OutgoingMessage, error) { arena := capnp.MultiSegment(nil) - msg, seg, err := capnp.NewMessage(arena) + _, seg, err := capnp.NewMessage(arena) if err != nil { err = transporterr.Annotate(exc.WrapError("new message", err), "stream transport") - return OutgoingMessage{}, err + return nil, err } - rmsg, err := rpccp.NewRootMessage(seg) + m, err := rpccp.NewRootMessage(seg) if err != nil { err = transporterr.Annotate(exc.WrapError("new message", err), "stream transport") - return OutgoingMessage{}, err + return nil, err } - alreadyReleased := false - - send := func() error { - if alreadyReleased { - panic("Tried to send() a message that was already released.") - } - if err = s.c.Encode(msg); err != nil { + send := func(m *capnp.Message) (err error) { + if err = s.c.Encode(m); err != nil { err = transporterr.Annotate(exc.WrapError("send", err), "stream transport") } - return err - } - - release := func() { - if alreadyReleased { - return - } - alreadyReleased = true - msg.Reset(nil) - arena.Release() + return } - return OutgoingMessage{ - Message: rmsg, - Send: send, - Release: release, + return &outgoingMsg{ + message: m, + send: send, }, nil } @@ -155,22 +157,15 @@ func (s *transport) RecvMessage() (IncomingMessage, error) { msg, err := s.c.Decode() if err != nil { err = transporterr.Annotate(exc.WrapError("receive", err), "stream transport") - return IncomingMessage{}, err + return nil, err } rmsg, err := rpccp.ReadRootMessage(msg) if err != nil { err = transporterr.Annotate(exc.WrapError("receive", err), "stream transport") - return IncomingMessage{}, err - } - - release := func() { - msg.Reset(nil) + return nil, err } - return IncomingMessage{ - Message: rmsg, - Release: release, - }, nil + return incomingMsg(rmsg), nil } // Close closes the underlying ReadWriteCloser. It is not safe to call @@ -215,3 +210,39 @@ type packedEncoding struct{} func (packedEncoding) NewEncoder(w io.Writer) *capnp.Encoder { return capnp.NewPackedEncoder(w) } func (packedEncoding) NewDecoder(r io.Reader) *capnp.Decoder { return capnp.NewPackedDecoder(r) } + +type outgoingMsg struct { + message rpccp.Message + send func(*capnp.Message) error + released bool +} + +func (o *outgoingMsg) Release() { + if m := o.message.Message(); !o.released && m != nil { + m.Release() + } +} + +func (o *outgoingMsg) Message() rpccp.Message { + return o.message +} + +func (o *outgoingMsg) Send() error { + if !o.released { + return o.send(o.message.Message()) + } + + panic("call to Send() after call to Release()") +} + +type incomingMsg rpccp.Message + +func (i incomingMsg) Message() rpccp.Message { + return rpccp.Message(i) +} + +func (i incomingMsg) Release() { + if m := i.Message().Message(); m != nil { + m.Release() + } +} diff --git a/rpc/transport/transport_test.go b/rpc/transport/transport_test.go index 1af7185f..ed9c9d8d 100644 --- a/rpc/transport/transport_test.go +++ b/rpc/transport/transport_test.go @@ -9,6 +9,8 @@ import ( capnp "capnproto.org/go/capnp/v3" rpccp "capnproto.org/go/capnp/v3/std/capnp/rpc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func testTransport(t *testing.T, makePipe func() (t1, t2 Transport, err error)) { @@ -51,14 +53,14 @@ func testTransport(t *testing.T, makePipe func() (t1, t2 Transport, err error)) defer bootMsg.Release() // Fill in bootstrap message - boot, err := bootMsg.Message.NewBootstrap() + boot, err := bootMsg.Message().NewBootstrap() if err != nil { t.Fatal("NewBootstrap:", err) } boot.SetQuestionId(42) // Fill in call message - call, err := callMsg.Message.NewCall() + call, err := callMsg.Message().NewCall() if err != nil { t.Fatal("NewCall:", err) } @@ -79,8 +81,8 @@ func testTransport(t *testing.T, makePipe func() (t1, t2 Transport, err error)) t.Fatal("NewParams:", err) } // simulate mutating CapTable - callMsg.Message.Message().AddCap(capnp.ErrorClient(errors.New("foo"))) - callMsg.Message.Message().CapTable = nil + callMsg.Message().Message().AddCap(capnp.ErrorClient(errors.New("foo"))) + callMsg.Message().Message().CapTable = nil capPtr := capnp.NewInterface(params.Segment(), 0).ToPtr() if err := params.SetContent(capPtr); err != nil { t.Fatal("SetContent:", err) @@ -100,13 +102,13 @@ func testTransport(t *testing.T, makePipe func() (t1, t2 Transport, err error)) if err != nil { t.Fatal("t2.RecvMessage:", err) } - if r1.Message.Message().CapTable != nil { + if r1.Message().Message().CapTable != nil { t.Error("t2.RecvMessage(ctx).Message().CapTable is not nil") } - if r1.Message.Which() != rpccp.Message_Which_bootstrap { - t.Errorf("t2.RecvMessage(ctx).Which = %v; want bootstrap", r1.Message.Which()) + if r1.Message().Which() != rpccp.Message_Which_bootstrap { + t.Errorf("t2.RecvMessage(ctx).Which = %v; want bootstrap", r1.Message().Which()) } else { - rboot, _ := r1.Message.Bootstrap() + rboot, _ := r1.Message().Bootstrap() if rboot.QuestionId() != 42 { t.Errorf("t2.RecvMessage(ctx).Bootstrap.QuestionID = %d; want 42", rboot.QuestionId()) } @@ -122,13 +124,13 @@ func testTransport(t *testing.T, makePipe func() (t1, t2 Transport, err error)) if err != nil { t.Fatal("t2.RecvMessage:", err) } - if r2.Message.Message().CapTable != nil { + if r2.Message().Message().CapTable != nil { t.Error("t2.RecvMessage(ctx).Message().CapTable is not nil") } - if r2.Message.Which() != rpccp.Message_Which_call { - t.Errorf("t2.RecvMessage(ctx).Which = %v; want call", r2.Message.Which()) + if r2.Message().Which() != rpccp.Message_Which_call { + t.Errorf("t2.RecvMessage(ctx).Which = %v; want call", r2.Message().Which()) } else { - rcall, _ := r2.Message.Call() + rcall, _ := r2.Message().Call() if rcall.QuestionId() != 123 { t.Errorf("t2.RecvMessage(ctx).Call.QuestionID = %d; want 123", rcall.QuestionId()) } @@ -152,25 +154,20 @@ func testTransport(t *testing.T, makePipe func() (t1, t2 Transport, err error)) }) t.Run("InterruptRecv", func(t *testing.T) { t1, t2, err := makePipe() - if err != nil { - t.Fatal("makePipe:", err) - } + require.NoError(t, err, "makePipe should not fail") go func() { time.Sleep(100 * time.Millisecond) t1.Close() }() - inMsg, err := t1.RecvMessage() // hangs here if doesn't work - if err == nil { - t.Error("interrupted RecvMessage returned nil error") - } - if inMsg.Release != nil { - inMsg.Release() - } + _, err = t1.RecvMessage() // hangs here if doesn't work + assert.Error(t, err, + "RecvMessage() should return error when transport is closed") + + err = t2.Close() + assert.NoError(t, err, + "Close() should not return error after remote side closes") - if err := t2.Close(); err != nil { - t.Error("t2.Close:", err) - } }) } diff --git a/server/server.go b/server/server.go index 04f023c9..d06776b2 100644 --- a/server/server.go +++ b/server/server.go @@ -136,7 +136,7 @@ func (srv *Server) Send(ctx context.Context, s capnp.Send) (*capnp.Answer, capnp Args: args, ReleaseArgs: func() { if msg := args.Message(); msg != nil { - msg.Reset(nil) + msg.Release() args = capnp.Struct{} } }, @@ -244,7 +244,7 @@ func sendArgsToStruct(s capnp.Send) (capnp.Struct, error) { return capnp.Struct{}, err } if err := s.PlaceArgs(st); err != nil { - st.Message().Reset(nil) + st.Message().Release() return capnp.Struct{}, exc.WrapError("place args", err) } return st, nil @@ -286,10 +286,6 @@ func (sm sortedMethods) Swap(i, j int) { sm[i], sm[j] = sm[j], sm[i] } -type resultsAllocer interface { - AllocResults(capnp.ObjectSize) (capnp.Struct, error) -} - func newError(msg string) error { return exc.New(exc.Failed, "capnp server", msg) }