【Golang】Set实现

Posted by 西维蜀黍 on 2020-03-22, Last Modified on 2021-09-21

Interface

type Set interface {
	// Adds an element to the set. Returns whether
	// the item was added.
	Add(i interface{}) bool

	// Returns the number of elements in the set.
	Cardinality() int

	// Removes all elements from the set, leaving
	// the empty set.
	Clear()

	// Returns a clone of the set using the same
	// implementation, duplicating all keys.
	Clone() Set

	// Returns whether the given items
	// are all in the set.
	Contains(i ...interface{}) bool

	// Returns the difference between this set
	// and other. The returned set will contain
	// all elements of this set that are not also
	// elements of other.
	//
	// Note that the argument to Difference
	// must be of the same type as the receiver
	// of the method. Otherwise, Difference will
	// panic.
	Difference(other Set) Set

	// Determines if two sets are equal to each
	// other. If they have the same cardinality
	// and contain the same elements, they are
	// considered equal. The order in which
	// the elements were added is irrelevant.
	//
	// Note that the argument to Equal must be
	// of the same type as the receiver of the
	// method. Otherwise, Equal will panic.
	Equal(other Set) bool

	// Returns a new set containing only the elements
	// that exist only in both sets.
	//
	// Note that the argument to Intersect
	// must be of the same type as the receiver
	// of the method. Otherwise, Intersect will
	// panic.
	Intersect(other Set) Set
	
	...
}	

创建Set

// NewSet creates and returns a reference to an empty set.  Operations
// on the resulting set are thread-safe.
func NewSet(s ...interface{}) Set {
   set := newThreadSafeSet()
   for _, item := range s {
      set.Add(item)
   }
   return &set
}

Set的内部实现

type threadUnsafeSet map[interface{}]struct{}

type threadSafeSet struct {
	s threadUnsafeSet
	sync.RWMutex
}

func newThreadSafeSet() threadSafeSet {
	return threadSafeSet{s: newThreadUnsafeSet()}
}

添加

func (set *threadSafeSet) Add(i interface{}) bool {
   set.Lock()
   ret := set.s.Add(i)
   set.Unlock()
   return ret
}

func (set *threadUnsafeSet) Add(i interface{}) bool {
	_, found := (*set)[i]
	if found {
		return false //False if it existed already
	}

	(*set)[i] = struct{}{}
	return true
}

包含

func (set *threadUnsafeSet) Contains(i ...interface{}) bool {
   for _, val := range i {
      if _, ok := (*set)[val]; !ok {
         return false
      }
   }
   return true
}

长度和清除

func (set *threadUnsafeSet) Cardinality() int {
	return len(*set)
}

func (set *threadUnsafeSet) Clear() {
	*set = newThreadUnsafeSet()
}

func (set *threadUnsafeSet) Remove(i interface{}) {
	delete(*set, i)
}

相等

func (set *threadUnsafeSet) Equal(other Set) bool {
   _ = other.(*threadUnsafeSet)

   if set.Cardinality() != other.Cardinality() {
      return false
   }
   for elem := range *set {
      if !other.Contains(elem) {
         return false
      }
   }
   return true
}

子集

func (set *threadUnsafeSet) IsSubset(other Set) bool {
   _ = other.(*threadUnsafeSet)
   for elem := range *set {
      if !other.Contains(elem) {
         return false
      }
   }
   return true
}

交集

func (set *threadSafeSet) Union(other Set) Set {
	o := other.(*threadSafeSet)

	set.RLock()
	o.RLock()

	unsafeUnion := set.s.Union(&o.s).(*threadUnsafeSet)
	ret := &threadSafeSet{s: *unsafeUnion}
	set.RUnlock()
	o.RUnlock()
	return ret
}

func (set *threadUnsafeSet) Union(other Set) Set {
	o := other.(*threadUnsafeSet)

	unionedSet := newThreadUnsafeSet()

	for elem := range *set {
		unionedSet.Add(elem)
	}
	for elem := range *o {
		unionedSet.Add(elem)
	}
	return &unionedSet
}

Reference