1//===-- Exhaustive test template for math functions -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "src/__support/CPP/type_traits.h"
10#include "src/__support/FPUtil/FPBits.h"
11#include "test/UnitTest/FPMatcher.h"
12#include "test/UnitTest/Test.h"
13#include "utils/MPFRWrapper/MPFRUtils.h"
14
15#include <atomic>
16#include <functional>
17#include <iostream>
18#include <mutex>
19#include <sstream>
20#include <thread>
21#include <vector>
22
23// To test exhaustively for inputs in the range [start, stop) in parallel:
24// 1. Define a Checker class with:
25// - FloatType: define floating point type to be used.
26// - FPBits: fputil::FPBits<FloatType>.
27// - StorageType: define bit type for the corresponding floating point type.
28// - uint64_t check(start, stop, rounding_mode): a method to test in given
29// range for a given rounding mode, which returns the number of
30// failures.
31// 2. Use LlvmLibcExhaustiveMathTest<Checker> class
32// 3. Call: test_full_range(start, stop, nthreads, rounding)
33// or test_full_range_all_roundings(start, stop).
34// * For single input single output math function, use the convenient template:
35// LlvmLibcUnaryOpExhaustiveMathTest<FloatType, Op, Func>.
36namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
37
38template <typename T> using UnaryOp = T(T);
39
40template <typename T, mpfr::Operation Op, UnaryOp<T> Func>
41struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
42 using FloatType = T;
43 using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
44 using StorageType = typename FPBits::StorageType;
45
46 static constexpr UnaryOp<FloatType> *FUNC = Func;
47
48 // Check in a range, return the number of failures.
49 uint64_t check(StorageType start, StorageType stop,
50 mpfr::RoundingMode rounding) {
51 mpfr::ForceRoundingMode r(rounding);
52 if (!r.success)
53 return (stop > start);
54 StorageType bits = start;
55 uint64_t failed = 0;
56 do {
57 FPBits xbits(bits);
58 FloatType x = xbits.get_val();
59 bool correct =
60 TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, FUNC(x), 0.5, rounding);
61 failed += (!correct);
62 // Uncomment to print out failed values.
63 // if (!correct) {
64 // TEST_MPFR_MATCH(Op::Operation, x, Op::func(x), 0.5, rounding);
65 // }
66 } while (bits++ < stop);
67 return failed;
68 }
69};
70
71// Checker class needs inherit from LIBC_NAMESPACE::testing::Test and provide
72// StorageType and check method.
73template <typename Checker>
74struct LlvmLibcExhaustiveMathTest
75 : public virtual LIBC_NAMESPACE::testing::Test,
76 public Checker {
77 using FloatType = typename Checker::FloatType;
78 using FPBits = typename Checker::FPBits;
79 using StorageType = typename Checker::StorageType;
80
81 static constexpr StorageType INCREMENT = (1 << 20);
82
83 // Break [start, stop) into `nthreads` subintervals and apply *check to each
84 // subinterval in parallel.
85 void test_full_range(StorageType start, StorageType stop,
86 mpfr::RoundingMode rounding) {
87 int n_threads = std::thread::hardware_concurrency();
88 std::vector<std::thread> thread_list;
89 std::mutex mx_cur_val;
90 int current_percent = -1;
91 StorageType current_value = start;
92 std::atomic<uint64_t> failed(0);
93
94 for (int i = 0; i < n_threads; ++i) {
95 thread_list.emplace_back([&, this]() {
96 while (true) {
97 StorageType range_begin, range_end;
98 int new_percent = -1;
99 {
100 std::lock_guard<std::mutex> lock(mx_cur_val);
101 if (current_value == stop)
102 return;
103
104 range_begin = current_value;
105 if (stop >= INCREMENT && stop - INCREMENT >= current_value) {
106 range_end = current_value + INCREMENT;
107 } else {
108 range_end = stop;
109 }
110 current_value = range_end;
111 int pc = 100.0 * (range_end - start) / (stop - start);
112 if (current_percent != pc) {
113 new_percent = pc;
114 current_percent = pc;
115 }
116 }
117 if (new_percent >= 0) {
118 std::stringstream msg;
119 msg << new_percent << "% is in process \r";
120 std::cout << msg.str() << std::flush;
121 }
122
123 uint64_t failed_in_range =
124 Checker::check(range_begin, range_end, rounding);
125 if (failed_in_range > 0) {
126 std::stringstream msg;
127 msg << "Test failed for " << std::dec << failed_in_range
128 << " inputs in range: " << range_begin << " to " << range_end
129 << " [0x" << std::hex << range_begin << ", 0x" << range_end
130 << "), [" << std::hexfloat << FPBits(range_begin).get_val()
131 << ", " << FPBits(range_end).get_val() << ")\n";
132 std::cerr << msg.str() << std::flush;
133
134 failed.fetch_add(i: failed_in_range);
135 }
136 }
137 });
138 }
139
140 for (auto &thread : thread_list) {
141 if (thread.joinable()) {
142 thread.join();
143 }
144 }
145
146 std::cout << std::endl;
147 std::cout << "Test " << ((failed > 0) ? "FAILED" : "PASSED") << std::endl;
148 ASSERT_EQ(failed.load(), uint64_t(0));
149 }
150
151 void test_full_range_all_roundings(StorageType start, StorageType stop) {
152 std::cout << "-- Testing for FE_TONEAREST in range [0x" << std::hex << start
153 << ", 0x" << stop << ") --" << std::dec << std::endl;
154 test_full_range(start, stop, rounding: mpfr::RoundingMode::Nearest);
155
156 std::cout << "-- Testing for FE_UPWARD in range [0x" << std::hex << start
157 << ", 0x" << stop << ") --" << std::dec << std::endl;
158 test_full_range(start, stop, rounding: mpfr::RoundingMode::Upward);
159
160 std::cout << "-- Testing for FE_DOWNWARD in range [0x" << std::hex << start
161 << ", 0x" << stop << ") --" << std::dec << std::endl;
162 test_full_range(start, stop, rounding: mpfr::RoundingMode::Downward);
163
164 std::cout << "-- Testing for FE_TOWARDZERO in range [0x" << std::hex
165 << start << ", 0x" << stop << ") --" << std::dec << std::endl;
166 test_full_range(start, stop, rounding: mpfr::RoundingMode::TowardZero);
167 };
168};
169
170template <typename FloatType, mpfr::Operation Op, UnaryOp<FloatType> Func>
171using LlvmLibcUnaryOpExhaustiveMathTest =
172 LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, Op, Func>>;
173

source code of libc/test/src/math/exhaustive/exhaustive_test.h