1//===--- TypeMismatchCheck.cpp - clang-tidy--------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TypeMismatchCheck.h"
10#include "clang/Lex/Lexer.h"
11#include "clang/Tooling/FixIt.h"
12#include "llvm/ADT/StringSet.h"
13#include <map>
14
15using namespace clang::ast_matchers;
16
17namespace clang::tidy::mpi {
18
19/// Check if a BuiltinType::Kind matches the MPI datatype.
20///
21/// \param MultiMap datatype group
22/// \param Kind buffer type kind
23/// \param MPIDatatype name of the MPI datatype
24///
25/// \returns true if the pair matches
26static bool
27isMPITypeMatching(const std::multimap<BuiltinType::Kind, StringRef> &MultiMap,
28 const BuiltinType::Kind Kind, StringRef MPIDatatype) {
29 auto ItPair = MultiMap.equal_range(x: Kind);
30 while (ItPair.first != ItPair.second) {
31 if (ItPair.first->second == MPIDatatype)
32 return true;
33 ++ItPair.first;
34 }
35 return false;
36}
37
38/// Check if the MPI datatype is a standard type.
39///
40/// \param MPIDatatype name of the MPI datatype
41///
42/// \returns true if the type is a standard type
43static bool isStandardMPIDatatype(StringRef MPIDatatype) {
44 static llvm::StringSet<> AllTypes = {"MPI_C_BOOL",
45 "MPI_CHAR",
46 "MPI_SIGNED_CHAR",
47 "MPI_UNSIGNED_CHAR",
48 "MPI_WCHAR",
49 "MPI_INT",
50 "MPI_LONG",
51 "MPI_SHORT",
52 "MPI_LONG_LONG",
53 "MPI_LONG_LONG_INT",
54 "MPI_UNSIGNED",
55 "MPI_UNSIGNED_SHORT",
56 "MPI_UNSIGNED_LONG",
57 "MPI_UNSIGNED_LONG_LONG",
58 "MPI_FLOAT",
59 "MPI_DOUBLE",
60 "MPI_LONG_DOUBLE",
61 "MPI_C_COMPLEX",
62 "MPI_C_FLOAT_COMPLEX",
63 "MPI_C_DOUBLE_COMPLEX",
64 "MPI_C_LONG_DOUBLE_COMPLEX",
65 "MPI_INT8_T",
66 "MPI_INT16_T",
67 "MPI_INT32_T",
68 "MPI_INT64_T",
69 "MPI_UINT8_T",
70 "MPI_UINT16_T",
71 "MPI_UINT32_T",
72 "MPI_UINT64_T",
73 "MPI_CXX_BOOL",
74 "MPI_CXX_FLOAT_COMPLEX",
75 "MPI_CXX_DOUBLE_COMPLEX",
76 "MPI_CXX_LONG_DOUBLE_COMPLEX"};
77
78 return AllTypes.contains(key: MPIDatatype);
79}
80
81/// Check if a BuiltinType matches the MPI datatype.
82///
83/// \param Builtin the builtin type
84/// \param BufferTypeName buffer type name, gets assigned
85/// \param MPIDatatype name of the MPI datatype
86/// \param LO language options
87///
88/// \returns true if the type matches
89static bool isBuiltinTypeMatching(const BuiltinType *Builtin,
90 std::string &BufferTypeName,
91 StringRef MPIDatatype,
92 const LangOptions &LO) {
93 static std::multimap<BuiltinType::Kind, StringRef> BuiltinMatches = {
94 // On some systems like PPC or ARM, 'char' is unsigned by default which is
95 // why distinct signedness for the buffer and MPI type is tolerated.
96 {BuiltinType::SChar, "MPI_CHAR"},
97 {BuiltinType::SChar, "MPI_SIGNED_CHAR"},
98 {BuiltinType::SChar, "MPI_UNSIGNED_CHAR"},
99 {BuiltinType::Char_S, "MPI_CHAR"},
100 {BuiltinType::Char_S, "MPI_SIGNED_CHAR"},
101 {BuiltinType::Char_S, "MPI_UNSIGNED_CHAR"},
102 {BuiltinType::UChar, "MPI_CHAR"},
103 {BuiltinType::UChar, "MPI_SIGNED_CHAR"},
104 {BuiltinType::UChar, "MPI_UNSIGNED_CHAR"},
105 {BuiltinType::Char_U, "MPI_CHAR"},
106 {BuiltinType::Char_U, "MPI_SIGNED_CHAR"},
107 {BuiltinType::Char_U, "MPI_UNSIGNED_CHAR"},
108 {BuiltinType::WChar_S, "MPI_WCHAR"},
109 {BuiltinType::WChar_U, "MPI_WCHAR"},
110 {BuiltinType::Bool, "MPI_C_BOOL"},
111 {BuiltinType::Bool, "MPI_CXX_BOOL"},
112 {BuiltinType::Short, "MPI_SHORT"},
113 {BuiltinType::Int, "MPI_INT"},
114 {BuiltinType::Long, "MPI_LONG"},
115 {BuiltinType::LongLong, "MPI_LONG_LONG"},
116 {BuiltinType::LongLong, "MPI_LONG_LONG_INT"},
117 {BuiltinType::UShort, "MPI_UNSIGNED_SHORT"},
118 {BuiltinType::UInt, "MPI_UNSIGNED"},
119 {BuiltinType::ULong, "MPI_UNSIGNED_LONG"},
120 {BuiltinType::ULongLong, "MPI_UNSIGNED_LONG_LONG"},
121 {BuiltinType::Float, "MPI_FLOAT"},
122 {BuiltinType::Double, "MPI_DOUBLE"},
123 {BuiltinType::LongDouble, "MPI_LONG_DOUBLE"}};
124
125 if (!isMPITypeMatching(MultiMap: BuiltinMatches, Kind: Builtin->getKind(), MPIDatatype)) {
126 BufferTypeName = std::string(Builtin->getName(Policy: LO));
127 return false;
128 }
129
130 return true;
131}
132
133/// Check if a complex float/double/long double buffer type matches
134/// the MPI datatype.
135///
136/// \param Complex buffer type
137/// \param BufferTypeName buffer type name, gets assigned
138/// \param MPIDatatype name of the MPI datatype
139/// \param LO language options
140///
141/// \returns true if the type matches or the buffer type is unknown
142static bool isCComplexTypeMatching(const ComplexType *const Complex,
143 std::string &BufferTypeName,
144 StringRef MPIDatatype,
145 const LangOptions &LO) {
146 static std::multimap<BuiltinType::Kind, StringRef> ComplexCMatches = {
147 {BuiltinType::Float, "MPI_C_COMPLEX"},
148 {BuiltinType::Float, "MPI_C_FLOAT_COMPLEX"},
149 {BuiltinType::Double, "MPI_C_DOUBLE_COMPLEX"},
150 {BuiltinType::LongDouble, "MPI_C_LONG_DOUBLE_COMPLEX"}};
151
152 const auto *Builtin =
153 Complex->getElementType().getTypePtr()->getAs<BuiltinType>();
154
155 if (Builtin &&
156 !isMPITypeMatching(MultiMap: ComplexCMatches, Kind: Builtin->getKind(), MPIDatatype)) {
157 BufferTypeName = (llvm::Twine(Builtin->getName(Policy: LO)) + " _Complex").str();
158 return false;
159 }
160 return true;
161}
162
163/// Check if a complex<float/double/long double> templated buffer type matches
164/// the MPI datatype.
165///
166/// \param Template buffer type
167/// \param BufferTypeName buffer type name, gets assigned
168/// \param MPIDatatype name of the MPI datatype
169/// \param LO language options
170///
171/// \returns true if the type matches or the buffer type is unknown
172static bool
173isCXXComplexTypeMatching(const TemplateSpecializationType *const Template,
174 std::string &BufferTypeName, StringRef MPIDatatype,
175 const LangOptions &LO) {
176 static std::multimap<BuiltinType::Kind, StringRef> ComplexCXXMatches = {
177 {BuiltinType::Float, "MPI_CXX_FLOAT_COMPLEX"},
178 {BuiltinType::Double, "MPI_CXX_DOUBLE_COMPLEX"},
179 {BuiltinType::LongDouble, "MPI_CXX_LONG_DOUBLE_COMPLEX"}};
180
181 if (Template->getAsCXXRecordDecl()->getName() != "complex")
182 return true;
183
184 const auto *Builtin = Template->template_arguments()[0]
185 .getAsType()
186 .getTypePtr()
187 ->getAs<BuiltinType>();
188
189 if (Builtin &&
190 !isMPITypeMatching(MultiMap: ComplexCXXMatches, Kind: Builtin->getKind(), MPIDatatype)) {
191 BufferTypeName =
192 (llvm::Twine("complex<") + Builtin->getName(Policy: LO) + ">").str();
193 return false;
194 }
195
196 return true;
197}
198
199/// Check if a fixed size width buffer type matches the MPI datatype.
200///
201/// \param Typedef buffer type
202/// \param BufferTypeName buffer type name, gets assigned
203/// \param MPIDatatype name of the MPI datatype
204///
205/// \returns true if the type matches or the buffer type is unknown
206static bool isTypedefTypeMatching(const TypedefType *const Typedef,
207 std::string &BufferTypeName,
208 StringRef MPIDatatype) {
209 static llvm::StringMap<StringRef> FixedWidthMatches = {
210 {"int8_t", "MPI_INT8_T"}, {"int16_t", "MPI_INT16_T"},
211 {"int32_t", "MPI_INT32_T"}, {"int64_t", "MPI_INT64_T"},
212 {"uint8_t", "MPI_UINT8_T"}, {"uint16_t", "MPI_UINT16_T"},
213 {"uint32_t", "MPI_UINT32_T"}, {"uint64_t", "MPI_UINT64_T"}};
214
215 const auto It = FixedWidthMatches.find(Typedef->getDecl()->getName());
216 // Check if the typedef is known and not matching the MPI datatype.
217 if (It != FixedWidthMatches.end() && It->getValue() != MPIDatatype) {
218 BufferTypeName = std::string(Typedef->getDecl()->getName());
219 return false;
220 }
221 return true;
222}
223
224/// Get the unqualified, dereferenced type of an argument.
225///
226/// \param CE call expression
227/// \param Idx argument index
228///
229/// \returns type of the argument
230static const Type *argumentType(const CallExpr *const CE, const size_t Idx) {
231 const QualType QT = CE->getArg(Arg: Idx)->IgnoreImpCasts()->getType();
232 return QT.getTypePtr()->getPointeeOrArrayElementType();
233}
234
235void TypeMismatchCheck::registerMatchers(MatchFinder *Finder) {
236 Finder->addMatcher(NodeMatch: callExpr().bind(ID: "CE"), Action: this);
237}
238
239void TypeMismatchCheck::check(const MatchFinder::MatchResult &Result) {
240 const auto *const CE = Result.Nodes.getNodeAs<CallExpr>(ID: "CE");
241 if (!CE->getDirectCallee())
242 return;
243
244 if (!FuncClassifier)
245 FuncClassifier.emplace(args&: *Result.Context);
246
247 const IdentifierInfo *Identifier = CE->getDirectCallee()->getIdentifier();
248 if (!Identifier || !FuncClassifier->isMPIType(IdentInfo: Identifier))
249 return;
250
251 // These containers are used, to capture buffer, MPI datatype pairs.
252 SmallVector<const Type *, 1> BufferTypes;
253 SmallVector<const Expr *, 1> BufferExprs;
254 SmallVector<StringRef, 1> MPIDatatypes;
255
256 // Adds a buffer, MPI datatype pair of an MPI call expression to the
257 // containers. For buffers, the type and expression is captured.
258 auto AddPair = [&CE, &Result, &BufferTypes, &BufferExprs, &MPIDatatypes](
259 const size_t BufferIdx, const size_t DatatypeIdx) {
260 // Skip null pointer constants and in place 'operators'.
261 if (CE->getArg(Arg: BufferIdx)->isNullPointerConstant(
262 Ctx&: *Result.Context, NPC: Expr::NPC_ValueDependentIsNull) ||
263 tooling::fixit::getText(Node: *CE->getArg(Arg: BufferIdx), Context: *Result.Context) ==
264 "MPI_IN_PLACE")
265 return;
266
267 StringRef MPIDatatype =
268 tooling::fixit::getText(Node: *CE->getArg(Arg: DatatypeIdx), Context: *Result.Context);
269
270 const Type *ArgType = argumentType(CE, Idx: BufferIdx);
271 // Skip unknown MPI datatypes and void pointers.
272 if (!isStandardMPIDatatype(MPIDatatype) || ArgType->isVoidType())
273 return;
274
275 BufferTypes.push_back(Elt: ArgType);
276 BufferExprs.push_back(Elt: CE->getArg(Arg: BufferIdx));
277 MPIDatatypes.push_back(Elt: MPIDatatype);
278 };
279
280 // Collect all buffer, MPI datatype pairs for the inspected call expression.
281 if (FuncClassifier->isPointToPointType(IdentInfo: Identifier)) {
282 AddPair(0, 2);
283 } else if (FuncClassifier->isCollectiveType(IdentInfo: Identifier)) {
284 if (FuncClassifier->isReduceType(IdentInfo: Identifier)) {
285 AddPair(0, 3);
286 AddPair(1, 3);
287 } else if (FuncClassifier->isScatterType(IdentInfo: Identifier) ||
288 FuncClassifier->isGatherType(IdentInfo: Identifier) ||
289 FuncClassifier->isAlltoallType(IdentInfo: Identifier)) {
290 AddPair(0, 2);
291 AddPair(3, 5);
292 } else if (FuncClassifier->isBcastType(IdentInfo: Identifier)) {
293 AddPair(0, 2);
294 }
295 }
296 checkArguments(BufferTypes, BufferExprs, MPIDatatypes, LO: getLangOpts());
297}
298
299void TypeMismatchCheck::checkArguments(ArrayRef<const Type *> BufferTypes,
300 ArrayRef<const Expr *> BufferExprs,
301 ArrayRef<StringRef> MPIDatatypes,
302 const LangOptions &LO) {
303 std::string BufferTypeName;
304
305 for (size_t I = 0; I < MPIDatatypes.size(); ++I) {
306 const Type *const BT = BufferTypes[I];
307 bool Error = false;
308
309 if (const auto *Typedef = BT->getAs<TypedefType>()) {
310 Error = !isTypedefTypeMatching(Typedef, BufferTypeName, MPIDatatype: MPIDatatypes[I]);
311 } else if (const auto *Complex = BT->getAs<ComplexType>()) {
312 Error =
313 !isCComplexTypeMatching(Complex, BufferTypeName, MPIDatatype: MPIDatatypes[I], LO);
314 } else if (const auto *Template = BT->getAs<TemplateSpecializationType>()) {
315 Error = !isCXXComplexTypeMatching(Template, BufferTypeName,
316 MPIDatatype: MPIDatatypes[I], LO);
317 } else if (const auto *Builtin = BT->getAs<BuiltinType>()) {
318 Error =
319 !isBuiltinTypeMatching(Builtin, BufferTypeName, MPIDatatype: MPIDatatypes[I], LO);
320 }
321
322 if (Error) {
323 const auto Loc = BufferExprs[I]->getSourceRange().getBegin();
324 diag(Loc, "buffer type '%0' does not match the MPI datatype '%1'")
325 << BufferTypeName << MPIDatatypes[I];
326 }
327 }
328}
329
330void TypeMismatchCheck::onEndOfTranslationUnit() { FuncClassifier.reset(); }
331} // namespace clang::tidy::mpi
332

source code of clang-tools-extra/clang-tidy/mpi/TypeMismatchCheck.cpp