// Copyright 2015 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "mojo/public/cpp/bindings/interface_endpoint_client.h" #include <stdint.h> #include "base/bind.h" #include "base/location.h" #include "base/logging.h" #include "base/macros.h" #include "base/memory/ptr_util.h" #include "base/sequenced_task_runner.h" #include "base/stl_util.h" #include "mojo/public/cpp/bindings/associated_group.h" #include "mojo/public/cpp/bindings/associated_group_controller.h" #include "mojo/public/cpp/bindings/interface_endpoint_controller.h" #include "mojo/public/cpp/bindings/lib/task_runner_helper.h" #include "mojo/public/cpp/bindings/lib/validation_util.h" #include "mojo/public/cpp/bindings/sync_call_restrictions.h" namespace mojo { // ---------------------------------------------------------------------------- namespace { void DetermineIfEndpointIsConnected( const base::WeakPtr<InterfaceEndpointClient>& client, base::OnceCallback<void(bool)> callback) { std::move(callback).Run(client && !client->encountered_error()); } // When receiving an incoming message which expects a repsonse, // InterfaceEndpointClient creates a ResponderThunk object and passes it to the // incoming message receiver. When the receiver finishes processing the message, // it can provide a response using this object. class ResponderThunk : public MessageReceiverWithStatus { public: explicit ResponderThunk( const base::WeakPtr<InterfaceEndpointClient>& endpoint_client, scoped_refptr<base::SequencedTaskRunner> runner) : endpoint_client_(endpoint_client), accept_was_invoked_(false), task_runner_(std::move(runner)) {} ~ResponderThunk() override { if (!accept_was_invoked_) { // The Service handled a message that was expecting a response // but did not send a response. // We raise an error to signal the calling application that an error // condition occurred. Without this the calling application would have no // way of knowing it should stop waiting for a response. if (task_runner_->RunsTasksInCurrentSequence()) { // Please note that even if this code is run from a different task // runner on the same thread as |task_runner_|, it is okay to directly // call InterfaceEndpointClient::RaiseError(), because it will raise // error from the correct task runner asynchronously. if (endpoint_client_) { endpoint_client_->RaiseError(); } } else { task_runner_->PostTask( FROM_HERE, base::Bind(&InterfaceEndpointClient::RaiseError, endpoint_client_)); } } } // MessageReceiver implementation: bool PrefersSerializedMessages() override { return endpoint_client_ && endpoint_client_->PrefersSerializedMessages(); } bool Accept(Message* message) override { DCHECK(task_runner_->RunsTasksInCurrentSequence()); accept_was_invoked_ = true; DCHECK(message->has_flag(Message::kFlagIsResponse)); bool result = false; if (endpoint_client_) result = endpoint_client_->Accept(message); return result; } // MessageReceiverWithStatus implementation: bool IsConnected() override { DCHECK(task_runner_->RunsTasksInCurrentSequence()); return endpoint_client_ && !endpoint_client_->encountered_error(); } void IsConnectedAsync(base::OnceCallback<void(bool)> callback) override { if (task_runner_->RunsTasksInCurrentSequence()) { DetermineIfEndpointIsConnected(endpoint_client_, std::move(callback)); } else { task_runner_->PostTask( FROM_HERE, base::BindOnce(&DetermineIfEndpointIsConnected, endpoint_client_, std::move(callback))); } } private: base::WeakPtr<InterfaceEndpointClient> endpoint_client_; bool accept_was_invoked_; scoped_refptr<base::SequencedTaskRunner> task_runner_; DISALLOW_COPY_AND_ASSIGN(ResponderThunk); }; } // namespace // ---------------------------------------------------------------------------- InterfaceEndpointClient::SyncResponseInfo::SyncResponseInfo( bool* in_response_received) : response_received(in_response_received) {} InterfaceEndpointClient::SyncResponseInfo::~SyncResponseInfo() {} // ---------------------------------------------------------------------------- InterfaceEndpointClient::HandleIncomingMessageThunk::HandleIncomingMessageThunk( InterfaceEndpointClient* owner) : owner_(owner) {} InterfaceEndpointClient::HandleIncomingMessageThunk:: ~HandleIncomingMessageThunk() {} bool InterfaceEndpointClient::HandleIncomingMessageThunk::Accept( Message* message) { return owner_->HandleValidatedMessage(message); } // ---------------------------------------------------------------------------- InterfaceEndpointClient::InterfaceEndpointClient( ScopedInterfaceEndpointHandle handle, MessageReceiverWithResponderStatus* receiver, std::unique_ptr<MessageReceiver> payload_validator, bool expect_sync_requests, scoped_refptr<base::SequencedTaskRunner> runner, uint32_t interface_version) : expect_sync_requests_(expect_sync_requests), handle_(std::move(handle)), incoming_receiver_(receiver), thunk_(this), filters_(&thunk_), task_runner_(std::move(runner)), control_message_proxy_(this), control_message_handler_(interface_version), weak_ptr_factory_(this) { DCHECK(handle_.is_valid()); // TODO(yzshen): the way to use validator (or message filter in general) // directly is a little awkward. if (payload_validator) filters_.Append(std::move(payload_validator)); if (handle_.pending_association()) { handle_.SetAssociationEventHandler(base::Bind( &InterfaceEndpointClient::OnAssociationEvent, base::Unretained(this))); } else { InitControllerIfNecessary(); } } InterfaceEndpointClient::~InterfaceEndpointClient() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (controller_) handle_.group_controller()->DetachEndpointClient(handle_); } AssociatedGroup* InterfaceEndpointClient::associated_group() { if (!associated_group_) associated_group_ = std::make_unique<AssociatedGroup>(handle_); return associated_group_.get(); } ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK(!has_pending_responders()); if (!handle_.is_valid()) return ScopedInterfaceEndpointHandle(); handle_.SetAssociationEventHandler( ScopedInterfaceEndpointHandle::AssociationEventCallback()); if (controller_) { controller_ = nullptr; handle_.group_controller()->DetachEndpointClient(handle_); } return std::move(handle_); } void InterfaceEndpointClient::AddFilter( std::unique_ptr<MessageReceiver> filter) { filters_.Append(std::move(filter)); } void InterfaceEndpointClient::RaiseError() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (!handle_.pending_association()) handle_.group_controller()->RaiseError(); } void InterfaceEndpointClient::CloseWithReason(uint32_t custom_reason, const std::string& description) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); auto handle = PassHandle(); handle.ResetWithReason(custom_reason, description); } bool InterfaceEndpointClient::PrefersSerializedMessages() { auto* controller = handle_.group_controller(); return controller && controller->PrefersSerializedMessages(); } bool InterfaceEndpointClient::Accept(Message* message) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK(!message->has_flag(Message::kFlagExpectsResponse)); DCHECK(!handle_.pending_association()); // This has to been done even if connection error has occurred. For example, // the message contains a pending associated request. The user may try to use // the corresponding associated interface pointer after sending this message. // That associated interface pointer has to join an associated group in order // to work properly. if (!message->associated_endpoint_handles()->empty()) message->SerializeAssociatedEndpointHandles(handle_.group_controller()); if (encountered_error_) return false; InitControllerIfNecessary(); return controller_->SendMessage(message); } bool InterfaceEndpointClient::AcceptWithResponder( Message* message, std::unique_ptr<MessageReceiver> responder) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK(message->has_flag(Message::kFlagExpectsResponse)); DCHECK(!handle_.pending_association()); // Please see comments in Accept(). if (!message->associated_endpoint_handles()->empty()) message->SerializeAssociatedEndpointHandles(handle_.group_controller()); if (encountered_error_) return false; InitControllerIfNecessary(); // Reserve 0 in case we want it to convey special meaning in the future. uint64_t request_id = next_request_id_++; if (request_id == 0) request_id = next_request_id_++; message->set_request_id(request_id); bool is_sync = message->has_flag(Message::kFlagIsSync); if (!controller_->SendMessage(message)) return false; if (!is_sync) { async_responders_[request_id] = std::move(responder); return true; } SyncCallRestrictions::AssertSyncCallAllowed(); bool response_received = false; sync_responses_.insert(std::make_pair( request_id, std::make_unique<SyncResponseInfo>(&response_received))); base::WeakPtr<InterfaceEndpointClient> weak_self = weak_ptr_factory_.GetWeakPtr(); controller_->SyncWatch(&response_received); // Make sure that this instance hasn't been destroyed. if (weak_self) { DCHECK(base::ContainsKey(sync_responses_, request_id)); auto iter = sync_responses_.find(request_id); DCHECK_EQ(&response_received, iter->second->response_received); if (response_received) { ignore_result(responder->Accept(&iter->second->response)); } else { DVLOG(1) << "Mojo sync call returns without receiving a response. " << "Typcially it is because the interface has been " << "disconnected."; } sync_responses_.erase(iter); } return true; } bool InterfaceEndpointClient::HandleIncomingMessage(Message* message) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return filters_.Accept(message); } void InterfaceEndpointClient::NotifyError( const base::Optional<DisconnectReason>& reason) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (encountered_error_) return; encountered_error_ = true; // Response callbacks may hold on to resource, and there's no need to keep // them alive any longer. Note that it's allowed that a pending response // callback may own this endpoint, so we simply move the responders onto the // stack here and let them be destroyed when the stack unwinds. AsyncResponderMap responders = std::move(async_responders_); control_message_proxy_.OnConnectionError(); if (error_handler_) { std::move(error_handler_).Run(); } else if (error_with_reason_handler_) { if (reason) { std::move(error_with_reason_handler_) .Run(reason->custom_reason, reason->description); } else { std::move(error_with_reason_handler_).Run(0, std::string()); } } } void InterfaceEndpointClient::QueryVersion( const base::Callback<void(uint32_t)>& callback) { control_message_proxy_.QueryVersion(callback); } void InterfaceEndpointClient::RequireVersion(uint32_t version) { control_message_proxy_.RequireVersion(version); } void InterfaceEndpointClient::FlushForTesting() { control_message_proxy_.FlushForTesting(); } void InterfaceEndpointClient::InitControllerIfNecessary() { if (controller_ || handle_.pending_association()) return; controller_ = handle_.group_controller()->AttachEndpointClient(handle_, this, task_runner_); if (expect_sync_requests_) controller_->AllowWokenUpBySyncWatchOnSameThread(); } void InterfaceEndpointClient::OnAssociationEvent( ScopedInterfaceEndpointHandle::AssociationEvent event) { if (event == ScopedInterfaceEndpointHandle::ASSOCIATED) { InitControllerIfNecessary(); } else if (event == ScopedInterfaceEndpointHandle::PEER_CLOSED_BEFORE_ASSOCIATION) { task_runner_->PostTask(FROM_HERE, base::Bind(&InterfaceEndpointClient::NotifyError, weak_ptr_factory_.GetWeakPtr(), handle_.disconnect_reason())); } } bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) { DCHECK_EQ(handle_.id(), message->interface_id()); if (encountered_error_) { // This message is received after error has been encountered. For associated // interfaces, this means the remote side sends a // PeerAssociatedEndpointClosed event but continues to send more messages // for the same interface. Close the pipe because this shouldn't happen. DVLOG(1) << "A message is received for an interface after it has been " << "disconnected. Closing the pipe."; return false; } if (message->has_flag(Message::kFlagExpectsResponse)) { std::unique_ptr<MessageReceiverWithStatus> responder = std::make_unique<ResponderThunk>(weak_ptr_factory_.GetWeakPtr(), task_runner_); if (mojo::internal::ControlMessageHandler::IsControlMessage(message)) { return control_message_handler_.AcceptWithResponder(message, std::move(responder)); } else { return incoming_receiver_->AcceptWithResponder(message, std::move(responder)); } } else if (message->has_flag(Message::kFlagIsResponse)) { uint64_t request_id = message->request_id(); if (message->has_flag(Message::kFlagIsSync)) { auto it = sync_responses_.find(request_id); if (it == sync_responses_.end()) return false; it->second->response = std::move(*message); *it->second->response_received = true; return true; } auto it = async_responders_.find(request_id); if (it == async_responders_.end()) return false; std::unique_ptr<MessageReceiver> responder = std::move(it->second); async_responders_.erase(it); return responder->Accept(message); } else { if (mojo::internal::ControlMessageHandler::IsControlMessage(message)) return control_message_handler_.Accept(message); return incoming_receiver_->Accept(message); } } } // namespace mojo