// Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/websockets/websocket_job.h" #include <algorithm> #include "base/bind.h" #include "base/lazy_instance.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/base/net_log.h" #include "net/cookies/cookie_store.h" #include "net/http/http_network_session.h" #include "net/http/http_transaction_factory.h" #include "net/http/http_util.h" #include "net/spdy/spdy_session.h" #include "net/spdy/spdy_session_pool.h" #include "net/url_request/url_request_context.h" #include "net/websockets/websocket_handshake_handler.h" #include "net/websockets/websocket_net_log_params.h" #include "net/websockets/websocket_throttle.h" #include "url/gurl.h" static const int kMaxPendingSendAllowed = 32768; // 32 kilobytes. namespace { // lower-case header names. const char* const kCookieHeaders[] = { "cookie", "cookie2" }; const char* const kSetCookieHeaders[] = { "set-cookie", "set-cookie2" }; net::SocketStreamJob* WebSocketJobFactory( const GURL& url, net::SocketStream::Delegate* delegate, net::URLRequestContext* context, net::CookieStore* cookie_store) { net::WebSocketJob* job = new net::WebSocketJob(delegate); job->InitSocketStream(new net::SocketStream(url, job, context, cookie_store)); return job; } class WebSocketJobInitSingleton { private: friend struct base::DefaultLazyInstanceTraits<WebSocketJobInitSingleton>; WebSocketJobInitSingleton() { net::SocketStreamJob::RegisterProtocolFactory("ws", WebSocketJobFactory); net::SocketStreamJob::RegisterProtocolFactory("wss", WebSocketJobFactory); } }; static base::LazyInstance<WebSocketJobInitSingleton> g_websocket_job_init = LAZY_INSTANCE_INITIALIZER; } // anonymous namespace namespace net { // static void WebSocketJob::EnsureInit() { g_websocket_job_init.Get(); } WebSocketJob::WebSocketJob(SocketStream::Delegate* delegate) : delegate_(delegate), state_(INITIALIZED), waiting_(false), handshake_request_(new WebSocketHandshakeRequestHandler), handshake_response_(new WebSocketHandshakeResponseHandler), started_to_send_handshake_request_(false), handshake_request_sent_(0), response_cookies_save_index_(0), spdy_protocol_version_(0), save_next_cookie_running_(false), callback_pending_(false), weak_ptr_factory_(this), weak_ptr_factory_for_send_pending_(this) { } WebSocketJob::~WebSocketJob() { DCHECK_EQ(CLOSED, state_); DCHECK(!delegate_); DCHECK(!socket_.get()); } void WebSocketJob::Connect() { DCHECK(socket_.get()); DCHECK_EQ(state_, INITIALIZED); state_ = CONNECTING; socket_->Connect(); } bool WebSocketJob::SendData(const char* data, int len) { switch (state_) { case INITIALIZED: return false; case CONNECTING: return SendHandshakeRequest(data, len); case OPEN: { scoped_refptr<IOBufferWithSize> buffer = new IOBufferWithSize(len); memcpy(buffer->data(), data, len); if (current_send_buffer_.get() || !send_buffer_queue_.empty()) { send_buffer_queue_.push_back(buffer); return true; } current_send_buffer_ = new DrainableIOBuffer(buffer.get(), len); return SendDataInternal(current_send_buffer_->data(), current_send_buffer_->BytesRemaining()); } case CLOSING: case CLOSED: return false; } return false; } void WebSocketJob::Close() { if (state_ == CLOSED) return; state_ = CLOSING; if (current_send_buffer_.get()) { // Will close in SendPending. return; } state_ = CLOSED; CloseInternal(); } void WebSocketJob::RestartWithAuth(const AuthCredentials& credentials) { state_ = CONNECTING; socket_->RestartWithAuth(credentials); } void WebSocketJob::DetachDelegate() { state_ = CLOSED; WebSocketThrottle::GetInstance()->RemoveFromQueue(this); scoped_refptr<WebSocketJob> protect(this); weak_ptr_factory_.InvalidateWeakPtrs(); weak_ptr_factory_for_send_pending_.InvalidateWeakPtrs(); delegate_ = NULL; if (socket_.get()) socket_->DetachDelegate(); socket_ = NULL; if (!callback_.is_null()) { waiting_ = false; callback_.Reset(); Release(); // Balanced with OnStartOpenConnection(). } } int WebSocketJob::OnStartOpenConnection( SocketStream* socket, const CompletionCallback& callback) { DCHECK(callback_.is_null()); state_ = CONNECTING; addresses_ = socket->address_list(); if (!WebSocketThrottle::GetInstance()->PutInQueue(this)) { return ERR_WS_THROTTLE_QUEUE_TOO_LARGE; } if (delegate_) { int result = delegate_->OnStartOpenConnection(socket, callback); DCHECK_EQ(OK, result); } if (waiting_) { // PutInQueue() may set |waiting_| true for throttling. In this case, // Wakeup() will be called later. callback_ = callback; AddRef(); // Balanced when callback_ is cleared. return ERR_IO_PENDING; } return TrySpdyStream(); } void WebSocketJob::OnConnected( SocketStream* socket, int max_pending_send_allowed) { if (state_ == CLOSED) return; DCHECK_EQ(CONNECTING, state_); if (delegate_) delegate_->OnConnected(socket, max_pending_send_allowed); } void WebSocketJob::OnSentData(SocketStream* socket, int amount_sent) { DCHECK_NE(INITIALIZED, state_); DCHECK_GT(amount_sent, 0); if (state_ == CLOSED) return; if (state_ == CONNECTING) { OnSentHandshakeRequest(socket, amount_sent); return; } if (delegate_) { DCHECK(state_ == OPEN || state_ == CLOSING); if (!current_send_buffer_.get()) { VLOG(1) << "OnSentData current_send_buffer=NULL amount_sent=" << amount_sent; return; } current_send_buffer_->DidConsume(amount_sent); if (current_send_buffer_->BytesRemaining() > 0) return; // We need to report amount_sent of original buffer size, instead of // amount sent to |socket|. amount_sent = current_send_buffer_->size(); DCHECK_GT(amount_sent, 0); current_send_buffer_ = NULL; if (!weak_ptr_factory_for_send_pending_.HasWeakPtrs()) { base::MessageLoopForIO::current()->PostTask( FROM_HERE, base::Bind(&WebSocketJob::SendPending, weak_ptr_factory_for_send_pending_.GetWeakPtr())); } delegate_->OnSentData(socket, amount_sent); } } void WebSocketJob::OnReceivedData( SocketStream* socket, const char* data, int len) { DCHECK_NE(INITIALIZED, state_); if (state_ == CLOSED) return; if (state_ == CONNECTING) { OnReceivedHandshakeResponse(socket, data, len); return; } DCHECK(state_ == OPEN || state_ == CLOSING); if (delegate_ && len > 0) delegate_->OnReceivedData(socket, data, len); } void WebSocketJob::OnClose(SocketStream* socket) { state_ = CLOSED; WebSocketThrottle::GetInstance()->RemoveFromQueue(this); scoped_refptr<WebSocketJob> protect(this); weak_ptr_factory_.InvalidateWeakPtrs(); SocketStream::Delegate* delegate = delegate_; delegate_ = NULL; socket_ = NULL; if (!callback_.is_null()) { waiting_ = false; callback_.Reset(); Release(); // Balanced with OnStartOpenConnection(). } if (delegate) delegate->OnClose(socket); } void WebSocketJob::OnAuthRequired( SocketStream* socket, AuthChallengeInfo* auth_info) { if (delegate_) delegate_->OnAuthRequired(socket, auth_info); } void WebSocketJob::OnSSLCertificateError( SocketStream* socket, const SSLInfo& ssl_info, bool fatal) { if (delegate_) delegate_->OnSSLCertificateError(socket, ssl_info, fatal); } void WebSocketJob::OnError(const SocketStream* socket, int error) { if (delegate_ && error != ERR_PROTOCOL_SWITCHED) delegate_->OnError(socket, error); } void WebSocketJob::OnCreatedSpdyStream(int result) { DCHECK(spdy_websocket_stream_.get()); DCHECK(socket_.get()); DCHECK_NE(ERR_IO_PENDING, result); if (state_ == CLOSED) { result = ERR_ABORTED; } else if (result == OK) { state_ = CONNECTING; result = ERR_PROTOCOL_SWITCHED; } else { spdy_websocket_stream_.reset(); } CompleteIO(result); } void WebSocketJob::OnSentSpdyHeaders() { DCHECK_NE(INITIALIZED, state_); if (state_ != CONNECTING) return; size_t original_length = handshake_request_->original_length(); handshake_request_.reset(); if (delegate_) delegate_->OnSentData(socket_.get(), original_length); } void WebSocketJob::OnSpdyResponseHeadersUpdated( const SpdyHeaderBlock& response_headers) { DCHECK_NE(INITIALIZED, state_); if (state_ != CONNECTING) return; // TODO(toyoshim): Fallback to non-spdy connection? handshake_response_->ParseResponseHeaderBlock(response_headers, challenge_, spdy_protocol_version_); SaveCookiesAndNotifyHeadersComplete(); } void WebSocketJob::OnSentSpdyData(size_t bytes_sent) { DCHECK_NE(INITIALIZED, state_); DCHECK_NE(CONNECTING, state_); if (state_ == CLOSED) return; if (!spdy_websocket_stream_.get()) return; OnSentData(socket_.get(), static_cast<int>(bytes_sent)); } void WebSocketJob::OnReceivedSpdyData(scoped_ptr<SpdyBuffer> buffer) { DCHECK_NE(INITIALIZED, state_); DCHECK_NE(CONNECTING, state_); if (state_ == CLOSED) return; if (!spdy_websocket_stream_.get()) return; if (buffer) { OnReceivedData( socket_.get(), buffer->GetRemainingData(), buffer->GetRemainingSize()); } else { OnReceivedData(socket_.get(), NULL, 0); } } void WebSocketJob::OnCloseSpdyStream() { spdy_websocket_stream_.reset(); OnClose(socket_.get()); } bool WebSocketJob::SendHandshakeRequest(const char* data, int len) { DCHECK_EQ(state_, CONNECTING); if (started_to_send_handshake_request_) return false; if (!handshake_request_->ParseRequest(data, len)) return false; AddCookieHeaderAndSend(); return true; } void WebSocketJob::AddCookieHeaderAndSend() { bool allow = true; if (delegate_ && !delegate_->CanGetCookies(socket_.get(), GetURLForCookies())) allow = false; if (socket_.get() && delegate_ && state_ == CONNECTING) { handshake_request_->RemoveHeaders(kCookieHeaders, arraysize(kCookieHeaders)); if (allow && socket_->cookie_store()) { // Add cookies, including HttpOnly cookies. CookieOptions cookie_options; cookie_options.set_include_httponly(); socket_->cookie_store()->GetCookiesWithOptionsAsync( GetURLForCookies(), cookie_options, base::Bind(&WebSocketJob::LoadCookieCallback, weak_ptr_factory_.GetWeakPtr())); } else { DoSendData(); } } } void WebSocketJob::LoadCookieCallback(const std::string& cookie) { if (!cookie.empty()) // TODO(tyoshino): Sending cookie means that connection doesn't need // PRIVACY_MODE_ENABLED as cookies may be server-bound and channel id // wouldn't negatively affect privacy anyway. Need to restart connection // or refactor to determine cookie status prior to connecting. handshake_request_->AppendHeaderIfMissing("Cookie", cookie); DoSendData(); } void WebSocketJob::DoSendData() { if (spdy_websocket_stream_.get()) { scoped_ptr<SpdyHeaderBlock> headers(new SpdyHeaderBlock); handshake_request_->GetRequestHeaderBlock( socket_->url(), headers.get(), &challenge_, spdy_protocol_version_); spdy_websocket_stream_->SendRequest(headers.Pass()); } else { const std::string& handshake_request = handshake_request_->GetRawRequest(); handshake_request_sent_ = 0; socket_->net_log()->AddEvent( NetLog::TYPE_WEB_SOCKET_SEND_REQUEST_HEADERS, base::Bind(&NetLogWebSocketHandshakeCallback, &handshake_request)); socket_->SendData(handshake_request.data(), handshake_request.size()); } // Just buffered in |handshake_request_|. started_to_send_handshake_request_ = true; } void WebSocketJob::OnSentHandshakeRequest( SocketStream* socket, int amount_sent) { DCHECK_EQ(state_, CONNECTING); handshake_request_sent_ += amount_sent; DCHECK_LE(handshake_request_sent_, handshake_request_->raw_length()); if (handshake_request_sent_ >= handshake_request_->raw_length()) { // handshake request has been sent. // notify original size of handshake request to delegate. // Reset the handshake_request_ first in case this object is deleted by the // delegate. size_t original_length = handshake_request_->original_length(); handshake_request_.reset(); if (delegate_) delegate_->OnSentData(socket, original_length); } } void WebSocketJob::OnReceivedHandshakeResponse( SocketStream* socket, const char* data, int len) { DCHECK_EQ(state_, CONNECTING); if (handshake_response_->HasResponse()) { // If we already has handshake response, received data should be frame // data, not handshake message. received_data_after_handshake_.insert( received_data_after_handshake_.end(), data, data + len); return; } size_t response_length = handshake_response_->ParseRawResponse(data, len); if (!handshake_response_->HasResponse()) { // not yet. we need more data. return; } // handshake message is completed. std::string raw_response = handshake_response_->GetRawResponse(); socket_->net_log()->AddEvent( NetLog::TYPE_WEB_SOCKET_READ_RESPONSE_HEADERS, base::Bind(&NetLogWebSocketHandshakeCallback, &raw_response)); if (len - response_length > 0) { // If we received extra data, it should be frame data. DCHECK(received_data_after_handshake_.empty()); received_data_after_handshake_.assign(data + response_length, data + len); } SaveCookiesAndNotifyHeadersComplete(); } void WebSocketJob::SaveCookiesAndNotifyHeadersComplete() { // handshake message is completed. DCHECK(handshake_response_->HasResponse()); // Extract cookies from the handshake response into a temporary vector. response_cookies_.clear(); response_cookies_save_index_ = 0; handshake_response_->GetHeaders( kSetCookieHeaders, arraysize(kSetCookieHeaders), &response_cookies_); // Now, loop over the response cookies, and attempt to persist each. SaveNextCookie(); } void WebSocketJob::NotifyHeadersComplete() { // Remove cookie headers, with malformed headers preserved. // Actual handshake should be done in Blink. handshake_response_->RemoveHeaders( kSetCookieHeaders, arraysize(kSetCookieHeaders)); std::string handshake_response = handshake_response_->GetResponse(); handshake_response_.reset(); std::vector<char> received_data(handshake_response.begin(), handshake_response.end()); received_data.insert(received_data.end(), received_data_after_handshake_.begin(), received_data_after_handshake_.end()); received_data_after_handshake_.clear(); state_ = OPEN; DCHECK(!received_data.empty()); if (delegate_) delegate_->OnReceivedData( socket_.get(), &received_data.front(), received_data.size()); WebSocketThrottle::GetInstance()->RemoveFromQueue(this); } void WebSocketJob::SaveNextCookie() { if (!socket_.get() || !delegate_ || state_ != CONNECTING) return; callback_pending_ = false; save_next_cookie_running_ = true; if (socket_->cookie_store()) { GURL url_for_cookies = GetURLForCookies(); CookieOptions options; options.set_include_httponly(); // Loop as long as SetCookieWithOptionsAsync completes synchronously. Since // CookieMonster's asynchronous operation APIs queue the callback to run it // on the thread where the API was called, there won't be race. I.e. unless // the callback is run synchronously, it won't be run in parallel with this // method. while (!callback_pending_ && response_cookies_save_index_ < response_cookies_.size()) { std::string cookie = response_cookies_[response_cookies_save_index_]; response_cookies_save_index_++; if (!delegate_->CanSetCookie( socket_.get(), url_for_cookies, cookie, &options)) continue; callback_pending_ = true; socket_->cookie_store()->SetCookieWithOptionsAsync( url_for_cookies, cookie, options, base::Bind(&WebSocketJob::OnCookieSaved, weak_ptr_factory_.GetWeakPtr())); } } save_next_cookie_running_ = false; if (callback_pending_) return; response_cookies_.clear(); response_cookies_save_index_ = 0; NotifyHeadersComplete(); } void WebSocketJob::OnCookieSaved(bool cookie_status) { // Tell the caller of SetCookieWithOptionsAsync() that this completion // callback is invoked. // - If the caller checks callback_pending earlier than this callback, the // caller exits to let this method continue iteration. // - Otherwise, the caller continues iteration. callback_pending_ = false; // Resume SaveNextCookie if the caller of SetCookieWithOptionsAsync() exited // the loop. Otherwise, return. if (save_next_cookie_running_) return; SaveNextCookie(); } GURL WebSocketJob::GetURLForCookies() const { GURL url = socket_->url(); std::string scheme = socket_->is_secure() ? "https" : "http"; url::Replacements<char> replacements; replacements.SetScheme(scheme.c_str(), url::Component(0, scheme.length())); return url.ReplaceComponents(replacements); } const AddressList& WebSocketJob::address_list() const { return addresses_; } int WebSocketJob::TrySpdyStream() { if (!socket_.get()) return ERR_FAILED; // Check if we have a SPDY session available. HttpTransactionFactory* factory = socket_->context()->http_transaction_factory(); if (!factory) return OK; scoped_refptr<HttpNetworkSession> session = factory->GetSession(); if (!session.get() || !session->params().enable_websocket_over_spdy) return OK; SpdySessionPool* spdy_pool = session->spdy_session_pool(); PrivacyMode privacy_mode = socket_->privacy_mode(); const SpdySessionKey key(HostPortPair::FromURL(socket_->url()), socket_->proxy_server(), privacy_mode); // Forbid wss downgrade to SPDY without SSL. // TODO(toyoshim): Does it realize the same policy with HTTP? base::WeakPtr<SpdySession> spdy_session = spdy_pool->FindAvailableSession(key, *socket_->net_log()); if (!spdy_session) return OK; SSLInfo ssl_info; bool was_npn_negotiated; NextProto protocol_negotiated = kProtoUnknown; bool use_ssl = spdy_session->GetSSLInfo( &ssl_info, &was_npn_negotiated, &protocol_negotiated); if (socket_->is_secure() && !use_ssl) return OK; // Create SpdyWebSocketStream. spdy_protocol_version_ = spdy_session->GetProtocolVersion(); spdy_websocket_stream_.reset(new SpdyWebSocketStream(spdy_session, this)); int result = spdy_websocket_stream_->InitializeStream( socket_->url(), MEDIUM, *socket_->net_log()); if (result == OK) { OnConnected(socket_.get(), kMaxPendingSendAllowed); return ERR_PROTOCOL_SWITCHED; } if (result != ERR_IO_PENDING) { spdy_websocket_stream_.reset(); return OK; } return ERR_IO_PENDING; } void WebSocketJob::SetWaiting() { waiting_ = true; } bool WebSocketJob::IsWaiting() const { return waiting_; } void WebSocketJob::Wakeup() { if (!waiting_) return; waiting_ = false; DCHECK(!callback_.is_null()); base::MessageLoopForIO::current()->PostTask( FROM_HERE, base::Bind(&WebSocketJob::RetryPendingIO, weak_ptr_factory_.GetWeakPtr())); } void WebSocketJob::RetryPendingIO() { int result = TrySpdyStream(); // In the case of ERR_IO_PENDING, CompleteIO() will be called from // OnCreatedSpdyStream(). if (result != ERR_IO_PENDING) CompleteIO(result); } void WebSocketJob::CompleteIO(int result) { // |callback_| may be null if OnClose() or DetachDelegate() was called. if (!callback_.is_null()) { CompletionCallback callback = callback_; callback_.Reset(); callback.Run(result); Release(); // Balanced with OnStartOpenConnection(). } } bool WebSocketJob::SendDataInternal(const char* data, int length) { if (spdy_websocket_stream_.get()) return ERR_IO_PENDING == spdy_websocket_stream_->SendData(data, length); if (socket_.get()) return socket_->SendData(data, length); return false; } void WebSocketJob::CloseInternal() { if (spdy_websocket_stream_.get()) spdy_websocket_stream_->Close(); if (socket_.get()) socket_->Close(); } void WebSocketJob::SendPending() { if (current_send_buffer_.get()) return; // Current buffer has been sent. Try next if any. if (send_buffer_queue_.empty()) { // No more data to send. if (state_ == CLOSING) CloseInternal(); return; } scoped_refptr<IOBufferWithSize> next_buffer = send_buffer_queue_.front(); send_buffer_queue_.pop_front(); current_send_buffer_ = new DrainableIOBuffer(next_buffer.get(), next_buffer->size()); SendDataInternal(current_send_buffer_->data(), current_send_buffer_->BytesRemaining()); } } // namespace net