1//===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for HLSL constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaHLSL.h"
12#include "clang/Basic/DiagnosticSema.h"
13#include "clang/Basic/LLVM.h"
14#include "clang/Basic/TargetInfo.h"
15#include "clang/Sema/Sema.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/ADT/StringExtras.h"
18#include "llvm/ADT/StringRef.h"
19#include "llvm/Support/ErrorHandling.h"
20#include "llvm/TargetParser/Triple.h"
21#include <iterator>
22
23using namespace clang;
24
25SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
26
27Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
28 SourceLocation KwLoc, IdentifierInfo *Ident,
29 SourceLocation IdentLoc,
30 SourceLocation LBrace) {
31 // For anonymous namespace, take the location of the left brace.
32 DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
33 HLSLBufferDecl *Result = HLSLBufferDecl::Create(
34 C&: getASTContext(), LexicalParent, CBuffer, KwLoc, ID: Ident, IDLoc: IdentLoc, LBrace);
35
36 SemaRef.PushOnScopeChains(Result, BufferScope);
37 SemaRef.PushDeclContext(BufferScope, Result);
38
39 return Result;
40}
41
42void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
43 auto *BufDecl = cast<HLSLBufferDecl>(Val: Dcl);
44 BufDecl->setRBraceLoc(RBrace);
45 SemaRef.PopDeclContext();
46}
47
48HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
49 const AttributeCommonInfo &AL,
50 int X, int Y, int Z) {
51 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
52 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
53 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
54 Diag(AL.getLoc(), diag::note_conflicting_attribute);
55 }
56 return nullptr;
57 }
58 return ::new (getASTContext())
59 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
60}
61
62HLSLShaderAttr *
63SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
64 HLSLShaderAttr::ShaderType ShaderType) {
65 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
66 if (NT->getType() != ShaderType) {
67 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
68 Diag(AL.getLoc(), diag::note_conflicting_attribute);
69 }
70 return nullptr;
71 }
72 return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL);
73}
74
75HLSLParamModifierAttr *
76SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
77 HLSLParamModifierAttr::Spelling Spelling) {
78 // We can only merge an `in` attribute with an `out` attribute. All other
79 // combinations of duplicated attributes are ill-formed.
80 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
81 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
82 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
83 D->dropAttr<HLSLParamModifierAttr>();
84 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
85 return HLSLParamModifierAttr::Create(
86 getASTContext(), /*MergedSpelling=*/true, AdjustedRange,
87 HLSLParamModifierAttr::Keyword_inout);
88 }
89 Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
90 Diag(PA->getLocation(), diag::note_conflicting_attribute);
91 return nullptr;
92 }
93 return HLSLParamModifierAttr::Create(getASTContext(), AL);
94}
95
96void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
97 auto &TargetInfo = getASTContext().getTargetInfo();
98
99 if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
100 return;
101
102 StringRef Env = TargetInfo.getTriple().getEnvironmentName();
103 HLSLShaderAttr::ShaderType ShaderType;
104 if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
105 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
106 // The entry point is already annotated - check that it matches the
107 // triple.
108 if (Shader->getType() != ShaderType) {
109 Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
110 << Shader;
111 FD->setInvalidDecl();
112 }
113 } else {
114 // Implicitly add the shader attribute if the entry function isn't
115 // explicitly annotated.
116 FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), ShaderType,
117 FD->getBeginLoc()));
118 }
119 } else {
120 switch (TargetInfo.getTriple().getEnvironment()) {
121 case llvm::Triple::UnknownEnvironment:
122 case llvm::Triple::Library:
123 break;
124 default:
125 llvm_unreachable("Unhandled environment in triple");
126 }
127 }
128}
129
130void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
131 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
132 assert(ShaderAttr && "Entry point has no shader attribute");
133 HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
134
135 switch (ST) {
136 case HLSLShaderAttr::Pixel:
137 case HLSLShaderAttr::Vertex:
138 case HLSLShaderAttr::Geometry:
139 case HLSLShaderAttr::Hull:
140 case HLSLShaderAttr::Domain:
141 case HLSLShaderAttr::RayGeneration:
142 case HLSLShaderAttr::Intersection:
143 case HLSLShaderAttr::AnyHit:
144 case HLSLShaderAttr::ClosestHit:
145 case HLSLShaderAttr::Miss:
146 case HLSLShaderAttr::Callable:
147 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
148 DiagnoseAttrStageMismatch(NT, ST,
149 {HLSLShaderAttr::Compute,
150 HLSLShaderAttr::Amplification,
151 HLSLShaderAttr::Mesh});
152 FD->setInvalidDecl();
153 }
154 break;
155
156 case HLSLShaderAttr::Compute:
157 case HLSLShaderAttr::Amplification:
158 case HLSLShaderAttr::Mesh:
159 if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
160 Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
161 << HLSLShaderAttr::ConvertShaderTypeToStr(ST);
162 FD->setInvalidDecl();
163 }
164 break;
165 }
166
167 for (ParmVarDecl *Param : FD->parameters()) {
168 if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
169 CheckSemanticAnnotation(EntryPoint: FD, Param, AnnotationAttr: AnnotationAttr);
170 } else {
171 // FIXME: Handle struct parameters where annotations are on struct fields.
172 // See: https://github.com/llvm/llvm-project/issues/57875
173 Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation);
174 Diag(Param->getLocation(), diag::note_previous_decl) << Param;
175 FD->setInvalidDecl();
176 }
177 }
178 // FIXME: Verify return type semantic annotation.
179}
180
181void SemaHLSL::CheckSemanticAnnotation(
182 FunctionDecl *EntryPoint, const Decl *Param,
183 const HLSLAnnotationAttr *AnnotationAttr) {
184 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
185 assert(ShaderAttr && "Entry point has no shader attribute");
186 HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
187
188 switch (AnnotationAttr->getKind()) {
189 case attr::HLSLSV_DispatchThreadID:
190 case attr::HLSLSV_GroupIndex:
191 if (ST == HLSLShaderAttr::Compute)
192 return;
193 DiagnoseAttrStageMismatch(AnnotationAttr, ST, {HLSLShaderAttr::Compute});
194 break;
195 default:
196 llvm_unreachable("Unknown HLSLAnnotationAttr");
197 }
198}
199
200void SemaHLSL::DiagnoseAttrStageMismatch(
201 const Attr *A, HLSLShaderAttr::ShaderType Stage,
202 std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
203 SmallVector<StringRef, 8> StageStrings;
204 llvm::transform(AllowedStages, std::back_inserter(x&: StageStrings),
205 [](HLSLShaderAttr::ShaderType ST) {
206 return StringRef(
207 HLSLShaderAttr::ConvertShaderTypeToStr(ST));
208 });
209 Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
210 << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
211 << (AllowedStages.size() != 1) << join(StageStrings, ", ");
212}
213

source code of clang/lib/Sema/SemaHLSL.cpp