// Copyright 2016 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "mojo/core/channel.h"
#include <stdint.h>
#include <windows.h>
#include <algorithm>
#include <limits>
#include <memory>
#include "base/bind.h"
#include "base/containers/queue.h"
#include "base/location.h"
#include "base/macros.h"
#include "base/memory/ref_counted.h"
#include "base/message_loop/message_loop_current.h"
#include "base/message_loop/message_pump_for_io.h"
#include "base/process/process_handle.h"
#include "base/synchronization/lock.h"
#include "base/task_runner.h"
#include "base/win/scoped_handle.h"
#include "base/win/win_util.h"
namespace mojo {
namespace core {
namespace {
class ChannelWin : public Channel,
public base::MessageLoopCurrent::DestructionObserver,
public base::MessagePumpForIO::IOHandler {
public:
ChannelWin(Delegate* delegate,
ConnectionParams connection_params,
scoped_refptr<base::TaskRunner> io_task_runner)
: Channel(delegate), self_(this), io_task_runner_(io_task_runner) {
if (connection_params.server_endpoint().is_valid()) {
handle_ = connection_params.TakeServerEndpoint()
.TakePlatformHandle()
.TakeHandle();
needs_connection_ = true;
} else {
handle_ =
connection_params.TakeEndpoint().TakePlatformHandle().TakeHandle();
}
CHECK(handle_.IsValid());
}
void Start() override {
io_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&ChannelWin::StartOnIOThread, this));
}
void ShutDownImpl() override {
// Always shut down asynchronously when called through the public interface.
io_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&ChannelWin::ShutDownOnIOThread, this));
}
void Write(MessagePtr message) override {
if (remote_process().is_valid()) {
// If we know the remote process handle, we transfer all outgoing handles
// to the process now rewriting them in the message.
std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
for (auto& handle : handles) {
if (handle.handle().is_valid())
handle.TransferToProcess(remote_process().Clone());
}
message->SetHandles(std::move(handles));
}
bool write_error = false;
{
base::AutoLock lock(write_lock_);
if (reject_writes_)
return;
bool write_now = !delay_writes_ && outgoing_messages_.empty();
outgoing_messages_.emplace_back(std::move(message));
if (write_now && !WriteNoLock(outgoing_messages_.front()))
reject_writes_ = write_error = true;
}
if (write_error) {
// Do not synchronously invoke OnWriteError(). Write() may have been
// called by the delegate and we don't want to re-enter it.
io_task_runner_->PostTask(FROM_HERE,
base::BindOnce(&ChannelWin::OnWriteError, this,
Error::kDisconnected));
}
}
void LeakHandle() override {
DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
leak_handle_ = true;
}
bool GetReadPlatformHandles(const void* payload,
size_t payload_size,
size_t num_handles,
const void* extra_header,
size_t extra_header_size,
std::vector<PlatformHandle>* handles,
bool* deferred) override {
DCHECK(extra_header);
if (num_handles > std::numeric_limits<uint16_t>::max())
return false;
using HandleEntry = Channel::Message::HandleEntry;
size_t handles_size = sizeof(HandleEntry) * num_handles;
if (handles_size > extra_header_size)
return false;
handles->reserve(num_handles);
const HandleEntry* extra_header_handles =
reinterpret_cast<const HandleEntry*>(extra_header);
for (size_t i = 0; i < num_handles; i++) {
HANDLE handle_value =
base::win::Uint32ToHandle(extra_header_handles[i].handle);
if (remote_process().is_valid()) {
// If we know the remote process's handle, we assume it doesn't know
// ours; that means any handle values still belong to that process, and
// we need to transfer them to this process.
handle_value = PlatformHandleInTransit::TakeIncomingRemoteHandle(
handle_value, remote_process().get())
.ReleaseHandle();
}
handles->emplace_back(base::win::ScopedHandle(std::move(handle_value)));
}
return true;
}
private:
// May run on any thread.
~ChannelWin() override {}
void StartOnIOThread() {
base::MessageLoopCurrent::Get()->AddDestructionObserver(this);
base::MessageLoopCurrentForIO::Get()->RegisterIOHandler(handle_.Get(),
this);
if (needs_connection_) {
BOOL ok = ::ConnectNamedPipe(handle_.Get(), &connect_context_.overlapped);
if (ok) {
PLOG(ERROR) << "Unexpected success while waiting for pipe connection";
OnError(Error::kConnectionFailed);
return;
}
const DWORD err = GetLastError();
switch (err) {
case ERROR_PIPE_CONNECTED:
break;
case ERROR_IO_PENDING:
is_connect_pending_ = true;
AddRef();
return;
case ERROR_NO_DATA:
default:
OnError(Error::kConnectionFailed);
return;
}
}
// Now that we have registered our IOHandler, we can start writing.
{
base::AutoLock lock(write_lock_);
if (delay_writes_) {
delay_writes_ = false;
WriteNextNoLock();
}
}
// Keep this alive in case we synchronously run shutdown, via OnError(),
// as a result of a ReadFile() failure on the channel.
scoped_refptr<ChannelWin> keep_alive(this);
ReadMore(0);
}
void ShutDownOnIOThread() {
base::MessageLoopCurrent::Get()->RemoveDestructionObserver(this);
// TODO(https://crbug.com/583525): This function is expected to be called
// once, and |handle_| should be valid at this point.
CHECK(handle_.IsValid());
CancelIo(handle_.Get());
if (leak_handle_)
ignore_result(handle_.Take());
else
handle_.Close();
// Allow |this| to be destroyed as soon as no IO is pending.
self_ = nullptr;
}
// base::MessageLoopCurrent::DestructionObserver:
void WillDestroyCurrentMessageLoop() override {
DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
if (self_)
ShutDownOnIOThread();
}
// base::MessageLoop::IOHandler:
void OnIOCompleted(base::MessagePumpForIO::IOContext* context,
DWORD bytes_transfered,
DWORD error) override {
if (error != ERROR_SUCCESS) {
if (context == &write_context_) {
{
base::AutoLock lock(write_lock_);
reject_writes_ = true;
}
OnWriteError(Error::kDisconnected);
} else {
OnError(Error::kDisconnected);
}
} else if (context == &connect_context_) {
DCHECK(is_connect_pending_);
is_connect_pending_ = false;
ReadMore(0);
base::AutoLock lock(write_lock_);
if (delay_writes_) {
delay_writes_ = false;
WriteNextNoLock();
}
} else if (context == &read_context_) {
OnReadDone(static_cast<size_t>(bytes_transfered));
} else {
CHECK(context == &write_context_);
OnWriteDone(static_cast<size_t>(bytes_transfered));
}
Release();
}
void OnReadDone(size_t bytes_read) {
DCHECK(is_read_pending_);
is_read_pending_ = false;
if (bytes_read > 0) {
size_t next_read_size = 0;
if (OnReadComplete(bytes_read, &next_read_size)) {
ReadMore(next_read_size);
} else {
OnError(Error::kReceivedMalformedData);
}
} else if (bytes_read == 0) {
OnError(Error::kDisconnected);
}
}
void OnWriteDone(size_t bytes_written) {
if (bytes_written == 0)
return;
bool write_error = false;
{
base::AutoLock lock(write_lock_);
DCHECK(is_write_pending_);
is_write_pending_ = false;
DCHECK(!outgoing_messages_.empty());
Channel::MessagePtr message = std::move(outgoing_messages_.front());
outgoing_messages_.pop_front();
// Invalidate all the scoped handles so we don't attempt to close them.
std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
for (auto& handle : handles)
handle.CompleteTransit();
// Overlapped WriteFile() to a pipe should always fully complete.
if (message->data_num_bytes() != bytes_written)
reject_writes_ = write_error = true;
else if (!WriteNextNoLock())
reject_writes_ = write_error = true;
}
if (write_error)
OnWriteError(Error::kDisconnected);
}
void ReadMore(size_t next_read_size_hint) {
DCHECK(!is_read_pending_);
size_t buffer_capacity = next_read_size_hint;
char* buffer = GetReadBuffer(&buffer_capacity);
DCHECK_GT(buffer_capacity, 0u);
BOOL ok =
::ReadFile(handle_.Get(), buffer, static_cast<DWORD>(buffer_capacity),
NULL, &read_context_.overlapped);
if (ok || GetLastError() == ERROR_IO_PENDING) {
is_read_pending_ = true;
AddRef();
} else {
OnError(Error::kDisconnected);
}
}
// Attempts to write a message directly to the channel. If the full message
// cannot be written, it's queued and a wait is initiated to write the message
// ASAP on the I/O thread.
bool WriteNoLock(const Channel::MessagePtr& message) {
BOOL ok = WriteFile(handle_.Get(), message->data(),
static_cast<DWORD>(message->data_num_bytes()), NULL,
&write_context_.overlapped);
if (ok || GetLastError() == ERROR_IO_PENDING) {
is_write_pending_ = true;
AddRef();
return true;
}
return false;
}
bool WriteNextNoLock() {
if (outgoing_messages_.empty())
return true;
return WriteNoLock(outgoing_messages_.front());
}
void OnWriteError(Error error) {
DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
DCHECK(reject_writes_);
if (error == Error::kDisconnected) {
// If we can't write because the pipe is disconnected then continue
// reading to fetch any in-flight messages, relying on end-of-stream to
// signal the actual disconnection.
if (is_read_pending_ || is_connect_pending_)
return;
}
OnError(error);
}
// Keeps the Channel alive at least until explicit shutdown on the IO thread.
scoped_refptr<Channel> self_;
// The pipe handle this Channel uses for communication.
base::win::ScopedHandle handle_;
// Indicates whether |handle_| must wait for a connection.
bool needs_connection_ = false;
const scoped_refptr<base::TaskRunner> io_task_runner_;
base::MessagePumpForIO::IOContext connect_context_;
base::MessagePumpForIO::IOContext read_context_;
bool is_connect_pending_ = false;
bool is_read_pending_ = false;
// Protects all fields potentially accessed on multiple threads via Write().
base::Lock write_lock_;
base::MessagePumpForIO::IOContext write_context_;
base::circular_deque<Channel::MessagePtr> outgoing_messages_;
bool delay_writes_ = true;
bool reject_writes_ = false;
bool is_write_pending_ = false;
bool leak_handle_ = false;
DISALLOW_COPY_AND_ASSIGN(ChannelWin);
};
} // namespace
// static
scoped_refptr<Channel> Channel::Create(
Delegate* delegate,
ConnectionParams connection_params,
scoped_refptr<base::TaskRunner> io_task_runner) {
return new ChannelWin(delegate, std::move(connection_params), io_task_runner);
}
} // namespace core
} // namespace mojo