普通文本  |  155行  |  4.36 KB

// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/websockets/websocket_handshake_draft75.h"

#include "base/memory/ref_counted.h"
#include "base/string_util.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_util.h"

namespace net {

const char WebSocketHandshakeDraft75::kServerHandshakeHeader[] =
    "HTTP/1.1 101 Web Socket Protocol Handshake\r\n";
const size_t WebSocketHandshakeDraft75::kServerHandshakeHeaderLength =
    sizeof(kServerHandshakeHeader) - 1;

const char WebSocketHandshakeDraft75::kUpgradeHeader[] =
    "Upgrade: WebSocket\r\n";
const size_t WebSocketHandshakeDraft75::kUpgradeHeaderLength =
    sizeof(kUpgradeHeader) - 1;

const char WebSocketHandshakeDraft75::kConnectionHeader[] =
    "Connection: Upgrade\r\n";
const size_t WebSocketHandshakeDraft75::kConnectionHeaderLength =
    sizeof(kConnectionHeader) - 1;

WebSocketHandshakeDraft75::WebSocketHandshakeDraft75(
    const GURL& url,
    const std::string& origin,
    const std::string& location,
    const std::string& protocol)
    : WebSocketHandshake(url, origin, location, protocol) {
}

WebSocketHandshakeDraft75::~WebSocketHandshakeDraft75() {
}

std::string WebSocketHandshakeDraft75::CreateClientHandshakeMessage() {
  std::string msg;
  msg = "GET ";
  msg += GetResourceName();
  msg += " HTTP/1.1\r\n";
  msg += kUpgradeHeader;
  msg += kConnectionHeader;
  msg += "Host: ";
  msg += GetHostFieldValue();
  msg += "\r\n";
  msg += "Origin: ";
  msg += GetOriginFieldValue();
  msg += "\r\n";
  if (!protocol_.empty()) {
    msg += "WebSocket-Protocol: ";
    msg += protocol_;
    msg += "\r\n";
  }
  // TODO(ukai): Add cookie if necessary.
  msg += "\r\n";
  return msg;
}

int WebSocketHandshakeDraft75::ReadServerHandshake(
    const char* data, size_t len) {
  mode_ = MODE_INCOMPLETE;
  if (len < kServerHandshakeHeaderLength) {
    return -1;
  }
  if (!memcmp(data, kServerHandshakeHeader, kServerHandshakeHeaderLength)) {
    mode_ = MODE_NORMAL;
  } else {
    int eoh = HttpUtil::LocateEndOfHeaders(data, len);
    if (eoh < 0)
      return -1;
    return eoh;
  }
  const char* p = data + kServerHandshakeHeaderLength;
  const char* end = data + len;

  if (mode_ == MODE_NORMAL) {
    size_t header_size = end - p;
    if (header_size < kUpgradeHeaderLength)
      return -1;
    if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) {
      mode_ = MODE_FAILED;
      DVLOG(1) << "Bad Upgrade Header " << std::string(p, kUpgradeHeaderLength);
      return p - data;
    }
    p += kUpgradeHeaderLength;
    header_size = end - p;
    if (header_size < kConnectionHeaderLength)
      return -1;
    if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) {
      mode_ = MODE_FAILED;
      DVLOG(1) << "Bad Connection Header "
               << std::string(p, kConnectionHeaderLength);
      return p - data;
    }
    p += kConnectionHeaderLength;
  }

  int eoh = HttpUtil::LocateEndOfHeaders(data, len);
  if (eoh == -1)
    return eoh;

  scoped_refptr<HttpResponseHeaders> headers(
      new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh)));
  if (!ProcessHeaders(*headers)) {
    DVLOG(1) << "Process Headers failed: " << std::string(data, eoh);
    mode_ = MODE_FAILED;
  }
  switch (mode_) {
    case MODE_NORMAL:
      if (CheckResponseHeaders()) {
        mode_ = MODE_CONNECTED;
      } else {
        mode_ = MODE_FAILED;
      }
      break;
    default:
      mode_ = MODE_FAILED;
      break;
  }
  return eoh;
}

bool WebSocketHandshakeDraft75::ProcessHeaders(
    const HttpResponseHeaders& headers) {
  if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_))
    return false;

  if (!GetSingleHeader(headers, "websocket-location", &ws_location_))
    return false;

  // If |protocol_| is not specified by client, we don't care if there's
  // protocol field or not as specified in the spec.
  if (!protocol_.empty()
      && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_))
    return false;
  return true;
}

bool WebSocketHandshakeDraft75::CheckResponseHeaders() const {
  DCHECK(mode_ == MODE_NORMAL);
  if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str()))
    return false;
  if (location_ != ws_location_)
    return false;
  if (!protocol_.empty() && protocol_ != ws_protocol_)
    return false;
  return true;
}

}  // namespace net