// Copyright 2017 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "mojo/core/channel.h" #include "base/bind.h" #include "base/memory/ptr_util.h" #include "base/message_loop/message_loop.h" #include "base/threading/thread.h" #include "mojo/core/platform_handle_utils.h" #include "mojo/public/cpp/platform/platform_channel.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" namespace mojo { namespace core { namespace { class TestChannel : public Channel { public: TestChannel(Channel::Delegate* delegate) : Channel(delegate) {} char* GetReadBufferTest(size_t* buffer_capacity) { return GetReadBuffer(buffer_capacity); } bool OnReadCompleteTest(size_t bytes_read, size_t* next_read_size_hint) { return OnReadComplete(bytes_read, next_read_size_hint); } MOCK_METHOD7(GetReadPlatformHandles, bool(const void* payload, size_t payload_size, size_t num_handles, const void* extra_header, size_t extra_header_size, std::vector<PlatformHandle>* handles, bool* deferred)); MOCK_METHOD0(Start, void()); MOCK_METHOD0(ShutDownImpl, void()); MOCK_METHOD0(LeakHandle, void()); void Write(MessagePtr message) override {} protected: ~TestChannel() override {} }; // Not using GMock as I don't think it supports movable types. class MockChannelDelegate : public Channel::Delegate { public: MockChannelDelegate() {} size_t GetReceivedPayloadSize() const { return payload_size_; } const void* GetReceivedPayload() const { return payload_.get(); } protected: void OnChannelMessage(const void* payload, size_t payload_size, std::vector<PlatformHandle> handles) override { payload_.reset(new char[payload_size]); memcpy(payload_.get(), payload, payload_size); payload_size_ = payload_size; } // Notify that an error has occured and the Channel will cease operation. void OnChannelError(Channel::Error error) override {} private: size_t payload_size_ = 0; std::unique_ptr<char[]> payload_; }; Channel::MessagePtr CreateDefaultMessage(bool legacy_message) { const size_t payload_size = 100; Channel::MessagePtr message = std::make_unique<Channel::Message>( payload_size, 0, legacy_message ? Channel::Message::MessageType::NORMAL_LEGACY : Channel::Message::MessageType::NORMAL); char* payload = static_cast<char*>(message->mutable_payload()); for (size_t i = 0; i < payload_size; i++) { payload[i] = static_cast<char>(i); } return message; } void TestMemoryEqual(const void* data1, size_t data1_size, const void* data2, size_t data2_size) { ASSERT_EQ(data1_size, data2_size); const unsigned char* data1_char = static_cast<const unsigned char*>(data1); const unsigned char* data2_char = static_cast<const unsigned char*>(data2); for (size_t i = 0; i < data1_size; i++) { // ASSERT so we don't log tons of errors if the data is different. ASSERT_EQ(data1_char[i], data2_char[i]); } } void TestMessagesAreEqual(Channel::Message* message1, Channel::Message* message2, bool legacy_messages) { // If any of the message is null, this is probably not what you wanted to // test. ASSERT_NE(nullptr, message1); ASSERT_NE(nullptr, message2); ASSERT_EQ(message1->payload_size(), message2->payload_size()); EXPECT_EQ(message1->has_handles(), message2->has_handles()); TestMemoryEqual(message1->payload(), message1->payload_size(), message2->payload(), message2->payload_size()); if (legacy_messages) return; ASSERT_EQ(message1->extra_header_size(), message2->extra_header_size()); TestMemoryEqual(message1->extra_header(), message1->extra_header_size(), message2->extra_header(), message2->extra_header_size()); } TEST(ChannelTest, LegacyMessageDeserialization) { Channel::MessagePtr message = CreateDefaultMessage(true /* legacy_message */); Channel::MessagePtr deserialized_message = Channel::Message::Deserialize(message->data(), message->data_num_bytes()); TestMessagesAreEqual(message.get(), deserialized_message.get(), true /* legacy_message */); } TEST(ChannelTest, NonLegacyMessageDeserialization) { Channel::MessagePtr message = CreateDefaultMessage(false /* legacy_message */); Channel::MessagePtr deserialized_message = Channel::Message::Deserialize(message->data(), message->data_num_bytes()); TestMessagesAreEqual(message.get(), deserialized_message.get(), false /* legacy_message */); } TEST(ChannelTest, OnReadLegacyMessage) { size_t buffer_size = 100 * 1024; Channel::MessagePtr message = CreateDefaultMessage(true /* legacy_message */); MockChannelDelegate channel_delegate; scoped_refptr<TestChannel> channel = new TestChannel(&channel_delegate); char* read_buffer = channel->GetReadBufferTest(&buffer_size); ASSERT_LT(message->data_num_bytes(), buffer_size); // Bad test. Increase buffer // size. memcpy(read_buffer, message->data(), message->data_num_bytes()); size_t next_read_size_hint = 0; EXPECT_TRUE(channel->OnReadCompleteTest(message->data_num_bytes(), &next_read_size_hint)); TestMemoryEqual(message->payload(), message->payload_size(), channel_delegate.GetReceivedPayload(), channel_delegate.GetReceivedPayloadSize()); } TEST(ChannelTest, OnReadNonLegacyMessage) { size_t buffer_size = 100 * 1024; Channel::MessagePtr message = CreateDefaultMessage(false /* legacy_message */); MockChannelDelegate channel_delegate; scoped_refptr<TestChannel> channel = new TestChannel(&channel_delegate); char* read_buffer = channel->GetReadBufferTest(&buffer_size); ASSERT_LT(message->data_num_bytes(), buffer_size); // Bad test. Increase buffer // size. memcpy(read_buffer, message->data(), message->data_num_bytes()); size_t next_read_size_hint = 0; EXPECT_TRUE(channel->OnReadCompleteTest(message->data_num_bytes(), &next_read_size_hint)); TestMemoryEqual(message->payload(), message->payload_size(), channel_delegate.GetReceivedPayload(), channel_delegate.GetReceivedPayloadSize()); } class ChannelTestShutdownAndWriteDelegate : public Channel::Delegate { public: ChannelTestShutdownAndWriteDelegate( PlatformChannelEndpoint endpoint, scoped_refptr<base::TaskRunner> task_runner, scoped_refptr<Channel> client_channel, std::unique_ptr<base::Thread> client_thread, base::RepeatingClosure quit_closure) : quit_closure_(std::move(quit_closure)), client_channel_(std::move(client_channel)), client_thread_(std::move(client_thread)) { channel_ = Channel::Create(this, ConnectionParams(std::move(endpoint)), std::move(task_runner)); channel_->Start(); } ~ChannelTestShutdownAndWriteDelegate() override { channel_->ShutDown(); } // Channel::Delegate implementation void OnChannelMessage(const void* payload, size_t payload_size, std::vector<PlatformHandle> handles) override { ++message_count_; // If |client_channel_| exists then close it and its thread. if (client_channel_) { // Write a fresh message, making our channel readable again. Channel::MessagePtr message = CreateDefaultMessage(false); client_thread_->task_runner()->PostTask( FROM_HERE, base::BindOnce(&Channel::Write, client_channel_, base::Passed(&message))); // Close the channel and wait for it to shutdown. client_channel_->ShutDown(); client_channel_ = nullptr; client_thread_->Stop(); client_thread_ = nullptr; } // Write a message to the channel, to verify whether this triggers an // OnChannelError callback before all messages were read. Channel::MessagePtr message = CreateDefaultMessage(false); channel_->Write(std::move(message)); } void OnChannelError(Channel::Error error) override { EXPECT_EQ(2, message_count_); quit_closure_.Run(); } base::RepeatingClosure quit_closure_; int message_count_ = 0; scoped_refptr<Channel> channel_; scoped_refptr<Channel> client_channel_; std::unique_ptr<base::Thread> client_thread_; }; TEST(ChannelTest, PeerShutdownDuringRead) { base::MessageLoop message_loop(base::MessageLoop::TYPE_IO); PlatformChannel channel; // Create a "client" Channel with one end of the pipe, and Start() it. std::unique_ptr<base::Thread> client_thread = std::make_unique<base::Thread>("clientio_thread"); client_thread->StartWithOptions( base::Thread::Options(base::MessageLoop::TYPE_IO, 0)); scoped_refptr<Channel> client_channel = Channel::Create(nullptr, ConnectionParams(channel.TakeRemoteEndpoint()), client_thread->task_runner()); client_channel->Start(); // On the "client" IO thread, create and write a message. Channel::MessagePtr message = CreateDefaultMessage(false); client_thread->task_runner()->PostTask( FROM_HERE, base::BindOnce(&Channel::Write, client_channel, base::Passed(&message))); // Create a "server" Channel with the other end of the pipe, and process the // messages from it. The |server_delegate| will ShutDown the client end of // the pipe after the first message, and quit the RunLoop when OnChannelError // is received. base::RunLoop run_loop; ChannelTestShutdownAndWriteDelegate server_delegate( channel.TakeLocalEndpoint(), message_loop.task_runner(), std::move(client_channel), std::move(client_thread), run_loop.QuitClosure()); run_loop.Run(); } } // namespace } // namespace core } // namespace mojo