// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <string>
#include <vector>

#include "base/memory/scoped_ptr.h"
#include "base/string_split.h"
#include "base/string_util.h"
#include "base/stringprintf.h"
#include "net/websockets/websocket_handshake.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "testing/platform_test.h"

namespace net {

class WebSocketHandshakeTest : public testing::Test {
 public:
  static void SetUpParameter(WebSocketHandshake* handshake,
                             uint32 number_1, uint32 number_2,
                             const std::string& key_1, const std::string& key_2,
                             const std::string& key_3) {
    WebSocketHandshake::Parameter* parameter =
        new WebSocketHandshake::Parameter;
    parameter->number_1_ = number_1;
    parameter->number_2_ = number_2;
    parameter->key_1_ = key_1;
    parameter->key_2_ = key_2;
    parameter->key_3_ = key_3;
    handshake->parameter_.reset(parameter);
  }

  static void ExpectHeaderEquals(const std::string& expected,
                          const std::string& actual) {
    std::vector<std::string> expected_lines;
    Tokenize(expected, "\r\n", &expected_lines);
    std::vector<std::string> actual_lines;
    Tokenize(actual, "\r\n", &actual_lines);
    // Request lines.
    EXPECT_EQ(expected_lines[0], actual_lines[0]);

    std::vector<std::string> expected_headers;
    for (size_t i = 1; i < expected_lines.size(); i++) {
      // Finish at first CRLF CRLF.  Note that /key_3/ might include CRLF.
      if (expected_lines[i] == "")
        break;
      expected_headers.push_back(expected_lines[i]);
    }
    sort(expected_headers.begin(), expected_headers.end());

    std::vector<std::string> actual_headers;
    for (size_t i = 1; i < actual_lines.size(); i++) {
      // Finish at first CRLF CRLF.  Note that /key_3/ might include CRLF.
      if (actual_lines[i] == "")
        break;
      actual_headers.push_back(actual_lines[i]);
    }
    sort(actual_headers.begin(), actual_headers.end());

    EXPECT_EQ(expected_headers.size(), actual_headers.size())
        << "expected:" << expected
        << "\nactual:" << actual;
    for (size_t i = 0; i < expected_headers.size(); i++) {
      EXPECT_EQ(expected_headers[i], actual_headers[i]);
    }
  }

  static void ExpectHandshakeMessageEquals(const std::string& expected,
                                           const std::string& actual) {
    // Headers.
    ExpectHeaderEquals(expected, actual);
    // Compare tailing \r\n\r\n<key3> (4 + 8 bytes).
    ASSERT_GT(expected.size(), 12U);
    const char* expected_key3 = expected.data() + expected.size() - 12;
    EXPECT_GT(actual.size(), 12U);
    if (actual.size() <= 12U)
      return;
    const char* actual_key3 = actual.data() + actual.size() - 12;
    EXPECT_TRUE(memcmp(expected_key3, actual_key3, 12) == 0)
        << "expected_key3:" << DumpKey(expected_key3, 12)
        << ", actual_key3:" << DumpKey(actual_key3, 12);
  }

  static std::string DumpKey(const char* buf, int len) {
    std::string s;
    for (int i = 0; i < len; i++) {
      if (isprint(buf[i]))
        s += base::StringPrintf("%c", buf[i]);
      else
        s += base::StringPrintf("\\x%02x", buf[i]);
    }
    return s;
  }

  static std::string GetResourceName(WebSocketHandshake* handshake) {
    return handshake->GetResourceName();
  }
  static std::string GetHostFieldValue(WebSocketHandshake* handshake) {
    return handshake->GetHostFieldValue();
  }
  static std::string GetOriginFieldValue(WebSocketHandshake* handshake) {
    return handshake->GetOriginFieldValue();
  }
};


TEST_F(WebSocketHandshakeTest, Connect) {
  const std::string kExpectedClientHandshakeMessage =
      "GET /demo HTTP/1.1\r\n"
      "Upgrade: WebSocket\r\n"
      "Connection: Upgrade\r\n"
      "Host: example.com\r\n"
      "Origin: http://example.com\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "Sec-WebSocket-Key1: 388P O503D&ul7 {K%gX( %7  15\r\n"
      "Sec-WebSocket-Key2: 1 N ?|k UT0or 3o  4 I97N 5-S3O 31\r\n"
      "\r\n"
      "\x47\x30\x22\x2D\x5A\x3F\x47\x58";

  scoped_ptr<WebSocketHandshake> handshake(
      new WebSocketHandshake(GURL("ws://example.com/demo"),
                             "http://example.com",
                             "ws://example.com/demo",
                             "sample"));
  SetUpParameter(handshake.get(), 777007543U, 114997259U,
                 "388P O503D&ul7 {K%gX( %7  15",
                 "1 N ?|k UT0or 3o  4 I97N 5-S3O 31",
                 std::string("\x47\x30\x22\x2D\x5A\x3F\x47\x58", 8));
  EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
  ExpectHandshakeMessageEquals(
      kExpectedClientHandshakeMessage,
      handshake->CreateClientHandshakeMessage());

  const char kResponse[] = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
      "Upgrade: WebSocket\r\n"
      "Connection: Upgrade\r\n"
      "Sec-WebSocket-Origin: http://example.com\r\n"
      "Sec-WebSocket-Location: ws://example.com/demo\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "\r\n"
      "\x30\x73\x74\x33\x52\x6C\x26\x71\x2D\x32\x5A\x55\x5E\x77\x65\x75";
  std::vector<std::string> response_lines;
  base::SplitStringDontTrim(kResponse, '\n', &response_lines);

  EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
  // too short
  EXPECT_EQ(-1, handshake->ReadServerHandshake(kResponse, 16));
  EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());

  // only status line
  std::string response = response_lines[0];
  EXPECT_EQ(-1, handshake->ReadServerHandshake(
      response.data(), response.size()));
  EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
  // by upgrade header
  response += response_lines[1];
  EXPECT_EQ(-1, handshake->ReadServerHandshake(
      response.data(), response.size()));
  EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
  // by connection header
  response += response_lines[2];
  EXPECT_EQ(-1, handshake->ReadServerHandshake(
      response.data(), response.size()));
  EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());

  response += response_lines[3];  // Sec-WebSocket-Origin
  response += response_lines[4];  // Sec-WebSocket-Location
  response += response_lines[5];  // Sec-WebSocket-Protocol
  EXPECT_EQ(-1, handshake->ReadServerHandshake(
      response.data(), response.size()));
  EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());

  response += response_lines[6];  // \r\n
  EXPECT_EQ(-1, handshake->ReadServerHandshake(
      response.data(), response.size()));
  EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());

  int handshake_length = sizeof(kResponse) - 1;  // -1 for terminating \0
  EXPECT_EQ(handshake_length, handshake->ReadServerHandshake(
      kResponse, handshake_length));  // -1 for terminating \0
  EXPECT_EQ(WebSocketHandshake::MODE_CONNECTED, handshake->mode());
}

TEST_F(WebSocketHandshakeTest, ServerSentData) {
  const std::string kExpectedClientHandshakeMessage =
      "GET /demo HTTP/1.1\r\n"
      "Upgrade: WebSocket\r\n"
      "Connection: Upgrade\r\n"
      "Host: example.com\r\n"
      "Origin: http://example.com\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "Sec-WebSocket-Key1: 388P O503D&ul7 {K%gX( %7  15\r\n"
      "Sec-WebSocket-Key2: 1 N ?|k UT0or 3o  4 I97N 5-S3O 31\r\n"
      "\r\n"
      "\x47\x30\x22\x2D\x5A\x3F\x47\x58";
  scoped_ptr<WebSocketHandshake> handshake(
      new WebSocketHandshake(GURL("ws://example.com/demo"),
                             "http://example.com",
                             "ws://example.com/demo",
                             "sample"));
  SetUpParameter(handshake.get(), 777007543U, 114997259U,
                 "388P O503D&ul7 {K%gX( %7  15",
                 "1 N ?|k UT0or 3o  4 I97N 5-S3O 31",
                 std::string("\x47\x30\x22\x2D\x5A\x3F\x47\x58", 8));
  EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
  ExpectHandshakeMessageEquals(
      kExpectedClientHandshakeMessage,
      handshake->CreateClientHandshakeMessage());

  const char kResponse[] = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
      "Upgrade: WebSocket\r\n"
      "Connection: Upgrade\r\n"
      "Sec-WebSocket-Origin: http://example.com\r\n"
      "Sec-WebSocket-Location: ws://example.com/demo\r\n"
      "Sec-WebSocket-Protocol: sample\r\n"
      "\r\n"
      "\x30\x73\x74\x33\x52\x6C\x26\x71\x2D\x32\x5A\x55\x5E\x77\x65\x75"
      "\0Hello\xff";

  int handshake_length = strlen(kResponse);  // key3 doesn't contain \0.
  EXPECT_EQ(handshake_length, handshake->ReadServerHandshake(
      kResponse, sizeof(kResponse) - 1));  // -1 for terminating \0
  EXPECT_EQ(WebSocketHandshake::MODE_CONNECTED, handshake->mode());
}

TEST_F(WebSocketHandshakeTest, is_secure_false) {
  scoped_ptr<WebSocketHandshake> handshake(
      new WebSocketHandshake(GURL("ws://example.com/demo"),
                             "http://example.com",
                             "ws://example.com/demo",
                             "sample"));
  EXPECT_FALSE(handshake->is_secure());
}

TEST_F(WebSocketHandshakeTest, is_secure_true) {
  // wss:// is secure.
  scoped_ptr<WebSocketHandshake> handshake(
      new WebSocketHandshake(GURL("wss://example.com/demo"),
                             "http://example.com",
                             "wss://example.com/demo",
                             "sample"));
  EXPECT_TRUE(handshake->is_secure());
}

TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_ResourceName) {
  scoped_ptr<WebSocketHandshake> handshake(
      new WebSocketHandshake(GURL("ws://example.com/Test?q=xxx&p=%20"),
                             "http://example.com",
                             "ws://example.com/demo",
                             "sample"));
  // Path and query should be preserved as-is.
  EXPECT_EQ("/Test?q=xxx&p=%20", GetResourceName(handshake.get()));
}

TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_Host) {
  scoped_ptr<WebSocketHandshake> handshake(
      new WebSocketHandshake(GURL("ws://Example.Com/demo"),
                             "http://Example.Com",
                             "ws://Example.Com/demo",
                             "sample"));
  // Host should be lowercased
  EXPECT_EQ("example.com", GetHostFieldValue(handshake.get()));
  EXPECT_EQ("http://example.com", GetOriginFieldValue(handshake.get()));
}

TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_TrimPort80) {
  scoped_ptr<WebSocketHandshake> handshake(
      new WebSocketHandshake(GURL("ws://example.com:80/demo"),
                             "http://example.com",
                             "ws://example.com/demo",
                             "sample"));
  // :80 should be trimmed as it's the default port for ws://.
  EXPECT_EQ("example.com", GetHostFieldValue(handshake.get()));
}

TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_TrimPort443) {
  scoped_ptr<WebSocketHandshake> handshake(
      new WebSocketHandshake(GURL("wss://example.com:443/demo"),
                             "http://example.com",
                             "wss://example.com/demo",
                             "sample"));
  // :443 should be trimmed as it's the default port for wss://.
  EXPECT_EQ("example.com", GetHostFieldValue(handshake.get()));
}

TEST_F(WebSocketHandshakeTest,
       CreateClientHandshakeMessage_NonDefaultPortForWs) {
  scoped_ptr<WebSocketHandshake> handshake(
      new WebSocketHandshake(GURL("ws://example.com:8080/demo"),
                             "http://example.com",
                             "wss://example.com/demo",
                             "sample"));
  // :8080 should be preserved as it's not the default port for ws://.
  EXPECT_EQ("example.com:8080", GetHostFieldValue(handshake.get()));
}

TEST_F(WebSocketHandshakeTest,
     CreateClientHandshakeMessage_NonDefaultPortForWss) {
  scoped_ptr<WebSocketHandshake> handshake(
      new WebSocketHandshake(GURL("wss://example.com:4443/demo"),
                             "http://example.com",
                             "wss://example.com/demo",
                             "sample"));
  // :4443 should be preserved as it's not the default port for wss://.
  EXPECT_EQ("example.com:4443", GetHostFieldValue(handshake.get()));
}

TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_WsBut443) {
  scoped_ptr<WebSocketHandshake> handshake(
      new WebSocketHandshake(GURL("ws://example.com:443/demo"),
                             "http://example.com",
                             "ws://example.com/demo",
                             "sample"));
  // :443 should be preserved as it's not the default port for ws://.
  EXPECT_EQ("example.com:443", GetHostFieldValue(handshake.get()));
}

TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_WssBut80) {
  scoped_ptr<WebSocketHandshake> handshake(
      new WebSocketHandshake(GURL("wss://example.com:80/demo"),
                             "http://example.com",
                             "wss://example.com/demo",
                             "sample"));
  // :80 should be preserved as it's not the default port for wss://.
  EXPECT_EQ("example.com:80", GetHostFieldValue(handshake.get()));
}

}  // namespace net