Skip to content

Commit fff59c0

Browse files
committed
Implement BSON Marshaler support
1 parent b710ba4 commit fff59c0

File tree

6 files changed

+140
-3
lines changed

6 files changed

+140
-3
lines changed

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
module github.com/deckarep/golang-set/v2
22

33
go 1.18
4+
5+
require go.mongodb.org/mongo-driver v1.16.0

go.sum

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
2+
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
3+
go.mongodb.org/mongo-driver v1.16.0 h1:tpRsfBJMROVHKpdGyc1BBEzzjDUWjItxbVSZ8Ls4BQ4=
4+
go.mongodb.org/mongo-driver v1.16.0/go.mod h1:oB6AhJQvFQL4LEHyXi6aJzQJtBiTQHiAd83l0GdFaiw=

set.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ SOFTWARE.
3535
// that can enforce mutual exclusion through other means.
3636
package mapset
3737

38+
import "go.mongodb.org/mongo-driver/bson/bsontype"
39+
3840
// Set is the primary interface provided by the mapset package. It
3941
// represents an unordered set of data and a large number of
4042
// operations that can be applied to that set.
@@ -192,8 +194,15 @@ type Set[T comparable] interface {
192194
MarshalJSON() ([]byte, error)
193195

194196
// UnmarshalJSON will unmarshal a JSON-based byte slice into a full Set datastructure.
195-
// For this to work, set subtypes must implemented the Marshal/Unmarshal interface.
197+
// For this to work, set subtypes must implement the Marshal/Unmarshal interface.
196198
UnmarshalJSON(b []byte) error
199+
200+
// MarshalBSONValue will marshal the set into a BSON-based representation.
201+
MarshalBSONValue() (bsontype.Type, []byte, error)
202+
203+
// UnmarshalBSONValue will unmarshal a BSON-based byte slice into a full Set datastructure.
204+
// For this to work, set subtypes must implement the Marshal/Unmarshal interface.
205+
UnmarshalBSONValue(bt bsontype.Type, b []byte) error
197206
}
198207

199208
// NewSet creates and returns a new set with the given elements.

threadsafe.go

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ SOFTWARE.
2525

2626
package mapset
2727

28-
import "sync"
28+
import (
29+
"sync"
30+
31+
"go.mongodb.org/mongo-driver/bson/bsontype"
32+
)
2933

3034
type threadSafeSet[T comparable] struct {
3135
sync.RWMutex
@@ -291,9 +295,25 @@ func (t *threadSafeSet[T]) MarshalJSON() ([]byte, error) {
291295
}
292296

293297
func (t *threadSafeSet[T]) UnmarshalJSON(p []byte) error {
294-
t.RLock()
298+
t.Lock()
295299
err := t.uss.UnmarshalJSON(p)
300+
t.Unlock()
301+
302+
return err
303+
}
304+
305+
func (t *threadSafeSet[T]) MarshalBSONValue() (bsontype.Type, []byte, error) {
306+
t.RLock()
307+
bt, b, err := t.uss.MarshalBSONValue()
296308
t.RUnlock()
297309

310+
return bt, b, err
311+
}
312+
313+
func (t *threadSafeSet[T]) UnmarshalBSONValue(bt bsontype.Type, p []byte) error {
314+
t.Lock()
315+
err := t.uss.UnmarshalBSONValue(bt, p)
316+
t.Unlock()
317+
298318
return err
299319
}

threadsafe_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ import (
3232
"sync"
3333
"sync/atomic"
3434
"testing"
35+
36+
"go.mongodb.org/mongo-driver/bson"
3537
)
3638

3739
const N = 1000
@@ -625,3 +627,79 @@ func Test_MarshalJSON(t *testing.T) {
625627
t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
626628
}
627629
}
630+
631+
func Test_UnmarshalBSONValue(t *testing.T) {
632+
tp, s, initErr := bson.MarshalValue(
633+
bson.A{"1", "2", "3", "test"},
634+
)
635+
636+
if initErr != nil {
637+
t.Errorf("Init Error should be nil: %v", initErr)
638+
639+
return
640+
}
641+
642+
if tp != bson.TypeArray {
643+
t.Errorf("Encoded Type should be bson.Array, got: %v", tp)
644+
645+
return
646+
}
647+
648+
expected := NewSet("1", "2", "3", "test")
649+
actual := NewSet[string]()
650+
err := bson.UnmarshalValue(bson.TypeArray, s, actual)
651+
if err != nil {
652+
t.Errorf("Error should be nil: %v", err)
653+
}
654+
655+
if !expected.Equal(actual) {
656+
t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
657+
}
658+
}
659+
func TestThreadUnsafeSet_UnmarshalBSONValue(t *testing.T) {
660+
tp, s, initErr := bson.MarshalValue(
661+
bson.A{int64(1), int64(2), int64(3)},
662+
)
663+
664+
if initErr != nil {
665+
t.Errorf("Init Error should be nil: %v", initErr)
666+
667+
return
668+
}
669+
670+
if tp != bson.TypeArray {
671+
t.Errorf("Encoded Type should be bson.Array, got: %v", tp)
672+
673+
return
674+
}
675+
676+
expected := NewThreadUnsafeSet[int64](1, 2, 3)
677+
actual := NewThreadUnsafeSet[int64]()
678+
err := actual.UnmarshalBSONValue(bson.TypeArray, []byte(s))
679+
if err != nil {
680+
t.Errorf("Error should be nil: %v", err)
681+
}
682+
if !expected.Equal(actual) {
683+
t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
684+
}
685+
}
686+
func Test_MarshalBSONValue(t *testing.T) {
687+
expected := NewSet("1", "test")
688+
689+
_, b, err := bson.MarshalValue(
690+
NewSet("1", "test"),
691+
)
692+
if err != nil {
693+
t.Errorf("Error should be nil: %v", err)
694+
}
695+
696+
actual := NewSet[string]()
697+
err = bson.UnmarshalValue(bson.TypeArray, b, actual)
698+
if err != nil {
699+
t.Errorf("Error should be nil: %v", err)
700+
}
701+
702+
if !expected.Equal(actual) {
703+
t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
704+
}
705+
}

threadunsafe.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ import (
2929
"encoding/json"
3030
"fmt"
3131
"strings"
32+
33+
"go.mongodb.org/mongo-driver/bson"
34+
"go.mongodb.org/mongo-driver/bson/bsontype"
3235
)
3336

3437
type threadUnsafeSet[T comparable] map[T]struct{}
@@ -328,3 +331,24 @@ func (s threadUnsafeSet[T]) UnmarshalJSON(b []byte) error {
328331

329332
return nil
330333
}
334+
335+
// MarshalBSON creates a BSON array from the set.
336+
func (s threadUnsafeSet[T]) MarshalBSONValue() (bsontype.Type, []byte, error) {
337+
return bson.MarshalValue(s.ToSlice())
338+
}
339+
340+
// UnmarshalBSON recreates a set from a BSON array.
341+
func (s threadUnsafeSet[T]) UnmarshalBSONValue(bt bsontype.Type, b []byte) error {
342+
if bt != bson.TypeArray {
343+
return fmt.Errorf("must use BSON Array to unmarshal Set")
344+
}
345+
346+
var i []T
347+
err := bson.UnmarshalValue(bt, b, &i)
348+
if err != nil {
349+
return err
350+
}
351+
s.Append(i...)
352+
353+
return nil
354+
}

0 commit comments

Comments
 (0)