1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 module thrift.transport.socket;
20 
21 import core.thread : Thread;
22 import core.time : dur, Duration;
23 import std.array : empty;
24 import std.conv : text, to;
25 import std.exception : enforce;
26 import std.socket;
27 import thrift.base;
28 import thrift.transport.base;
29 import thrift.internal.socket;
30 
31 /**
32  * Common parts of a socket TTransport implementation, regardless of how the
33  * actual I/O is performed (sync/async).
34  */
35 abstract class TSocketBase : TBaseTransport {
36   /**
37    * Constructor that takes an already created, connected (!) socket.
38    *
39    * Params:
40    *   socket = Already created, connected socket object.
41    */
42   this(Socket socket) {
43     socket_ = socket;
44     setSocketOpts();
45   }
46 
47   /**
48    * Creates a new unconnected socket that will connect to the given host
49    * on the given port.
50    *
51    * Params:
52    *   host = Remote host.
53    *   port = Remote port.
54    */
55   this(string host, ushort port) {
56     host_ = host;
57     port_ = port;
58   }
59 
60   /**
61    * Checks whether the socket is connected.
62    */
63   override bool isOpen() @property {
64     return socket_ !is null;
65   }
66 
67   /**
68    * Writes as much data to the socket as there can be in a single OS call.
69    *
70    * Params:
71    *   buf = Data to write.
72    *
73    * Returns: The actual number of bytes written. Never more than buf.length.
74    */
75   abstract size_t writeSome(in ubyte[] buf) out (written) {
76     // DMD @@BUG@@: Enabling this e.g. fails the contract in the
77     // async_test_server, because buf.length evaluates to 0 here, even though
78     // in the method body it correctly is 27 (equal to the return value).
79     version (none) assert(written <= buf.length, text("Implementation wrote " ~
80       "more data than requested to?! (", written, " vs. ", buf.length, ")"));
81   } body {
82     assert(0, "DMD bug? – Why would contracts work for interfaces, but not " ~
83       "for abstract methods? " ~
84       "(Error: function […] in and out contracts require function body");
85   }
86 
87   /**
88    * Returns the actual address of the peer the socket is connected to.
89    *
90    * In contrast, the host and port properties contain the address used to
91    * establish the connection, and are not updated after the connection.
92    *
93    * The socket must be open when calling this.
94    */
95   Address getPeerAddress() {
96     enforce(isOpen, new TTransportException("Cannot get peer host for " ~
97       "closed socket.", TTransportException.Type.NOT_OPEN));
98 
99     if (!peerAddress_) {
100       peerAddress_ = socket_.remoteAddress();
101       assert(peerAddress_);
102     }
103 
104     return peerAddress_;
105   }
106 
107   /**
108    * The host the socket is connected to or will connect to. Null if an
109    * already connected socket was used to construct the object.
110    */
111   string host() const @property {
112     return host_;
113   }
114 
115   /**
116    * The port the socket is connected to or will connect to. Zero if an
117    * already connected socket was used to construct the object.
118    */
119   ushort port() const @property {
120     return port_;
121   }
122 
123   /// The socket send timeout.
124   Duration sendTimeout() const @property {
125     return sendTimeout_;
126   }
127 
128   /// Ditto
129   void sendTimeout(Duration value) @property {
130     sendTimeout_ = value;
131   }
132 
133   /// The socket receiving timeout. Values smaller than 500 ms are not
134   /// supported on Windows.
135   Duration recvTimeout() const @property {
136     return recvTimeout_;
137   }
138 
139   /// Ditto
140   void recvTimeout(Duration value) @property {
141     recvTimeout_ = value;
142   }
143 
144   /**
145    * Returns the OS handle of the underlying socket.
146    *
147    * Should not usually be used directly, but access to it can be necessary
148    * to interface with C libraries.
149    */
150   typeof(socket_.handle()) socketHandle() @property {
151     return socket_.handle();
152   }
153 
154 protected:
155   /**
156    * Sets the needed socket options.
157    */
158   void setSocketOpts() {
159     try {
160       alias SocketOptionLevel.SOCKET lvlSock;
161       Linger l;
162       l.on = 0;
163       l.time = 0;
164       socket_.setOption(lvlSock, SocketOption.LINGER, l);
165     } catch (SocketException e) {
166       logError("Could not set socket option: %s", e);
167     }
168 
169     // Just try to disable Nagle's algorithm – this will fail if we are passed
170     // in a non-TCP socket via the Socket-accepting constructor.
171     try {
172       socket_.setOption(SocketOptionLevel.TCP, SocketOption.TCP_NODELAY, true);
173     } catch (SocketException e) {}
174   }
175 
176   /// Remote host.
177   string host_;
178 
179   /// Remote port.
180   ushort port_;
181 
182   /// Timeout for sending.
183   Duration sendTimeout_;
184 
185   /// Timeout for receiving.
186   Duration recvTimeout_;
187 
188   /// Cached peer address.
189   Address peerAddress_;
190 
191   /// Cached peer host name.
192   string peerHost_;
193 
194   /// Cached peer port.
195   ushort peerPort_;
196 
197   /// Wrapped socket object.
198   Socket socket_;
199 }
200 
201 /**
202  * Socket implementation of the TTransport interface.
203  *
204  * Due to the limitations of std.socket, currently only TCP/IP sockets are
205  * supported (i.e. Unix domain sockets are not).
206  */
207 class TSocket : TSocketBase {
208   ///
209   this(Socket socket) {
210     super(socket);
211   }
212 
213   ///
214   this(string host, ushort port) {
215     super(host, port);
216   }
217 
218   /**
219    * Connects the socket.
220    */
221   override void open() {
222     if (isOpen) return;
223 
224     enforce(!host_.empty, new TTransportException(
225       "Cannot open socket to null host.", TTransportException.Type.NOT_OPEN));
226     enforce(port_ != 0, new TTransportException(
227       "Cannot open socket to port zero.", TTransportException.Type.NOT_OPEN));
228 
229     Address[] addrs;
230     try {
231       addrs = getAddress(host_, port_);
232     } catch (SocketException e) {
233       throw new TTransportException("Could not resolve given host string.",
234         TTransportException.Type.NOT_OPEN, __FILE__, __LINE__, e);
235     }
236 
237     Exception[] errors;
238     foreach (addr; addrs) {
239       try {
240         socket_ = new TcpSocket(addr.addressFamily);
241         setSocketOpts();
242         socket_.connect(addr);
243         break;
244       } catch (SocketException e) {
245         errors ~= e;
246       }
247     }
248     if (errors.length == addrs.length) {
249       socket_ = null;
250       // Need to throw a TTransportException to abide the TTransport API.
251       import std.algorithm, std.range;
252       import std.array:array;
253       throw new TTransportException(
254         text("Failed to connect to ", host_, ":", port_, "."),
255         TTransportException.Type.NOT_OPEN,
256         __FILE__, __LINE__,
257         new TCompoundOperationException(
258           text(
259             "All addresses tried failed (",
260             joiner(map!q{text(a[0], `: "`, a[1].msg, `"`)}(zip(addrs, errors)).array, ", "),
261             ")."
262           ),
263           errors
264         )
265       );
266     }
267   }
268 
269   /**
270    * Closes the socket.
271    */
272   override void close() {
273     if (!isOpen) return;
274 
275     socket_.close();
276     socket_ = null;
277   }
278 
279   override bool peek() {
280     if (!isOpen) return false;
281 
282     ubyte buf;
283     auto r = socket_.receive((&buf)[0 .. 1], SocketFlags.PEEK);
284     if (r == -1) {
285       auto lastErrno = getSocketErrno();
286       static if (connresetOnPeerShutdown) {
287         if (lastErrno == ECONNRESET) {
288           close();
289           return false;
290         }
291       }
292       throw new TTransportException("Peeking into socket failed: " ~
293         socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
294     }
295     return (r > 0);
296   }
297 
298   override size_t read(ubyte[] buf) {
299     enforce(isOpen, new TTransportException(
300       "Cannot read if socket is not open.", TTransportException.Type.NOT_OPEN));
301 
302     typeof(getSocketErrno()) lastErrno;
303     ushort tries;
304     while (tries++ <= maxRecvRetries_) {
305       auto r = socket_.receive(cast(void[])buf);
306 
307       // If recv went fine, immediately return.
308       if (r >= 0) return r;
309 
310       // Something went wrong, find out how to handle it.
311       lastErrno = getSocketErrno();
312 
313       if (lastErrno == INTERRUPTED_ERRNO) {
314         // If the syscall was interrupted, just try again.
315         continue;
316       }
317 
318       static if (connresetOnPeerShutdown) {
319         // See top comment.
320         if (lastErrno == ECONNRESET) {
321           return 0;
322         }
323       }
324 
325       // Not an error which is handled in a special way, just leave the loop.
326       break;
327     }
328 
329     if (isSocketCloseErrno(lastErrno)) {
330       close();
331       throw new TTransportException("Receiving failed, closing socket: " ~
332         socketErrnoString(lastErrno), TTransportException.Type.NOT_OPEN);
333     } else if (lastErrno == TIMEOUT_ERRNO) {
334       throw new TTransportException(TTransportException.Type.TIMED_OUT);
335     } else {
336       throw new TTransportException("Receiving from socket failed: " ~
337         socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
338     }
339   }
340 
341   override void write(in ubyte[] buf) {
342     size_t sent;
343     while (sent < buf.length) {
344       auto b = writeSome(buf[sent .. $]);
345       if (b == 0) {
346         // This should only happen if the timeout set with SO_SNDTIMEO expired.
347         throw new TTransportException("send() timeout expired.",
348           TTransportException.Type.TIMED_OUT);
349       }
350       sent += b;
351     }
352     assert(sent == buf.length);
353   }
354 
355   override size_t writeSome(in ubyte[] buf) {
356     enforce(isOpen, new TTransportException(
357       "Cannot write if file is not open.", TTransportException.Type.NOT_OPEN));
358 
359     auto r = socket_.send(buf);
360 
361     // Everything went well, just return the number of bytes written.
362     if (r > 0) return r;
363 
364     // Handle error conditions.
365     if (r < 0) {
366       auto lastErrno = getSocketErrno();
367 
368       if (lastErrno == WOULD_BLOCK_ERRNO) {
369         // Not an exceptional error per se – even with blocking sockets,
370         // EAGAIN apparently is returned sometimes on out-of-resource
371         // conditions (see the C++ implementation for details). Also, this
372         // allows using TSocket with non-blocking sockets e.g. in
373         // TNonblockingServer.
374         return 0;
375       }
376 
377       auto type = TTransportException.Type.UNKNOWN;
378       if (isSocketCloseErrno(lastErrno)) {
379         type = TTransportException.Type.NOT_OPEN;
380         close();
381       }
382 
383       throw new TTransportException("Sending to socket failed: " ~
384         socketErrnoString(lastErrno), type);
385     }
386 
387     // send() should never return 0.
388     throw new TTransportException("Sending to socket failed (0 bytes written).",
389       TTransportException.Type.UNKNOWN);
390   }
391 
392   override void sendTimeout(Duration value) @property {
393     super.sendTimeout(value);
394     setTimeout(SocketOption.SNDTIMEO, value);
395   }
396 
397   override void recvTimeout(Duration value) @property {
398     super.recvTimeout(value);
399     setTimeout(SocketOption.RCVTIMEO, value);
400   }
401 
402   /**
403    * Maximum number of retries for receiving from socket on read() in case of
404    * EAGAIN/EINTR.
405    */
406   ushort maxRecvRetries() @property const {
407     return maxRecvRetries_;
408   }
409 
410   /// Ditto
411   void maxRecvRetries(ushort value) @property {
412     maxRecvRetries_ = value;
413   }
414 
415   /// Ditto
416   enum DEFAULT_MAX_RECV_RETRIES = 5;
417 
418 protected:
419   override void setSocketOpts() {
420     super.setSocketOpts();
421     setTimeout(SocketOption.SNDTIMEO, sendTimeout_);
422     setTimeout(SocketOption.RCVTIMEO, recvTimeout_);
423   }
424 
425   void setTimeout(SocketOption type, Duration value) {
426     assert(type == SocketOption.SNDTIMEO || type == SocketOption.RCVTIMEO);
427     version (Windows) {
428       if (value > dur!"hnsecs"(0) && value < dur!"msecs"(500)) {
429         logError(
430           "Socket %s timeout of %s ms might be raised to 500 ms on Windows.",
431           (type == SocketOption.SNDTIMEO) ? "send" : "receive",
432           value.total!"msecs"
433         );
434       }
435     }
436 
437     if (socket_) {
438       try {
439         socket_.setOption(SocketOptionLevel.SOCKET, type, value);
440       } catch (SocketException e) {
441         throw new TTransportException(
442           "Could not set timeout.",
443           TTransportException.Type.UNKNOWN,
444           __FILE__,
445           __LINE__,
446           e
447         );
448       }
449     }
450   }
451 
452   /// Maximum number of recv() retries.
453   ushort maxRecvRetries_  = DEFAULT_MAX_RECV_RETRIES;
454 }