Skip to content

Commit cbe8aa4

Browse files
charlieviethmatthewdaleqingyang-hu
committed
GODRIVER-2914 bsoncodec/bsonrw: eliminate encoding allocations (#1323)
Co-authored-by: Matt Dale <[email protected]> Co-authored-by: Qingyang Hu <[email protected]>
1 parent 7f8b1d0 commit cbe8aa4

File tree

8 files changed

+119
-84
lines changed

8 files changed

+119
-84
lines changed

bson/bsoncodec/slice_codec.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val re
6262
}
6363

6464
// If we have a []primitive.E we want to treat it as a document instead of as an array.
65-
if val.Type().ConvertibleTo(tD) {
65+
if val.Type() == tD || val.Type().ConvertibleTo(tD) {
6666
d := val.Convert(tD).Interface().(primitive.D)
6767

6868
dw, err := vw.WriteDocument()

bson/bsoncodec/struct_codec.go

Lines changed: 24 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,14 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val
190190
encoder := desc.encoder
191191

192192
var zero bool
193-
rvInterface := rv.Interface()
194193
if cz, ok := encoder.(CodecZeroer); ok {
195-
zero = cz.IsTypeZero(rvInterface)
194+
zero = cz.IsTypeZero(rv.Interface())
196195
} else if rv.Kind() == reflect.Interface {
197196
// isZero will not treat an interface rv as an interface, so we need to check for the
198197
// zero interface separately.
199198
zero = rv.IsNil()
200199
} else {
201-
zero = isZero(rvInterface, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct)
200+
zero = isZero(rv, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct)
202201
}
203202
if desc.omitEmpty && zero {
204203
continue
@@ -392,56 +391,32 @@ func (sc *StructCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val
392391
return nil
393392
}
394393

395-
func isZero(i interface{}, omitZeroStruct bool) bool {
396-
v := reflect.ValueOf(i)
397-
398-
// check the value validity
399-
if !v.IsValid() {
400-
return true
394+
func isZero(v reflect.Value, omitZeroStruct bool) bool {
395+
kind := v.Kind()
396+
if (kind != reflect.Ptr || !v.IsNil()) && v.Type().Implements(tZeroer) {
397+
return v.Interface().(Zeroer).IsZero()
401398
}
402-
403-
if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
404-
return z.IsZero()
405-
}
406-
407-
switch v.Kind() {
408-
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
409-
return v.Len() == 0
410-
case reflect.Bool:
411-
return !v.Bool()
412-
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
413-
return v.Int() == 0
414-
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
415-
return v.Uint() == 0
416-
case reflect.Float32, reflect.Float64:
417-
return v.Float() == 0
418-
case reflect.Interface, reflect.Ptr:
419-
return v.IsNil()
420-
case reflect.Struct:
399+
if kind == reflect.Struct {
421400
if !omitZeroStruct {
422401
return false
423402
}
424-
425-
// TODO(GODRIVER-2820): Update the logic to be able to handle private struct fields.
426-
// TODO Use condition "reflect.Zero(v.Type()).Equal(v)" instead.
427-
428403
vt := v.Type()
429404
if vt == tTime {
430405
return v.Interface().(time.Time).IsZero()
431406
}
432-
for i := 0; i < v.NumField(); i++ {
433-
if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous {
407+
numField := vt.NumField()
408+
for i := 0; i < numField; i++ {
409+
ff := vt.Field(i)
410+
if ff.PkgPath != "" && !ff.Anonymous {
434411
continue // Private field
435412
}
436-
fld := v.Field(i)
437-
if !isZero(fld.Interface(), omitZeroStruct) {
413+
if !isZero(v.Field(i), omitZeroStruct) {
438414
return false
439415
}
440416
}
441417
return true
442418
}
443-
444-
return false
419+
return !v.IsValid() || v.IsZero()
445420
}
446421

447422
type structDescription struct {
@@ -708,21 +683,21 @@ func getInlineField(val reflect.Value, index []int) (reflect.Value, error) {
708683

709684
// DeepZero returns recursive zero object
710685
func deepZero(st reflect.Type) (result reflect.Value) {
711-
result = reflect.Indirect(reflect.New(st))
712-
713-
if result.Kind() == reflect.Struct {
714-
for i := 0; i < result.NumField(); i++ {
715-
if f := result.Field(i); f.Kind() == reflect.Ptr {
716-
if f.CanInterface() {
717-
if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct {
718-
result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem())))
719-
}
686+
if st.Kind() == reflect.Struct {
687+
numField := st.NumField()
688+
for i := 0; i < numField; i++ {
689+
if result == emptyValue {
690+
result = reflect.Indirect(reflect.New(st))
691+
}
692+
f := result.Field(i)
693+
if f.CanInterface() {
694+
if f.Type().Kind() == reflect.Struct {
695+
result.Field(i).Set(recursivePointerTo(deepZero(f.Type().Elem())))
720696
}
721697
}
722698
}
723699
}
724-
725-
return
700+
return result
726701
}
727702

728703
// recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside

bson/bsoncodec/struct_codec_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package bsoncodec
88

99
import (
10+
"reflect"
1011
"testing"
1112
"time"
1213

@@ -147,7 +148,7 @@ func TestIsZero(t *testing.T) {
147148
t.Run(tc.description, func(t *testing.T) {
148149
t.Parallel()
149150

150-
got := isZero(tc.value, tc.omitZeroStruct)
151+
got := isZero(reflect.ValueOf(tc.value), tc.omitZeroStruct)
151152
assert.Equal(t, tc.want, got, "expected and actual isZero return are different")
152153
})
153154
}

bson/bsoncodec/types.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ var tValueUnmarshaler = reflect.TypeOf((*ValueUnmarshaler)(nil)).Elem()
3434
var tMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem()
3535
var tUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
3636
var tProxy = reflect.TypeOf((*Proxy)(nil)).Elem()
37+
var tZeroer = reflect.TypeOf((*Zeroer)(nil)).Elem()
3738

3839
var tBinary = reflect.TypeOf(primitive.Binary{})
3940
var tUndefined = reflect.TypeOf(primitive.Undefined{})

bson/bsonrw/copier.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ func (c Copier) AppendDocumentBytes(dst []byte, src ValueReader) ([]byte, error)
193193
}
194194

195195
vw := vwPool.Get().(*valueWriter)
196-
defer vwPool.Put(vw)
196+
defer putValueWriter(vw)
197197

198198
vw.reset(dst)
199199

@@ -213,7 +213,7 @@ func (c Copier) AppendArrayBytes(dst []byte, src ValueReader) ([]byte, error) {
213213
}
214214

215215
vw := vwPool.Get().(*valueWriter)
216-
defer vwPool.Put(vw)
216+
defer putValueWriter(vw)
217217

218218
vw.reset(dst)
219219

@@ -258,7 +258,7 @@ func (c Copier) AppendValueBytes(dst []byte, src ValueReader) (bsontype.Type, []
258258
}
259259

260260
vw := vwPool.Get().(*valueWriter)
261-
defer vwPool.Put(vw)
261+
defer putValueWriter(vw)
262262

263263
start := len(dst)
264264

bson/bsonrw/value_reader.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,8 +739,7 @@ func (vr *valueReader) ReadValue() (ValueReader, error) {
739739
return nil, ErrEOA
740740
}
741741

742-
_, err = vr.readCString()
743-
if err != nil {
742+
if err := vr.skipCString(); err != nil {
744743
return nil, err
745744
}
746745

@@ -794,6 +793,15 @@ func (vr *valueReader) readByte() (byte, error) {
794793
return vr.d[vr.offset-1], nil
795794
}
796795

796+
func (vr *valueReader) skipCString() error {
797+
idx := bytes.IndexByte(vr.d[vr.offset:], 0x00)
798+
if idx < 0 {
799+
return io.EOF
800+
}
801+
vr.offset += int64(idx) + 1
802+
return nil
803+
}
804+
797805
func (vr *valueReader) readCString() (string, error) {
798806
idx := bytes.IndexByte(vr.d[vr.offset:], 0x00)
799807
if idx < 0 {

bson/bsonrw/value_writer.go

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ var vwPool = sync.Pool{
2828
},
2929
}
3030

31+
func putValueWriter(vw *valueWriter) {
32+
if vw != nil {
33+
vw.w = nil // don't leak the writer
34+
vwPool.Put(vw)
35+
}
36+
}
37+
3138
// BSONValueWriterPool is a pool for BSON ValueWriters.
3239
//
3340
// Deprecated: BSONValueWriterPool will not be supported in Go Driver 2.0.
@@ -149,32 +156,21 @@ type valueWriter struct {
149156
}
150157

151158
func (vw *valueWriter) advanceFrame() {
152-
if vw.frame+1 >= int64(len(vw.stack)) { // We need to grow the stack
153-
length := len(vw.stack)
154-
if length+1 >= cap(vw.stack) {
155-
// double it
156-
buf := make([]vwState, 2*cap(vw.stack)+1)
157-
copy(buf, vw.stack)
158-
vw.stack = buf
159-
}
160-
vw.stack = vw.stack[:length+1]
161-
}
162159
vw.frame++
160+
if vw.frame >= int64(len(vw.stack)) {
161+
vw.stack = append(vw.stack, vwState{})
162+
}
163163
}
164164

165165
func (vw *valueWriter) push(m mode) {
166166
vw.advanceFrame()
167167

168168
// Clean the stack
169-
vw.stack[vw.frame].mode = m
170-
vw.stack[vw.frame].key = ""
171-
vw.stack[vw.frame].arrkey = 0
172-
vw.stack[vw.frame].start = 0
169+
vw.stack[vw.frame] = vwState{mode: m}
173170

174-
vw.stack[vw.frame].mode = m
175171
switch m {
176172
case mDocument, mArray, mCodeWithScope:
177-
vw.reserveLength()
173+
vw.reserveLength() // WARN: this is not needed
178174
}
179175
}
180176

@@ -213,6 +209,7 @@ func newValueWriter(w io.Writer) *valueWriter {
213209
return vw
214210
}
215211

212+
// TODO: only used in tests
216213
func newValueWriterFromSlice(buf []byte) *valueWriter {
217214
vw := new(valueWriter)
218215
stack := make([]vwState, 1, 5)
@@ -249,17 +246,16 @@ func (vw *valueWriter) invalidTransitionError(destination mode, name string, mod
249246
}
250247

251248
func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error {
252-
switch vw.stack[vw.frame].mode {
249+
frame := &vw.stack[vw.frame]
250+
switch frame.mode {
253251
case mElement:
254-
key := vw.stack[vw.frame].key
252+
key := frame.key
255253
if !isValidCString(key) {
256254
return errors.New("BSON element key cannot contain null bytes")
257255
}
258-
259-
vw.buf = bsoncore.AppendHeader(vw.buf, t, key)
256+
vw.appendHeader(t, key)
260257
case mValue:
261-
// TODO: Do this with a cache of the first 1000 or so array keys.
262-
vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey))
258+
vw.appendIntHeader(t, frame.arrkey)
263259
default:
264260
modes := []mode{mElement, mValue}
265261
if addmodes != nil {
@@ -601,9 +597,11 @@ func (vw *valueWriter) writeLength() error {
601597
if length > maxSize {
602598
return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))}
603599
}
604-
length = length - int(vw.stack[vw.frame].start)
605-
start := vw.stack[vw.frame].start
600+
frame := &vw.stack[vw.frame]
601+
length = length - int(frame.start)
602+
start := frame.start
606603

604+
_ = vw.buf[start+3] // BCE
607605
vw.buf[start+0] = byte(length)
608606
vw.buf[start+1] = byte(length >> 8)
609607
vw.buf[start+2] = byte(length >> 16)
@@ -612,5 +610,31 @@ func (vw *valueWriter) writeLength() error {
612610
}
613611

614612
func isValidCString(cs string) bool {
615-
return !strings.ContainsRune(cs, '\x00')
613+
// Disallow the zero byte in a cstring because the zero byte is used as the
614+
// terminating character.
615+
//
616+
// It's safe to check bytes instead of runes because all multibyte UTF-8
617+
// code points start with (binary) 11xxxxxx or 10xxxxxx, so 00000000 (i.e.
618+
// 0) will never be part of a multibyte UTF-8 code point. This logic is the
619+
// same as the "r < utf8.RuneSelf" case in strings.IndexRune but can be
620+
// inlined.
621+
//
622+
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.1:src/strings/strings.go;l=127
623+
return strings.IndexByte(cs, 0) == -1
624+
}
625+
626+
// appendHeader is the same as bsoncore.AppendHeader but does not check if the
627+
// key is a valid C string since the caller has already checked for that.
628+
//
629+
// The caller of this function must check if key is a valid C string.
630+
func (vw *valueWriter) appendHeader(t bsontype.Type, key string) {
631+
vw.buf = bsoncore.AppendType(vw.buf, t)
632+
vw.buf = append(vw.buf, key...)
633+
vw.buf = append(vw.buf, 0x00)
634+
}
635+
636+
func (vw *valueWriter) appendIntHeader(t bsontype.Type, key int) {
637+
vw.buf = bsoncore.AppendType(vw.buf, t)
638+
vw.buf = strconv.AppendInt(vw.buf, int64(key), 10)
639+
vw.buf = append(vw.buf, 0x00)
616640
}

bson/marshal.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package bson
99
import (
1010
"bytes"
1111
"encoding/json"
12+
"sync"
1213

1314
"go.mongodb.org/mongo-driver/bson/bsoncodec"
1415
"go.mongodb.org/mongo-driver/bson/bsonrw"
@@ -141,6 +142,13 @@ func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{
141142
return MarshalAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val)
142143
}
143144

145+
// Pool of buffers for marshalling BSON.
146+
var bufPool = sync.Pool{
147+
New: func() interface{} {
148+
return new(bytes.Buffer)
149+
},
150+
}
151+
144152
// MarshalAppendWithContext will encode val as a BSON document using Registry r and EncodeContext ec and append the
145153
// bytes to dst. If dst is not large enough to hold the bytes, it will be grown. If val is not a type that can be
146154
// transformed into a document, MarshalValueAppendWithContext should be used instead.
@@ -162,8 +170,26 @@ func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{
162170
//
163171
// See [Encoder] for more examples.
164172
func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) ([]byte, error) {
165-
sw := new(bsonrw.SliceWriter)
166-
*sw = dst
173+
sw := bufPool.Get().(*bytes.Buffer)
174+
defer func() {
175+
// Proper usage of a sync.Pool requires each entry to have approximately
176+
// the same memory cost. To obtain this property when the stored type
177+
// contains a variably-sized buffer, we add a hard limit on the maximum
178+
// buffer to place back in the pool. We limit the size to 16MiB because
179+
// that's the maximum wire message size supported by any current MongoDB
180+
// server.
181+
//
182+
// Comment based on
183+
// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/fmt/print.go;l=147
184+
//
185+
// Recycle byte slices that are smaller than 16MiB and at least half
186+
// occupied.
187+
if sw.Cap() < 16*1024*1024 && sw.Cap()/2 < sw.Len() {
188+
bufPool.Put(sw)
189+
}
190+
}()
191+
192+
sw.Reset()
167193
vw := bvwPool.Get(sw)
168194
defer bvwPool.Put(vw)
169195

@@ -184,7 +210,7 @@ func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interf
184210
return nil, err
185211
}
186212

187-
return *sw, nil
213+
return append(dst, sw.Bytes()...), nil
188214
}
189215

190216
// MarshalValue returns the BSON encoding of val.

0 commit comments

Comments
 (0)