1 | //===- TFUtils.h - utilities for TFLite -------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | #ifndef LLVM_ANALYSIS_UTILS_TFUTILS_H |
10 | #define LLVM_ANALYSIS_UTILS_TFUTILS_H |
11 | |
12 | #include "llvm/Config/llvm-config.h" |
13 | |
14 | #ifdef LLVM_HAVE_TFLITE |
15 | #include "llvm/ADT/StringMap.h" |
16 | #include "llvm/Analysis/TensorSpec.h" |
17 | #include "llvm/IR/LLVMContext.h" |
18 | #include "llvm/Support/JSON.h" |
19 | |
20 | #include <memory> |
21 | #include <vector> |
22 | |
23 | namespace llvm { |
24 | |
25 | /// Load a SavedModel, find the given inputs and outputs, and setup storage |
26 | /// for input tensors. The user is responsible for correctly dimensioning the |
27 | /// input tensors and setting their values before calling evaluate(). |
28 | /// To initialize: |
29 | /// - construct the object |
30 | /// - initialize the input tensors using initInput. Indices must correspond to |
31 | /// indices in the InputNames used at construction. |
32 | /// To use: |
33 | /// - set input values by using getInput to get each input tensor, and then |
34 | /// setting internal scalars, for all dimensions (tensors are row-major: |
35 | /// https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/c/c_api.h#L205) |
36 | /// - call evaluate. The input tensors' values are not consumed after this, and |
37 | /// may still be read. |
38 | /// - use the outputs in the output vector |
39 | class TFModelEvaluatorImpl; |
40 | class EvaluationResultImpl; |
41 | |
42 | class TFModelEvaluator final { |
43 | public: |
44 | /// The result of a model evaluation. Handles the lifetime of the output |
45 | /// tensors, which means that their values need to be used before |
46 | /// the EvaluationResult's dtor is called. |
47 | class EvaluationResult { |
48 | public: |
49 | EvaluationResult(const EvaluationResult &) = delete; |
50 | EvaluationResult &operator=(const EvaluationResult &Other) = delete; |
51 | |
52 | EvaluationResult(EvaluationResult &&Other); |
53 | EvaluationResult &operator=(EvaluationResult &&Other); |
54 | |
55 | ~EvaluationResult(); |
56 | |
57 | /// Get a (const) pointer to the first element of the tensor at Index. |
58 | template <typename T> T *getTensorValue(size_t Index) { |
59 | return static_cast<T *>(getUntypedTensorValue(Index)); |
60 | } |
61 | |
62 | template <typename T> const T *getTensorValue(size_t Index) const { |
63 | return static_cast<T *>(getUntypedTensorValue(Index)); |
64 | } |
65 | |
66 | /// Get a (const) pointer to the untyped data of the tensor. |
67 | void *getUntypedTensorValue(size_t Index); |
68 | const void *getUntypedTensorValue(size_t Index) const; |
69 | |
70 | private: |
71 | friend class TFModelEvaluator; |
72 | EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl); |
73 | std::unique_ptr<EvaluationResultImpl> Impl; |
74 | }; |
75 | |
76 | TFModelEvaluator(StringRef SavedModelPath, |
77 | const std::vector<TensorSpec> &InputSpecs, |
78 | const std::vector<TensorSpec> &OutputSpecs, |
79 | const char *Tags = "serve" ); |
80 | |
81 | ~TFModelEvaluator(); |
82 | TFModelEvaluator(const TFModelEvaluator &) = delete; |
83 | TFModelEvaluator(TFModelEvaluator &&) = delete; |
84 | |
85 | /// Evaluate the model, assuming it is valid. Returns std::nullopt if the |
86 | /// evaluation fails or the model is invalid, or an EvaluationResult |
87 | /// otherwise. The inputs are assumed to have been already provided via |
88 | /// getInput(). When returning std::nullopt, it also invalidates this object. |
89 | std::optional<EvaluationResult> evaluate(); |
90 | |
91 | /// Provides access to the input vector. |
92 | template <typename T> T *getInput(size_t Index) { |
93 | return static_cast<T *>(getUntypedInput(Index)); |
94 | } |
95 | |
96 | /// Returns true if the model was loaded successfully, false |
97 | /// otherwise. |
98 | bool isValid() const { return !!Impl; } |
99 | |
100 | /// Untyped access to input. |
101 | void *getUntypedInput(size_t Index); |
102 | |
103 | private: |
104 | std::unique_ptr<TFModelEvaluatorImpl> Impl; |
105 | }; |
106 | |
107 | } // namespace llvm |
108 | |
109 | #endif // LLVM_HAVE_TFLITE |
110 | #endif // LLVM_ANALYSIS_UTILS_TFUTILS_H |
111 | |