1 | //===- BalancedPartitioningTest.cpp - BalancedPartitioning tests ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "llvm/Support/BalancedPartitioning.h" |
10 | #include "llvm/Testing/Support/SupportHelpers.h" |
11 | #include "gmock/gmock.h" |
12 | #include "gtest/gtest.h" |
13 | |
14 | using testing::Each; |
15 | using testing::Field; |
16 | using testing::Not; |
17 | using testing::UnorderedElementsAre; |
18 | using testing::UnorderedElementsAreArray; |
19 | |
20 | namespace llvm { |
21 | |
22 | void PrintTo(const BPFunctionNode &Node, std::ostream *OS) { |
23 | raw_os_ostream ROS(*OS); |
24 | Node.dump(OS&: ROS); |
25 | } |
26 | |
27 | class BalancedPartitioningTest : public ::testing::Test { |
28 | protected: |
29 | BalancedPartitioningConfig Config; |
30 | BalancedPartitioning Bp; |
31 | BalancedPartitioningTest() : Bp(Config) {} |
32 | |
33 | static std::vector<BPFunctionNode::IDT> |
34 | getIds(std::vector<BPFunctionNode> Nodes) { |
35 | std::vector<BPFunctionNode::IDT> Ids; |
36 | for (auto &N : Nodes) |
37 | Ids.push_back(x: N.Id); |
38 | return Ids; |
39 | } |
40 | }; |
41 | |
42 | TEST_F(BalancedPartitioningTest, Basic) { |
43 | std::vector<BPFunctionNode> Nodes = { |
44 | BPFunctionNode(0, {1, 2}), BPFunctionNode(2, {3, 4}), |
45 | BPFunctionNode(1, {1, 2}), BPFunctionNode(3, {3, 4}), |
46 | BPFunctionNode(4, {4}), |
47 | }; |
48 | |
49 | Bp.run(Nodes); |
50 | |
51 | auto NodeIs = [](BPFunctionNode::IDT Id, std::optional<uint32_t> Bucket) { |
52 | return AllOf(matchers: Field(field_name: "Id" , field: &BPFunctionNode::Id, matcher: Id), |
53 | matchers: Field(field_name: "Bucket" , field: &BPFunctionNode::Bucket, matcher: Bucket)); |
54 | }; |
55 | |
56 | EXPECT_THAT(Nodes, |
57 | UnorderedElementsAre(NodeIs(0, 0), NodeIs(1, 1), NodeIs(2, 2), |
58 | NodeIs(3, 3), NodeIs(4, 4))); |
59 | } |
60 | |
61 | TEST_F(BalancedPartitioningTest, Large) { |
62 | const int ProblemSize = 1000; |
63 | std::vector<BPFunctionNode::UtilityNodeT> AllUNs; |
64 | for (int i = 0; i < ProblemSize; i++) |
65 | AllUNs.emplace_back(args&: i); |
66 | |
67 | std::mt19937 RNG; |
68 | std::vector<BPFunctionNode> Nodes; |
69 | for (int i = 0; i < ProblemSize; i++) { |
70 | std::vector<BPFunctionNode::UtilityNodeT> UNs; |
71 | int SampleSize = |
72 | std::uniform_int_distribution<int>(0, AllUNs.size() - 1)(RNG); |
73 | std::sample(first: AllUNs.begin(), last: AllUNs.end(), out: std::back_inserter(x&: UNs), |
74 | n: SampleSize, g&: RNG); |
75 | Nodes.emplace_back(args&: i, args&: UNs); |
76 | } |
77 | |
78 | auto OrigIds = getIds(Nodes); |
79 | |
80 | Bp.run(Nodes); |
81 | |
82 | EXPECT_THAT( |
83 | Nodes, Each(Not(Field("Bucket" , &BPFunctionNode::Bucket, std::nullopt)))); |
84 | EXPECT_THAT(getIds(Nodes), UnorderedElementsAreArray(OrigIds)); |
85 | } |
86 | |
87 | TEST_F(BalancedPartitioningTest, MoveGain) { |
88 | BalancedPartitioning::SignaturesT Signatures = { |
89 | {.LeftCount: 10, .RightCount: 10, .CachedGainLR: 10.f, .CachedGainRL: 0.f, .CachedGainIsValid: true}, // 0 |
90 | {.LeftCount: 10, .RightCount: 10, .CachedGainLR: 0.f, .CachedGainRL: 10.f, .CachedGainIsValid: true}, // 1 |
91 | {.LeftCount: 10, .RightCount: 10, .CachedGainLR: 0.f, .CachedGainRL: 20.f, .CachedGainIsValid: true}, // 2 |
92 | }; |
93 | EXPECT_FLOAT_EQ(Bp.moveGain(BPFunctionNode(0, {}), true, Signatures), 0.f); |
94 | EXPECT_FLOAT_EQ(Bp.moveGain(BPFunctionNode(0, {0, 1}), true, Signatures), |
95 | 10.f); |
96 | EXPECT_FLOAT_EQ(Bp.moveGain(BPFunctionNode(0, {1, 2}), false, Signatures), |
97 | 30.f); |
98 | } |
99 | |
100 | } // end namespace llvm |
101 | |