// // Copyright (C) 2013 The Android Open Source Project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #include "shill/icmp.h" #include <netinet/in.h> #include <netinet/ip_icmp.h> #include <gtest/gtest.h> #include "shill/mock_log.h" #include "shill/net/ip_address.h" #include "shill/net/mock_sockets.h" using testing::_; using testing::HasSubstr; using testing::InSequence; using testing::Return; using testing::StrictMock; using testing::Test; namespace shill { namespace { // These binary blobs representing ICMP headers and their respective checksums // were taken directly from Wireshark ICMP packet captures and are given in big // endian. The checksum field is zeroed in |kIcmpEchoRequestEvenLen| and // |kIcmpEchoRequestOddLen| so the checksum can be calculated on the header in // IcmpTest.ComputeIcmpChecksum. const uint8_t kIcmpEchoRequestEvenLen[] = {0x08, 0x00, 0x00, 0x00, 0x71, 0x50, 0x00, 0x00}; const uint8_t kIcmpEchoRequestEvenLenChecksum[] = {0x86, 0xaf}; const uint8_t kIcmpEchoRequestOddLen[] = {0x08, 0x00, 0x00, 0x00, 0xac, 0x51, 0x00, 0x00, 0x00, 0x00, 0x01}; const uint8_t kIcmpEchoRequestOddLenChecksum[] = {0x4a, 0xae}; } // namespace class IcmpTest : public Test { public: IcmpTest() {} virtual ~IcmpTest() {} virtual void SetUp() { sockets_ = new StrictMock<MockSockets>(); // Passes ownership. icmp_.sockets_.reset(sockets_); } virtual void TearDown() { if (icmp_.IsStarted()) { EXPECT_CALL(*sockets_, Close(kSocketFD)); icmp_.Stop(); } EXPECT_FALSE(icmp_.IsStarted()); } protected: static const int kSocketFD; static const char kIPAddress[]; int GetSocket() { return icmp_.socket_; } bool StartIcmp() { return StartIcmpWithFD(kSocketFD); } bool StartIcmpWithFD(int fd) { EXPECT_CALL(*sockets_, Socket(AF_INET, SOCK_RAW, IPPROTO_ICMP)) .WillOnce(Return(fd)); EXPECT_CALL(*sockets_, SetNonBlocking(fd)).WillOnce(Return(0)); bool start_status = icmp_.Start(); EXPECT_TRUE(start_status); EXPECT_EQ(fd, icmp_.socket_); EXPECT_TRUE(icmp_.IsStarted()); return start_status; } uint16_t ComputeIcmpChecksum(const struct icmphdr &hdr, size_t len) { return Icmp::ComputeIcmpChecksum(hdr, len); } // Owned by Icmp, and tracked here only for mocks. MockSockets* sockets_; Icmp icmp_; }; const int IcmpTest::kSocketFD = 456; const char IcmpTest::kIPAddress[] = "10.0.1.1"; TEST_F(IcmpTest, Constructor) { EXPECT_EQ(-1, GetSocket()); EXPECT_FALSE(icmp_.IsStarted()); } TEST_F(IcmpTest, SocketOpenFail) { ScopedMockLog log; EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("Could not create ICMP socket"))).Times(1); EXPECT_CALL(*sockets_, Socket(AF_INET, SOCK_RAW, IPPROTO_ICMP)) .WillOnce(Return(-1)); EXPECT_FALSE(icmp_.Start()); EXPECT_FALSE(icmp_.IsStarted()); } TEST_F(IcmpTest, SocketNonBlockingFail) { ScopedMockLog log; EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("Could not set socket to be non-blocking"))).Times(1); EXPECT_CALL(*sockets_, Socket(_, _, _)).WillOnce(Return(kSocketFD)); EXPECT_CALL(*sockets_, SetNonBlocking(kSocketFD)).WillOnce(Return(-1)); EXPECT_CALL(*sockets_, Close(kSocketFD)); EXPECT_FALSE(icmp_.Start()); EXPECT_FALSE(icmp_.IsStarted()); } TEST_F(IcmpTest, StartMultipleTimes) { const int kFirstSocketFD = kSocketFD + 1; StartIcmpWithFD(kFirstSocketFD); EXPECT_CALL(*sockets_, Close(kFirstSocketFD)); StartIcmp(); } MATCHER_P(IsIcmpHeader, header, "") { return memcmp(arg, &header, sizeof(header)) == 0; } MATCHER_P(IsSocketAddress, address, "") { const struct sockaddr_in* sock_addr = reinterpret_cast<const struct sockaddr_in*>(arg); return sock_addr->sin_family == address.family() && memcmp(&sock_addr->sin_addr.s_addr, address.GetConstData(), address.GetLength()) == 0; } TEST_F(IcmpTest, TransmitEchoRequest) { StartIcmp(); // Address isn't valid. EXPECT_FALSE( icmp_.TransmitEchoRequest(IPAddress(IPAddress::kFamilyIPv4), 1, 1)); // IPv6 adresses aren't implemented. IPAddress ipv6_destination(IPAddress::kFamilyIPv6); EXPECT_TRUE(ipv6_destination.SetAddressFromString( "fe80::1aa9:5ff:abcd:1234")); EXPECT_FALSE(icmp_.TransmitEchoRequest(ipv6_destination, 1, 1)); IPAddress ipv4_destination(IPAddress::kFamilyIPv4); EXPECT_TRUE(ipv4_destination.SetAddressFromString(kIPAddress)); struct icmphdr icmp_header; memset(&icmp_header, 0, sizeof(icmp_header)); icmp_header.type = ICMP_ECHO; icmp_header.code = Icmp::kIcmpEchoCode; icmp_header.un.echo.id = 1; icmp_header.un.echo.sequence = 1; icmp_header.checksum = ComputeIcmpChecksum(icmp_header, sizeof(icmp_header)); EXPECT_CALL(*sockets_, SendTo(kSocketFD, IsIcmpHeader(icmp_header), sizeof(icmp_header), 0, IsSocketAddress(ipv4_destination), sizeof(sockaddr_in))) .WillOnce(Return(-1)) .WillOnce(Return(0)) .WillOnce(Return(sizeof(icmp_header) - 1)) .WillOnce(Return(sizeof(icmp_header))); { InSequence seq; ScopedMockLog log; EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("Socket sendto failed"))).Times(1); EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("less than the expected result"))).Times(2); EXPECT_FALSE(icmp_.TransmitEchoRequest(ipv4_destination, 1, 1)); EXPECT_FALSE(icmp_.TransmitEchoRequest(ipv4_destination, 1, 1)); EXPECT_FALSE(icmp_.TransmitEchoRequest(ipv4_destination, 1, 1)); EXPECT_TRUE(icmp_.TransmitEchoRequest(ipv4_destination, 1, 1)); } } TEST_F(IcmpTest, ComputeIcmpChecksum) { EXPECT_EQ(*reinterpret_cast<const uint16_t*>(kIcmpEchoRequestEvenLenChecksum), ComputeIcmpChecksum(*reinterpret_cast<const struct icmphdr*>( kIcmpEchoRequestEvenLen), sizeof(kIcmpEchoRequestEvenLen))); EXPECT_EQ(*reinterpret_cast<const uint16_t*>(kIcmpEchoRequestOddLenChecksum), ComputeIcmpChecksum(*reinterpret_cast<const struct icmphdr*>( kIcmpEchoRequestOddLen), sizeof(kIcmpEchoRequestOddLen))); } } // namespace shill