// Copyright (c) 2011 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include <string> #include <vector> #include "base/memory/ref_counted.h" #include "base/string_split.h" #include "base/string_util.h" #include "googleurl/src/gurl.h" #include "net/base/cookie_policy.h" #include "net/base/cookie_store.h" #include "net/base/net_errors.h" #include "net/base/sys_addrinfo.h" #include "net/base/transport_security_state.h" #include "net/socket_stream/socket_stream.h" #include "net/url_request/url_request_context.h" #include "net/websockets/websocket_job.h" #include "net/websockets/websocket_throttle.h" #include "testing/gtest/include/gtest/gtest.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/platform_test.h" namespace net { class MockSocketStream : public SocketStream { public: MockSocketStream(const GURL& url, SocketStream::Delegate* delegate) : SocketStream(url, delegate) {} virtual ~MockSocketStream() {} virtual void Connect() {} virtual bool SendData(const char* data, int len) { sent_data_ += std::string(data, len); return true; } virtual void Close() {} virtual void RestartWithAuth( const string16& username, const string16& password) {} virtual void DetachDelegate() { delegate_ = NULL; } const std::string& sent_data() const { return sent_data_; } private: std::string sent_data_; }; class MockSocketStreamDelegate : public SocketStream::Delegate { public: MockSocketStreamDelegate() : amount_sent_(0) {} virtual ~MockSocketStreamDelegate() {} virtual void OnConnected(SocketStream* socket, int max_pending_send_allowed) { } virtual void OnSentData(SocketStream* socket, int amount_sent) { amount_sent_ += amount_sent; } virtual void OnReceivedData(SocketStream* socket, const char* data, int len) { received_data_ += std::string(data, len); } virtual void OnClose(SocketStream* socket) { } size_t amount_sent() const { return amount_sent_; } const std::string& received_data() const { return received_data_; } private: int amount_sent_; std::string received_data_; }; class MockCookieStore : public CookieStore { public: struct Entry { GURL url; std::string cookie_line; CookieOptions options; }; MockCookieStore() {} virtual bool SetCookieWithOptions(const GURL& url, const std::string& cookie_line, const CookieOptions& options) { Entry entry; entry.url = url; entry.cookie_line = cookie_line; entry.options = options; entries_.push_back(entry); return true; } virtual std::string GetCookiesWithOptions(const GURL& url, const CookieOptions& options) { std::string result; for (size_t i = 0; i < entries_.size(); i++) { Entry &entry = entries_[i]; if (url == entry.url) { if (!result.empty()) { result += "; "; } result += entry.cookie_line; } } return result; } virtual void DeleteCookie(const GURL& url, const std::string& cookie_name) {} virtual CookieMonster* GetCookieMonster() { return NULL; } const std::vector<Entry>& entries() const { return entries_; } private: friend class base::RefCountedThreadSafe<MockCookieStore>; virtual ~MockCookieStore() {} std::vector<Entry> entries_; }; class MockCookiePolicy : public CookiePolicy { public: MockCookiePolicy() : allow_all_cookies_(true) {} virtual ~MockCookiePolicy() {} void set_allow_all_cookies(bool allow_all_cookies) { allow_all_cookies_ = allow_all_cookies; } virtual int CanGetCookies(const GURL& url, const GURL& first_party_for_cookies) const { if (allow_all_cookies_) return OK; return ERR_ACCESS_DENIED; } virtual int CanSetCookie(const GURL& url, const GURL& first_party_for_cookies, const std::string& cookie_line) const { if (allow_all_cookies_) return OK; return ERR_ACCESS_DENIED; } private: bool allow_all_cookies_; }; class MockURLRequestContext : public URLRequestContext { public: MockURLRequestContext(CookieStore* cookie_store, CookiePolicy* cookie_policy) { set_cookie_store(cookie_store); set_cookie_policy(cookie_policy); transport_security_state_ = new TransportSecurityState(); set_transport_security_state(transport_security_state_.get()); TransportSecurityState::DomainState state; state.expiry = base::Time::Now() + base::TimeDelta::FromSeconds(1000); transport_security_state_->EnableHost("upgrademe.com", state); } private: friend class base::RefCountedThreadSafe<MockURLRequestContext>; virtual ~MockURLRequestContext() {} scoped_refptr<TransportSecurityState> transport_security_state_; }; class WebSocketJobTest : public PlatformTest { public: virtual void SetUp() { cookie_store_ = new MockCookieStore; cookie_policy_.reset(new MockCookiePolicy); context_ = new MockURLRequestContext( cookie_store_.get(), cookie_policy_.get()); } virtual void TearDown() { cookie_store_ = NULL; cookie_policy_.reset(); context_ = NULL; websocket_ = NULL; socket_ = NULL; } protected: void InitWebSocketJob(const GURL& url, MockSocketStreamDelegate* delegate) { websocket_ = new WebSocketJob(delegate); socket_ = new MockSocketStream(url, websocket_.get()); websocket_->InitSocketStream(socket_.get()); websocket_->set_context(context_.get()); websocket_->state_ = WebSocketJob::CONNECTING; struct addrinfo addr; memset(&addr, 0, sizeof(struct addrinfo)); addr.ai_family = AF_INET; addr.ai_addrlen = sizeof(struct sockaddr_in); struct sockaddr_in sa_in; memset(&sa_in, 0, sizeof(struct sockaddr_in)); memcpy(&sa_in.sin_addr, "\x7f\0\0\1", 4); addr.ai_addr = reinterpret_cast<sockaddr*>(&sa_in); addr.ai_next = NULL; websocket_->addresses_.Copy(&addr, true); WebSocketThrottle::GetInstance()->PutInQueue(websocket_); } WebSocketJob::State GetWebSocketJobState() { return websocket_->state_; } void CloseWebSocketJob() { if (websocket_->socket_) { websocket_->socket_->DetachDelegate(); WebSocketThrottle::GetInstance()->RemoveFromQueue(websocket_); } websocket_->state_ = WebSocketJob::CLOSED; websocket_->delegate_ = NULL; websocket_->socket_ = NULL; } SocketStream* GetSocket(SocketStreamJob* job) { return job->socket_.get(); } scoped_refptr<MockCookieStore> cookie_store_; scoped_ptr<MockCookiePolicy> cookie_policy_; scoped_refptr<MockURLRequestContext> context_; scoped_refptr<WebSocketJob> websocket_; scoped_refptr<MockSocketStream> socket_; }; TEST_F(WebSocketJobTest, SimpleHandshake) { GURL url("ws://example.com/demo"); MockSocketStreamDelegate delegate; InitWebSocketJob(url, &delegate); static const char* kHandshakeRequestMessage = "GET /demo HTTP/1.1\r\n" "Host: example.com\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\n" "Sec-WebSocket-Protocol: sample\r\n" "Upgrade: WebSocket\r\n" "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n" "Origin: http://example.com\r\n" "\r\n" "^n:ds[4U"; bool sent = websocket_->SendData(kHandshakeRequestMessage, strlen(kHandshakeRequestMessage)); EXPECT_TRUE(sent); MessageLoop::current()->RunAllPending(); EXPECT_EQ(kHandshakeRequestMessage, socket_->sent_data()); EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState()); websocket_->OnSentData(socket_.get(), strlen(kHandshakeRequestMessage)); EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent()); const char kHandshakeResponseMessage[] = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Origin: http://example.com\r\n" "Sec-WebSocket-Location: ws://example.com/demo\r\n" "Sec-WebSocket-Protocol: sample\r\n" "\r\n" "8jKS'y:G*Co,Wxa-"; websocket_->OnReceivedData(socket_.get(), kHandshakeResponseMessage, strlen(kHandshakeResponseMessage)); MessageLoop::current()->RunAllPending(); EXPECT_EQ(kHandshakeResponseMessage, delegate.received_data()); EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState()); CloseWebSocketJob(); } TEST_F(WebSocketJobTest, SlowHandshake) { GURL url("ws://example.com/demo"); MockSocketStreamDelegate delegate; InitWebSocketJob(url, &delegate); static const char* kHandshakeRequestMessage = "GET /demo HTTP/1.1\r\n" "Host: example.com\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\n" "Sec-WebSocket-Protocol: sample\r\n" "Upgrade: WebSocket\r\n" "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n" "Origin: http://example.com\r\n" "\r\n" "^n:ds[4U"; bool sent = websocket_->SendData(kHandshakeRequestMessage, strlen(kHandshakeRequestMessage)); EXPECT_TRUE(sent); // We assume request is sent in one data chunk (from WebKit) // We don't support streaming request. MessageLoop::current()->RunAllPending(); EXPECT_EQ(kHandshakeRequestMessage, socket_->sent_data()); EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState()); websocket_->OnSentData(socket_.get(), strlen(kHandshakeRequestMessage)); EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent()); const char kHandshakeResponseMessage[] = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Origin: http://example.com\r\n" "Sec-WebSocket-Location: ws://example.com/demo\r\n" "Sec-WebSocket-Protocol: sample\r\n" "\r\n" "8jKS'y:G*Co,Wxa-"; std::vector<std::string> lines; base::SplitString(kHandshakeResponseMessage, '\n', &lines); for (size_t i = 0; i < lines.size() - 2; i++) { std::string line = lines[i] + "\r\n"; SCOPED_TRACE("Line: " + line); websocket_->OnReceivedData(socket_, line.c_str(), line.size()); MessageLoop::current()->RunAllPending(); EXPECT_TRUE(delegate.received_data().empty()); EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState()); } websocket_->OnReceivedData(socket_.get(), "\r\n", 2); MessageLoop::current()->RunAllPending(); EXPECT_TRUE(delegate.received_data().empty()); EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState()); websocket_->OnReceivedData(socket_.get(), "8jKS'y:G*Co,Wxa-", 16); EXPECT_EQ(kHandshakeResponseMessage, delegate.received_data()); EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState()); CloseWebSocketJob(); } TEST_F(WebSocketJobTest, HandshakeWithCookie) { GURL url("ws://example.com/demo"); GURL cookieUrl("http://example.com/demo"); CookieOptions cookie_options; cookie_store_->SetCookieWithOptions( cookieUrl, "CR-test=1", cookie_options); cookie_options.set_include_httponly(); cookie_store_->SetCookieWithOptions( cookieUrl, "CR-test-httponly=1", cookie_options); MockSocketStreamDelegate delegate; InitWebSocketJob(url, &delegate); static const char* kHandshakeRequestMessage = "GET /demo HTTP/1.1\r\n" "Host: example.com\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\n" "Sec-WebSocket-Protocol: sample\r\n" "Upgrade: WebSocket\r\n" "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n" "Origin: http://example.com\r\n" "Cookie: WK-test=1\r\n" "\r\n" "^n:ds[4U"; static const char* kHandshakeRequestExpected = "GET /demo HTTP/1.1\r\n" "Host: example.com\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\n" "Sec-WebSocket-Protocol: sample\r\n" "Upgrade: WebSocket\r\n" "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n" "Origin: http://example.com\r\n" "Cookie: CR-test=1; CR-test-httponly=1\r\n" "\r\n" "^n:ds[4U"; bool sent = websocket_->SendData(kHandshakeRequestMessage, strlen(kHandshakeRequestMessage)); EXPECT_TRUE(sent); MessageLoop::current()->RunAllPending(); EXPECT_EQ(kHandshakeRequestExpected, socket_->sent_data()); EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState()); websocket_->OnSentData(socket_, strlen(kHandshakeRequestExpected)); EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent()); const char kHandshakeResponseMessage[] = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Origin: http://example.com\r\n" "Sec-WebSocket-Location: ws://example.com/demo\r\n" "Sec-WebSocket-Protocol: sample\r\n" "Set-Cookie: CR-set-test=1\r\n" "\r\n" "8jKS'y:G*Co,Wxa-"; static const char* kHandshakeResponseExpected = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Origin: http://example.com\r\n" "Sec-WebSocket-Location: ws://example.com/demo\r\n" "Sec-WebSocket-Protocol: sample\r\n" "\r\n" "8jKS'y:G*Co,Wxa-"; websocket_->OnReceivedData(socket_.get(), kHandshakeResponseMessage, strlen(kHandshakeResponseMessage)); MessageLoop::current()->RunAllPending(); EXPECT_EQ(kHandshakeResponseExpected, delegate.received_data()); EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState()); EXPECT_EQ(3U, cookie_store_->entries().size()); EXPECT_EQ(cookieUrl, cookie_store_->entries()[0].url); EXPECT_EQ("CR-test=1", cookie_store_->entries()[0].cookie_line); EXPECT_EQ(cookieUrl, cookie_store_->entries()[1].url); EXPECT_EQ("CR-test-httponly=1", cookie_store_->entries()[1].cookie_line); EXPECT_EQ(cookieUrl, cookie_store_->entries()[2].url); EXPECT_EQ("CR-set-test=1", cookie_store_->entries()[2].cookie_line); CloseWebSocketJob(); } TEST_F(WebSocketJobTest, HandshakeWithCookieButNotAllowed) { GURL url("ws://example.com/demo"); GURL cookieUrl("http://example.com/demo"); CookieOptions cookie_options; cookie_store_->SetCookieWithOptions( cookieUrl, "CR-test=1", cookie_options); cookie_options.set_include_httponly(); cookie_store_->SetCookieWithOptions( cookieUrl, "CR-test-httponly=1", cookie_options); cookie_policy_->set_allow_all_cookies(false); MockSocketStreamDelegate delegate; InitWebSocketJob(url, &delegate); static const char* kHandshakeRequestMessage = "GET /demo HTTP/1.1\r\n" "Host: example.com\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\n" "Sec-WebSocket-Protocol: sample\r\n" "Upgrade: WebSocket\r\n" "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n" "Origin: http://example.com\r\n" "Cookie: WK-test=1\r\n" "\r\n" "^n:ds[4U"; static const char* kHandshakeRequestExpected = "GET /demo HTTP/1.1\r\n" "Host: example.com\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\n" "Sec-WebSocket-Protocol: sample\r\n" "Upgrade: WebSocket\r\n" "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n" "Origin: http://example.com\r\n" "\r\n" "^n:ds[4U"; bool sent = websocket_->SendData(kHandshakeRequestMessage, strlen(kHandshakeRequestMessage)); EXPECT_TRUE(sent); MessageLoop::current()->RunAllPending(); EXPECT_EQ(kHandshakeRequestExpected, socket_->sent_data()); EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState()); websocket_->OnSentData(socket_, strlen(kHandshakeRequestExpected)); EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent()); const char kHandshakeResponseMessage[] = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Origin: http://example.com\r\n" "Sec-WebSocket-Location: ws://example.com/demo\r\n" "Sec-WebSocket-Protocol: sample\r\n" "Set-Cookie: CR-set-test=1\r\n" "\r\n" "8jKS'y:G*Co,Wxa-"; static const char* kHandshakeResponseExpected = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Origin: http://example.com\r\n" "Sec-WebSocket-Location: ws://example.com/demo\r\n" "Sec-WebSocket-Protocol: sample\r\n" "\r\n" "8jKS'y:G*Co,Wxa-"; websocket_->OnReceivedData(socket_.get(), kHandshakeResponseMessage, strlen(kHandshakeResponseMessage)); MessageLoop::current()->RunAllPending(); EXPECT_EQ(kHandshakeResponseExpected, delegate.received_data()); EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState()); EXPECT_EQ(2U, cookie_store_->entries().size()); EXPECT_EQ(cookieUrl, cookie_store_->entries()[0].url); EXPECT_EQ("CR-test=1", cookie_store_->entries()[0].cookie_line); EXPECT_EQ(cookieUrl, cookie_store_->entries()[1].url); EXPECT_EQ("CR-test-httponly=1", cookie_store_->entries()[1].cookie_line); CloseWebSocketJob(); } TEST_F(WebSocketJobTest, HSTSUpgrade) { GURL url("ws://upgrademe.com/"); MockSocketStreamDelegate delegate; scoped_refptr<SocketStreamJob> job = SocketStreamJob::CreateSocketStreamJob( url, &delegate, *context_.get()); EXPECT_TRUE(GetSocket(job.get())->is_secure()); job->DetachDelegate(); url = GURL("ws://donotupgrademe.com/"); job = SocketStreamJob::CreateSocketStreamJob( url, &delegate, *context_.get()); EXPECT_FALSE(GetSocket(job.get())->is_secure()); job->DetachDelegate(); } } // namespace net