// Copyright 2016 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MOJO_PUBLIC_CPP_BINDINGS_THREAD_SAFE_INTERFACE_PTR_H_
#define MOJO_PUBLIC_CPP_BINDINGS_THREAD_SAFE_INTERFACE_PTR_H_
#include <memory>
#include "base/macros.h"
#include "base/memory/ptr_util.h"
#include "base/memory/ref_counted.h"
#include "base/stl_util.h"
#include "base/synchronization/waitable_event.h"
#include "base/task_runner.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "mojo/public/cpp/bindings/associated_group.h"
#include "mojo/public/cpp/bindings/associated_interface_ptr.h"
#include "mojo/public/cpp/bindings/interface_ptr.h"
#include "mojo/public/cpp/bindings/message.h"
#include "mojo/public/cpp/bindings/sync_call_restrictions.h"
#include "mojo/public/cpp/bindings/sync_event_watcher.h"
// ThreadSafeInterfacePtr wraps a non-thread-safe InterfacePtr and proxies
// messages to it. Async calls are posted to the sequence that the InteracePtr
// is bound to, and the responses are posted back. Sync calls are dispatched
// directly if the call is made on the sequence that the wrapped InterfacePtr is
// bound to, or posted otherwise. It's important to be aware that sync calls
// block both the calling sequence and the InterfacePtr sequence. That means
// that you cannot make sync calls through a ThreadSafeInterfacePtr if the
// underlying InterfacePtr is bound to a sequence that cannot block, like the IO
// thread.
namespace mojo {
// Instances of this class may be used from any sequence to serialize
// |Interface| messages and forward them elsewhere. In general you should use
// one of the ThreadSafeInterfacePtrBase helper aliases defined below, but this
// type may be useful if you need/want to manually manage the lifetime of the
// underlying proxy object which will be used to ultimately send messages.
template <typename Interface>
class ThreadSafeForwarder : public MessageReceiverWithResponder {
public:
using ProxyType = typename Interface::Proxy_;
using ForwardMessageCallback = base::Callback<void(Message)>;
using ForwardMessageWithResponderCallback =
base::Callback<void(Message, std::unique_ptr<MessageReceiver>)>;
// Constructs a ThreadSafeForwarder through which Messages are forwarded to
// |forward| or |forward_with_responder| by posting to |task_runner|.
//
// Any message sent through this forwarding interface will dispatch its reply,
// if any, back to the sequence which called the corresponding interface
// method.
ThreadSafeForwarder(
const scoped_refptr<base::SequencedTaskRunner>& task_runner,
const ForwardMessageCallback& forward,
const ForwardMessageWithResponderCallback& forward_with_responder,
const AssociatedGroup& associated_group)
: proxy_(this),
task_runner_(task_runner),
forward_(forward),
forward_with_responder_(forward_with_responder),
associated_group_(associated_group),
sync_calls_(new InProgressSyncCalls()) {}
~ThreadSafeForwarder() override {
// If there are ongoing sync calls signal their completion now.
base::AutoLock l(sync_calls_->lock);
for (const auto& pending_response : sync_calls_->pending_responses)
pending_response->event.Signal();
}
ProxyType& proxy() { return proxy_; }
private:
// MessageReceiverWithResponder implementation:
bool PrefersSerializedMessages() override {
// TSIP is primarily used because it emulates legacy IPC threading behavior.
// In practice this means it's only for cross-process messaging and we can
// just always assume messages should be serialized.
return true;
}
bool Accept(Message* message) override {
if (!message->associated_endpoint_handles()->empty()) {
// If this DCHECK fails, it is likely because:
// - This is a non-associated interface pointer setup using
// PtrWrapper::BindOnTaskRunner(
// InterfacePtrInfo<InterfaceType> ptr_info);
// Please see the TODO in that method.
// - This is an associated interface which hasn't been associated with a
// message pipe. In other words, the corresponding
// AssociatedInterfaceRequest hasn't been sent.
DCHECK(associated_group_.GetController());
message->SerializeAssociatedEndpointHandles(
associated_group_.GetController());
}
task_runner_->PostTask(FROM_HERE,
base::Bind(forward_, base::Passed(message)));
return true;
}
bool AcceptWithResponder(
Message* message,
std::unique_ptr<MessageReceiver> responder) override {
if (!message->associated_endpoint_handles()->empty()) {
// Please see comment for the DCHECK in the previous method.
DCHECK(associated_group_.GetController());
message->SerializeAssociatedEndpointHandles(
associated_group_.GetController());
}
// Async messages are always posted (even if |task_runner_| runs tasks on
// this sequence) to guarantee that two async calls can't be reordered.
if (!message->has_flag(Message::kFlagIsSync)) {
auto reply_forwarder =
std::make_unique<ForwardToCallingThread>(std::move(responder));
task_runner_->PostTask(
FROM_HERE, base::Bind(forward_with_responder_, base::Passed(message),
base::Passed(&reply_forwarder)));
return true;
}
SyncCallRestrictions::AssertSyncCallAllowed();
// If the InterfacePtr is bound to this sequence, dispatch it directly.
if (task_runner_->RunsTasksInCurrentSequence()) {
forward_with_responder_.Run(std::move(*message), std::move(responder));
return true;
}
// If the InterfacePtr is bound on another sequence, post the call.
// TODO(yzshen, watk): We block both this sequence and the InterfacePtr
// sequence. Ideally only this sequence would block.
auto response = base::MakeRefCounted<SyncResponseInfo>();
auto response_signaler = std::make_unique<SyncResponseSignaler>(response);
task_runner_->PostTask(
FROM_HERE, base::Bind(forward_with_responder_, base::Passed(message),
base::Passed(&response_signaler)));
// Save the pending SyncResponseInfo so that if the sync call deletes
// |this|, we can signal the completion of the call to return from
// SyncWatch().
auto sync_calls = sync_calls_;
{
base::AutoLock l(sync_calls->lock);
sync_calls->pending_responses.push_back(response.get());
}
auto assign_true = [](bool* b) { *b = true; };
bool event_signaled = false;
SyncEventWatcher watcher(&response->event,
base::Bind(assign_true, &event_signaled));
const bool* stop_flags[] = {&event_signaled};
watcher.SyncWatch(stop_flags, 1);
{
base::AutoLock l(sync_calls->lock);
base::Erase(sync_calls->pending_responses, response.get());
}
if (response->received)
ignore_result(responder->Accept(&response->message));
return true;
}
// Data that we need to share between the sequences involved in a sync call.
struct SyncResponseInfo
: public base::RefCountedThreadSafe<SyncResponseInfo> {
Message message;
bool received = false;
base::WaitableEvent event{base::WaitableEvent::ResetPolicy::MANUAL,
base::WaitableEvent::InitialState::NOT_SIGNALED};
private:
friend class base::RefCountedThreadSafe<SyncResponseInfo>;
};
// A MessageReceiver that signals |response| when it either accepts the
// response message, or is destructed.
class SyncResponseSignaler : public MessageReceiver {
public:
explicit SyncResponseSignaler(scoped_refptr<SyncResponseInfo> response)
: response_(response) {}
~SyncResponseSignaler() override {
// If Accept() was not called we must still notify the waiter that the
// sync call is finished.
if (response_)
response_->event.Signal();
}
bool Accept(Message* message) override {
response_->message = std::move(*message);
response_->received = true;
response_->event.Signal();
response_ = nullptr;
return true;
}
private:
scoped_refptr<SyncResponseInfo> response_;
};
// A record of the pending sync responses for canceling pending sync calls
// when the owning ThreadSafeForwarder is destructed.
struct InProgressSyncCalls
: public base::RefCountedThreadSafe<InProgressSyncCalls> {
// |lock| protects access to |pending_responses|.
base::Lock lock;
std::vector<SyncResponseInfo*> pending_responses;
};
class ForwardToCallingThread : public MessageReceiver {
public:
explicit ForwardToCallingThread(std::unique_ptr<MessageReceiver> responder)
: responder_(std::move(responder)),
caller_task_runner_(base::SequencedTaskRunnerHandle::Get()) {}
~ForwardToCallingThread() override {
caller_task_runner_->DeleteSoon(FROM_HERE, std::move(responder_));
}
private:
bool Accept(Message* message) override {
// The current instance will be deleted when this method returns, so we
// have to relinquish the responder's ownership so it does not get
// deleted.
caller_task_runner_->PostTask(
FROM_HERE,
base::Bind(&ForwardToCallingThread::CallAcceptAndDeleteResponder,
base::Passed(std::move(responder_)),
base::Passed(std::move(*message))));
return true;
}
static void CallAcceptAndDeleteResponder(
std::unique_ptr<MessageReceiver> responder,
Message message) {
ignore_result(responder->Accept(&message));
}
std::unique_ptr<MessageReceiver> responder_;
scoped_refptr<base::SequencedTaskRunner> caller_task_runner_;
};
ProxyType proxy_;
const scoped_refptr<base::SequencedTaskRunner> task_runner_;
const ForwardMessageCallback forward_;
const ForwardMessageWithResponderCallback forward_with_responder_;
AssociatedGroup associated_group_;
scoped_refptr<InProgressSyncCalls> sync_calls_;
DISALLOW_COPY_AND_ASSIGN(ThreadSafeForwarder);
};
template <typename InterfacePtrType>
class ThreadSafeInterfacePtrBase
: public base::RefCountedThreadSafe<
ThreadSafeInterfacePtrBase<InterfacePtrType>> {
public:
using InterfaceType = typename InterfacePtrType::InterfaceType;
using PtrInfoType = typename InterfacePtrType::PtrInfoType;
explicit ThreadSafeInterfacePtrBase(
std::unique_ptr<ThreadSafeForwarder<InterfaceType>> forwarder)
: forwarder_(std::move(forwarder)) {}
// Creates a ThreadSafeInterfacePtrBase wrapping an underlying non-thread-safe
// InterfacePtrType which is bound to the calling sequence. All messages sent
// via this thread-safe proxy will internally be sent by first posting to this
// (the calling) sequence's TaskRunner.
static scoped_refptr<ThreadSafeInterfacePtrBase> Create(
InterfacePtrType interface_ptr) {
scoped_refptr<PtrWrapper> wrapper =
new PtrWrapper(std::move(interface_ptr));
return new ThreadSafeInterfacePtrBase(wrapper->CreateForwarder());
}
// Creates a ThreadSafeInterfacePtrBase which binds the underlying
// non-thread-safe InterfacePtrType on the specified TaskRunner. All messages
// sent via this thread-safe proxy will internally be sent by first posting to
// that TaskRunner.
static scoped_refptr<ThreadSafeInterfacePtrBase> Create(
PtrInfoType ptr_info,
const scoped_refptr<base::SequencedTaskRunner>& bind_task_runner) {
scoped_refptr<PtrWrapper> wrapper = new PtrWrapper(bind_task_runner);
wrapper->BindOnTaskRunner(std::move(ptr_info));
return new ThreadSafeInterfacePtrBase(wrapper->CreateForwarder());
}
InterfaceType* get() { return &forwarder_->proxy(); }
InterfaceType* operator->() { return get(); }
InterfaceType& operator*() { return *get(); }
private:
friend class base::RefCountedThreadSafe<
ThreadSafeInterfacePtrBase<InterfacePtrType>>;
struct PtrWrapperDeleter;
// Helper class which owns an |InterfacePtrType| instance on an appropriate
// sequence. This is kept alive as long its bound within some
// ThreadSafeForwarder's callbacks.
class PtrWrapper
: public base::RefCountedThreadSafe<PtrWrapper, PtrWrapperDeleter> {
public:
explicit PtrWrapper(InterfacePtrType ptr)
: PtrWrapper(base::SequencedTaskRunnerHandle::Get()) {
ptr_ = std::move(ptr);
associated_group_ = *ptr_.internal_state()->associated_group();
}
explicit PtrWrapper(
const scoped_refptr<base::SequencedTaskRunner>& task_runner)
: task_runner_(task_runner) {}
void BindOnTaskRunner(AssociatedInterfacePtrInfo<InterfaceType> ptr_info) {
associated_group_ = AssociatedGroup(ptr_info.handle());
task_runner_->PostTask(FROM_HERE, base::Bind(&PtrWrapper::Bind, this,
base::Passed(&ptr_info)));
}
void BindOnTaskRunner(InterfacePtrInfo<InterfaceType> ptr_info) {
// TODO(yzhsen): At the momment we don't have a group controller
// available. That means the user won't be able to pass associated
// endpoints on this interface (at least not immediately). In order to fix
// this, we need to create a MultiplexRouter immediately and bind it to
// the interface pointer on the |task_runner_|. Therefore, MultiplexRouter
// should be able to be created on a sequence different than the one that
// it is supposed to listen on. crbug.com/682334
task_runner_->PostTask(FROM_HERE, base::Bind(&PtrWrapper::Bind, this,
base::Passed(&ptr_info)));
}
std::unique_ptr<ThreadSafeForwarder<InterfaceType>> CreateForwarder() {
return std::make_unique<ThreadSafeForwarder<InterfaceType>>(
task_runner_, base::Bind(&PtrWrapper::Accept, this),
base::Bind(&PtrWrapper::AcceptWithResponder, this),
associated_group_);
}
private:
friend struct PtrWrapperDeleter;
~PtrWrapper() {}
void Bind(PtrInfoType ptr_info) {
DCHECK(task_runner_->RunsTasksInCurrentSequence());
ptr_.Bind(std::move(ptr_info));
}
void Accept(Message message) {
ptr_.internal_state()->ForwardMessage(std::move(message));
}
void AcceptWithResponder(Message message,
std::unique_ptr<MessageReceiver> responder) {
ptr_.internal_state()->ForwardMessageWithResponder(std::move(message),
std::move(responder));
}
void DeleteOnCorrectThread() const {
if (!task_runner_->RunsTasksInCurrentSequence()) {
// NOTE: This is only called when there are no more references to
// |this|, so binding it unretained is both safe and necessary.
task_runner_->PostTask(FROM_HERE,
base::Bind(&PtrWrapper::DeleteOnCorrectThread,
base::Unretained(this)));
} else {
delete this;
}
}
InterfacePtrType ptr_;
const scoped_refptr<base::SequencedTaskRunner> task_runner_;
AssociatedGroup associated_group_;
DISALLOW_COPY_AND_ASSIGN(PtrWrapper);
};
struct PtrWrapperDeleter {
static void Destruct(const PtrWrapper* interface_ptr) {
interface_ptr->DeleteOnCorrectThread();
}
};
~ThreadSafeInterfacePtrBase() {}
const std::unique_ptr<ThreadSafeForwarder<InterfaceType>> forwarder_;
DISALLOW_COPY_AND_ASSIGN(ThreadSafeInterfacePtrBase);
};
template <typename Interface>
using ThreadSafeAssociatedInterfacePtr =
ThreadSafeInterfacePtrBase<AssociatedInterfacePtr<Interface>>;
template <typename Interface>
using ThreadSafeInterfacePtr =
ThreadSafeInterfacePtrBase<InterfacePtr<Interface>>;
} // namespace mojo
#endif // MOJO_PUBLIC_CPP_BINDINGS_THREAD_SAFE_INTERFACE_PTR_H_