diff --git a/set.go b/set.go index 6ee28a2..b795583 100644 --- a/set.go +++ b/set.go @@ -39,26 +39,26 @@ package mapset // represents an unordered set of data and a large number of // operations that can be applied to that set. type Set[T comparable] interface { - // Adds an element to the set. Returns whether + // Add adds an element to the set. Returns whether // the item was added. Add(val T) bool - // Returns the number of elements in the set. + // Cardinality returns the number of elements in the set. Cardinality() int - // Removes all elements from the set, leaving + // Clear removes all elements from the set, leaving // the empty set. Clear() - // Returns a clone of the set using the same + // Clone returns a clone of the set using the same // implementation, duplicating all keys. Clone() Set[T] - // Returns whether the given items + // Contains returns whether the given items // are all in the set. Contains(val ...T) bool - // Returns the difference between this set + // Difference returns the difference between this set // and other. The returned set will contain // all elements of this set that are not also // elements of other. @@ -69,7 +69,7 @@ type Set[T comparable] interface { // panic. Difference(other Set[T]) Set[T] - // Determines if two sets are equal to each + // Equal determines if two sets are equal to each // other. If they have the same cardinality // and contain the same elements, they are // considered equal. The order in which @@ -80,7 +80,7 @@ type Set[T comparable] interface { // method. Otherwise, Equal will panic. Equal(other Set[T]) bool - // Returns a new set containing only the elements + // Intersect returns a new set containing only the elements // that exist only in both sets. // // Note that the argument to Intersect @@ -89,7 +89,7 @@ type Set[T comparable] interface { // panic. Intersect(other Set[T]) Set[T] - // Determines if every element in this set is in + // IsProperSubset determines if every element in this set is in // the other set but the two sets are not equal. // // Note that the argument to IsProperSubset @@ -98,7 +98,7 @@ type Set[T comparable] interface { // will panic. IsProperSubset(other Set[T]) bool - // Determines if every element in the other set + // IsProperSuperset determines if every element in the other set // is in this set but the two sets are not // equal. // @@ -108,7 +108,7 @@ type Set[T comparable] interface { // panic. IsProperSuperset(other Set[T]) bool - // Determines if every element in this set is in + // IsSubset determines if every element in this set is in // the other set. // // Note that the argument to IsSubset @@ -117,7 +117,7 @@ type Set[T comparable] interface { // panic. IsSubset(other Set[T]) bool - // Determines if every element in the other set + // IsSuperset determines if every element in the other set // is in this set. // // Note that the argument to IsSuperset @@ -126,26 +126,26 @@ type Set[T comparable] interface { // panic. IsSuperset(other Set[T]) bool - // Iterates over elements and executes the passed func against each element. + // Each iterates over elements and executes the passed func against each element. // If passed func returns true, stop iteration at the time. Each(func(T) bool) - // Returns a channel of elements that you can + // Iter returns a channel of elements that you can // range over. Iter() <-chan T - // Returns an Iterator object that you can + // Iterator returns an Iterator object that you can // use to range over the set. Iterator() *Iterator[T] - // Remove a single element from the set. + // Remove removes a single element from the set. Remove(i T) - // Provides a convenient string representation + // String provides a convenient string representation // of the current state of the set. String() string - // Returns a new set with all elements which are + // SymmetricDifference returns a new set with all elements which are // in either this set or the other set but not in both. // // Note that the argument to SymmetricDifference @@ -154,7 +154,7 @@ type Set[T comparable] interface { // will panic. SymmetricDifference(other Set[T]) Set[T] - // Returns a new set with all elements in both sets. + // Union returns a new set with all elements in both sets. // // Note that the argument to Union must be of the // same type as the receiver of the method. @@ -164,7 +164,7 @@ type Set[T comparable] interface { // Pop removes and returns an arbitrary item from the set. Pop() (T, bool) - // Returns the members of the set as a slice. + // ToSlice returns the members of the set as a slice. ToSlice() []T // MarshalJSON will marshal the set into a JSON-based representation. @@ -182,7 +182,7 @@ func NewSet[T comparable](vals ...T) Set[T] { for _, item := range vals { s.Add(item) } - return &s + return s } // NewThreadUnsafeSet creates and returns a new set with the given elements. @@ -192,5 +192,5 @@ func NewThreadUnsafeSet[T comparable](vals ...T) Set[T] { for _, item := range vals { s.Add(item) } - return &s + return s } diff --git a/threadsafe.go b/threadsafe.go index 1a137a9..da694b7 100644 --- a/threadsafe.go +++ b/threadsafe.go @@ -32,160 +32,159 @@ type threadSafeSet[T comparable] struct { uss threadUnsafeSet[T] } -func newThreadSafeSet[T comparable]() threadSafeSet[T] { - newUss := newThreadUnsafeSet[T]() - return threadSafeSet[T]{ - uss: newUss, +func newThreadSafeSet[T comparable]() *threadSafeSet[T] { + return &threadSafeSet[T]{ + uss: newThreadUnsafeSet[T](), } } -func (s *threadSafeSet[T]) Add(v T) bool { - s.Lock() - ret := s.uss.Add(v) - s.Unlock() +func (t *threadSafeSet[T]) Add(v T) bool { + t.Lock() + ret := t.uss.Add(v) + t.Unlock() return ret } -func (s *threadSafeSet[T]) Contains(v ...T) bool { - s.RLock() - ret := s.uss.Contains(v...) - s.RUnlock() +func (t *threadSafeSet[T]) Contains(v ...T) bool { + t.RLock() + ret := t.uss.Contains(v...) + t.RUnlock() return ret } -func (s *threadSafeSet[T]) IsSubset(other Set[T]) bool { +func (t *threadSafeSet[T]) IsSubset(other Set[T]) bool { o := other.(*threadSafeSet[T]) - s.RLock() + t.RLock() o.RLock() - ret := s.uss.IsSubset(&o.uss) - s.RUnlock() + ret := t.uss.IsSubset(o.uss) + t.RUnlock() o.RUnlock() return ret } -func (s *threadSafeSet[T]) IsProperSubset(other Set[T]) bool { +func (t *threadSafeSet[T]) IsProperSubset(other Set[T]) bool { o := other.(*threadSafeSet[T]) - s.RLock() - defer s.RUnlock() + t.RLock() + defer t.RUnlock() o.RLock() defer o.RUnlock() - return s.uss.IsProperSubset(&o.uss) + return t.uss.IsProperSubset(o.uss) } -func (s *threadSafeSet[T]) IsSuperset(other Set[T]) bool { - return other.IsSubset(s) +func (t *threadSafeSet[T]) IsSuperset(other Set[T]) bool { + return other.IsSubset(t) } -func (s *threadSafeSet[T]) IsProperSuperset(other Set[T]) bool { - return other.IsProperSubset(s) +func (t *threadSafeSet[T]) IsProperSuperset(other Set[T]) bool { + return other.IsProperSubset(t) } -func (s *threadSafeSet[T]) Union(other Set[T]) Set[T] { +func (t *threadSafeSet[T]) Union(other Set[T]) Set[T] { o := other.(*threadSafeSet[T]) - s.RLock() + t.RLock() o.RLock() - unsafeUnion := s.uss.Union(&o.uss).(*threadUnsafeSet[T]) - ret := &threadSafeSet[T]{uss: *unsafeUnion} - s.RUnlock() + unsafeUnion := t.uss.Union(o.uss).(threadUnsafeSet[T]) + ret := &threadSafeSet[T]{uss: unsafeUnion} + t.RUnlock() o.RUnlock() return ret } -func (s *threadSafeSet[T]) Intersect(other Set[T]) Set[T] { +func (t *threadSafeSet[T]) Intersect(other Set[T]) Set[T] { o := other.(*threadSafeSet[T]) - s.RLock() + t.RLock() o.RLock() - unsafeIntersection := s.uss.Intersect(&o.uss).(*threadUnsafeSet[T]) - ret := &threadSafeSet[T]{uss: *unsafeIntersection} - s.RUnlock() + unsafeIntersection := t.uss.Intersect(o.uss).(threadUnsafeSet[T]) + ret := &threadSafeSet[T]{uss: unsafeIntersection} + t.RUnlock() o.RUnlock() return ret } -func (s *threadSafeSet[T]) Difference(other Set[T]) Set[T] { +func (t *threadSafeSet[T]) Difference(other Set[T]) Set[T] { o := other.(*threadSafeSet[T]) - s.RLock() + t.RLock() o.RLock() - unsafeDifference := s.uss.Difference(&o.uss).(*threadUnsafeSet[T]) - ret := &threadSafeSet[T]{uss: *unsafeDifference} - s.RUnlock() + unsafeDifference := t.uss.Difference(o.uss).(threadUnsafeSet[T]) + ret := &threadSafeSet[T]{uss: unsafeDifference} + t.RUnlock() o.RUnlock() return ret } -func (s *threadSafeSet[T]) SymmetricDifference(other Set[T]) Set[T] { +func (t *threadSafeSet[T]) SymmetricDifference(other Set[T]) Set[T] { o := other.(*threadSafeSet[T]) - s.RLock() + t.RLock() o.RLock() - unsafeDifference := s.uss.SymmetricDifference(&o.uss).(*threadUnsafeSet[T]) - ret := &threadSafeSet[T]{uss: *unsafeDifference} - s.RUnlock() + unsafeDifference := t.uss.SymmetricDifference(o.uss).(threadUnsafeSet[T]) + ret := &threadSafeSet[T]{uss: unsafeDifference} + t.RUnlock() o.RUnlock() return ret } -func (s *threadSafeSet[T]) Clear() { - s.Lock() - s.uss = newThreadUnsafeSet[T]() - s.Unlock() +func (t *threadSafeSet[T]) Clear() { + t.Lock() + t.uss.Clear() + t.Unlock() } -func (s *threadSafeSet[T]) Remove(v T) { - s.Lock() - delete(s.uss, v) - s.Unlock() +func (t *threadSafeSet[T]) Remove(v T) { + t.Lock() + delete(t.uss, v) + t.Unlock() } -func (s *threadSafeSet[T]) Cardinality() int { - s.RLock() - defer s.RUnlock() - return len(s.uss) +func (t *threadSafeSet[T]) Cardinality() int { + t.RLock() + defer t.RUnlock() + return len(t.uss) } -func (s *threadSafeSet[T]) Each(cb func(T) bool) { - s.RLock() - for elem := range s.uss { +func (t *threadSafeSet[T]) Each(cb func(T) bool) { + t.RLock() + for elem := range t.uss { if cb(elem) { break } } - s.RUnlock() + t.RUnlock() } -func (s *threadSafeSet[T]) Iter() <-chan T { +func (t *threadSafeSet[T]) Iter() <-chan T { ch := make(chan T) go func() { - s.RLock() + t.RLock() - for elem := range s.uss { + for elem := range t.uss { ch <- elem } close(ch) - s.RUnlock() + t.RUnlock() }() return ch } -func (s *threadSafeSet[T]) Iterator() *Iterator[T] { +func (t *threadSafeSet[T]) Iterator() *Iterator[T] { iterator, ch, stopCh := newIterator[T]() go func() { - s.RLock() + t.RLock() L: - for elem := range s.uss { + for elem := range t.uss { select { case <-stopCh: break L @@ -193,68 +192,68 @@ func (s *threadSafeSet[T]) Iterator() *Iterator[T] { } } close(ch) - s.RUnlock() + t.RUnlock() }() return iterator } -func (s *threadSafeSet[T]) Equal(other Set[T]) bool { +func (t *threadSafeSet[T]) Equal(other Set[T]) bool { o := other.(*threadSafeSet[T]) - s.RLock() + t.RLock() o.RLock() - ret := s.uss.Equal(&o.uss) - s.RUnlock() + ret := t.uss.Equal(o.uss) + t.RUnlock() o.RUnlock() return ret } -func (s *threadSafeSet[T]) Clone() Set[T] { - s.RLock() +func (t *threadSafeSet[T]) Clone() Set[T] { + t.RLock() - unsafeClone := s.uss.Clone().(*threadUnsafeSet[T]) - ret := &threadSafeSet[T]{uss: *unsafeClone} - s.RUnlock() + unsafeClone := t.uss.Clone().(threadUnsafeSet[T]) + ret := &threadSafeSet[T]{uss: unsafeClone} + t.RUnlock() return ret } -func (s *threadSafeSet[T]) String() string { - s.RLock() - ret := s.uss.String() - s.RUnlock() +func (t *threadSafeSet[T]) String() string { + t.RLock() + ret := t.uss.String() + t.RUnlock() return ret } -func (s *threadSafeSet[T]) Pop() (T, bool) { - s.Lock() - defer s.Unlock() - return s.uss.Pop() +func (t *threadSafeSet[T]) Pop() (T, bool) { + t.Lock() + defer t.Unlock() + return t.uss.Pop() } -func (s *threadSafeSet[T]) ToSlice() []T { - keys := make([]T, 0, s.Cardinality()) - s.RLock() - for elem := range s.uss { +func (t *threadSafeSet[T]) ToSlice() []T { + keys := make([]T, 0, t.Cardinality()) + t.RLock() + for elem := range t.uss { keys = append(keys, elem) } - s.RUnlock() + t.RUnlock() return keys } -func (s *threadSafeSet[T]) MarshalJSON() ([]byte, error) { - s.RLock() - b, err := s.uss.MarshalJSON() - s.RUnlock() +func (t *threadSafeSet[T]) MarshalJSON() ([]byte, error) { + t.RLock() + b, err := t.uss.MarshalJSON() + t.RUnlock() return b, err } -func (s *threadSafeSet[T]) UnmarshalJSON(p []byte) error { - s.RLock() - err := s.uss.UnmarshalJSON(p) - s.RUnlock() +func (t *threadSafeSet[T]) UnmarshalJSON(p []byte) error { + t.RLock() + err := t.uss.UnmarshalJSON(p) + t.RUnlock() return err } diff --git a/threadunsafe.go b/threadunsafe.go index dfc5c8f..36e2fd9 100644 --- a/threadunsafe.go +++ b/threadunsafe.go @@ -35,42 +35,47 @@ import ( type threadUnsafeSet[T comparable] map[T]struct{} // Assert concrete type:threadUnsafeSet adheres to Set interface. -var _ Set[string] = (*threadUnsafeSet[string])(nil) +var _ Set[string] = (threadUnsafeSet[string])(nil) func newThreadUnsafeSet[T comparable]() threadUnsafeSet[T] { return make(threadUnsafeSet[T]) } -func (s *threadUnsafeSet[T]) Add(v T) bool { - prevLen := len(*s) - (*s)[v] = struct{}{} - return prevLen != len(*s) +func (s threadUnsafeSet[T]) Add(v T) bool { + prevLen := len(s) + s[v] = struct{}{} + return prevLen != len(s) } // private version of Add which doesn't return a value -func (s *threadUnsafeSet[T]) add(v T) { - (*s)[v] = struct{}{} +func (s threadUnsafeSet[T]) add(v T) { + s[v] = struct{}{} } -func (s *threadUnsafeSet[T]) Cardinality() int { - return len(*s) +func (s threadUnsafeSet[T]) Cardinality() int { + return len(s) } -func (s *threadUnsafeSet[T]) Clear() { - *s = newThreadUnsafeSet[T]() +func (s threadUnsafeSet[T]) Clear() { + // Constructions like this are optimised by compiler, and replaced by + // mapclear() function, defined in + // https://github.com/golang/go/blob/29bbca5c2c1ad41b2a9747890d183b6dd3a4ace4/src/runtime/map.go#L993) + for key := range s { + delete(s, key) + } } -func (s *threadUnsafeSet[T]) Clone() Set[T] { +func (s threadUnsafeSet[T]) Clone() Set[T] { clonedSet := make(threadUnsafeSet[T], s.Cardinality()) - for elem := range *s { + for elem := range s { clonedSet.add(elem) } - return &clonedSet + return clonedSet } -func (s *threadUnsafeSet[T]) Contains(v ...T) bool { +func (s threadUnsafeSet[T]) Contains(v ...T) bool { for _, val := range v { - if _, ok := (*s)[val]; !ok { + if _, ok := s[val]; !ok { return false } } @@ -78,38 +83,38 @@ func (s *threadUnsafeSet[T]) Contains(v ...T) bool { } // private version of Contains for a single element v -func (s *threadUnsafeSet[T]) contains(v T) bool { - _, ok := (*s)[v] +func (s threadUnsafeSet[T]) contains(v T) (ok bool) { + _, ok = s[v] return ok } -func (s *threadUnsafeSet[T]) Difference(other Set[T]) Set[T] { - o := other.(*threadUnsafeSet[T]) +func (s threadUnsafeSet[T]) Difference(other Set[T]) Set[T] { + o := other.(threadUnsafeSet[T]) diff := newThreadUnsafeSet[T]() - for elem := range *s { + for elem := range s { if !o.contains(elem) { diff.add(elem) } } - return &diff + return diff } -func (s *threadUnsafeSet[T]) Each(cb func(T) bool) { - for elem := range *s { +func (s threadUnsafeSet[T]) Each(cb func(T) bool) { + for elem := range s { if cb(elem) { break } } } -func (s *threadUnsafeSet[T]) Equal(other Set[T]) bool { - o := other.(*threadUnsafeSet[T]) +func (s threadUnsafeSet[T]) Equal(other Set[T]) bool { + o := other.(threadUnsafeSet[T]) if s.Cardinality() != other.Cardinality() { return false } - for elem := range *s { + for elem := range s { if !o.contains(elem) { return false } @@ -117,41 +122,41 @@ func (s *threadUnsafeSet[T]) Equal(other Set[T]) bool { return true } -func (s *threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] { - o := other.(*threadUnsafeSet[T]) +func (s threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] { + o := other.(threadUnsafeSet[T]) intersection := newThreadUnsafeSet[T]() // loop over smaller set if s.Cardinality() < other.Cardinality() { - for elem := range *s { + for elem := range s { if o.contains(elem) { intersection.add(elem) } } } else { - for elem := range *o { + for elem := range o { if s.contains(elem) { intersection.add(elem) } } } - return &intersection + return intersection } -func (s *threadUnsafeSet[T]) IsProperSubset(other Set[T]) bool { +func (s threadUnsafeSet[T]) IsProperSubset(other Set[T]) bool { return s.Cardinality() < other.Cardinality() && s.IsSubset(other) } -func (s *threadUnsafeSet[T]) IsProperSuperset(other Set[T]) bool { +func (s threadUnsafeSet[T]) IsProperSuperset(other Set[T]) bool { return s.Cardinality() > other.Cardinality() && s.IsSuperset(other) } -func (s *threadUnsafeSet[T]) IsSubset(other Set[T]) bool { - o := other.(*threadUnsafeSet[T]) +func (s threadUnsafeSet[T]) IsSubset(other Set[T]) bool { + o := other.(threadUnsafeSet[T]) if s.Cardinality() > other.Cardinality() { return false } - for elem := range *s { + for elem := range s { if !o.contains(elem) { return false } @@ -159,14 +164,14 @@ func (s *threadUnsafeSet[T]) IsSubset(other Set[T]) bool { return true } -func (s *threadUnsafeSet[T]) IsSuperset(other Set[T]) bool { +func (s threadUnsafeSet[T]) IsSuperset(other Set[T]) bool { return other.IsSubset(s) } -func (s *threadUnsafeSet[T]) Iter() <-chan T { +func (s threadUnsafeSet[T]) Iter() <-chan T { ch := make(chan T) go func() { - for elem := range *s { + for elem := range s { ch <- elem } close(ch) @@ -175,12 +180,12 @@ func (s *threadUnsafeSet[T]) Iter() <-chan T { return ch } -func (s *threadUnsafeSet[T]) Iterator() *Iterator[T] { +func (s threadUnsafeSet[T]) Iterator() *Iterator[T] { iterator, ch, stopCh := newIterator[T]() go func() { L: - for elem := range *s { + for elem := range s { select { case <-stopCh: break L @@ -193,56 +198,57 @@ func (s *threadUnsafeSet[T]) Iterator() *Iterator[T] { return iterator } -// TODO: how can we make this properly , return T but can't return nil. -func (s *threadUnsafeSet[T]) Pop() (v T, ok bool) { - for item := range *s { - delete(*s, item) +// Pop returns a popped item in case set is not empty, or nil-value of T +// if set is already empty +func (s threadUnsafeSet[T]) Pop() (v T, ok bool) { + for item := range s { + delete(s, item) return item, true } - return + return v, false } -func (s *threadUnsafeSet[T]) Remove(v T) { - delete(*s, v) +func (s threadUnsafeSet[T]) Remove(v T) { + delete(s, v) } -func (s *threadUnsafeSet[T]) String() string { - items := make([]string, 0, len(*s)) +func (s threadUnsafeSet[T]) String() string { + items := make([]string, 0, len(s)) - for elem := range *s { + for elem := range s { items = append(items, fmt.Sprintf("%v", elem)) } return fmt.Sprintf("Set{%s}", strings.Join(items, ", ")) } -func (s *threadUnsafeSet[T]) SymmetricDifference(other Set[T]) Set[T] { - o := other.(*threadUnsafeSet[T]) +func (s threadUnsafeSet[T]) SymmetricDifference(other Set[T]) Set[T] { + o := other.(threadUnsafeSet[T]) sd := newThreadUnsafeSet[T]() - for elem := range *s { + for elem := range s { if !o.contains(elem) { sd.add(elem) } } - for elem := range *o { + for elem := range o { if !s.contains(elem) { sd.add(elem) } } - return &sd + return sd } -func (s *threadUnsafeSet[T]) ToSlice() []T { +func (s threadUnsafeSet[T]) ToSlice() []T { keys := make([]T, 0, s.Cardinality()) - for elem := range *s { + for elem := range s { keys = append(keys, elem) } return keys } -func (s *threadUnsafeSet[T]) Union(other Set[T]) Set[T] { - o := other.(*threadUnsafeSet[T]) +func (s threadUnsafeSet[T]) Union(other Set[T]) Set[T] { + o := other.(threadUnsafeSet[T]) n := s.Cardinality() if o.Cardinality() > n { @@ -250,20 +256,20 @@ func (s *threadUnsafeSet[T]) Union(other Set[T]) Set[T] { } unionedSet := make(threadUnsafeSet[T], n) - for elem := range *s { + for elem := range s { unionedSet.add(elem) } - for elem := range *o { + for elem := range o { unionedSet.add(elem) } - return &unionedSet + return unionedSet } // MarshalJSON creates a JSON array from the set, it marshals all elements -func (s *threadUnsafeSet[T]) MarshalJSON() ([]byte, error) { +func (s threadUnsafeSet[T]) MarshalJSON() ([]byte, error) { items := make([]string, 0, s.Cardinality()) - for elem := range *s { + for elem := range s { b, err := json.Marshal(elem) if err != nil { return nil, err @@ -277,7 +283,7 @@ func (s *threadUnsafeSet[T]) MarshalJSON() ([]byte, error) { // UnmarshalJSON recreates a set from a JSON array, it only decodes // primitive types. Numbers are decoded as json.Number. -func (s *threadUnsafeSet[T]) UnmarshalJSON(b []byte) error { +func (s threadUnsafeSet[T]) UnmarshalJSON(b []byte) error { var i []any d := json.NewDecoder(bytes.NewReader(b))