1//===- BalancedPartitioningTest.cpp - BalancedPartitioning tests ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Support/BalancedPartitioning.h"
10#include "llvm/Testing/Support/SupportHelpers.h"
11#include "gmock/gmock.h"
12#include "gtest/gtest.h"
13
14using testing::Each;
15using testing::Field;
16using testing::Not;
17using testing::UnorderedElementsAre;
18using testing::UnorderedElementsAreArray;
19
20namespace llvm {
21
22void PrintTo(const BPFunctionNode &Node, std::ostream *OS) {
23 raw_os_ostream ROS(*OS);
24 Node.dump(OS&: ROS);
25}
26
27class BalancedPartitioningTest : public ::testing::Test {
28protected:
29 BalancedPartitioningConfig Config;
30 BalancedPartitioning Bp;
31 BalancedPartitioningTest() : Bp(Config) {}
32
33 static std::vector<BPFunctionNode::IDT>
34 getIds(std::vector<BPFunctionNode> Nodes) {
35 std::vector<BPFunctionNode::IDT> Ids;
36 for (auto &N : Nodes)
37 Ids.push_back(x: N.Id);
38 return Ids;
39 }
40};
41
42TEST_F(BalancedPartitioningTest, Basic) {
43 std::vector<BPFunctionNode> Nodes = {
44 BPFunctionNode(0, {1, 2}), BPFunctionNode(2, {3, 4}),
45 BPFunctionNode(1, {1, 2}), BPFunctionNode(3, {3, 4}),
46 BPFunctionNode(4, {4}),
47 };
48
49 Bp.run(Nodes);
50
51 auto NodeIs = [](BPFunctionNode::IDT Id, std::optional<uint32_t> Bucket) {
52 return AllOf(matchers: Field(field_name: "Id", field: &BPFunctionNode::Id, matcher: Id),
53 matchers: Field(field_name: "Bucket", field: &BPFunctionNode::Bucket, matcher: Bucket));
54 };
55
56 EXPECT_THAT(Nodes,
57 UnorderedElementsAre(NodeIs(0, 0), NodeIs(1, 1), NodeIs(2, 2),
58 NodeIs(3, 3), NodeIs(4, 4)));
59}
60
61TEST_F(BalancedPartitioningTest, Large) {
62 const int ProblemSize = 1000;
63 std::vector<BPFunctionNode::UtilityNodeT> AllUNs;
64 for (int i = 0; i < ProblemSize; i++)
65 AllUNs.emplace_back(args&: i);
66
67 std::mt19937 RNG;
68 std::vector<BPFunctionNode> Nodes;
69 for (int i = 0; i < ProblemSize; i++) {
70 std::vector<BPFunctionNode::UtilityNodeT> UNs;
71 int SampleSize =
72 std::uniform_int_distribution<int>(0, AllUNs.size() - 1)(RNG);
73 std::sample(first: AllUNs.begin(), last: AllUNs.end(), out: std::back_inserter(x&: UNs),
74 n: SampleSize, g&: RNG);
75 Nodes.emplace_back(args&: i, args&: UNs);
76 }
77
78 auto OrigIds = getIds(Nodes);
79
80 Bp.run(Nodes);
81
82 EXPECT_THAT(
83 Nodes, Each(Not(Field("Bucket", &BPFunctionNode::Bucket, std::nullopt))));
84 EXPECT_THAT(getIds(Nodes), UnorderedElementsAreArray(OrigIds));
85}
86
87TEST_F(BalancedPartitioningTest, MoveGain) {
88 BalancedPartitioning::SignaturesT Signatures = {
89 {.LeftCount: 10, .RightCount: 10, .CachedGainLR: 10.f, .CachedGainRL: 0.f, .CachedGainIsValid: true}, // 0
90 {.LeftCount: 10, .RightCount: 10, .CachedGainLR: 0.f, .CachedGainRL: 10.f, .CachedGainIsValid: true}, // 1
91 {.LeftCount: 10, .RightCount: 10, .CachedGainLR: 0.f, .CachedGainRL: 20.f, .CachedGainIsValid: true}, // 2
92 };
93 EXPECT_FLOAT_EQ(Bp.moveGain(BPFunctionNode(0, {}), true, Signatures), 0.f);
94 EXPECT_FLOAT_EQ(Bp.moveGain(BPFunctionNode(0, {0, 1}), true, Signatures),
95 10.f);
96 EXPECT_FLOAT_EQ(Bp.moveGain(BPFunctionNode(0, {1, 2}), false, Signatures),
97 30.f);
98}
99
100} // end namespace llvm
101

source code of llvm/unittests/Support/BalancedPartitioningTest.cpp