1//===-- TCPSocket.cpp -----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#if defined(_MSC_VER)
10#define _WINSOCK_DEPRECATED_NO_WARNINGS
11#endif
12
13#include "lldb/Host/common/TCPSocket.h"
14
15#include "lldb/Host/Config.h"
16#include "lldb/Host/MainLoop.h"
17#include "lldb/Utility/LLDBLog.h"
18#include "lldb/Utility/Log.h"
19
20#include "llvm/Config/llvm-config.h"
21#include "llvm/Support/Errno.h"
22#include "llvm/Support/WindowsError.h"
23#include "llvm/Support/raw_ostream.h"
24
25#if LLDB_ENABLE_POSIX
26#include <arpa/inet.h>
27#include <netinet/tcp.h>
28#include <sys/socket.h>
29#endif
30
31#if defined(_WIN32)
32#include <winsock2.h>
33#endif
34
35#ifdef _WIN32
36#define CLOSE_SOCKET closesocket
37typedef const char *set_socket_option_arg_type;
38#else
39#include <unistd.h>
40#define CLOSE_SOCKET ::close
41typedef const void *set_socket_option_arg_type;
42#endif
43
44using namespace lldb;
45using namespace lldb_private;
46
47static Status GetLastSocketError() {
48 std::error_code EC;
49#ifdef _WIN32
50 EC = llvm::mapWindowsError(WSAGetLastError());
51#else
52 EC = std::error_code(errno, std::generic_category());
53#endif
54 return EC;
55}
56
57static const int kType = SOCK_STREAM;
58
59TCPSocket::TCPSocket(bool should_close, bool child_processes_inherit)
60 : Socket(ProtocolTcp, should_close, child_processes_inherit) {}
61
62TCPSocket::TCPSocket(NativeSocket socket, const TCPSocket &listen_socket)
63 : Socket(ProtocolTcp, listen_socket.m_should_close_fd,
64 listen_socket.m_child_processes_inherit) {
65 m_socket = socket;
66}
67
68TCPSocket::TCPSocket(NativeSocket socket, bool should_close,
69 bool child_processes_inherit)
70 : Socket(ProtocolTcp, should_close, child_processes_inherit) {
71 m_socket = socket;
72}
73
74TCPSocket::~TCPSocket() { CloseListenSockets(); }
75
76bool TCPSocket::IsValid() const {
77 return m_socket != kInvalidSocketValue || m_listen_sockets.size() != 0;
78}
79
80// Return the port number that is being used by the socket.
81uint16_t TCPSocket::GetLocalPortNumber() const {
82 if (m_socket != kInvalidSocketValue) {
83 SocketAddress sock_addr;
84 socklen_t sock_addr_len = sock_addr.GetMaxLength();
85 if (::getsockname(fd: m_socket, addr: sock_addr, len: &sock_addr_len) == 0)
86 return sock_addr.GetPort();
87 } else if (!m_listen_sockets.empty()) {
88 SocketAddress sock_addr;
89 socklen_t sock_addr_len = sock_addr.GetMaxLength();
90 if (::getsockname(fd: m_listen_sockets.begin()->first, addr: sock_addr,
91 len: &sock_addr_len) == 0)
92 return sock_addr.GetPort();
93 }
94 return 0;
95}
96
97std::string TCPSocket::GetLocalIPAddress() const {
98 // We bound to port zero, so we need to figure out which port we actually
99 // bound to
100 if (m_socket != kInvalidSocketValue) {
101 SocketAddress sock_addr;
102 socklen_t sock_addr_len = sock_addr.GetMaxLength();
103 if (::getsockname(fd: m_socket, addr: sock_addr, len: &sock_addr_len) == 0)
104 return sock_addr.GetIPAddress();
105 }
106 return "";
107}
108
109uint16_t TCPSocket::GetRemotePortNumber() const {
110 if (m_socket != kInvalidSocketValue) {
111 SocketAddress sock_addr;
112 socklen_t sock_addr_len = sock_addr.GetMaxLength();
113 if (::getpeername(fd: m_socket, addr: sock_addr, len: &sock_addr_len) == 0)
114 return sock_addr.GetPort();
115 }
116 return 0;
117}
118
119std::string TCPSocket::GetRemoteIPAddress() const {
120 // We bound to port zero, so we need to figure out which port we actually
121 // bound to
122 if (m_socket != kInvalidSocketValue) {
123 SocketAddress sock_addr;
124 socklen_t sock_addr_len = sock_addr.GetMaxLength();
125 if (::getpeername(fd: m_socket, addr: sock_addr, len: &sock_addr_len) == 0)
126 return sock_addr.GetIPAddress();
127 }
128 return "";
129}
130
131std::string TCPSocket::GetRemoteConnectionURI() const {
132 if (m_socket != kInvalidSocketValue) {
133 return std::string(llvm::formatv(
134 Fmt: "connect://[{0}]:{1}", Vals: GetRemoteIPAddress(), Vals: GetRemotePortNumber()));
135 }
136 return "";
137}
138
139Status TCPSocket::CreateSocket(int domain) {
140 Status error;
141 if (IsValid())
142 error = Close();
143 if (error.Fail())
144 return error;
145 m_socket = Socket::CreateSocket(domain, type: kType, IPPROTO_TCP,
146 child_processes_inherit: m_child_processes_inherit, error);
147 return error;
148}
149
150Status TCPSocket::Connect(llvm::StringRef name) {
151
152 Log *log = GetLog(mask: LLDBLog::Communication);
153 LLDB_LOG(log, "Connect to host/port {0}", name);
154
155 Status error;
156 llvm::Expected<HostAndPort> host_port = DecodeHostAndPort(host_and_port: name);
157 if (!host_port)
158 return Status(host_port.takeError());
159
160 std::vector<SocketAddress> addresses =
161 SocketAddress::GetAddressInfo(hostname: host_port->hostname.c_str(), servname: nullptr,
162 AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
163 for (SocketAddress &address : addresses) {
164 error = CreateSocket(domain: address.GetFamily());
165 if (error.Fail())
166 continue;
167
168 address.SetPort(host_port->port);
169
170 if (llvm::sys::RetryAfterSignal(Fail: -1, F&: ::connect, As: GetNativeSocket(),
171 As: &address.sockaddr(),
172 As: address.GetLength()) == -1) {
173 Close();
174 continue;
175 }
176
177 if (SetOptionNoDelay() == -1) {
178 Close();
179 continue;
180 }
181
182 error.Clear();
183 return error;
184 }
185
186 error.SetErrorString("Failed to connect port");
187 return error;
188}
189
190Status TCPSocket::Listen(llvm::StringRef name, int backlog) {
191 Log *log = GetLog(mask: LLDBLog::Connection);
192 LLDB_LOG(log, "Listen to {0}", name);
193
194 Status error;
195 llvm::Expected<HostAndPort> host_port = DecodeHostAndPort(host_and_port: name);
196 if (!host_port)
197 return Status(host_port.takeError());
198
199 if (host_port->hostname == "*")
200 host_port->hostname = "0.0.0.0";
201 std::vector<SocketAddress> addresses = SocketAddress::GetAddressInfo(
202 hostname: host_port->hostname.c_str(), servname: nullptr, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
203 for (SocketAddress &address : addresses) {
204 int fd = Socket::CreateSocket(domain: address.GetFamily(), type: kType, IPPROTO_TCP,
205 child_processes_inherit: m_child_processes_inherit, error);
206 if (error.Fail() || fd < 0)
207 continue;
208
209 // enable local address reuse
210 int option_value = 1;
211 set_socket_option_arg_type option_value_p =
212 reinterpret_cast<set_socket_option_arg_type>(&option_value);
213 if (::setsockopt(fd: fd, SOL_SOCKET, SO_REUSEADDR, optval: option_value_p,
214 optlen: sizeof(option_value)) == -1) {
215 CLOSE_SOCKET(fd: fd);
216 continue;
217 }
218
219 SocketAddress listen_address = address;
220 if(!listen_address.IsLocalhost())
221 listen_address.SetToAnyAddress(family: address.GetFamily(), port: host_port->port);
222 else
223 listen_address.SetPort(host_port->port);
224
225 int err =
226 ::bind(fd: fd, addr: &listen_address.sockaddr(), len: listen_address.GetLength());
227 if (err != -1)
228 err = ::listen(fd: fd, n: backlog);
229
230 if (err == -1) {
231 error = GetLastSocketError();
232 CLOSE_SOCKET(fd: fd);
233 continue;
234 }
235
236 if (host_port->port == 0) {
237 socklen_t sa_len = address.GetLength();
238 if (getsockname(fd: fd, addr: &address.sockaddr(), len: &sa_len) == 0)
239 host_port->port = address.GetPort();
240 }
241 m_listen_sockets[fd] = address;
242 }
243
244 if (m_listen_sockets.empty()) {
245 assert(error.Fail());
246 return error;
247 }
248 return Status();
249}
250
251void TCPSocket::CloseListenSockets() {
252 for (auto socket : m_listen_sockets)
253 CLOSE_SOCKET(fd: socket.first);
254 m_listen_sockets.clear();
255}
256
257Status TCPSocket::Accept(Socket *&conn_socket) {
258 Status error;
259 if (m_listen_sockets.size() == 0) {
260 error.SetErrorString("No open listening sockets!");
261 return error;
262 }
263
264 NativeSocket sock = kInvalidSocketValue;
265 NativeSocket listen_sock = kInvalidSocketValue;
266 lldb_private::SocketAddress AcceptAddr;
267 MainLoop accept_loop;
268 std::vector<MainLoopBase::ReadHandleUP> handles;
269 for (auto socket : m_listen_sockets) {
270 auto fd = socket.first;
271 auto inherit = this->m_child_processes_inherit;
272 auto io_sp = IOObjectSP(new TCPSocket(socket.first, false, inherit));
273 handles.emplace_back(args: accept_loop.RegisterReadObject(
274 object_sp: io_sp, callback: [fd, inherit, &sock, &AcceptAddr, &error,
275 &listen_sock](MainLoopBase &loop) {
276 socklen_t sa_len = AcceptAddr.GetMaxLength();
277 sock = AcceptSocket(sockfd: fd, addr: &AcceptAddr.sockaddr(), addrlen: &sa_len, child_processes_inherit: inherit,
278 error);
279 listen_sock = fd;
280 loop.RequestTermination();
281 }, error));
282 if (error.Fail())
283 return error;
284 }
285
286 bool accept_connection = false;
287 std::unique_ptr<TCPSocket> accepted_socket;
288 // Loop until we are happy with our connection
289 while (!accept_connection) {
290 accept_loop.Run();
291
292 if (error.Fail())
293 return error;
294
295 lldb_private::SocketAddress &AddrIn = m_listen_sockets[listen_sock];
296 if (!AddrIn.IsAnyAddr() && AcceptAddr != AddrIn) {
297 if (sock != kInvalidSocketValue) {
298 CLOSE_SOCKET(fd: sock);
299 sock = kInvalidSocketValue;
300 }
301 llvm::errs() << llvm::formatv(
302 Fmt: "error: rejecting incoming connection from {0} (expecting {1})",
303 Vals: AcceptAddr.GetIPAddress(), Vals: AddrIn.GetIPAddress());
304 continue;
305 }
306 accept_connection = true;
307 accepted_socket.reset(p: new TCPSocket(sock, *this));
308 }
309
310 if (!accepted_socket)
311 return error;
312
313 // Keep our TCP packets coming without any delays.
314 accepted_socket->SetOptionNoDelay();
315 error.Clear();
316 conn_socket = accepted_socket.release();
317 return error;
318}
319
320int TCPSocket::SetOptionNoDelay() {
321 return SetOption(IPPROTO_TCP, TCP_NODELAY, option_value: 1);
322}
323
324int TCPSocket::SetOptionReuseAddress() {
325 return SetOption(SOL_SOCKET, SO_REUSEADDR, option_value: 1);
326}
327

source code of lldb/source/Host/common/TCPSocket.cpp