Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions answer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@ func TestPromiseReject(t *testing.T) {
}

func TestPromiseFulfill(t *testing.T) {
t.Parallel()

t.Run("Done", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
done := p.Answer().Done()
msg, seg, _ := NewMessage(SingleSegment(nil))
defer msg.Reset(nil)
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{DataSize: 8})
p.Fulfill(res.ToPtr())
select {
Expand All @@ -73,7 +76,8 @@ func TestPromiseFulfill(t *testing.T) {
defer p.ReleaseClients()
ans := p.Answer()
msg, seg, _ := NewMessage(SingleSegment(nil))
defer msg.Reset(nil)
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{DataSize: 8})
res.SetUint32(0, 0xdeadbeef)
p.Fulfill(res.ToPtr())
Expand All @@ -96,7 +100,8 @@ func TestPromiseFulfill(t *testing.T) {
c := NewClient(h)
defer c.Release()
msg, seg, _ := NewMessage(SingleSegment(nil))
defer msg.Reset(nil)
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{PointerCount: 3})
res.SetPtr(1, NewInterface(seg, msg.AddCap(c.AddRef())).ToPtr())

Expand Down
30 changes: 14 additions & 16 deletions answerqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ import (
//
// An AnswerQueue can be in one of three states:
//
// 1) Queueing. Incoming method calls will be added to the queue.
// 2) Draining, entered by calling Fulfill or Reject. Queued method
// calls will be delivered in sequence, and new incoming method calls
// will block until the AnswerQueue enters the Drained state.
// 3) Drained, entered once all queued methods have been delivered.
// Incoming methods are passthrough.
// 1. Queueing. Incoming method calls will be added to the queue.
// 2. Draining, entered by calling Fulfill or Reject. Queued method
// calls will be delivered in sequence, and new incoming method calls
// will block until the AnswerQueue enters the Drained state.
// 3. Drained, entered once all queued methods have been delivered.
// Incoming methods are passthrough.
type AnswerQueue struct {
method Method
draining chan struct{} // closed while exiting queueing state
Expand Down Expand Up @@ -154,8 +154,9 @@ func (qc queueCaller) PipelineRecv(ctx context.Context, transform []PipelineOp,
func (qc queueCaller) PipelineSend(ctx context.Context, transform []PipelineOp, s Send) (*Answer, ReleaseFunc) {
ret := new(StructReturner)
r := Recv{
Method: s.Method,
Returner: ret,
Method: s.Method,
Returner: ret,
ReleaseArgs: func() {},
}
if s.PlaceArgs != nil {
var err error
Expand All @@ -167,12 +168,9 @@ func (qc queueCaller) PipelineSend(ctx context.Context, transform []PipelineOp,
if err = s.PlaceArgs(r.Args); err != nil {
return ErrorAnswer(s.Method, err), func() {}
}
r.ReleaseArgs = func() {
r.Args.Message().Reset(nil)
}
} else {
r.ReleaseArgs = func() {}
r.ReleaseArgs = r.Args.Message().Release
}

pcall := qc.PipelineRecv(ctx, transform, r)
return ret.Answer(s.Method, pcall)
}
Expand Down Expand Up @@ -258,7 +256,7 @@ func (sr *StructReturner) ReleaseResults() {
return
}
if err != nil && msg != nil {
msg.Reset(nil)
msg.Release()
}
}

Expand All @@ -280,7 +278,7 @@ func (sr *StructReturner) Answer(m Method, pcall PipelineCaller) (*Answer, Relea
sr.result = Struct{}
sr.mu.Unlock()
if msg != nil {
msg.Reset(nil)
msg.Release()
}
}
}
Expand All @@ -294,7 +292,7 @@ func (sr *StructReturner) Answer(m Method, pcall PipelineCaller) (*Answer, Relea
sr.mu.Unlock()
sr.p.ReleaseClients()
if msg != nil {
msg.Reset(nil)
msg.Release()
}
}
}
15 changes: 13 additions & 2 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,16 @@ func NewMultiSegmentMessage(b [][]byte) (msg *Message, first *Segment) {
return msg, first
}

// Release is syntactic sugar for Message.Reset(nil). See
// docstring for Reset for an important warning.
func (m *Message) Release() {
m.Reset(nil)
}

// Reset the message to use a different arena, allowing it
// to be reused. This invalidates any existing pointers in
// the Message, and releases all clients in the cap table,
// so use with caution.
// the Message, releases all clients in the cap table, and
// releases the current Arena, so use with caution.
func (m *Message) Reset(arena Arena) (first *Segment, err error) {
for _, c := range m.CapTable {
c.Release()
Expand All @@ -107,10 +113,15 @@ func (m *Message) Reset(arena Arena) (first *Segment, err error) {
delete(m.segs, k)
}

if m.Arena != nil {
m.Arena.Release()
}

*m = Message{
Arena: arena,
TraverseLimit: m.TraverseLimit,
DepthLimit: m.DepthLimit,
CapTable: m.CapTable[:0],
segs: m.segs,
}

Expand Down
17 changes: 7 additions & 10 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ func (r *Request) Future() *Future {

// Release resources associated with the request. In particular:
//
// * Release the arguments if they have not yet been released.
// * If the request has been sent, wait for the result and release
// the results.
// - Release the arguments if they have not yet been released.
// - If the request has been sent, wait for the result and release
// the results.
func (r *Request) Release() {
r.releaseArgs()
rel := r.releaseResponse
Expand All @@ -91,12 +91,9 @@ func (r *Request) Release() {
}

func (r *Request) releaseArgs() {
if r.args.IsValid() {
return
if !r.args.IsValid() {
msg := r.args.Message()
r.args = Struct{}
msg.Release()
}
msg := r.args.Message()
r.args = Struct{}
arena := msg.Arena
msg.Reset(nil)
arena.Release()
}
2 changes: 1 addition & 1 deletion rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Rel
if err != nil {
return rpccp.Return{}, nil, nil, rpcerr.WrapFailed("create return", err)
}
ret, err := outMsg.Message.NewReturn()
ret, err := outMsg.Message().NewReturn()
if err != nil {
outMsg.Release()
return rpccp.Return{}, nil, nil, rpcerr.WrapFailed("create return", err)
Expand Down
36 changes: 23 additions & 13 deletions rpc/flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,36 @@ func (t *measuringTransport) RecvMessage() (transport.IncomingMessage, error) {
return inMsg, err
}

size, err := capnp.Struct(inMsg.Message).Message().TotalSize()
size, err := inMsg.Message().Message().TotalSize()
if err != nil {
return inMsg, err
}

t.mu.Lock()
t.inUse += size
if t.inUse > t.maxInUse {
defer t.mu.Unlock()

if t.inUse += size; t.inUse > t.maxInUse {
t.maxInUse = t.inUse
}
t.mu.Unlock()

oldRelease := inMsg.Release
inMsg.Release = capnp.ReleaseFunc(func() {
oldRelease()
t.mu.Lock()
defer t.mu.Unlock()
t.inUse -= size
})
return inMsg, err

return releaseHook{
t: t,
IncomingMessage: inMsg,
}, nil
}

type releaseHook struct {
t *measuringTransport
size uint64
transport.IncomingMessage
}

func (rh releaseHook) Release() {
rh.IncomingMessage.Release()

rh.t.mu.Lock()
rh.t.inUse -= rh.size
rh.t.mu.Lock()
}

// Test that attaching a fixed-size FlowLimiter results in actually limiting the
Expand Down
42 changes: 21 additions & 21 deletions rpc/level0_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ func TestSendBootstrapCall(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
iptr := capnp.NewInterface(outMsg.Message.Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
iptr := capnp.NewInterface(outMsg.Message().Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_return,
Return: &rpcReturn{
AnswerID: qid,
Expand Down Expand Up @@ -416,12 +416,12 @@ func TestSendBootstrapCall(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
resp, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8})
resp, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8})
if err != nil {
t.Fatal("capnp.NewStruct:", err)
}
resp.SetUint64(0, 0xdeadbeef)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_return,
Return: &rpcReturn{
AnswerID: qid,
Expand Down Expand Up @@ -530,8 +530,8 @@ func TestSendBootstrapCallException(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
iptr := capnp.NewInterface(outMsg.Message.Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
iptr := capnp.NewInterface(outMsg.Message().Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_return,
Return: &rpcReturn{
AnswerID: qid,
Expand Down Expand Up @@ -755,12 +755,12 @@ func TestSendBootstrapPipelineCall(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
resp, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8})
resp, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8})
if err != nil {
t.Fatal("capnp.NewStruct:", err)
}
resp.SetUint64(0, 0xdeadbeef)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_return,
Return: &rpcReturn{
AnswerID: qid,
Expand Down Expand Up @@ -963,12 +963,12 @@ func TestRecvBootstrapCall(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
params, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8})
params, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8})
if err != nil {
t.Fatal("capnp.NewStruct:", err)
}
params.SetUint32(0, 0x2a2b)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_call,
Call: &rpcCall{
QuestionID: callQID,
Expand Down Expand Up @@ -1114,12 +1114,12 @@ func TestRecvBootstrapCallException(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
params, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8})
params, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8})
if err != nil {
t.Fatal("capnp.NewStruct:", err)
}
params.SetUint32(0, 0x2a2b)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_call,
Call: &rpcCall{
QuestionID: callQID,
Expand Down Expand Up @@ -1258,12 +1258,12 @@ func TestRecvBootstrapPipelineCall(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
params, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8})
params, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8})
if err != nil {
t.Fatal("capnp.NewStruct:", err)
}
params.SetUint32(0, 0x2a2b)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_call,
Call: &rpcCall{
QuestionID: callQID,
Expand Down Expand Up @@ -1452,8 +1452,8 @@ func TestCallOnClosedConn(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
iptr := capnp.NewInterface(outMsg.Message.Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
iptr := capnp.NewInterface(outMsg.Message().Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_return,
Return: &rpcReturn{
AnswerID: qid,
Expand Down Expand Up @@ -1593,7 +1593,7 @@ func TestRecvCancel(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_call,
Call: &rpcCall{
QuestionID: callQID,
Expand Down Expand Up @@ -1732,8 +1732,8 @@ func TestSendCancel(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
iptr := capnp.NewInterface(outMsg.Message.Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
iptr := capnp.NewInterface(outMsg.Message().Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_return,
Return: &rpcReturn{
AnswerID: bootQID,
Expand Down Expand Up @@ -2046,7 +2046,7 @@ func sendMessage(ctx context.Context, t rpc.Transport, msg *rpcMessage) error {
return fmt.Errorf("send message: %v", err)
}
defer outMsg.Release()
if err := pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), msg); err != nil {
if err := pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), msg); err != nil {
return fmt.Errorf("send message: %v", err)
}
if err := outMsg.Send(); err != nil {
Expand All @@ -2061,7 +2061,7 @@ func recvMessage(ctx context.Context, t rpc.Transport) (*rpcMessage, capnp.Relea
return nil, nil, err
}
r := new(rpcMessage)
if err := pogs.Extract(r, rpccp.Message_TypeID, capnp.Struct(inMsg.Message)); err != nil {
if err := pogs.Extract(r, rpccp.Message_TypeID, capnp.Struct(inMsg.Message())); err != nil {
inMsg.Release()
return nil, nil, fmt.Errorf("extract RPC message: %v", err)
}
Expand Down
Loading