// Copyright 2016 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <windows.h>
#include <limits>
#include <utility>
#include "base/debug/alias.h"
#include "base/memory/platform_shared_memory_region.h"
#include "base/numerics/safe_conversions.h"
#include "base/strings/string_piece.h"
#include "mojo/core/broker.h"
#include "mojo/core/broker_messages.h"
#include "mojo/core/channel.h"
#include "mojo/core/platform_handle_utils.h"
#include "mojo/public/cpp/platform/named_platform_channel.h"
namespace mojo {
namespace core {
namespace {
// 256 bytes should be enough for anyone!
const size_t kMaxBrokerMessageSize = 256;
bool TakeHandlesFromBrokerMessage(Channel::Message* message,
size_t num_handles,
PlatformHandle* out_handles) {
if (message->num_handles() != num_handles) {
DLOG(ERROR) << "Received unexpected number of handles in broker message";
return false;
}
std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
DCHECK_EQ(handles.size(), num_handles);
DCHECK(out_handles);
for (size_t i = 0; i < num_handles; ++i)
out_handles[i] = handles[i].TakeHandle();
return true;
}
Channel::MessagePtr WaitForBrokerMessage(HANDLE pipe_handle,
BrokerMessageType expected_type) {
char buffer[kMaxBrokerMessageSize];
DWORD bytes_read = 0;
BOOL result = ::ReadFile(pipe_handle, buffer, kMaxBrokerMessageSize,
&bytes_read, nullptr);
if (!result) {
// The pipe may be broken if the browser side has been closed, e.g. during
// browser shutdown. In that case the ReadFile call will fail and we
// shouldn't continue waiting.
PLOG(ERROR) << "Error reading broker pipe";
return nullptr;
}
Channel::MessagePtr message =
Channel::Message::Deserialize(buffer, static_cast<size_t>(bytes_read));
if (!message || message->payload_size() < sizeof(BrokerMessageHeader)) {
LOG(ERROR) << "Invalid broker message";
base::debug::Alias(&buffer[0]);
base::debug::Alias(&bytes_read);
CHECK(false);
return nullptr;
}
const BrokerMessageHeader* header =
reinterpret_cast<const BrokerMessageHeader*>(message->payload());
if (header->type != expected_type) {
LOG(ERROR) << "Unexpected broker message type";
base::debug::Alias(&buffer[0]);
base::debug::Alias(&bytes_read);
CHECK(false);
return nullptr;
}
return message;
}
} // namespace
Broker::Broker(PlatformHandle handle) : sync_channel_(std::move(handle)) {
CHECK(sync_channel_.is_valid());
Channel::MessagePtr message = WaitForBrokerMessage(
sync_channel_.GetHandle().Get(), BrokerMessageType::INIT);
// If we fail to read a message (broken pipe), just return early. The inviter
// handle will be null and callers must handle this gracefully.
if (!message)
return;
PlatformHandle endpoint_handle;
if (TakeHandlesFromBrokerMessage(message.get(), 1, &endpoint_handle)) {
inviter_endpoint_ = PlatformChannelEndpoint(std::move(endpoint_handle));
} else {
// If the message has no handles, we expect it to carry pipe name instead.
const BrokerMessageHeader* header =
static_cast<const BrokerMessageHeader*>(message->payload());
CHECK_GE(message->payload_size(),
sizeof(BrokerMessageHeader) + sizeof(InitData));
const InitData* data = reinterpret_cast<const InitData*>(header + 1);
CHECK_EQ(message->payload_size(),
sizeof(BrokerMessageHeader) + sizeof(InitData) +
data->pipe_name_length * sizeof(base::char16));
const base::char16* name_data =
reinterpret_cast<const base::char16*>(data + 1);
CHECK(data->pipe_name_length);
inviter_endpoint_ = NamedPlatformChannel::ConnectToServer(
base::StringPiece16(name_data, data->pipe_name_length).as_string());
}
}
Broker::~Broker() {}
PlatformChannelEndpoint Broker::GetInviterEndpoint() {
return std::move(inviter_endpoint_);
}
base::WritableSharedMemoryRegion Broker::GetWritableSharedMemoryRegion(
size_t num_bytes) {
base::AutoLock lock(lock_);
BufferRequestData* buffer_request;
Channel::MessagePtr out_message = CreateBrokerMessage(
BrokerMessageType::BUFFER_REQUEST, 0, 0, &buffer_request);
buffer_request->size = base::checked_cast<uint32_t>(num_bytes);
DWORD bytes_written = 0;
BOOL result =
::WriteFile(sync_channel_.GetHandle().Get(), out_message->data(),
static_cast<DWORD>(out_message->data_num_bytes()),
&bytes_written, nullptr);
if (!result ||
static_cast<size_t>(bytes_written) != out_message->data_num_bytes()) {
PLOG(ERROR) << "Error sending sync broker message";
return base::WritableSharedMemoryRegion();
}
PlatformHandle handle;
Channel::MessagePtr response = WaitForBrokerMessage(
sync_channel_.GetHandle().Get(), BrokerMessageType::BUFFER_RESPONSE);
if (response && TakeHandlesFromBrokerMessage(response.get(), 1, &handle)) {
BufferResponseData* data;
if (!GetBrokerMessageData(response.get(), &data))
return base::WritableSharedMemoryRegion();
return base::WritableSharedMemoryRegion::Deserialize(
base::subtle::PlatformSharedMemoryRegion::Take(
CreateSharedMemoryRegionHandleFromPlatformHandles(std::move(handle),
PlatformHandle()),
base::subtle::PlatformSharedMemoryRegion::Mode::kWritable,
num_bytes,
base::UnguessableToken::Deserialize(data->guid_high,
data->guid_low)));
}
return base::WritableSharedMemoryRegion();
}
} // namespace core
} // namespace mojo