diff --git a/bson/benchmark_test.go b/bson/benchmark_test.go index a2190518d6..c77f2dde6b 100644 --- a/bson/benchmark_test.go +++ b/bson/benchmark_test.go @@ -7,12 +7,15 @@ package bson import ( + "bytes" "compress/gzip" "encoding/json" "fmt" + "io" "io/ioutil" "os" "path" + "sync" "testing" ) @@ -129,10 +132,20 @@ var nestedInstance = nestedtest1{ const extendedBSONDir = "../testdata/extended_bson" +var ( + extJSONFiles map[string]map[string]interface{} + extJSONFilesMu sync.Mutex +) + // readExtJSONFile reads the GZIP-compressed extended JSON document from the given filename in the // "extended BSON" test data directory (../testdata/extended_bson) and returns it as a // map[string]interface{}. It panics on any errors. func readExtJSONFile(filename string) map[string]interface{} { + extJSONFilesMu.Lock() + defer extJSONFilesMu.Unlock() + if v, ok := extJSONFiles[filename]; ok { + return v + } filePath := path.Join(extendedBSONDir, filename) file, err := os.Open(filePath) if err != nil { @@ -161,6 +174,10 @@ func readExtJSONFile(filename string) map[string]interface{} { panic(fmt.Sprintf("error unmarshalling extended JSON: %s", err)) } + if extJSONFiles == nil { + extJSONFiles = make(map[string]map[string]interface{}) + } + extJSONFiles[filename] = v return v } @@ -305,3 +322,128 @@ func BenchmarkUnmarshal(b *testing.B) { }) } } + +// The following benchmarks are copied from the Go standard library's +// encoding/json package. + +type codeResponse struct { + Tree *codeNode `json:"tree"` + Username string `json:"username"` +} + +type codeNode struct { + Name string `json:"name"` + Kids []*codeNode `json:"kids"` + CLWeight float64 `json:"cl_weight"` + Touches int `json:"touches"` + MinT int64 `json:"min_t"` + MaxT int64 `json:"max_t"` + MeanT int64 `json:"mean_t"` +} + +var codeJSON []byte +var codeBSON []byte +var codeStruct codeResponse + +func codeInit() { + f, err := os.Open("testdata/code.json.gz") + if err != nil { + panic(err) + } + defer f.Close() + gz, err := gzip.NewReader(f) + if err != nil { + panic(err) + } + data, err := io.ReadAll(gz) + if err != nil { + panic(err) + } + + codeJSON = data + + if err := json.Unmarshal(codeJSON, &codeStruct); err != nil { + panic("json.Unmarshal code.json: " + err.Error()) + } + + if data, err = json.Marshal(&codeStruct); err != nil { + panic("json.Marshal code.json: " + err.Error()) + } + + if codeBSON, err = Marshal(&codeStruct); err != nil { + panic("Marshal code.json: " + err.Error()) + } + + if !bytes.Equal(data, codeJSON) { + println("different lengths", len(data), len(codeJSON)) + for i := 0; i < len(data) && i < len(codeJSON); i++ { + if data[i] != codeJSON[i] { + println("re-marshal: changed at byte", i) + println("orig: ", string(codeJSON[i-10:i+10])) + println("new: ", string(data[i-10:i+10])) + break + } + } + panic("re-marshal code.json: different result") + } +} + +func BenchmarkCodeUnmarshal(b *testing.B) { + b.ReportAllocs() + if codeJSON == nil { + b.StopTimer() + codeInit() + b.StartTimer() + } + b.Run("BSON", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + var r codeResponse + if err := Unmarshal(codeBSON, &r); err != nil { + b.Fatal("Unmarshal:", err) + } + } + }) + b.SetBytes(int64(len(codeBSON))) + }) + b.Run("JSON", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + var r codeResponse + if err := json.Unmarshal(codeJSON, &r); err != nil { + b.Fatal("json.Unmarshal:", err) + } + } + }) + b.SetBytes(int64(len(codeJSON))) + }) +} + +func BenchmarkCodeMarshal(b *testing.B) { + b.ReportAllocs() + if codeJSON == nil { + b.StopTimer() + codeInit() + b.StartTimer() + } + b.Run("BSON", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := Marshal(&codeStruct); err != nil { + b.Fatal("Marshal:", err) + } + } + }) + b.SetBytes(int64(len(codeBSON))) + }) + b.Run("JSON", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := json.Marshal(&codeStruct); err != nil { + b.Fatal("json.Marshal:", err) + } + } + }) + b.SetBytes(int64(len(codeJSON))) + }) +} diff --git a/bson/bsoncodec/codec_cache.go b/bson/bsoncodec/codec_cache.go new file mode 100644 index 0000000000..844b50299f --- /dev/null +++ b/bson/bsoncodec/codec_cache.go @@ -0,0 +1,166 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsoncodec + +import ( + "reflect" + "sync" + "sync/atomic" +) + +// Runtime check that the kind encoder and decoder caches can store any valid +// reflect.Kind constant. +func init() { + if s := reflect.Kind(len(kindEncoderCache{}.entries)).String(); s != "kind27" { + panic("The capacity of kindEncoderCache is too small.\n" + + "This is due to a new type being added to reflect.Kind.") + } +} + +// statically assert array size +var _ = (kindEncoderCache{}).entries[reflect.UnsafePointer] +var _ = (kindDecoderCache{}).entries[reflect.UnsafePointer] + +type typeEncoderCache struct { + cache sync.Map // map[reflect.Type]ValueEncoder +} + +func (c *typeEncoderCache) Store(rt reflect.Type, enc ValueEncoder) { + c.cache.Store(rt, enc) +} + +func (c *typeEncoderCache) Load(rt reflect.Type) (ValueEncoder, bool) { + if v, _ := c.cache.Load(rt); v != nil { + return v.(ValueEncoder), true + } + return nil, false +} + +func (c *typeEncoderCache) LoadOrStore(rt reflect.Type, enc ValueEncoder) ValueEncoder { + if v, loaded := c.cache.LoadOrStore(rt, enc); loaded { + enc = v.(ValueEncoder) + } + return enc +} + +func (c *typeEncoderCache) Clone() *typeEncoderCache { + cc := new(typeEncoderCache) + c.cache.Range(func(k, v interface{}) bool { + if k != nil && v != nil { + cc.cache.Store(k, v) + } + return true + }) + return cc +} + +type typeDecoderCache struct { + cache sync.Map // map[reflect.Type]ValueDecoder +} + +func (c *typeDecoderCache) Store(rt reflect.Type, dec ValueDecoder) { + c.cache.Store(rt, dec) +} + +func (c *typeDecoderCache) Load(rt reflect.Type) (ValueDecoder, bool) { + if v, _ := c.cache.Load(rt); v != nil { + return v.(ValueDecoder), true + } + return nil, false +} + +func (c *typeDecoderCache) LoadOrStore(rt reflect.Type, dec ValueDecoder) ValueDecoder { + if v, loaded := c.cache.LoadOrStore(rt, dec); loaded { + dec = v.(ValueDecoder) + } + return dec +} + +func (c *typeDecoderCache) Clone() *typeDecoderCache { + cc := new(typeDecoderCache) + c.cache.Range(func(k, v interface{}) bool { + if k != nil && v != nil { + cc.cache.Store(k, v) + } + return true + }) + return cc +} + +// atomic.Value requires that all calls to Store() have the same concrete type +// so we wrap the ValueEncoder with a kindEncoderCacheEntry to ensure the type +// is always the same (since different concrete types may implement the +// ValueEncoder interface). +type kindEncoderCacheEntry struct { + enc ValueEncoder +} + +type kindEncoderCache struct { + entries [reflect.UnsafePointer + 1]atomic.Value // *kindEncoderCacheEntry +} + +func (c *kindEncoderCache) Store(rt reflect.Kind, enc ValueEncoder) { + if enc != nil && rt < reflect.Kind(len(c.entries)) { + c.entries[rt].Store(&kindEncoderCacheEntry{enc: enc}) + } +} + +func (c *kindEncoderCache) Load(rt reflect.Kind) (ValueEncoder, bool) { + if rt < reflect.Kind(len(c.entries)) { + if ent, ok := c.entries[rt].Load().(*kindEncoderCacheEntry); ok { + return ent.enc, ent.enc != nil + } + } + return nil, false +} + +func (c *kindEncoderCache) Clone() *kindEncoderCache { + cc := new(kindEncoderCache) + for i, v := range c.entries { + if val := v.Load(); val != nil { + cc.entries[i].Store(val) + } + } + return cc +} + +// atomic.Value requires that all calls to Store() have the same concrete type +// so we wrap the ValueDecoder with a kindDecoderCacheEntry to ensure the type +// is always the same (since different concrete types may implement the +// ValueDecoder interface). +type kindDecoderCacheEntry struct { + dec ValueDecoder +} + +type kindDecoderCache struct { + entries [reflect.UnsafePointer + 1]atomic.Value // *kindDecoderCacheEntry +} + +func (c *kindDecoderCache) Store(rt reflect.Kind, dec ValueDecoder) { + if rt < reflect.Kind(len(c.entries)) { + c.entries[rt].Store(&kindDecoderCacheEntry{dec: dec}) + } +} + +func (c *kindDecoderCache) Load(rt reflect.Kind) (ValueDecoder, bool) { + if rt < reflect.Kind(len(c.entries)) { + if ent, ok := c.entries[rt].Load().(*kindDecoderCacheEntry); ok { + return ent.dec, ent.dec != nil + } + } + return nil, false +} + +func (c *kindDecoderCache) Clone() *kindDecoderCache { + cc := new(kindDecoderCache) + for i, v := range c.entries { + if val := v.Load(); val != nil { + cc.entries[i].Store(val) + } + } + return cc +} diff --git a/bson/bsoncodec/codec_cache_test.go b/bson/bsoncodec/codec_cache_test.go new file mode 100644 index 0000000000..054a45df50 --- /dev/null +++ b/bson/bsoncodec/codec_cache_test.go @@ -0,0 +1,176 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsoncodec + +import ( + "reflect" + "strconv" + "strings" + "testing" +) + +// NB(charlie): the array size is a power of 2 because we use the remainder of +// it (mod) in benchmarks and that is faster when the size is a power of 2. +var codecCacheTestTypes = [16]reflect.Type{ + reflect.TypeOf(uint8(0)), + reflect.TypeOf(uint16(0)), + reflect.TypeOf(uint32(0)), + reflect.TypeOf(uint64(0)), + reflect.TypeOf(uint(0)), + reflect.TypeOf(uintptr(0)), + reflect.TypeOf(int8(0)), + reflect.TypeOf(int16(0)), + reflect.TypeOf(int32(0)), + reflect.TypeOf(int64(0)), + reflect.TypeOf(int(0)), + reflect.TypeOf(float32(0)), + reflect.TypeOf(float64(0)), + reflect.TypeOf(true), + reflect.TypeOf(struct{ A int }{}), + reflect.TypeOf(map[int]int{}), +} + +func TestTypeCache(t *testing.T) { + rt := reflect.TypeOf(int(0)) + ec := new(typeEncoderCache) + dc := new(typeDecoderCache) + + codec := new(fakeCodec) + ec.Store(rt, codec) + dc.Store(rt, codec) + if v, ok := ec.Load(rt); !ok || !reflect.DeepEqual(v, codec) { + t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, codec, true) + } + if v, ok := dc.Load(rt); !ok || !reflect.DeepEqual(v, codec) { + t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, codec, true) + } + + // Make sure we overwrite the stored value with nil + ec.Store(rt, nil) + dc.Store(rt, nil) + if v, ok := ec.Load(rt); ok || v != nil { + t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, nil, false) + } + if v, ok := dc.Load(rt); ok || v != nil { + t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, nil, false) + } +} + +func TestTypeCacheClone(t *testing.T) { + codec := new(fakeCodec) + ec1 := new(typeEncoderCache) + dc1 := new(typeDecoderCache) + for _, rt := range codecCacheTestTypes { + ec1.Store(rt, codec) + dc1.Store(rt, codec) + } + ec2 := ec1.Clone() + dc2 := dc1.Clone() + for _, rt := range codecCacheTestTypes { + if v, _ := ec2.Load(rt); !reflect.DeepEqual(v, codec) { + t.Errorf("Load(%s) = %#v; want: %#v", rt, v, codec) + } + if v, _ := dc2.Load(rt); !reflect.DeepEqual(v, codec) { + t.Errorf("Load(%s) = %#v; want: %#v", rt, v, codec) + } + } +} + +func TestKindCacheArray(t *testing.T) { + // Check array bounds + var c kindEncoderCache + codec := new(fakeCodec) + c.Store(reflect.UnsafePointer, codec) // valid + c.Store(reflect.UnsafePointer+1, codec) // ignored + if v, ok := c.Load(reflect.UnsafePointer); !ok || v != codec { + t.Errorf("Load(reflect.UnsafePointer) = %v, %t; want: %v, %t", v, ok, codec, true) + } + if v, ok := c.Load(reflect.UnsafePointer + 1); ok || v != nil { + t.Errorf("Load(reflect.UnsafePointer + 1) = %v, %t; want: %v, %t", v, ok, nil, false) + } + + // Make sure that reflect.UnsafePointer is the last/largest reflect.Type. + // + // The String() method of invalid reflect.Type types are of the format + // "kind{NUMBER}". + for rt := reflect.UnsafePointer + 1; rt < reflect.UnsafePointer+16; rt++ { + s := rt.String() + if !strings.Contains(s, strconv.Itoa(int(rt))) { + t.Errorf("reflect.Type(%d) appears to be valid: %q", rt, s) + } + } +} + +func TestKindCacheClone(t *testing.T) { + e1 := new(kindEncoderCache) + d1 := new(kindDecoderCache) + codec := new(fakeCodec) + for k := reflect.Invalid; k <= reflect.UnsafePointer; k++ { + e1.Store(k, codec) + d1.Store(k, codec) + } + e2 := e1.Clone() + for k := reflect.Invalid; k <= reflect.UnsafePointer; k++ { + v1, ok1 := e1.Load(k) + v2, ok2 := e2.Load(k) + if ok1 != ok2 || !reflect.DeepEqual(v1, v2) || v1 == nil || v2 == nil { + t.Errorf("Encoder(%s): %#v, %t != %#v, %t", k, v1, ok1, v2, ok2) + } + } + d2 := d1.Clone() + for k := reflect.Invalid; k <= reflect.UnsafePointer; k++ { + v1, ok1 := d1.Load(k) + v2, ok2 := d2.Load(k) + if ok1 != ok2 || !reflect.DeepEqual(v1, v2) || v1 == nil || v2 == nil { + t.Errorf("Decoder(%s): %#v, %t != %#v, %t", k, v1, ok1, v2, ok2) + } + } +} + +func TestKindCacheEncoderNilEncoder(t *testing.T) { + t.Run("Encoder", func(t *testing.T) { + c := new(kindEncoderCache) + c.Store(reflect.Invalid, ValueEncoder(nil)) + v, ok := c.Load(reflect.Invalid) + if v != nil || ok { + t.Errorf("Load of nil ValueEncoder should return: nil, false; got: %v, %t", v, ok) + } + }) + t.Run("Decoder", func(t *testing.T) { + c := new(kindDecoderCache) + c.Store(reflect.Invalid, ValueDecoder(nil)) + v, ok := c.Load(reflect.Invalid) + if v != nil || ok { + t.Errorf("Load of nil ValueDecoder should return: nil, false; got: %v, %t", v, ok) + } + }) +} + +func BenchmarkEncoderCacheLoad(b *testing.B) { + c := new(typeEncoderCache) + codec := new(fakeCodec) + typs := codecCacheTestTypes + for _, t := range typs { + c.Store(t, codec) + } + b.RunParallel(func(pb *testing.PB) { + for i := 0; pb.Next(); i++ { + c.Load(typs[i%len(typs)]) + } + }) +} + +func BenchmarkEncoderCacheStore(b *testing.B) { + c := new(typeEncoderCache) + codec := new(fakeCodec) + b.RunParallel(func(pb *testing.PB) { + typs := codecCacheTestTypes + for i := 0; pb.Next(); i++ { + c.Store(typs[i%len(typs)], codec) + } + }) +} diff --git a/bson/bsoncodec/pointer_codec.go b/bson/bsoncodec/pointer_codec.go index a1bf9c3e2b..e5923230b0 100644 --- a/bson/bsoncodec/pointer_codec.go +++ b/bson/bsoncodec/pointer_codec.go @@ -8,7 +8,6 @@ package bsoncodec import ( "reflect" - "sync" "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/bsontype" @@ -22,9 +21,8 @@ var _ ValueDecoder = &PointerCodec{} // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the // PointerCodec registered. type PointerCodec struct { - ecache map[reflect.Type]ValueEncoder - dcache map[reflect.Type]ValueDecoder - l sync.RWMutex + ecache typeEncoderCache + dcache typeDecoderCache } // NewPointerCodec returns a PointerCodec that has been initialized. @@ -32,10 +30,7 @@ type PointerCodec struct { // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the // PointerCodec registered. func NewPointerCodec() *PointerCodec { - return &PointerCodec{ - ecache: make(map[reflect.Type]ValueEncoder), - dcache: make(map[reflect.Type]ValueDecoder), - } + return &PointerCodec{} } // EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil @@ -52,24 +47,19 @@ func (pc *PointerCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val return vw.WriteNull() } - pc.l.RLock() - enc, ok := pc.ecache[val.Type()] - pc.l.RUnlock() - if ok { - if enc == nil { - return ErrNoEncoder{Type: val.Type()} + typ := val.Type() + if v, ok := pc.ecache.Load(typ); ok { + if v == nil { + return ErrNoEncoder{Type: typ} } - return enc.EncodeValue(ec, vw, val.Elem()) + return v.EncodeValue(ec, vw, val.Elem()) } - - enc, err := ec.LookupEncoder(val.Type().Elem()) - pc.l.Lock() - pc.ecache[val.Type()] = enc - pc.l.Unlock() + // TODO(charlie): handle concurrent requests for the same type + enc, err := ec.LookupEncoder(typ.Elem()) + enc = pc.ecache.LoadOrStore(typ, enc) if err != nil { return err } - return enc.EncodeValue(ec, vw, val.Elem()) } @@ -80,36 +70,31 @@ func (pc *PointerCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val return ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val} } + typ := val.Type() if vr.Type() == bsontype.Null { - val.Set(reflect.Zero(val.Type())) + val.Set(reflect.Zero(typ)) return vr.ReadNull() } if vr.Type() == bsontype.Undefined { - val.Set(reflect.Zero(val.Type())) + val.Set(reflect.Zero(typ)) return vr.ReadUndefined() } if val.IsNil() { - val.Set(reflect.New(val.Type().Elem())) + val.Set(reflect.New(typ.Elem())) } - pc.l.RLock() - dec, ok := pc.dcache[val.Type()] - pc.l.RUnlock() - if ok { - if dec == nil { - return ErrNoDecoder{Type: val.Type()} + if v, ok := pc.dcache.Load(typ); ok { + if v == nil { + return ErrNoDecoder{Type: typ} } - return dec.DecodeValue(dc, vr, val.Elem()) + return v.DecodeValue(dc, vr, val.Elem()) } - - dec, err := dc.LookupDecoder(val.Type().Elem()) - pc.l.Lock() - pc.dcache[val.Type()] = dec - pc.l.Unlock() + // TODO(charlie): handle concurrent requests for the same type + dec, err := dc.LookupDecoder(typ.Elem()) + dec = pc.dcache.LoadOrStore(typ, dec) if err != nil { return err } - return dec.DecodeValue(dc, vr, val.Elem()) } diff --git a/bson/bsoncodec/registry.go b/bson/bsoncodec/registry.go index 930de28490..f309ee2b39 100644 --- a/bson/bsoncodec/registry.go +++ b/bson/bsoncodec/registry.go @@ -216,72 +216,42 @@ func (rb *RegistryBuilder) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Typ // // Deprecated: Use NewRegistry instead. func (rb *RegistryBuilder) Build() *Registry { - registry := new(Registry) - - registry.typeEncoders = make(map[reflect.Type]ValueEncoder, len(rb.registry.typeEncoders)) - for t, enc := range rb.registry.typeEncoders { - registry.typeEncoders[t] = enc - } - - registry.typeDecoders = make(map[reflect.Type]ValueDecoder, len(rb.registry.typeDecoders)) - for t, dec := range rb.registry.typeDecoders { - registry.typeDecoders[t] = dec - } - - registry.interfaceEncoders = make([]interfaceValueEncoder, len(rb.registry.interfaceEncoders)) - copy(registry.interfaceEncoders, rb.registry.interfaceEncoders) - - registry.interfaceDecoders = make([]interfaceValueDecoder, len(rb.registry.interfaceDecoders)) - copy(registry.interfaceDecoders, rb.registry.interfaceDecoders) - - registry.kindEncoders = make(map[reflect.Kind]ValueEncoder) - for kind, enc := range rb.registry.kindEncoders { - registry.kindEncoders[kind] = enc + r := &Registry{ + interfaceEncoders: append([]interfaceValueEncoder(nil), rb.registry.interfaceEncoders...), + interfaceDecoders: append([]interfaceValueDecoder(nil), rb.registry.interfaceDecoders...), + typeEncoders: rb.registry.typeEncoders.Clone(), + typeDecoders: rb.registry.typeDecoders.Clone(), + kindEncoders: rb.registry.kindEncoders.Clone(), + kindDecoders: rb.registry.kindDecoders.Clone(), } - - registry.kindDecoders = make(map[reflect.Kind]ValueDecoder) - for kind, dec := range rb.registry.kindDecoders { - registry.kindDecoders[kind] = dec - } - - registry.typeMap = make(map[bsontype.Type]reflect.Type) - for bt, rt := range rb.registry.typeMap { - registry.typeMap[bt] = rt - } - - return registry + rb.registry.typeMap.Range(func(k, v interface{}) bool { + if k != nil && v != nil { + r.typeMap.Store(k, v) + } + return true + }) + return r } // A Registry is used to store and retrieve codecs for types and interfaces. This type is the main // typed passed around and Encoders and Decoders are constructed from it. type Registry struct { - typeEncoders map[reflect.Type]ValueEncoder - typeDecoders map[reflect.Type]ValueDecoder - interfaceEncoders []interfaceValueEncoder interfaceDecoders []interfaceValueDecoder - - kindEncoders map[reflect.Kind]ValueEncoder - kindDecoders map[reflect.Kind]ValueDecoder - - typeMap map[bsontype.Type]reflect.Type - - mu sync.RWMutex + typeEncoders *typeEncoderCache + typeDecoders *typeDecoderCache + kindEncoders *kindEncoderCache + kindDecoders *kindDecoderCache + typeMap sync.Map // map[bsontype.Type]reflect.Type } // NewRegistry creates a new empty Registry. func NewRegistry() *Registry { return &Registry{ - typeEncoders: make(map[reflect.Type]ValueEncoder), - typeDecoders: make(map[reflect.Type]ValueDecoder), - - interfaceEncoders: make([]interfaceValueEncoder, 0), - interfaceDecoders: make([]interfaceValueDecoder, 0), - - kindEncoders: make(map[reflect.Kind]ValueEncoder), - kindDecoders: make(map[reflect.Kind]ValueDecoder), - - typeMap: make(map[bsontype.Type]reflect.Type), + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), } } @@ -296,7 +266,7 @@ func NewRegistry() *Registry { // // RegisterTypeEncoder should not be called concurrently with any other Registry method. func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) { - r.typeEncoders[valueType] = enc + r.typeEncoders.Store(valueType, enc) } // RegisterTypeDecoder registers the provided ValueDecoder for the provided type. @@ -310,7 +280,7 @@ func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) // // RegisterTypeDecoder should not be called concurrently with any other Registry method. func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) { - r.typeDecoders[valueType] = dec + r.typeDecoders.Store(valueType, dec) } // RegisterKindEncoder registers the provided ValueEncoder for the provided kind. @@ -326,7 +296,7 @@ func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) // // RegisterKindEncoder should not be called concurrently with any other Registry method. func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc ValueEncoder) { - r.kindEncoders[kind] = enc + r.kindEncoders.Store(kind, enc) } // RegisterKindDecoder registers the provided ValueDecoder for the provided kind. @@ -342,7 +312,7 @@ func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc ValueEncoder) { // // RegisterKindDecoder should not be called concurrently with any other Registry method. func (r *Registry) RegisterKindDecoder(kind reflect.Kind, dec ValueDecoder) { - r.kindDecoders[kind] = dec + r.kindDecoders.Store(kind, dec) } // RegisterInterfaceEncoder registers an encoder for the provided interface type iface. This encoder will @@ -401,7 +371,7 @@ func (r *Registry) RegisterInterfaceDecoder(iface reflect.Type, dec ValueDecoder // // reg.RegisterTypeMapEntry(bsontype.EmbeddedDocument, reflect.TypeOf(bson.Raw{})) func (r *Registry) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Type) { - r.typeMap[bt] = rt + r.typeMap.Store(bt, rt) } // LookupEncoder returns the first matching encoder in the Registry. It uses the following lookup @@ -418,9 +388,7 @@ func (r *Registry) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Type) { // If no encoder is found, an error of type ErrNoEncoder is returned. LookupEncoder is safe for // concurrent use by multiple goroutines after all codecs and encoders are registered. func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { - r.mu.RLock() enc, found := r.lookupTypeEncoder(valueType) - r.mu.RUnlock() if found { if enc == nil { return nil, ErrNoEncoder{Type: valueType} @@ -430,36 +398,26 @@ func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { enc, found = r.lookupInterfaceEncoder(valueType, true) if found { - r.mu.Lock() - r.typeEncoders[valueType] = enc - r.mu.Unlock() - return enc, nil + return r.typeEncoders.LoadOrStore(valueType, enc), nil } - if valueType == nil { - r.mu.Lock() - r.typeEncoders[valueType] = nil - r.mu.Unlock() + r.storeTypeEncoder(valueType, nil) return nil, ErrNoEncoder{Type: valueType} } - enc, found = r.kindEncoders[valueType.Kind()] - if !found { - r.mu.Lock() - r.typeEncoders[valueType] = nil - r.mu.Unlock() - return nil, ErrNoEncoder{Type: valueType} + if v, ok := r.kindEncoders.Load(valueType.Kind()); ok { + return r.storeTypeEncoder(valueType, v), nil } + r.storeTypeEncoder(valueType, nil) + return nil, ErrNoEncoder{Type: valueType} +} - r.mu.Lock() - r.typeEncoders[valueType] = enc - r.mu.Unlock() - return enc, nil +func (r *Registry) storeTypeEncoder(rt reflect.Type, enc ValueEncoder) ValueEncoder { + return r.typeEncoders.LoadOrStore(rt, enc) } -func (r *Registry) lookupTypeEncoder(valueType reflect.Type) (ValueEncoder, bool) { - enc, found := r.typeEncoders[valueType] - return enc, found +func (r *Registry) lookupTypeEncoder(rt reflect.Type) (ValueEncoder, bool) { + return r.typeEncoders.Load(rt) } func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool) (ValueEncoder, bool) { @@ -475,7 +433,7 @@ func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool // ahead in interfaceEncoders defaultEnc, found := r.lookupInterfaceEncoder(valueType, false) if !found { - defaultEnc = r.kindEncoders[valueType.Kind()] + defaultEnc, _ = r.kindEncoders.Load(valueType.Kind()) } return newCondAddrEncoder(ienc.ve, defaultEnc), true } @@ -500,10 +458,7 @@ func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { if valueType == nil { return nil, ErrNilType } - decodererr := ErrNoDecoder{Type: valueType} - r.mu.RLock() dec, found := r.lookupTypeDecoder(valueType) - r.mu.RUnlock() if found { if dec == nil { return nil, ErrNoDecoder{Type: valueType} @@ -513,29 +468,22 @@ func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { dec, found = r.lookupInterfaceDecoder(valueType, true) if found { - r.mu.Lock() - r.typeDecoders[valueType] = dec - r.mu.Unlock() - return dec, nil + return r.storeTypeDecoder(valueType, dec), nil } - dec, found = r.kindDecoders[valueType.Kind()] - if !found { - r.mu.Lock() - r.typeDecoders[valueType] = nil - r.mu.Unlock() - return nil, decodererr + if v, ok := r.kindDecoders.Load(valueType.Kind()); ok { + return r.storeTypeDecoder(valueType, v), nil } - - r.mu.Lock() - r.typeDecoders[valueType] = dec - r.mu.Unlock() - return dec, nil + r.storeTypeDecoder(valueType, nil) + return nil, ErrNoDecoder{Type: valueType} } func (r *Registry) lookupTypeDecoder(valueType reflect.Type) (ValueDecoder, bool) { - dec, found := r.typeDecoders[valueType] - return dec, found + return r.typeDecoders.Load(valueType) +} + +func (r *Registry) storeTypeDecoder(typ reflect.Type, dec ValueDecoder) ValueDecoder { + return r.typeDecoders.LoadOrStore(typ, dec) } func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool) (ValueDecoder, bool) { @@ -548,7 +496,7 @@ func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool // ahead in interfaceDecoders defaultDec, found := r.lookupInterfaceDecoder(valueType, false) if !found { - defaultDec = r.kindDecoders[valueType.Kind()] + defaultDec, _ = r.kindDecoders.Load(valueType.Kind()) } return newCondAddrDecoder(idec.vd, defaultDec), true } @@ -561,11 +509,11 @@ func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool // // LookupTypeMapEntry should not be called concurrently with any other Registry method. func (r *Registry) LookupTypeMapEntry(bt bsontype.Type) (reflect.Type, error) { - t, ok := r.typeMap[bt] - if !ok || t == nil { + v, ok := r.typeMap.Load(bt) + if v == nil || !ok { return nil, ErrNoTypeMapEntry{Type: bt} } - return t, nil + return v.(reflect.Type), nil } type interfaceValueEncoder struct { diff --git a/bson/bsoncodec/registry_test.go b/bson/bsoncodec/registry_test.go index 9ed68ce566..d09f32be5e 100644 --- a/bson/bsoncodec/registry_test.go +++ b/bson/bsoncodec/registry_test.go @@ -65,7 +65,7 @@ func TestRegistryBuilder(t *testing.T) { got := reg.typeEncoders for _, s := range want { wantT, wantC := s.t, s.c - gotC, exists := got[wantT] + gotC, exists := got.Load(wantT) if !exists { t.Errorf("Did not find type in the type registry: %v", wantT) } @@ -94,7 +94,7 @@ func TestRegistryBuilder(t *testing.T) { got := reg.kindEncoders for _, s := range want { wantK, wantC := s.k, s.c - gotC, exists := got[wantK] + gotC, exists := got.Load(wantK) if !exists { t.Errorf("Did not find kind in the kind registry: %v", wantK) } @@ -111,14 +111,14 @@ func TestRegistryBuilder(t *testing.T) { rb.RegisterDefaultEncoder(reflect.Map, codec) reg := rb.Build() - if reg.kindEncoders[reflect.Map] != codec { - t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Map], codec) + if reg.kindEncoders.get(reflect.Map) != codec { + t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec) } rb.RegisterDefaultEncoder(reflect.Map, codec2) reg = rb.Build() - if reg.kindEncoders[reflect.Map] != codec2 { - t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Map], codec2) + if reg.kindEncoders.get(reflect.Map) != codec2 { + t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2) } }) t.Run("StructCodec", func(t *testing.T) { @@ -128,14 +128,14 @@ func TestRegistryBuilder(t *testing.T) { rb.RegisterDefaultEncoder(reflect.Struct, codec) reg := rb.Build() - if reg.kindEncoders[reflect.Struct] != codec { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Struct], codec) + if reg.kindEncoders.get(reflect.Struct) != codec { + t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec) } rb.RegisterDefaultEncoder(reflect.Struct, codec2) reg = rb.Build() - if reg.kindEncoders[reflect.Struct] != codec2 { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Struct], codec2) + if reg.kindEncoders.get(reflect.Struct) != codec2 { + t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2) } }) t.Run("SliceCodec", func(t *testing.T) { @@ -145,14 +145,14 @@ func TestRegistryBuilder(t *testing.T) { rb.RegisterDefaultEncoder(reflect.Slice, codec) reg := rb.Build() - if reg.kindEncoders[reflect.Slice] != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Slice], codec) + if reg.kindEncoders.get(reflect.Slice) != codec { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec) } rb.RegisterDefaultEncoder(reflect.Slice, codec2) reg = rb.Build() - if reg.kindEncoders[reflect.Slice] != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Slice], codec2) + if reg.kindEncoders.get(reflect.Slice) != codec2 { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2) } }) t.Run("ArrayCodec", func(t *testing.T) { @@ -162,14 +162,14 @@ func TestRegistryBuilder(t *testing.T) { rb.RegisterDefaultEncoder(reflect.Array, codec) reg := rb.Build() - if reg.kindEncoders[reflect.Array] != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Array], codec) + if reg.kindEncoders.get(reflect.Array) != codec { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec) } rb.RegisterDefaultEncoder(reflect.Array, codec2) reg = rb.Build() - if reg.kindEncoders[reflect.Array] != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Array], codec2) + if reg.kindEncoders.get(reflect.Array) != codec2 { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2) } }) }) @@ -485,7 +485,7 @@ func TestRegistry(t *testing.T) { got := reg.typeEncoders for _, s := range want { wantT, wantC := s.t, s.c - gotC, exists := got[wantT] + gotC, exists := got.Load(wantT) if !exists { t.Errorf("type missing in registry: %v", wantT) } @@ -515,7 +515,7 @@ func TestRegistry(t *testing.T) { got := reg.kindEncoders for _, s := range want { wantK, wantC := s.k, s.c - gotC, exists := got[wantK] + gotC, exists := got.Load(wantK) if !exists { t.Errorf("type missing in registry: %v", wantK) } @@ -534,12 +534,12 @@ func TestRegistry(t *testing.T) { codec2 := &fakeCodec{num: 2} reg := NewRegistry() reg.RegisterKindEncoder(reflect.Map, codec) - if reg.kindEncoders[reflect.Map] != codec { - t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Map], codec) + if reg.kindEncoders.get(reflect.Map) != codec { + t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec) } reg.RegisterKindEncoder(reflect.Map, codec2) - if reg.kindEncoders[reflect.Map] != codec2 { - t.Errorf("map codec properly set: got %#v, want %#v", reg.kindEncoders[reflect.Map], codec2) + if reg.kindEncoders.get(reflect.Map) != codec2 { + t.Errorf("map codec properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2) } }) t.Run("StructCodec", func(t *testing.T) { @@ -549,12 +549,12 @@ func TestRegistry(t *testing.T) { codec2 := &fakeCodec{num: 2} reg := NewRegistry() reg.RegisterKindEncoder(reflect.Struct, codec) - if reg.kindEncoders[reflect.Struct] != codec { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Struct], codec) + if reg.kindEncoders.get(reflect.Struct) != codec { + t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec) } reg.RegisterKindEncoder(reflect.Struct, codec2) - if reg.kindEncoders[reflect.Struct] != codec2 { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Struct], codec2) + if reg.kindEncoders.get(reflect.Struct) != codec2 { + t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2) } }) t.Run("SliceCodec", func(t *testing.T) { @@ -564,12 +564,12 @@ func TestRegistry(t *testing.T) { codec2 := &fakeCodec{num: 2} reg := NewRegistry() reg.RegisterKindEncoder(reflect.Slice, codec) - if reg.kindEncoders[reflect.Slice] != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Slice], codec) + if reg.kindEncoders.get(reflect.Slice) != codec { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec) } reg.RegisterKindEncoder(reflect.Slice, codec2) - if reg.kindEncoders[reflect.Slice] != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Slice], codec2) + if reg.kindEncoders.get(reflect.Slice) != codec2 { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2) } }) t.Run("ArrayCodec", func(t *testing.T) { @@ -579,12 +579,12 @@ func TestRegistry(t *testing.T) { codec2 := &fakeCodec{num: 2} reg := NewRegistry() reg.RegisterKindEncoder(reflect.Array, codec) - if reg.kindEncoders[reflect.Array] != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Array], codec) + if reg.kindEncoders.get(reflect.Array) != codec { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec) } reg.RegisterKindEncoder(reflect.Array, codec2) - if reg.kindEncoders[reflect.Array] != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Array], codec2) + if reg.kindEncoders.get(reflect.Array) != codec2 { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2) } }) }) @@ -860,6 +860,52 @@ func TestRegistry(t *testing.T) { }) } +// get is only for testing as it does return if the value was found +func (c *kindEncoderCache) get(rt reflect.Kind) ValueEncoder { + e, _ := c.Load(rt) + return e +} + +func BenchmarkLookupEncoder(b *testing.B) { + type childStruct struct { + V1, V2, V3, V4 int + } + type nestedStruct struct { + childStruct + A struct{ C1, C2, C3, C4 childStruct } + B struct{ C1, C2, C3, C4 childStruct } + C struct{ M1, M2, M3, M4 map[int]int } + } + types := [...]reflect.Type{ + reflect.TypeOf(int64(1)), + reflect.TypeOf(&fakeCodec{}), + reflect.TypeOf(&testInterface1Impl{}), + reflect.TypeOf(&nestedStruct{}), + } + r := NewRegistry() + for _, typ := range types { + r.RegisterTypeEncoder(typ, &fakeCodec{}) + } + b.Run("Serial", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := r.LookupEncoder(types[i%len(types)]) + if err != nil { + b.Fatal(err) + } + } + }) + b.Run("Parallel", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for i := 0; pb.Next(); i++ { + _, err := r.LookupEncoder(types[i%len(types)]) + if err != nil { + b.Fatal(err) + } + } + }) + }) +} + type fakeType1 struct{} type fakeType2 struct{} type fakeType4 struct{} diff --git a/bson/bsoncodec/struct_codec.go b/bson/bsoncodec/struct_codec.go index 1dfdd98865..29ea76d19c 100644 --- a/bson/bsoncodec/struct_codec.go +++ b/bson/bsoncodec/struct_codec.go @@ -63,8 +63,7 @@ type Zeroer interface { // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the // StructCodec registered. type StructCodec struct { - cache map[reflect.Type]*structDescription - l sync.RWMutex + cache sync.Map // map[reflect.Type]*structDescription parser StructTagParser // DecodeZeroStruct causes DecodeValue to delete any existing values from Go structs in the @@ -115,7 +114,6 @@ func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) structOpt := bsonoptions.MergeStructCodecOptions(opts...) codec := &StructCodec{ - cache: make(map[reflect.Type]*structDescription), parser: p, } @@ -502,13 +500,27 @@ func (sc *StructCodec) describeStruct( ) (*structDescription, error) { // We need to analyze the struct, including getting the tags, collecting // information about inlining, and create a map of the field name to the field. - sc.l.RLock() - ds, exists := sc.cache[t] - sc.l.RUnlock() - if exists { - return ds, nil + if v, ok := sc.cache.Load(t); ok { + return v.(*structDescription), nil } + // TODO(charlie): Only describe the struct once when called + // concurrently with the same type. + ds, err := sc.describeStructSlow(r, t, useJSONStructTags, errorOnDuplicates) + if err != nil { + return nil, err + } + if v, loaded := sc.cache.LoadOrStore(t, ds); loaded { + ds = v.(*structDescription) + } + return ds, nil +} +func (sc *StructCodec) describeStructSlow( + r *Registry, + t reflect.Type, + useJSONStructTags bool, + errorOnDuplicates bool, +) (*structDescription, error) { numFields := t.NumField() sd := &structDescription{ fm: make(map[string]fieldDescription, numFields), @@ -639,10 +651,6 @@ func (sc *StructCodec) describeStruct( sort.Sort(byIndex(sd.fl)) - sc.l.Lock() - sc.cache[t] = sd - sc.l.Unlock() - return sd, nil } diff --git a/bson/testdata/code.json.gz b/bson/testdata/code.json.gz new file mode 100644 index 0000000000..1572a92bfb Binary files /dev/null and b/bson/testdata/code.json.gz differ