Skip to content

Commit 6fc92df

Browse files
author
Achille Roussel
committed
add check for addressable types
1 parent cf97deb commit 6fc92df

File tree

2 files changed

+39
-25
lines changed

2 files changed

+39
-25
lines changed

json/codec.go

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func typeid(t reflect.Type) unsafe.Pointer {
5959
}
6060

6161
func constructCachedCodec(t reflect.Type, cache map[unsafe.Pointer]codec) codec {
62-
c := constructCodec(t, map[reflect.Type]*structType{})
62+
c := constructCodec(t, map[reflect.Type]*structType{}, t.Kind() == reflect.Ptr)
6363

6464
if inlined(t) {
6565
c.encode = constructInlineValueEncodeFunc(c.encode)
@@ -69,7 +69,7 @@ func constructCachedCodec(t reflect.Type, cache map[unsafe.Pointer]codec) codec
6969
return c
7070
}
7171

72-
func constructCodec(t reflect.Type, seen map[reflect.Type]*structType) (c codec) {
72+
func constructCodec(t reflect.Type, seen map[reflect.Type]*structType, canAddr bool) (c codec) {
7373
switch t {
7474
case nullType, nil:
7575
c = codec{encode: encoder.encodeNull, decode: decoder.decodeNull}
@@ -159,7 +159,7 @@ func constructCodec(t reflect.Type, seen map[reflect.Type]*structType) (c codec)
159159
c = codec{encode: encoder.encodeInterface, decode: constructNonEmptyInterfaceDecoderFunc(t)}
160160

161161
case reflect.Array:
162-
c = constructArrayCodec(t, seen)
162+
c = constructArrayCodec(t, seen, canAddr)
163163

164164
case reflect.Slice:
165165
c = constructSliceCodec(t, seen)
@@ -168,7 +168,7 @@ func constructCodec(t reflect.Type, seen map[reflect.Type]*structType) (c codec)
168168
c = constructMapCodec(t, seen)
169169

170170
case reflect.Struct:
171-
c = constructStructCodec(t, seen)
171+
c = constructStructCodec(t, seen, canAddr)
172172

173173
case reflect.Ptr:
174174
c = constructPointerCodec(t, seen)
@@ -179,27 +179,34 @@ func constructCodec(t reflect.Type, seen map[reflect.Type]*structType) (c codec)
179179

180180
p := reflect.PtrTo(t)
181181

182+
if canAddr {
183+
switch {
184+
case p.Implements(jsonMarshalerType):
185+
c.encode = constructJSONMarshalerEncodeFunc(t, true)
186+
case p.Implements(textMarshalerType):
187+
c.encode = constructTextMarshalerEncodeFunc(t, true)
188+
}
189+
}
190+
182191
switch {
183192
case t.Implements(jsonMarshalerType):
184193
c.encode = constructJSONMarshalerEncodeFunc(t, false)
185-
186194
case t.Implements(textMarshalerType):
187195
c.encode = constructTextMarshalerEncodeFunc(t, false)
188196
}
189197

190198
switch {
191199
case p.Implements(jsonUnmarshalerType):
192200
c.decode = constructJSONUnmarshalerDecodeFunc(t, true)
193-
194201
case p.Implements(textUnmarshalerType):
195202
c.decode = constructTextUnmarshalerDecodeFunc(t, true)
196203
}
197204

198205
return
199206
}
200207

201-
func constructStringCodec(t reflect.Type, seen map[reflect.Type]*structType) codec {
202-
c := constructCodec(t, seen)
208+
func constructStringCodec(t reflect.Type, seen map[reflect.Type]*structType, canAddr bool) codec {
209+
c := constructCodec(t, seen, canAddr)
203210
return codec{
204211
encode: constructStringEncodeFunc(c.encode),
205212
decode: constructStringDecodeFunc(c.decode),
@@ -224,9 +231,9 @@ func constructStringToIntDecodeFunc(t reflect.Type, decode decodeFunc) decodeFun
224231
}
225232
}
226233

227-
func constructArrayCodec(t reflect.Type, seen map[reflect.Type]*structType) codec {
234+
func constructArrayCodec(t reflect.Type, seen map[reflect.Type]*structType, canAddr bool) codec {
228235
e := t.Elem()
229-
c := constructCodec(e, seen)
236+
c := constructCodec(e, seen, canAddr)
230237
s := alignedSize(e)
231238
return codec{
232239
encode: constructArrayEncodeFunc(s, t, c.encode),
@@ -253,12 +260,12 @@ func constructSliceCodec(t reflect.Type, seen map[reflect.Type]*structType) code
253260
s := alignedSize(e)
254261

255262
if e.Kind() == reflect.Uint8 {
256-
p := reflect.PtrTo(e)
257-
c := codec{}
258-
259263
// Go 1.7+ behavior: slices of byte types (and aliases) may override the
260264
// default encoding and decoding behaviors by implementing marshaler and
261265
// unmarshaler interfaces.
266+
p := reflect.PtrTo(e)
267+
c := codec{}
268+
262269
switch {
263270
case e.Implements(jsonMarshalerType):
264271
c.encode = constructJSONMarshalerEncodeFunc(e, false)
@@ -296,7 +303,7 @@ func constructSliceCodec(t reflect.Type, seen map[reflect.Type]*structType) code
296303
return c
297304
}
298305

299-
c := constructCodec(e, seen)
306+
c := constructCodec(e, seen, true)
300307
return codec{
301308
encode: constructSliceEncodeFunc(s, t, c.encode),
302309
decode: constructSliceDecodeFunc(s, t, c.decode),
@@ -336,7 +343,7 @@ func constructMapCodec(t reflect.Type, seen map[reflect.Type]*structType) codec
336343
}
337344

338345
kc := codec{}
339-
vc := constructCodec(v, seen)
346+
vc := constructCodec(v, seen, false)
340347

341348
if k.Implements(textMarshalerType) || reflect.PtrTo(k).Implements(textUnmarshalerType) {
342349
kc.encode = constructTextMarshalerEncodeFunc(k, false)
@@ -366,7 +373,7 @@ func constructMapCodec(t reflect.Type, seen map[reflect.Type]*structType) codec
366373
reflect.Int16,
367374
reflect.Int32,
368375
reflect.Int64:
369-
kc = constructStringCodec(k, seen)
376+
kc = constructStringCodec(k, seen, false)
370377

371378
sortKeys = func(keys []reflect.Value) {
372379
sort.Slice(keys, func(i, j int) bool { return keys[i].Int() < keys[j].Int() })
@@ -378,7 +385,7 @@ func constructMapCodec(t reflect.Type, seen map[reflect.Type]*structType) codec
378385
reflect.Uint16,
379386
reflect.Uint32,
380387
reflect.Uint64:
381-
kc = constructStringCodec(k, seen)
388+
kc = constructStringCodec(k, seen, false)
382389

383390
sortKeys = func(keys []reflect.Value) {
384391
sort.Slice(keys, func(i, j int) bool { return keys[i].Uint() < keys[j].Uint() })
@@ -416,15 +423,15 @@ func constructMapDecodeFunc(t reflect.Type, decodeKey, decodeValue decodeFunc) d
416423
}
417424
}
418425

419-
func constructStructCodec(t reflect.Type, seen map[reflect.Type]*structType) codec {
420-
st := constructStructType(t, seen)
426+
func constructStructCodec(t reflect.Type, seen map[reflect.Type]*structType, canAddr bool) codec {
427+
st := constructStructType(t, seen, canAddr)
421428
return codec{
422429
encode: constructStructEncodeFunc(st),
423430
decode: constructStructDecodeFunc(st),
424431
}
425432
}
426433

427-
func constructStructType(t reflect.Type, seen map[reflect.Type]*structType) *structType {
434+
func constructStructType(t reflect.Type, seen map[reflect.Type]*structType, canAddr bool) *structType {
428435
// Used for preventing infinite recursion on types that have pointers to
429436
// themselves.
430437
st := seen[t]
@@ -438,7 +445,7 @@ func constructStructType(t reflect.Type, seen map[reflect.Type]*structType) *str
438445
}
439446

440447
seen[t] = st
441-
st.fields = appendStructFields(st.fields, t, 0, seen)
448+
st.fields = appendStructFields(st.fields, t, 0, seen, canAddr)
442449

443450
for i := range st.fields {
444451
f := &st.fields[i]
@@ -486,7 +493,7 @@ func constructEmbeddedStructPointerDecodeFunc(t reflect.Type, unexported bool, o
486493
}
487494
}
488495

489-
func appendStructFields(fields []structField, t reflect.Type, offset uintptr, seen map[reflect.Type]*structType) []structField {
496+
func appendStructFields(fields []structField, t reflect.Type, offset uintptr, seen map[reflect.Type]*structType, canAddr bool) []structField {
490497
type embeddedField struct {
491498
index int
492499
offset uintptr
@@ -551,7 +558,7 @@ func appendStructFields(fields []structField, t reflect.Type, offset uintptr, se
551558
// up by offset from the address of the wrapping object, so we
552559
// simply add the embedded struct fields to the list of fields
553560
// of the current struct type.
554-
subtype := constructStructType(typ, seen)
561+
subtype := constructStructType(typ, seen, canAddr)
555562

556563
for j := range subtype.fields {
557564
embedded = append(embedded, embeddedField{
@@ -572,7 +579,7 @@ func appendStructFields(fields []structField, t reflect.Type, offset uintptr, se
572579
}
573580
}
574581

575-
codec := constructCodec(f.Type, seen)
582+
codec := constructCodec(f.Type, seen, canAddr)
576583

577584
if stringify {
578585
// https://golang.org/pkg/encoding/json/#Marshal
@@ -686,7 +693,7 @@ func encodeString(s string, flags AppendFlags) string {
686693

687694
func constructPointerCodec(t reflect.Type, seen map[reflect.Type]*structType) codec {
688695
e := t.Elem()
689-
c := constructCodec(e, seen)
696+
c := constructCodec(e, seen, true)
690697
return codec{
691698
encode: constructPointerEncodeFunc(e, c.encode),
692699
decode: constructPointerDecodeFunc(e, c.decode),

json/json_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,6 +1286,9 @@ func (*intPtrB) MarshalText() ([]byte, error) {
12861286
return []byte("B"), nil
12871287
}
12881288

1289+
type structA struct{ I intPtrA }
1290+
type structB struct{ I intPtrB }
1291+
12891292
func TestGithubIssue16(t *testing.T) {
12901293
// https://github.com/segmentio/encoding/issues/16
12911294
tests := []struct {
@@ -1302,6 +1305,10 @@ func TestGithubIssue16(t *testing.T) {
13021305
{value: new(intPtrB), output: `"B"`},
13031306
{value: (*intPtrA)(nil), output: `null`},
13041307
{value: (*intPtrB)(nil), output: `null`},
1308+
{value: structA{I: 1}, output: `{"I":1}`},
1309+
{value: structB{I: 2}, output: `{"I":2}`},
1310+
{value: &structA{I: 1}, output: `{"I":"A"}`},
1311+
{value: &structB{I: 2}, output: `{"I":"B"}`},
13051312
}
13061313

13071314
for _, test := range tests {

0 commit comments

Comments
 (0)