// Copyright (c) 2009 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/socket/socket_test_util.h"
#include <algorithm>
#include "base/basictypes.h"
#include "base/compiler_specific.h"
#include "base/message_loop.h"
#include "net/base/ssl_info.h"
#include "net/socket/socket.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace net {
MockClientSocket::MockClientSocket()
: ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)),
connected_(false) {
}
void MockClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
NOTREACHED();
}
void MockClientSocket::GetSSLCertRequestInfo(
net::SSLCertRequestInfo* cert_request_info) {
NOTREACHED();
}
SSLClientSocket::NextProtoStatus
MockClientSocket::GetNextProto(std::string* proto) {
proto->clear();
return SSLClientSocket::kNextProtoUnsupported;
}
void MockClientSocket::Disconnect() {
connected_ = false;
}
bool MockClientSocket::IsConnected() const {
return connected_;
}
bool MockClientSocket::IsConnectedAndIdle() const {
return connected_;
}
int MockClientSocket::GetPeerName(struct sockaddr* name, socklen_t* namelen) {
memset(reinterpret_cast<char *>(name), 0, *namelen);
return net::OK;
}
void MockClientSocket::RunCallbackAsync(net::CompletionCallback* callback,
int result) {
MessageLoop::current()->PostTask(FROM_HERE,
method_factory_.NewRunnableMethod(
&MockClientSocket::RunCallback, callback, result));
}
void MockClientSocket::RunCallback(net::CompletionCallback* callback,
int result) {
if (callback)
callback->Run(result);
}
MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses,
net::SocketDataProvider* data)
: addresses_(addresses),
data_(data),
read_offset_(0),
read_data_(false, net::ERR_UNEXPECTED),
need_read_data_(true),
peer_closed_connection_(false),
pending_buf_(NULL),
pending_buf_len_(0),
pending_callback_(NULL) {
DCHECK(data_);
data_->Reset();
}
int MockTCPClientSocket::Connect(net::CompletionCallback* callback,
LoadLog* load_log) {
if (connected_)
return net::OK;
connected_ = true;
if (data_->connect_data().async) {
RunCallbackAsync(callback, data_->connect_data().result);
return net::ERR_IO_PENDING;
}
return data_->connect_data().result;
}
bool MockTCPClientSocket::IsConnected() const {
return connected_ && !peer_closed_connection_;
}
int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len,
net::CompletionCallback* callback) {
if (!connected_)
return net::ERR_UNEXPECTED;
// If the buffer is already in use, a read is already in progress!
DCHECK(pending_buf_ == NULL);
// Store our async IO data.
pending_buf_ = buf;
pending_buf_len_ = buf_len;
pending_callback_ = callback;
if (need_read_data_) {
read_data_ = data_->GetNextRead();
if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
// This MockRead is just a marker to instruct us to set
// peer_closed_connection_. Skip it and get the next one.
read_data_ = data_->GetNextRead();
peer_closed_connection_ = true;
}
// ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
// to complete the async IO manually later (via OnReadComplete).
if (read_data_.result == ERR_IO_PENDING) {
DCHECK(callback); // We need to be using async IO in this case.
return ERR_IO_PENDING;
}
need_read_data_ = false;
}
return CompleteRead();
}
int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len,
net::CompletionCallback* callback) {
DCHECK(buf);
DCHECK_GT(buf_len, 0);
if (!connected_)
return net::ERR_UNEXPECTED;
std::string data(buf->data(), buf_len);
net::MockWriteResult write_result = data_->OnWrite(data);
if (write_result.async) {
RunCallbackAsync(callback, write_result.result);
return net::ERR_IO_PENDING;
}
return write_result.result;
}
void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
// There must be a read pending.
DCHECK(pending_buf_);
// You can't complete a read with another ERR_IO_PENDING status code.
DCHECK_NE(ERR_IO_PENDING, data.result);
// Since we've been waiting for data, need_read_data_ should be true.
DCHECK(need_read_data_);
read_data_ = data;
need_read_data_ = false;
// The caller is simulating that this IO completes right now. Don't
// let CompleteRead() schedule a callback.
read_data_.async = false;
net::CompletionCallback* callback = pending_callback_;
int rv = CompleteRead();
RunCallback(callback, rv);
}
int MockTCPClientSocket::CompleteRead() {
DCHECK(pending_buf_);
DCHECK(pending_buf_len_ > 0);
// Save the pending async IO data and reset our |pending_| state.
net::IOBuffer* buf = pending_buf_;
int buf_len = pending_buf_len_;
net::CompletionCallback* callback = pending_callback_;
pending_buf_ = NULL;
pending_buf_len_ = 0;
pending_callback_ = NULL;
int result = read_data_.result;
DCHECK(result != ERR_IO_PENDING);
if (read_data_.data) {
if (read_data_.data_len - read_offset_ > 0) {
result = std::min(buf_len, read_data_.data_len - read_offset_);
memcpy(buf->data(), read_data_.data + read_offset_, result);
read_offset_ += result;
if (read_offset_ == read_data_.data_len) {
need_read_data_ = true;
read_offset_ = 0;
}
} else {
result = 0; // EOF
}
}
if (read_data_.async) {
DCHECK(callback);
RunCallbackAsync(callback, result);
return net::ERR_IO_PENDING;
}
return result;
}
class MockSSLClientSocket::ConnectCallback :
public net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> {
public:
ConnectCallback(MockSSLClientSocket *ssl_client_socket,
net::CompletionCallback* user_callback,
int rv)
: ALLOW_THIS_IN_INITIALIZER_LIST(
net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback>(
this, &ConnectCallback::Wrapper)),
ssl_client_socket_(ssl_client_socket),
user_callback_(user_callback),
rv_(rv) {
}
private:
void Wrapper(int rv) {
if (rv_ == net::OK)
ssl_client_socket_->connected_ = true;
user_callback_->Run(rv_);
delete this;
}
MockSSLClientSocket* ssl_client_socket_;
net::CompletionCallback* user_callback_;
int rv_;
};
MockSSLClientSocket::MockSSLClientSocket(
net::ClientSocket* transport_socket,
const std::string& hostname,
const net::SSLConfig& ssl_config,
net::SSLSocketDataProvider* data)
: transport_(transport_socket),
data_(data) {
DCHECK(data_);
}
MockSSLClientSocket::~MockSSLClientSocket() {
Disconnect();
}
void MockSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
ssl_info->Reset();
}
int MockSSLClientSocket::Connect(net::CompletionCallback* callback,
LoadLog* load_log) {
ConnectCallback* connect_callback = new ConnectCallback(
this, callback, data_->connect.result);
int rv = transport_->Connect(connect_callback, load_log);
if (rv == net::OK) {
delete connect_callback;
if (data_->connect.async) {
RunCallbackAsync(callback, data_->connect.result);
return net::ERR_IO_PENDING;
}
if (data_->connect.result == net::OK)
connected_ = true;
return data_->connect.result;
}
return rv;
}
void MockSSLClientSocket::Disconnect() {
MockClientSocket::Disconnect();
if (transport_ != NULL)
transport_->Disconnect();
}
int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len,
net::CompletionCallback* callback) {
return transport_->Read(buf, buf_len, callback);
}
int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len,
net::CompletionCallback* callback) {
return transport_->Write(buf, buf_len, callback);
}
MockRead StaticSocketDataProvider::GetNextRead() {
MockRead rv = reads_[read_index_];
if (reads_[read_index_].result != OK ||
reads_[read_index_].data_len != 0)
read_index_++; // Don't advance past an EOF.
return rv;
}
MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
if (!writes_) {
// Not using mock writes; succeed synchronously.
return MockWriteResult(false, data.length());
}
// Check that what we are writing matches the expectation.
// Then give the mocked return value.
net::MockWrite* w = &writes_[write_index_++];
int result = w->result;
if (w->data) {
// Note - we can simulate a partial write here. If the expected data
// is a match, but shorter than the write actually written, that is legal.
// Example:
// Application writes "foobarbaz" (9 bytes)
// Expected write was "foo" (3 bytes)
// This is a success, and we return 3 to the application.
std::string expected_data(w->data, w->data_len);
EXPECT_GE(data.length(), expected_data.length());
std::string actual_data(data.substr(0, w->data_len));
EXPECT_EQ(expected_data, actual_data);
if (expected_data != actual_data)
return MockWriteResult(false, net::ERR_UNEXPECTED);
if (result == net::OK)
result = w->data_len;
}
return MockWriteResult(w->async, result);
}
void StaticSocketDataProvider::Reset() {
read_index_ = 0;
write_index_ = 0;
}
DynamicSocketDataProvider::DynamicSocketDataProvider()
: short_read_limit_(0),
allow_unconsumed_reads_(false) {
}
MockRead DynamicSocketDataProvider::GetNextRead() {
if (reads_.empty())
return MockRead(false, ERR_UNEXPECTED);
MockRead result = reads_.front();
if (short_read_limit_ == 0 || result.data_len <= short_read_limit_) {
reads_.pop_front();
} else {
result.data_len = short_read_limit_;
reads_.front().data += result.data_len;
reads_.front().data_len -= result.data_len;
}
return result;
}
void DynamicSocketDataProvider::Reset() {
reads_.clear();
}
void DynamicSocketDataProvider::SimulateRead(const char* data) {
if (!allow_unconsumed_reads_) {
EXPECT_TRUE(reads_.empty()) << "Unconsumed read: " << reads_.front().data;
}
reads_.push_back(MockRead(data));
}
void MockClientSocketFactory::AddSocketDataProvider(
SocketDataProvider* data) {
mock_data_.Add(data);
}
void MockClientSocketFactory::AddSSLSocketDataProvider(
SSLSocketDataProvider* data) {
mock_ssl_data_.Add(data);
}
void MockClientSocketFactory::ResetNextMockIndexes() {
mock_data_.ResetNextIndex();
mock_ssl_data_.ResetNextIndex();
}
MockTCPClientSocket* MockClientSocketFactory::GetMockTCPClientSocket(
int index) const {
return tcp_client_sockets_[index];
}
MockSSLClientSocket* MockClientSocketFactory::GetMockSSLClientSocket(
int index) const {
return ssl_client_sockets_[index];
}
ClientSocket* MockClientSocketFactory::CreateTCPClientSocket(
const AddressList& addresses) {
SocketDataProvider* data_provider = mock_data_.GetNext();
MockTCPClientSocket* socket =
new MockTCPClientSocket(addresses, data_provider);
data_provider->set_socket(socket);
tcp_client_sockets_.push_back(socket);
return socket;
}
SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket(
ClientSocket* transport_socket,
const std::string& hostname,
const SSLConfig& ssl_config) {
MockSSLClientSocket* socket =
new MockSSLClientSocket(transport_socket, hostname, ssl_config,
mock_ssl_data_.GetNext());
ssl_client_sockets_.push_back(socket);
return socket;
}
int TestSocketRequest::WaitForResult() {
return callback_.WaitForResult();
}
void TestSocketRequest::RunWithParams(const Tuple1<int>& params) {
callback_.RunWithParams(params);
(*completion_count_)++;
request_order_->push_back(this);
}
// static
const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
// static
const int ClientSocketPoolTest::kRequestNotFound = -2;
void ClientSocketPoolTest::SetUp() {
completion_count_ = 0;
}
void ClientSocketPoolTest::TearDown() {
// The tests often call Reset() on handles at the end which may post
// DoReleaseSocket() tasks.
// Pending tasks created by client_socket_pool_base_unittest.cc are
// posted two milliseconds into the future and thus won't become
// scheduled until that time.
// We wait a few milliseconds to make sure that all such future tasks
// are ready to run, before calling RunAllPending(). This will work
// correctly even if Sleep() finishes late (and it should never finish
// early), as all we have to ensure is that actual wall-time has progressed
// past the scheduled starting time of the pending task.
PlatformThread::Sleep(10);
MessageLoop::current()->RunAllPending();
}
int ClientSocketPoolTest::GetOrderOfRequest(size_t index) {
index--;
if (index >= requests_.size())
return kIndexOutOfBounds;
for (size_t i = 0; i < request_order_.size(); i++)
if (requests_[index] == request_order_[i])
return i + 1;
return kRequestNotFound;
}
bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
ScopedVector<TestSocketRequest>::iterator i;
for (i = requests_.begin(); i != requests_.end(); ++i) {
if ((*i)->handle()->is_initialized()) {
if (keep_alive == NO_KEEP_ALIVE)
(*i)->handle()->socket()->Disconnect();
(*i)->handle()->Reset();
MessageLoop::current()->RunAllPending();
return true;
}
}
return false;
}
void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
bool released_one;
do {
released_one = ReleaseOneConnection(keep_alive);
} while (released_one);
}
} // namespace net