1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /**
21  * OpenSSL socket implementation, in large parts ported from C++.
22  */
23 module thrift.transport.ssl;
24 
25 import core.exception : onOutOfMemoryError;
26 import core.stdc.errno : errno, EINTR;
27 import core.sync.mutex : Mutex;
28 import core.memory : GC;
29 import core.stdc.config;
30 import core.stdc.stdlib : free, malloc;
31 import std.ascii : toUpper;
32 import std.array : empty, front, popFront;
33 import std.conv : emplace, to;
34 import std.exception : enforce;
35 import std.socket : Address, InternetAddress, Internet6Address, Socket;
36 import std..string : toStringz;
37 import deimos.openssl.err;
38 import deimos.openssl.rand;
39 import deimos.openssl.ssl;
40 import deimos.openssl.x509v3;
41 import thrift.base;
42 import thrift.internal.ssl;
43 import thrift.transport.base;
44 import thrift.transport.socket;
45 
46 /**
47  * SSL encrypted socket implementation using OpenSSL.
48  *
49  * Note:
50  * On Posix systems which do not have the BSD-specific SO_NOSIGPIPE flag, you
51  * might want to ignore the SIGPIPE signal, as OpenSSL might try to write to
52  * a closed socket if the peer disconnects abruptly:
53  * ---
54  * import core.stdc.signal;
55  * import core.sys.posix.signal;
56  * signal(SIGPIPE, SIG_IGN);
57  * ---
58  */
59 final class TSSLSocket : TSocket {
60   /**
61    * Creates an instance that wraps an already created, connected (!) socket.
62    *
63    * Params:
64    *   context = The SSL socket context to use. A reference to it is stored so
65    *     that it doesn't get cleaned up while the socket is used.
66    *   socket = Already created, connected socket object.
67    */
68   this(TSSLContext context, Socket socket) {
69     super(socket);
70     context_ = context;
71     serverSide_ = context.serverSide;
72     accessManager_ = context.accessManager;
73   }
74 
75   /**
76    * Creates a new unconnected socket that will connect to the given host
77    * on the given port.
78    *
79    * Params:
80    *   context = The SSL socket context to use. A reference to it is stored so
81     *     that it doesn't get cleaned up while the socket is used.
82    *   host = Remote host.
83    *   port = Remote port.
84    */
85   this(TSSLContext context, string host, ushort port) {
86     super(host, port);
87     context_ = context;
88     serverSide_ = context.serverSide;
89     accessManager_ = context.accessManager;
90   }
91 
92   override bool isOpen() @property {
93     if (ssl_ is null || !super.isOpen()) return false;
94 
95     auto shutdown = SSL_get_shutdown(ssl_);
96     bool shutdownReceived = (shutdown & SSL_RECEIVED_SHUTDOWN) != 0;
97     bool shutdownSent = (shutdown & SSL_SENT_SHUTDOWN) != 0;
98     return !(shutdownReceived && shutdownSent);
99   }
100 
101   override bool peek() {
102     if (!isOpen) return false;
103     checkHandshake();
104 
105     byte bt;
106     auto rc = SSL_peek(ssl_, &bt, bt.sizeof);
107     enforce(rc >= 0, getSSLException("SSL_peek"));
108 
109     if (rc == 0) {
110       ERR_clear_error();
111     }
112     return (rc > 0);
113   }
114 
115   override void open() {
116     enforce(!serverSide_, "Cannot open a server-side SSL socket.");
117     if (isOpen) return;
118     super.open();
119   }
120 
121   override void close() {
122     if (!isOpen) return;
123 
124     if (ssl_ !is null) {
125       // Two-step SSL shutdown.
126       auto rc = SSL_shutdown(ssl_);
127       if (rc == 0) {
128         rc = SSL_shutdown(ssl_);
129       }
130       if (rc < 0) {
131         // Do not throw an exception here as leaving the transport "open" will
132         // probably produce only more errors, and the chance we can do
133         // something about the error e.g. by retrying is very low.
134         logError("Error shutting down SSL: %s", getSSLException());
135       }
136 
137       SSL_free(ssl_);
138       ssl_ = null;
139       ERR_remove_state(0);
140     }
141     super.close();
142   }
143 
144   override size_t read(ubyte[] buf) {
145     checkHandshake();
146 
147     int bytes;
148     foreach (_; 0 .. maxRecvRetries) {
149       bytes = SSL_read(ssl_, buf.ptr, cast(int)buf.length);
150       if (bytes >= 0) break;
151 
152       auto errnoCopy = errno;
153       if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) {
154         if (ERR_get_error() == 0 && errnoCopy == EINTR) {
155           // FIXME: Windows.
156           continue;
157         }
158       }
159       throw getSSLException("SSL_read");
160     }
161     return bytes;
162   }
163 
164   override void write(in ubyte[] buf) {
165     checkHandshake();
166 
167     // Loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
168     size_t written = 0;
169     while (written < buf.length) {
170       auto bytes = SSL_write(ssl_, buf.ptr + written,
171         cast(int)(buf.length - written));
172       if (bytes <= 0) {
173         throw getSSLException("SSL_write");
174       }
175       written += bytes;
176     }
177   }
178 
179   override void flush() {
180     checkHandshake();
181 
182     auto bio = SSL_get_wbio(ssl_);
183     enforce(bio !is null, new TSSLException("SSL_get_wbio returned null"));
184 
185     auto rc = BIO_flush(bio);
186     enforce(rc == 1, getSSLException("BIO_flush"));
187   }
188 
189   /**
190    * Whether to use client or server side SSL handshake protocol.
191    */
192   bool serverSide() @property const {
193     return serverSide_;
194   }
195 
196   /// Ditto
197   void serverSide(bool value) @property {
198     serverSide_ = value;
199   }
200 
201   /**
202    * The access manager to use.
203    */
204   void accessManager(TAccessManager value) @property {
205     accessManager_ = value;
206   }
207 
208 private:
209   void checkHandshake() {
210     enforce(super.isOpen(), new TTransportException(
211       TTransportException.Type.NOT_OPEN));
212 
213     if (ssl_ !is null) return;
214     ssl_ = context_.createSSL();
215 
216     SSL_set_fd(ssl_, cast(int)socketHandle);
217     int rc;
218     if (serverSide_) {
219       rc = SSL_accept(ssl_);
220     } else {
221       rc = SSL_connect(ssl_);
222     }
223     enforce(rc > 0, getSSLException());
224     authorize(ssl_, accessManager_, getPeerAddress(),
225       (serverSide_ ? getPeerAddress().toHostNameString() : host));
226   }
227 
228   bool serverSide_;
229   SSL* ssl_;
230   TSSLContext context_;
231   TAccessManager accessManager_;
232 }
233 
234 /**
235  * Represents an OpenSSL context with certification settings, etc. and handles
236  * initialization/teardown.
237  *
238  * OpenSSL is initialized when the first instance of this class is created
239  * and shut down when the last one is destroyed (thread-safe).
240  */
241 class TSSLContext {
242   this() {
243     initMutex_.lock();
244     scope(exit) initMutex_.unlock();
245 
246     if (count_ == 0) {
247       initializeOpenSSL();
248       randomize();
249     }
250     count_++;
251 
252     ctx_ = SSL_CTX_new(SSLv23_method());
253     SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv2);
254     SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv3);   // THRIFT-3164
255     enforce(ctx_, getSSLException("SSL_CTX_new"));
256     SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
257   }
258 
259   ~this() {
260     initMutex_.lock();
261     scope(exit) initMutex_.unlock();
262 
263     if (ctx_ !is null) {
264       SSL_CTX_free(ctx_);
265       ctx_ = null;
266     }
267 
268     count_--;
269     if (count_ == 0) {
270       cleanupOpenSSL();
271     }
272   }
273 
274   /**
275    * Ciphers to be used in SSL handshake process.
276    *
277    * The string must be in the colon-delimited OpenSSL notation described in
278    * ciphers(1), for example: "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH".
279    */
280   void ciphers(string enable) @property {
281     auto rc = SSL_CTX_set_cipher_list(ctx_, toStringz(enable));
282 
283     enforce(ERR_peek_error() == 0, getSSLException("SSL_CTX_set_cipher_list"));
284     enforce(rc > 0, new TSSLException("None of specified ciphers are supported"));
285   }
286 
287   /**
288    * Whether peer is required to present a valid certificate.
289    */
290   void authenticate(bool required) @property {
291     int mode;
292     if (required) {
293       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT |
294         SSL_VERIFY_CLIENT_ONCE;
295     } else {
296       mode = SSL_VERIFY_NONE;
297     }
298     SSL_CTX_set_verify(ctx_, mode, null);
299   }
300 
301   /**
302    * Load server certificate.
303    *
304    * Params:
305    *   path = Path to the certificate file.
306    *   format = Certificate file format. Defaults to PEM, which is currently
307    *     the only one supported.
308    */
309   void loadCertificate(string path, string format = "PEM") {
310     enforce(path !is null && format !is null, new TTransportException(
311       "loadCertificateChain: either <path> or <format> is null",
312       TTransportException.Type.BAD_ARGS));
313 
314     if (format == "PEM") {
315       enforce(SSL_CTX_use_certificate_chain_file(ctx_, toStringz(path)),
316         getSSLException(
317           `Could not load SSL server certificate from file "` ~ path ~ `"`
318         )
319       );
320     } else {
321       throw new TSSLException("Unsupported certificate format: " ~ format);
322     }
323   }
324 
325   /*
326    * Load private key.
327    *
328    * Params:
329    *   path = Path to the certificate file.
330    *   format = Private key file format. Defaults to PEM, which is currently
331    *     the only one supported.
332    */
333   void loadPrivateKey(string path, string format = "PEM") {
334     enforce(path !is null && format !is null, new TTransportException(
335       "loadPrivateKey: either <path> or <format> is NULL",
336       TTransportException.Type.BAD_ARGS));
337 
338     if (format == "PEM") {
339       enforce(SSL_CTX_use_PrivateKey_file(ctx_, toStringz(path), SSL_FILETYPE_PEM),
340         getSSLException(
341           `Could not load SSL private key from file "` ~ path ~ `"`
342         )
343       );
344     } else {
345       throw new TSSLException("Unsupported certificate format: " ~ format);
346     }
347   }
348 
349   /**
350    * Load trusted certificates from specified file (in PEM format).
351    *
352    * Params.
353    *   path = Path to the file containing the trusted certificates.
354    */
355   void loadTrustedCertificates(string path) {
356     enforce(path !is null, new TTransportException(
357       "loadTrustedCertificates: <path> is NULL",
358       TTransportException.Type.BAD_ARGS));
359 
360     enforce(SSL_CTX_load_verify_locations(ctx_, toStringz(path), null),
361       getSSLException(
362         `Could not load SSL trusted certificate list from file "` ~ path ~ `"`
363       )
364     );
365   }
366 
367   /**
368    * Called during OpenSSL initialization to seed the OpenSSL entropy pool.
369    *
370    * Defaults to simply calling RAND_poll(), but it can be overwritten if a
371    * different, perhaps more secure implementation is desired.
372    */
373   void randomize() {
374     RAND_poll();
375   }
376 
377   /**
378    * Whether to use client or server side SSL handshake protocol.
379    */
380   bool serverSide() @property const {
381     return serverSide_;
382   }
383 
384   /// Ditto
385   void serverSide(bool value) @property {
386     serverSide_ = value;
387   }
388 
389   /**
390    * The access manager to use.
391    */
392   TAccessManager accessManager() @property {
393     if (!serverSide_ && !accessManager_) {
394       accessManager_ = new TDefaultClientAccessManager;
395     }
396     return accessManager_;
397   }
398 
399   /// Ditto
400   void accessManager(TAccessManager value) @property {
401     accessManager_ = value;
402   }
403 
404   SSL* createSSL() out (result) {
405     assert(result);
406   } body {
407     auto result = SSL_new(ctx_);
408     enforce(result, getSSLException("SSL_new"));
409     return result;
410   }
411 
412 protected:
413   /**
414    * Override this method for custom password callback. It may be called
415    * multiple times at any time during a session as necessary.
416    *
417    * Params:
418    *   size = Maximum length of password, including null byte.
419    */
420   string getPassword(int size) nothrow out(result) {
421     assert(result.length < size);
422   } body {
423     return "";
424   }
425 
426   /**
427    * Notifies OpenSSL to use getPassword() instead of the default password
428    * callback with getPassword().
429    */
430   void overrideDefaultPasswordCallback() {
431     SSL_CTX_set_default_passwd_cb(ctx_, &passwordCallback);
432     SSL_CTX_set_default_passwd_cb_userdata(ctx_, cast(void*)this);
433   }
434 
435   SSL_CTX* ctx_;
436 
437 private:
438   bool serverSide_;
439   TAccessManager accessManager_;
440 
441   shared static this() {
442     initMutex_ = new Mutex();
443   }
444 
445   static void initializeOpenSSL() {
446     if (initialized_) {
447       return;
448     }
449     initialized_ = true;
450 
451     SSL_library_init();
452     SSL_load_error_strings();
453 
454     mutexes_ = new Mutex[CRYPTO_num_locks()];
455     foreach (ref m; mutexes_) {
456       m = new Mutex;
457     }
458 
459     import thrift.internal.traits;
460     // As per the OpenSSL threads manpage, this isn't needed on Windows.
461     version (Posix) {
462       CRYPTO_set_id_callback(assumeNothrow(&threadIdCallback));
463     }
464     CRYPTO_set_locking_callback(assumeNothrow(&lockingCallback));
465     CRYPTO_set_dynlock_create_callback(assumeNothrow(&dynlockCreateCallback));
466     CRYPTO_set_dynlock_lock_callback(assumeNothrow(&dynlockLockCallback));
467     CRYPTO_set_dynlock_destroy_callback(assumeNothrow(&dynlockDestroyCallback));
468   }
469 
470   static void cleanupOpenSSL() {
471     if (!initialized_) return;
472 
473     initialized_ = false;
474     CRYPTO_set_locking_callback(null);
475     CRYPTO_set_dynlock_create_callback(null);
476     CRYPTO_set_dynlock_lock_callback(null);
477     CRYPTO_set_dynlock_destroy_callback(null);
478     CRYPTO_cleanup_all_ex_data();
479     ERR_free_strings();
480     ERR_remove_state(0);
481   }
482 
483   static extern(C) {
484     version (Posix) {
485       import core.sys.posix.pthread : pthread_self;
486       c_ulong threadIdCallback() {
487         return cast(c_ulong)pthread_self();
488       }
489     }
490 
491     void lockingCallback(int mode, int n, const(char)* file, int line) {
492       if (mode & CRYPTO_LOCK) {
493         mutexes_[n].lock();
494       } else {
495         mutexes_[n].unlock();
496       }
497     }
498 
499     CRYPTO_dynlock_value* dynlockCreateCallback(const(char)* file, int line) {
500       enum size =  __traits(classInstanceSize, Mutex);
501       auto mem = malloc(size)[0 .. size];
502       if (!mem) onOutOfMemoryError();
503       GC.addRange(mem.ptr, size);
504       auto mutex = emplace!Mutex(mem);
505       return cast(CRYPTO_dynlock_value*)mutex;
506     }
507 
508     void dynlockLockCallback(int mode, CRYPTO_dynlock_value* l,
509       const(char)* file, int line)
510     {
511       if (l is null) return;
512       if (mode & CRYPTO_LOCK) {
513         (cast(Mutex)l).lock();
514       } else {
515         (cast(Mutex)l).unlock();
516       }
517     }
518 
519     void dynlockDestroyCallback(CRYPTO_dynlock_value* l,
520       const(char)* file, int line)
521     {
522       GC.removeRange(l);
523       destroy(cast(Mutex)l);
524       free(l);
525     }
526 
527     int passwordCallback(char* password, int size, int, void* data) nothrow {
528       auto context = cast(TSSLContext) data;
529       auto userPassword = context.getPassword(size);
530       auto len = userPassword.length;
531       if (len > size) {
532         len = size;
533       }
534       password[0 .. len] = userPassword[0 .. len]; // TODO: \0 handling correct?
535       return cast(int)len;
536     }
537   }
538 
539   static __gshared bool initialized_;
540   static __gshared Mutex initMutex_;
541   static __gshared Mutex[] mutexes_;
542   static __gshared uint count_;
543 }
544 
545 /**
546  * Decides whether a remote host is legitimate or not.
547  *
548  * It is usually set at a TSSLContext, which then passes it to all the created
549  * TSSLSockets.
550  */
551 class TAccessManager {
552   ///
553   enum Decision {
554     DENY = -1, /// Deny access.
555     SKIP =  0, /// Cannot decide, move on to next check (deny if last).
556     ALLOW = 1  /// Allow access.
557   }
558 
559   /**
560    * Determines whether a peer should be granted access or not based on its
561    * IP address.
562    *
563    * Called once after SSL handshake is completes successfully and before peer
564    * certificate is examined.
565    *
566    * If a valid decision (ALLOW or DENY) is returned, the peer certificate
567    * will not be verified.
568    */
569   Decision verify(Address address) {
570     return Decision.DENY;
571   }
572 
573   /**
574    * Determines whether a peer should be granted access or not based on a
575    * name from its certificate.
576    *
577    * Called every time a DNS subjectAltName/common name is extracted from the
578    * peer's certificate.
579    *
580    * Params:
581    *   host = The actual host name string from the socket connection.
582    *   certHost = A host name string from the certificate.
583    */
584   Decision verify(string host, const(char)[] certHost) {
585     return Decision.DENY;
586   }
587 
588   /**
589    * Determines whether a peer should be granted access or not based on an IP
590    * address from its certificate.
591    *
592    * Called every time an IP subjectAltName is extracted from the peer's
593    * certificate.
594    *
595    * Params:
596    *   address = The actual address from the socket connection.
597    *   certHost = A host name string from the certificate.
598    */
599   Decision verify(Address address, ubyte[] certAddress) {
600     return Decision.DENY;
601   }
602 }
603 
604 /**
605  * Default access manager implementation, which just checks the host name
606  * resp. IP address of the connection against the certificate.
607  */
608 class TDefaultClientAccessManager : TAccessManager {
609   override Decision verify(Address address) {
610     return Decision.SKIP;
611   }
612 
613   override Decision verify(string host, const(char)[] certHost) {
614     if (host.empty || certHost.empty) {
615       return Decision.SKIP;
616     }
617     return (matchName(host, certHost) ? Decision.ALLOW : Decision.SKIP);
618   }
619 
620   override Decision verify(Address address, ubyte[] certAddress) {
621     bool match;
622     if (certAddress.length == 4) {
623       if (auto ia = cast(InternetAddress)address) {
624         match = ((cast(ubyte*)ia.addr())[0 .. 4] == certAddress[]);
625       }
626     } else if (certAddress.length == 16) {
627       if (auto ia = cast(Internet6Address)address) {
628         match = (ia.addr() == certAddress[]);
629       }
630     }
631     return (match ? Decision.ALLOW : Decision.SKIP);
632   }
633 }
634 
635 private {
636   /**
637    * Matches a name with a pattern. The pattern may include wildcard. A single
638    * wildcard "*" can match up to one component in the domain name.
639    *
640    * Params:
641    *   host = Host name to match, typically the SSL remote peer.
642    *   pattern = Host name pattern, typically from the SSL certificate.
643    *
644    * Returns: true if host matches pattern, false otherwise.
645    */
646   bool matchName(const(char)[] host, const(char)[] pattern) {
647     while (!host.empty && !pattern.empty) {
648       if (toUpper(pattern.front) == toUpper(host.front)) {
649         host.popFront;
650         pattern.popFront;
651       } else if (pattern.front == '*') {
652         while (!host.empty && host.front != '.') {
653           host.popFront;
654         }
655         pattern.popFront;
656       } else {
657         break;
658       }
659     }
660     return (host.empty && pattern.empty);
661   }
662 
663   unittest {
664     enforce(matchName("thrift.apache.org", "*.apache.org"));
665     enforce(!matchName("thrift.apache.org", "apache.org"));
666     enforce(matchName("thrift.apache.org", "thrift.*.*"));
667     enforce(matchName("", ""));
668     enforce(!matchName("", "*"));
669   }
670 }
671 
672 /**
673  * SSL-level exception.
674  */
675 class TSSLException : TTransportException {
676   ///
677   this(string msg, string file = __FILE__, size_t line = __LINE__,
678     Throwable next = null)
679   {
680     super(msg, TTransportException.Type.INTERNAL_ERROR, file, line, next);
681   }
682 }