rustc_index/
bit_set.rs

1use std::marker::PhantomData;
2#[cfg(not(feature = "nightly"))]
3use std::mem;
4use std::ops::{BitAnd, BitAndAssign, BitOrAssign, Bound, Not, Range, RangeBounds, Shl};
5use std::rc::Rc;
6use std::{fmt, iter, slice};
7
8use Chunk::*;
9#[cfg(feature = "nightly")]
10use rustc_macros::{Decodable_NoContext, Encodable_NoContext};
11use smallvec::{SmallVec, smallvec};
12
13use crate::{Idx, IndexVec};
14
15#[cfg(test)]
16mod tests;
17
18type Word = u64;
19const WORD_BYTES: usize = size_of::<Word>();
20const WORD_BITS: usize = WORD_BYTES * 8;
21
22// The choice of chunk size has some trade-offs.
23//
24// A big chunk size tends to favour cases where many large `ChunkedBitSet`s are
25// present, because they require fewer `Chunk`s, reducing the number of
26// allocations and reducing peak memory usage. Also, fewer chunk operations are
27// required, though more of them might be `Mixed`.
28//
29// A small chunk size tends to favour cases where many small `ChunkedBitSet`s
30// are present, because less space is wasted at the end of the final chunk (if
31// it's not full).
32const CHUNK_WORDS: usize = 32;
33const CHUNK_BITS: usize = CHUNK_WORDS * WORD_BITS; // 2048 bits
34
35/// ChunkSize is small to keep `Chunk` small. The static assertion ensures it's
36/// not too small.
37type ChunkSize = u16;
38const _: () = assert!(CHUNK_BITS <= ChunkSize::MAX as usize);
39
40pub trait BitRelations<Rhs> {
41    fn union(&mut self, other: &Rhs) -> bool;
42    fn subtract(&mut self, other: &Rhs) -> bool;
43    fn intersect(&mut self, other: &Rhs) -> bool;
44}
45
46#[inline]
47fn inclusive_start_end<T: Idx>(
48    range: impl RangeBounds<T>,
49    domain: usize,
50) -> Option<(usize, usize)> {
51    // Both start and end are inclusive.
52    let start = match range.start_bound().cloned() {
53        Bound::Included(start) => start.index(),
54        Bound::Excluded(start) => start.index() + 1,
55        Bound::Unbounded => 0,
56    };
57    let end = match range.end_bound().cloned() {
58        Bound::Included(end) => end.index(),
59        Bound::Excluded(end) => end.index().checked_sub(1)?,
60        Bound::Unbounded => domain - 1,
61    };
62    assert!(end < domain);
63    if start > end {
64        return None;
65    }
66    Some((start, end))
67}
68
69macro_rules! bit_relations_inherent_impls {
70    () => {
71        /// Sets `self = self | other` and returns `true` if `self` changed
72        /// (i.e., if new bits were added).
73        pub fn union<Rhs>(&mut self, other: &Rhs) -> bool
74        where
75            Self: BitRelations<Rhs>,
76        {
77            <Self as BitRelations<Rhs>>::union(self, other)
78        }
79
80        /// Sets `self = self - other` and returns `true` if `self` changed.
81        /// (i.e., if any bits were removed).
82        pub fn subtract<Rhs>(&mut self, other: &Rhs) -> bool
83        where
84            Self: BitRelations<Rhs>,
85        {
86            <Self as BitRelations<Rhs>>::subtract(self, other)
87        }
88
89        /// Sets `self = self & other` and return `true` if `self` changed.
90        /// (i.e., if any bits were removed).
91        pub fn intersect<Rhs>(&mut self, other: &Rhs) -> bool
92        where
93            Self: BitRelations<Rhs>,
94        {
95            <Self as BitRelations<Rhs>>::intersect(self, other)
96        }
97    };
98}
99
100/// A fixed-size bitset type with a dense representation.
101///
102/// Note 1: Since this bitset is dense, if your domain is big, and/or relatively
103/// homogeneous (for example, with long runs of bits set or unset), then it may
104/// be preferable to instead use a [MixedBitSet], or an
105/// [IntervalSet](crate::interval::IntervalSet). They should be more suited to
106/// sparse, or highly-compressible, domains.
107///
108/// Note 2: Use [`GrowableBitSet`] if you need support for resizing after creation.
109///
110/// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
111/// just be `usize`.
112///
113/// All operations that involve an element will panic if the element is equal
114/// to or greater than the domain size. All operations that involve two bitsets
115/// will panic if the bitsets have differing domain sizes.
116///
117#[cfg_attr(feature = "nightly", derive(Decodable_NoContext, Encodable_NoContext))]
118#[derive(Eq, PartialEq, Hash)]
119pub struct DenseBitSet<T> {
120    domain_size: usize,
121    words: SmallVec<[Word; 2]>,
122    marker: PhantomData<T>,
123}
124
125impl<T> DenseBitSet<T> {
126    /// Gets the domain size.
127    pub fn domain_size(&self) -> usize {
128        self.domain_size
129    }
130}
131
132impl<T: Idx> DenseBitSet<T> {
133    /// Creates a new, empty bitset with a given `domain_size`.
134    #[inline]
135    pub fn new_empty(domain_size: usize) -> DenseBitSet<T> {
136        let num_words = num_words(domain_size);
137        DenseBitSet { domain_size, words: smallvec![0; num_words], marker: PhantomData }
138    }
139
140    /// Creates a new, filled bitset with a given `domain_size`.
141    #[inline]
142    pub fn new_filled(domain_size: usize) -> DenseBitSet<T> {
143        let num_words = num_words(domain_size);
144        let mut result =
145            DenseBitSet { domain_size, words: smallvec![!0; num_words], marker: PhantomData };
146        result.clear_excess_bits();
147        result
148    }
149
150    /// Clear all elements.
151    #[inline]
152    pub fn clear(&mut self) {
153        self.words.fill(0);
154    }
155
156    /// Clear excess bits in the final word.
157    fn clear_excess_bits(&mut self) {
158        clear_excess_bits_in_final_word(self.domain_size, &mut self.words);
159    }
160
161    /// Count the number of set bits in the set.
162    pub fn count(&self) -> usize {
163        self.words.iter().map(|e| e.count_ones() as usize).sum()
164    }
165
166    /// Returns `true` if `self` contains `elem`.
167    #[inline]
168    pub fn contains(&self, elem: T) -> bool {
169        assert!(elem.index() < self.domain_size);
170        let (word_index, mask) = word_index_and_mask(elem);
171        (self.words[word_index] & mask) != 0
172    }
173
174    /// Is `self` is a (non-strict) superset of `other`?
175    #[inline]
176    pub fn superset(&self, other: &DenseBitSet<T>) -> bool {
177        assert_eq!(self.domain_size, other.domain_size);
178        self.words.iter().zip(&other.words).all(|(a, b)| (a & b) == *b)
179    }
180
181    /// Is the set empty?
182    #[inline]
183    pub fn is_empty(&self) -> bool {
184        self.words.iter().all(|a| *a == 0)
185    }
186
187    /// Insert `elem`. Returns whether the set has changed.
188    #[inline]
189    pub fn insert(&mut self, elem: T) -> bool {
190        assert!(
191            elem.index() < self.domain_size,
192            "inserting element at index {} but domain size is {}",
193            elem.index(),
194            self.domain_size,
195        );
196        let (word_index, mask) = word_index_and_mask(elem);
197        let word_ref = &mut self.words[word_index];
198        let word = *word_ref;
199        let new_word = word | mask;
200        *word_ref = new_word;
201        new_word != word
202    }
203
204    #[inline]
205    pub fn insert_range(&mut self, elems: impl RangeBounds<T>) {
206        let Some((start, end)) = inclusive_start_end(elems, self.domain_size) else {
207            return;
208        };
209
210        let (start_word_index, start_mask) = word_index_and_mask(start);
211        let (end_word_index, end_mask) = word_index_and_mask(end);
212
213        // Set all words in between start and end (exclusively of both).
214        for word_index in (start_word_index + 1)..end_word_index {
215            self.words[word_index] = !0;
216        }
217
218        if start_word_index != end_word_index {
219            // Start and end are in different words, so we handle each in turn.
220            //
221            // We set all leading bits. This includes the start_mask bit.
222            self.words[start_word_index] |= !(start_mask - 1);
223            // And all trailing bits (i.e. from 0..=end) in the end word,
224            // including the end.
225            self.words[end_word_index] |= end_mask | (end_mask - 1);
226        } else {
227            self.words[start_word_index] |= end_mask | (end_mask - start_mask);
228        }
229    }
230
231    /// Sets all bits to true.
232    pub fn insert_all(&mut self) {
233        self.words.fill(!0);
234        self.clear_excess_bits();
235    }
236
237    /// Returns `true` if the set has changed.
238    #[inline]
239    pub fn remove(&mut self, elem: T) -> bool {
240        assert!(elem.index() < self.domain_size);
241        let (word_index, mask) = word_index_and_mask(elem);
242        let word_ref = &mut self.words[word_index];
243        let word = *word_ref;
244        let new_word = word & !mask;
245        *word_ref = new_word;
246        new_word != word
247    }
248
249    /// Iterates over the indices of set bits in a sorted order.
250    #[inline]
251    pub fn iter(&self) -> BitIter<'_, T> {
252        BitIter::new(&self.words)
253    }
254
255    pub fn last_set_in(&self, range: impl RangeBounds<T>) -> Option<T> {
256        let (start, end) = inclusive_start_end(range, self.domain_size)?;
257        let (start_word_index, _) = word_index_and_mask(start);
258        let (end_word_index, end_mask) = word_index_and_mask(end);
259
260        let end_word = self.words[end_word_index] & (end_mask | (end_mask - 1));
261        if end_word != 0 {
262            let pos = max_bit(end_word) + WORD_BITS * end_word_index;
263            if start <= pos {
264                return Some(T::new(pos));
265            }
266        }
267
268        // We exclude end_word_index from the range here, because we don't want
269        // to limit ourselves to *just* the last word: the bits set it in may be
270        // after `end`, so it may not work out.
271        if let Some(offset) =
272            self.words[start_word_index..end_word_index].iter().rposition(|&w| w != 0)
273        {
274            let word_idx = start_word_index + offset;
275            let start_word = self.words[word_idx];
276            let pos = max_bit(start_word) + WORD_BITS * word_idx;
277            if start <= pos {
278                return Some(T::new(pos));
279            }
280        }
281
282        None
283    }
284
285    bit_relations_inherent_impls! {}
286
287    /// Sets `self = self | !other`.
288    ///
289    /// FIXME: Incorporate this into [`BitRelations`] and fill out
290    /// implementations for other bitset types, if needed.
291    pub fn union_not(&mut self, other: &DenseBitSet<T>) {
292        assert_eq!(self.domain_size, other.domain_size);
293
294        // FIXME(Zalathar): If we were to forcibly _set_ all excess bits before
295        // the bitwise update, and then clear them again afterwards, we could
296        // quickly and accurately detect whether the update changed anything.
297        // But that's only worth doing if there's an actual use-case.
298
299        bitwise(&mut self.words, &other.words, |a, b| a | !b);
300        // The bitwise update `a | !b` can result in the last word containing
301        // out-of-domain bits, so we need to clear them.
302        self.clear_excess_bits();
303    }
304}
305
306// dense REL dense
307impl<T: Idx> BitRelations<DenseBitSet<T>> for DenseBitSet<T> {
308    fn union(&mut self, other: &DenseBitSet<T>) -> bool {
309        assert_eq!(self.domain_size, other.domain_size);
310        bitwise(&mut self.words, &other.words, |a, b| a | b)
311    }
312
313    fn subtract(&mut self, other: &DenseBitSet<T>) -> bool {
314        assert_eq!(self.domain_size, other.domain_size);
315        bitwise(&mut self.words, &other.words, |a, b| a & !b)
316    }
317
318    fn intersect(&mut self, other: &DenseBitSet<T>) -> bool {
319        assert_eq!(self.domain_size, other.domain_size);
320        bitwise(&mut self.words, &other.words, |a, b| a & b)
321    }
322}
323
324impl<T: Idx> From<GrowableBitSet<T>> for DenseBitSet<T> {
325    fn from(bit_set: GrowableBitSet<T>) -> Self {
326        bit_set.bit_set
327    }
328}
329
330impl<T> Clone for DenseBitSet<T> {
331    fn clone(&self) -> Self {
332        DenseBitSet {
333            domain_size: self.domain_size,
334            words: self.words.clone(),
335            marker: PhantomData,
336        }
337    }
338
339    fn clone_from(&mut self, from: &Self) {
340        self.domain_size = from.domain_size;
341        self.words.clone_from(&from.words);
342    }
343}
344
345impl<T: Idx> fmt::Debug for DenseBitSet<T> {
346    fn fmt(&self, w: &mut fmt::Formatter<'_>) -> fmt::Result {
347        w.debug_list().entries(self.iter()).finish()
348    }
349}
350
351impl<T: Idx> ToString for DenseBitSet<T> {
352    fn to_string(&self) -> String {
353        let mut result = String::new();
354        let mut sep = '[';
355
356        // Note: this is a little endian printout of bytes.
357
358        // i tracks how many bits we have printed so far.
359        let mut i = 0;
360        for word in &self.words {
361            let mut word = *word;
362            for _ in 0..WORD_BYTES {
363                // for each byte in `word`:
364                let remain = self.domain_size - i;
365                // If less than a byte remains, then mask just that many bits.
366                let mask = if remain <= 8 { (1 << remain) - 1 } else { 0xFF };
367                assert!(mask <= 0xFF);
368                let byte = word & mask;
369
370                result.push_str(&format!("{sep}{byte:02x}"));
371
372                if remain <= 8 {
373                    break;
374                }
375                word >>= 8;
376                i += 8;
377                sep = '-';
378            }
379            sep = '|';
380        }
381        result.push(']');
382
383        result
384    }
385}
386
387pub struct BitIter<'a, T: Idx> {
388    /// A copy of the current word, but with any already-visited bits cleared.
389    /// (This lets us use `trailing_zeros()` to find the next set bit.) When it
390    /// is reduced to 0, we move onto the next word.
391    word: Word,
392
393    /// The offset (measured in bits) of the current word.
394    offset: usize,
395
396    /// Underlying iterator over the words.
397    iter: slice::Iter<'a, Word>,
398
399    marker: PhantomData<T>,
400}
401
402impl<'a, T: Idx> BitIter<'a, T> {
403    #[inline]
404    fn new(words: &'a [Word]) -> BitIter<'a, T> {
405        // We initialize `word` and `offset` to degenerate values. On the first
406        // call to `next()` we will fall through to getting the first word from
407        // `iter`, which sets `word` to the first word (if there is one) and
408        // `offset` to 0. Doing it this way saves us from having to maintain
409        // additional state about whether we have started.
410        BitIter {
411            word: 0,
412            offset: usize::MAX - (WORD_BITS - 1),
413            iter: words.iter(),
414            marker: PhantomData,
415        }
416    }
417}
418
419impl<'a, T: Idx> Iterator for BitIter<'a, T> {
420    type Item = T;
421    fn next(&mut self) -> Option<T> {
422        loop {
423            if self.word != 0 {
424                // Get the position of the next set bit in the current word,
425                // then clear the bit.
426                let bit_pos = self.word.trailing_zeros() as usize;
427                self.word ^= 1 << bit_pos;
428                return Some(T::new(bit_pos + self.offset));
429            }
430
431            // Move onto the next word. `wrapping_add()` is needed to handle
432            // the degenerate initial value given to `offset` in `new()`.
433            self.word = *self.iter.next()?;
434            self.offset = self.offset.wrapping_add(WORD_BITS);
435        }
436    }
437}
438
439/// A fixed-size bitset type with a partially dense, partially sparse
440/// representation. The bitset is broken into chunks, and chunks that are all
441/// zeros or all ones are represented and handled very efficiently.
442///
443/// This type is especially efficient for sets that typically have a large
444/// `domain_size` with significant stretches of all zeros or all ones, and also
445/// some stretches with lots of 0s and 1s mixed in a way that causes trouble
446/// for `IntervalSet`.
447///
448/// Best used via `MixedBitSet`, rather than directly, because `MixedBitSet`
449/// has better performance for small bitsets.
450///
451/// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
452/// just be `usize`.
453///
454/// All operations that involve an element will panic if the element is equal
455/// to or greater than the domain size. All operations that involve two bitsets
456/// will panic if the bitsets have differing domain sizes.
457#[derive(PartialEq, Eq)]
458pub struct ChunkedBitSet<T> {
459    domain_size: usize,
460
461    /// The chunks. Each one contains exactly CHUNK_BITS values, except the
462    /// last one which contains 1..=CHUNK_BITS values.
463    chunks: Box<[Chunk]>,
464
465    marker: PhantomData<T>,
466}
467
468// Note: the chunk domain size is duplicated in each variant. This is a bit
469// inconvenient, but it allows the type size to be smaller than if we had an
470// outer struct containing a chunk domain size plus the `Chunk`, because the
471// compiler can place the chunk domain size after the tag.
472#[derive(Clone, Debug, PartialEq, Eq)]
473enum Chunk {
474    /// A chunk that is all zeros; we don't represent the zeros explicitly.
475    /// The `ChunkSize` is always non-zero.
476    Zeros(ChunkSize),
477
478    /// A chunk that is all ones; we don't represent the ones explicitly.
479    /// `ChunkSize` is always non-zero.
480    Ones(ChunkSize),
481
482    /// A chunk that has a mix of zeros and ones, which are represented
483    /// explicitly and densely. It never has all zeros or all ones.
484    ///
485    /// If this is the final chunk there may be excess, unused words. This
486    /// turns out to be both simpler and have better performance than
487    /// allocating the minimum number of words, largely because we avoid having
488    /// to store the length, which would make this type larger. These excess
489    /// words are always zero, as are any excess bits in the final in-use word.
490    ///
491    /// The first `ChunkSize` field is always non-zero.
492    ///
493    /// The second `ChunkSize` field is the count of 1s set in the chunk, and
494    /// must satisfy `0 < count < chunk_domain_size`.
495    ///
496    /// The words are within an `Rc` because it's surprisingly common to
497    /// duplicate an entire chunk, e.g. in `ChunkedBitSet::clone_from()`, or
498    /// when a `Mixed` chunk is union'd into a `Zeros` chunk. When we do need
499    /// to modify a chunk we use `Rc::make_mut`.
500    Mixed(ChunkSize, ChunkSize, Rc<[Word; CHUNK_WORDS]>),
501}
502
503// This type is used a lot. Make sure it doesn't unintentionally get bigger.
504#[cfg(target_pointer_width = "64")]
505crate::static_assert_size!(Chunk, 16);
506
507impl<T> ChunkedBitSet<T> {
508    pub fn domain_size(&self) -> usize {
509        self.domain_size
510    }
511
512    #[cfg(test)]
513    fn assert_valid(&self) {
514        if self.domain_size == 0 {
515            assert!(self.chunks.is_empty());
516            return;
517        }
518
519        assert!((self.chunks.len() - 1) * CHUNK_BITS <= self.domain_size);
520        assert!(self.chunks.len() * CHUNK_BITS >= self.domain_size);
521        for chunk in self.chunks.iter() {
522            chunk.assert_valid();
523        }
524    }
525}
526
527impl<T: Idx> ChunkedBitSet<T> {
528    /// Creates a new bitset with a given `domain_size` and chunk kind.
529    fn new(domain_size: usize, is_empty: bool) -> Self {
530        let chunks = if domain_size == 0 {
531            Box::new([])
532        } else {
533            // All the chunks have a chunk_domain_size of `CHUNK_BITS` except
534            // the final one.
535            let final_chunk_domain_size = {
536                let n = domain_size % CHUNK_BITS;
537                if n == 0 { CHUNK_BITS } else { n }
538            };
539            let mut chunks =
540                vec![Chunk::new(CHUNK_BITS, is_empty); num_chunks(domain_size)].into_boxed_slice();
541            *chunks.last_mut().unwrap() = Chunk::new(final_chunk_domain_size, is_empty);
542            chunks
543        };
544        ChunkedBitSet { domain_size, chunks, marker: PhantomData }
545    }
546
547    /// Creates a new, empty bitset with a given `domain_size`.
548    #[inline]
549    pub fn new_empty(domain_size: usize) -> Self {
550        ChunkedBitSet::new(domain_size, /* is_empty */ true)
551    }
552
553    /// Creates a new, filled bitset with a given `domain_size`.
554    #[inline]
555    pub fn new_filled(domain_size: usize) -> Self {
556        ChunkedBitSet::new(domain_size, /* is_empty */ false)
557    }
558
559    pub fn clear(&mut self) {
560        let domain_size = self.domain_size();
561        *self = ChunkedBitSet::new_empty(domain_size);
562    }
563
564    #[cfg(test)]
565    fn chunks(&self) -> &[Chunk] {
566        &self.chunks
567    }
568
569    /// Count the number of bits in the set.
570    pub fn count(&self) -> usize {
571        self.chunks.iter().map(|chunk| chunk.count()).sum()
572    }
573
574    pub fn is_empty(&self) -> bool {
575        self.chunks.iter().all(|chunk| matches!(chunk, Zeros(..)))
576    }
577
578    /// Returns `true` if `self` contains `elem`.
579    #[inline]
580    pub fn contains(&self, elem: T) -> bool {
581        assert!(elem.index() < self.domain_size);
582        let chunk = &self.chunks[chunk_index(elem)];
583        match &chunk {
584            Zeros(_) => false,
585            Ones(_) => true,
586            Mixed(_, _, words) => {
587                let (word_index, mask) = chunk_word_index_and_mask(elem);
588                (words[word_index] & mask) != 0
589            }
590        }
591    }
592
593    #[inline]
594    pub fn iter(&self) -> ChunkedBitIter<'_, T> {
595        ChunkedBitIter::new(self)
596    }
597
598    /// Insert `elem`. Returns whether the set has changed.
599    pub fn insert(&mut self, elem: T) -> bool {
600        assert!(elem.index() < self.domain_size);
601        let chunk_index = chunk_index(elem);
602        let chunk = &mut self.chunks[chunk_index];
603        match *chunk {
604            Zeros(chunk_domain_size) => {
605                if chunk_domain_size > 1 {
606                    #[cfg(feature = "nightly")]
607                    let mut words = {
608                        // We take some effort to avoid copying the words.
609                        let words = Rc::<[Word; CHUNK_WORDS]>::new_zeroed();
610                        // SAFETY: `words` can safely be all zeroes.
611                        unsafe { words.assume_init() }
612                    };
613                    #[cfg(not(feature = "nightly"))]
614                    let mut words = {
615                        // FIXME: unconditionally use `Rc::new_zeroed` once it is stable (#63291).
616                        let words = mem::MaybeUninit::<[Word; CHUNK_WORDS]>::zeroed();
617                        // SAFETY: `words` can safely be all zeroes.
618                        let words = unsafe { words.assume_init() };
619                        // Unfortunate possibly-large copy
620                        Rc::new(words)
621                    };
622                    let words_ref = Rc::get_mut(&mut words).unwrap();
623
624                    let (word_index, mask) = chunk_word_index_and_mask(elem);
625                    words_ref[word_index] |= mask;
626                    *chunk = Mixed(chunk_domain_size, 1, words);
627                } else {
628                    *chunk = Ones(chunk_domain_size);
629                }
630                true
631            }
632            Ones(_) => false,
633            Mixed(chunk_domain_size, ref mut count, ref mut words) => {
634                // We skip all the work if the bit is already set.
635                let (word_index, mask) = chunk_word_index_and_mask(elem);
636                if (words[word_index] & mask) == 0 {
637                    *count += 1;
638                    if *count < chunk_domain_size {
639                        let words = Rc::make_mut(words);
640                        words[word_index] |= mask;
641                    } else {
642                        *chunk = Ones(chunk_domain_size);
643                    }
644                    true
645                } else {
646                    false
647                }
648            }
649        }
650    }
651
652    /// Sets all bits to true.
653    pub fn insert_all(&mut self) {
654        for chunk in self.chunks.iter_mut() {
655            *chunk = match *chunk {
656                Zeros(chunk_domain_size)
657                | Ones(chunk_domain_size)
658                | Mixed(chunk_domain_size, ..) => Ones(chunk_domain_size),
659            }
660        }
661    }
662
663    /// Returns `true` if the set has changed.
664    pub fn remove(&mut self, elem: T) -> bool {
665        assert!(elem.index() < self.domain_size);
666        let chunk_index = chunk_index(elem);
667        let chunk = &mut self.chunks[chunk_index];
668        match *chunk {
669            Zeros(_) => false,
670            Ones(chunk_domain_size) => {
671                if chunk_domain_size > 1 {
672                    #[cfg(feature = "nightly")]
673                    let mut words = {
674                        // We take some effort to avoid copying the words.
675                        let words = Rc::<[Word; CHUNK_WORDS]>::new_zeroed();
676                        // SAFETY: `words` can safely be all zeroes.
677                        unsafe { words.assume_init() }
678                    };
679                    #[cfg(not(feature = "nightly"))]
680                    let mut words = {
681                        // FIXME: unconditionally use `Rc::new_zeroed` once it is stable (#63291).
682                        let words = mem::MaybeUninit::<[Word; CHUNK_WORDS]>::zeroed();
683                        // SAFETY: `words` can safely be all zeroes.
684                        let words = unsafe { words.assume_init() };
685                        // Unfortunate possibly-large copy
686                        Rc::new(words)
687                    };
688                    let words_ref = Rc::get_mut(&mut words).unwrap();
689
690                    // Set only the bits in use.
691                    let num_words = num_words(chunk_domain_size as usize);
692                    words_ref[..num_words].fill(!0);
693                    clear_excess_bits_in_final_word(
694                        chunk_domain_size as usize,
695                        &mut words_ref[..num_words],
696                    );
697                    let (word_index, mask) = chunk_word_index_and_mask(elem);
698                    words_ref[word_index] &= !mask;
699                    *chunk = Mixed(chunk_domain_size, chunk_domain_size - 1, words);
700                } else {
701                    *chunk = Zeros(chunk_domain_size);
702                }
703                true
704            }
705            Mixed(chunk_domain_size, ref mut count, ref mut words) => {
706                // We skip all the work if the bit is already clear.
707                let (word_index, mask) = chunk_word_index_and_mask(elem);
708                if (words[word_index] & mask) != 0 {
709                    *count -= 1;
710                    if *count > 0 {
711                        let words = Rc::make_mut(words);
712                        words[word_index] &= !mask;
713                    } else {
714                        *chunk = Zeros(chunk_domain_size);
715                    }
716                    true
717                } else {
718                    false
719                }
720            }
721        }
722    }
723
724    fn chunk_iter(&self, chunk_index: usize) -> ChunkIter<'_> {
725        match self.chunks.get(chunk_index) {
726            Some(Zeros(_chunk_domain_size)) => ChunkIter::Zeros,
727            Some(Ones(chunk_domain_size)) => ChunkIter::Ones(0..*chunk_domain_size as usize),
728            Some(Mixed(chunk_domain_size, _, words)) => {
729                let num_words = num_words(*chunk_domain_size as usize);
730                ChunkIter::Mixed(BitIter::new(&words[0..num_words]))
731            }
732            None => ChunkIter::Finished,
733        }
734    }
735
736    bit_relations_inherent_impls! {}
737}
738
739impl<T: Idx> BitRelations<ChunkedBitSet<T>> for ChunkedBitSet<T> {
740    fn union(&mut self, other: &ChunkedBitSet<T>) -> bool {
741        assert_eq!(self.domain_size, other.domain_size);
742        debug_assert_eq!(self.chunks.len(), other.chunks.len());
743
744        let mut changed = false;
745        for (mut self_chunk, other_chunk) in self.chunks.iter_mut().zip(other.chunks.iter()) {
746            match (&mut self_chunk, &other_chunk) {
747                (_, Zeros(_)) | (Ones(_), _) => {}
748                (Zeros(self_chunk_domain_size), Ones(other_chunk_domain_size))
749                | (Mixed(self_chunk_domain_size, ..), Ones(other_chunk_domain_size))
750                | (Zeros(self_chunk_domain_size), Mixed(other_chunk_domain_size, ..)) => {
751                    // `other_chunk` fully overwrites `self_chunk`
752                    debug_assert_eq!(self_chunk_domain_size, other_chunk_domain_size);
753                    *self_chunk = other_chunk.clone();
754                    changed = true;
755                }
756                (
757                    Mixed(self_chunk_domain_size, self_chunk_count, self_chunk_words),
758                    Mixed(_other_chunk_domain_size, _other_chunk_count, other_chunk_words),
759                ) => {
760                    // First check if the operation would change
761                    // `self_chunk.words`. If not, we can avoid allocating some
762                    // words, and this happens often enough that it's a
763                    // performance win. Also, we only need to operate on the
764                    // in-use words, hence the slicing.
765                    let op = |a, b| a | b;
766                    let num_words = num_words(*self_chunk_domain_size as usize);
767                    if bitwise_changes(
768                        &self_chunk_words[0..num_words],
769                        &other_chunk_words[0..num_words],
770                        op,
771                    ) {
772                        let self_chunk_words = Rc::make_mut(self_chunk_words);
773                        let has_changed = bitwise(
774                            &mut self_chunk_words[0..num_words],
775                            &other_chunk_words[0..num_words],
776                            op,
777                        );
778                        debug_assert!(has_changed);
779                        *self_chunk_count = self_chunk_words[0..num_words]
780                            .iter()
781                            .map(|w| w.count_ones() as ChunkSize)
782                            .sum();
783                        if *self_chunk_count == *self_chunk_domain_size {
784                            *self_chunk = Ones(*self_chunk_domain_size);
785                        }
786                        changed = true;
787                    }
788                }
789            }
790        }
791        changed
792    }
793
794    fn subtract(&mut self, other: &ChunkedBitSet<T>) -> bool {
795        assert_eq!(self.domain_size, other.domain_size);
796        debug_assert_eq!(self.chunks.len(), other.chunks.len());
797
798        let mut changed = false;
799        for (mut self_chunk, other_chunk) in self.chunks.iter_mut().zip(other.chunks.iter()) {
800            match (&mut self_chunk, &other_chunk) {
801                (Zeros(..), _) | (_, Zeros(..)) => {}
802                (
803                    Ones(self_chunk_domain_size) | Mixed(self_chunk_domain_size, _, _),
804                    Ones(other_chunk_domain_size),
805                ) => {
806                    debug_assert_eq!(self_chunk_domain_size, other_chunk_domain_size);
807                    changed = true;
808                    *self_chunk = Zeros(*self_chunk_domain_size);
809                }
810                (
811                    Ones(self_chunk_domain_size),
812                    Mixed(other_chunk_domain_size, other_chunk_count, other_chunk_words),
813                ) => {
814                    debug_assert_eq!(self_chunk_domain_size, other_chunk_domain_size);
815                    changed = true;
816                    let num_words = num_words(*self_chunk_domain_size as usize);
817                    debug_assert!(num_words > 0 && num_words <= CHUNK_WORDS);
818                    let mut tail_mask =
819                        1 << (*other_chunk_domain_size - ((num_words - 1) * WORD_BITS) as u16) - 1;
820                    let mut self_chunk_words = **other_chunk_words;
821                    for word in self_chunk_words[0..num_words].iter_mut().rev() {
822                        *word = !*word & tail_mask;
823                        tail_mask = u64::MAX;
824                    }
825                    let self_chunk_count = *self_chunk_domain_size - *other_chunk_count;
826                    debug_assert_eq!(
827                        self_chunk_count,
828                        self_chunk_words[0..num_words]
829                            .iter()
830                            .map(|w| w.count_ones() as ChunkSize)
831                            .sum()
832                    );
833                    *self_chunk =
834                        Mixed(*self_chunk_domain_size, self_chunk_count, Rc::new(self_chunk_words));
835                }
836                (
837                    Mixed(self_chunk_domain_size, self_chunk_count, self_chunk_words),
838                    Mixed(_other_chunk_domain_size, _other_chunk_count, other_chunk_words),
839                ) => {
840                    // See [`<Self as BitRelations<ChunkedBitSet<T>>>::union`] for the explanation
841                    let op = |a: u64, b: u64| a & !b;
842                    let num_words = num_words(*self_chunk_domain_size as usize);
843                    if bitwise_changes(
844                        &self_chunk_words[0..num_words],
845                        &other_chunk_words[0..num_words],
846                        op,
847                    ) {
848                        let self_chunk_words = Rc::make_mut(self_chunk_words);
849                        let has_changed = bitwise(
850                            &mut self_chunk_words[0..num_words],
851                            &other_chunk_words[0..num_words],
852                            op,
853                        );
854                        debug_assert!(has_changed);
855                        *self_chunk_count = self_chunk_words[0..num_words]
856                            .iter()
857                            .map(|w| w.count_ones() as ChunkSize)
858                            .sum();
859                        if *self_chunk_count == 0 {
860                            *self_chunk = Zeros(*self_chunk_domain_size);
861                        }
862                        changed = true;
863                    }
864                }
865            }
866        }
867        changed
868    }
869
870    fn intersect(&mut self, other: &ChunkedBitSet<T>) -> bool {
871        assert_eq!(self.domain_size, other.domain_size);
872        debug_assert_eq!(self.chunks.len(), other.chunks.len());
873
874        let mut changed = false;
875        for (mut self_chunk, other_chunk) in self.chunks.iter_mut().zip(other.chunks.iter()) {
876            match (&mut self_chunk, &other_chunk) {
877                (Zeros(..), _) | (_, Ones(..)) => {}
878                (
879                    Ones(self_chunk_domain_size),
880                    Zeros(other_chunk_domain_size) | Mixed(other_chunk_domain_size, ..),
881                )
882                | (Mixed(self_chunk_domain_size, ..), Zeros(other_chunk_domain_size)) => {
883                    debug_assert_eq!(self_chunk_domain_size, other_chunk_domain_size);
884                    changed = true;
885                    *self_chunk = other_chunk.clone();
886                }
887                (
888                    Mixed(self_chunk_domain_size, self_chunk_count, self_chunk_words),
889                    Mixed(_other_chunk_domain_size, _other_chunk_count, other_chunk_words),
890                ) => {
891                    // See [`<Self as BitRelations<ChunkedBitSet<T>>>::union`] for the explanation
892                    let op = |a, b| a & b;
893                    let num_words = num_words(*self_chunk_domain_size as usize);
894                    if bitwise_changes(
895                        &self_chunk_words[0..num_words],
896                        &other_chunk_words[0..num_words],
897                        op,
898                    ) {
899                        let self_chunk_words = Rc::make_mut(self_chunk_words);
900                        let has_changed = bitwise(
901                            &mut self_chunk_words[0..num_words],
902                            &other_chunk_words[0..num_words],
903                            op,
904                        );
905                        debug_assert!(has_changed);
906                        *self_chunk_count = self_chunk_words[0..num_words]
907                            .iter()
908                            .map(|w| w.count_ones() as ChunkSize)
909                            .sum();
910                        if *self_chunk_count == 0 {
911                            *self_chunk = Zeros(*self_chunk_domain_size);
912                        }
913                        changed = true;
914                    }
915                }
916            }
917        }
918
919        changed
920    }
921}
922
923impl<T: Idx> BitRelations<ChunkedBitSet<T>> for DenseBitSet<T> {
924    fn union(&mut self, other: &ChunkedBitSet<T>) -> bool {
925        sequential_update(|elem| self.insert(elem), other.iter())
926    }
927
928    fn subtract(&mut self, _other: &ChunkedBitSet<T>) -> bool {
929        unimplemented!("implement if/when necessary");
930    }
931
932    fn intersect(&mut self, other: &ChunkedBitSet<T>) -> bool {
933        assert_eq!(self.domain_size(), other.domain_size);
934        let mut changed = false;
935        for (i, chunk) in other.chunks.iter().enumerate() {
936            let mut words = &mut self.words[i * CHUNK_WORDS..];
937            if words.len() > CHUNK_WORDS {
938                words = &mut words[..CHUNK_WORDS];
939            }
940            match chunk {
941                Zeros(..) => {
942                    for word in words {
943                        if *word != 0 {
944                            changed = true;
945                            *word = 0;
946                        }
947                    }
948                }
949                Ones(..) => (),
950                Mixed(_, _, data) => {
951                    for (i, word) in words.iter_mut().enumerate() {
952                        let new_val = *word & data[i];
953                        if new_val != *word {
954                            changed = true;
955                            *word = new_val;
956                        }
957                    }
958                }
959            }
960        }
961        changed
962    }
963}
964
965impl<T> Clone for ChunkedBitSet<T> {
966    fn clone(&self) -> Self {
967        ChunkedBitSet {
968            domain_size: self.domain_size,
969            chunks: self.chunks.clone(),
970            marker: PhantomData,
971        }
972    }
973
974    /// WARNING: this implementation of clone_from will panic if the two
975    /// bitsets have different domain sizes. This constraint is not inherent to
976    /// `clone_from`, but it works with the existing call sites and allows a
977    /// faster implementation, which is important because this function is hot.
978    fn clone_from(&mut self, from: &Self) {
979        assert_eq!(self.domain_size, from.domain_size);
980        debug_assert_eq!(self.chunks.len(), from.chunks.len());
981
982        self.chunks.clone_from(&from.chunks)
983    }
984}
985
986pub struct ChunkedBitIter<'a, T: Idx> {
987    bit_set: &'a ChunkedBitSet<T>,
988
989    // The index of the current chunk.
990    chunk_index: usize,
991
992    // The sub-iterator for the current chunk.
993    chunk_iter: ChunkIter<'a>,
994}
995
996impl<'a, T: Idx> ChunkedBitIter<'a, T> {
997    #[inline]
998    fn new(bit_set: &'a ChunkedBitSet<T>) -> ChunkedBitIter<'a, T> {
999        ChunkedBitIter { bit_set, chunk_index: 0, chunk_iter: bit_set.chunk_iter(0) }
1000    }
1001}
1002
1003impl<'a, T: Idx> Iterator for ChunkedBitIter<'a, T> {
1004    type Item = T;
1005
1006    fn next(&mut self) -> Option<T> {
1007        loop {
1008            match &mut self.chunk_iter {
1009                ChunkIter::Zeros => {}
1010                ChunkIter::Ones(iter) => {
1011                    if let Some(next) = iter.next() {
1012                        return Some(T::new(next + self.chunk_index * CHUNK_BITS));
1013                    }
1014                }
1015                ChunkIter::Mixed(iter) => {
1016                    if let Some(next) = iter.next() {
1017                        return Some(T::new(next + self.chunk_index * CHUNK_BITS));
1018                    }
1019                }
1020                ChunkIter::Finished => return None,
1021            }
1022            self.chunk_index += 1;
1023            self.chunk_iter = self.bit_set.chunk_iter(self.chunk_index);
1024        }
1025    }
1026}
1027
1028impl Chunk {
1029    #[cfg(test)]
1030    fn assert_valid(&self) {
1031        match *self {
1032            Zeros(chunk_domain_size) | Ones(chunk_domain_size) => {
1033                assert!(chunk_domain_size as usize <= CHUNK_BITS);
1034            }
1035            Mixed(chunk_domain_size, count, ref words) => {
1036                assert!(chunk_domain_size as usize <= CHUNK_BITS);
1037                assert!(0 < count && count < chunk_domain_size);
1038
1039                // Check the number of set bits matches `count`.
1040                assert_eq!(
1041                    words.iter().map(|w| w.count_ones() as ChunkSize).sum::<ChunkSize>(),
1042                    count
1043                );
1044
1045                // Check the not-in-use words are all zeroed.
1046                let num_words = num_words(chunk_domain_size as usize);
1047                if num_words < CHUNK_WORDS {
1048                    assert_eq!(
1049                        words[num_words..]
1050                            .iter()
1051                            .map(|w| w.count_ones() as ChunkSize)
1052                            .sum::<ChunkSize>(),
1053                        0
1054                    );
1055                }
1056            }
1057        }
1058    }
1059
1060    fn new(chunk_domain_size: usize, is_empty: bool) -> Self {
1061        debug_assert!(0 < chunk_domain_size && chunk_domain_size <= CHUNK_BITS);
1062        let chunk_domain_size = chunk_domain_size as ChunkSize;
1063        if is_empty { Zeros(chunk_domain_size) } else { Ones(chunk_domain_size) }
1064    }
1065
1066    /// Count the number of 1s in the chunk.
1067    fn count(&self) -> usize {
1068        match *self {
1069            Zeros(_) => 0,
1070            Ones(chunk_domain_size) => chunk_domain_size as usize,
1071            Mixed(_, count, _) => count as usize,
1072        }
1073    }
1074}
1075
1076enum ChunkIter<'a> {
1077    Zeros,
1078    Ones(Range<usize>),
1079    Mixed(BitIter<'a, usize>),
1080    Finished,
1081}
1082
1083// Applies a function to mutate a bitset, and returns true if any
1084// of the applications return true
1085fn sequential_update<T: Idx>(
1086    mut self_update: impl FnMut(T) -> bool,
1087    it: impl Iterator<Item = T>,
1088) -> bool {
1089    it.fold(false, |changed, elem| self_update(elem) | changed)
1090}
1091
1092impl<T: Idx> fmt::Debug for ChunkedBitSet<T> {
1093    fn fmt(&self, w: &mut fmt::Formatter<'_>) -> fmt::Result {
1094        w.debug_list().entries(self.iter()).finish()
1095    }
1096}
1097
1098/// Sets `out_vec[i] = op(out_vec[i], in_vec[i])` for each index `i` in both
1099/// slices. The slices must have the same length.
1100///
1101/// Returns true if at least one bit in `out_vec` was changed.
1102///
1103/// ## Warning
1104/// Some bitwise operations (e.g. union-not, xor) can set output bits that were
1105/// unset in in both inputs. If this happens in the last word/chunk of a bitset,
1106/// it can cause the bitset to contain out-of-domain values, which need to
1107/// be cleared with `clear_excess_bits_in_final_word`. This also makes the
1108/// "changed" return value unreliable, because the change might have only
1109/// affected excess bits.
1110#[inline]
1111fn bitwise<Op>(out_vec: &mut [Word], in_vec: &[Word], op: Op) -> bool
1112where
1113    Op: Fn(Word, Word) -> Word,
1114{
1115    assert_eq!(out_vec.len(), in_vec.len());
1116    let mut changed = 0;
1117    for (out_elem, in_elem) in iter::zip(out_vec, in_vec) {
1118        let old_val = *out_elem;
1119        let new_val = op(old_val, *in_elem);
1120        *out_elem = new_val;
1121        // This is essentially equivalent to a != with changed being a bool, but
1122        // in practice this code gets auto-vectorized by the compiler for most
1123        // operators. Using != here causes us to generate quite poor code as the
1124        // compiler tries to go back to a boolean on each loop iteration.
1125        changed |= old_val ^ new_val;
1126    }
1127    changed != 0
1128}
1129
1130/// Does this bitwise operation change `out_vec`?
1131#[inline]
1132fn bitwise_changes<Op>(out_vec: &[Word], in_vec: &[Word], op: Op) -> bool
1133where
1134    Op: Fn(Word, Word) -> Word,
1135{
1136    assert_eq!(out_vec.len(), in_vec.len());
1137    for (out_elem, in_elem) in iter::zip(out_vec, in_vec) {
1138        let old_val = *out_elem;
1139        let new_val = op(old_val, *in_elem);
1140        if old_val != new_val {
1141            return true;
1142        }
1143    }
1144    false
1145}
1146
1147/// A bitset with a mixed representation, using `DenseBitSet` for small and
1148/// medium bitsets, and `ChunkedBitSet` for large bitsets, i.e. those with
1149/// enough bits for at least two chunks. This is a good choice for many bitsets
1150/// that can have large domain sizes (e.g. 5000+).
1151///
1152/// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
1153/// just be `usize`.
1154///
1155/// All operations that involve an element will panic if the element is equal
1156/// to or greater than the domain size. All operations that involve two bitsets
1157/// will panic if the bitsets have differing domain sizes.
1158#[derive(PartialEq, Eq)]
1159pub enum MixedBitSet<T> {
1160    Small(DenseBitSet<T>),
1161    Large(ChunkedBitSet<T>),
1162}
1163
1164impl<T> MixedBitSet<T> {
1165    pub fn domain_size(&self) -> usize {
1166        match self {
1167            MixedBitSet::Small(set) => set.domain_size(),
1168            MixedBitSet::Large(set) => set.domain_size(),
1169        }
1170    }
1171}
1172
1173impl<T: Idx> MixedBitSet<T> {
1174    #[inline]
1175    pub fn new_empty(domain_size: usize) -> MixedBitSet<T> {
1176        if domain_size <= CHUNK_BITS {
1177            MixedBitSet::Small(DenseBitSet::new_empty(domain_size))
1178        } else {
1179            MixedBitSet::Large(ChunkedBitSet::new_empty(domain_size))
1180        }
1181    }
1182
1183    #[inline]
1184    pub fn is_empty(&self) -> bool {
1185        match self {
1186            MixedBitSet::Small(set) => set.is_empty(),
1187            MixedBitSet::Large(set) => set.is_empty(),
1188        }
1189    }
1190
1191    #[inline]
1192    pub fn contains(&self, elem: T) -> bool {
1193        match self {
1194            MixedBitSet::Small(set) => set.contains(elem),
1195            MixedBitSet::Large(set) => set.contains(elem),
1196        }
1197    }
1198
1199    #[inline]
1200    pub fn insert(&mut self, elem: T) -> bool {
1201        match self {
1202            MixedBitSet::Small(set) => set.insert(elem),
1203            MixedBitSet::Large(set) => set.insert(elem),
1204        }
1205    }
1206
1207    pub fn insert_all(&mut self) {
1208        match self {
1209            MixedBitSet::Small(set) => set.insert_all(),
1210            MixedBitSet::Large(set) => set.insert_all(),
1211        }
1212    }
1213
1214    #[inline]
1215    pub fn remove(&mut self, elem: T) -> bool {
1216        match self {
1217            MixedBitSet::Small(set) => set.remove(elem),
1218            MixedBitSet::Large(set) => set.remove(elem),
1219        }
1220    }
1221
1222    pub fn iter(&self) -> MixedBitIter<'_, T> {
1223        match self {
1224            MixedBitSet::Small(set) => MixedBitIter::Small(set.iter()),
1225            MixedBitSet::Large(set) => MixedBitIter::Large(set.iter()),
1226        }
1227    }
1228
1229    #[inline]
1230    pub fn clear(&mut self) {
1231        match self {
1232            MixedBitSet::Small(set) => set.clear(),
1233            MixedBitSet::Large(set) => set.clear(),
1234        }
1235    }
1236
1237    bit_relations_inherent_impls! {}
1238}
1239
1240impl<T> Clone for MixedBitSet<T> {
1241    fn clone(&self) -> Self {
1242        match self {
1243            MixedBitSet::Small(set) => MixedBitSet::Small(set.clone()),
1244            MixedBitSet::Large(set) => MixedBitSet::Large(set.clone()),
1245        }
1246    }
1247
1248    /// WARNING: this implementation of clone_from may panic if the two
1249    /// bitsets have different domain sizes. This constraint is not inherent to
1250    /// `clone_from`, but it works with the existing call sites and allows a
1251    /// faster implementation, which is important because this function is hot.
1252    fn clone_from(&mut self, from: &Self) {
1253        match (self, from) {
1254            (MixedBitSet::Small(set), MixedBitSet::Small(from)) => set.clone_from(from),
1255            (MixedBitSet::Large(set), MixedBitSet::Large(from)) => set.clone_from(from),
1256            _ => panic!("MixedBitSet size mismatch"),
1257        }
1258    }
1259}
1260
1261impl<T: Idx> BitRelations<MixedBitSet<T>> for MixedBitSet<T> {
1262    fn union(&mut self, other: &MixedBitSet<T>) -> bool {
1263        match (self, other) {
1264            (MixedBitSet::Small(set), MixedBitSet::Small(other)) => set.union(other),
1265            (MixedBitSet::Large(set), MixedBitSet::Large(other)) => set.union(other),
1266            _ => panic!("MixedBitSet size mismatch"),
1267        }
1268    }
1269
1270    fn subtract(&mut self, other: &MixedBitSet<T>) -> bool {
1271        match (self, other) {
1272            (MixedBitSet::Small(set), MixedBitSet::Small(other)) => set.subtract(other),
1273            (MixedBitSet::Large(set), MixedBitSet::Large(other)) => set.subtract(other),
1274            _ => panic!("MixedBitSet size mismatch"),
1275        }
1276    }
1277
1278    fn intersect(&mut self, _other: &MixedBitSet<T>) -> bool {
1279        unimplemented!("implement if/when necessary");
1280    }
1281}
1282
1283impl<T: Idx> fmt::Debug for MixedBitSet<T> {
1284    fn fmt(&self, w: &mut fmt::Formatter<'_>) -> fmt::Result {
1285        match self {
1286            MixedBitSet::Small(set) => set.fmt(w),
1287            MixedBitSet::Large(set) => set.fmt(w),
1288        }
1289    }
1290}
1291
1292pub enum MixedBitIter<'a, T: Idx> {
1293    Small(BitIter<'a, T>),
1294    Large(ChunkedBitIter<'a, T>),
1295}
1296
1297impl<'a, T: Idx> Iterator for MixedBitIter<'a, T> {
1298    type Item = T;
1299    fn next(&mut self) -> Option<T> {
1300        match self {
1301            MixedBitIter::Small(iter) => iter.next(),
1302            MixedBitIter::Large(iter) => iter.next(),
1303        }
1304    }
1305}
1306
1307/// A resizable bitset type with a dense representation.
1308///
1309/// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
1310/// just be `usize`.
1311///
1312/// All operations that involve an element will panic if the element is equal
1313/// to or greater than the domain size.
1314#[derive(Clone, Debug, PartialEq)]
1315pub struct GrowableBitSet<T: Idx> {
1316    bit_set: DenseBitSet<T>,
1317}
1318
1319impl<T: Idx> Default for GrowableBitSet<T> {
1320    fn default() -> Self {
1321        GrowableBitSet::new_empty()
1322    }
1323}
1324
1325impl<T: Idx> GrowableBitSet<T> {
1326    /// Ensure that the set can hold at least `min_domain_size` elements.
1327    pub fn ensure(&mut self, min_domain_size: usize) {
1328        if self.bit_set.domain_size < min_domain_size {
1329            self.bit_set.domain_size = min_domain_size;
1330        }
1331
1332        let min_num_words = num_words(min_domain_size);
1333        if self.bit_set.words.len() < min_num_words {
1334            self.bit_set.words.resize(min_num_words, 0)
1335        }
1336    }
1337
1338    pub fn new_empty() -> GrowableBitSet<T> {
1339        GrowableBitSet { bit_set: DenseBitSet::new_empty(0) }
1340    }
1341
1342    pub fn with_capacity(capacity: usize) -> GrowableBitSet<T> {
1343        GrowableBitSet { bit_set: DenseBitSet::new_empty(capacity) }
1344    }
1345
1346    /// Returns `true` if the set has changed.
1347    #[inline]
1348    pub fn insert(&mut self, elem: T) -> bool {
1349        self.ensure(elem.index() + 1);
1350        self.bit_set.insert(elem)
1351    }
1352
1353    /// Returns `true` if the set has changed.
1354    #[inline]
1355    pub fn remove(&mut self, elem: T) -> bool {
1356        self.ensure(elem.index() + 1);
1357        self.bit_set.remove(elem)
1358    }
1359
1360    #[inline]
1361    pub fn is_empty(&self) -> bool {
1362        self.bit_set.is_empty()
1363    }
1364
1365    #[inline]
1366    pub fn contains(&self, elem: T) -> bool {
1367        let (word_index, mask) = word_index_and_mask(elem);
1368        self.bit_set.words.get(word_index).is_some_and(|word| (word & mask) != 0)
1369    }
1370
1371    #[inline]
1372    pub fn iter(&self) -> BitIter<'_, T> {
1373        self.bit_set.iter()
1374    }
1375
1376    #[inline]
1377    pub fn len(&self) -> usize {
1378        self.bit_set.count()
1379    }
1380}
1381
1382impl<T: Idx> From<DenseBitSet<T>> for GrowableBitSet<T> {
1383    fn from(bit_set: DenseBitSet<T>) -> Self {
1384        Self { bit_set }
1385    }
1386}
1387
1388/// A fixed-size 2D bit matrix type with a dense representation.
1389///
1390/// `R` and `C` are index types used to identify rows and columns respectively;
1391/// typically newtyped `usize` wrappers, but they can also just be `usize`.
1392///
1393/// All operations that involve a row and/or column index will panic if the
1394/// index exceeds the relevant bound.
1395#[cfg_attr(feature = "nightly", derive(Decodable_NoContext, Encodable_NoContext))]
1396#[derive(Clone, Eq, PartialEq, Hash)]
1397pub struct BitMatrix<R: Idx, C: Idx> {
1398    num_rows: usize,
1399    num_columns: usize,
1400    words: SmallVec<[Word; 2]>,
1401    marker: PhantomData<(R, C)>,
1402}
1403
1404impl<R: Idx, C: Idx> BitMatrix<R, C> {
1405    /// Creates a new `rows x columns` matrix, initially empty.
1406    pub fn new(num_rows: usize, num_columns: usize) -> BitMatrix<R, C> {
1407        // For every element, we need one bit for every other
1408        // element. Round up to an even number of words.
1409        let words_per_row = num_words(num_columns);
1410        BitMatrix {
1411            num_rows,
1412            num_columns,
1413            words: smallvec![0; num_rows * words_per_row],
1414            marker: PhantomData,
1415        }
1416    }
1417
1418    /// Creates a new matrix, with `row` used as the value for every row.
1419    pub fn from_row_n(row: &DenseBitSet<C>, num_rows: usize) -> BitMatrix<R, C> {
1420        let num_columns = row.domain_size();
1421        let words_per_row = num_words(num_columns);
1422        assert_eq!(words_per_row, row.words.len());
1423        BitMatrix {
1424            num_rows,
1425            num_columns,
1426            words: iter::repeat(&row.words).take(num_rows).flatten().cloned().collect(),
1427            marker: PhantomData,
1428        }
1429    }
1430
1431    pub fn rows(&self) -> impl Iterator<Item = R> {
1432        (0..self.num_rows).map(R::new)
1433    }
1434
1435    /// The range of bits for a given row.
1436    fn range(&self, row: R) -> (usize, usize) {
1437        let words_per_row = num_words(self.num_columns);
1438        let start = row.index() * words_per_row;
1439        (start, start + words_per_row)
1440    }
1441
1442    /// Sets the cell at `(row, column)` to true. Put another way, insert
1443    /// `column` to the bitset for `row`.
1444    ///
1445    /// Returns `true` if this changed the matrix.
1446    pub fn insert(&mut self, row: R, column: C) -> bool {
1447        assert!(row.index() < self.num_rows && column.index() < self.num_columns);
1448        let (start, _) = self.range(row);
1449        let (word_index, mask) = word_index_and_mask(column);
1450        let words = &mut self.words[..];
1451        let word = words[start + word_index];
1452        let new_word = word | mask;
1453        words[start + word_index] = new_word;
1454        word != new_word
1455    }
1456
1457    /// Do the bits from `row` contain `column`? Put another way, is
1458    /// the matrix cell at `(row, column)` true?  Put yet another way,
1459    /// if the matrix represents (transitive) reachability, can
1460    /// `row` reach `column`?
1461    pub fn contains(&self, row: R, column: C) -> bool {
1462        assert!(row.index() < self.num_rows && column.index() < self.num_columns);
1463        let (start, _) = self.range(row);
1464        let (word_index, mask) = word_index_and_mask(column);
1465        (self.words[start + word_index] & mask) != 0
1466    }
1467
1468    /// Returns those indices that are true in rows `a` and `b`. This
1469    /// is an *O*(*n*) operation where *n* is the number of elements
1470    /// (somewhat independent from the actual size of the
1471    /// intersection, in particular).
1472    pub fn intersect_rows(&self, row1: R, row2: R) -> Vec<C> {
1473        assert!(row1.index() < self.num_rows && row2.index() < self.num_rows);
1474        let (row1_start, row1_end) = self.range(row1);
1475        let (row2_start, row2_end) = self.range(row2);
1476        let mut result = Vec::with_capacity(self.num_columns);
1477        for (base, (i, j)) in (row1_start..row1_end).zip(row2_start..row2_end).enumerate() {
1478            let mut v = self.words[i] & self.words[j];
1479            for bit in 0..WORD_BITS {
1480                if v == 0 {
1481                    break;
1482                }
1483                if v & 0x1 != 0 {
1484                    result.push(C::new(base * WORD_BITS + bit));
1485                }
1486                v >>= 1;
1487            }
1488        }
1489        result
1490    }
1491
1492    /// Adds the bits from row `read` to the bits from row `write`, and
1493    /// returns `true` if anything changed.
1494    ///
1495    /// This is used when computing transitive reachability because if
1496    /// you have an edge `write -> read`, because in that case
1497    /// `write` can reach everything that `read` can (and
1498    /// potentially more).
1499    pub fn union_rows(&mut self, read: R, write: R) -> bool {
1500        assert!(read.index() < self.num_rows && write.index() < self.num_rows);
1501        let (read_start, read_end) = self.range(read);
1502        let (write_start, write_end) = self.range(write);
1503        let words = &mut self.words[..];
1504        let mut changed = 0;
1505        for (read_index, write_index) in iter::zip(read_start..read_end, write_start..write_end) {
1506            let word = words[write_index];
1507            let new_word = word | words[read_index];
1508            words[write_index] = new_word;
1509            // See `bitwise` for the rationale.
1510            changed |= word ^ new_word;
1511        }
1512        changed != 0
1513    }
1514
1515    /// Adds the bits from `with` to the bits from row `write`, and
1516    /// returns `true` if anything changed.
1517    pub fn union_row_with(&mut self, with: &DenseBitSet<C>, write: R) -> bool {
1518        assert!(write.index() < self.num_rows);
1519        assert_eq!(with.domain_size(), self.num_columns);
1520        let (write_start, write_end) = self.range(write);
1521        bitwise(&mut self.words[write_start..write_end], &with.words, |a, b| a | b)
1522    }
1523
1524    /// Sets every cell in `row` to true.
1525    pub fn insert_all_into_row(&mut self, row: R) {
1526        assert!(row.index() < self.num_rows);
1527        let (start, end) = self.range(row);
1528        let words = &mut self.words[..];
1529        for index in start..end {
1530            words[index] = !0;
1531        }
1532        clear_excess_bits_in_final_word(self.num_columns, &mut self.words[..end]);
1533    }
1534
1535    /// Gets a slice of the underlying words.
1536    pub fn words(&self) -> &[Word] {
1537        &self.words
1538    }
1539
1540    /// Iterates through all the columns set to true in a given row of
1541    /// the matrix.
1542    pub fn iter(&self, row: R) -> BitIter<'_, C> {
1543        assert!(row.index() < self.num_rows);
1544        let (start, end) = self.range(row);
1545        BitIter::new(&self.words[start..end])
1546    }
1547
1548    /// Returns the number of elements in `row`.
1549    pub fn count(&self, row: R) -> usize {
1550        let (start, end) = self.range(row);
1551        self.words[start..end].iter().map(|e| e.count_ones() as usize).sum()
1552    }
1553}
1554
1555impl<R: Idx, C: Idx> fmt::Debug for BitMatrix<R, C> {
1556    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1557        /// Forces its contents to print in regular mode instead of alternate mode.
1558        struct OneLinePrinter<T>(T);
1559        impl<T: fmt::Debug> fmt::Debug for OneLinePrinter<T> {
1560            fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1561                write!(fmt, "{:?}", self.0)
1562            }
1563        }
1564
1565        write!(fmt, "BitMatrix({}x{}) ", self.num_rows, self.num_columns)?;
1566        let items = self.rows().flat_map(|r| self.iter(r).map(move |c| (r, c)));
1567        fmt.debug_set().entries(items.map(OneLinePrinter)).finish()
1568    }
1569}
1570
1571/// A fixed-column-size, variable-row-size 2D bit matrix with a moderately
1572/// sparse representation.
1573///
1574/// Initially, every row has no explicit representation. If any bit within a row
1575/// is set, the entire row is instantiated as `Some(<DenseBitSet>)`.
1576/// Furthermore, any previously uninstantiated rows prior to it will be
1577/// instantiated as `None`. Those prior rows may themselves become fully
1578/// instantiated later on if any of their bits are set.
1579///
1580/// `R` and `C` are index types used to identify rows and columns respectively;
1581/// typically newtyped `usize` wrappers, but they can also just be `usize`.
1582#[derive(Clone, Debug)]
1583pub struct SparseBitMatrix<R, C>
1584where
1585    R: Idx,
1586    C: Idx,
1587{
1588    num_columns: usize,
1589    rows: IndexVec<R, Option<DenseBitSet<C>>>,
1590}
1591
1592impl<R: Idx, C: Idx> SparseBitMatrix<R, C> {
1593    /// Creates a new empty sparse bit matrix with no rows or columns.
1594    pub fn new(num_columns: usize) -> Self {
1595        Self { num_columns, rows: IndexVec::new() }
1596    }
1597
1598    fn ensure_row(&mut self, row: R) -> &mut DenseBitSet<C> {
1599        // Instantiate any missing rows up to and including row `row` with an empty `DenseBitSet`.
1600        // Then replace row `row` with a full `DenseBitSet` if necessary.
1601        self.rows.get_or_insert_with(row, || DenseBitSet::new_empty(self.num_columns))
1602    }
1603
1604    /// Sets the cell at `(row, column)` to true. Put another way, insert
1605    /// `column` to the bitset for `row`.
1606    ///
1607    /// Returns `true` if this changed the matrix.
1608    pub fn insert(&mut self, row: R, column: C) -> bool {
1609        self.ensure_row(row).insert(column)
1610    }
1611
1612    /// Sets the cell at `(row, column)` to false. Put another way, delete
1613    /// `column` from the bitset for `row`. Has no effect if `row` does not
1614    /// exist.
1615    ///
1616    /// Returns `true` if this changed the matrix.
1617    pub fn remove(&mut self, row: R, column: C) -> bool {
1618        match self.rows.get_mut(row) {
1619            Some(Some(row)) => row.remove(column),
1620            _ => false,
1621        }
1622    }
1623
1624    /// Sets all columns at `row` to false. Has no effect if `row` does
1625    /// not exist.
1626    pub fn clear(&mut self, row: R) {
1627        if let Some(Some(row)) = self.rows.get_mut(row) {
1628            row.clear();
1629        }
1630    }
1631
1632    /// Do the bits from `row` contain `column`? Put another way, is
1633    /// the matrix cell at `(row, column)` true?  Put yet another way,
1634    /// if the matrix represents (transitive) reachability, can
1635    /// `row` reach `column`?
1636    pub fn contains(&self, row: R, column: C) -> bool {
1637        self.row(row).is_some_and(|r| r.contains(column))
1638    }
1639
1640    /// Adds the bits from row `read` to the bits from row `write`, and
1641    /// returns `true` if anything changed.
1642    ///
1643    /// This is used when computing transitive reachability because if
1644    /// you have an edge `write -> read`, because in that case
1645    /// `write` can reach everything that `read` can (and
1646    /// potentially more).
1647    pub fn union_rows(&mut self, read: R, write: R) -> bool {
1648        if read == write || self.row(read).is_none() {
1649            return false;
1650        }
1651
1652        self.ensure_row(write);
1653        if let (Some(read_row), Some(write_row)) = self.rows.pick2_mut(read, write) {
1654            write_row.union(read_row)
1655        } else {
1656            unreachable!()
1657        }
1658    }
1659
1660    /// Insert all bits in the given row.
1661    pub fn insert_all_into_row(&mut self, row: R) {
1662        self.ensure_row(row).insert_all();
1663    }
1664
1665    pub fn rows(&self) -> impl Iterator<Item = R> {
1666        self.rows.indices()
1667    }
1668
1669    /// Iterates through all the columns set to true in a given row of
1670    /// the matrix.
1671    pub fn iter(&self, row: R) -> impl Iterator<Item = C> {
1672        self.row(row).into_iter().flat_map(|r| r.iter())
1673    }
1674
1675    pub fn row(&self, row: R) -> Option<&DenseBitSet<C>> {
1676        self.rows.get(row)?.as_ref()
1677    }
1678
1679    /// Intersects `row` with `set`. `set` can be either `DenseBitSet` or
1680    /// `ChunkedBitSet`. Has no effect if `row` does not exist.
1681    ///
1682    /// Returns true if the row was changed.
1683    pub fn intersect_row<Set>(&mut self, row: R, set: &Set) -> bool
1684    where
1685        DenseBitSet<C>: BitRelations<Set>,
1686    {
1687        match self.rows.get_mut(row) {
1688            Some(Some(row)) => row.intersect(set),
1689            _ => false,
1690        }
1691    }
1692
1693    /// Subtracts `set` from `row`. `set` can be either `DenseBitSet` or
1694    /// `ChunkedBitSet`. Has no effect if `row` does not exist.
1695    ///
1696    /// Returns true if the row was changed.
1697    pub fn subtract_row<Set>(&mut self, row: R, set: &Set) -> bool
1698    where
1699        DenseBitSet<C>: BitRelations<Set>,
1700    {
1701        match self.rows.get_mut(row) {
1702            Some(Some(row)) => row.subtract(set),
1703            _ => false,
1704        }
1705    }
1706
1707    /// Unions `row` with `set`. `set` can be either `DenseBitSet` or
1708    /// `ChunkedBitSet`.
1709    ///
1710    /// Returns true if the row was changed.
1711    pub fn union_row<Set>(&mut self, row: R, set: &Set) -> bool
1712    where
1713        DenseBitSet<C>: BitRelations<Set>,
1714    {
1715        self.ensure_row(row).union(set)
1716    }
1717}
1718
1719#[inline]
1720fn num_words<T: Idx>(domain_size: T) -> usize {
1721    (domain_size.index() + WORD_BITS - 1) / WORD_BITS
1722}
1723
1724#[inline]
1725fn num_chunks<T: Idx>(domain_size: T) -> usize {
1726    assert!(domain_size.index() > 0);
1727    (domain_size.index() + CHUNK_BITS - 1) / CHUNK_BITS
1728}
1729
1730#[inline]
1731fn word_index_and_mask<T: Idx>(elem: T) -> (usize, Word) {
1732    let elem = elem.index();
1733    let word_index = elem / WORD_BITS;
1734    let mask = 1 << (elem % WORD_BITS);
1735    (word_index, mask)
1736}
1737
1738#[inline]
1739fn chunk_index<T: Idx>(elem: T) -> usize {
1740    elem.index() / CHUNK_BITS
1741}
1742
1743#[inline]
1744fn chunk_word_index_and_mask<T: Idx>(elem: T) -> (usize, Word) {
1745    let chunk_elem = elem.index() % CHUNK_BITS;
1746    word_index_and_mask(chunk_elem)
1747}
1748
1749fn clear_excess_bits_in_final_word(domain_size: usize, words: &mut [Word]) {
1750    let num_bits_in_final_word = domain_size % WORD_BITS;
1751    if num_bits_in_final_word > 0 {
1752        let mask = (1 << num_bits_in_final_word) - 1;
1753        words[words.len() - 1] &= mask;
1754    }
1755}
1756
1757#[inline]
1758fn max_bit(word: Word) -> usize {
1759    WORD_BITS - 1 - word.leading_zeros() as usize
1760}
1761
1762/// Integral type used to represent the bit set.
1763pub trait FiniteBitSetTy:
1764    BitAnd<Output = Self>
1765    + BitAndAssign
1766    + BitOrAssign
1767    + Clone
1768    + Copy
1769    + Shl
1770    + Not<Output = Self>
1771    + PartialEq
1772    + Sized
1773{
1774    /// Size of the domain representable by this type, e.g. 64 for `u64`.
1775    const DOMAIN_SIZE: u32;
1776
1777    /// Value which represents the `FiniteBitSet` having every bit set.
1778    const FILLED: Self;
1779    /// Value which represents the `FiniteBitSet` having no bits set.
1780    const EMPTY: Self;
1781
1782    /// Value for one as the integral type.
1783    const ONE: Self;
1784    /// Value for zero as the integral type.
1785    const ZERO: Self;
1786
1787    /// Perform a checked left shift on the integral type.
1788    fn checked_shl(self, rhs: u32) -> Option<Self>;
1789    /// Perform a checked right shift on the integral type.
1790    fn checked_shr(self, rhs: u32) -> Option<Self>;
1791}
1792
1793impl FiniteBitSetTy for u32 {
1794    const DOMAIN_SIZE: u32 = 32;
1795
1796    const FILLED: Self = Self::MAX;
1797    const EMPTY: Self = Self::MIN;
1798
1799    const ONE: Self = 1u32;
1800    const ZERO: Self = 0u32;
1801
1802    fn checked_shl(self, rhs: u32) -> Option<Self> {
1803        self.checked_shl(rhs)
1804    }
1805
1806    fn checked_shr(self, rhs: u32) -> Option<Self> {
1807        self.checked_shr(rhs)
1808    }
1809}
1810
1811impl std::fmt::Debug for FiniteBitSet<u32> {
1812    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1813        write!(f, "{:032b}", self.0)
1814    }
1815}
1816
1817/// A fixed-sized bitset type represented by an integer type. Indices outwith than the range
1818/// representable by `T` are considered set.
1819#[cfg_attr(feature = "nightly", derive(Decodable_NoContext, Encodable_NoContext))]
1820#[derive(Copy, Clone, Eq, PartialEq)]
1821pub struct FiniteBitSet<T: FiniteBitSetTy>(pub T);
1822
1823impl<T: FiniteBitSetTy> FiniteBitSet<T> {
1824    /// Creates a new, empty bitset.
1825    pub fn new_empty() -> Self {
1826        Self(T::EMPTY)
1827    }
1828
1829    /// Sets the `index`th bit.
1830    pub fn set(&mut self, index: u32) {
1831        self.0 |= T::ONE.checked_shl(index).unwrap_or(T::ZERO);
1832    }
1833
1834    /// Unsets the `index`th bit.
1835    pub fn clear(&mut self, index: u32) {
1836        self.0 &= !T::ONE.checked_shl(index).unwrap_or(T::ZERO);
1837    }
1838
1839    /// Sets the `i`th to `j`th bits.
1840    pub fn set_range(&mut self, range: Range<u32>) {
1841        let bits = T::FILLED
1842            .checked_shl(range.end - range.start)
1843            .unwrap_or(T::ZERO)
1844            .not()
1845            .checked_shl(range.start)
1846            .unwrap_or(T::ZERO);
1847        self.0 |= bits;
1848    }
1849
1850    /// Is the set empty?
1851    pub fn is_empty(&self) -> bool {
1852        self.0 == T::EMPTY
1853    }
1854
1855    /// Returns the domain size of the bitset.
1856    pub fn within_domain(&self, index: u32) -> bool {
1857        index < T::DOMAIN_SIZE
1858    }
1859
1860    /// Returns if the `index`th bit is set.
1861    pub fn contains(&self, index: u32) -> Option<bool> {
1862        self.within_domain(index)
1863            .then(|| ((self.0.checked_shr(index).unwrap_or(T::ONE)) & T::ONE) == T::ONE)
1864    }
1865}
1866
1867impl<T: FiniteBitSetTy> Default for FiniteBitSet<T> {
1868    fn default() -> Self {
1869        Self::new_empty()
1870    }
1871}