普通文本  |  524行  |  17.77 KB

// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <string>
#include <vector>

#include "base/memory/ref_counted.h"
#include "base/string_split.h"
#include "base/string_util.h"
#include "googleurl/src/gurl.h"
#include "net/base/cookie_policy.h"
#include "net/base/cookie_store.h"
#include "net/base/net_errors.h"
#include "net/base/sys_addrinfo.h"
#include "net/base/transport_security_state.h"
#include "net/socket_stream/socket_stream.h"
#include "net/url_request/url_request_context.h"
#include "net/websockets/websocket_job.h"
#include "net/websockets/websocket_throttle.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/platform_test.h"

namespace net {

class MockSocketStream : public SocketStream {
 public:
  MockSocketStream(const GURL& url, SocketStream::Delegate* delegate)
      : SocketStream(url, delegate) {}
  virtual ~MockSocketStream() {}

  virtual void Connect() {}
  virtual bool SendData(const char* data, int len) {
    sent_data_ += std::string(data, len);
    return true;
  }

  virtual void Close() {}
  virtual void RestartWithAuth(
      const string16& username, const string16& password) {}
  virtual void DetachDelegate() {
    delegate_ = NULL;
  }

  const std::string& sent_data() const {
    return sent_data_;
  }

 private:
  std::string sent_data_;
};

class MockSocketStreamDelegate : public SocketStream::Delegate {
 public:
  MockSocketStreamDelegate()
      : amount_sent_(0) {}
  virtual ~MockSocketStreamDelegate() {}

  virtual void OnConnected(SocketStream* socket, int max_pending_send_allowed) {
  }
  virtual void OnSentData(SocketStream* socket, int amount_sent) {
    amount_sent_ += amount_sent;
  }
  virtual void OnReceivedData(SocketStream* socket,
                              const char* data, int len) {
    received_data_ += std::string(data, len);
  }
  virtual void OnClose(SocketStream* socket) {
  }

  size_t amount_sent() const { return amount_sent_; }
  const std::string& received_data() const { return received_data_; }

 private:
  int amount_sent_;
  std::string received_data_;
};

class MockCookieStore : public CookieStore {
 public:
  struct Entry {
    GURL url;
    std::string cookie_line;
    CookieOptions options;
  };
  MockCookieStore() {}

  virtual bool SetCookieWithOptions(const GURL& url,
                                    const std::string& cookie_line,
                                    const CookieOptions& options) {
    Entry entry;
    entry.url = url;
    entry.cookie_line = cookie_line;
    entry.options = options;
    entries_.push_back(entry);
    return true;
  }
  virtual std::string GetCookiesWithOptions(const GURL& url,
                                            const CookieOptions& options) {
    std::string result;
    for (size_t i = 0; i < entries_.size(); i++) {
      Entry &entry = entries_[i];
      if (url == entry.url) {
        if (!result.empty()) {
          result += "; ";
        }
        result += entry.cookie_line;
      }
    }
    return result;
  }
  virtual void DeleteCookie(const GURL& url,
                            const std::string& cookie_name) {}
  virtual CookieMonster* GetCookieMonster() { return NULL; }

  const std::vector<Entry>& entries() const { return entries_; }

 private:
  friend class base::RefCountedThreadSafe<MockCookieStore>;
  virtual ~MockCookieStore() {}

  std::vector<Entry> entries_;
};

class MockCookiePolicy : public CookiePolicy {
 public:
  MockCookiePolicy() : allow_all_cookies_(true) {}
  virtual ~MockCookiePolicy() {}

  void set_allow_all_cookies(bool allow_all_cookies) {
    allow_all_cookies_ = allow_all_cookies;
  }

  virtual int CanGetCookies(const GURL& url,
                            const GURL& first_party_for_cookies) const {
    if (allow_all_cookies_)
      return OK;
    return ERR_ACCESS_DENIED;
  }

  virtual int CanSetCookie(const GURL& url,
                           const GURL& first_party_for_cookies,
                           const std::string& cookie_line) const {
    if (allow_all_cookies_)
      return OK;
    return ERR_ACCESS_DENIED;
  }

 private:
  bool allow_all_cookies_;
};

class MockURLRequestContext : public URLRequestContext {
 public:
  MockURLRequestContext(CookieStore* cookie_store,
                        CookiePolicy* cookie_policy) {
    set_cookie_store(cookie_store);
    set_cookie_policy(cookie_policy);
    transport_security_state_ = new TransportSecurityState();
    set_transport_security_state(transport_security_state_.get());
    TransportSecurityState::DomainState state;
    state.expiry = base::Time::Now() + base::TimeDelta::FromSeconds(1000);
    transport_security_state_->EnableHost("upgrademe.com", state);
  }

 private:
  friend class base::RefCountedThreadSafe<MockURLRequestContext>;
  virtual ~MockURLRequestContext() {}

  scoped_refptr<TransportSecurityState> transport_security_state_;
};

class WebSocketJobTest : public PlatformTest {
 public:
  virtual void SetUp() {
    cookie_store_ = new MockCookieStore;
    cookie_policy_.reset(new MockCookiePolicy);
    context_ = new MockURLRequestContext(
        cookie_store_.get(), cookie_policy_.get());
  }
  virtual void TearDown() {
    cookie_store_ = NULL;
    cookie_policy_.reset();
    context_ = NULL;
    websocket_ = NULL;
    socket_ = NULL;
  }
 protected:
  void InitWebSocketJob(const GURL& url, MockSocketStreamDelegate* delegate) {
    websocket_ = new WebSocketJob(delegate);
    socket_ = new MockSocketStream(url, websocket_.get());
    websocket_->InitSocketStream(socket_.get());
    websocket_->set_context(context_.get());
    websocket_->state_ = WebSocketJob::CONNECTING;
    struct addrinfo addr;
    memset(&addr, 0, sizeof(struct addrinfo));
    addr.ai_family = AF_INET;
    addr.ai_addrlen = sizeof(struct sockaddr_in);
    struct sockaddr_in sa_in;
    memset(&sa_in, 0, sizeof(struct sockaddr_in));
    memcpy(&sa_in.sin_addr, "\x7f\0\0\1", 4);
    addr.ai_addr = reinterpret_cast<sockaddr*>(&sa_in);
    addr.ai_next = NULL;
    websocket_->addresses_.Copy(&addr, true);
    WebSocketThrottle::GetInstance()->PutInQueue(websocket_);
  }
  WebSocketJob::State GetWebSocketJobState() {
    return websocket_->state_;
  }
  void CloseWebSocketJob() {
    if (websocket_->socket_) {
      websocket_->socket_->DetachDelegate();
      WebSocketThrottle::GetInstance()->RemoveFromQueue(websocket_);
    }
    websocket_->state_ = WebSocketJob::CLOSED;
    websocket_->delegate_ = NULL;
    websocket_->socket_ = NULL;
  }
  SocketStream* GetSocket(SocketStreamJob* job) {
    return job->socket_.get();
  }

  scoped_refptr<MockCookieStore> cookie_store_;
  scoped_ptr<MockCookiePolicy> cookie_policy_;
  scoped_refptr<MockURLRequestContext> context_;
  scoped_refptr<WebSocketJob> websocket_;
  scoped_refptr<MockSocketStream> socket_;
};

TEST_F(WebSocketJobTest, SimpleHandshake) {
  GURL url("ws://example.com/demo");
  MockSocketStreamDelegate delegate;
  InitWebSocketJob(url, &delegate);

  static const char* kHandshakeRequestMessage =
      "GET /demo HTTP/1.1\r\n"
      "Host: example.com\r\n"
      "Connection: Upgrade\r\n"
      "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "Upgrade: WebSocket\r\n"
      "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
      "Origin: http://example.com\r\n"
      "\r\n"
      "^n:ds[4U";

  bool sent = websocket_->SendData(kHandshakeRequestMessage,
                                   strlen(kHandshakeRequestMessage));
  EXPECT_TRUE(sent);
  MessageLoop::current()->RunAllPending();
  EXPECT_EQ(kHandshakeRequestMessage, socket_->sent_data());
  EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
  websocket_->OnSentData(socket_.get(), strlen(kHandshakeRequestMessage));
  EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent());

  const char kHandshakeResponseMessage[] =
      "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
      "Upgrade: WebSocket\r\n"
      "Connection: Upgrade\r\n"
      "Sec-WebSocket-Origin: http://example.com\r\n"
      "Sec-WebSocket-Location: ws://example.com/demo\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "\r\n"
      "8jKS'y:G*Co,Wxa-";

  websocket_->OnReceivedData(socket_.get(),
                             kHandshakeResponseMessage,
                             strlen(kHandshakeResponseMessage));
  MessageLoop::current()->RunAllPending();
  EXPECT_EQ(kHandshakeResponseMessage, delegate.received_data());
  EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState());
  CloseWebSocketJob();
}

TEST_F(WebSocketJobTest, SlowHandshake) {
  GURL url("ws://example.com/demo");
  MockSocketStreamDelegate delegate;
  InitWebSocketJob(url, &delegate);

  static const char* kHandshakeRequestMessage =
      "GET /demo HTTP/1.1\r\n"
      "Host: example.com\r\n"
      "Connection: Upgrade\r\n"
      "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "Upgrade: WebSocket\r\n"
      "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
      "Origin: http://example.com\r\n"
      "\r\n"
      "^n:ds[4U";

  bool sent = websocket_->SendData(kHandshakeRequestMessage,
                                   strlen(kHandshakeRequestMessage));
  EXPECT_TRUE(sent);
  // We assume request is sent in one data chunk (from WebKit)
  // We don't support streaming request.
  MessageLoop::current()->RunAllPending();
  EXPECT_EQ(kHandshakeRequestMessage, socket_->sent_data());
  EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
  websocket_->OnSentData(socket_.get(), strlen(kHandshakeRequestMessage));
  EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent());

  const char kHandshakeResponseMessage[] =
      "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
      "Upgrade: WebSocket\r\n"
      "Connection: Upgrade\r\n"
      "Sec-WebSocket-Origin: http://example.com\r\n"
      "Sec-WebSocket-Location: ws://example.com/demo\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "\r\n"
      "8jKS'y:G*Co,Wxa-";

  std::vector<std::string> lines;
  base::SplitString(kHandshakeResponseMessage, '\n', &lines);
  for (size_t i = 0; i < lines.size() - 2; i++) {
    std::string line = lines[i] + "\r\n";
    SCOPED_TRACE("Line: " + line);
    websocket_->OnReceivedData(socket_,
                               line.c_str(),
                               line.size());
    MessageLoop::current()->RunAllPending();
    EXPECT_TRUE(delegate.received_data().empty());
    EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
  }
  websocket_->OnReceivedData(socket_.get(), "\r\n", 2);
  MessageLoop::current()->RunAllPending();
  EXPECT_TRUE(delegate.received_data().empty());
  EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
  websocket_->OnReceivedData(socket_.get(), "8jKS'y:G*Co,Wxa-", 16);
  EXPECT_EQ(kHandshakeResponseMessage, delegate.received_data());
  EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState());
  CloseWebSocketJob();
}

TEST_F(WebSocketJobTest, HandshakeWithCookie) {
  GURL url("ws://example.com/demo");
  GURL cookieUrl("http://example.com/demo");
  CookieOptions cookie_options;
  cookie_store_->SetCookieWithOptions(
      cookieUrl, "CR-test=1", cookie_options);
  cookie_options.set_include_httponly();
  cookie_store_->SetCookieWithOptions(
      cookieUrl, "CR-test-httponly=1", cookie_options);

  MockSocketStreamDelegate delegate;
  InitWebSocketJob(url, &delegate);

  static const char* kHandshakeRequestMessage =
      "GET /demo HTTP/1.1\r\n"
      "Host: example.com\r\n"
      "Connection: Upgrade\r\n"
      "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "Upgrade: WebSocket\r\n"
      "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
      "Origin: http://example.com\r\n"
      "Cookie: WK-test=1\r\n"
      "\r\n"
      "^n:ds[4U";

  static const char* kHandshakeRequestExpected =
      "GET /demo HTTP/1.1\r\n"
      "Host: example.com\r\n"
      "Connection: Upgrade\r\n"
      "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "Upgrade: WebSocket\r\n"
      "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
      "Origin: http://example.com\r\n"
      "Cookie: CR-test=1; CR-test-httponly=1\r\n"
      "\r\n"
      "^n:ds[4U";

  bool sent = websocket_->SendData(kHandshakeRequestMessage,
                                   strlen(kHandshakeRequestMessage));
  EXPECT_TRUE(sent);
  MessageLoop::current()->RunAllPending();
  EXPECT_EQ(kHandshakeRequestExpected, socket_->sent_data());
  EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
  websocket_->OnSentData(socket_, strlen(kHandshakeRequestExpected));
  EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent());

  const char kHandshakeResponseMessage[] =
      "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
      "Upgrade: WebSocket\r\n"
      "Connection: Upgrade\r\n"
      "Sec-WebSocket-Origin: http://example.com\r\n"
      "Sec-WebSocket-Location: ws://example.com/demo\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "Set-Cookie: CR-set-test=1\r\n"
      "\r\n"
      "8jKS'y:G*Co,Wxa-";

  static const char* kHandshakeResponseExpected =
      "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
      "Upgrade: WebSocket\r\n"
      "Connection: Upgrade\r\n"
      "Sec-WebSocket-Origin: http://example.com\r\n"
      "Sec-WebSocket-Location: ws://example.com/demo\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "\r\n"
      "8jKS'y:G*Co,Wxa-";

  websocket_->OnReceivedData(socket_.get(),
                             kHandshakeResponseMessage,
                             strlen(kHandshakeResponseMessage));
  MessageLoop::current()->RunAllPending();
  EXPECT_EQ(kHandshakeResponseExpected, delegate.received_data());
  EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState());

  EXPECT_EQ(3U, cookie_store_->entries().size());
  EXPECT_EQ(cookieUrl, cookie_store_->entries()[0].url);
  EXPECT_EQ("CR-test=1", cookie_store_->entries()[0].cookie_line);
  EXPECT_EQ(cookieUrl, cookie_store_->entries()[1].url);
  EXPECT_EQ("CR-test-httponly=1", cookie_store_->entries()[1].cookie_line);
  EXPECT_EQ(cookieUrl, cookie_store_->entries()[2].url);
  EXPECT_EQ("CR-set-test=1", cookie_store_->entries()[2].cookie_line);

  CloseWebSocketJob();
}

TEST_F(WebSocketJobTest, HandshakeWithCookieButNotAllowed) {
  GURL url("ws://example.com/demo");
  GURL cookieUrl("http://example.com/demo");
  CookieOptions cookie_options;
  cookie_store_->SetCookieWithOptions(
      cookieUrl, "CR-test=1", cookie_options);
  cookie_options.set_include_httponly();
  cookie_store_->SetCookieWithOptions(
      cookieUrl, "CR-test-httponly=1", cookie_options);
  cookie_policy_->set_allow_all_cookies(false);

  MockSocketStreamDelegate delegate;
  InitWebSocketJob(url, &delegate);

  static const char* kHandshakeRequestMessage =
      "GET /demo HTTP/1.1\r\n"
      "Host: example.com\r\n"
      "Connection: Upgrade\r\n"
      "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "Upgrade: WebSocket\r\n"
      "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
      "Origin: http://example.com\r\n"
      "Cookie: WK-test=1\r\n"
      "\r\n"
      "^n:ds[4U";

  static const char* kHandshakeRequestExpected =
      "GET /demo HTTP/1.1\r\n"
      "Host: example.com\r\n"
      "Connection: Upgrade\r\n"
      "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "Upgrade: WebSocket\r\n"
      "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
      "Origin: http://example.com\r\n"
      "\r\n"
      "^n:ds[4U";

  bool sent = websocket_->SendData(kHandshakeRequestMessage,
                                   strlen(kHandshakeRequestMessage));
  EXPECT_TRUE(sent);
  MessageLoop::current()->RunAllPending();
  EXPECT_EQ(kHandshakeRequestExpected, socket_->sent_data());
  EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
  websocket_->OnSentData(socket_, strlen(kHandshakeRequestExpected));
  EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent());

  const char kHandshakeResponseMessage[] =
      "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
      "Upgrade: WebSocket\r\n"
      "Connection: Upgrade\r\n"
      "Sec-WebSocket-Origin: http://example.com\r\n"
      "Sec-WebSocket-Location: ws://example.com/demo\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "Set-Cookie: CR-set-test=1\r\n"
      "\r\n"
      "8jKS'y:G*Co,Wxa-";

  static const char* kHandshakeResponseExpected =
      "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
      "Upgrade: WebSocket\r\n"
      "Connection: Upgrade\r\n"
      "Sec-WebSocket-Origin: http://example.com\r\n"
      "Sec-WebSocket-Location: ws://example.com/demo\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "\r\n"
      "8jKS'y:G*Co,Wxa-";

  websocket_->OnReceivedData(socket_.get(),
                             kHandshakeResponseMessage,
                             strlen(kHandshakeResponseMessage));
  MessageLoop::current()->RunAllPending();
  EXPECT_EQ(kHandshakeResponseExpected, delegate.received_data());
  EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState());

  EXPECT_EQ(2U, cookie_store_->entries().size());
  EXPECT_EQ(cookieUrl, cookie_store_->entries()[0].url);
  EXPECT_EQ("CR-test=1", cookie_store_->entries()[0].cookie_line);
  EXPECT_EQ(cookieUrl, cookie_store_->entries()[1].url);
  EXPECT_EQ("CR-test-httponly=1", cookie_store_->entries()[1].cookie_line);

  CloseWebSocketJob();
}

TEST_F(WebSocketJobTest, HSTSUpgrade) {
  GURL url("ws://upgrademe.com/");
  MockSocketStreamDelegate delegate;
  scoped_refptr<SocketStreamJob> job = SocketStreamJob::CreateSocketStreamJob(
      url, &delegate, *context_.get());
  EXPECT_TRUE(GetSocket(job.get())->is_secure());
  job->DetachDelegate();

  url = GURL("ws://donotupgrademe.com/");
  job = SocketStreamJob::CreateSocketStreamJob(
      url, &delegate, *context_.get());
  EXPECT_FALSE(GetSocket(job.get())->is_secure());
  job->DetachDelegate();
}

}  // namespace net