// Copyright (c) 2009 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/websockets/websocket_throttle.h"
#include <string>
#include "base/message_loop.h"
#include "base/ref_counted.h"
#include "base/singleton.h"
#include "base/string_util.h"
#include "net/base/io_buffer.h"
#include "net/base/sys_addrinfo.h"
#include "net/socket_stream/socket_stream.h"
namespace net {
static std::string AddrinfoToHashkey(const struct addrinfo* addrinfo) {
switch (addrinfo->ai_family) {
case AF_INET: {
const struct sockaddr_in* const addr =
reinterpret_cast<const sockaddr_in*>(addrinfo->ai_addr);
return StringPrintf("%d:%s",
addrinfo->ai_family,
HexEncode(&addr->sin_addr, 4).c_str());
}
case AF_INET6: {
const struct sockaddr_in6* const addr6 =
reinterpret_cast<const sockaddr_in6*>(addrinfo->ai_addr);
return StringPrintf("%d:%s",
addrinfo->ai_family,
HexEncode(&addr6->sin6_addr,
sizeof(addr6->sin6_addr)).c_str());
}
default:
return StringPrintf("%d:%s",
addrinfo->ai_family,
HexEncode(addrinfo->ai_addr,
addrinfo->ai_addrlen).c_str());
}
}
// State for WebSocket protocol on each SocketStream.
// This is owned in SocketStream as UserData keyed by WebSocketState::kKeyName.
// This is alive between connection starts and handshake is finished.
// In this class, it doesn't check actual handshake finishes, but only checks
// end of header is found in read data.
class WebSocketThrottle::WebSocketState : public SocketStream::UserData {
public:
explicit WebSocketState(const AddressList& addrs)
: address_list_(addrs),
callback_(NULL),
waiting_(false),
handshake_finished_(false),
buffer_(NULL) {
}
~WebSocketState() {}
int OnStartOpenConnection(CompletionCallback* callback) {
DCHECK(!callback_);
if (!waiting_)
return OK;
callback_ = callback;
return ERR_IO_PENDING;
}
int OnRead(const char* data, int len, CompletionCallback* callback) {
DCHECK(!waiting_);
DCHECK(!callback_);
DCHECK(!handshake_finished_);
static const int kBufferSize = 8129;
if (!buffer_) {
// Fast path.
int eoh = HttpUtil::LocateEndOfHeaders(data, len, 0);
if (eoh > 0) {
handshake_finished_ = true;
return OK;
}
buffer_ = new GrowableIOBuffer();
buffer_->SetCapacity(kBufferSize);
} else if (buffer_->RemainingCapacity() < len) {
buffer_->SetCapacity(buffer_->capacity() + kBufferSize);
}
memcpy(buffer_->data(), data, len);
buffer_->set_offset(buffer_->offset() + len);
int eoh = HttpUtil::LocateEndOfHeaders(buffer_->StartOfBuffer(),
buffer_->offset(), 0);
handshake_finished_ = (eoh > 0);
return OK;
}
const AddressList& address_list() const { return address_list_; }
void SetWaiting() { waiting_ = true; }
bool IsWaiting() const { return waiting_; }
bool HandshakeFinished() const { return handshake_finished_; }
void Wakeup() {
waiting_ = false;
// We wrap |callback_| to keep this alive while this is released.
scoped_refptr<CompletionCallbackRunner> runner =
new CompletionCallbackRunner(callback_);
callback_ = NULL;
MessageLoopForIO::current()->PostTask(
FROM_HERE,
NewRunnableMethod(runner.get(),
&CompletionCallbackRunner::Run));
}
static const char* kKeyName;
private:
class CompletionCallbackRunner
: public base::RefCountedThreadSafe<CompletionCallbackRunner> {
public:
explicit CompletionCallbackRunner(CompletionCallback* callback)
: callback_(callback) {
DCHECK(callback_);
}
void Run() {
callback_->Run(OK);
}
private:
friend class base::RefCountedThreadSafe<CompletionCallbackRunner>;
virtual ~CompletionCallbackRunner() {}
CompletionCallback* callback_;
DISALLOW_COPY_AND_ASSIGN(CompletionCallbackRunner);
};
const AddressList& address_list_;
CompletionCallback* callback_;
// True if waiting another websocket connection is established.
// False if the websocket is performing handshaking.
bool waiting_;
// True if the websocket handshake is completed.
// If true, it will be removed from queue and deleted from the SocketStream
// UserData soon.
bool handshake_finished_;
// Buffer for read data to check handshake response message.
scoped_refptr<GrowableIOBuffer> buffer_;
DISALLOW_COPY_AND_ASSIGN(WebSocketState);
};
const char* WebSocketThrottle::WebSocketState::kKeyName = "WebSocketState";
WebSocketThrottle::WebSocketThrottle() {
SocketStreamThrottle::RegisterSocketStreamThrottle("ws", this);
SocketStreamThrottle::RegisterSocketStreamThrottle("wss", this);
}
WebSocketThrottle::~WebSocketThrottle() {
DCHECK(queue_.empty());
DCHECK(addr_map_.empty());
}
int WebSocketThrottle::OnStartOpenConnection(
SocketStream* socket, CompletionCallback* callback) {
WebSocketState* state = new WebSocketState(socket->address_list());
PutInQueue(socket, state);
return state->OnStartOpenConnection(callback);
}
int WebSocketThrottle::OnRead(SocketStream* socket,
const char* data, int len,
CompletionCallback* callback) {
WebSocketState* state = static_cast<WebSocketState*>(
socket->GetUserData(WebSocketState::kKeyName));
// If no state, handshake was already completed. Do nothing.
if (!state)
return OK;
int result = state->OnRead(data, len, callback);
if (state->HandshakeFinished()) {
RemoveFromQueue(socket, state);
WakeupSocketIfNecessary();
}
return result;
}
int WebSocketThrottle::OnWrite(SocketStream* socket,
const char* data, int len,
CompletionCallback* callback) {
// Do nothing.
return OK;
}
void WebSocketThrottle::OnClose(SocketStream* socket) {
WebSocketState* state = static_cast<WebSocketState*>(
socket->GetUserData(WebSocketState::kKeyName));
if (!state)
return;
RemoveFromQueue(socket, state);
WakeupSocketIfNecessary();
}
void WebSocketThrottle::PutInQueue(SocketStream* socket,
WebSocketState* state) {
socket->SetUserData(WebSocketState::kKeyName, state);
queue_.push_back(state);
const AddressList& address_list = socket->address_list();
for (const struct addrinfo* addrinfo = address_list.head();
addrinfo != NULL;
addrinfo = addrinfo->ai_next) {
std::string addrkey = AddrinfoToHashkey(addrinfo);
ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
if (iter == addr_map_.end()) {
ConnectingQueue* queue = new ConnectingQueue();
queue->push_back(state);
addr_map_[addrkey] = queue;
} else {
iter->second->push_back(state);
state->SetWaiting();
}
}
}
void WebSocketThrottle::RemoveFromQueue(SocketStream* socket,
WebSocketState* state) {
const AddressList& address_list = socket->address_list();
for (const struct addrinfo* addrinfo = address_list.head();
addrinfo != NULL;
addrinfo = addrinfo->ai_next) {
std::string addrkey = AddrinfoToHashkey(addrinfo);
ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
DCHECK(iter != addr_map_.end());
ConnectingQueue* queue = iter->second;
DCHECK(state == queue->front());
queue->pop_front();
if (queue->empty()) {
delete queue;
addr_map_.erase(iter);
}
}
for (ConnectingQueue::iterator iter = queue_.begin();
iter != queue_.end();
++iter) {
if (*iter == state) {
queue_.erase(iter);
break;
}
}
socket->SetUserData(WebSocketState::kKeyName, NULL);
}
void WebSocketThrottle::WakeupSocketIfNecessary() {
for (ConnectingQueue::iterator iter = queue_.begin();
iter != queue_.end();
++iter) {
WebSocketState* state = *iter;
if (!state->IsWaiting())
continue;
bool should_wakeup = true;
const AddressList& address_list = state->address_list();
for (const struct addrinfo* addrinfo = address_list.head();
addrinfo != NULL;
addrinfo = addrinfo->ai_next) {
std::string addrkey = AddrinfoToHashkey(addrinfo);
ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
DCHECK(iter != addr_map_.end());
ConnectingQueue* queue = iter->second;
if (state != queue->front()) {
should_wakeup = false;
break;
}
}
if (should_wakeup)
state->Wakeup();
}
}
/* static */
void WebSocketThrottle::Init() {
Singleton<WebSocketThrottle>::get();
}
} // namespace net