Skip to content

Commit 62e782f

Browse files
author
bcmills
committed
proto: Fix a Marshal race on messages with extensions.
(*Buffer).enc_exts was not acquiring a necessary lock when writing lazily-decoded extensions back to the map. PiperOrigin-RevId: 139345543
1 parent 224aaba commit 62e782f

File tree

4 files changed

+39
-20
lines changed

4 files changed

+39
-20
lines changed

jsonpb/jsonpb.go

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -585,14 +585,7 @@ func (u *Unmarshaler) unmarshalValue(target reflect.Value, inputValue json.RawMe
585585
case "Any":
586586
return fmt.Errorf("unmarshaling Any not supported yet")
587587
case "Duration":
588-
ivStr := string(inputValue)
589-
if ivStr == "null" {
590-
target.Field(0).SetInt(0)
591-
target.Field(1).SetInt(0)
592-
return nil
593-
}
594-
595-
unq, err := strconv.Unquote(ivStr)
588+
unq, err := strconv.Unquote(string(inputValue))
596589
if err != nil {
597590
return err
598591
}
@@ -607,14 +600,7 @@ func (u *Unmarshaler) unmarshalValue(target reflect.Value, inputValue json.RawMe
607600
target.Field(1).SetInt(ns)
608601
return nil
609602
case "Timestamp":
610-
ivStr := string(inputValue)
611-
if ivStr == "null" {
612-
target.Field(0).SetInt(0)
613-
target.Field(1).SetInt(0)
614-
return nil
615-
}
616-
617-
unq, err := strconv.Unquote(ivStr)
603+
unq, err := strconv.Unquote(string(inputValue))
618604
if err != nil {
619605
return err
620606
}

jsonpb/jsonpb_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,11 +467,9 @@ var unmarshalingTests = []struct {
467467
{"camelName input", Unmarshaler{}, `{"oBool":true}`, &pb.Simple{OBool: proto.Bool(true)}},
468468

469469
{"Duration", Unmarshaler{}, `{"dur":"3.000s"}`, &pb.KnownTypes{Dur: &durpb.Duration{Seconds: 3}}},
470-
{"null Duration", Unmarshaler{}, `{"dur":null}`, &pb.KnownTypes{Dur: &durpb.Duration{Seconds: 0}}},
471470
{"Timestamp", Unmarshaler{}, `{"ts":"2014-05-13T16:53:20.021Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: 14e8, Nanos: 21e6}}},
472471
{"PreEpochTimestamp", Unmarshaler{}, `{"ts":"1969-12-31T23:59:58.999999995Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: -2, Nanos: 999999995}}},
473472
{"ZeroTimeTimestamp", Unmarshaler{}, `{"ts":"0001-01-01T00:00:00Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: -62135596800, Nanos: 0}}},
474-
{"null Timestamp", Unmarshaler{}, `{"ts":null}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: 0, Nanos: 0}}},
475473

476474
{"DoubleValue", Unmarshaler{}, `{"dbl":1.2}`, &pb.KnownTypes{Dbl: &wpb.DoubleValue{Value: 1.2}}},
477475
{"FloatValue", Unmarshaler{}, `{"flt":1.2}`, &pb.KnownTypes{Flt: &wpb.FloatValue{Value: 1.2}}},

proto/encode.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,10 +1075,17 @@ func (o *Buffer) enc_map(p *Properties, base structPointer) error {
10751075

10761076
func (o *Buffer) enc_exts(p *Properties, base structPointer) error {
10771077
exts := structPointer_Extensions(base, p.field)
1078-
if err := encodeExtensions(exts); err != nil {
1078+
1079+
v, mu := exts.extensionsRead()
1080+
if v == nil {
1081+
return nil
1082+
}
1083+
1084+
mu.Lock()
1085+
defer mu.Unlock()
1086+
if err := encodeExtensionsMap(v); err != nil {
10791087
return err
10801088
}
1081-
v, _ := exts.extensionsRead()
10821089

10831090
return o.enc_map_body(v)
10841091
}

proto/extensions_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import (
4040

4141
"github.com/golang/protobuf/proto"
4242
pb "github.com/golang/protobuf/proto/testdata"
43+
"golang.org/x/sync/errgroup"
4344
)
4445

4546
func TestGetExtensionsWithMissingExtensions(t *testing.T) {
@@ -506,3 +507,30 @@ func TestClearAllExtensions(t *testing.T) {
506507
t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
507508
}
508509
}
510+
511+
func TestMarshalRace(t *testing.T) {
512+
// unregistered extension
513+
desc := &proto.ExtensionDesc{
514+
ExtendedType: (*pb.MyMessage)(nil),
515+
ExtensionType: (*bool)(nil),
516+
Field: 101010100,
517+
Name: "emptyextension",
518+
Tag: "varint,0,opt",
519+
}
520+
521+
m := &pb.MyMessage{Count: proto.Int32(4)}
522+
if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
523+
t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
524+
}
525+
526+
var g errgroup.Group
527+
for n := 3; n > 0; n-- {
528+
g.Go(func() error {
529+
_, err := proto.Marshal(m)
530+
return err
531+
})
532+
}
533+
if err := g.Wait(); err != nil {
534+
t.Fatal(err)
535+
}
536+
}

0 commit comments

Comments
 (0)