Skip to content

Commit f0d880e

Browse files
committed
encoding/json: call MarshalJSON() and MarshalText() defined with pointer receivers even for non-addressable values of non-pointer types on marshalling JSON
1 parent 660e7d6 commit f0d880e

File tree

2 files changed

+100
-28
lines changed

2 files changed

+100
-28
lines changed

src/encoding/json/encode.go

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ func typeEncoder(t reflect.Type) encoderFunc {
363363
}
364364

365365
// Compute the real encoder and replace the indirect func with it.
366-
f = newTypeEncoder(t, true)
366+
f = newTypeEncoder(t)
367367
wg.Done()
368368
encoderCache.Store(t, f)
369369
return f
@@ -375,20 +375,19 @@ var (
375375
)
376376

377377
// newTypeEncoder constructs an encoderFunc for a type.
378-
// The returned encoder only checks CanAddr when allowAddr is true.
379-
func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
378+
func newTypeEncoder(t reflect.Type) encoderFunc {
380379
// If we have a non-pointer value whose type implements
381380
// Marshaler with a value receiver, then we're better off taking
382381
// the address of the value - otherwise we end up with an
383382
// allocation as we cast the value to an interface.
384-
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) {
385-
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
383+
if t.Kind() != reflect.Pointer && reflect.PointerTo(t).Implements(marshalerType) {
384+
return addrMarshalerEncoder
386385
}
387386
if t.Implements(marshalerType) {
388387
return marshalerEncoder
389388
}
390-
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) {
391-
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
389+
if t.Kind() != reflect.Pointer && reflect.PointerTo(t).Implements(textMarshalerType) {
390+
return addrTextMarshalerEncoder
392391
}
393392
if t.Implements(textMarshalerType) {
394393
return textMarshalerEncoder
@@ -451,7 +450,13 @@ func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
451450
}
452451

453452
func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
454-
va := v.Addr()
453+
var va reflect.Value
454+
if v.CanAddr() {
455+
va = v.Addr()
456+
} else {
457+
va = reflect.New(v.Type())
458+
va.Elem().Set(v)
459+
}
455460
if va.IsNil() {
456461
e.WriteString("null")
457462
return
@@ -487,7 +492,13 @@ func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
487492
}
488493

489494
func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
490-
va := v.Addr()
495+
var va reflect.Value
496+
if v.CanAddr() {
497+
va = v.Addr()
498+
} else {
499+
va = reflect.New(v.Type())
500+
va.Elem().Set(v)
501+
}
491502
if va.IsNil() {
492503
e.WriteString("null")
493504
return
@@ -893,25 +904,6 @@ func newPtrEncoder(t reflect.Type) encoderFunc {
893904
return enc.encode
894905
}
895906

896-
type condAddrEncoder struct {
897-
canAddrEnc, elseEnc encoderFunc
898-
}
899-
900-
func (ce condAddrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
901-
if v.CanAddr() {
902-
ce.canAddrEnc(e, v, opts)
903-
} else {
904-
ce.elseEnc(e, v, opts)
905-
}
906-
}
907-
908-
// newCondAddrEncoder returns an encoder that checks whether its value
909-
// CanAddr and delegates to canAddrEnc if so, else to elseEnc.
910-
func newCondAddrEncoder(canAddrEnc, elseEnc encoderFunc) encoderFunc {
911-
enc := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc}
912-
return enc.encode
913-
}
914-
915907
func isValidTag(s string) bool {
916908
if s == "" {
917909
return false

src/encoding/json/encode_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,3 +1219,83 @@ func TestIssue63379(t *testing.T) {
12191219
}
12201220
}
12211221
}
1222+
1223+
type structWithMarshalJSON struct{ v int }
1224+
1225+
func (s *structWithMarshalJSON) MarshalJSON() ([]byte, error) {
1226+
return []byte(fmt.Sprintf(`"marshalled(%d)"`, s.v)), nil
1227+
}
1228+
1229+
var _ = Marshaler(&structWithMarshalJSON{})
1230+
1231+
type embedderJ struct {
1232+
V structWithMarshalJSON
1233+
}
1234+
1235+
func TestMarshalJSONWithPointerJSONMarshalers(t *testing.T) {
1236+
for _, test := range []struct {
1237+
name string
1238+
v interface{}
1239+
expected string
1240+
}{
1241+
{name: "a value with MarshalJSON", v: structWithMarshalJSON{v: 1}, expected: `"marshalled(1)"`},
1242+
{name: "pointer to a value with MarshalJSON", v: &structWithMarshalJSON{v: 1}, expected: `"marshalled(1)"`},
1243+
{name: "a map with a value with MarshalJSON", v: map[string]interface{}{"v": structWithMarshalJSON{v: 1}}, expected: `{"v":"marshalled(1)"}`},
1244+
{name: "a map with a pointer to a value with MarshalJSON", v: map[string]interface{}{"v": &structWithMarshalJSON{v: 1}}, expected: `{"v":"marshalled(1)"}`},
1245+
{name: "a slice of maps with a value with MarshalJSON", v: []map[string]interface{}{{"v": structWithMarshalJSON{v: 1}}}, expected: `[{"v":"marshalled(1)"}]`},
1246+
{name: "a slice of maps with a pointer to a value with MarshalJSON", v: []map[string]interface{}{{"v": &structWithMarshalJSON{v: 1}}}, expected: `[{"v":"marshalled(1)"}]`},
1247+
{name: "a struct with a value with MarshalJSON", v: embedderJ{V: structWithMarshalJSON{v: 1}}, expected: `{"V":"marshalled(1)"}`},
1248+
{name: "a slice of structs with a value with MarshalJSON", v: []embedderJ{{V: structWithMarshalJSON{v: 1}}}, expected: `[{"V":"marshalled(1)"}]`},
1249+
} {
1250+
test := test
1251+
t.Run(test.name, func(t *testing.T) {
1252+
result, err := Marshal(test.v)
1253+
if err != nil {
1254+
t.Fatalf("Marshal error: %v", err)
1255+
}
1256+
if string(result) != test.expected {
1257+
t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, test.expected)
1258+
}
1259+
})
1260+
}
1261+
}
1262+
1263+
type structWithMarshalText struct{ v int }
1264+
1265+
func (s *structWithMarshalText) MarshalText() ([]byte, error) {
1266+
return []byte(fmt.Sprintf("marshalled(%d)", s.v)), nil
1267+
}
1268+
1269+
var _ = encoding.TextMarshaler(&structWithMarshalText{})
1270+
1271+
type embedderT struct {
1272+
V structWithMarshalText
1273+
}
1274+
1275+
func TestMarshalJSONWithPointerTextMarshalers(t *testing.T) {
1276+
for _, test := range []struct {
1277+
name string
1278+
v interface{}
1279+
expected string
1280+
}{
1281+
{name: "a value with MarshalText", v: structWithMarshalText{v: 1}, expected: `"marshalled(1)"`},
1282+
{name: "pointer to a value with MarshalText", v: &structWithMarshalText{v: 1}, expected: `"marshalled(1)"`},
1283+
{name: "a map with a value with MarshalText", v: map[string]interface{}{"v": structWithMarshalText{v: 1}}, expected: `{"v":"marshalled(1)"}`},
1284+
{name: "a map with a pointer to a value with MarshalText", v: map[string]interface{}{"v": &structWithMarshalText{v: 1}}, expected: `{"v":"marshalled(1)"}`},
1285+
{name: "a slice of maps with a value with MarshalText", v: []map[string]interface{}{{"v": structWithMarshalText{v: 1}}}, expected: `[{"v":"marshalled(1)"}]`},
1286+
{name: "a slice of maps with a pointer to a value with MarshalText", v: []map[string]interface{}{{"v": &structWithMarshalText{v: 1}}}, expected: `[{"v":"marshalled(1)"}]`},
1287+
{name: "a struct with a value with MarshalText", v: embedderT{V: structWithMarshalText{v: 1}}, expected: `{"V":"marshalled(1)"}`},
1288+
{name: "a slice of structs with a value with MarshalText", v: []embedderT{{V: structWithMarshalText{v: 1}}}, expected: `[{"V":"marshalled(1)"}]`},
1289+
} {
1290+
test := test
1291+
t.Run(test.name, func(t *testing.T) {
1292+
result, err := Marshal(test.v)
1293+
if err != nil {
1294+
t.Fatalf("Marshal error: %v", err)
1295+
}
1296+
if string(result) != test.expected {
1297+
t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, test.expected)
1298+
}
1299+
})
1300+
}
1301+
}

0 commit comments

Comments
 (0)