1 | //===- llvm/unittest/Support/DivisionByConstantTest.cpp -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "llvm/ADT/APInt.h" |
10 | #include "llvm/Support/DivisionByConstantInfo.h" |
11 | #include "gtest/gtest.h" |
12 | |
13 | using namespace llvm; |
14 | |
15 | namespace { |
16 | |
17 | template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) { |
18 | APInt N(Bits, 0); |
19 | do { |
20 | TestFn(N); |
21 | } while (++N != 0); |
22 | } |
23 | |
24 | APInt MULHS(APInt X, APInt Y) { |
25 | unsigned Bits = X.getBitWidth(); |
26 | unsigned WideBits = 2 * Bits; |
27 | return (X.sext(width: WideBits) * Y.sext(width: WideBits)).lshr(shiftAmt: Bits).trunc(width: Bits); |
28 | } |
29 | |
30 | APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor, |
31 | SignedDivisionByConstantInfo Magics) { |
32 | unsigned Bits = Numerator.getBitWidth(); |
33 | |
34 | APInt Factor(Bits, 0); |
35 | APInt ShiftMask(Bits, -1); |
36 | if (Divisor.isOne() || Divisor.isAllOnes()) { |
37 | // If d is +1/-1, we just multiply the numerator by +1/-1. |
38 | Factor = Divisor.getSExtValue(); |
39 | Magics.Magic = 0; |
40 | Magics.ShiftAmount = 0; |
41 | ShiftMask = 0; |
42 | } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) { |
43 | // If d > 0 and m < 0, add the numerator. |
44 | Factor = 1; |
45 | } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) { |
46 | // If d < 0 and m > 0, subtract the numerator. |
47 | Factor = -1; |
48 | } |
49 | |
50 | // Multiply the numerator by the magic value. |
51 | APInt Q = MULHS(X: Numerator, Y: Magics.Magic); |
52 | |
53 | // (Optionally) Add/subtract the numerator using Factor. |
54 | Factor = Numerator * Factor; |
55 | Q = Q + Factor; |
56 | |
57 | // Shift right algebraic by shift value. |
58 | Q = Q.ashr(ShiftAmt: Magics.ShiftAmount); |
59 | |
60 | // Extract the sign bit, mask it and add it to the quotient. |
61 | unsigned SignShift = Bits - 1; |
62 | APInt T = Q.lshr(shiftAmt: SignShift); |
63 | T = T & ShiftMask; |
64 | return Q + T; |
65 | } |
66 | |
67 | TEST(SignedDivisionByConstantTest, Test) { |
68 | for (unsigned Bits = 1; Bits <= 32; ++Bits) { |
69 | if (Bits < 3) |
70 | continue; // Not supported by `SignedDivisionByConstantInfo::get()`. |
71 | if (Bits > 12) |
72 | continue; // Unreasonably slow. |
73 | EnumerateAPInts(Bits, TestFn: [Bits](const APInt &Divisor) { |
74 | if (Divisor.isZero()) |
75 | return; // Division by zero is undefined behavior. |
76 | SignedDivisionByConstantInfo Magics; |
77 | if (!(Divisor.isOne() || Divisor.isAllOnes())) |
78 | Magics = SignedDivisionByConstantInfo::get(D: Divisor); |
79 | EnumerateAPInts(Bits, TestFn: [Divisor, Magics, Bits](const APInt &Numerator) { |
80 | if (Numerator.isMinSignedValue() && Divisor.isAllOnes()) |
81 | return; // Overflow is undefined behavior. |
82 | APInt NativeResult = Numerator.sdiv(RHS: Divisor); |
83 | APInt MagicResult = SignedDivideUsingMagic(Numerator, Divisor, Magics); |
84 | ASSERT_EQ(MagicResult, NativeResult) |
85 | << " ... given the operation: srem i" << Bits << " " << Numerator |
86 | << ", " << Divisor; |
87 | }); |
88 | }); |
89 | } |
90 | } |
91 | |
92 | APInt MULHU(APInt X, APInt Y) { |
93 | unsigned Bits = X.getBitWidth(); |
94 | unsigned WideBits = 2 * Bits; |
95 | return (X.zext(width: WideBits) * Y.zext(width: WideBits)).lshr(shiftAmt: Bits).trunc(width: Bits); |
96 | } |
97 | |
98 | APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor, |
99 | bool LZOptimization, |
100 | bool AllowEvenDivisorOptimization, bool ForceNPQ, |
101 | UnsignedDivisionByConstantInfo Magics) { |
102 | assert(!Divisor.isOne() && "Division by 1 is not supported using Magic." ); |
103 | |
104 | unsigned Bits = Numerator.getBitWidth(); |
105 | |
106 | if (LZOptimization) { |
107 | unsigned LeadingZeros = Numerator.countl_zero(); |
108 | // Clip to the number of leading zeros in the divisor. |
109 | LeadingZeros = std::min(a: LeadingZeros, b: Divisor.countl_zero()); |
110 | if (LeadingZeros > 0) { |
111 | Magics = UnsignedDivisionByConstantInfo::get( |
112 | D: Divisor, LeadingZeros, AllowEvenDivisorOptimization); |
113 | assert(!Magics.IsAdd && "Should use cheap fixup now" ); |
114 | } |
115 | } |
116 | |
117 | assert(Magics.PreShift < Divisor.getBitWidth() && |
118 | "We shouldn't generate an undefined shift!" ); |
119 | assert(Magics.PostShift < Divisor.getBitWidth() && |
120 | "We shouldn't generate an undefined shift!" ); |
121 | assert((!Magics.IsAdd || Magics.PreShift == 0) && "Unexpected pre-shift" ); |
122 | unsigned PreShift = Magics.PreShift; |
123 | unsigned PostShift = Magics.PostShift; |
124 | bool UseNPQ = Magics.IsAdd; |
125 | |
126 | APInt NPQFactor = |
127 | UseNPQ ? APInt::getSignedMinValue(numBits: Bits) : APInt::getZero(numBits: Bits); |
128 | |
129 | APInt Q = Numerator.lshr(shiftAmt: PreShift); |
130 | |
131 | // Multiply the numerator by the magic value. |
132 | Q = MULHU(X: Q, Y: Magics.Magic); |
133 | |
134 | if (UseNPQ || ForceNPQ) { |
135 | APInt NPQ = Numerator - Q; |
136 | |
137 | // For vectors we might have a mix of non-NPQ/NPQ paths, so use |
138 | // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero. |
139 | APInt NPQ_Scalar = NPQ.lshr(shiftAmt: 1); |
140 | (void)NPQ_Scalar; |
141 | NPQ = MULHU(X: NPQ, Y: NPQFactor); |
142 | assert(!UseNPQ || NPQ == NPQ_Scalar); |
143 | |
144 | Q = NPQ + Q; |
145 | } |
146 | |
147 | Q = Q.lshr(shiftAmt: PostShift); |
148 | |
149 | return Q; |
150 | } |
151 | |
152 | TEST(UnsignedDivisionByConstantTest, Test) { |
153 | for (unsigned Bits = 1; Bits <= 32; ++Bits) { |
154 | if (Bits < 2) |
155 | continue; // Not supported by `UnsignedDivisionByConstantInfo::get()`. |
156 | if (Bits > 10) |
157 | continue; // Unreasonably slow. |
158 | EnumerateAPInts(Bits, TestFn: [Bits](const APInt &Divisor) { |
159 | if (Divisor.isZero()) |
160 | return; // Division by zero is undefined behavior. |
161 | if (Divisor.isOne()) |
162 | return; // Division by one is the numerator. |
163 | |
164 | const UnsignedDivisionByConstantInfo Magics = |
165 | UnsignedDivisionByConstantInfo::get(D: Divisor); |
166 | EnumerateAPInts(Bits, TestFn: [Divisor, Magics, Bits](const APInt &Numerator) { |
167 | APInt NativeResult = Numerator.udiv(RHS: Divisor); |
168 | for (bool LZOptimization : {true, false}) { |
169 | for (bool AllowEvenDivisorOptimization : {true, false}) { |
170 | for (bool ForceNPQ : {false, true}) { |
171 | APInt MagicResult = UnsignedDivideUsingMagic( |
172 | Numerator, Divisor, LZOptimization, |
173 | AllowEvenDivisorOptimization, ForceNPQ, Magics); |
174 | ASSERT_EQ(MagicResult, NativeResult) |
175 | << " ... given the operation: urem i" << Bits << " " |
176 | << Numerator << ", " << Divisor |
177 | << " (allow LZ optimization = " |
178 | << LZOptimization << ", allow even divisior optimization = " |
179 | << AllowEvenDivisorOptimization << ", force NPQ = " |
180 | << ForceNPQ << ")" ; |
181 | } |
182 | } |
183 | } |
184 | }); |
185 | }); |
186 | } |
187 | } |
188 | |
189 | } // end anonymous namespace |
190 | |