1//===- SMTAPI.h -------------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines a SMT generic Solver API, which will be the base class
10// for every SMT solver specific class.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_SUPPORT_SMTAPI_H
15#define LLVM_SUPPORT_SMTAPI_H
16
17#include "llvm/ADT/APFloat.h"
18#include "llvm/ADT/APSInt.h"
19#include "llvm/ADT/FoldingSet.h"
20#include "llvm/Support/raw_ostream.h"
21#include <memory>
22
23namespace llvm {
24
25/// Generic base class for SMT sorts
26class SMTSort {
27public:
28 SMTSort() = default;
29 virtual ~SMTSort() = default;
30
31 /// Returns true if the sort is a bitvector, calls isBitvectorSortImpl().
32 virtual bool isBitvectorSort() const { return isBitvectorSortImpl(); }
33
34 /// Returns true if the sort is a floating-point, calls isFloatSortImpl().
35 virtual bool isFloatSort() const { return isFloatSortImpl(); }
36
37 /// Returns true if the sort is a boolean, calls isBooleanSortImpl().
38 virtual bool isBooleanSort() const { return isBooleanSortImpl(); }
39
40 /// Returns the bitvector size, fails if the sort is not a bitvector
41 /// Calls getBitvectorSortSizeImpl().
42 virtual unsigned getBitvectorSortSize() const {
43 assert(isBitvectorSort() && "Not a bitvector sort!");
44 unsigned Size = getBitvectorSortSizeImpl();
45 assert(Size && "Size is zero!");
46 return Size;
47 };
48
49 /// Returns the floating-point size, fails if the sort is not a floating-point
50 /// Calls getFloatSortSizeImpl().
51 virtual unsigned getFloatSortSize() const {
52 assert(isFloatSort() && "Not a floating-point sort!");
53 unsigned Size = getFloatSortSizeImpl();
54 assert(Size && "Size is zero!");
55 return Size;
56 };
57
58 virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0;
59
60 bool operator<(const SMTSort &Other) const {
61 llvm::FoldingSetNodeID ID1, ID2;
62 Profile(ID&: ID1);
63 Other.Profile(ID&: ID2);
64 return ID1 < ID2;
65 }
66
67 friend bool operator==(SMTSort const &LHS, SMTSort const &RHS) {
68 return LHS.equal_to(other: RHS);
69 }
70
71 virtual void print(raw_ostream &OS) const = 0;
72
73 LLVM_DUMP_METHOD void dump() const;
74
75protected:
76 /// Query the SMT solver and returns true if two sorts are equal (same kind
77 /// and bit width). This does not check if the two sorts are the same objects.
78 virtual bool equal_to(SMTSort const &other) const = 0;
79
80 /// Query the SMT solver and checks if a sort is bitvector.
81 virtual bool isBitvectorSortImpl() const = 0;
82
83 /// Query the SMT solver and checks if a sort is floating-point.
84 virtual bool isFloatSortImpl() const = 0;
85
86 /// Query the SMT solver and checks if a sort is boolean.
87 virtual bool isBooleanSortImpl() const = 0;
88
89 /// Query the SMT solver and returns the sort bit width.
90 virtual unsigned getBitvectorSortSizeImpl() const = 0;
91
92 /// Query the SMT solver and returns the sort bit width.
93 virtual unsigned getFloatSortSizeImpl() const = 0;
94};
95
96/// Shared pointer for SMTSorts, used by SMTSolver API.
97using SMTSortRef = const SMTSort *;
98
99/// Generic base class for SMT exprs
100class SMTExpr {
101public:
102 SMTExpr() = default;
103 virtual ~SMTExpr() = default;
104
105 bool operator<(const SMTExpr &Other) const {
106 llvm::FoldingSetNodeID ID1, ID2;
107 Profile(ID&: ID1);
108 Other.Profile(ID&: ID2);
109 return ID1 < ID2;
110 }
111
112 virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0;
113
114 friend bool operator==(SMTExpr const &LHS, SMTExpr const &RHS) {
115 return LHS.equal_to(other: RHS);
116 }
117
118 virtual void print(raw_ostream &OS) const = 0;
119
120 LLVM_DUMP_METHOD void dump() const;
121
122protected:
123 /// Query the SMT solver and returns true if two sorts are equal (same kind
124 /// and bit width). This does not check if the two sorts are the same objects.
125 virtual bool equal_to(SMTExpr const &other) const = 0;
126};
127
128/// Shared pointer for SMTExprs, used by SMTSolver API.
129using SMTExprRef = const SMTExpr *;
130
131/// Generic base class for SMT Solvers
132///
133/// This class is responsible for wrapping all sorts and expression generation,
134/// through the mk* methods. It also provides methods to create SMT expressions
135/// straight from clang's AST, through the from* methods.
136class SMTSolver {
137public:
138 SMTSolver() = default;
139 virtual ~SMTSolver() = default;
140
141 LLVM_DUMP_METHOD void dump() const;
142
143 // Returns an appropriate floating-point sort for the given bitwidth.
144 SMTSortRef getFloatSort(unsigned BitWidth) {
145 switch (BitWidth) {
146 case 16:
147 return getFloat16Sort();
148 case 32:
149 return getFloat32Sort();
150 case 64:
151 return getFloat64Sort();
152 case 128:
153 return getFloat128Sort();
154 default:;
155 }
156 llvm_unreachable("Unsupported floating-point bitwidth!");
157 }
158
159 // Returns a boolean sort.
160 virtual SMTSortRef getBoolSort() = 0;
161
162 // Returns an appropriate bitvector sort for the given bitwidth.
163 virtual SMTSortRef getBitvectorSort(const unsigned BitWidth) = 0;
164
165 // Returns a floating-point sort of width 16
166 virtual SMTSortRef getFloat16Sort() = 0;
167
168 // Returns a floating-point sort of width 32
169 virtual SMTSortRef getFloat32Sort() = 0;
170
171 // Returns a floating-point sort of width 64
172 virtual SMTSortRef getFloat64Sort() = 0;
173
174 // Returns a floating-point sort of width 128
175 virtual SMTSortRef getFloat128Sort() = 0;
176
177 // Returns an appropriate sort for the given AST.
178 virtual SMTSortRef getSort(const SMTExprRef &AST) = 0;
179
180 /// Given a constraint, adds it to the solver
181 virtual void addConstraint(const SMTExprRef &Exp) const = 0;
182
183 /// Creates a bitvector addition operation
184 virtual SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
185
186 /// Creates a bitvector subtraction operation
187 virtual SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
188
189 /// Creates a bitvector multiplication operation
190 virtual SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
191
192 /// Creates a bitvector signed modulus operation
193 virtual SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
194
195 /// Creates a bitvector unsigned modulus operation
196 virtual SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
197
198 /// Creates a bitvector signed division operation
199 virtual SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
200
201 /// Creates a bitvector unsigned division operation
202 virtual SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
203
204 /// Creates a bitvector logical shift left operation
205 virtual SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
206
207 /// Creates a bitvector arithmetic shift right operation
208 virtual SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
209
210 /// Creates a bitvector logical shift right operation
211 virtual SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
212
213 /// Creates a bitvector negation operation
214 virtual SMTExprRef mkBVNeg(const SMTExprRef &Exp) = 0;
215
216 /// Creates a bitvector not operation
217 virtual SMTExprRef mkBVNot(const SMTExprRef &Exp) = 0;
218
219 /// Creates a bitvector xor operation
220 virtual SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
221
222 /// Creates a bitvector or operation
223 virtual SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
224
225 /// Creates a bitvector and operation
226 virtual SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
227
228 /// Creates a bitvector unsigned less-than operation
229 virtual SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
230
231 /// Creates a bitvector signed less-than operation
232 virtual SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
233
234 /// Creates a bitvector unsigned greater-than operation
235 virtual SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
236
237 /// Creates a bitvector signed greater-than operation
238 virtual SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
239
240 /// Creates a bitvector unsigned less-equal-than operation
241 virtual SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
242
243 /// Creates a bitvector signed less-equal-than operation
244 virtual SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
245
246 /// Creates a bitvector unsigned greater-equal-than operation
247 virtual SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
248
249 /// Creates a bitvector signed greater-equal-than operation
250 virtual SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
251
252 /// Creates a boolean not operation
253 virtual SMTExprRef mkNot(const SMTExprRef &Exp) = 0;
254
255 /// Creates a boolean equality operation
256 virtual SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
257
258 /// Creates a boolean and operation
259 virtual SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
260
261 /// Creates a boolean or operation
262 virtual SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
263
264 /// Creates a boolean ite operation
265 virtual SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T,
266 const SMTExprRef &F) = 0;
267
268 /// Creates a bitvector sign extension operation
269 virtual SMTExprRef mkBVSignExt(unsigned i, const SMTExprRef &Exp) = 0;
270
271 /// Creates a bitvector zero extension operation
272 virtual SMTExprRef mkBVZeroExt(unsigned i, const SMTExprRef &Exp) = 0;
273
274 /// Creates a bitvector extract operation
275 virtual SMTExprRef mkBVExtract(unsigned High, unsigned Low,
276 const SMTExprRef &Exp) = 0;
277
278 /// Creates a bitvector concat operation
279 virtual SMTExprRef mkBVConcat(const SMTExprRef &LHS,
280 const SMTExprRef &RHS) = 0;
281
282 /// Creates a predicate that checks for overflow in a bitvector addition
283 /// operation
284 virtual SMTExprRef mkBVAddNoOverflow(const SMTExprRef &LHS,
285 const SMTExprRef &RHS,
286 bool isSigned) = 0;
287
288 /// Creates a predicate that checks for underflow in a signed bitvector
289 /// addition operation
290 virtual SMTExprRef mkBVAddNoUnderflow(const SMTExprRef &LHS,
291 const SMTExprRef &RHS) = 0;
292
293 /// Creates a predicate that checks for overflow in a signed bitvector
294 /// subtraction operation
295 virtual SMTExprRef mkBVSubNoOverflow(const SMTExprRef &LHS,
296 const SMTExprRef &RHS) = 0;
297
298 /// Creates a predicate that checks for underflow in a bitvector subtraction
299 /// operation
300 virtual SMTExprRef mkBVSubNoUnderflow(const SMTExprRef &LHS,
301 const SMTExprRef &RHS,
302 bool isSigned) = 0;
303
304 /// Creates a predicate that checks for overflow in a signed bitvector
305 /// division/modulus operation
306 virtual SMTExprRef mkBVSDivNoOverflow(const SMTExprRef &LHS,
307 const SMTExprRef &RHS) = 0;
308
309 /// Creates a predicate that checks for overflow in a bitvector negation
310 /// operation
311 virtual SMTExprRef mkBVNegNoOverflow(const SMTExprRef &Exp) = 0;
312
313 /// Creates a predicate that checks for overflow in a bitvector multiplication
314 /// operation
315 virtual SMTExprRef mkBVMulNoOverflow(const SMTExprRef &LHS,
316 const SMTExprRef &RHS,
317 bool isSigned) = 0;
318
319 /// Creates a predicate that checks for underflow in a signed bitvector
320 /// multiplication operation
321 virtual SMTExprRef mkBVMulNoUnderflow(const SMTExprRef &LHS,
322 const SMTExprRef &RHS) = 0;
323
324 /// Creates a floating-point negation operation
325 virtual SMTExprRef mkFPNeg(const SMTExprRef &Exp) = 0;
326
327 /// Creates a floating-point isInfinite operation
328 virtual SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) = 0;
329
330 /// Creates a floating-point isNaN operation
331 virtual SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) = 0;
332
333 /// Creates a floating-point isNormal operation
334 virtual SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) = 0;
335
336 /// Creates a floating-point isZero operation
337 virtual SMTExprRef mkFPIsZero(const SMTExprRef &Exp) = 0;
338
339 /// Creates a floating-point multiplication operation
340 virtual SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
341
342 /// Creates a floating-point division operation
343 virtual SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
344
345 /// Creates a floating-point remainder operation
346 virtual SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
347
348 /// Creates a floating-point addition operation
349 virtual SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
350
351 /// Creates a floating-point subtraction operation
352 virtual SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
353
354 /// Creates a floating-point less-than operation
355 virtual SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
356
357 /// Creates a floating-point greater-than operation
358 virtual SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
359
360 /// Creates a floating-point less-than-or-equal operation
361 virtual SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
362
363 /// Creates a floating-point greater-than-or-equal operation
364 virtual SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
365
366 /// Creates a floating-point equality operation
367 virtual SMTExprRef mkFPEqual(const SMTExprRef &LHS,
368 const SMTExprRef &RHS) = 0;
369
370 /// Creates a floating-point conversion from floatint-point to floating-point
371 /// operation
372 virtual SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) = 0;
373
374 /// Creates a floating-point conversion from signed bitvector to
375 /// floatint-point operation
376 virtual SMTExprRef mkSBVtoFP(const SMTExprRef &From,
377 const SMTSortRef &To) = 0;
378
379 /// Creates a floating-point conversion from unsigned bitvector to
380 /// floatint-point operation
381 virtual SMTExprRef mkUBVtoFP(const SMTExprRef &From,
382 const SMTSortRef &To) = 0;
383
384 /// Creates a floating-point conversion from floatint-point to signed
385 /// bitvector operation
386 virtual SMTExprRef mkFPtoSBV(const SMTExprRef &From, unsigned ToWidth) = 0;
387
388 /// Creates a floating-point conversion from floatint-point to unsigned
389 /// bitvector operation
390 virtual SMTExprRef mkFPtoUBV(const SMTExprRef &From, unsigned ToWidth) = 0;
391
392 /// Creates a new symbol, given a name and a sort
393 virtual SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) = 0;
394
395 // Returns an appropriate floating-point rounding mode.
396 virtual SMTExprRef getFloatRoundingMode() = 0;
397
398 // If the a model is available, returns the value of a given bitvector symbol
399 virtual llvm::APSInt getBitvector(const SMTExprRef &Exp, unsigned BitWidth,
400 bool isUnsigned) = 0;
401
402 // If the a model is available, returns the value of a given boolean symbol
403 virtual bool getBoolean(const SMTExprRef &Exp) = 0;
404
405 /// Constructs an SMTExprRef from a boolean.
406 virtual SMTExprRef mkBoolean(const bool b) = 0;
407
408 /// Constructs an SMTExprRef from a finite APFloat.
409 virtual SMTExprRef mkFloat(const llvm::APFloat Float) = 0;
410
411 /// Constructs an SMTExprRef from an APSInt and its bit width
412 virtual SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) = 0;
413
414 /// Given an expression, extract the value of this operand in the model.
415 virtual bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) = 0;
416
417 /// Given an expression extract the value of this operand in the model.
418 virtual bool getInterpretation(const SMTExprRef &Exp,
419 llvm::APFloat &Float) = 0;
420
421 /// Check if the constraints are satisfiable
422 virtual std::optional<bool> check() const = 0;
423
424 /// Push the current solver state
425 virtual void push() = 0;
426
427 /// Pop the previous solver state
428 virtual void pop(unsigned NumStates = 1) = 0;
429
430 /// Reset the solver and remove all constraints.
431 virtual void reset() = 0;
432
433 /// Checks if the solver supports floating-points.
434 virtual bool isFPSupported() = 0;
435
436 virtual void print(raw_ostream &OS) const = 0;
437};
438
439/// Shared pointer for SMTSolvers.
440using SMTSolverRef = std::shared_ptr<SMTSolver>;
441
442/// Convenience method to create and Z3Solver object
443SMTSolverRef CreateZ3Solver();
444
445} // namespace llvm
446
447#endif
448

source code of llvm/include/llvm/Support/SMTAPI.h