diff --git a/go.mod b/go.mod index ae7f41e..0a507bf 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/deckarep/golang-set/v2 go 1.18 + +require go.mongodb.org/mongo-driver v1.17.4 diff --git a/go.sum b/go.sum index e69de29..a669c19 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,6 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +go.mongodb.org/mongo-driver v1.16.0 h1:tpRsfBJMROVHKpdGyc1BBEzzjDUWjItxbVSZ8Ls4BQ4= +go.mongodb.org/mongo-driver v1.16.0/go.mod h1:oB6AhJQvFQL4LEHyXi6aJzQJtBiTQHiAd83l0GdFaiw= +go.mongodb.org/mongo-driver v1.17.4 h1:jUorfmVzljjr0FLzYQsGP8cgN/qzzxlY9Vh0C9KFXVw= +go.mongodb.org/mongo-driver v1.17.4/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ= diff --git a/set.go b/set.go index e9409aa..86c9785 100644 --- a/set.go +++ b/set.go @@ -35,6 +35,8 @@ SOFTWARE. // that can enforce mutual exclusion through other means. package mapset +import "go.mongodb.org/mongo-driver/bson/bsontype" + // Set is the primary interface provided by the mapset package. It // represents an unordered set of data and a large number of // operations that can be applied to that set. @@ -196,8 +198,15 @@ type Set[T comparable] interface { MarshalJSON() ([]byte, error) // UnmarshalJSON will unmarshal a JSON-based byte slice into a full Set datastructure. - // For this to work, set subtypes must implemented the Marshal/Unmarshal interface. + // For this to work, set subtypes must implement the Marshal/Unmarshal interface. UnmarshalJSON(b []byte) error + + // MarshalBSONValue will marshal the set into a BSON-based representation. + MarshalBSONValue() (bsontype.Type, []byte, error) + + // UnmarshalBSONValue will unmarshal a BSON-based byte slice into a full Set datastructure. + // For this to work, set subtypes must implement the Marshal/Unmarshal interface. + UnmarshalBSONValue(bt bsontype.Type, b []byte) error } // NewSet creates and returns a new set with the given elements. diff --git a/threadsafe.go b/threadsafe.go index 0f3e593..28a83d4 100644 --- a/threadsafe.go +++ b/threadsafe.go @@ -25,7 +25,11 @@ SOFTWARE. package mapset -import "sync" +import ( + "sync" + + "go.mongodb.org/mongo-driver/bson/bsontype" +) type threadSafeSet[T comparable] struct { sync.RWMutex @@ -305,9 +309,25 @@ func (t *threadSafeSet[T]) MarshalJSON() ([]byte, error) { } func (t *threadSafeSet[T]) UnmarshalJSON(p []byte) error { - t.RLock() + t.Lock() err := t.uss.UnmarshalJSON(p) + t.Unlock() + + return err +} + +func (t *threadSafeSet[T]) MarshalBSONValue() (bsontype.Type, []byte, error) { + t.RLock() + bt, b, err := t.uss.MarshalBSONValue() t.RUnlock() + return bt, b, err +} + +func (t *threadSafeSet[T]) UnmarshalBSONValue(bt bsontype.Type, p []byte) error { + t.Lock() + err := t.uss.UnmarshalBSONValue(bt, p) + t.Unlock() + return err } diff --git a/threadsafe_test.go b/threadsafe_test.go index ed15d02..7fbeba5 100644 --- a/threadsafe_test.go +++ b/threadsafe_test.go @@ -33,6 +33,8 @@ import ( "sync" "sync/atomic" "testing" + + "go.mongodb.org/mongo-driver/bson" ) const N = 1000 @@ -683,3 +685,81 @@ func Test_DeadlockOnEachCallbackWhenPanic(t *testing.T) { t.Errorf("Expected widgets to have 5 elements, but has %d", card) } } + +func Test_UnmarshalBSONValue(t *testing.T) { + tp, s, initErr := bson.MarshalValue( + bson.A{"1", "2", "3", "test"}, + ) + + if initErr != nil { + t.Errorf("Init Error should be nil: %v", initErr) + + return + } + + if tp != bson.TypeArray { + t.Errorf("Encoded Type should be bson.Array, got: %v", tp) + + return + } + + expected := NewSet("1", "2", "3", "test") + actual := NewSet[string]() + err := bson.UnmarshalValue(bson.TypeArray, s, actual) + if err != nil { + t.Errorf("Error should be nil: %v", err) + } + + if !expected.Equal(actual) { + t.Errorf("Expected no difference, got: %v", expected.Difference(actual)) + } +} + +func TestThreadUnsafeSet_UnmarshalBSONValue(t *testing.T) { + tp, s, initErr := bson.MarshalValue( + bson.A{int64(1), int64(2), int64(3)}, + ) + + if initErr != nil { + t.Errorf("Init Error should be nil: %v", initErr) + + return + } + + if tp != bson.TypeArray { + t.Errorf("Encoded Type should be bson.Array, got: %v", tp) + + return + } + + expected := NewThreadUnsafeSet[int64](1, 2, 3) + actual := NewThreadUnsafeSet[int64]() + err := actual.UnmarshalBSONValue(bson.TypeArray, []byte(s)) + if err != nil { + t.Errorf("Error should be nil: %v", err) + } + if !expected.Equal(actual) { + t.Errorf("Expected no difference, got: %v", expected.Difference(actual)) + } +} + +func Test_MarshalBSONValue(t *testing.T) { + expected := NewSet("1", "test") + + _, b, err := bson.MarshalValue( + NewSet("1", "test"), + ) + if err != nil { + t.Errorf("Error should be nil: %v", err) + } + + actual := NewSet[string]() + err = bson.UnmarshalValue(bson.TypeArray, b, actual) + if err != nil { + t.Errorf("Error should be nil: %v", err) + } + + if !expected.Equal(actual) { + t.Errorf("Expected no difference, got: %v", expected.Difference(actual)) + } +} diff --git a/threadunsafe.go b/threadunsafe.go index c95d32b..351354f 100644 --- a/threadunsafe.go +++ b/threadunsafe.go @@ -29,6 +29,9 @@ import ( "encoding/json" "fmt" "strings" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsontype" ) type threadUnsafeSet[T comparable] map[T]struct{} @@ -350,3 +353,24 @@ func (s *threadUnsafeSet[T]) UnmarshalJSON(b []byte) error { return nil } + +// MarshalBSON creates a BSON array from the set. +func (s threadUnsafeSet[T]) MarshalBSONValue() (bsontype.Type, []byte, error) { + return bson.MarshalValue(s.ToSlice()) +} + +// UnmarshalBSON recreates a set from a BSON array. +func (s threadUnsafeSet[T]) UnmarshalBSONValue(bt bsontype.Type, b []byte) error { + if bt != bson.TypeArray { + return fmt.Errorf("must use BSON Array to unmarshal Set") + } + + var i []T + err := bson.UnmarshalValue(bt, b, &i) + if err != nil { + return err + } + s.Append(i...) + + return nil +}