1 | //===- TensorSpec.h - type descriptor for a tensor --------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | #ifndef LLVM_ANALYSIS_TENSORSPEC_H |
10 | #define LLVM_ANALYSIS_TENSORSPEC_H |
11 | |
12 | #include "llvm/Config/llvm-config.h" |
13 | |
14 | #include "llvm/ADT/StringMap.h" |
15 | #include "llvm/IR/LLVMContext.h" |
16 | #include "llvm/Support/JSON.h" |
17 | |
18 | #include <memory> |
19 | #include <optional> |
20 | #include <vector> |
21 | |
22 | namespace llvm { |
23 | /// TensorSpec encapsulates the specification of a tensor: its dimensions, or |
24 | /// "shape" (row-major), its type (see TensorSpec::getDataType specializations |
25 | /// for supported types), its name and port (see "TensorFlow: Large-Scale |
26 | /// Machine Learning on Heterogeneous Distributed Systems", section 4.2, para 2: |
27 | /// https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf) |
28 | /// |
29 | /// Note that the design is motivated by Tensorflow, but it is not intended to |
30 | /// be Tensorflow-specific. |
31 | /// |
32 | /// Known tensor types. The left part is the C type, the |
33 | /// right is a name we can use to identify the type (to implement TensorSpec |
34 | /// equality checks), and to use, if needed, when mapping to an underlying |
35 | /// evaluator's type system. The main requirement is that the C type we use has |
36 | /// the same size and encoding (e.g. endian-ness) as the one used by the |
37 | /// evaluator. |
38 | #define SUPPORTED_TENSOR_TYPES(M) \ |
39 | M(float, Float) \ |
40 | M(double, Double) \ |
41 | M(int8_t, Int8) \ |
42 | M(uint8_t, UInt8) \ |
43 | M(int16_t, Int16) \ |
44 | M(uint16_t, UInt16) \ |
45 | M(int32_t, Int32) \ |
46 | M(uint32_t, UInt32) \ |
47 | M(int64_t, Int64) \ |
48 | M(uint64_t, UInt64) |
49 | |
50 | enum class TensorType { |
51 | Invalid, |
52 | #define _TENSOR_TYPE_ENUM_MEMBERS(_, Name) Name, |
53 | SUPPORTED_TENSOR_TYPES(_TENSOR_TYPE_ENUM_MEMBERS) |
54 | #undef _TENSOR_TYPE_ENUM_MEMBERS |
55 | Total |
56 | }; |
57 | |
58 | class TensorSpec final { |
59 | public: |
60 | template <typename T> |
61 | static TensorSpec createSpec(const std::string &Name, |
62 | const std::vector<int64_t> &Shape, |
63 | int Port = 0) { |
64 | return TensorSpec(Name, Port, getDataType<T>(), sizeof(T), Shape); |
65 | } |
66 | |
67 | const std::string &name() const { return Name; } |
68 | int port() const { return Port; } |
69 | TensorType type() const { return Type; } |
70 | const std::vector<int64_t> &shape() const { return Shape; } |
71 | |
72 | bool operator==(const TensorSpec &Other) const { |
73 | return Name == Other.Name && Port == Other.Port && Type == Other.Type && |
74 | Shape == Other.Shape; |
75 | } |
76 | |
77 | bool operator!=(const TensorSpec &Other) const { return !(*this == Other); } |
78 | |
79 | /// Get the number of elements in a tensor with this shape. |
80 | size_t getElementCount() const { return ElementCount; } |
81 | /// Get the size, in bytes, of one element. |
82 | size_t getElementByteSize() const { return ElementSize; } |
83 | /// Get the total size of a memory buffer needed to store the whole tensor. |
84 | size_t getTotalTensorBufferSize() const { return ElementCount * ElementSize; } |
85 | |
86 | template <typename T> bool isElementType() const { |
87 | return getDataType<T>() == Type; |
88 | } |
89 | |
90 | TensorSpec(const std::string &NewName, const TensorSpec &Other) |
91 | : TensorSpec(NewName, Other.Port, Other.Type, Other.ElementSize, |
92 | Other.Shape) {} |
93 | |
94 | void toJSON(json::OStream &OS) const; |
95 | |
96 | private: |
97 | TensorSpec(const std::string &Name, int Port, TensorType Type, |
98 | size_t ElementSize, const std::vector<int64_t> &Shape); |
99 | |
100 | template <typename T> static TensorType getDataType(); |
101 | |
102 | std::string Name; |
103 | int Port = 0; |
104 | TensorType Type = TensorType::Invalid; |
105 | std::vector<int64_t> Shape; |
106 | size_t ElementCount = 0; |
107 | size_t ElementSize = 0; |
108 | }; |
109 | |
110 | /// For debugging. |
111 | std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec); |
112 | |
113 | /// Construct a TensorSpec from a JSON dictionary of the form: |
114 | /// { "name": <string>, |
115 | /// "port": <int>, |
116 | /// "type": <string. Use LLVM's types, e.g. float, double, int64_t>, |
117 | /// "shape": <array of ints> } |
118 | /// For the "type" field, see the C++ primitive types used in |
119 | /// TFUTILS_SUPPORTED_TYPES. |
120 | std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx, |
121 | const json::Value &Value); |
122 | |
123 | #define TFUTILS_GETDATATYPE_DEF(T, Name) \ |
124 | template <> TensorType TensorSpec::getDataType<T>(); |
125 | SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_DEF) |
126 | |
127 | #undef TFUTILS_GETDATATYPE_DEF |
128 | } // namespace llvm |
129 | |
130 | #endif // LLVM_ANALYSIS_TENSORSPEC_H |
131 | |