// Copyright (c) 2011 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/websockets/websocket_handshake_draft75.h" #include "base/memory/ref_counted.h" #include "base/string_util.h" #include "net/http/http_response_headers.h" #include "net/http/http_util.h" namespace net { const char WebSocketHandshakeDraft75::kServerHandshakeHeader[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"; const size_t WebSocketHandshakeDraft75::kServerHandshakeHeaderLength = sizeof(kServerHandshakeHeader) - 1; const char WebSocketHandshakeDraft75::kUpgradeHeader[] = "Upgrade: WebSocket\r\n"; const size_t WebSocketHandshakeDraft75::kUpgradeHeaderLength = sizeof(kUpgradeHeader) - 1; const char WebSocketHandshakeDraft75::kConnectionHeader[] = "Connection: Upgrade\r\n"; const size_t WebSocketHandshakeDraft75::kConnectionHeaderLength = sizeof(kConnectionHeader) - 1; WebSocketHandshakeDraft75::WebSocketHandshakeDraft75( const GURL& url, const std::string& origin, const std::string& location, const std::string& protocol) : WebSocketHandshake(url, origin, location, protocol) { } WebSocketHandshakeDraft75::~WebSocketHandshakeDraft75() { } std::string WebSocketHandshakeDraft75::CreateClientHandshakeMessage() { std::string msg; msg = "GET "; msg += GetResourceName(); msg += " HTTP/1.1\r\n"; msg += kUpgradeHeader; msg += kConnectionHeader; msg += "Host: "; msg += GetHostFieldValue(); msg += "\r\n"; msg += "Origin: "; msg += GetOriginFieldValue(); msg += "\r\n"; if (!protocol_.empty()) { msg += "WebSocket-Protocol: "; msg += protocol_; msg += "\r\n"; } // TODO(ukai): Add cookie if necessary. msg += "\r\n"; return msg; } int WebSocketHandshakeDraft75::ReadServerHandshake( const char* data, size_t len) { mode_ = MODE_INCOMPLETE; if (len < kServerHandshakeHeaderLength) { return -1; } if (!memcmp(data, kServerHandshakeHeader, kServerHandshakeHeaderLength)) { mode_ = MODE_NORMAL; } else { int eoh = HttpUtil::LocateEndOfHeaders(data, len); if (eoh < 0) return -1; return eoh; } const char* p = data + kServerHandshakeHeaderLength; const char* end = data + len; if (mode_ == MODE_NORMAL) { size_t header_size = end - p; if (header_size < kUpgradeHeaderLength) return -1; if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) { mode_ = MODE_FAILED; DVLOG(1) << "Bad Upgrade Header " << std::string(p, kUpgradeHeaderLength); return p - data; } p += kUpgradeHeaderLength; header_size = end - p; if (header_size < kConnectionHeaderLength) return -1; if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) { mode_ = MODE_FAILED; DVLOG(1) << "Bad Connection Header " << std::string(p, kConnectionHeaderLength); return p - data; } p += kConnectionHeaderLength; } int eoh = HttpUtil::LocateEndOfHeaders(data, len); if (eoh == -1) return eoh; scoped_refptr<HttpResponseHeaders> headers( new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh))); if (!ProcessHeaders(*headers)) { DVLOG(1) << "Process Headers failed: " << std::string(data, eoh); mode_ = MODE_FAILED; } switch (mode_) { case MODE_NORMAL: if (CheckResponseHeaders()) { mode_ = MODE_CONNECTED; } else { mode_ = MODE_FAILED; } break; default: mode_ = MODE_FAILED; break; } return eoh; } bool WebSocketHandshakeDraft75::ProcessHeaders( const HttpResponseHeaders& headers) { if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_)) return false; if (!GetSingleHeader(headers, "websocket-location", &ws_location_)) return false; // If |protocol_| is not specified by client, we don't care if there's // protocol field or not as specified in the spec. if (!protocol_.empty() && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_)) return false; return true; } bool WebSocketHandshakeDraft75::CheckResponseHeaders() const { DCHECK(mode_ == MODE_NORMAL); if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str())) return false; if (location_ != ws_location_) return false; if (!protocol_.empty() && protocol_ != ws_protocol_) return false; return true; } } // namespace net