1//===- llvm/ADT/CoalescingBitVector.h - A coalescing bitvector --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file A bitvector that uses an IntervalMap to coalesce adjacent elements
10/// into intervals.
11///
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_ADT_COALESCINGBITVECTOR_H
15#define LLVM_ADT_COALESCINGBITVECTOR_H
16
17#include "llvm/ADT/IntervalMap.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/ADT/iterator_range.h"
20#include "llvm/Support/Debug.h"
21#include "llvm/Support/raw_ostream.h"
22
23#include <algorithm>
24#include <initializer_list>
25
26namespace llvm {
27
28/// A bitvector that, under the hood, relies on an IntervalMap to coalesce
29/// elements into intervals. Good for representing sets which predominantly
30/// contain contiguous ranges. Bad for representing sets with lots of gaps
31/// between elements.
32///
33/// Compared to SparseBitVector, CoalescingBitVector offers more predictable
34/// performance for non-sequential find() operations.
35///
36/// \tparam IndexT - The type of the index into the bitvector.
37template <typename IndexT> class CoalescingBitVector {
38 static_assert(std::is_unsigned<IndexT>::value,
39 "Index must be an unsigned integer.");
40
41 using ThisT = CoalescingBitVector<IndexT>;
42
43 /// An interval map for closed integer ranges. The mapped values are unused.
44 using MapT = IntervalMap<IndexT, char>;
45
46 using UnderlyingIterator = typename MapT::const_iterator;
47
48 using IntervalT = std::pair<IndexT, IndexT>;
49
50public:
51 using Allocator = typename MapT::Allocator;
52
53 /// Construct by passing in a CoalescingBitVector<IndexT>::Allocator
54 /// reference.
55 CoalescingBitVector(Allocator &Alloc)
56 : Alloc(&Alloc), Intervals(Alloc) {}
57
58 /// \name Copy/move constructors and assignment operators.
59 /// @{
60
61 CoalescingBitVector(const ThisT &Other)
62 : Alloc(Other.Alloc), Intervals(*Other.Alloc) {
63 set(Other);
64 }
65
66 ThisT &operator=(const ThisT &Other) {
67 clear();
68 set(Other);
69 return *this;
70 }
71
72 CoalescingBitVector(ThisT &&Other) = delete;
73 ThisT &operator=(ThisT &&Other) = delete;
74
75 /// @}
76
77 /// Clear all the bits.
78 void clear() { Intervals.clear(); }
79
80 /// Check whether no bits are set.
81 bool empty() const { return Intervals.empty(); }
82
83 /// Count the number of set bits.
84 unsigned count() const {
85 unsigned Bits = 0;
86 for (auto It = Intervals.begin(), End = Intervals.end(); It != End; ++It)
87 Bits += 1 + It.stop() - It.start();
88 return Bits;
89 }
90
91 /// Set the bit at \p Index.
92 ///
93 /// This method does /not/ support setting a bit that has already been set,
94 /// for efficiency reasons. If possible, restructure your code to not set the
95 /// same bit multiple times, or use \ref test_and_set.
96 void set(IndexT Index) {
97 assert(!test(Index) && "Setting already-set bits not supported/efficient, "
98 "IntervalMap will assert");
99 insert(Index, Index);
100 }
101
102 /// Set the bits set in \p Other.
103 ///
104 /// This method does /not/ support setting already-set bits, see \ref set
105 /// for the rationale. For a safe set union operation, use \ref operator|=.
106 void set(const ThisT &Other) {
107 for (auto It = Other.Intervals.begin(), End = Other.Intervals.end();
108 It != End; ++It)
109 insert(It.start(), It.stop());
110 }
111
112 /// Set the bits at \p Indices. Used for testing, primarily.
113 void set(std::initializer_list<IndexT> Indices) {
114 for (IndexT Index : Indices)
115 set(Index);
116 }
117
118 /// Check whether the bit at \p Index is set.
119 bool test(IndexT Index) const {
120 const auto It = Intervals.find(Index);
121 if (It == Intervals.end())
122 return false;
123 assert(It.stop() >= Index && "Interval must end after Index");
124 return It.start() <= Index;
125 }
126
127 /// Set the bit at \p Index. Supports setting an already-set bit.
128 void test_and_set(IndexT Index) {
129 if (!test(Index))
130 set(Index);
131 }
132
133 /// Reset the bit at \p Index. Supports resetting an already-unset bit.
134 void reset(IndexT Index) {
135 auto It = Intervals.find(Index);
136 if (It == Intervals.end())
137 return;
138
139 // Split the interval containing Index into up to two parts: one from
140 // [Start, Index-1] and another from [Index+1, Stop]. If Index is equal to
141 // either Start or Stop, we create one new interval. If Index is equal to
142 // both Start and Stop, we simply erase the existing interval.
143 IndexT Start = It.start();
144 if (Index < Start)
145 // The index was not set.
146 return;
147 IndexT Stop = It.stop();
148 assert(Index <= Stop && "Wrong interval for index");
149 It.erase();
150 if (Start < Index)
151 insert(Start, Index - 1);
152 if (Index < Stop)
153 insert(Index + 1, Stop);
154 }
155
156 /// Set union. If \p RHS is guaranteed to not overlap with this, \ref set may
157 /// be a faster alternative.
158 void operator|=(const ThisT &RHS) {
159 // Get the overlaps between the two interval maps.
160 SmallVector<IntervalT, 8> Overlaps;
161 getOverlaps(RHS, Overlaps);
162
163 // Insert the non-overlapping parts of all the intervals from RHS.
164 for (auto It = RHS.Intervals.begin(), End = RHS.Intervals.end();
165 It != End; ++It) {
166 IndexT Start = It.start();
167 IndexT Stop = It.stop();
168 SmallVector<IntervalT, 8> NonOverlappingParts;
169 getNonOverlappingParts(Start, Stop, Overlaps, NonOverlappingParts);
170 for (IntervalT AdditivePortion : NonOverlappingParts)
171 insert(AdditivePortion.first, AdditivePortion.second);
172 }
173 }
174
175 /// Set intersection.
176 void operator&=(const ThisT &RHS) {
177 // Get the overlaps between the two interval maps (i.e. the intersection).
178 SmallVector<IntervalT, 8> Overlaps;
179 getOverlaps(RHS, Overlaps);
180 // Rebuild the interval map, including only the overlaps.
181 clear();
182 for (IntervalT Overlap : Overlaps)
183 insert(Overlap.first, Overlap.second);
184 }
185
186 /// Reset all bits present in \p Other.
187 void intersectWithComplement(const ThisT &Other) {
188 SmallVector<IntervalT, 8> Overlaps;
189 if (!getOverlaps(Other, Overlaps)) {
190 // If there is no overlap with Other, the intersection is empty.
191 return;
192 }
193
194 // Delete the overlapping intervals. Split up intervals that only partially
195 // intersect an overlap.
196 for (IntervalT Overlap : Overlaps) {
197 IndexT OlapStart, OlapStop;
198 std::tie(OlapStart, OlapStop) = Overlap;
199
200 auto It = Intervals.find(OlapStart);
201 IndexT CurrStart = It.start();
202 IndexT CurrStop = It.stop();
203 assert(CurrStart <= OlapStart && OlapStop <= CurrStop &&
204 "Expected some intersection!");
205
206 // Split the overlap interval into up to two parts: one from [CurrStart,
207 // OlapStart-1] and another from [OlapStop+1, CurrStop]. If OlapStart is
208 // equal to CurrStart, the first split interval is unnecessary. Ditto for
209 // when OlapStop is equal to CurrStop, we omit the second split interval.
210 It.erase();
211 if (CurrStart < OlapStart)
212 insert(CurrStart, OlapStart - 1);
213 if (OlapStop < CurrStop)
214 insert(OlapStop + 1, CurrStop);
215 }
216 }
217
218 bool operator==(const ThisT &RHS) const {
219 // We cannot just use std::equal because it checks the dereferenced values
220 // of an iterator pair for equality, not the iterators themselves. In our
221 // case that results in comparison of the (unused) IntervalMap values.
222 auto ItL = Intervals.begin();
223 auto ItR = RHS.Intervals.begin();
224 while (ItL != Intervals.end() && ItR != RHS.Intervals.end() &&
225 ItL.start() == ItR.start() && ItL.stop() == ItR.stop()) {
226 ++ItL;
227 ++ItR;
228 }
229 return ItL == Intervals.end() && ItR == RHS.Intervals.end();
230 }
231
232 bool operator!=(const ThisT &RHS) const { return !operator==(RHS); }
233
234 class const_iterator {
235 friend class CoalescingBitVector;
236
237 public:
238 using iterator_category = std::forward_iterator_tag;
239 using value_type = IndexT;
240 using difference_type = std::ptrdiff_t;
241 using pointer = value_type *;
242 using reference = value_type &;
243
244 private:
245 // For performance reasons, make the offset at the end different than the
246 // one used in \ref begin, to optimize the common `It == end()` pattern.
247 static constexpr unsigned kIteratorAtTheEndOffset = ~0u;
248
249 UnderlyingIterator MapIterator;
250 unsigned OffsetIntoMapIterator = 0;
251
252 // Querying the start/stop of an IntervalMap iterator can be very expensive.
253 // Cache these values for performance reasons.
254 IndexT CachedStart = IndexT();
255 IndexT CachedStop = IndexT();
256
257 void setToEnd() {
258 OffsetIntoMapIterator = kIteratorAtTheEndOffset;
259 CachedStart = IndexT();
260 CachedStop = IndexT();
261 }
262
263 /// MapIterator has just changed, reset the cached state to point to the
264 /// start of the new underlying iterator.
265 void resetCache() {
266 if (MapIterator.valid()) {
267 OffsetIntoMapIterator = 0;
268 CachedStart = MapIterator.start();
269 CachedStop = MapIterator.stop();
270 } else {
271 setToEnd();
272 }
273 }
274
275 /// Advance the iterator to \p Index, if it is contained within the current
276 /// interval. The public-facing method which supports advancing past the
277 /// current interval is \ref advanceToLowerBound.
278 void advanceTo(IndexT Index) {
279 assert(Index <= CachedStop && "Cannot advance to OOB index");
280 if (Index < CachedStart)
281 // We're already past this index.
282 return;
283 OffsetIntoMapIterator = Index - CachedStart;
284 }
285
286 const_iterator(UnderlyingIterator MapIt) : MapIterator(MapIt) {
287 resetCache();
288 }
289
290 public:
291 const_iterator() { setToEnd(); }
292
293 bool operator==(const const_iterator &RHS) const {
294 // Do /not/ compare MapIterator for equality, as this is very expensive.
295 // The cached start/stop values make that check unnecessary.
296 return std::tie(OffsetIntoMapIterator, CachedStart, CachedStop) ==
297 std::tie(RHS.OffsetIntoMapIterator, RHS.CachedStart,
298 RHS.CachedStop);
299 }
300
301 bool operator!=(const const_iterator &RHS) const {
302 return !operator==(RHS);
303 }
304
305 IndexT operator*() const { return CachedStart + OffsetIntoMapIterator; }
306
307 const_iterator &operator++() { // Pre-increment (++It).
308 if (CachedStart + OffsetIntoMapIterator < CachedStop) {
309 // Keep going within the current interval.
310 ++OffsetIntoMapIterator;
311 } else {
312 // We reached the end of the current interval: advance.
313 ++MapIterator;
314 resetCache();
315 }
316 return *this;
317 }
318
319 const_iterator operator++(int) { // Post-increment (It++).
320 const_iterator tmp = *this;
321 operator++();
322 return tmp;
323 }
324
325 /// Advance the iterator to the first set bit AT, OR AFTER, \p Index. If
326 /// no such set bit exists, advance to end(). This is like std::lower_bound.
327 /// This is useful if \p Index is close to the current iterator position.
328 /// However, unlike \ref find(), this has worst-case O(n) performance.
329 void advanceToLowerBound(IndexT Index) {
330 if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
331 return;
332
333 // Advance to the first interval containing (or past) Index, or to end().
334 while (Index > CachedStop) {
335 ++MapIterator;
336 resetCache();
337 if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
338 return;
339 }
340
341 advanceTo(Index);
342 }
343 };
344
345 const_iterator begin() const { return const_iterator(Intervals.begin()); }
346
347 const_iterator end() const { return const_iterator(); }
348
349 /// Return an iterator pointing to the first set bit AT, OR AFTER, \p Index.
350 /// If no such set bit exists, return end(). This is like std::lower_bound.
351 /// This has worst-case logarithmic performance (roughly O(log(gaps between
352 /// contiguous ranges))).
353 const_iterator find(IndexT Index) const {
354 auto UnderlyingIt = Intervals.find(Index);
355 if (UnderlyingIt == Intervals.end())
356 return end();
357 auto It = const_iterator(UnderlyingIt);
358 It.advanceTo(Index);
359 return It;
360 }
361
362 /// Return a range iterator which iterates over all of the set bits in the
363 /// half-open range [Start, End).
364 iterator_range<const_iterator> half_open_range(IndexT Start,
365 IndexT End) const {
366 assert(Start < End && "Not a valid range");
367 auto StartIt = find(Start);
368 if (StartIt == end() || *StartIt >= End)
369 return {end(), end()};
370 auto EndIt = StartIt;
371 EndIt.advanceToLowerBound(End);
372 return {StartIt, EndIt};
373 }
374
375 void print(raw_ostream &OS) const {
376 OS << "{";
377 for (auto It = Intervals.begin(), End = Intervals.end(); It != End;
378 ++It) {
379 OS << "[" << It.start();
380 if (It.start() != It.stop())
381 OS << ", " << It.stop();
382 OS << "]";
383 }
384 OS << "}";
385 }
386
387#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
388 LLVM_DUMP_METHOD void dump() const {
389 // LLDB swallows the first line of output after callling dump(). Add
390 // newlines before/after the braces to work around this.
391 dbgs() << "\n";
392 print(dbgs());
393 dbgs() << "\n";
394 }
395#endif
396
397private:
398 void insert(IndexT Start, IndexT End) { Intervals.insert(Start, End, 0); }
399
400 /// Record the overlaps between \p this and \p Other in \p Overlaps. Return
401 /// true if there is any overlap.
402 bool getOverlaps(const ThisT &Other,
403 SmallVectorImpl<IntervalT> &Overlaps) const {
404 for (IntervalMapOverlaps<MapT, MapT> I(Intervals, Other.Intervals);
405 I.valid(); ++I)
406 Overlaps.emplace_back(I.start(), I.stop());
407 assert(llvm::is_sorted(Overlaps,
408 [](IntervalT LHS, IntervalT RHS) {
409 return LHS.second < RHS.first;
410 }) &&
411 "Overlaps must be sorted");
412 return !Overlaps.empty();
413 }
414
415 /// Given the set of overlaps between this and some other bitvector, and an
416 /// interval [Start, Stop] from that bitvector, determine the portions of the
417 /// interval which do not overlap with this.
418 void getNonOverlappingParts(IndexT Start, IndexT Stop,
419 const SmallVectorImpl<IntervalT> &Overlaps,
420 SmallVectorImpl<IntervalT> &NonOverlappingParts) {
421 IndexT NextUncoveredBit = Start;
422 for (IntervalT Overlap : Overlaps) {
423 IndexT OlapStart, OlapStop;
424 std::tie(OlapStart, OlapStop) = Overlap;
425
426 // [Start;Stop] and [OlapStart;OlapStop] overlap iff OlapStart <= Stop
427 // and Start <= OlapStop.
428 bool DoesOverlap = OlapStart <= Stop && Start <= OlapStop;
429 if (!DoesOverlap)
430 continue;
431
432 // Cover the range [NextUncoveredBit, OlapStart). This puts the start of
433 // the next uncovered range at OlapStop+1.
434 if (NextUncoveredBit < OlapStart)
435 NonOverlappingParts.emplace_back(NextUncoveredBit, OlapStart - 1);
436 NextUncoveredBit = OlapStop + 1;
437 if (NextUncoveredBit > Stop)
438 break;
439 }
440 if (NextUncoveredBit <= Stop)
441 NonOverlappingParts.emplace_back(NextUncoveredBit, Stop);
442 }
443
444 Allocator *Alloc;
445 MapT Intervals;
446};
447
448} // namespace llvm
449
450#endif // LLVM_ADT_COALESCINGBITVECTOR_H
451