// Copyright 2015 The Chromium OS Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include <brillo/streams/tls_stream.h> #include <algorithm> #include <limits> #include <string> #include <vector> #include <openssl/err.h> #include <openssl/ssl.h> #include <base/bind.h> #include <base/memory/weak_ptr.h> #include <brillo/message_loops/message_loop.h> #include <brillo/secure_blob.h> #include <brillo/streams/openssl_stream_bio.h> #include <brillo/streams/stream_utils.h> #include <brillo/strings/string_utils.h> namespace { // SSL info callback which is called by OpenSSL when we enable logging level of // at least 3. This logs the information about the internal TLS handshake. void TlsInfoCallback(const SSL* /* ssl */, int where, int ret) { std::string reason; std::vector<std::string> info; if (where & SSL_CB_LOOP) info.push_back("loop"); if (where & SSL_CB_EXIT) info.push_back("exit"); if (where & SSL_CB_READ) info.push_back("read"); if (where & SSL_CB_WRITE) info.push_back("write"); if (where & SSL_CB_ALERT) { info.push_back("alert"); reason = ", reason: "; reason += SSL_alert_type_string_long(ret); reason += "/"; reason += SSL_alert_desc_string_long(ret); } if (where & SSL_CB_HANDSHAKE_START) info.push_back("handshake_start"); if (where & SSL_CB_HANDSHAKE_DONE) info.push_back("handshake_done"); VLOG(3) << "TLS progress info: " << brillo::string_utils::Join(",", info) << ", with status: " << ret << reason; } // Static variable to store the index of TlsStream private data in SSL context // used to store custom data for OnCertVerifyResults(). int ssl_ctx_private_data_index = -1; // Default trusted certificate store location. const char kCACertificatePath[] = #ifdef __ANDROID__ "/system/etc/security/cacerts_google"; #else "/usr/share/chromeos-ca-certificates"; #endif } // anonymous namespace namespace brillo { // Helper implementation of TLS stream used to hide most of OpenSSL inner // workings from the users of brillo::TlsStream. class TlsStream::TlsStreamImpl { public: TlsStreamImpl(); ~TlsStreamImpl(); bool Init(StreamPtr socket, const std::string& host, const base::Closure& success_callback, const Stream::ErrorCallback& error_callback, ErrorPtr* error); bool ReadNonBlocking(void* buffer, size_t size_to_read, size_t* size_read, bool* end_of_stream, ErrorPtr* error); bool WriteNonBlocking(const void* buffer, size_t size_to_write, size_t* size_written, ErrorPtr* error); bool Flush(ErrorPtr* error); bool Close(ErrorPtr* error); bool WaitForData(AccessMode mode, const base::Callback<void(AccessMode)>& callback, ErrorPtr* error); bool WaitForDataBlocking(AccessMode in_mode, base::TimeDelta timeout, AccessMode* out_mode, ErrorPtr* error); void CancelPendingAsyncOperations(); private: bool ReportError(ErrorPtr* error, const base::Location& location, const std::string& message); void DoHandshake(const base::Closure& success_callback, const Stream::ErrorCallback& error_callback); void RetryHandshake(const base::Closure& success_callback, const Stream::ErrorCallback& error_callback, Stream::AccessMode mode); int OnCertVerifyResults(int ok, X509_STORE_CTX* ctx); static int OnCertVerifyResultsStatic(int ok, X509_STORE_CTX* ctx); StreamPtr socket_; std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free}; std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free}; BIO* stream_bio_{nullptr}; bool need_more_read_{false}; bool need_more_write_{false}; base::WeakPtrFactory<TlsStreamImpl> weak_ptr_factory_{this}; DISALLOW_COPY_AND_ASSIGN(TlsStreamImpl); }; TlsStream::TlsStreamImpl::TlsStreamImpl() { SSL_load_error_strings(); SSL_library_init(); if (ssl_ctx_private_data_index < 0) { ssl_ctx_private_data_index = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); } } TlsStream::TlsStreamImpl::~TlsStreamImpl() { ssl_.reset(); ctx_.reset(); } bool TlsStream::TlsStreamImpl::ReadNonBlocking(void* buffer, size_t size_to_read, size_t* size_read, bool* end_of_stream, ErrorPtr* error) { const size_t max_int = std::numeric_limits<int>::max(); int size_int = static_cast<int>(std::min(size_to_read, max_int)); int ret = SSL_read(ssl_.get(), buffer, size_int); if (ret > 0) { *size_read = static_cast<size_t>(ret); if (end_of_stream) *end_of_stream = false; return true; } int err = SSL_get_error(ssl_.get(), ret); if (err == SSL_ERROR_ZERO_RETURN) { *size_read = 0; if (end_of_stream) *end_of_stream = true; return true; } if (err == SSL_ERROR_WANT_READ) { need_more_read_ = true; } else if (err == SSL_ERROR_WANT_WRITE) { // Writes might be required for SSL_read() because of possible TLS // re-negotiations which can happen at any time. need_more_write_ = true; } else { return ReportError(error, FROM_HERE, "Error reading from TLS socket"); } *size_read = 0; if (end_of_stream) *end_of_stream = false; return true; } bool TlsStream::TlsStreamImpl::WriteNonBlocking(const void* buffer, size_t size_to_write, size_t* size_written, ErrorPtr* error) { const size_t max_int = std::numeric_limits<int>::max(); int size_int = static_cast<int>(std::min(size_to_write, max_int)); int ret = SSL_write(ssl_.get(), buffer, size_int); if (ret > 0) { *size_written = static_cast<size_t>(ret); return true; } int err = SSL_get_error(ssl_.get(), ret); if (err == SSL_ERROR_WANT_READ) { // Reads might be required for SSL_write() because of possible TLS // re-negotiations which can happen at any time. need_more_read_ = true; } else if (err == SSL_ERROR_WANT_WRITE) { need_more_write_ = true; } else { return ReportError(error, FROM_HERE, "Error writing to TLS socket"); } *size_written = 0; return true; } bool TlsStream::TlsStreamImpl::Flush(ErrorPtr* error) { return socket_->FlushBlocking(error); } bool TlsStream::TlsStreamImpl::Close(ErrorPtr* error) { // 2 seconds should be plenty here. const base::TimeDelta kTimeout = base::TimeDelta::FromSeconds(2); // The retry count of 4 below is just arbitrary, to ensure we don't get stuck // here forever. We should rarely need to repeat SSL_shutdown anyway. for (int retry_count = 0; retry_count < 4; retry_count++) { int ret = SSL_shutdown(ssl_.get()); // We really don't care for bi-directional shutdown here. // Just make sure we only send the "close notify" alert to the remote peer. if (ret >= 0) break; int err = SSL_get_error(ssl_.get(), ret); if (err == SSL_ERROR_WANT_READ) { if (!socket_->WaitForDataBlocking(AccessMode::READ, kTimeout, nullptr, error)) { break; } } else if (err == SSL_ERROR_WANT_WRITE) { if (!socket_->WaitForDataBlocking(AccessMode::WRITE, kTimeout, nullptr, error)) { break; } } else { LOG(ERROR) << "SSL_shutdown returned error #" << err; ReportError(error, FROM_HERE, "Failed to shut down TLS socket"); break; } } return socket_->CloseBlocking(error); } bool TlsStream::TlsStreamImpl::WaitForData( AccessMode mode, const base::Callback<void(AccessMode)>& callback, ErrorPtr* error) { bool is_read = stream_utils::IsReadAccessMode(mode); bool is_write = stream_utils::IsWriteAccessMode(mode); is_read |= need_more_read_; is_write |= need_more_write_; need_more_read_ = false; need_more_write_ = false; if (is_read && SSL_pending(ssl_.get()) > 0) { callback.Run(AccessMode::READ); return true; } mode = stream_utils::MakeAccessMode(is_read, is_write); return socket_->WaitForData(mode, callback, error); } bool TlsStream::TlsStreamImpl::WaitForDataBlocking(AccessMode in_mode, base::TimeDelta timeout, AccessMode* out_mode, ErrorPtr* error) { bool is_read = stream_utils::IsReadAccessMode(in_mode); bool is_write = stream_utils::IsWriteAccessMode(in_mode); is_read |= need_more_read_; is_write |= need_more_write_; need_more_read_ = need_more_write_ = false; if (is_read && SSL_pending(ssl_.get()) > 0) { if (out_mode) *out_mode = AccessMode::READ; return true; } in_mode = stream_utils::MakeAccessMode(is_read, is_write); return socket_->WaitForDataBlocking(in_mode, timeout, out_mode, error); } void TlsStream::TlsStreamImpl::CancelPendingAsyncOperations() { socket_->CancelPendingAsyncOperations(); weak_ptr_factory_.InvalidateWeakPtrs(); } bool TlsStream::TlsStreamImpl::ReportError( ErrorPtr* error, const base::Location& location, const std::string& message) { const char* file = nullptr; int line = 0; const char* data = 0; int flags = 0; while (auto errnum = ERR_get_error_line_data(&file, &line, &data, &flags)) { char buf[256]; ERR_error_string_n(errnum, buf, sizeof(buf)); base::Location ssl_location{"Unknown", file, line, nullptr}; std::string ssl_message = buf; if (flags & ERR_TXT_STRING) { ssl_message += ": "; ssl_message += data; } Error::AddTo(error, ssl_location, "openssl", std::to_string(errnum), ssl_message); } Error::AddTo(error, location, "tls_stream", "failed", message); return false; } int TlsStream::TlsStreamImpl::OnCertVerifyResults(int ok, X509_STORE_CTX* ctx) { // OpenSSL already performs a comprehensive check of the certificate chain // (using X509_verify_cert() function) and calls back with the result of its // verification. // |ok| is set to 1 if the verification passed and 0 if an error was detected. // Here we can perform some additional checks if we need to, or simply log // the issues found. // For now, just log an error if it occurred. if (!ok) { LOG(ERROR) << "Server certificate validation failed: " << X509_verify_cert_error_string(X509_STORE_CTX_get_error(ctx)); } return ok; } int TlsStream::TlsStreamImpl::OnCertVerifyResultsStatic(int ok, X509_STORE_CTX* ctx) { // Obtain the pointer to the instance of TlsStream::TlsStreamImpl from the // SSL CTX object referenced by |ctx|. SSL* ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data( ctx, SSL_get_ex_data_X509_STORE_CTX_idx())); SSL_CTX* ssl_ctx = ssl ? SSL_get_SSL_CTX(ssl) : nullptr; TlsStream::TlsStreamImpl* self = nullptr; if (ssl_ctx) { self = static_cast<TlsStream::TlsStreamImpl*>(SSL_CTX_get_ex_data( ssl_ctx, ssl_ctx_private_data_index)); } return self ? self->OnCertVerifyResults(ok, ctx) : ok; } bool TlsStream::TlsStreamImpl::Init(StreamPtr socket, const std::string& host, const base::Closure& success_callback, const Stream::ErrorCallback& error_callback, ErrorPtr* error) { ctx_.reset(SSL_CTX_new(TLSv1_2_client_method())); if (!ctx_) return ReportError(error, FROM_HERE, "Cannot create SSL_CTX"); // Top cipher suites supported by both Google GFEs and OpenSSL (in server // preferred order). int res = SSL_CTX_set_cipher_list(ctx_.get(), "ECDHE-ECDSA-AES128-GCM-SHA256:" "ECDHE-ECDSA-AES256-GCM-SHA384:" "ECDHE-RSA-AES128-GCM-SHA256:" "ECDHE-RSA-AES256-GCM-SHA384"); if (res != 1) return ReportError(error, FROM_HERE, "Cannot set the cipher list"); res = SSL_CTX_load_verify_locations(ctx_.get(), nullptr, kCACertificatePath); if (res != 1) { return ReportError(error, FROM_HERE, "Failed to specify trusted certificate location"); } // Store a pointer to "this" into SSL_CTX instance. SSL_CTX_set_ex_data(ctx_.get(), ssl_ctx_private_data_index, this); // Ask OpenSSL to validate the server host from the certificate to match // the expected host name we are given: X509_VERIFY_PARAM* param = SSL_CTX_get0_param(ctx_.get()); X509_VERIFY_PARAM_set1_host(param, host.c_str(), host.size()); SSL_CTX_set_verify(ctx_.get(), SSL_VERIFY_PEER, &TlsStreamImpl::OnCertVerifyResultsStatic); socket_ = std::move(socket); ssl_.reset(SSL_new(ctx_.get())); // Enable TLS progress callback if VLOG level is >=3. if (VLOG_IS_ON(3)) SSL_set_info_callback(ssl_.get(), TlsInfoCallback); stream_bio_ = BIO_new_stream(socket_.get()); SSL_set_bio(ssl_.get(), stream_bio_, stream_bio_); SSL_set_connect_state(ssl_.get()); // We might have no message loop (e.g. we are in unit tests). if (MessageLoop::ThreadHasCurrent()) { MessageLoop::current()->PostTask( FROM_HERE, base::Bind(&TlsStreamImpl::DoHandshake, weak_ptr_factory_.GetWeakPtr(), success_callback, error_callback)); } else { DoHandshake(success_callback, error_callback); } return true; } void TlsStream::TlsStreamImpl::RetryHandshake( const base::Closure& success_callback, const Stream::ErrorCallback& error_callback, Stream::AccessMode /* mode */) { VLOG(1) << "Retrying TLS handshake"; DoHandshake(success_callback, error_callback); } void TlsStream::TlsStreamImpl::DoHandshake( const base::Closure& success_callback, const Stream::ErrorCallback& error_callback) { VLOG(1) << "Begin TLS handshake"; int res = SSL_do_handshake(ssl_.get()); if (res == 1) { VLOG(1) << "Handshake successful"; success_callback.Run(); return; } ErrorPtr error; int err = SSL_get_error(ssl_.get(), res); if (err == SSL_ERROR_WANT_READ) { VLOG(1) << "Waiting for read data..."; bool ok = socket_->WaitForData( Stream::AccessMode::READ, base::Bind(&TlsStreamImpl::RetryHandshake, weak_ptr_factory_.GetWeakPtr(), success_callback, error_callback), &error); if (ok) return; } else if (err == SSL_ERROR_WANT_WRITE) { VLOG(1) << "Waiting for write data..."; bool ok = socket_->WaitForData( Stream::AccessMode::WRITE, base::Bind(&TlsStreamImpl::RetryHandshake, weak_ptr_factory_.GetWeakPtr(), success_callback, error_callback), &error); if (ok) return; } else { ReportError(&error, FROM_HERE, "TLS handshake failed."); } error_callback.Run(error.get()); } ///////////////////////////////////////////////////////////////////////////// TlsStream::TlsStream(std::unique_ptr<TlsStreamImpl> impl) : impl_{std::move(impl)} {} TlsStream::~TlsStream() { if (impl_) { impl_->Close(nullptr); } } void TlsStream::Connect(StreamPtr socket, const std::string& host, const base::Callback<void(StreamPtr)>& success_callback, const Stream::ErrorCallback& error_callback) { std::unique_ptr<TlsStreamImpl> impl{new TlsStreamImpl}; std::unique_ptr<TlsStream> stream{new TlsStream{std::move(impl)}}; TlsStreamImpl* pimpl = stream->impl_.get(); ErrorPtr error; bool success = pimpl->Init(std::move(socket), host, base::Bind(success_callback, base::Passed(std::move(stream))), error_callback, &error); if (!success) error_callback.Run(error.get()); } bool TlsStream::IsOpen() const { return impl_ ? true : false; } bool TlsStream::SetSizeBlocking(uint64_t /* size */, ErrorPtr* error) { return stream_utils::ErrorOperationNotSupported(FROM_HERE, error); } bool TlsStream::Seek(int64_t /* offset */, Whence /* whence */, uint64_t* /* new_position*/, ErrorPtr* error) { return stream_utils::ErrorOperationNotSupported(FROM_HERE, error); } bool TlsStream::ReadNonBlocking(void* buffer, size_t size_to_read, size_t* size_read, bool* end_of_stream, ErrorPtr* error) { if (!impl_) return stream_utils::ErrorStreamClosed(FROM_HERE, error); return impl_->ReadNonBlocking(buffer, size_to_read, size_read, end_of_stream, error); } bool TlsStream::WriteNonBlocking(const void* buffer, size_t size_to_write, size_t* size_written, ErrorPtr* error) { if (!impl_) return stream_utils::ErrorStreamClosed(FROM_HERE, error); return impl_->WriteNonBlocking(buffer, size_to_write, size_written, error); } bool TlsStream::FlushBlocking(ErrorPtr* error) { if (!impl_) return stream_utils::ErrorStreamClosed(FROM_HERE, error); return impl_->Flush(error); } bool TlsStream::CloseBlocking(ErrorPtr* error) { if (impl_ && !impl_->Close(error)) return false; impl_.reset(); return true; } bool TlsStream::WaitForData(AccessMode mode, const base::Callback<void(AccessMode)>& callback, ErrorPtr* error) { if (!impl_) return stream_utils::ErrorStreamClosed(FROM_HERE, error); return impl_->WaitForData(mode, callback, error); } bool TlsStream::WaitForDataBlocking(AccessMode in_mode, base::TimeDelta timeout, AccessMode* out_mode, ErrorPtr* error) { if (!impl_) return stream_utils::ErrorStreamClosed(FROM_HERE, error); return impl_->WaitForDataBlocking(in_mode, timeout, out_mode, error); } void TlsStream::CancelPendingAsyncOperations() { if (impl_) impl_->CancelPendingAsyncOperations(); Stream::CancelPendingAsyncOperations(); } } // namespace brillo