// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "mojo/public/cpp/bindings/message.h" #include <stddef.h> #include <stdint.h> #include <stdlib.h> #include <algorithm> #include <utility> #include "base/bind.h" #include "base/lazy_instance.h" #include "base/logging.h" #include "base/numerics/safe_math.h" #include "base/strings/stringprintf.h" #include "base/threading/thread_local.h" #include "mojo/public/cpp/bindings/associated_group_controller.h" #include "mojo/public/cpp/bindings/lib/array_internal.h" #include "mojo/public/cpp/bindings/lib/unserialized_message_context.h" namespace mojo { namespace { base::LazyInstance<base::ThreadLocalPointer<internal::MessageDispatchContext>>:: Leaky g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER; base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>::Leaky g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER; void DoNotifyBadMessage(Message message, const std::string& error) { message.NotifyBadMessage(error); } template <typename HeaderType> void AllocateHeaderFromBuffer(internal::Buffer* buffer, HeaderType** header) { *header = buffer->AllocateAndGet<HeaderType>(); (*header)->num_bytes = sizeof(HeaderType); } void WriteMessageHeader(uint32_t name, uint32_t flags, size_t payload_interface_id_count, internal::Buffer* payload_buffer) { if (payload_interface_id_count > 0) { // Version 2 internal::MessageHeaderV2* header; AllocateHeaderFromBuffer(payload_buffer, &header); header->version = 2; header->name = name; header->flags = flags; // The payload immediately follows the header. header->payload.Set(header + 1); } else if (flags & (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) { // Version 1 internal::MessageHeaderV1* header; AllocateHeaderFromBuffer(payload_buffer, &header); header->version = 1; header->name = name; header->flags = flags; } else { internal::MessageHeader* header; AllocateHeaderFromBuffer(payload_buffer, &header); header->version = 0; header->name = name; header->flags = flags; } } void CreateSerializedMessageObject(uint32_t name, uint32_t flags, size_t payload_size, size_t payload_interface_id_count, std::vector<ScopedHandle>* handles, ScopedMessageHandle* out_handle, internal::Buffer* out_buffer) { ScopedMessageHandle handle; MojoResult rv = mojo::CreateMessage(&handle); DCHECK_EQ(MOJO_RESULT_OK, rv); DCHECK(handle.is_valid()); void* buffer; uint32_t buffer_size; size_t total_size = internal::ComputeSerializedMessageSize( flags, payload_size, payload_interface_id_count); DCHECK(base::IsValueInRangeForNumericType<uint32_t>(total_size)); DCHECK(!handles || base::IsValueInRangeForNumericType<uint32_t>(handles->size())); rv = MojoAppendMessageData( handle->value(), static_cast<uint32_t>(total_size), handles ? reinterpret_cast<MojoHandle*>(handles->data()) : nullptr, handles ? static_cast<uint32_t>(handles->size()) : 0, nullptr, &buffer, &buffer_size); DCHECK_EQ(MOJO_RESULT_OK, rv); if (handles) { // Handle ownership has been taken by MojoAppendMessageData. for (size_t i = 0; i < handles->size(); ++i) ignore_result(handles->at(i).release()); } internal::Buffer payload_buffer(handle.get(), total_size, buffer, buffer_size); // Make sure we zero the memory first! memset(payload_buffer.data(), 0, total_size); WriteMessageHeader(name, flags, payload_interface_id_count, &payload_buffer); *out_handle = std::move(handle); *out_buffer = std::move(payload_buffer); } void SerializeUnserializedContext(MojoMessageHandle message, uintptr_t context_value) { auto* context = reinterpret_cast<internal::UnserializedMessageContext*>(context_value); void* buffer; uint32_t buffer_size; MojoResult attach_result = MojoAppendMessageData( message, 0, nullptr, 0, nullptr, &buffer, &buffer_size); if (attach_result != MOJO_RESULT_OK) return; internal::Buffer payload_buffer(MessageHandle(message), 0, buffer, buffer_size); WriteMessageHeader(context->message_name(), context->message_flags(), 0 /* payload_interface_id_count */, &payload_buffer); // We need to copy additional header data which may have been set after // message construction, as this codepath may be reached at some arbitrary // time between message send and message dispatch. static_cast<internal::MessageHeader*>(buffer)->interface_id = context->header()->interface_id; if (context->header()->flags & (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) { DCHECK_GE(context->header()->version, 1u); static_cast<internal::MessageHeaderV1*>(buffer)->request_id = context->header()->request_id; } internal::SerializationContext serialization_context; context->Serialize(&serialization_context, &payload_buffer); // TODO(crbug.com/753433): Support lazy serialization of associated endpoint // handles. See corresponding TODO in the bindings generator for proof that // this DCHECK is indeed valid. DCHECK(serialization_context.associated_endpoint_handles()->empty()); if (!serialization_context.handles()->empty()) payload_buffer.AttachHandles(serialization_context.mutable_handles()); payload_buffer.Seal(); } void DestroyUnserializedContext(uintptr_t context) { delete reinterpret_cast<internal::UnserializedMessageContext*>(context); } ScopedMessageHandle CreateUnserializedMessageObject( std::unique_ptr<internal::UnserializedMessageContext> context) { ScopedMessageHandle handle; MojoResult rv = mojo::CreateMessage(&handle); DCHECK_EQ(MOJO_RESULT_OK, rv); DCHECK(handle.is_valid()); rv = MojoSetMessageContext( handle->value(), reinterpret_cast<uintptr_t>(context.release()), &SerializeUnserializedContext, &DestroyUnserializedContext, nullptr); DCHECK_EQ(MOJO_RESULT_OK, rv); return handle; } } // namespace Message::Message() = default; Message::Message(Message&& other) : handle_(std::move(other.handle_)), payload_buffer_(std::move(other.payload_buffer_)), handles_(std::move(other.handles_)), associated_endpoint_handles_( std::move(other.associated_endpoint_handles_)), transferable_(other.transferable_), serialized_(other.serialized_) { other.transferable_ = false; other.serialized_ = false; #if defined(ENABLE_IPC_FUZZER) interface_name_ = other.interface_name_; method_name_ = other.method_name_; #endif } Message::Message(std::unique_ptr<internal::UnserializedMessageContext> context) : Message(CreateUnserializedMessageObject(std::move(context))) {} Message::Message(uint32_t name, uint32_t flags, size_t payload_size, size_t payload_interface_id_count, std::vector<ScopedHandle>* handles) { CreateSerializedMessageObject(name, flags, payload_size, payload_interface_id_count, handles, &handle_, &payload_buffer_); transferable_ = true; serialized_ = true; } Message::Message(ScopedMessageHandle handle) { DCHECK(handle.is_valid()); uintptr_t context_value = 0; MojoResult get_context_result = MojoGetMessageContext(handle->value(), nullptr, &context_value); if (get_context_result == MOJO_RESULT_NOT_FOUND) { // It's a serialized message. Extract handles if possible. uint32_t num_bytes; void* buffer; uint32_t num_handles = 0; MojoResult rv = MojoGetMessageData(handle->value(), nullptr, &buffer, &num_bytes, nullptr, &num_handles); if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) { handles_.resize(num_handles); rv = MojoGetMessageData(handle->value(), nullptr, &buffer, &num_bytes, reinterpret_cast<MojoHandle*>(handles_.data()), &num_handles); } else { // No handles, so it's safe to retransmit this message if the caller // really wants to. transferable_ = true; } if (rv != MOJO_RESULT_OK) { // Failed to deserialize handles. Leave the Message uninitialized. return; } payload_buffer_ = internal::Buffer(buffer, num_bytes, num_bytes); serialized_ = true; } else { DCHECK_EQ(MOJO_RESULT_OK, get_context_result); auto* context = reinterpret_cast<internal::UnserializedMessageContext*>(context_value); // Dummy data address so common header accessors still behave properly. The // choice is V1 reflects unserialized message capabilities: we may or may // not need to support request IDs (which require at least V1), but we never // (for now, anyway) need to support associated interface handles (V2). payload_buffer_ = internal::Buffer(context->header(), sizeof(internal::MessageHeaderV1), sizeof(internal::MessageHeaderV1)); transferable_ = true; serialized_ = false; } handle_ = std::move(handle); } Message::~Message() = default; Message& Message::operator=(Message&& other) { handle_ = std::move(other.handle_); payload_buffer_ = std::move(other.payload_buffer_); handles_ = std::move(other.handles_); associated_endpoint_handles_ = std::move(other.associated_endpoint_handles_); transferable_ = other.transferable_; other.transferable_ = false; serialized_ = other.serialized_; other.serialized_ = false; #if defined(ENABLE_IPC_FUZZER) interface_name_ = other.interface_name_; method_name_ = other.method_name_; #endif return *this; } void Message::Reset() { handle_.reset(); payload_buffer_.Reset(); handles_.clear(); associated_endpoint_handles_.clear(); transferable_ = false; serialized_ = false; } const uint8_t* Message::payload() const { if (version() < 2) return data() + header()->num_bytes; DCHECK(!header_v2()->payload.is_null()); return static_cast<const uint8_t*>(header_v2()->payload.Get()); } uint32_t Message::payload_num_bytes() const { DCHECK_GE(data_num_bytes(), header()->num_bytes); size_t num_bytes; if (version() < 2) { num_bytes = data_num_bytes() - header()->num_bytes; } else { auto payload_begin = reinterpret_cast<uintptr_t>(header_v2()->payload.Get()); auto payload_end = reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get()); if (!payload_end) payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes()); DCHECK_GE(payload_end, payload_begin); num_bytes = payload_end - payload_begin; } DCHECK(base::IsValueInRangeForNumericType<uint32_t>(num_bytes)); return static_cast<uint32_t>(num_bytes); } uint32_t Message::payload_num_interface_ids() const { auto* array_pointer = version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get(); return array_pointer ? static_cast<uint32_t>(array_pointer->size()) : 0; } const uint32_t* Message::payload_interface_ids() const { auto* array_pointer = version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get(); return array_pointer ? array_pointer->storage() : nullptr; } void Message::AttachHandlesFromSerializationContext( internal::SerializationContext* context) { if (context->handles()->empty() && context->associated_endpoint_handles()->empty()) { // No handles attached, so no extra serialization work. return; } if (context->associated_endpoint_handles()->empty()) { // Attaching only non-associated handles is easier since we don't have to // modify the message header. Faster path for that. payload_buffer_.AttachHandles(context->mutable_handles()); return; } // Allocate a new message with enough space to hold all attached handles. Copy // this message's contents into the new one and use it to replace ourself. // // TODO(rockot): We could avoid the extra full message allocation by instead // growing the buffer and carefully moving its contents around. This errs on // the side of less complexity with probably only marginal performance cost. uint32_t payload_size = payload_num_bytes(); mojo::Message new_message(name(), header()->flags, payload_size, context->associated_endpoint_handles()->size(), context->mutable_handles()); std::swap(*context->mutable_associated_endpoint_handles(), new_message.associated_endpoint_handles_); memcpy(new_message.payload_buffer()->AllocateAndGet(payload_size), payload(), payload_size); *this = std::move(new_message); } ScopedMessageHandle Message::TakeMojoMessage() { // If there are associated endpoints transferred, // SerializeAssociatedEndpointHandles() must be called before this method. DCHECK(associated_endpoint_handles_.empty()); DCHECK(transferable_); payload_buffer_.Seal(); auto handle = std::move(handle_); Reset(); return handle; } void Message::NotifyBadMessage(const std::string& error) { DCHECK(handle_.is_valid()); mojo::NotifyBadMessage(handle_.get(), error); } void Message::SerializeAssociatedEndpointHandles( AssociatedGroupController* group_controller) { if (associated_endpoint_handles_.empty()) return; DCHECK_GE(version(), 2u); DCHECK(header_v2()->payload_interface_ids.is_null()); DCHECK(payload_buffer_.is_valid()); DCHECK(handle_.is_valid()); size_t size = associated_endpoint_handles_.size(); internal::Array_Data<uint32_t>::BufferWriter handle_writer; handle_writer.Allocate(size, &payload_buffer_); header_v2()->payload_interface_ids.Set(handle_writer.data()); for (size_t i = 0; i < size; ++i) { ScopedInterfaceEndpointHandle& handle = associated_endpoint_handles_[i]; DCHECK(handle.pending_association()); handle_writer->storage()[i] = group_controller->AssociateInterface(std::move(handle)); } associated_endpoint_handles_.clear(); } bool Message::DeserializeAssociatedEndpointHandles( AssociatedGroupController* group_controller) { if (!serialized_) return true; associated_endpoint_handles_.clear(); uint32_t num_ids = payload_num_interface_ids(); if (num_ids == 0) return true; associated_endpoint_handles_.reserve(num_ids); uint32_t* ids = header_v2()->payload_interface_ids.Get()->storage(); bool result = true; for (uint32_t i = 0; i < num_ids; ++i) { auto handle = group_controller->CreateLocalEndpointHandle(ids[i]); if (IsValidInterfaceId(ids[i]) && !handle.is_valid()) { // |ids[i]| itself is valid but handle creation failed. In that case, mark // deserialization as failed but continue to deserialize the rest of // handles. result = false; } associated_endpoint_handles_.push_back(std::move(handle)); ids[i] = kInvalidInterfaceId; } return result; } void Message::SerializeIfNecessary() { MojoResult rv = MojoSerializeMessage(handle_->value(), nullptr); if (rv == MOJO_RESULT_FAILED_PRECONDITION) return; // Reconstruct this Message instance from the serialized message's handle. *this = Message(std::move(handle_)); } std::unique_ptr<internal::UnserializedMessageContext> Message::TakeUnserializedContext( const internal::UnserializedMessageContext::Tag* tag) { DCHECK(handle_.is_valid()); uintptr_t context_value = 0; MojoResult rv = MojoGetMessageContext(handle_->value(), nullptr, &context_value); if (rv == MOJO_RESULT_NOT_FOUND) return nullptr; DCHECK_EQ(MOJO_RESULT_OK, rv); auto* context = reinterpret_cast<internal::UnserializedMessageContext*>(context_value); if (context->tag() != tag) return nullptr; // Detach the context from the message. rv = MojoSetMessageContext(handle_->value(), 0, nullptr, nullptr, nullptr); DCHECK_EQ(MOJO_RESULT_OK, rv); return base::WrapUnique(context); } bool MessageReceiver::PrefersSerializedMessages() { return false; } PassThroughFilter::PassThroughFilter() {} PassThroughFilter::~PassThroughFilter() {} bool PassThroughFilter::Accept(Message* message) { return true; } SyncMessageResponseContext::SyncMessageResponseContext() : outer_context_(current()) { g_tls_sync_response_context.Get().Set(this); } SyncMessageResponseContext::~SyncMessageResponseContext() { DCHECK_EQ(current(), this); g_tls_sync_response_context.Get().Set(outer_context_); } // static SyncMessageResponseContext* SyncMessageResponseContext::current() { return g_tls_sync_response_context.Get().Get(); } void SyncMessageResponseContext::ReportBadMessage(const std::string& error) { GetBadMessageCallback().Run(error); } ReportBadMessageCallback SyncMessageResponseContext::GetBadMessageCallback() { DCHECK(!response_.IsNull()); return base::BindOnce(&DoNotifyBadMessage, std::move(response_)); } MojoResult ReadMessage(MessagePipeHandle handle, Message* message) { ScopedMessageHandle message_handle; MojoResult rv = ReadMessageNew(handle, &message_handle, MOJO_READ_MESSAGE_FLAG_NONE); if (rv != MOJO_RESULT_OK) return rv; *message = Message(std::move(message_handle)); return MOJO_RESULT_OK; } void ReportBadMessage(const std::string& error) { internal::MessageDispatchContext* context = internal::MessageDispatchContext::current(); DCHECK(context); context->GetBadMessageCallback().Run(error); } ReportBadMessageCallback GetBadMessageCallback() { internal::MessageDispatchContext* context = internal::MessageDispatchContext::current(); DCHECK(context); return context->GetBadMessageCallback(); } namespace internal { MessageHeaderV2::MessageHeaderV2() = default; MessageDispatchContext::MessageDispatchContext(Message* message) : outer_context_(current()), message_(message) { g_tls_message_dispatch_context.Get().Set(this); } MessageDispatchContext::~MessageDispatchContext() { DCHECK_EQ(current(), this); g_tls_message_dispatch_context.Get().Set(outer_context_); } // static MessageDispatchContext* MessageDispatchContext::current() { return g_tls_message_dispatch_context.Get().Get(); } ReportBadMessageCallback MessageDispatchContext::GetBadMessageCallback() { DCHECK(!message_->IsNull()); return base::BindOnce(&DoNotifyBadMessage, std::move(*message_)); } // static void SyncMessageResponseSetup::SetCurrentSyncResponseMessage(Message* message) { SyncMessageResponseContext* context = SyncMessageResponseContext::current(); if (context) context->response_ = std::move(*message); } } // namespace internal } // namespace mojo