普通文本  |  324行  |  10.11 KB

// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "mojo/public/cpp/bindings/lib/router.h"

#include <stdint.h>

#include <utility>

#include "base/bind.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/memory/ptr_util.h"
#include "base/stl_util.h"
#include "mojo/public/cpp/bindings/sync_call_restrictions.h"

namespace mojo {
namespace internal {

// ----------------------------------------------------------------------------

namespace {

void DCheckIfInvalid(const base::WeakPtr<Router>& router,
                   const std::string& message) {
  bool is_valid = router && !router->encountered_error() && router->is_valid();
  DCHECK(!is_valid) << message;
}

class ResponderThunk : public MessageReceiverWithStatus {
 public:
  explicit ResponderThunk(const base::WeakPtr<Router>& router,
                          scoped_refptr<base::SingleThreadTaskRunner> runner)
      : router_(router),
        accept_was_invoked_(false),
        task_runner_(std::move(runner)) {}
  ~ResponderThunk() override {
    if (!accept_was_invoked_) {
      // The Mojo application handled a message that was expecting a response
      // but did not send a response.
      // We raise an error to signal the calling application that an error
      // condition occurred. Without this the calling application would have no
      // way of knowing it should stop waiting for a response.
      if (task_runner_->RunsTasksOnCurrentThread()) {
        // Please note that even if this code is run from a different task
        // runner on the same thread as |task_runner_|, it is okay to directly
        // call Router::RaiseError(), because it will raise error from the
        // correct task runner asynchronously.
        if (router_)
          router_->RaiseError();
      } else {
        task_runner_->PostTask(FROM_HERE,
                               base::Bind(&Router::RaiseError, router_));
      }
    }
  }

  // MessageReceiver implementation:
  bool Accept(Message* message) override {
    DCHECK(task_runner_->RunsTasksOnCurrentThread());
    accept_was_invoked_ = true;
    DCHECK(message->has_flag(Message::kFlagIsResponse));

    bool result = false;

    if (router_)
      result = router_->Accept(message);

    return result;
  }

  // MessageReceiverWithStatus implementation:
  bool IsValid() override {
    DCHECK(task_runner_->RunsTasksOnCurrentThread());
    return router_ && !router_->encountered_error() && router_->is_valid();
  }

  void DCheckInvalid(const std::string& message) override {
    if (task_runner_->RunsTasksOnCurrentThread()) {
      DCheckIfInvalid(router_, message);
    } else {
      task_runner_->PostTask(FROM_HERE,
                             base::Bind(&DCheckIfInvalid, router_, message));
    }
  }

 private:
  base::WeakPtr<Router> router_;
  bool accept_was_invoked_;
  scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
};

}  // namespace

// ----------------------------------------------------------------------------

Router::SyncResponseInfo::SyncResponseInfo(bool* in_response_received)
    : response_received(in_response_received) {}

Router::SyncResponseInfo::~SyncResponseInfo() {}

// ----------------------------------------------------------------------------

Router::HandleIncomingMessageThunk::HandleIncomingMessageThunk(Router* router)
    : router_(router) {
}

Router::HandleIncomingMessageThunk::~HandleIncomingMessageThunk() {
}

bool Router::HandleIncomingMessageThunk::Accept(Message* message) {
  return router_->HandleIncomingMessage(message);
}

// ----------------------------------------------------------------------------

Router::Router(ScopedMessagePipeHandle message_pipe,
               FilterChain filters,
               bool expects_sync_requests,
               scoped_refptr<base::SingleThreadTaskRunner> runner)
    : thunk_(this),
      filters_(std::move(filters)),
      connector_(std::move(message_pipe),
                 Connector::SINGLE_THREADED_SEND,
                 std::move(runner)),
      incoming_receiver_(nullptr),
      next_request_id_(0),
      testing_mode_(false),
      pending_task_for_messages_(false),
      encountered_error_(false),
      weak_factory_(this) {
  filters_.SetSink(&thunk_);
  if (expects_sync_requests)
    connector_.AllowWokenUpBySyncWatchOnSameThread();
  connector_.set_incoming_receiver(filters_.GetHead());
  connector_.set_connection_error_handler(
      base::Bind(&Router::OnConnectionError, base::Unretained(this)));
}

Router::~Router() {}

bool Router::Accept(Message* message) {
  DCHECK(thread_checker_.CalledOnValidThread());
  DCHECK(!message->has_flag(Message::kFlagExpectsResponse));
  return connector_.Accept(message);
}

bool Router::AcceptWithResponder(Message* message, MessageReceiver* responder) {
  DCHECK(thread_checker_.CalledOnValidThread());
  DCHECK(message->has_flag(Message::kFlagExpectsResponse));

  // Reserve 0 in case we want it to convey special meaning in the future.
  uint64_t request_id = next_request_id_++;
  if (request_id == 0)
    request_id = next_request_id_++;

  bool is_sync = message->has_flag(Message::kFlagIsSync);
  message->set_request_id(request_id);
  if (!connector_.Accept(message))
    return false;

  if (!is_sync) {
    // We assume ownership of |responder|.
    async_responders_[request_id] = base::WrapUnique(responder);
    return true;
  }

  SyncCallRestrictions::AssertSyncCallAllowed();

  bool response_received = false;
  std::unique_ptr<MessageReceiver> sync_responder(responder);
  sync_responses_.insert(std::make_pair(
      request_id, base::WrapUnique(new SyncResponseInfo(&response_received))));

  base::WeakPtr<Router> weak_self = weak_factory_.GetWeakPtr();
  connector_.SyncWatch(&response_received);
  // Make sure that this instance hasn't been destroyed.
  if (weak_self) {
    DCHECK(ContainsKey(sync_responses_, request_id));
    auto iter = sync_responses_.find(request_id);
    DCHECK_EQ(&response_received, iter->second->response_received);
    if (response_received) {
      std::unique_ptr<Message> response = std::move(iter->second->response);
      ignore_result(sync_responder->Accept(response.get()));
    }
    sync_responses_.erase(iter);
  }

  // Return true means that we take ownership of |responder|.
  return true;
}

void Router::EnableTestingMode() {
  DCHECK(thread_checker_.CalledOnValidThread());
  testing_mode_ = true;
  connector_.set_enforce_errors_from_incoming_receiver(false);
}

bool Router::HandleIncomingMessage(Message* message) {
  DCHECK(thread_checker_.CalledOnValidThread());

  const bool during_sync_call =
      connector_.during_sync_handle_watcher_callback();
  if (!message->has_flag(Message::kFlagIsSync) &&
      (during_sync_call || !pending_messages_.empty())) {
    std::unique_ptr<Message> pending_message(new Message);
    message->MoveTo(pending_message.get());
    pending_messages_.push(std::move(pending_message));

    if (!pending_task_for_messages_) {
      pending_task_for_messages_ = true;
      connector_.task_runner()->PostTask(
          FROM_HERE, base::Bind(&Router::HandleQueuedMessages,
                                weak_factory_.GetWeakPtr()));
    }

    return true;
  }

  return HandleMessageInternal(message);
}

void Router::HandleQueuedMessages() {
  DCHECK(thread_checker_.CalledOnValidThread());
  DCHECK(pending_task_for_messages_);

  base::WeakPtr<Router> weak_self = weak_factory_.GetWeakPtr();
  while (!pending_messages_.empty()) {
    std::unique_ptr<Message> message(std::move(pending_messages_.front()));
    pending_messages_.pop();

    bool result = HandleMessageInternal(message.get());
    if (!weak_self)
      return;

    if (!result && !testing_mode_) {
      connector_.RaiseError();
      break;
    }
  }

  pending_task_for_messages_ = false;

  // We may have already seen a connection error from the connector, but
  // haven't notified the user because we want to process all the queued
  // messages first. We should do it now.
  if (connector_.encountered_error() && !encountered_error_)
    OnConnectionError();
}

bool Router::HandleMessageInternal(Message* message) {
  if (message->has_flag(Message::kFlagExpectsResponse)) {
    if (!incoming_receiver_)
      return false;

    MessageReceiverWithStatus* responder = new ResponderThunk(
        weak_factory_.GetWeakPtr(), connector_.task_runner());
    bool ok = incoming_receiver_->AcceptWithResponder(message, responder);
    if (!ok)
      delete responder;
    return ok;

  } else if (message->has_flag(Message::kFlagIsResponse)) {
    uint64_t request_id = message->request_id();

    if (message->has_flag(Message::kFlagIsSync)) {
      auto it = sync_responses_.find(request_id);
      if (it == sync_responses_.end()) {
        DCHECK(testing_mode_);
        return false;
      }
      it->second->response.reset(new Message());
      message->MoveTo(it->second->response.get());
      *it->second->response_received = true;
      return true;
    }

    auto it = async_responders_.find(request_id);
    if (it == async_responders_.end()) {
      DCHECK(testing_mode_);
      return false;
    }
    std::unique_ptr<MessageReceiver> responder = std::move(it->second);
    async_responders_.erase(it);
    return responder->Accept(message);
  } else {
    if (!incoming_receiver_)
      return false;

    return incoming_receiver_->Accept(message);
  }
}

void Router::OnConnectionError() {
  if (encountered_error_)
    return;

  if (!pending_messages_.empty()) {
    // After all the pending messages are processed, we will check whether an
    // error has been encountered and run the user's connection error handler
    // if necessary.
    DCHECK(pending_task_for_messages_);
    return;
  }

  if (connector_.during_sync_handle_watcher_callback()) {
    // We don't want the error handler to reenter an ongoing sync call.
    connector_.task_runner()->PostTask(
        FROM_HERE,
        base::Bind(&Router::OnConnectionError, weak_factory_.GetWeakPtr()));
    return;
  }

  encountered_error_ = true;
  if (!error_handler_.is_null())
    error_handler_.Run();
}

// ----------------------------------------------------------------------------

}  // namespace internal
}  // namespace mojo