1/*
2 *
3 * Copyright 2018 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19#ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H
20#define GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H
21
22#include <array>
23#include <functional>
24
25#include <grpcpp/impl/codegen/call.h>
26#include <grpcpp/impl/codegen/call_op_set_interface.h>
27#include <grpcpp/impl/codegen/client_interceptor.h>
28#include <grpcpp/impl/codegen/intercepted_channel.h>
29#include <grpcpp/impl/codegen/server_interceptor.h>
30
31#include <grpc/impl/codegen/grpc_types.h>
32
33namespace grpc {
34namespace internal {
35
36class InterceptorBatchMethodsImpl
37 : public experimental::InterceptorBatchMethods {
38 public:
39 InterceptorBatchMethodsImpl() {
40 for (auto i = static_cast<experimental::InterceptionHookPoints>(0);
41 i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS;
42 i = static_cast<experimental::InterceptionHookPoints>(
43 static_cast<size_t>(i) + 1)) {
44 hooks_[static_cast<size_t>(i)] = false;
45 }
46 }
47
48 ~InterceptorBatchMethodsImpl() {}
49
50 bool QueryInterceptionHookPoint(
51 experimental::InterceptionHookPoints type) override {
52 return hooks_[static_cast<size_t>(type)];
53 }
54
55 void Proceed() override {
56 if (call_->client_rpc_info() != nullptr) {
57 return ProceedClient();
58 }
59 GPR_CODEGEN_ASSERT(call_->server_rpc_info() != nullptr);
60 ProceedServer();
61 }
62
63 void Hijack() override {
64 // Only the client can hijack when sending down initial metadata
65 GPR_CODEGEN_ASSERT(!reverse_ && ops_ != nullptr &&
66 call_->client_rpc_info() != nullptr);
67 // It is illegal to call Hijack twice
68 GPR_CODEGEN_ASSERT(!ran_hijacking_interceptor_);
69 auto* rpc_info = call_->client_rpc_info();
70 rpc_info->hijacked_ = true;
71 rpc_info->hijacked_interceptor_ = current_interceptor_index_;
72 ClearHookPoints();
73 ops_->SetHijackingState();
74 ran_hijacking_interceptor_ = true;
75 rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_);
76 }
77
78 void AddInterceptionHookPoint(experimental::InterceptionHookPoints type) {
79 hooks_[static_cast<size_t>(type)] = true;
80 }
81
82 ByteBuffer* GetSerializedSendMessage() override {
83 GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
84 if (*orig_send_message_ != nullptr) {
85 GPR_CODEGEN_ASSERT(serializer_(*orig_send_message_).ok());
86 *orig_send_message_ = nullptr;
87 }
88 return send_message_;
89 }
90
91 const void* GetSendMessage() override {
92 GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
93 return *orig_send_message_;
94 }
95
96 void ModifySendMessage(const void* message) override {
97 GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
98 *orig_send_message_ = message;
99 }
100
101 bool GetSendMessageStatus() override { return !*fail_send_message_; }
102
103 std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override {
104 return send_initial_metadata_;
105 }
106
107 Status GetSendStatus() override {
108 return Status(static_cast<StatusCode>(*code_), *error_message_,
109 *error_details_);
110 }
111
112 void ModifySendStatus(const Status& status) override {
113 *code_ = static_cast<grpc_status_code>(status.error_code());
114 *error_details_ = status.error_details();
115 *error_message_ = status.error_message();
116 }
117
118 std::multimap<grpc::string, grpc::string>* GetSendTrailingMetadata()
119 override {
120 return send_trailing_metadata_;
121 }
122
123 void* GetRecvMessage() override { return recv_message_; }
124
125 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata()
126 override {
127 return recv_initial_metadata_->map();
128 }
129
130 Status* GetRecvStatus() override { return recv_status_; }
131
132 void FailHijackedSendMessage() override {
133 GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>(
134 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]);
135 *fail_send_message_ = true;
136 }
137
138 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
139 override {
140 return recv_trailing_metadata_->map();
141 }
142
143 void SetSendMessage(ByteBuffer* buf, const void** msg,
144 bool* fail_send_message,
145 std::function<Status(const void*)> serializer) {
146 send_message_ = buf;
147 orig_send_message_ = msg;
148 fail_send_message_ = fail_send_message;
149 serializer_ = serializer;
150 }
151
152 void SetSendInitialMetadata(
153 std::multimap<grpc::string, grpc::string>* metadata) {
154 send_initial_metadata_ = metadata;
155 }
156
157 void SetSendStatus(grpc_status_code* code, grpc::string* error_details,
158 grpc::string* error_message) {
159 code_ = code;
160 error_details_ = error_details;
161 error_message_ = error_message;
162 }
163
164 void SetSendTrailingMetadata(
165 std::multimap<grpc::string, grpc::string>* metadata) {
166 send_trailing_metadata_ = metadata;
167 }
168
169 void SetRecvMessage(void* message, bool* hijacked_recv_message_failed) {
170 recv_message_ = message;
171 hijacked_recv_message_failed_ = hijacked_recv_message_failed;
172 }
173
174 void SetRecvInitialMetadata(MetadataMap* map) {
175 recv_initial_metadata_ = map;
176 }
177
178 void SetRecvStatus(Status* status) { recv_status_ = status; }
179
180 void SetRecvTrailingMetadata(MetadataMap* map) {
181 recv_trailing_metadata_ = map;
182 }
183
184 std::unique_ptr<ChannelInterface> GetInterceptedChannel() override {
185 auto* info = call_->client_rpc_info();
186 if (info == nullptr) {
187 return std::unique_ptr<ChannelInterface>(nullptr);
188 }
189 // The intercepted channel starts from the interceptor just after the
190 // current interceptor
191 return std::unique_ptr<ChannelInterface>(new InterceptedChannel(
192 info->channel(), current_interceptor_index_ + 1));
193 }
194
195 void FailHijackedRecvMessage() override {
196 GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>(
197 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]);
198 *hijacked_recv_message_failed_ = true;
199 }
200
201 // Clears all state
202 void ClearState() {
203 reverse_ = false;
204 ran_hijacking_interceptor_ = false;
205 ClearHookPoints();
206 }
207
208 // Prepares for Post_recv operations
209 void SetReverse() {
210 reverse_ = true;
211 ran_hijacking_interceptor_ = false;
212 ClearHookPoints();
213 }
214
215 // This needs to be set before interceptors are run
216 void SetCall(Call* call) { call_ = call; }
217
218 // This needs to be set before interceptors are run using RunInterceptors().
219 // Alternatively, RunInterceptors(std::function<void(void)> f) can be used.
220 void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; }
221
222 // SetCall should have been called before this.
223 // Returns true if the interceptors list is empty
224 bool InterceptorsListEmpty() {
225 auto* client_rpc_info = call_->client_rpc_info();
226 if (client_rpc_info != nullptr) {
227 if (client_rpc_info->interceptors_.size() == 0) {
228 return true;
229 } else {
230 return false;
231 }
232 }
233
234 auto* server_rpc_info = call_->server_rpc_info();
235 if (server_rpc_info == nullptr ||
236 server_rpc_info->interceptors_.size() == 0) {
237 return true;
238 }
239 return false;
240 }
241
242 // This should be used only by subclasses of CallOpSetInterface. SetCall and
243 // SetCallOpSetInterface should have been called before this. After all the
244 // interceptors are done running, either ContinueFillOpsAfterInterception or
245 // ContinueFinalizeOpsAfterInterception will be called. Note that neither of
246 // them is invoked if there were no interceptors registered.
247 bool RunInterceptors() {
248 GPR_CODEGEN_ASSERT(ops_);
249 auto* client_rpc_info = call_->client_rpc_info();
250 if (client_rpc_info != nullptr) {
251 if (client_rpc_info->interceptors_.size() == 0) {
252 return true;
253 } else {
254 RunClientInterceptors();
255 return false;
256 }
257 }
258
259 auto* server_rpc_info = call_->server_rpc_info();
260 if (server_rpc_info == nullptr ||
261 server_rpc_info->interceptors_.size() == 0) {
262 return true;
263 }
264 RunServerInterceptors();
265 return false;
266 }
267
268 // Returns true if no interceptors are run. Returns false otherwise if there
269 // are interceptors registered. After the interceptors are done running \a f
270 // will be invoked. This is to be used only by BaseAsyncRequest and
271 // SyncRequest.
272 bool RunInterceptors(std::function<void(void)> f) {
273 // This is used only by the server for initial call request
274 GPR_CODEGEN_ASSERT(reverse_ == true);
275 GPR_CODEGEN_ASSERT(call_->client_rpc_info() == nullptr);
276 auto* server_rpc_info = call_->server_rpc_info();
277 if (server_rpc_info == nullptr ||
278 server_rpc_info->interceptors_.size() == 0) {
279 return true;
280 }
281 callback_ = std::move(f);
282 RunServerInterceptors();
283 return false;
284 }
285
286 private:
287 void RunClientInterceptors() {
288 auto* rpc_info = call_->client_rpc_info();
289 if (!reverse_) {
290 current_interceptor_index_ = 0;
291 } else {
292 if (rpc_info->hijacked_) {
293 current_interceptor_index_ = rpc_info->hijacked_interceptor_;
294 } else {
295 current_interceptor_index_ = rpc_info->interceptors_.size() - 1;
296 }
297 }
298 rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_);
299 }
300
301 void RunServerInterceptors() {
302 auto* rpc_info = call_->server_rpc_info();
303 if (!reverse_) {
304 current_interceptor_index_ = 0;
305 } else {
306 current_interceptor_index_ = rpc_info->interceptors_.size() - 1;
307 }
308 rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_);
309 }
310
311 void ProceedClient() {
312 auto* rpc_info = call_->client_rpc_info();
313 if (rpc_info->hijacked_ && !reverse_ &&
314 current_interceptor_index_ == rpc_info->hijacked_interceptor_ &&
315 !ran_hijacking_interceptor_) {
316 // We now need to provide hijacked recv ops to this interceptor
317 ClearHookPoints();
318 ops_->SetHijackingState();
319 ran_hijacking_interceptor_ = true;
320 rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_);
321 return;
322 }
323 if (!reverse_) {
324 current_interceptor_index_++;
325 // We are going down the stack of interceptors
326 if (current_interceptor_index_ < rpc_info->interceptors_.size()) {
327 if (rpc_info->hijacked_ &&
328 current_interceptor_index_ > rpc_info->hijacked_interceptor_) {
329 // This is a hijacked RPC and we are done with hijacking
330 ops_->ContinueFillOpsAfterInterception();
331 } else {
332 rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_);
333 }
334 } else {
335 // we are done running all the interceptors without any hijacking
336 ops_->ContinueFillOpsAfterInterception();
337 }
338 } else {
339 // We are going up the stack of interceptors
340 if (current_interceptor_index_ > 0) {
341 // Continue running interceptors
342 current_interceptor_index_--;
343 rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_);
344 } else {
345 // we are done running all the interceptors without any hijacking
346 ops_->ContinueFinalizeResultAfterInterception();
347 }
348 }
349 }
350
351 void ProceedServer() {
352 auto* rpc_info = call_->server_rpc_info();
353 if (!reverse_) {
354 current_interceptor_index_++;
355 if (current_interceptor_index_ < rpc_info->interceptors_.size()) {
356 return rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_);
357 } else if (ops_) {
358 return ops_->ContinueFillOpsAfterInterception();
359 }
360 } else {
361 // We are going up the stack of interceptors
362 if (current_interceptor_index_ > 0) {
363 // Continue running interceptors
364 current_interceptor_index_--;
365 return rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_);
366 } else if (ops_) {
367 return ops_->ContinueFinalizeResultAfterInterception();
368 }
369 }
370 GPR_CODEGEN_ASSERT(callback_);
371 callback_();
372 }
373
374 void ClearHookPoints() {
375 for (auto i = static_cast<experimental::InterceptionHookPoints>(0);
376 i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS;
377 i = static_cast<experimental::InterceptionHookPoints>(
378 static_cast<size_t>(i) + 1)) {
379 hooks_[static_cast<size_t>(i)] = false;
380 }
381 }
382
383 std::array<bool,
384 static_cast<size_t>(
385 experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)>
386 hooks_;
387
388 size_t current_interceptor_index_ = 0; // Current iterator
389 bool reverse_ = false;
390 bool ran_hijacking_interceptor_ = false;
391 Call* call_ = nullptr; // The Call object is present along with CallOpSet
392 // object/callback
393 CallOpSetInterface* ops_ = nullptr;
394 std::function<void(void)> callback_;
395
396 ByteBuffer* send_message_ = nullptr;
397 bool* fail_send_message_ = nullptr;
398 const void** orig_send_message_ = nullptr;
399 std::function<Status(const void*)> serializer_;
400
401 std::multimap<grpc::string, grpc::string>* send_initial_metadata_;
402
403 grpc_status_code* code_ = nullptr;
404 grpc::string* error_details_ = nullptr;
405 grpc::string* error_message_ = nullptr;
406
407 std::multimap<grpc::string, grpc::string>* send_trailing_metadata_ = nullptr;
408
409 void* recv_message_ = nullptr;
410 bool* hijacked_recv_message_failed_ = nullptr;
411
412 MetadataMap* recv_initial_metadata_ = nullptr;
413
414 Status* recv_status_ = nullptr;
415
416 MetadataMap* recv_trailing_metadata_ = nullptr;
417};
418
419// A special implementation of InterceptorBatchMethods to send a Cancel
420// notification down the interceptor stack
421class CancelInterceptorBatchMethods
422 : public experimental::InterceptorBatchMethods {
423 public:
424 bool QueryInterceptionHookPoint(
425 experimental::InterceptionHookPoints type) override {
426 if (type == experimental::InterceptionHookPoints::PRE_SEND_CANCEL) {
427 return true;
428 } else {
429 return false;
430 }
431 }
432
433 void Proceed() override {
434 // This is a no-op. For actual continuation of the RPC simply needs to
435 // return from the Intercept method
436 }
437
438 void Hijack() override {
439 // Only the client can hijack when sending down initial metadata
440 GPR_CODEGEN_ASSERT(false &&
441 "It is illegal to call Hijack on a method which has a "
442 "Cancel notification");
443 }
444
445 ByteBuffer* GetSerializedSendMessage() override {
446 GPR_CODEGEN_ASSERT(false &&
447 "It is illegal to call GetSendMessage on a method which "
448 "has a Cancel notification");
449 return nullptr;
450 }
451
452 bool GetSendMessageStatus() override {
453 GPR_CODEGEN_ASSERT(
454 false &&
455 "It is illegal to call GetSendMessageStatus on a method which "
456 "has a Cancel notification");
457 return false;
458 }
459
460 const void* GetSendMessage() override {
461 GPR_CODEGEN_ASSERT(
462 false &&
463 "It is illegal to call GetOriginalSendMessage on a method which "
464 "has a Cancel notification");
465 return nullptr;
466 }
467
468 void ModifySendMessage(const void* /*message*/) override {
469 GPR_CODEGEN_ASSERT(
470 false &&
471 "It is illegal to call ModifySendMessage on a method which "
472 "has a Cancel notification");
473 }
474
475 std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override {
476 GPR_CODEGEN_ASSERT(false &&
477 "It is illegal to call GetSendInitialMetadata on a "
478 "method which has a Cancel notification");
479 return nullptr;
480 }
481
482 Status GetSendStatus() override {
483 GPR_CODEGEN_ASSERT(false &&
484 "It is illegal to call GetSendStatus on a method which "
485 "has a Cancel notification");
486 return Status();
487 }
488
489 void ModifySendStatus(const Status& /*status*/) override {
490 GPR_CODEGEN_ASSERT(false &&
491 "It is illegal to call ModifySendStatus on a method "
492 "which has a Cancel notification");
493 return;
494 }
495
496 std::multimap<grpc::string, grpc::string>* GetSendTrailingMetadata()
497 override {
498 GPR_CODEGEN_ASSERT(false &&
499 "It is illegal to call GetSendTrailingMetadata on a "
500 "method which has a Cancel notification");
501 return nullptr;
502 }
503
504 void* GetRecvMessage() override {
505 GPR_CODEGEN_ASSERT(false &&
506 "It is illegal to call GetRecvMessage on a method which "
507 "has a Cancel notification");
508 return nullptr;
509 }
510
511 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata()
512 override {
513 GPR_CODEGEN_ASSERT(false &&
514 "It is illegal to call GetRecvInitialMetadata on a "
515 "method which has a Cancel notification");
516 return nullptr;
517 }
518
519 Status* GetRecvStatus() override {
520 GPR_CODEGEN_ASSERT(false &&
521 "It is illegal to call GetRecvStatus on a method which "
522 "has a Cancel notification");
523 return nullptr;
524 }
525
526 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
527 override {
528 GPR_CODEGEN_ASSERT(false &&
529 "It is illegal to call GetRecvTrailingMetadata on a "
530 "method which has a Cancel notification");
531 return nullptr;
532 }
533
534 std::unique_ptr<ChannelInterface> GetInterceptedChannel() override {
535 GPR_CODEGEN_ASSERT(false &&
536 "It is illegal to call GetInterceptedChannel on a "
537 "method which has a Cancel notification");
538 return std::unique_ptr<ChannelInterface>(nullptr);
539 }
540
541 void FailHijackedRecvMessage() override {
542 GPR_CODEGEN_ASSERT(false &&
543 "It is illegal to call FailHijackedRecvMessage on a "
544 "method which has a Cancel notification");
545 }
546
547 void FailHijackedSendMessage() override {
548 GPR_CODEGEN_ASSERT(false &&
549 "It is illegal to call FailHijackedSendMessage on a "
550 "method which has a Cancel notification");
551 }
552};
553} // namespace internal
554} // namespace grpc
555
556#endif // GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H
557

source code of include/grpcpp/impl/codegen/interceptor_common.h