// Copyright 2016 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #ifndef MOJO_PUBLIC_CPP_BINDINGS_THREAD_SAFE_INTERFACE_PTR_H_ #define MOJO_PUBLIC_CPP_BINDINGS_THREAD_SAFE_INTERFACE_PTR_H_ #include <memory> #include "base/macros.h" #include "base/memory/ptr_util.h" #include "base/memory/ref_counted.h" #include "base/stl_util.h" #include "base/synchronization/waitable_event.h" #include "base/task_runner.h" #include "base/threading/sequenced_task_runner_handle.h" #include "mojo/public/cpp/bindings/associated_group.h" #include "mojo/public/cpp/bindings/associated_interface_ptr.h" #include "mojo/public/cpp/bindings/interface_ptr.h" #include "mojo/public/cpp/bindings/message.h" #include "mojo/public/cpp/bindings/sync_call_restrictions.h" #include "mojo/public/cpp/bindings/sync_event_watcher.h" // ThreadSafeInterfacePtr wraps a non-thread-safe InterfacePtr and proxies // messages to it. Async calls are posted to the sequence that the InteracePtr // is bound to, and the responses are posted back. Sync calls are dispatched // directly if the call is made on the sequence that the wrapped InterfacePtr is // bound to, or posted otherwise. It's important to be aware that sync calls // block both the calling sequence and the InterfacePtr sequence. That means // that you cannot make sync calls through a ThreadSafeInterfacePtr if the // underlying InterfacePtr is bound to a sequence that cannot block, like the IO // thread. namespace mojo { // Instances of this class may be used from any sequence to serialize // |Interface| messages and forward them elsewhere. In general you should use // one of the ThreadSafeInterfacePtrBase helper aliases defined below, but this // type may be useful if you need/want to manually manage the lifetime of the // underlying proxy object which will be used to ultimately send messages. template <typename Interface> class ThreadSafeForwarder : public MessageReceiverWithResponder { public: using ProxyType = typename Interface::Proxy_; using ForwardMessageCallback = base::Callback<void(Message)>; using ForwardMessageWithResponderCallback = base::Callback<void(Message, std::unique_ptr<MessageReceiver>)>; // Constructs a ThreadSafeForwarder through which Messages are forwarded to // |forward| or |forward_with_responder| by posting to |task_runner|. // // Any message sent through this forwarding interface will dispatch its reply, // if any, back to the sequence which called the corresponding interface // method. ThreadSafeForwarder( const scoped_refptr<base::SequencedTaskRunner>& task_runner, const ForwardMessageCallback& forward, const ForwardMessageWithResponderCallback& forward_with_responder, const AssociatedGroup& associated_group) : proxy_(this), task_runner_(task_runner), forward_(forward), forward_with_responder_(forward_with_responder), associated_group_(associated_group), sync_calls_(new InProgressSyncCalls()) {} ~ThreadSafeForwarder() override { // If there are ongoing sync calls signal their completion now. base::AutoLock l(sync_calls_->lock); for (const auto& pending_response : sync_calls_->pending_responses) pending_response->event.Signal(); } ProxyType& proxy() { return proxy_; } private: // MessageReceiverWithResponder implementation: bool PrefersSerializedMessages() override { // TSIP is primarily used because it emulates legacy IPC threading behavior. // In practice this means it's only for cross-process messaging and we can // just always assume messages should be serialized. return true; } bool Accept(Message* message) override { if (!message->associated_endpoint_handles()->empty()) { // If this DCHECK fails, it is likely because: // - This is a non-associated interface pointer setup using // PtrWrapper::BindOnTaskRunner( // InterfacePtrInfo<InterfaceType> ptr_info); // Please see the TODO in that method. // - This is an associated interface which hasn't been associated with a // message pipe. In other words, the corresponding // AssociatedInterfaceRequest hasn't been sent. DCHECK(associated_group_.GetController()); message->SerializeAssociatedEndpointHandles( associated_group_.GetController()); } task_runner_->PostTask(FROM_HERE, base::Bind(forward_, base::Passed(message))); return true; } bool AcceptWithResponder( Message* message, std::unique_ptr<MessageReceiver> responder) override { if (!message->associated_endpoint_handles()->empty()) { // Please see comment for the DCHECK in the previous method. DCHECK(associated_group_.GetController()); message->SerializeAssociatedEndpointHandles( associated_group_.GetController()); } // Async messages are always posted (even if |task_runner_| runs tasks on // this sequence) to guarantee that two async calls can't be reordered. if (!message->has_flag(Message::kFlagIsSync)) { auto reply_forwarder = std::make_unique<ForwardToCallingThread>(std::move(responder)); task_runner_->PostTask( FROM_HERE, base::Bind(forward_with_responder_, base::Passed(message), base::Passed(&reply_forwarder))); return true; } SyncCallRestrictions::AssertSyncCallAllowed(); // If the InterfacePtr is bound to this sequence, dispatch it directly. if (task_runner_->RunsTasksInCurrentSequence()) { forward_with_responder_.Run(std::move(*message), std::move(responder)); return true; } // If the InterfacePtr is bound on another sequence, post the call. // TODO(yzshen, watk): We block both this sequence and the InterfacePtr // sequence. Ideally only this sequence would block. auto response = base::MakeRefCounted<SyncResponseInfo>(); auto response_signaler = std::make_unique<SyncResponseSignaler>(response); task_runner_->PostTask( FROM_HERE, base::Bind(forward_with_responder_, base::Passed(message), base::Passed(&response_signaler))); // Save the pending SyncResponseInfo so that if the sync call deletes // |this|, we can signal the completion of the call to return from // SyncWatch(). auto sync_calls = sync_calls_; { base::AutoLock l(sync_calls->lock); sync_calls->pending_responses.push_back(response.get()); } auto assign_true = [](bool* b) { *b = true; }; bool event_signaled = false; SyncEventWatcher watcher(&response->event, base::Bind(assign_true, &event_signaled)); const bool* stop_flags[] = {&event_signaled}; watcher.SyncWatch(stop_flags, 1); { base::AutoLock l(sync_calls->lock); base::Erase(sync_calls->pending_responses, response.get()); } if (response->received) ignore_result(responder->Accept(&response->message)); return true; } // Data that we need to share between the sequences involved in a sync call. struct SyncResponseInfo : public base::RefCountedThreadSafe<SyncResponseInfo> { Message message; bool received = false; base::WaitableEvent event{base::WaitableEvent::ResetPolicy::MANUAL, base::WaitableEvent::InitialState::NOT_SIGNALED}; private: friend class base::RefCountedThreadSafe<SyncResponseInfo>; }; // A MessageReceiver that signals |response| when it either accepts the // response message, or is destructed. class SyncResponseSignaler : public MessageReceiver { public: explicit SyncResponseSignaler(scoped_refptr<SyncResponseInfo> response) : response_(response) {} ~SyncResponseSignaler() override { // If Accept() was not called we must still notify the waiter that the // sync call is finished. if (response_) response_->event.Signal(); } bool Accept(Message* message) override { response_->message = std::move(*message); response_->received = true; response_->event.Signal(); response_ = nullptr; return true; } private: scoped_refptr<SyncResponseInfo> response_; }; // A record of the pending sync responses for canceling pending sync calls // when the owning ThreadSafeForwarder is destructed. struct InProgressSyncCalls : public base::RefCountedThreadSafe<InProgressSyncCalls> { // |lock| protects access to |pending_responses|. base::Lock lock; std::vector<SyncResponseInfo*> pending_responses; }; class ForwardToCallingThread : public MessageReceiver { public: explicit ForwardToCallingThread(std::unique_ptr<MessageReceiver> responder) : responder_(std::move(responder)), caller_task_runner_(base::SequencedTaskRunnerHandle::Get()) {} ~ForwardToCallingThread() override { caller_task_runner_->DeleteSoon(FROM_HERE, std::move(responder_)); } private: bool Accept(Message* message) override { // The current instance will be deleted when this method returns, so we // have to relinquish the responder's ownership so it does not get // deleted. caller_task_runner_->PostTask( FROM_HERE, base::Bind(&ForwardToCallingThread::CallAcceptAndDeleteResponder, base::Passed(std::move(responder_)), base::Passed(std::move(*message)))); return true; } static void CallAcceptAndDeleteResponder( std::unique_ptr<MessageReceiver> responder, Message message) { ignore_result(responder->Accept(&message)); } std::unique_ptr<MessageReceiver> responder_; scoped_refptr<base::SequencedTaskRunner> caller_task_runner_; }; ProxyType proxy_; const scoped_refptr<base::SequencedTaskRunner> task_runner_; const ForwardMessageCallback forward_; const ForwardMessageWithResponderCallback forward_with_responder_; AssociatedGroup associated_group_; scoped_refptr<InProgressSyncCalls> sync_calls_; DISALLOW_COPY_AND_ASSIGN(ThreadSafeForwarder); }; template <typename InterfacePtrType> class ThreadSafeInterfacePtrBase : public base::RefCountedThreadSafe< ThreadSafeInterfacePtrBase<InterfacePtrType>> { public: using InterfaceType = typename InterfacePtrType::InterfaceType; using PtrInfoType = typename InterfacePtrType::PtrInfoType; explicit ThreadSafeInterfacePtrBase( std::unique_ptr<ThreadSafeForwarder<InterfaceType>> forwarder) : forwarder_(std::move(forwarder)) {} // Creates a ThreadSafeInterfacePtrBase wrapping an underlying non-thread-safe // InterfacePtrType which is bound to the calling sequence. All messages sent // via this thread-safe proxy will internally be sent by first posting to this // (the calling) sequence's TaskRunner. static scoped_refptr<ThreadSafeInterfacePtrBase> Create( InterfacePtrType interface_ptr) { scoped_refptr<PtrWrapper> wrapper = new PtrWrapper(std::move(interface_ptr)); return new ThreadSafeInterfacePtrBase(wrapper->CreateForwarder()); } // Creates a ThreadSafeInterfacePtrBase which binds the underlying // non-thread-safe InterfacePtrType on the specified TaskRunner. All messages // sent via this thread-safe proxy will internally be sent by first posting to // that TaskRunner. static scoped_refptr<ThreadSafeInterfacePtrBase> Create( PtrInfoType ptr_info, const scoped_refptr<base::SequencedTaskRunner>& bind_task_runner) { scoped_refptr<PtrWrapper> wrapper = new PtrWrapper(bind_task_runner); wrapper->BindOnTaskRunner(std::move(ptr_info)); return new ThreadSafeInterfacePtrBase(wrapper->CreateForwarder()); } InterfaceType* get() { return &forwarder_->proxy(); } InterfaceType* operator->() { return get(); } InterfaceType& operator*() { return *get(); } private: friend class base::RefCountedThreadSafe< ThreadSafeInterfacePtrBase<InterfacePtrType>>; struct PtrWrapperDeleter; // Helper class which owns an |InterfacePtrType| instance on an appropriate // sequence. This is kept alive as long its bound within some // ThreadSafeForwarder's callbacks. class PtrWrapper : public base::RefCountedThreadSafe<PtrWrapper, PtrWrapperDeleter> { public: explicit PtrWrapper(InterfacePtrType ptr) : PtrWrapper(base::SequencedTaskRunnerHandle::Get()) { ptr_ = std::move(ptr); associated_group_ = *ptr_.internal_state()->associated_group(); } explicit PtrWrapper( const scoped_refptr<base::SequencedTaskRunner>& task_runner) : task_runner_(task_runner) {} void BindOnTaskRunner(AssociatedInterfacePtrInfo<InterfaceType> ptr_info) { associated_group_ = AssociatedGroup(ptr_info.handle()); task_runner_->PostTask(FROM_HERE, base::Bind(&PtrWrapper::Bind, this, base::Passed(&ptr_info))); } void BindOnTaskRunner(InterfacePtrInfo<InterfaceType> ptr_info) { // TODO(yzhsen): At the momment we don't have a group controller // available. That means the user won't be able to pass associated // endpoints on this interface (at least not immediately). In order to fix // this, we need to create a MultiplexRouter immediately and bind it to // the interface pointer on the |task_runner_|. Therefore, MultiplexRouter // should be able to be created on a sequence different than the one that // it is supposed to listen on. crbug.com/682334 task_runner_->PostTask(FROM_HERE, base::Bind(&PtrWrapper::Bind, this, base::Passed(&ptr_info))); } std::unique_ptr<ThreadSafeForwarder<InterfaceType>> CreateForwarder() { return std::make_unique<ThreadSafeForwarder<InterfaceType>>( task_runner_, base::Bind(&PtrWrapper::Accept, this), base::Bind(&PtrWrapper::AcceptWithResponder, this), associated_group_); } private: friend struct PtrWrapperDeleter; ~PtrWrapper() {} void Bind(PtrInfoType ptr_info) { DCHECK(task_runner_->RunsTasksInCurrentSequence()); ptr_.Bind(std::move(ptr_info)); } void Accept(Message message) { ptr_.internal_state()->ForwardMessage(std::move(message)); } void AcceptWithResponder(Message message, std::unique_ptr<MessageReceiver> responder) { ptr_.internal_state()->ForwardMessageWithResponder(std::move(message), std::move(responder)); } void DeleteOnCorrectThread() const { if (!task_runner_->RunsTasksInCurrentSequence()) { // NOTE: This is only called when there are no more references to // |this|, so binding it unretained is both safe and necessary. task_runner_->PostTask(FROM_HERE, base::Bind(&PtrWrapper::DeleteOnCorrectThread, base::Unretained(this))); } else { delete this; } } InterfacePtrType ptr_; const scoped_refptr<base::SequencedTaskRunner> task_runner_; AssociatedGroup associated_group_; DISALLOW_COPY_AND_ASSIGN(PtrWrapper); }; struct PtrWrapperDeleter { static void Destruct(const PtrWrapper* interface_ptr) { interface_ptr->DeleteOnCorrectThread(); } }; ~ThreadSafeInterfacePtrBase() {} const std::unique_ptr<ThreadSafeForwarder<InterfaceType>> forwarder_; DISALLOW_COPY_AND_ASSIGN(ThreadSafeInterfacePtrBase); }; template <typename Interface> using ThreadSafeAssociatedInterfacePtr = ThreadSafeInterfacePtrBase<AssociatedInterfacePtr<Interface>>; template <typename Interface> using ThreadSafeInterfacePtr = ThreadSafeInterfacePtrBase<InterfacePtr<Interface>>; } // namespace mojo #endif // MOJO_PUBLIC_CPP_BINDINGS_THREAD_SAFE_INTERFACE_PTR_H_