// Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "chrome/test/chromedriver/net/websocket.h" #include <string.h> #include "base/base64.h" #include "base/bind.h" #include "base/bind_helpers.h" #include "base/memory/scoped_vector.h" #include "base/rand_util.h" #include "base/sha1.h" #include "base/strings/string_number_conversions.h" #include "base/strings/stringprintf.h" #include "net/base/address_list.h" #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/net_util.h" #include "net/base/sys_addrinfo.h" #include "net/http/http_response_headers.h" #include "net/http/http_util.h" #include "net/websockets/websocket_frame.h" #if defined(OS_WIN) #include <Winsock2.h> #endif namespace { bool ResolveHost(const std::string& host, net::IPAddressNumber* address) { struct addrinfo hints; memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_UNSPEC; hints.ai_socktype = SOCK_STREAM; struct addrinfo* result; if (getaddrinfo(host.c_str(), NULL, &hints, &result)) return false; for (struct addrinfo* addr = result; addr; addr = addr->ai_next) { if (addr->ai_family == AF_INET || addr->ai_family == AF_INET6) { net::IPEndPoint end_point; if (!end_point.FromSockAddr(addr->ai_addr, addr->ai_addrlen)) { freeaddrinfo(result); return false; } *address = end_point.address(); } } freeaddrinfo(result); return true; } } // namespace WebSocket::WebSocket(const GURL& url, WebSocketListener* listener) : url_(url), listener_(listener), state_(INITIALIZED), write_buffer_(new net::DrainableIOBuffer(new net::IOBuffer(0), 0)), read_buffer_(new net::IOBufferWithSize(4096)) {} WebSocket::~WebSocket() { CHECK(thread_checker_.CalledOnValidThread()); } void WebSocket::Connect(const net::CompletionCallback& callback) { CHECK(thread_checker_.CalledOnValidThread()); CHECK_EQ(INITIALIZED, state_); net::IPAddressNumber address; if (!net::ParseIPLiteralToNumber(url_.HostNoBrackets(), &address)) { if (!ResolveHost(url_.HostNoBrackets(), &address)) { callback.Run(net::ERR_ADDRESS_UNREACHABLE); return; } } int port = 80; base::StringToInt(url_.port(), &port); net::AddressList addresses(net::IPEndPoint(address, port)); net::NetLog::Source source; socket_.reset(new net::TCPClientSocket(addresses, NULL, source)); state_ = CONNECTING; connect_callback_ = callback; int code = socket_->Connect(base::Bind( &WebSocket::OnSocketConnect, base::Unretained(this))); if (code != net::ERR_IO_PENDING) OnSocketConnect(code); } bool WebSocket::Send(const std::string& message) { CHECK(thread_checker_.CalledOnValidThread()); if (state_ != OPEN) return false; net::WebSocketFrameHeader header(net::WebSocketFrameHeader::kOpCodeText); header.final = true; header.masked = true; header.payload_length = message.length(); int header_size = net::GetWebSocketFrameHeaderSize(header); net::WebSocketMaskingKey masking_key = net::GenerateWebSocketMaskingKey(); std::string header_str; header_str.resize(header_size); CHECK_EQ(header_size, net::WriteWebSocketFrameHeader( header, &masking_key, &header_str[0], header_str.length())); std::string masked_message = message; net::MaskWebSocketFramePayload( masking_key, 0, &masked_message[0], masked_message.length()); Write(header_str + masked_message); return true; } void WebSocket::OnSocketConnect(int code) { if (code != net::OK) { Close(code); return; } base::Base64Encode(base::RandBytesAsString(16), &sec_key_); std::string handshake = base::StringPrintf( "GET %s HTTP/1.1\r\n" "Host: %s\r\n" "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Key: %s\r\n" "Sec-WebSocket-Version: 13\r\n" "Pragma: no-cache\r\n" "Cache-Control: no-cache\r\n" "\r\n", url_.path().c_str(), url_.host().c_str(), sec_key_.c_str()); Write(handshake); Read(); } void WebSocket::Write(const std::string& data) { pending_write_ += data; if (!write_buffer_->BytesRemaining()) ContinueWritingIfNecessary(); } void WebSocket::OnWrite(int code) { if (!socket_->IsConnected()) { // Supposedly if |StreamSocket| is closed, the error code may be undefined. Close(net::ERR_FAILED); return; } if (code < 0) { Close(code); return; } write_buffer_->DidConsume(code); ContinueWritingIfNecessary(); } void WebSocket::ContinueWritingIfNecessary() { if (!write_buffer_->BytesRemaining()) { if (pending_write_.empty()) return; write_buffer_ = new net::DrainableIOBuffer( new net::StringIOBuffer(pending_write_), pending_write_.length()); pending_write_.clear(); } int code = socket_->Write(write_buffer_.get(), write_buffer_->BytesRemaining(), base::Bind(&WebSocket::OnWrite, base::Unretained(this))); if (code != net::ERR_IO_PENDING) OnWrite(code); } void WebSocket::Read() { int code = socket_->Read(read_buffer_.get(), read_buffer_->size(), base::Bind(&WebSocket::OnRead, base::Unretained(this))); if (code != net::ERR_IO_PENDING) OnRead(code); } void WebSocket::OnRead(int code) { if (code <= 0) { Close(code ? code : net::ERR_FAILED); return; } if (state_ == CONNECTING) OnReadDuringHandshake(read_buffer_->data(), code); else if (state_ == OPEN) OnReadDuringOpen(read_buffer_->data(), code); if (state_ != CLOSED) Read(); } void WebSocket::OnReadDuringHandshake(const char* data, int len) { handshake_response_ += std::string(data, len); int headers_end = net::HttpUtil::LocateEndOfHeaders( handshake_response_.data(), handshake_response_.size(), 0); if (headers_end == -1) return; const char kMagicKey[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; std::string websocket_accept; base::Base64Encode(base::SHA1HashString(sec_key_ + kMagicKey), &websocket_accept); scoped_refptr<net::HttpResponseHeaders> headers( new net::HttpResponseHeaders( net::HttpUtil::AssembleRawHeaders( handshake_response_.data(), headers_end))); if (headers->response_code() != 101 || !headers->HasHeaderValue("Upgrade", "WebSocket") || !headers->HasHeaderValue("Connection", "Upgrade") || !headers->HasHeaderValue("Sec-WebSocket-Accept", websocket_accept)) { Close(net::ERR_FAILED); return; } std::string leftover_message = handshake_response_.substr(headers_end); handshake_response_.clear(); sec_key_.clear(); state_ = OPEN; InvokeConnectCallback(net::OK); if (!leftover_message.empty()) OnReadDuringOpen(leftover_message.c_str(), leftover_message.length()); } void WebSocket::OnReadDuringOpen(const char* data, int len) { ScopedVector<net::WebSocketFrameChunk> frame_chunks; CHECK(parser_.Decode(data, len, &frame_chunks)); for (size_t i = 0; i < frame_chunks.size(); ++i) { scoped_refptr<net::IOBufferWithSize> buffer = frame_chunks[i]->data; if (buffer.get()) next_message_ += std::string(buffer->data(), buffer->size()); if (frame_chunks[i]->final_chunk) { listener_->OnMessageReceived(next_message_); next_message_.clear(); } } } void WebSocket::InvokeConnectCallback(int code) { net::CompletionCallback temp = connect_callback_; connect_callback_.Reset(); CHECK(!temp.is_null()); temp.Run(code); } void WebSocket::Close(int code) { socket_->Disconnect(); if (!connect_callback_.is_null()) InvokeConnectCallback(code); if (state_ == OPEN) listener_->OnClose(); state_ = CLOSED; }