Skip to content

Commit 1a829fd

Browse files
authored
Merge pull request #489 from capnproto/cleanup/message-reset
Clean up Message.Reset
2 parents 3926d66 + 098f255 commit 1a829fd

File tree

13 files changed

+226
-181
lines changed

13 files changed

+226
-181
lines changed

answer_test.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,14 @@ func TestPromiseReject(t *testing.T) {
5454
}
5555

5656
func TestPromiseFulfill(t *testing.T) {
57+
t.Parallel()
58+
5759
t.Run("Done", func(t *testing.T) {
5860
p := NewPromise(dummyMethod, dummyPipelineCaller{})
5961
done := p.Answer().Done()
6062
msg, seg, _ := NewMessage(SingleSegment(nil))
61-
defer msg.Reset(nil)
63+
defer msg.Release()
64+
6265
res, _ := NewStruct(seg, ObjectSize{DataSize: 8})
6366
p.Fulfill(res.ToPtr())
6467
select {
@@ -73,7 +76,8 @@ func TestPromiseFulfill(t *testing.T) {
7376
defer p.ReleaseClients()
7477
ans := p.Answer()
7578
msg, seg, _ := NewMessage(SingleSegment(nil))
76-
defer msg.Reset(nil)
79+
defer msg.Release()
80+
7781
res, _ := NewStruct(seg, ObjectSize{DataSize: 8})
7882
res.SetUint32(0, 0xdeadbeef)
7983
p.Fulfill(res.ToPtr())
@@ -96,7 +100,8 @@ func TestPromiseFulfill(t *testing.T) {
96100
c := NewClient(h)
97101
defer c.Release()
98102
msg, seg, _ := NewMessage(SingleSegment(nil))
99-
defer msg.Reset(nil)
103+
defer msg.Release()
104+
100105
res, _ := NewStruct(seg, ObjectSize{PointerCount: 3})
101106
res.SetPtr(1, NewInterface(seg, msg.AddCap(c.AddRef())).ToPtr())
102107

answerqueue.go

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ import (
1414
//
1515
// An AnswerQueue can be in one of three states:
1616
//
17-
// 1) Queueing. Incoming method calls will be added to the queue.
18-
// 2) Draining, entered by calling Fulfill or Reject. Queued method
19-
// calls will be delivered in sequence, and new incoming method calls
20-
// will block until the AnswerQueue enters the Drained state.
21-
// 3) Drained, entered once all queued methods have been delivered.
22-
// Incoming methods are passthrough.
17+
// 1. Queueing. Incoming method calls will be added to the queue.
18+
// 2. Draining, entered by calling Fulfill or Reject. Queued method
19+
// calls will be delivered in sequence, and new incoming method calls
20+
// will block until the AnswerQueue enters the Drained state.
21+
// 3. Drained, entered once all queued methods have been delivered.
22+
// Incoming methods are passthrough.
2323
type AnswerQueue struct {
2424
method Method
2525
draining chan struct{} // closed while exiting queueing state
@@ -154,8 +154,9 @@ func (qc queueCaller) PipelineRecv(ctx context.Context, transform []PipelineOp,
154154
func (qc queueCaller) PipelineSend(ctx context.Context, transform []PipelineOp, s Send) (*Answer, ReleaseFunc) {
155155
ret := new(StructReturner)
156156
r := Recv{
157-
Method: s.Method,
158-
Returner: ret,
157+
Method: s.Method,
158+
Returner: ret,
159+
ReleaseArgs: func() {},
159160
}
160161
if s.PlaceArgs != nil {
161162
var err error
@@ -167,12 +168,9 @@ func (qc queueCaller) PipelineSend(ctx context.Context, transform []PipelineOp,
167168
if err = s.PlaceArgs(r.Args); err != nil {
168169
return ErrorAnswer(s.Method, err), func() {}
169170
}
170-
r.ReleaseArgs = func() {
171-
r.Args.Message().Reset(nil)
172-
}
173-
} else {
174-
r.ReleaseArgs = func() {}
171+
r.ReleaseArgs = r.Args.Message().Release
175172
}
173+
176174
pcall := qc.PipelineRecv(ctx, transform, r)
177175
return ret.Answer(s.Method, pcall)
178176
}
@@ -258,7 +256,7 @@ func (sr *StructReturner) ReleaseResults() {
258256
return
259257
}
260258
if err != nil && msg != nil {
261-
msg.Reset(nil)
259+
msg.Release()
262260
}
263261
}
264262

@@ -280,7 +278,7 @@ func (sr *StructReturner) Answer(m Method, pcall PipelineCaller) (*Answer, Relea
280278
sr.result = Struct{}
281279
sr.mu.Unlock()
282280
if msg != nil {
283-
msg.Reset(nil)
281+
msg.Release()
284282
}
285283
}
286284
}
@@ -294,7 +292,7 @@ func (sr *StructReturner) Answer(m Method, pcall PipelineCaller) (*Answer, Relea
294292
sr.mu.Unlock()
295293
sr.p.ReleaseClients()
296294
if msg != nil {
297-
msg.Reset(nil)
295+
msg.Release()
298296
}
299297
}
300298
}

message.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,16 @@ func NewMultiSegmentMessage(b [][]byte) (msg *Message, first *Segment) {
9494
return msg, first
9595
}
9696

97+
// Release is syntactic sugar for Message.Reset(nil). See
98+
// docstring for Reset for an important warning.
99+
func (m *Message) Release() {
100+
m.Reset(nil)
101+
}
102+
97103
// Reset the message to use a different arena, allowing it
98104
// to be reused. This invalidates any existing pointers in
99-
// the Message, and releases all clients in the cap table,
100-
// so use with caution.
105+
// the Message, releases all clients in the cap table, and
106+
// releases the current Arena, so use with caution.
101107
func (m *Message) Reset(arena Arena) (first *Segment, err error) {
102108
for _, c := range m.CapTable {
103109
c.Release()
@@ -107,10 +113,15 @@ func (m *Message) Reset(arena Arena) (first *Segment, err error) {
107113
delete(m.segs, k)
108114
}
109115

116+
if m.Arena != nil {
117+
m.Arena.Release()
118+
}
119+
110120
*m = Message{
111121
Arena: arena,
112122
TraverseLimit: m.TraverseLimit,
113123
DepthLimit: m.DepthLimit,
124+
CapTable: m.CapTable[:0],
114125
segs: m.segs,
115126
}
116127

request.go

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ func (r *Request) Future() *Future {
7777

7878
// Release resources associated with the request. In particular:
7979
//
80-
// * Release the arguments if they have not yet been released.
81-
// * If the request has been sent, wait for the result and release
82-
// the results.
80+
// - Release the arguments if they have not yet been released.
81+
// - If the request has been sent, wait for the result and release
82+
// the results.
8383
func (r *Request) Release() {
8484
r.releaseArgs()
8585
rel := r.releaseResponse
@@ -91,12 +91,9 @@ func (r *Request) Release() {
9191
}
9292

9393
func (r *Request) releaseArgs() {
94-
if r.args.IsValid() {
95-
return
94+
if !r.args.IsValid() {
95+
msg := r.args.Message()
96+
r.args = Struct{}
97+
msg.Release()
9698
}
97-
msg := r.args.Message()
98-
r.args = Struct{}
99-
arena := msg.Arena
100-
msg.Reset(nil)
101-
arena.Release()
10299
}

rpc/answer.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Rel
103103
if err != nil {
104104
return rpccp.Return{}, nil, nil, rpcerr.WrapFailed("create return", err)
105105
}
106-
ret, err := outMsg.Message.NewReturn()
106+
ret, err := outMsg.Message().NewReturn()
107107
if err != nil {
108108
outMsg.Release()
109109
return rpccp.Return{}, nil, nil, rpcerr.WrapFailed("create return", err)

rpc/flow_test.go

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,26 +34,36 @@ func (t *measuringTransport) RecvMessage() (transport.IncomingMessage, error) {
3434
return inMsg, err
3535
}
3636

37-
size, err := capnp.Struct(inMsg.Message).Message().TotalSize()
37+
size, err := inMsg.Message().Message().TotalSize()
3838
if err != nil {
3939
return inMsg, err
4040
}
4141

4242
t.mu.Lock()
43-
t.inUse += size
44-
if t.inUse > t.maxInUse {
43+
defer t.mu.Unlock()
44+
45+
if t.inUse += size; t.inUse > t.maxInUse {
4546
t.maxInUse = t.inUse
4647
}
47-
t.mu.Unlock()
48-
49-
oldRelease := inMsg.Release
50-
inMsg.Release = capnp.ReleaseFunc(func() {
51-
oldRelease()
52-
t.mu.Lock()
53-
defer t.mu.Unlock()
54-
t.inUse -= size
55-
})
56-
return inMsg, err
48+
49+
return releaseHook{
50+
t: t,
51+
IncomingMessage: inMsg,
52+
}, nil
53+
}
54+
55+
type releaseHook struct {
56+
t *measuringTransport
57+
size uint64
58+
transport.IncomingMessage
59+
}
60+
61+
func (rh releaseHook) Release() {
62+
rh.IncomingMessage.Release()
63+
64+
rh.t.mu.Lock()
65+
rh.t.inUse -= rh.size
66+
rh.t.mu.Lock()
5767
}
5868

5969
// Test that attaching a fixed-size FlowLimiter results in actually limiting the

rpc/level0_test.go

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,8 @@ func TestSendBootstrapCall(t *testing.T) {
319319
if err != nil {
320320
t.Fatal("p2.NewMessage():", err)
321321
}
322-
iptr := capnp.NewInterface(outMsg.Message.Segment(), 0)
323-
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
322+
iptr := capnp.NewInterface(outMsg.Message().Segment(), 0)
323+
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
324324
Which: rpccp.Message_Which_return,
325325
Return: &rpcReturn{
326326
AnswerID: qid,
@@ -416,12 +416,12 @@ func TestSendBootstrapCall(t *testing.T) {
416416
if err != nil {
417417
t.Fatal("p2.NewMessage():", err)
418418
}
419-
resp, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8})
419+
resp, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8})
420420
if err != nil {
421421
t.Fatal("capnp.NewStruct:", err)
422422
}
423423
resp.SetUint64(0, 0xdeadbeef)
424-
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
424+
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
425425
Which: rpccp.Message_Which_return,
426426
Return: &rpcReturn{
427427
AnswerID: qid,
@@ -530,8 +530,8 @@ func TestSendBootstrapCallException(t *testing.T) {
530530
if err != nil {
531531
t.Fatal("p2.NewMessage():", err)
532532
}
533-
iptr := capnp.NewInterface(outMsg.Message.Segment(), 0)
534-
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
533+
iptr := capnp.NewInterface(outMsg.Message().Segment(), 0)
534+
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
535535
Which: rpccp.Message_Which_return,
536536
Return: &rpcReturn{
537537
AnswerID: qid,
@@ -755,12 +755,12 @@ func TestSendBootstrapPipelineCall(t *testing.T) {
755755
if err != nil {
756756
t.Fatal("p2.NewMessage():", err)
757757
}
758-
resp, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8})
758+
resp, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8})
759759
if err != nil {
760760
t.Fatal("capnp.NewStruct:", err)
761761
}
762762
resp.SetUint64(0, 0xdeadbeef)
763-
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
763+
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
764764
Which: rpccp.Message_Which_return,
765765
Return: &rpcReturn{
766766
AnswerID: qid,
@@ -963,12 +963,12 @@ func TestRecvBootstrapCall(t *testing.T) {
963963
if err != nil {
964964
t.Fatal("p2.NewMessage():", err)
965965
}
966-
params, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8})
966+
params, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8})
967967
if err != nil {
968968
t.Fatal("capnp.NewStruct:", err)
969969
}
970970
params.SetUint32(0, 0x2a2b)
971-
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
971+
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
972972
Which: rpccp.Message_Which_call,
973973
Call: &rpcCall{
974974
QuestionID: callQID,
@@ -1114,12 +1114,12 @@ func TestRecvBootstrapCallException(t *testing.T) {
11141114
if err != nil {
11151115
t.Fatal("p2.NewMessage():", err)
11161116
}
1117-
params, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8})
1117+
params, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8})
11181118
if err != nil {
11191119
t.Fatal("capnp.NewStruct:", err)
11201120
}
11211121
params.SetUint32(0, 0x2a2b)
1122-
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
1122+
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
11231123
Which: rpccp.Message_Which_call,
11241124
Call: &rpcCall{
11251125
QuestionID: callQID,
@@ -1258,12 +1258,12 @@ func TestRecvBootstrapPipelineCall(t *testing.T) {
12581258
if err != nil {
12591259
t.Fatal("p2.NewMessage():", err)
12601260
}
1261-
params, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8})
1261+
params, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8})
12621262
if err != nil {
12631263
t.Fatal("capnp.NewStruct:", err)
12641264
}
12651265
params.SetUint32(0, 0x2a2b)
1266-
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
1266+
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
12671267
Which: rpccp.Message_Which_call,
12681268
Call: &rpcCall{
12691269
QuestionID: callQID,
@@ -1452,8 +1452,8 @@ func TestCallOnClosedConn(t *testing.T) {
14521452
if err != nil {
14531453
t.Fatal("p2.NewMessage():", err)
14541454
}
1455-
iptr := capnp.NewInterface(outMsg.Message.Segment(), 0)
1456-
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
1455+
iptr := capnp.NewInterface(outMsg.Message().Segment(), 0)
1456+
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
14571457
Which: rpccp.Message_Which_return,
14581458
Return: &rpcReturn{
14591459
AnswerID: qid,
@@ -1593,7 +1593,7 @@ func TestRecvCancel(t *testing.T) {
15931593
if err != nil {
15941594
t.Fatal("p2.NewMessage():", err)
15951595
}
1596-
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
1596+
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
15971597
Which: rpccp.Message_Which_call,
15981598
Call: &rpcCall{
15991599
QuestionID: callQID,
@@ -1732,8 +1732,8 @@ func TestSendCancel(t *testing.T) {
17321732
if err != nil {
17331733
t.Fatal("p2.NewMessage():", err)
17341734
}
1735-
iptr := capnp.NewInterface(outMsg.Message.Segment(), 0)
1736-
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
1735+
iptr := capnp.NewInterface(outMsg.Message().Segment(), 0)
1736+
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
17371737
Which: rpccp.Message_Which_return,
17381738
Return: &rpcReturn{
17391739
AnswerID: bootQID,
@@ -2046,7 +2046,7 @@ func sendMessage(ctx context.Context, t rpc.Transport, msg *rpcMessage) error {
20462046
return fmt.Errorf("send message: %v", err)
20472047
}
20482048
defer outMsg.Release()
2049-
if err := pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), msg); err != nil {
2049+
if err := pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), msg); err != nil {
20502050
return fmt.Errorf("send message: %v", err)
20512051
}
20522052
if err := outMsg.Send(); err != nil {
@@ -2061,7 +2061,7 @@ func recvMessage(ctx context.Context, t rpc.Transport) (*rpcMessage, capnp.Relea
20612061
return nil, nil, err
20622062
}
20632063
r := new(rpcMessage)
2064-
if err := pogs.Extract(r, rpccp.Message_TypeID, capnp.Struct(inMsg.Message)); err != nil {
2064+
if err := pogs.Extract(r, rpccp.Message_TypeID, capnp.Struct(inMsg.Message())); err != nil {
20652065
inMsg.Release()
20662066
return nil, nil, fmt.Errorf("extract RPC message: %v", err)
20672067
}

0 commit comments

Comments
 (0)