// Copyright (c) 2009 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/tools/flip_server/sm_connection.h"
#include <errno.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <list>
#include <string>
#include "net/tools/flip_server/constants.h"
#include "net/tools/flip_server/flip_config.h"
#include "net/tools/flip_server/http_interface.h"
#include "net/tools/flip_server/spdy_interface.h"
#include "net/tools/flip_server/spdy_ssl.h"
#include "net/tools/flip_server/streamer_interface.h"
namespace net {
// static
bool SMConnection::force_spdy_ = false;
SMConnection::SMConnection(EpollServer* epoll_server,
SSLState* ssl_state,
MemoryCache* memory_cache,
FlipAcceptor* acceptor,
std::string log_prefix)
: last_read_time_(0),
fd_(-1),
events_(0),
registered_in_epoll_server_(false),
initialized_(false),
protocol_detected_(false),
connection_complete_(false),
connection_pool_(NULL),
epoll_server_(epoll_server),
ssl_state_(ssl_state),
memory_cache_(memory_cache),
acceptor_(acceptor),
read_buffer_(kSpdySegmentSize * 40),
sm_spdy_interface_(NULL),
sm_http_interface_(NULL),
sm_streamer_interface_(NULL),
sm_interface_(NULL),
log_prefix_(log_prefix),
max_bytes_sent_per_dowrite_(4096),
ssl_(NULL) {
}
SMConnection::~SMConnection() {
if (initialized())
Reset();
}
EpollServer* SMConnection::epoll_server() {
return epoll_server_;
}
void SMConnection::ReadyToSend() {
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "Setting ready to send: EPOLLIN | EPOLLOUT";
epoll_server_->SetFDReady(fd_, EPOLLIN | EPOLLOUT);
}
void SMConnection::EnqueueDataFrame(DataFrame* df) {
output_list_.push_back(df);
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "EnqueueDataFrame: "
<< "size = " << df->size << ": Setting FD ready.";
ReadyToSend();
}
void SMConnection::InitSMConnection(SMConnectionPoolInterface* connection_pool,
SMInterface* sm_interface,
EpollServer* epoll_server,
int fd,
std::string server_ip,
std::string server_port,
std::string remote_ip,
bool use_ssl) {
if (initialized_) {
LOG(FATAL) << "Attempted to initialize already initialized server";
return;
}
client_ip_ = remote_ip;
if (fd == -1) {
// If fd == -1, then we are initializing a new connection that will
// connect to the backend.
//
// ret: -1 == error
// 0 == connection in progress
// 1 == connection complete
// TODO(kelindsay): is_numeric_host_address value needs to be detected
server_ip_ = server_ip;
server_port_ = server_port;
int ret = CreateConnectedSocket(&fd_,
server_ip,
server_port,
true,
acceptor_->disable_nagle_);
if (ret < 0) {
LOG(ERROR) << "-1 Could not create connected socket";
return;
} else if (ret == 1) {
DCHECK_NE(-1, fd_);
connection_complete_ = true;
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "Connection complete to: " << server_ip_ << ":"
<< server_port_ << " ";
}
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "Connecting to server: " << server_ip_ << ":"
<< server_port_ << " ";
} else {
// If fd != -1 then we are initializing a connection that has just been
// accepted from the listen socket.
connection_complete_ = true;
if (epoll_server_ && registered_in_epoll_server_ && fd_ != -1) {
epoll_server_->UnregisterFD(fd_);
}
if (fd_ != -1) {
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "Closing pre-existing fd";
close(fd_);
fd_ = -1;
}
fd_ = fd;
}
registered_in_epoll_server_ = false;
// Set the last read time here as the idle checker will start from
// now.
last_read_time_ = time(NULL);
initialized_ = true;
connection_pool_ = connection_pool;
epoll_server_ = epoll_server;
if (sm_interface) {
sm_interface_ = sm_interface;
protocol_detected_ = true;
}
read_buffer_.Clear();
epoll_server_->RegisterFD(fd_, this, EPOLLIN | EPOLLOUT | EPOLLET);
if (use_ssl) {
ssl_ = CreateSSLContext(ssl_state_->ssl_ctx);
SSL_set_fd(ssl_, fd_);
PrintSslError();
}
}
void SMConnection::CorkSocket() {
int state = 1;
int rv = setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &state, sizeof(state));
if (rv < 0)
VLOG(1) << "setsockopt(CORK): " << errno;
}
void SMConnection::UncorkSocket() {
int state = 0;
int rv = setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &state, sizeof(state));
if (rv < 0)
VLOG(1) << "setsockopt(CORK): " << errno;
}
int SMConnection::Send(const char* data, int len, int flags) {
int rv = 0;
CorkSocket();
if (ssl_) {
ssize_t bytes_written = 0;
// Write smallish chunks to SSL so that we don't have large
// multi-packet TLS records to receive before being able to handle
// the data. We don't have to be too careful here, because our data
// frames are already getting chunked appropriately, and those are
// the most likely "big" frames.
while (len > 0) {
const int kMaxTLSRecordSize = 1500;
const char* ptr = &(data[bytes_written]);
int chunksize = std::min(len, kMaxTLSRecordSize);
rv = SSL_write(ssl_, ptr, chunksize);
VLOG(2) << "SSLWrite(" << chunksize << " bytes): " << rv;
if (rv <= 0) {
switch (SSL_get_error(ssl_, rv)) {
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
case SSL_ERROR_WANT_ACCEPT:
case SSL_ERROR_WANT_CONNECT:
rv = -2;
break;
default:
PrintSslError();
break;
}
break;
}
bytes_written += rv;
len -= rv;
if (rv != chunksize)
break; // If we couldn't write everything, we're implicitly stalled
}
// If we wrote some data, return that count. Otherwise
// return the stall error.
if (bytes_written > 0)
rv = bytes_written;
} else {
rv = send(fd_, data, len, flags);
}
if (!(flags & MSG_MORE))
UncorkSocket();
return rv;
}
void SMConnection::OnRegistration(EpollServer* eps, int fd, int event_mask) {
registered_in_epoll_server_ = true;
}
void SMConnection::OnEvent(int fd, EpollEvent* event) {
events_ |= event->in_events;
HandleEvents();
if (events_) {
event->out_ready_mask = events_;
events_ = 0;
}
}
void SMConnection::OnUnregistration(int fd, bool replaced) {
registered_in_epoll_server_ = false;
}
void SMConnection::OnShutdown(EpollServer* eps, int fd) {
Cleanup("OnShutdown");
return;
}
void SMConnection::Cleanup(const char* cleanup) {
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Cleanup: " << cleanup;
if (!initialized_)
return;
Reset();
if (connection_pool_)
connection_pool_->SMConnectionDone(this);
if (sm_interface_)
sm_interface_->ResetForNewConnection();
last_read_time_ = 0;
}
void SMConnection::HandleEvents() {
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Received: "
<< EpollServer::EventMaskToString(events_).c_str();
if (events_ & EPOLLIN) {
if (!DoRead())
goto handle_close_or_error;
}
if (events_ & EPOLLOUT) {
// Check if we have connected or not
if (connection_complete_ == false) {
int sock_error;
socklen_t sock_error_len = sizeof(sock_error);
int ret = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &sock_error,
&sock_error_len);
if (ret != 0) {
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "getsockopt error: " << errno << ": " << strerror(errno);
goto handle_close_or_error;
}
if (sock_error == 0) {
connection_complete_ = true;
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "Connection complete to " << server_ip_ << ":"
<< server_port_ << " ";
} else if (sock_error == EINPROGRESS) {
return;
} else {
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "error connecting to server";
goto handle_close_or_error;
}
}
if (!DoWrite())
goto handle_close_or_error;
}
if (events_ & (EPOLLHUP | EPOLLERR)) {
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "!!! Got HUP or ERR";
goto handle_close_or_error;
}
return;
handle_close_or_error:
Cleanup("HandleEvents");
}
// Decide if SPDY was negotiated.
bool SMConnection::WasSpdyNegotiated() {
if (force_spdy())
return true;
// If this is an SSL connection, check if NPN specifies SPDY.
if (ssl_) {
const unsigned char *npn_proto;
unsigned int npn_proto_len;
SSL_get0_next_proto_negotiated(ssl_, &npn_proto, &npn_proto_len);
if (npn_proto_len > 0) {
std::string npn_proto_str((const char *)npn_proto, npn_proto_len);
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "NPN protocol detected: " << npn_proto_str;
if (!strncmp(reinterpret_cast<const char*>(npn_proto),
"spdy/2", npn_proto_len))
return true;
}
}
return false;
}
bool SMConnection::SetupProtocolInterfaces() {
DCHECK(!protocol_detected_);
protocol_detected_ = true;
bool spdy_negotiated = WasSpdyNegotiated();
bool using_ssl = ssl_ != NULL;
if (using_ssl)
VLOG(1) << (SSL_session_reused(ssl_) ? "Resumed" : "Renegotiated")
<< " SSL Session.";
if (acceptor_->spdy_only_ && !spdy_negotiated) {
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "SPDY proxy only, closing HTTPS connection.";
return false;
}
switch (acceptor_->flip_handler_type_) {
case FLIP_HANDLER_HTTP_SERVER:
{
DCHECK(!spdy_negotiated);
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< (sm_http_interface_ ? "Creating" : "Reusing")
<< " HTTP interface.";
if (!sm_http_interface_)
sm_http_interface_ = new HttpSM(this,
NULL,
epoll_server_,
memory_cache_,
acceptor_);
sm_interface_ = sm_http_interface_;
}
break;
case FLIP_HANDLER_PROXY:
{
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< (sm_streamer_interface_ ? "Creating" : "Reusing")
<< " PROXY Streamer interface.";
if (!sm_streamer_interface_) {
sm_streamer_interface_ = new StreamerSM(this,
NULL,
epoll_server_,
acceptor_);
sm_streamer_interface_->set_is_request();
}
sm_interface_ = sm_streamer_interface_;
// If spdy is not negotiated, the streamer interface will proxy all
// data to the origin server.
if (!spdy_negotiated)
break;
}
// Otherwise fall through into the case below.
case FLIP_HANDLER_SPDY_SERVER:
{
DCHECK(spdy_negotiated);
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< (sm_spdy_interface_ ? "Creating" : "Reusing")
<< " SPDY interface.";
if (!sm_spdy_interface_)
sm_spdy_interface_ = new SpdySM(this,
NULL,
epoll_server_,
memory_cache_,
acceptor_);
sm_interface_ = sm_spdy_interface_;
}
break;
}
CorkSocket();
if (!sm_interface_->PostAcceptHook())
return false;
return true;
}
bool SMConnection::DoRead() {
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "DoRead()";
while (!read_buffer_.Full()) {
char* bytes;
int size;
if (fd_ == -1) {
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "DoRead(): fd_ == -1. Invalid FD. Returning false";
return false;
}
read_buffer_.GetWritablePtr(&bytes, &size);
ssize_t bytes_read = 0;
if (ssl_) {
bytes_read = SSL_read(ssl_, bytes, size);
if (bytes_read < 0) {
int err = SSL_get_error(ssl_, bytes_read);
switch (err) {
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
case SSL_ERROR_WANT_ACCEPT:
case SSL_ERROR_WANT_CONNECT:
events_ &= ~EPOLLIN;
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "DoRead: SSL WANT_XXX: " << err;
goto done;
default:
PrintSslError();
goto error_or_close;
}
}
} else {
bytes_read = recv(fd_, bytes, size, MSG_DONTWAIT);
}
int stored_errno = errno;
if (bytes_read == -1) {
switch (stored_errno) {
case EAGAIN:
events_ &= ~EPOLLIN;
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "Got EAGAIN while reading";
goto done;
case EINTR:
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "Got EINTR while reading";
continue;
default:
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "While calling recv, got error: "
<< (ssl_?"(ssl error)":strerror(stored_errno));
goto error_or_close;
}
} else if (bytes_read > 0) {
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "read " << bytes_read
<< " bytes";
last_read_time_ = time(NULL);
// If the protocol hasn't been detected yet, set up the handlers
// we'll need.
if (!protocol_detected_) {
if (!SetupProtocolInterfaces()) {
LOG(ERROR) << "Error setting up protocol interfaces.";
goto error_or_close;
}
}
read_buffer_.AdvanceWritablePtr(bytes_read);
if (!DoConsumeReadData())
goto error_or_close;
continue;
} else { // bytes_read == 0
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "0 bytes read with recv call.";
}
goto error_or_close;
}
done:
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "DoRead done!";
return true;
error_or_close:
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "DoRead(): error_or_close. "
<< "Cleaning up, then returning false";
Cleanup("DoRead");
return false;
}
bool SMConnection::DoConsumeReadData() {
char* bytes;
int size;
read_buffer_.GetReadablePtr(&bytes, &size);
while (size != 0) {
size_t bytes_consumed = sm_interface_->ProcessReadInput(bytes, size);
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "consumed "
<< bytes_consumed << " bytes";
if (bytes_consumed == 0) {
break;
}
read_buffer_.AdvanceReadablePtr(bytes_consumed);
if (sm_interface_->MessageFullyRead()) {
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "HandleRequestFullyRead: Setting EPOLLOUT";
HandleResponseFullyRead();
events_ |= EPOLLOUT;
} else if (sm_interface_->Error()) {
LOG(ERROR) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "Framer error detected: Setting EPOLLOUT: "
<< sm_interface_->ErrorAsString();
// this causes everything to be closed/cleaned up.
events_ |= EPOLLOUT;
return false;
}
read_buffer_.GetReadablePtr(&bytes, &size);
}
return true;
}
void SMConnection::HandleResponseFullyRead() {
sm_interface_->Cleanup();
}
bool SMConnection::DoWrite() {
size_t bytes_sent = 0;
int flags = MSG_NOSIGNAL | MSG_DONTWAIT;
if (fd_ == -1) {
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "DoWrite: fd == -1. Returning false.";
return false;
}
if (output_list_.empty()) {
VLOG(2) << log_prefix_ << "DoWrite: Output list empty.";
if (sm_interface_) {
sm_interface_->GetOutput();
}
if (output_list_.empty()) {
events_ &= ~EPOLLOUT;
}
}
while (!output_list_.empty()) {
VLOG(2) << log_prefix_ << "DoWrite: Items in output list: "
<< output_list_.size();
if (bytes_sent >= max_bytes_sent_per_dowrite_) {
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< " byte sent >= max bytes sent per write: Setting EPOLLOUT: "
<< bytes_sent;
events_ |= EPOLLOUT;
break;
}
if (sm_interface_ && output_list_.size() < 2) {
sm_interface_->GetOutput();
}
DataFrame* data_frame = output_list_.front();
const char* bytes = data_frame->data;
int size = data_frame->size;
bytes += data_frame->index;
size -= data_frame->index;
DCHECK_GE(size, 0);
if (size <= 0) {
output_list_.pop_front();
delete data_frame;
continue;
}
flags = MSG_NOSIGNAL | MSG_DONTWAIT;
// Look for a queue size > 1 because |this| frame is remains on the list
// until it has finished sending.
if (output_list_.size() > 1) {
VLOG(2) << log_prefix_ << "Outlist size: " << output_list_.size()
<< ": Adding MSG_MORE flag";
flags |= MSG_MORE;
}
VLOG(2) << log_prefix_ << "Attempting to send " << size << " bytes.";
ssize_t bytes_written = Send(bytes, size, flags);
int stored_errno = errno;
if (bytes_written == -1) {
switch (stored_errno) {
case EAGAIN:
events_ &= ~EPOLLOUT;
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "Got EAGAIN while writing";
goto done;
case EINTR:
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "Got EINTR while writing";
continue;
default:
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "While calling send, got error: " << stored_errno
<< ": " << (ssl_?"":strerror(stored_errno));
goto error_or_close;
}
} else if (bytes_written > 0) {
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Wrote: "
<< bytes_written << " bytes";
data_frame->index += bytes_written;
bytes_sent += bytes_written;
continue;
} else if (bytes_written == -2) {
// -2 handles SSL_ERROR_WANT_* errors
events_ &= ~EPOLLOUT;
goto done;
}
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "0 bytes written with send call.";
goto error_or_close;
}
done:
UncorkSocket();
return true;
error_or_close:
VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT
<< "DoWrite: error_or_close. Returning false "
<< "after cleaning up";
Cleanup("DoWrite");
UncorkSocket();
return false;
}
void SMConnection::Reset() {
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Resetting";
if (ssl_) {
SSL_shutdown(ssl_);
PrintSslError();
SSL_free(ssl_);
PrintSslError();
ssl_ = NULL;
}
if (registered_in_epoll_server_) {
epoll_server_->UnregisterFD(fd_);
registered_in_epoll_server_ = false;
}
if (fd_ >= 0) {
VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Closing connection";
close(fd_);
fd_ = -1;
}
read_buffer_.Clear();
initialized_ = false;
protocol_detected_ = false;
events_ = 0;
for (std::list<DataFrame*>::iterator i =
output_list_.begin();
i != output_list_.end();
++i) {
delete *i;
}
output_list_.clear();
}
// static
SMConnection* SMConnection::NewSMConnection(EpollServer* epoll_server,
SSLState *ssl_state,
MemoryCache* memory_cache,
FlipAcceptor *acceptor,
std::string log_prefix) {
return new SMConnection(epoll_server, ssl_state, memory_cache,
acceptor, log_prefix);
}
} // namespace net