diff options
Diffstat (limited to 'zjit/src/bitset.rs')
| -rw-r--r-- | zjit/src/bitset.rs | 225 |
1 files changed, 225 insertions, 0 deletions
diff --git a/zjit/src/bitset.rs b/zjit/src/bitset.rs new file mode 100644 index 0000000000..986d537d9b --- /dev/null +++ b/zjit/src/bitset.rs @@ -0,0 +1,225 @@ +//! Optimized bitset implementation. + +type Entry = u128; + +const ENTRY_NUM_BITS: usize = Entry::BITS as usize; + +// TODO(max): Make a `SmallBitSet` and `LargeBitSet` and switch between them if `num_bits` fits in +// `Entry`. +#[derive(Clone)] +pub struct BitSet<T: Into<usize> + Copy> { + entries: Vec<Entry>, + num_bits: usize, + phantom: std::marker::PhantomData<T>, +} + +impl<T: Into<usize> + Copy> BitSet<T> { + pub fn with_capacity(num_bits: usize) -> Self { + let num_entries = num_bits.div_ceil(ENTRY_NUM_BITS); + Self { entries: vec![0; num_entries], num_bits, phantom: Default::default() } + } + + /// Returns whether the value was newly inserted: true if the set did not originally contain + /// the bit, and false otherwise. + pub fn insert(&mut self, idx: T) -> bool { + debug_assert!(idx.into() < self.num_bits); + let entry_idx = idx.into() / ENTRY_NUM_BITS; + let bit_idx = idx.into() % ENTRY_NUM_BITS; + let newly_inserted = (self.entries[entry_idx] & (1 << bit_idx)) == 0; + self.entries[entry_idx] |= 1 << bit_idx; + newly_inserted + } + + /// Set all bits to 1. + pub fn insert_all(&mut self) { + for i in 0..self.entries.len() { + self.entries[i] = !0; + } + } + + /// Clear a bit. Returns whether the bit was previously set. + pub fn remove(&mut self, idx: T) -> bool { + debug_assert!(idx.into() < self.num_bits); + let entry_idx = idx.into() / ENTRY_NUM_BITS; + let bit_idx = idx.into() % ENTRY_NUM_BITS; + let was_set = (self.entries[entry_idx] & (1 << bit_idx)) != 0; + self.entries[entry_idx] &= !(1 << bit_idx); + was_set + } + + pub fn get(&self, idx: T) -> bool { + debug_assert!(idx.into() < self.num_bits); + let entry_idx = idx.into() / ENTRY_NUM_BITS; + let bit_idx = idx.into() % ENTRY_NUM_BITS; + (self.entries[entry_idx] & (1 << bit_idx)) != 0 + } + + /// Modify `self` to only have bits set if they are also set in `other`. Returns true if `self` + /// was modified, and false otherwise. + /// `self` and `other` must have the same number of bits. + pub fn intersect_with(&mut self, other: &Self) -> bool { + assert_eq!(self.num_bits, other.num_bits); + let mut changed = false; + for i in 0..self.entries.len() { + let before = self.entries[i]; + self.entries[i] &= other.entries[i]; + changed |= self.entries[i] != before; + } + changed + } + + /// Modify `self` to have bits set if they are set in either `self` or `other`. Returns true if `self` + /// was modified, and false otherwise. + /// `self` and `other` must have the same number of bits. + pub fn union_with(&mut self, other: &Self) -> bool { + assert_eq!(self.num_bits, other.num_bits); + let mut changed = false; + for i in 0..self.entries.len() { + let before = self.entries[i]; + self.entries[i] |= other.entries[i]; + changed |= self.entries[i] != before; + } + changed + } + + /// Modify `self` to remove bits that are set in `other`. Returns true if `self` + /// was modified, and false otherwise. + /// `self` and `other` must have the same number of bits. + pub fn difference_with(&mut self, other: &Self) -> bool { + assert_eq!(self.num_bits, other.num_bits); + let mut changed = false; + for i in 0..self.entries.len() { + let before = self.entries[i]; + self.entries[i] &= !other.entries[i]; + changed |= self.entries[i] != before; + } + changed + } + + /// Check if two BitSets are equal. + /// `self` and `other` must have the same number of bits. + pub fn equals(&self, other: &Self) -> bool { + assert_eq!(self.num_bits, other.num_bits); + self.entries == other.entries + } + + /// Returns an iterator over the indices of set bits. + /// Only iterates over bits that are set, not all possible indices. + pub fn iter_set_bits(&self) -> impl Iterator<Item = usize> + '_ { + self.entries.iter().enumerate().flat_map(move |(entry_idx, &entry)| { + let mut bits = entry; + std::iter::from_fn(move || { + if bits == 0 { + return None; + } + let bit_pos = bits.trailing_zeros() as usize; + bits &= bits - 1; // Clear the lowest set bit + Some(entry_idx * ENTRY_NUM_BITS + bit_pos) + }) + }).filter(move |&idx| idx < self.num_bits) + } +} + +#[cfg(test)] +mod tests { + use super::BitSet; + + #[test] + #[should_panic] + fn get_over_capacity_panics() { + let set = BitSet::with_capacity(0); + assert!(!set.get(0usize)); + } + + #[test] + fn with_capacity_defaults_to_zero() { + let set = BitSet::with_capacity(4); + assert!(!set.get(0usize)); + assert!(!set.get(1usize)); + assert!(!set.get(2usize)); + assert!(!set.get(3usize)); + } + + #[test] + fn insert_sets_bit() { + let mut set = BitSet::with_capacity(4); + assert!(set.insert(1usize)); + assert!(set.get(1usize)); + } + + #[test] + fn insert_with_set_bit_returns_false() { + let mut set = BitSet::with_capacity(4); + assert!(set.insert(1usize)); + assert!(!set.insert(1usize)); + } + + #[test] + fn insert_all_sets_all_bits() { + let mut set = BitSet::with_capacity(4); + set.insert_all(); + assert!(set.get(0usize)); + assert!(set.get(1usize)); + assert!(set.get(2usize)); + assert!(set.get(3usize)); + } + + #[test] + #[should_panic] + fn intersect_with_panics_with_different_num_bits() { + let mut left: BitSet<usize> = BitSet::with_capacity(3); + let right = BitSet::with_capacity(4); + left.intersect_with(&right); + } + #[test] + fn intersect_with_keeps_only_common_bits() { + let mut left = BitSet::with_capacity(3); + let mut right = BitSet::with_capacity(3); + left.insert(0usize); + left.insert(1usize); + right.insert(1usize); + right.insert(2usize); + left.intersect_with(&right); + assert!(!left.get(0usize)); + assert!(left.get(1usize)); + assert!(!left.get(2usize)); + } + + #[test] + fn test_iter_set_bits() { + let mut set: BitSet<usize> = BitSet::with_capacity(10); + set.insert(1usize); + set.insert(5usize); + set.insert(9usize); + + let set_bits: Vec<usize> = set.iter_set_bits().collect(); + assert_eq!(set_bits, vec![1, 5, 9]); + } + + #[test] + fn test_iter_set_bits_empty() { + let set: BitSet<usize> = BitSet::with_capacity(10); + let set_bits: Vec<usize> = set.iter_set_bits().collect(); + assert_eq!(set_bits, vec![]); + } + + #[test] + fn test_iter_set_bits_all() { + let mut set: BitSet<usize> = BitSet::with_capacity(5); + set.insert_all(); + let set_bits: Vec<usize> = set.iter_set_bits().collect(); + assert_eq!(set_bits, vec![0, 1, 2, 3, 4]); + } + + #[test] + fn test_iter_set_bits_large() { + let mut set: BitSet<usize> = BitSet::with_capacity(200); + set.insert(0usize); + set.insert(127usize); + set.insert(128usize); + set.insert(199usize); + + let set_bits: Vec<usize> = set.iter_set_bits().collect(); + assert_eq!(set_bits, vec![0, 127, 128, 199]); + } +} |
