普通文本  |  496行  |  14.73 KB

// Copyright (c) 2009 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <algorithm>
#include <limits>

#include "net/websockets/websocket.h"

#include "base/message_loop.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_util.h"

namespace net {

static const int kWebSocketPort = 80;
static const int kSecureWebSocketPort = 443;

static const char kServerHandshakeHeader[] =
    "HTTP/1.1 101 Web Socket Protocol Handshake\r\n";
static const size_t kServerHandshakeHeaderLength =
    sizeof(kServerHandshakeHeader) - 1;

static const char kUpgradeHeader[] = "Upgrade: WebSocket\r\n";
static const size_t kUpgradeHeaderLength = sizeof(kUpgradeHeader) - 1;

static const char kConnectionHeader[] = "Connection: Upgrade\r\n";
static const size_t kConnectionHeaderLength = sizeof(kConnectionHeader) - 1;

bool WebSocket::Request::is_secure() const {
  return url_.SchemeIs("wss");
}

WebSocket::WebSocket(Request* request, WebSocketDelegate* delegate)
    : ready_state_(INITIALIZED),
      mode_(MODE_INCOMPLETE),
      request_(request),
      delegate_(delegate),
      origin_loop_(MessageLoop::current()),
      socket_stream_(NULL),
      max_pending_send_allowed_(0),
      current_read_buf_(NULL),
      read_consumed_len_(0),
      current_write_buf_(NULL) {
  DCHECK(request_.get());
  DCHECK(delegate_);
  DCHECK(origin_loop_);
}

WebSocket::~WebSocket() {
  DCHECK(ready_state_ == INITIALIZED || !delegate_);
  DCHECK(!socket_stream_);
  DCHECK(!delegate_);
}

void WebSocket::Connect() {
  DCHECK(ready_state_ == INITIALIZED);
  DCHECK(request_.get());
  DCHECK(delegate_);
  DCHECK(!socket_stream_);
  DCHECK(MessageLoop::current() == origin_loop_);

  socket_stream_ = new SocketStream(request_->url(), this);
  socket_stream_->set_context(request_->context());

  if (request_->host_resolver())
    socket_stream_->SetHostResolver(request_->host_resolver());
  if (request_->client_socket_factory())
    socket_stream_->SetClientSocketFactory(request_->client_socket_factory());

  AddRef();  // Release in DoClose().
  ready_state_ = CONNECTING;
  socket_stream_->Connect();
}

void WebSocket::Send(const std::string& msg) {
  DCHECK(ready_state_ == OPEN);
  DCHECK(MessageLoop::current() == origin_loop_);

  IOBufferWithSize* buf = new IOBufferWithSize(msg.size() + 2);
  char* p = buf->data();
  *p = '\0';
  memcpy(p + 1, msg.data(), msg.size());
  *(p + 1 + msg.size()) = '\xff';
  pending_write_bufs_.push_back(buf);
  SendPending();
}

void WebSocket::Close() {
  DCHECK(MessageLoop::current() == origin_loop_);

  if (ready_state_ == INITIALIZED) {
    DCHECK(!socket_stream_);
    ready_state_ = CLOSED;
    return;
  }
  if (ready_state_ != CLOSED) {
    DCHECK(socket_stream_);
    socket_stream_->Close();
    return;
  }
}

void WebSocket::DetachDelegate() {
  if (!delegate_)
    return;
  delegate_ = NULL;
  Close();
}

void WebSocket::OnConnected(SocketStream* socket_stream,
                            int max_pending_send_allowed) {
  DCHECK(socket_stream == socket_stream_);
  max_pending_send_allowed_ = max_pending_send_allowed;

  // Use |max_pending_send_allowed| as hint for initial size of read buffer.
  current_read_buf_ = new GrowableIOBuffer();
  current_read_buf_->SetCapacity(max_pending_send_allowed_);
  read_consumed_len_ = 0;

  DCHECK(!current_write_buf_);
  const std::string msg = request_->CreateClientHandshakeMessage();
  IOBufferWithSize* buf = new IOBufferWithSize(msg.size());
  memcpy(buf->data(), msg.data(), msg.size());
  pending_write_bufs_.push_back(buf);
  origin_loop_->PostTask(FROM_HERE,
                         NewRunnableMethod(this, &WebSocket::SendPending));
}

void WebSocket::OnSentData(SocketStream* socket_stream, int amount_sent) {
  DCHECK(socket_stream == socket_stream_);
  DCHECK(current_write_buf_);
  current_write_buf_->DidConsume(amount_sent);
  DCHECK_GE(current_write_buf_->BytesRemaining(), 0);
  if (current_write_buf_->BytesRemaining() == 0) {
    current_write_buf_ = NULL;
    pending_write_bufs_.pop_front();
  }
  origin_loop_->PostTask(FROM_HERE,
                         NewRunnableMethod(this, &WebSocket::SendPending));
}

void WebSocket::OnReceivedData(SocketStream* socket_stream,
                               const char* data, int len) {
  DCHECK(socket_stream == socket_stream_);
  AddToReadBuffer(data, len);
  origin_loop_->PostTask(FROM_HERE,
                         NewRunnableMethod(this, &WebSocket::DoReceivedData));
}

void WebSocket::OnClose(SocketStream* socket_stream) {
  origin_loop_->PostTask(FROM_HERE,
                         NewRunnableMethod(this, &WebSocket::DoClose));
}

void WebSocket::OnError(const SocketStream* socket_stream, int error) {
  origin_loop_->PostTask(FROM_HERE,
                         NewRunnableMethod(this, &WebSocket::DoError, error));
}

std::string WebSocket::Request::CreateClientHandshakeMessage() const {
  std::string msg;
  msg = "GET ";
  msg += url_.path();
  if (url_.has_query()) {
    msg += "?";
    msg += url_.query();
  }
  msg += " HTTP/1.1\r\n";
  msg += kUpgradeHeader;
  msg += kConnectionHeader;
  msg += "Host: ";
  msg += StringToLowerASCII(url_.host());
  if (url_.has_port()) {
    bool secure = is_secure();
    int port = url_.EffectiveIntPort();
    if ((!secure &&
         port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) ||
        (secure &&
         port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) {
      msg += ":";
      msg += IntToString(port);
    }
  }
  msg += "\r\n";
  msg += "Origin: ";
  // It's OK to lowercase the origin as the Origin header does not contain
  // the path or query portions, as per
  // http://tools.ietf.org/html/draft-abarth-origin-00.
  //
  // TODO(satorux): Should we trim the port portion here if it's 80 for
  // http:// or 443 for https:// ? Or can we assume it's done by the
  // client of the library?
  msg += StringToLowerASCII(origin_);
  msg += "\r\n";
  if (!protocol_.empty()) {
    msg += "WebSocket-Protocol: ";
    msg += protocol_;
    msg += "\r\n";
  }
  // TODO(ukai): Add cookie if necessary.
  msg += "\r\n";
  return msg;
}

int WebSocket::CheckHandshake() {
  DCHECK(current_read_buf_);
  DCHECK(ready_state_ == CONNECTING);
  mode_ = MODE_INCOMPLETE;
  const char *start = current_read_buf_->StartOfBuffer() + read_consumed_len_;
  const char *p = start;
  size_t len = current_read_buf_->offset() - read_consumed_len_;
  if (len < kServerHandshakeHeaderLength) {
    return -1;
  }
  if (!memcmp(p, kServerHandshakeHeader, kServerHandshakeHeaderLength)) {
    mode_ = MODE_NORMAL;
  } else {
    int eoh = HttpUtil::LocateEndOfHeaders(p, len);
    if (eoh < 0)
      return -1;
    scoped_refptr<HttpResponseHeaders> headers(
        new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(p, eoh)));
    if (headers->response_code() == 407) {
      mode_ = MODE_AUTHENTICATE;
      // TODO(ukai): Implement authentication handlers.
    }
    DLOG(INFO) << "non-normal websocket connection. "
               << "response_code=" << headers->response_code()
               << " mode=" << mode_;
    // Invalid response code.
    ready_state_ = CLOSED;
    return eoh;
  }
  const char* end = p + len + 1;
  p += kServerHandshakeHeaderLength;

  if (mode_ == MODE_NORMAL) {
    size_t header_size = end - p;
    if (header_size < kUpgradeHeaderLength)
      return -1;
    if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) {
      DLOG(INFO) << "Bad Upgrade Header "
                 << std::string(p, kUpgradeHeaderLength);
      ready_state_ = CLOSED;
      return p - start;
    }
    p += kUpgradeHeaderLength;

    header_size = end - p;
    if (header_size < kConnectionHeaderLength)
      return -1;
    if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) {
      DLOG(INFO) << "Bad Connection Header "
                 << std::string(p, kConnectionHeaderLength);
      ready_state_ = CLOSED;
      return p - start;
    }
    p += kConnectionHeaderLength;
  }
  int eoh = HttpUtil::LocateEndOfHeaders(start, len);
  if (eoh == -1)
    return eoh;
  scoped_refptr<HttpResponseHeaders> headers(
      new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(start, eoh)));
  if (!ProcessHeaders(*headers)) {
    DLOG(INFO) << "Process Headers failed: "
               << std::string(start, eoh);
    ready_state_ = CLOSED;
    return eoh;
  }
  switch (mode_) {
    case MODE_NORMAL:
      if (CheckResponseHeaders()) {
        ready_state_ = OPEN;
      } else {
        ready_state_ = CLOSED;
      }
      break;
    default:
      ready_state_ = CLOSED;
      break;
  }
  if (ready_state_ == CLOSED)
    DLOG(INFO) << "CheckHandshake mode=" << mode_
               << " " << std::string(start, eoh);
  return eoh;
}

// Gets the value of the specified header.
// It assures only one header of |name| in |headers|.
// Returns true iff single header of |name| is found in |headers|
// and |value| is filled with the value.
// Returns false otherwise.
static bool GetSingleHeader(const HttpResponseHeaders& headers,
                            const std::string& name,
                            std::string* value) {
  std::string first_value;
  void* iter = NULL;
  if (!headers.EnumerateHeader(&iter, name, &first_value))
    return false;

  // Checks no more |name| found in |headers|.
  // Second call of EnumerateHeader() must return false.
  std::string second_value;
  if (headers.EnumerateHeader(&iter, name, &second_value))
    return false;
  *value = first_value;
  return true;
}

bool WebSocket::ProcessHeaders(const HttpResponseHeaders& headers) {
  if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_))
    return false;

  if (!GetSingleHeader(headers, "websocket-location", &ws_location_))
    return false;

  if (!request_->protocol().empty()
      && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_))
    return false;
  return true;
}

bool WebSocket::CheckResponseHeaders() const {
  DCHECK(mode_ == MODE_NORMAL);
  if (!LowerCaseEqualsASCII(request_->origin(), ws_origin_.c_str()))
    return false;
  if (request_->location() != ws_location_)
    return false;
  if (request_->protocol() != ws_protocol_)
    return false;
  return true;
}

void WebSocket::SendPending() {
  DCHECK(MessageLoop::current() == origin_loop_);
  DCHECK(socket_stream_);
  if (!current_write_buf_) {
    if (pending_write_bufs_.empty())
      return;
    current_write_buf_ = new DrainableIOBuffer(
        pending_write_bufs_.front(), pending_write_bufs_.front()->size());
  }
  DCHECK_GT(current_write_buf_->BytesRemaining(), 0);
  bool sent = socket_stream_->SendData(
      current_write_buf_->data(),
      std::min(current_write_buf_->BytesRemaining(),
               max_pending_send_allowed_));
  DCHECK(sent);
}

void WebSocket::DoReceivedData() {
  DCHECK(MessageLoop::current() == origin_loop_);
  switch (ready_state_) {
    case CONNECTING:
      {
        int eoh = CheckHandshake();
        if (eoh < 0) {
          // Not enough data,  Retry when more data is available.
          return;
        }
        SkipReadBuffer(eoh);
      }
      if (ready_state_ != OPEN) {
        // Handshake failed.
        socket_stream_->Close();
        return;
      }
      if (delegate_)
        delegate_->OnOpen(this);
      if (current_read_buf_->offset() == read_consumed_len_) {
        // No remaining data after handshake message.
        break;
      }
      // FALL THROUGH
    case OPEN:
      ProcessFrameData();
      break;

    case CLOSED:
      // Closed just after DoReceivedData is queued on |origin_loop_|.
      break;
    default:
      NOTREACHED();
      break;
  }
}

void WebSocket::ProcessFrameData() {
  DCHECK(current_read_buf_);
  const char* start_frame =
      current_read_buf_->StartOfBuffer() + read_consumed_len_;
  const char* next_frame = start_frame;
  const char* p = next_frame;
  const char* end =
      current_read_buf_->StartOfBuffer() + current_read_buf_->offset();
  while (p < end) {
    unsigned char frame_byte = static_cast<unsigned char>(*p++);
    if ((frame_byte & 0x80) == 0x80) {
      int length = 0;
      while (p < end) {
        if (length > std::numeric_limits<int>::max() / 128) {
          // frame length overflow.
          socket_stream_->Close();
          return;
        }
        unsigned char c = static_cast<unsigned char>(*p);
        length = length * 128 + (c & 0x7f);
        ++p;
        if ((c & 0x80) != 0x80)
          break;
      }
      // Checks if the frame body hasn't been completely received yet.
      // It also checks the case the frame length bytes haven't been completely
      // received yet, because p == end and length > 0 in such case.
      if (p + length < end) {
        p += length;
        next_frame = p;
      } else {
        break;
      }
    } else {
      const char* msg_start = p;
      while (p < end && *p != '\xff')
        ++p;
      if (p < end && *p == '\xff') {
        if (frame_byte == 0x00 && delegate_)
          delegate_->OnMessage(this, std::string(msg_start, p - msg_start));
        ++p;
        next_frame = p;
      }
    }
  }
  SkipReadBuffer(next_frame - start_frame);
}

void WebSocket::AddToReadBuffer(const char* data, int len) {
  DCHECK(current_read_buf_);
  // Check if |current_read_buf_| has enough space to store |len| of |data|.
  if (len >= current_read_buf_->RemainingCapacity()) {
    current_read_buf_->SetCapacity(
        current_read_buf_->offset() + len);
  }

  DCHECK(current_read_buf_->RemainingCapacity() >= len);
  memcpy(current_read_buf_->data(), data, len);
  current_read_buf_->set_offset(current_read_buf_->offset() + len);
}

void WebSocket::SkipReadBuffer(int len) {
  if (len == 0)
    return;
  DCHECK_GT(len, 0);
  read_consumed_len_ += len;
  int remaining = current_read_buf_->offset() - read_consumed_len_;
  DCHECK_GE(remaining, 0);
  if (remaining < read_consumed_len_ &&
      current_read_buf_->RemainingCapacity() < read_consumed_len_) {
    // Pre compaction:
    // 0             v-read_consumed_len_  v-offset               v- capacity
    // |..processed..| .. remaining ..     | .. RemainingCapacity |
    //
    memmove(current_read_buf_->StartOfBuffer(),
            current_read_buf_->StartOfBuffer() + read_consumed_len_,
            remaining);
    read_consumed_len_ = 0;
    current_read_buf_->set_offset(remaining);
    // Post compaction:
    // 0read_consumed_len_  v- offset                             v- capacity
    // |.. remaining ..     | ..  RemainingCapacity  ...          |
    //
  }
}

void WebSocket::DoClose() {
  DCHECK(MessageLoop::current() == origin_loop_);
  WebSocketDelegate* delegate = delegate_;
  delegate_ = NULL;
  ready_state_ = CLOSED;
  if (!socket_stream_)
    return;
  socket_stream_ = NULL;
  if (delegate)
    delegate->OnClose(this);
  Release();
}

void WebSocket::DoError(int error) {
  DCHECK(MessageLoop::current() == origin_loop_);
  if (delegate_)
    delegate_->OnError(this, error);
}

}  // namespace net