1#include "llvm/ADT/APFloat.h"
2#include "llvm/ADT/STLExtras.h"
3#include "llvm/IR/BasicBlock.h"
4#include "llvm/IR/Constants.h"
5#include "llvm/IR/DerivedTypes.h"
6#include "llvm/IR/Function.h"
7#include "llvm/IR/IRBuilder.h"
8#include "llvm/IR/LLVMContext.h"
9#include "llvm/IR/Module.h"
10#include "llvm/IR/Type.h"
11#include "llvm/IR/Verifier.h"
12#include <algorithm>
13#include <cctype>
14#include <cstdio>
15#include <cstdlib>
16#include <map>
17#include <memory>
18#include <string>
19#include <vector>
20
21using namespace llvm;
22
23//===----------------------------------------------------------------------===//
24// Lexer
25//===----------------------------------------------------------------------===//
26
27// The lexer returns tokens [0-255] if it is an unknown character, otherwise one
28// of these for known things.
29enum Token {
30 tok_eof = -1,
31
32 // commands
33 tok_def = -2,
34 tok_extern = -3,
35
36 // primary
37 tok_identifier = -4,
38 tok_number = -5
39};
40
41static std::string IdentifierStr; // Filled in if tok_identifier
42static double NumVal; // Filled in if tok_number
43
44/// gettok - Return the next token from standard input.
45static int gettok() {
46 static int LastChar = ' ';
47
48 // Skip any whitespace.
49 while (isspace(LastChar))
50 LastChar = getchar();
51
52 if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
53 IdentifierStr = LastChar;
54 while (isalnum((LastChar = getchar())))
55 IdentifierStr += LastChar;
56
57 if (IdentifierStr == "def")
58 return tok_def;
59 if (IdentifierStr == "extern")
60 return tok_extern;
61 return tok_identifier;
62 }
63
64 if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
65 std::string NumStr;
66 do {
67 NumStr += LastChar;
68 LastChar = getchar();
69 } while (isdigit(LastChar) || LastChar == '.');
70
71 NumVal = strtod(nptr: NumStr.c_str(), endptr: nullptr);
72 return tok_number;
73 }
74
75 if (LastChar == '#') {
76 // Comment until end of line.
77 do
78 LastChar = getchar();
79 while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
80
81 if (LastChar != EOF)
82 return gettok();
83 }
84
85 // Check for end of file. Don't eat the EOF.
86 if (LastChar == EOF)
87 return tok_eof;
88
89 // Otherwise, just return the character as its ascii value.
90 int ThisChar = LastChar;
91 LastChar = getchar();
92 return ThisChar;
93}
94
95//===----------------------------------------------------------------------===//
96// Abstract Syntax Tree (aka Parse Tree)
97//===----------------------------------------------------------------------===//
98
99namespace {
100
101/// ExprAST - Base class for all expression nodes.
102class ExprAST {
103public:
104 virtual ~ExprAST() = default;
105
106 virtual Value *codegen() = 0;
107};
108
109/// NumberExprAST - Expression class for numeric literals like "1.0".
110class NumberExprAST : public ExprAST {
111 double Val;
112
113public:
114 NumberExprAST(double Val) : Val(Val) {}
115
116 Value *codegen() override;
117};
118
119/// VariableExprAST - Expression class for referencing a variable, like "a".
120class VariableExprAST : public ExprAST {
121 std::string Name;
122
123public:
124 VariableExprAST(const std::string &Name) : Name(Name) {}
125
126 Value *codegen() override;
127};
128
129/// BinaryExprAST - Expression class for a binary operator.
130class BinaryExprAST : public ExprAST {
131 char Op;
132 std::unique_ptr<ExprAST> LHS, RHS;
133
134public:
135 BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
136 std::unique_ptr<ExprAST> RHS)
137 : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
138
139 Value *codegen() override;
140};
141
142/// CallExprAST - Expression class for function calls.
143class CallExprAST : public ExprAST {
144 std::string Callee;
145 std::vector<std::unique_ptr<ExprAST>> Args;
146
147public:
148 CallExprAST(const std::string &Callee,
149 std::vector<std::unique_ptr<ExprAST>> Args)
150 : Callee(Callee), Args(std::move(Args)) {}
151
152 Value *codegen() override;
153};
154
155/// PrototypeAST - This class represents the "prototype" for a function,
156/// which captures its name, and its argument names (thus implicitly the number
157/// of arguments the function takes).
158class PrototypeAST {
159 std::string Name;
160 std::vector<std::string> Args;
161
162public:
163 PrototypeAST(const std::string &Name, std::vector<std::string> Args)
164 : Name(Name), Args(std::move(Args)) {}
165
166 Function *codegen();
167 const std::string &getName() const { return Name; }
168};
169
170/// FunctionAST - This class represents a function definition itself.
171class FunctionAST {
172 std::unique_ptr<PrototypeAST> Proto;
173 std::unique_ptr<ExprAST> Body;
174
175public:
176 FunctionAST(std::unique_ptr<PrototypeAST> Proto,
177 std::unique_ptr<ExprAST> Body)
178 : Proto(std::move(Proto)), Body(std::move(Body)) {}
179
180 Function *codegen();
181};
182
183} // end anonymous namespace
184
185//===----------------------------------------------------------------------===//
186// Parser
187//===----------------------------------------------------------------------===//
188
189/// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current
190/// token the parser is looking at. getNextToken reads another token from the
191/// lexer and updates CurTok with its results.
192static int CurTok;
193static int getNextToken() { return CurTok = gettok(); }
194
195/// BinopPrecedence - This holds the precedence for each binary operator that is
196/// defined.
197static std::map<char, int> BinopPrecedence;
198
199/// GetTokPrecedence - Get the precedence of the pending binary operator token.
200static int GetTokPrecedence() {
201 if (!isascii(c: CurTok))
202 return -1;
203
204 // Make sure it's a declared binop.
205 int TokPrec = BinopPrecedence[CurTok];
206 if (TokPrec <= 0)
207 return -1;
208 return TokPrec;
209}
210
211/// LogError* - These are little helper functions for error handling.
212std::unique_ptr<ExprAST> LogError(const char *Str) {
213 fprintf(stderr, format: "Error: %s\n", Str);
214 return nullptr;
215}
216
217std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
218 LogError(Str);
219 return nullptr;
220}
221
222static std::unique_ptr<ExprAST> ParseExpression();
223
224/// numberexpr ::= number
225static std::unique_ptr<ExprAST> ParseNumberExpr() {
226 auto Result = std::make_unique<NumberExprAST>(args&: NumVal);
227 getNextToken(); // consume the number
228 return std::move(Result);
229}
230
231/// parenexpr ::= '(' expression ')'
232static std::unique_ptr<ExprAST> ParseParenExpr() {
233 getNextToken(); // eat (.
234 auto V = ParseExpression();
235 if (!V)
236 return nullptr;
237
238 if (CurTok != ')')
239 return LogError(Str: "expected ')'");
240 getNextToken(); // eat ).
241 return V;
242}
243
244/// identifierexpr
245/// ::= identifier
246/// ::= identifier '(' expression* ')'
247static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
248 std::string IdName = IdentifierStr;
249
250 getNextToken(); // eat identifier.
251
252 if (CurTok != '(') // Simple variable ref.
253 return std::make_unique<VariableExprAST>(args&: IdName);
254
255 // Call.
256 getNextToken(); // eat (
257 std::vector<std::unique_ptr<ExprAST>> Args;
258 if (CurTok != ')') {
259 while (true) {
260 if (auto Arg = ParseExpression())
261 Args.push_back(x: std::move(Arg));
262 else
263 return nullptr;
264
265 if (CurTok == ')')
266 break;
267
268 if (CurTok != ',')
269 return LogError(Str: "Expected ')' or ',' in argument list");
270 getNextToken();
271 }
272 }
273
274 // Eat the ')'.
275 getNextToken();
276
277 return std::make_unique<CallExprAST>(args&: IdName, args: std::move(Args));
278}
279
280/// primary
281/// ::= identifierexpr
282/// ::= numberexpr
283/// ::= parenexpr
284static std::unique_ptr<ExprAST> ParsePrimary() {
285 switch (CurTok) {
286 default:
287 return LogError(Str: "unknown token when expecting an expression");
288 case tok_identifier:
289 return ParseIdentifierExpr();
290 case tok_number:
291 return ParseNumberExpr();
292 case '(':
293 return ParseParenExpr();
294 }
295}
296
297/// binoprhs
298/// ::= ('+' primary)*
299static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
300 std::unique_ptr<ExprAST> LHS) {
301 // If this is a binop, find its precedence.
302 while (true) {
303 int TokPrec = GetTokPrecedence();
304
305 // If this is a binop that binds at least as tightly as the current binop,
306 // consume it, otherwise we are done.
307 if (TokPrec < ExprPrec)
308 return LHS;
309
310 // Okay, we know this is a binop.
311 int BinOp = CurTok;
312 getNextToken(); // eat binop
313
314 // Parse the primary expression after the binary operator.
315 auto RHS = ParsePrimary();
316 if (!RHS)
317 return nullptr;
318
319 // If BinOp binds less tightly with RHS than the operator after RHS, let
320 // the pending operator take RHS as its LHS.
321 int NextPrec = GetTokPrecedence();
322 if (TokPrec < NextPrec) {
323 RHS = ParseBinOpRHS(ExprPrec: TokPrec + 1, LHS: std::move(RHS));
324 if (!RHS)
325 return nullptr;
326 }
327
328 // Merge LHS/RHS.
329 LHS =
330 std::make_unique<BinaryExprAST>(args&: BinOp, args: std::move(LHS), args: std::move(RHS));
331 }
332}
333
334/// expression
335/// ::= primary binoprhs
336///
337static std::unique_ptr<ExprAST> ParseExpression() {
338 auto LHS = ParsePrimary();
339 if (!LHS)
340 return nullptr;
341
342 return ParseBinOpRHS(ExprPrec: 0, LHS: std::move(LHS));
343}
344
345/// prototype
346/// ::= id '(' id* ')'
347static std::unique_ptr<PrototypeAST> ParsePrototype() {
348 if (CurTok != tok_identifier)
349 return LogErrorP(Str: "Expected function name in prototype");
350
351 std::string FnName = IdentifierStr;
352 getNextToken();
353
354 if (CurTok != '(')
355 return LogErrorP(Str: "Expected '(' in prototype");
356
357 std::vector<std::string> ArgNames;
358 while (getNextToken() == tok_identifier)
359 ArgNames.push_back(x: IdentifierStr);
360 if (CurTok != ')')
361 return LogErrorP(Str: "Expected ')' in prototype");
362
363 // success.
364 getNextToken(); // eat ')'.
365
366 return std::make_unique<PrototypeAST>(args&: FnName, args: std::move(ArgNames));
367}
368
369/// definition ::= 'def' prototype expression
370static std::unique_ptr<FunctionAST> ParseDefinition() {
371 getNextToken(); // eat def.
372 auto Proto = ParsePrototype();
373 if (!Proto)
374 return nullptr;
375
376 if (auto E = ParseExpression())
377 return std::make_unique<FunctionAST>(args: std::move(Proto), args: std::move(E));
378 return nullptr;
379}
380
381/// toplevelexpr ::= expression
382static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
383 if (auto E = ParseExpression()) {
384 // Make an anonymous proto.
385 auto Proto = std::make_unique<PrototypeAST>(args: "__anon_expr",
386 args: std::vector<std::string>());
387 return std::make_unique<FunctionAST>(args: std::move(Proto), args: std::move(E));
388 }
389 return nullptr;
390}
391
392/// external ::= 'extern' prototype
393static std::unique_ptr<PrototypeAST> ParseExtern() {
394 getNextToken(); // eat extern.
395 return ParsePrototype();
396}
397
398//===----------------------------------------------------------------------===//
399// Code Generation
400//===----------------------------------------------------------------------===//
401
402static std::unique_ptr<LLVMContext> TheContext;
403static std::unique_ptr<Module> TheModule;
404static std::unique_ptr<IRBuilder<>> Builder;
405static std::map<std::string, Value *> NamedValues;
406
407Value *LogErrorV(const char *Str) {
408 LogError(Str);
409 return nullptr;
410}
411
412Value *NumberExprAST::codegen() {
413 return ConstantFP::get(Context&: *TheContext, V: APFloat(Val));
414}
415
416Value *VariableExprAST::codegen() {
417 // Look this variable up in the function.
418 Value *V = NamedValues[Name];
419 if (!V)
420 return LogErrorV(Str: "Unknown variable name");
421 return V;
422}
423
424Value *BinaryExprAST::codegen() {
425 Value *L = LHS->codegen();
426 Value *R = RHS->codegen();
427 if (!L || !R)
428 return nullptr;
429
430 switch (Op) {
431 case '+':
432 return Builder->CreateFAdd(L, R, Name: "addtmp");
433 case '-':
434 return Builder->CreateFSub(L, R, Name: "subtmp");
435 case '*':
436 return Builder->CreateFMul(L, R, Name: "multmp");
437 case '<':
438 L = Builder->CreateFCmpULT(LHS: L, RHS: R, Name: "cmptmp");
439 // Convert bool 0/1 to double 0.0 or 1.0
440 return Builder->CreateUIToFP(V: L, DestTy: Type::getDoubleTy(C&: *TheContext), Name: "booltmp");
441 default:
442 return LogErrorV(Str: "invalid binary operator");
443 }
444}
445
446Value *CallExprAST::codegen() {
447 // Look up the name in the global module table.
448 Function *CalleeF = TheModule->getFunction(Name: Callee);
449 if (!CalleeF)
450 return LogErrorV(Str: "Unknown function referenced");
451
452 // If argument mismatch error.
453 if (CalleeF->arg_size() != Args.size())
454 return LogErrorV(Str: "Incorrect # arguments passed");
455
456 std::vector<Value *> ArgsV;
457 for (unsigned i = 0, e = Args.size(); i != e; ++i) {
458 ArgsV.push_back(x: Args[i]->codegen());
459 if (!ArgsV.back())
460 return nullptr;
461 }
462
463 return Builder->CreateCall(Callee: CalleeF, Args: ArgsV, Name: "calltmp");
464}
465
466Function *PrototypeAST::codegen() {
467 // Make the function type: double(double,double) etc.
468 std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(C&: *TheContext));
469 FunctionType *FT =
470 FunctionType::get(Result: Type::getDoubleTy(C&: *TheContext), Params: Doubles, isVarArg: false);
471
472 Function *F =
473 Function::Create(Ty: FT, Linkage: Function::ExternalLinkage, N: Name, M: TheModule.get());
474
475 // Set names for all arguments.
476 unsigned Idx = 0;
477 for (auto &Arg : F->args())
478 Arg.setName(Args[Idx++]);
479
480 return F;
481}
482
483Function *FunctionAST::codegen() {
484 // First, check for an existing function from a previous 'extern' declaration.
485 Function *TheFunction = TheModule->getFunction(Name: Proto->getName());
486
487 if (!TheFunction)
488 TheFunction = Proto->codegen();
489
490 if (!TheFunction)
491 return nullptr;
492
493 // Create a new basic block to start insertion into.
494 BasicBlock *BB = BasicBlock::Create(Context&: *TheContext, Name: "entry", Parent: TheFunction);
495 Builder->SetInsertPoint(BB);
496
497 // Record the function arguments in the NamedValues map.
498 NamedValues.clear();
499 for (auto &Arg : TheFunction->args())
500 NamedValues[std::string(Arg.getName())] = &Arg;
501
502 if (Value *RetVal = Body->codegen()) {
503 // Finish off the function.
504 Builder->CreateRet(V: RetVal);
505
506 // Validate the generated code, checking for consistency.
507 verifyFunction(F: *TheFunction);
508
509 return TheFunction;
510 }
511
512 // Error reading body, remove function.
513 TheFunction->eraseFromParent();
514 return nullptr;
515}
516
517//===----------------------------------------------------------------------===//
518// Top-Level parsing and JIT Driver
519//===----------------------------------------------------------------------===//
520
521static void InitializeModule() {
522 // Open a new context and module.
523 TheContext = std::make_unique<LLVMContext>();
524 TheModule = std::make_unique<Module>(args: "my cool jit", args&: *TheContext);
525
526 // Create a new builder for the module.
527 Builder = std::make_unique<IRBuilder<>>(args&: *TheContext);
528}
529
530static void HandleDefinition() {
531 if (auto FnAST = ParseDefinition()) {
532 if (auto *FnIR = FnAST->codegen()) {
533 fprintf(stderr, format: "Read function definition:");
534 FnIR->print(OS&: errs());
535 fprintf(stderr, format: "\n");
536 }
537 } else {
538 // Skip token for error recovery.
539 getNextToken();
540 }
541}
542
543static void HandleExtern() {
544 if (auto ProtoAST = ParseExtern()) {
545 if (auto *FnIR = ProtoAST->codegen()) {
546 fprintf(stderr, format: "Read extern: ");
547 FnIR->print(OS&: errs());
548 fprintf(stderr, format: "\n");
549 }
550 } else {
551 // Skip token for error recovery.
552 getNextToken();
553 }
554}
555
556static void HandleTopLevelExpression() {
557 // Evaluate a top-level expression into an anonymous function.
558 if (auto FnAST = ParseTopLevelExpr()) {
559 if (auto *FnIR = FnAST->codegen()) {
560 fprintf(stderr, format: "Read top-level expression:");
561 FnIR->print(OS&: errs());
562 fprintf(stderr, format: "\n");
563
564 // Remove the anonymous expression.
565 FnIR->eraseFromParent();
566 }
567 } else {
568 // Skip token for error recovery.
569 getNextToken();
570 }
571}
572
573/// top ::= definition | external | expression | ';'
574static void MainLoop() {
575 while (true) {
576 fprintf(stderr, format: "ready> ");
577 switch (CurTok) {
578 case tok_eof:
579 return;
580 case ';': // ignore top-level semicolons.
581 getNextToken();
582 break;
583 case tok_def:
584 HandleDefinition();
585 break;
586 case tok_extern:
587 HandleExtern();
588 break;
589 default:
590 HandleTopLevelExpression();
591 break;
592 }
593 }
594}
595
596//===----------------------------------------------------------------------===//
597// Main driver code.
598//===----------------------------------------------------------------------===//
599
600int main() {
601 // Install standard binary operators.
602 // 1 is lowest precedence.
603 BinopPrecedence['<'] = 10;
604 BinopPrecedence['+'] = 20;
605 BinopPrecedence['-'] = 20;
606 BinopPrecedence['*'] = 40; // highest.
607
608 // Prime the first token.
609 fprintf(stderr, format: "ready> ");
610 getNextToken();
611
612 // Make the module, which holds all the code.
613 InitializeModule();
614
615 // Run the main "interpreter loop" now.
616 MainLoop();
617
618 // Print out all of the generated code.
619 TheModule->print(OS&: errs(), AAW: nullptr);
620
621 return 0;
622}
623

source code of llvm/examples/Kaleidoscope/Chapter3/toy.cpp