I have come up with a solution based on the trie data structure as described here. Tries make it relatively fast to determine whether one of the stored sets is a subset of another given set (Savnik, 2013).
The solution then looks as follows:
- Create a trie
- Iterate through the given sets
- In each iteration, go through the sets in the trie and check if they are disjoint with the new set.
- If they are, continue; if not, add corresponding new sets to the trie unless they are supersets of sets in the trie.
The worst-case runtime is O(n m c), whereby m is the maximal number of solutions if we consider only n' <= n of the input sets, and c is the time factor from the subset lookups.
The code is below. I have implemented the algorithm based on the python package datrie, which is a wrapper around an efficent C implementation of a trie. The code below is in cython but can be converted to pure python easily by removing/exchangin cython specific commands.
The extended trie implementation:
from datrie cimport BaseTrie, BaseState, BaseIterator
cdef bint has_subset_c(BaseTrie trie, BaseState trieState, str setarr,
int index, int size):
cdef BaseState trieState2 = BaseState(trie)
cdef int i
trieState.copy_to(trieState2)
for i in range(index, size):
if trieState2.walk(setarr[i]):
if trieState2.is_terminal() or has_subset_c(trie, trieState2, setarr,
i, size):
return True
trieState.copy_to(trieState2)
return False
cdef class SetTrie():
def __init__(self, alphabet, initSet=[]):
if not hasattr(alphabet, "__iter__"):
alphabet = range(alphabet)
self.trie = BaseTrie("".join(chr(i) for i in alphabet))
self.touched = False
for i in initSet:
self.trie[chr(i)] = 0
if not self.touched:
self.touched = True
def has_subset(self, superset):
cdef BaseState trieState = BaseState(self.trie)
setarr = "".join(chr(i) for i in superset)
return bool(has_subset_c(self.trie, trieState, setarr, 0, len(setarr)))
def extend(self, sets):
for s in sets:
self.trie["".join(chr(i) for i in s)] = 0
if not self.touched:
self.touched = True
def delete_supersets(self):
cdef str elem
cdef BaseState trieState = BaseState(self.trie)
cdef BaseIterator trieIter = BaseIterator(BaseState(self.trie))
if trieIter.next():
elem = trieIter.key()
while trieIter.next():
self.trie._delitem(elem)
if not has_subset_c(self.trie, trieState, elem, 0, len(elem)):
self.trie._setitem(elem, 0)
elem = trieIter.key()
if has_subset_c(self.trie, trieState, elem, 0, len(elem)):
val = self.trie.pop(elem)
if not has_subset_c(self.trie, trieState, elem, 0, len(elem)):
self.trie._setitem(elem, val)
def update_by_settrie(self, SetTrie setTrie, maxSize=inf, initialize=True):
cdef BaseIterator trieIter = BaseIterator(BaseState(setTrie.trie))
cdef str s
if initialize and not self.touched and trieIter.next():
for s in trieIter.key():
self.trie._setitem(s, 0)
self.touched = True
while trieIter.next():
self.update(set(trieIter.key()), maxSize, True)
def update(self, otherSet, maxSize=inf, isStrSet=False):
if not isStrSet:
otherSet = set(chr(i) for i in otherSet)
cdef str subset, newSubset, elem
cdef list disjointList = []
cdef BaseTrie trie = self.trie
cdef int l
cdef BaseIterator trieIter = BaseIterator(BaseState(self.trie))
if trieIter.next():
subset = trieIter.key()
while trieIter.next():
if otherSet.isdisjoint(subset):
disjointList.append(subset)
trie._delitem(subset)
subset = trieIter.key()
if otherSet.isdisjoint(subset):
disjointList.append(subset)
trie._delitem(subset)
cdef BaseState trieState = BaseState(self.trie)
for subset in disjointList:
l = len(subset)
if l < maxSize:
if l+1 > self.maxSizeBound:
self.maxSizeBound = l+1
for elem in otherSet:
newSubset = subset + elem
trieState.rewind()
if not has_subset_c(self.trie, trieState, newSubset, 0,
len(newSubset)):
trie[newSubset] = 0
def get_frozensets(self):
return (frozenset(ord(t) for t in subset) for subset in self.trie)
def clear(self):
self.touched = False
self.trie.clear()
def prune(self, maxSize):
cdef bint changed = False
cdef BaseIterator trieIter
cdef str k
if self.maxSizeBound > maxSize:
self.maxSizeBound = maxSize
trieIter = BaseIterator(BaseState(self.trie))
k = ''
while trieIter.next():
if len(k) > maxSize:
self.trie._delitem(k)
changed = True
k = trieIter.key()
if len(k) > maxSize:
self.trie._delitem(k)
changed = True
return changed
def __nonzero__(self):
return self.touched
def __repr__(self):
return str([set(ord(t) for t in subset) for subset in self.trie])
This can be used as follows:
def cover_sets(sets):
strie = SetTrie(range(10), *([i] for i in sets[0]))
for s in sets[1:]:
strie.update(s)
return strie.get_frozensets()
Timing:
from timeit import timeit
s1 = {1, 2, 3}
s2 = {3, 4, 5}
s3 = {5, 6}
%timeit cover_sets([s1, s2, s3])
Result:
37.8 µs ± 2.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Note that the trie implementation above works only with keys larger than (and not equal to) 0. Otherwise, the integer to character mapping does not work properly. This problem can be solved with an index shift.