1//===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file
10/// Implements a verifier for AMDGPU HSA metadata.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15#include "llvm/ADT/StringSwitch.h"
16#include "llvm/Support/AMDGPUMetadata.h"
17
18namespace llvm {
19namespace AMDGPU {
20namespace HSAMD {
21namespace V3 {
22
23bool MetadataVerifier::verifyScalar(
24 msgpack::DocNode &Node, msgpack::Type SKind,
25 function_ref<bool(msgpack::DocNode &)> verifyValue) {
26 if (!Node.isScalar())
27 return false;
28 if (Node.getKind() != SKind) {
29 if (Strict)
30 return false;
31 // If we are not strict, we interpret string values as "implicitly typed"
32 // and attempt to coerce them to the expected type here.
33 if (Node.getKind() != msgpack::Type::String)
34 return false;
35 StringRef StringValue = Node.getString();
36 Node.fromString(StringValue);
37 if (Node.getKind() != SKind)
38 return false;
39 }
40 if (verifyValue)
41 return verifyValue(Node);
42 return true;
43}
44
45bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
46 if (!verifyScalar(Node, msgpack::Type::UInt))
47 if (!verifyScalar(Node, msgpack::Type::Int))
48 return false;
49 return true;
50}
51
52bool MetadataVerifier::verifyArray(
53 msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
54 Optional<size_t> Size) {
55 if (!Node.isArray())
56 return false;
57 auto &Array = Node.getArray();
58 if (Size && Array.size() != *Size)
59 return false;
60 for (auto &Item : Array)
61 if (!verifyNode(Item))
62 return false;
63
64 return true;
65}
66
67bool MetadataVerifier::verifyEntry(
68 msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
69 function_ref<bool(msgpack::DocNode &)> verifyNode) {
70 auto Entry = MapNode.find(Key);
71 if (Entry == MapNode.end())
72 return !Required;
73 return verifyNode(Entry->second);
74}
75
76bool MetadataVerifier::verifyScalarEntry(
77 msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
78 msgpack::Type SKind,
79 function_ref<bool(msgpack::DocNode &)> verifyValue) {
80 return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
81 return verifyScalar(Node, SKind, verifyValue);
82 });
83}
84
85bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
86 StringRef Key, bool Required) {
87 return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
88 return verifyInteger(Node);
89 });
90}
91
92bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
93 if (!Node.isMap())
94 return false;
95 auto &ArgsMap = Node.getMap();
96
97 if (!verifyScalarEntry(ArgsMap, ".name", false,
98 msgpack::Type::String))
99 return false;
100 if (!verifyScalarEntry(ArgsMap, ".type_name", false,
101 msgpack::Type::String))
102 return false;
103 if (!verifyIntegerEntry(ArgsMap, ".size", true))
104 return false;
105 if (!verifyIntegerEntry(ArgsMap, ".offset", true))
106 return false;
107 if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
108 msgpack::Type::String,
109 [](msgpack::DocNode &SNode) {
110 return StringSwitch<bool>(SNode.getString())
111 .Case("by_value", true)
112 .Case("global_buffer", true)
113 .Case("dynamic_shared_pointer", true)
114 .Case("sampler", true)
115 .Case("image", true)
116 .Case("pipe", true)
117 .Case("queue", true)
118 .Case("hidden_global_offset_x", true)
119 .Case("hidden_global_offset_y", true)
120 .Case("hidden_global_offset_z", true)
121 .Case("hidden_none", true)
122 .Case("hidden_printf_buffer", true)
123 .Case("hidden_hostcall_buffer", true)
124 .Case("hidden_default_queue", true)
125 .Case("hidden_completion_action", true)
126 .Case("hidden_multigrid_sync_arg", true)
127 .Default(false);
128 }))
129 return false;
130 if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
131 return false;
132 if (!verifyScalarEntry(ArgsMap, ".address_space", false,
133 msgpack::Type::String,
134 [](msgpack::DocNode &SNode) {
135 return StringSwitch<bool>(SNode.getString())
136 .Case("private", true)
137 .Case("global", true)
138 .Case("constant", true)
139 .Case("local", true)
140 .Case("generic", true)
141 .Case("region", true)
142 .Default(false);
143 }))
144 return false;
145 if (!verifyScalarEntry(ArgsMap, ".access", false,
146 msgpack::Type::String,
147 [](msgpack::DocNode &SNode) {
148 return StringSwitch<bool>(SNode.getString())
149 .Case("read_only", true)
150 .Case("write_only", true)
151 .Case("read_write", true)
152 .Default(false);
153 }))
154 return false;
155 if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
156 msgpack::Type::String,
157 [](msgpack::DocNode &SNode) {
158 return StringSwitch<bool>(SNode.getString())
159 .Case("read_only", true)
160 .Case("write_only", true)
161 .Case("read_write", true)
162 .Default(false);
163 }))
164 return false;
165 if (!verifyScalarEntry(ArgsMap, ".is_const", false,
166 msgpack::Type::Boolean))
167 return false;
168 if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
169 msgpack::Type::Boolean))
170 return false;
171 if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
172 msgpack::Type::Boolean))
173 return false;
174 if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
175 msgpack::Type::Boolean))
176 return false;
177
178 return true;
179}
180
181bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
182 if (!Node.isMap())
183 return false;
184 auto &KernelMap = Node.getMap();
185
186 if (!verifyScalarEntry(KernelMap, ".name", true,
187 msgpack::Type::String))
188 return false;
189 if (!verifyScalarEntry(KernelMap, ".symbol", true,
190 msgpack::Type::String))
191 return false;
192 if (!verifyScalarEntry(KernelMap, ".language", false,
193 msgpack::Type::String,
194 [](msgpack::DocNode &SNode) {
195 return StringSwitch<bool>(SNode.getString())
196 .Case("OpenCL C", true)
197 .Case("OpenCL C++", true)
198 .Case("HCC", true)
199 .Case("HIP", true)
200 .Case("OpenMP", true)
201 .Case("Assembler", true)
202 .Default(false);
203 }))
204 return false;
205 if (!verifyEntry(
206 KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
207 return verifyArray(
208 Node,
209 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
210 }))
211 return false;
212 if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
213 return verifyArray(Node, [this](msgpack::DocNode &Node) {
214 return verifyKernelArgs(Node);
215 });
216 }))
217 return false;
218 if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
219 [this](msgpack::DocNode &Node) {
220 return verifyArray(Node,
221 [this](msgpack::DocNode &Node) {
222 return verifyInteger(Node);
223 },
224 3);
225 }))
226 return false;
227 if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
228 [this](msgpack::DocNode &Node) {
229 return verifyArray(Node,
230 [this](msgpack::DocNode &Node) {
231 return verifyInteger(Node);
232 },
233 3);
234 }))
235 return false;
236 if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
237 msgpack::Type::String))
238 return false;
239 if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
240 msgpack::Type::String))
241 return false;
242 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
243 return false;
244 if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
245 return false;
246 if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
247 return false;
248 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
249 return false;
250 if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
251 return false;
252 if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
253 return false;
254 if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
255 return false;
256 if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
257 return false;
258 if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
259 return false;
260 if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
261 return false;
262
263 return true;
264}
265
266bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
267 if (!HSAMetadataRoot.isMap())
268 return false;
269 auto &RootMap = HSAMetadataRoot.getMap();
270
271 if (!verifyEntry(
272 RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
273 return verifyArray(
274 Node,
275 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
276 }))
277 return false;
278 if (!verifyEntry(
279 RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
280 return verifyArray(Node, [this](msgpack::DocNode &Node) {
281 return verifyScalar(Node, msgpack::Type::String);
282 });
283 }))
284 return false;
285 if (!verifyEntry(RootMap, "amdhsa.kernels", true,
286 [this](msgpack::DocNode &Node) {
287 return verifyArray(Node, [this](msgpack::DocNode &Node) {
288 return verifyKernel(Node);
289 });
290 }))
291 return false;
292
293 return true;
294}
295
296} // end namespace V3
297} // end namespace HSAMD
298} // end namespace AMDGPU
299} // end namespace llvm
300