1//===- TestOpsSyntax.cpp - Operations for testing syntax ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestOpsSyntax.h"
10#include "TestDialect.h"
11#include "TestOps.h"
12#include "mlir/IR/OpImplementation.h"
13#include "llvm/Support/Base64.h"
14
15using namespace mlir;
16using namespace test;
17
18//===----------------------------------------------------------------------===//
19// Test Format* operations
20//===----------------------------------------------------------------------===//
21
22//===----------------------------------------------------------------------===//
23// Parsing
24
25static ParseResult parseCustomOptionalOperand(
26 OpAsmParser &parser,
27 std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
28 if (succeeded(result: parser.parseOptionalLParen())) {
29 optOperand.emplace();
30 if (parser.parseOperand(result&: *optOperand) || parser.parseRParen())
31 return failure();
32 }
33 return success();
34}
35
36static ParseResult parseCustomDirectiveOperands(
37 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
38 std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
39 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
40 if (parser.parseOperand(result&: operand))
41 return failure();
42 if (succeeded(result: parser.parseOptionalComma())) {
43 optOperand.emplace();
44 if (parser.parseOperand(result&: *optOperand))
45 return failure();
46 }
47 if (parser.parseArrow() || parser.parseLParen() ||
48 parser.parseOperandList(result&: varOperands) || parser.parseRParen())
49 return failure();
50 return success();
51}
52static ParseResult
53parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
54 Type &optOperandType,
55 SmallVectorImpl<Type> &varOperandTypes) {
56 if (parser.parseColon())
57 return failure();
58
59 if (parser.parseType(result&: operandType))
60 return failure();
61 if (succeeded(result: parser.parseOptionalComma())) {
62 if (parser.parseType(result&: optOperandType))
63 return failure();
64 }
65 if (parser.parseArrow() || parser.parseLParen() ||
66 parser.parseTypeList(result&: varOperandTypes) || parser.parseRParen())
67 return failure();
68 return success();
69}
70static ParseResult
71parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
72 Type optOperandType,
73 const SmallVectorImpl<Type> &varOperandTypes) {
74 if (parser.parseKeyword(keyword: "type_refs_capture"))
75 return failure();
76
77 Type operandType2, optOperandType2;
78 SmallVector<Type, 1> varOperandTypes2;
79 if (parseCustomDirectiveResults(parser, operandType&: operandType2, optOperandType&: optOperandType2,
80 varOperandTypes&: varOperandTypes2))
81 return failure();
82
83 if (operandType != operandType2 || optOperandType != optOperandType2 ||
84 varOperandTypes != varOperandTypes2)
85 return failure();
86
87 return success();
88}
89static ParseResult parseCustomDirectiveOperandsAndTypes(
90 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
91 std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
92 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
93 Type &operandType, Type &optOperandType,
94 SmallVectorImpl<Type> &varOperandTypes) {
95 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
96 parseCustomDirectiveResults(parser, operandType, optOperandType,
97 varOperandTypes))
98 return failure();
99 return success();
100}
101static ParseResult parseCustomDirectiveRegions(
102 OpAsmParser &parser, Region &region,
103 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
104 if (parser.parseRegion(region))
105 return failure();
106 if (failed(result: parser.parseOptionalComma()))
107 return success();
108 std::unique_ptr<Region> varRegion = std::make_unique<Region>();
109 if (parser.parseRegion(region&: *varRegion))
110 return failure();
111 varRegions.emplace_back(Args: std::move(varRegion));
112 return success();
113}
114static ParseResult
115parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
116 SmallVectorImpl<Block *> &varSuccessors) {
117 if (parser.parseSuccessor(dest&: successor))
118 return failure();
119 if (failed(result: parser.parseOptionalComma()))
120 return success();
121 Block *varSuccessor;
122 if (parser.parseSuccessor(dest&: varSuccessor))
123 return failure();
124 varSuccessors.append(NumInputs: 2, Elt: varSuccessor);
125 return success();
126}
127static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
128 IntegerAttr &attr,
129 IntegerAttr &optAttr) {
130 if (parser.parseAttribute(result&: attr))
131 return failure();
132 if (succeeded(result: parser.parseOptionalComma())) {
133 if (parser.parseAttribute(result&: optAttr))
134 return failure();
135 }
136 return success();
137}
138static ParseResult parseCustomDirectiveSpacing(OpAsmParser &parser,
139 mlir::StringAttr &attr) {
140 return parser.parseAttribute(result&: attr);
141}
142static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
143 NamedAttrList &attrs) {
144 return parser.parseOptionalAttrDict(result&: attrs);
145}
146static ParseResult parseCustomDirectiveOptionalOperandRef(
147 OpAsmParser &parser,
148 std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
149 int64_t operandCount = 0;
150 if (parser.parseInteger(result&: operandCount))
151 return failure();
152 bool expectedOptionalOperand = operandCount == 0;
153 return success(isSuccess: expectedOptionalOperand != optOperand.has_value());
154}
155
156//===----------------------------------------------------------------------===//
157// Printing
158
159static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
160 Value optOperand) {
161 if (optOperand)
162 printer << "(" << optOperand << ") ";
163}
164
165static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
166 Value operand, Value optOperand,
167 OperandRange varOperands) {
168 printer << operand;
169 if (optOperand)
170 printer << ", " << optOperand;
171 printer << " -> (" << varOperands << ")";
172}
173static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
174 Type operandType, Type optOperandType,
175 TypeRange varOperandTypes) {
176 printer << " : " << operandType;
177 if (optOperandType)
178 printer << ", " << optOperandType;
179 printer << " -> (" << varOperandTypes << ")";
180}
181static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
182 Operation *op, Type operandType,
183 Type optOperandType,
184 TypeRange varOperandTypes) {
185 printer << " type_refs_capture ";
186 printCustomDirectiveResults(printer, op, operandType, optOperandType,
187 varOperandTypes);
188}
189static void printCustomDirectiveOperandsAndTypes(
190 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
191 OperandRange varOperands, Type operandType, Type optOperandType,
192 TypeRange varOperandTypes) {
193 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
194 printCustomDirectiveResults(printer, op, operandType, optOperandType,
195 varOperandTypes);
196}
197static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
198 Region &region,
199 MutableArrayRef<Region> varRegions) {
200 printer.printRegion(blocks&: region);
201 if (!varRegions.empty()) {
202 printer << ", ";
203 for (Region &region : varRegions)
204 printer.printRegion(blocks&: region);
205 }
206}
207static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
208 Block *successor,
209 SuccessorRange varSuccessors) {
210 printer << successor;
211 if (!varSuccessors.empty())
212 printer << ", " << varSuccessors.front();
213}
214static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
215 Attribute attribute,
216 Attribute optAttribute) {
217 printer << attribute;
218 if (optAttribute)
219 printer << ", " << optAttribute;
220}
221static void printCustomDirectiveSpacing(OpAsmPrinter &printer, Operation *op,
222 Attribute attribute) {
223 printer << attribute;
224}
225static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
226 DictionaryAttr attrs) {
227 printer.printOptionalAttrDict(attrs: attrs.getValue());
228}
229
230static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
231 Operation *op,
232 Value optOperand) {
233 printer << (optOperand ? "1" : "0");
234}
235//===----------------------------------------------------------------------===//
236// Test parser.
237//===----------------------------------------------------------------------===//
238
239ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
240 OperationState &result) {
241 if (parser.parseOptionalColon())
242 return success();
243 uint64_t numResults;
244 if (parser.parseInteger(numResults))
245 return failure();
246
247 IndexType type = parser.getBuilder().getIndexType();
248 for (unsigned i = 0; i < numResults; ++i)
249 result.addTypes(type);
250 return success();
251}
252
253void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
254 if (unsigned numResults = getNumResults())
255 p << " : " << numResults;
256}
257
258ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
259 OperationState &result) {
260 StringRef keyword;
261 if (parser.parseKeyword(&keyword))
262 return failure();
263 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
264 return success();
265}
266
267void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
268
269ParseResult ParseB64BytesOp::parse(OpAsmParser &parser,
270 OperationState &result) {
271 std::vector<char> bytes;
272 if (parser.parseBase64Bytes(&bytes))
273 return failure();
274 result.addAttribute("b64", parser.getBuilder().getStringAttr(
275 StringRef(&bytes.front(), bytes.size())));
276 return success();
277}
278
279void ParseB64BytesOp::print(OpAsmPrinter &p) {
280 p << " \"" << llvm::encodeBase64(getB64()) << "\"";
281}
282
283::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
284 ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
285 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
286 OpaqueProperties properties, ::mlir::RegionRange regions,
287 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
288 inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
289 return ::mlir::success();
290}
291
292//===----------------------------------------------------------------------===//
293// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
294
295ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
296 OperationState &result) {
297 if (parser.parseKeyword("wraps"))
298 return failure();
299
300 // Parse the wrapped op in a region
301 Region &body = *result.addRegion();
302 body.push_back(new Block);
303 Block &block = body.back();
304 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
305 if (!wrappedOp)
306 return failure();
307
308 // Create a return terminator in the inner region, pass as operand to the
309 // terminator the returned values from the wrapped operation.
310 SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
311 OpBuilder builder(parser.getContext());
312 builder.setInsertionPointToEnd(&block);
313 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
314
315 // Get the results type for the wrapping op from the terminator operands.
316 Operation &returnOp = body.back().back();
317 result.types.append(returnOp.operand_type_begin(),
318 returnOp.operand_type_end());
319
320 // Use the location of the wrapped op for the "test.wrapping_region" op.
321 result.location = wrappedOp->getLoc();
322
323 return success();
324}
325
326void WrappingRegionOp::print(OpAsmPrinter &p) {
327 p << " wraps ";
328 p.printGenericOp(&getRegion().front().front());
329}
330
331//===----------------------------------------------------------------------===//
332// Test PrettyPrintedRegionOp - exercising the following parser APIs
333// parseGenericOperationAfterOpName
334// parseCustomOperationName
335//===----------------------------------------------------------------------===//
336
337ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
338 OperationState &result) {
339
340 SMLoc loc = parser.getCurrentLocation();
341 Location currLocation = parser.getEncodedSourceLoc(loc);
342
343 // Parse the operands.
344 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
345 if (parser.parseOperandList(operands))
346 return failure();
347
348 // Check if we are parsing the pretty-printed version
349 // test.pretty_printed_region start <inner-op> end : <functional-type>
350 // Else fallback to parsing the "non pretty-printed" version.
351 if (!succeeded(parser.parseOptionalKeyword("start")))
352 return parser.parseGenericOperationAfterOpName(result,
353 llvm::ArrayRef(operands));
354
355 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
356 if (failed(parseOpNameInfo))
357 return failure();
358
359 StringAttr innerOpName = parseOpNameInfo->getIdentifier();
360
361 FunctionType opFntype;
362 std::optional<Location> explicitLoc;
363 if (parser.parseKeyword("end") || parser.parseColon() ||
364 parser.parseType(opFntype) ||
365 parser.parseOptionalLocationSpecifier(explicitLoc))
366 return failure();
367
368 // If location of the op is explicitly provided, then use it; Else use
369 // the parser's current location.
370 Location opLoc = explicitLoc.value_or(currLocation);
371
372 // Derive the SSA-values for op's operands.
373 if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
374 result.operands))
375 return failure();
376
377 // Add a region for op.
378 Region &region = *result.addRegion();
379
380 // Create a basic-block inside op's region.
381 Block &block = region.emplaceBlock();
382
383 // Create and insert an "inner-op" operation in the block.
384 // Just for testing purposes, we can assume that inner op is a binary op with
385 // result and operand types all same as the test-op's first operand.
386 Type innerOpType = opFntype.getInput(0);
387 Value lhs = block.addArgument(innerOpType, opLoc);
388 Value rhs = block.addArgument(innerOpType, opLoc);
389
390 OpBuilder builder(parser.getBuilder().getContext());
391 builder.setInsertionPointToStart(&block);
392
393 Operation *innerOp =
394 builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
395
396 // Insert a return statement in the block returning the inner-op's result.
397 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
398
399 // Populate the op operation-state with result-type and location.
400 result.addTypes(opFntype.getResults());
401 result.location = innerOp->getLoc();
402
403 return success();
404}
405
406void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
407 p << ' ';
408 p.printOperands(getOperands());
409
410 Operation &innerOp = getRegion().front().front();
411 // Assuming that region has a single non-terminator inner-op, if the inner-op
412 // meets some criteria (which in this case is a simple one based on the name
413 // of inner-op), then we can print the entire region in a succinct way.
414 // Here we assume that the prototype of "test.special.op" can be trivially
415 // derived while parsing it back.
416 if (innerOp.getName().getStringRef().equals("test.special.op")) {
417 p << " start test.special.op end";
418 } else {
419 p << " (";
420 p.printRegion(getRegion());
421 p << ")";
422 }
423
424 p << " : ";
425 p.printFunctionalType(*this);
426}
427
428//===----------------------------------------------------------------------===//
429// Test PolyForOp - parse list of region arguments.
430//===----------------------------------------------------------------------===//
431
432ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
433 SmallVector<OpAsmParser::Argument, 4> ivsInfo;
434 // Parse list of region arguments without a delimiter.
435 if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None))
436 return failure();
437
438 // Parse the body region.
439 Region *body = result.addRegion();
440 for (auto &iv : ivsInfo)
441 iv.type = parser.getBuilder().getIndexType();
442 return parser.parseRegion(*body, ivsInfo);
443}
444
445void PolyForOp::print(OpAsmPrinter &p) {
446 p << " ";
447 llvm::interleaveComma(getRegion().getArguments(), p, [&](auto arg) {
448 p.printRegionArgument(arg, /*argAttrs =*/{}, /*omitType=*/true);
449 });
450 p << " ";
451 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
452}
453
454void PolyForOp::getAsmBlockArgumentNames(Region &region,
455 OpAsmSetValueNameFn setNameFn) {
456 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
457 if (!arrayAttr)
458 return;
459 auto args = getRegion().front().getArguments();
460 auto e = std::min(arrayAttr.size(), args.size());
461 for (unsigned i = 0; i < e; ++i) {
462 if (auto strAttr = dyn_cast<StringAttr>(arrayAttr[i]))
463 setNameFn(args[i], strAttr.getValue());
464 }
465}
466
467//===----------------------------------------------------------------------===//
468// TestAttrWithLoc - parse/printOptionalLocationSpecifier
469//===----------------------------------------------------------------------===//
470
471static ParseResult parseOptionalLoc(OpAsmParser &p, Attribute &loc) {
472 std::optional<Location> result;
473 SMLoc sourceLoc = p.getCurrentLocation();
474 if (p.parseOptionalLocationSpecifier(result))
475 return failure();
476 if (result)
477 loc = *result;
478 else
479 loc = p.getEncodedSourceLoc(loc: sourceLoc);
480 return success();
481}
482
483static void printOptionalLoc(OpAsmPrinter &p, Operation *op, Attribute loc) {
484 p.printOptionalLocationSpecifier(loc: cast<LocationAttr>(Val&: loc));
485}
486
487#define GET_OP_CLASSES
488#include "TestOpsSyntax.cpp.inc"
489
490void TestDialect::registerOpsSyntax() {
491 addOperations<
492#define GET_OP_LIST
493#include "TestOpsSyntax.cpp.inc"
494 >();
495}
496

source code of mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp