//
// Copyright (C) 2013 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "shill/icmp.h"
#include <netinet/in.h>
#include <netinet/ip_icmp.h>
#include <gtest/gtest.h>
#include "shill/mock_log.h"
#include "shill/net/ip_address.h"
#include "shill/net/mock_sockets.h"
using testing::_;
using testing::HasSubstr;
using testing::InSequence;
using testing::Return;
using testing::StrictMock;
using testing::Test;
namespace shill {
namespace {
// These binary blobs representing ICMP headers and their respective checksums
// were taken directly from Wireshark ICMP packet captures and are given in big
// endian. The checksum field is zeroed in |kIcmpEchoRequestEvenLen| and
// |kIcmpEchoRequestOddLen| so the checksum can be calculated on the header in
// IcmpTest.ComputeIcmpChecksum.
const uint8_t kIcmpEchoRequestEvenLen[] = {0x08, 0x00, 0x00, 0x00,
0x71, 0x50, 0x00, 0x00};
const uint8_t kIcmpEchoRequestEvenLenChecksum[] = {0x86, 0xaf};
const uint8_t kIcmpEchoRequestOddLen[] = {0x08, 0x00, 0x00, 0x00, 0xac, 0x51,
0x00, 0x00, 0x00, 0x00, 0x01};
const uint8_t kIcmpEchoRequestOddLenChecksum[] = {0x4a, 0xae};
} // namespace
class IcmpTest : public Test {
public:
IcmpTest() {}
virtual ~IcmpTest() {}
virtual void SetUp() {
sockets_ = new StrictMock<MockSockets>();
// Passes ownership.
icmp_.sockets_.reset(sockets_);
}
virtual void TearDown() {
if (icmp_.IsStarted()) {
EXPECT_CALL(*sockets_, Close(kSocketFD));
icmp_.Stop();
}
EXPECT_FALSE(icmp_.IsStarted());
}
protected:
static const int kSocketFD;
static const char kIPAddress[];
int GetSocket() { return icmp_.socket_; }
bool StartIcmp() { return StartIcmpWithFD(kSocketFD); }
bool StartIcmpWithFD(int fd) {
EXPECT_CALL(*sockets_, Socket(AF_INET, SOCK_RAW, IPPROTO_ICMP))
.WillOnce(Return(fd));
EXPECT_CALL(*sockets_, SetNonBlocking(fd)).WillOnce(Return(0));
bool start_status = icmp_.Start();
EXPECT_TRUE(start_status);
EXPECT_EQ(fd, icmp_.socket_);
EXPECT_TRUE(icmp_.IsStarted());
return start_status;
}
uint16_t ComputeIcmpChecksum(const struct icmphdr &hdr, size_t len) {
return Icmp::ComputeIcmpChecksum(hdr, len);
}
// Owned by Icmp, and tracked here only for mocks.
MockSockets* sockets_;
Icmp icmp_;
};
const int IcmpTest::kSocketFD = 456;
const char IcmpTest::kIPAddress[] = "10.0.1.1";
TEST_F(IcmpTest, Constructor) {
EXPECT_EQ(-1, GetSocket());
EXPECT_FALSE(icmp_.IsStarted());
}
TEST_F(IcmpTest, SocketOpenFail) {
ScopedMockLog log;
EXPECT_CALL(log,
Log(logging::LOG_ERROR, _,
HasSubstr("Could not create ICMP socket"))).Times(1);
EXPECT_CALL(*sockets_, Socket(AF_INET, SOCK_RAW, IPPROTO_ICMP))
.WillOnce(Return(-1));
EXPECT_FALSE(icmp_.Start());
EXPECT_FALSE(icmp_.IsStarted());
}
TEST_F(IcmpTest, SocketNonBlockingFail) {
ScopedMockLog log;
EXPECT_CALL(log,
Log(logging::LOG_ERROR, _,
HasSubstr("Could not set socket to be non-blocking"))).Times(1);
EXPECT_CALL(*sockets_, Socket(_, _, _)).WillOnce(Return(kSocketFD));
EXPECT_CALL(*sockets_, SetNonBlocking(kSocketFD)).WillOnce(Return(-1));
EXPECT_CALL(*sockets_, Close(kSocketFD));
EXPECT_FALSE(icmp_.Start());
EXPECT_FALSE(icmp_.IsStarted());
}
TEST_F(IcmpTest, StartMultipleTimes) {
const int kFirstSocketFD = kSocketFD + 1;
StartIcmpWithFD(kFirstSocketFD);
EXPECT_CALL(*sockets_, Close(kFirstSocketFD));
StartIcmp();
}
MATCHER_P(IsIcmpHeader, header, "") {
return memcmp(arg, &header, sizeof(header)) == 0;
}
MATCHER_P(IsSocketAddress, address, "") {
const struct sockaddr_in* sock_addr =
reinterpret_cast<const struct sockaddr_in*>(arg);
return sock_addr->sin_family == address.family() &&
memcmp(&sock_addr->sin_addr.s_addr, address.GetConstData(),
address.GetLength()) == 0;
}
TEST_F(IcmpTest, TransmitEchoRequest) {
StartIcmp();
// Address isn't valid.
EXPECT_FALSE(
icmp_.TransmitEchoRequest(IPAddress(IPAddress::kFamilyIPv4), 1, 1));
// IPv6 adresses aren't implemented.
IPAddress ipv6_destination(IPAddress::kFamilyIPv6);
EXPECT_TRUE(ipv6_destination.SetAddressFromString(
"fe80::1aa9:5ff:abcd:1234"));
EXPECT_FALSE(icmp_.TransmitEchoRequest(ipv6_destination, 1, 1));
IPAddress ipv4_destination(IPAddress::kFamilyIPv4);
EXPECT_TRUE(ipv4_destination.SetAddressFromString(kIPAddress));
struct icmphdr icmp_header;
memset(&icmp_header, 0, sizeof(icmp_header));
icmp_header.type = ICMP_ECHO;
icmp_header.code = Icmp::kIcmpEchoCode;
icmp_header.un.echo.id = 1;
icmp_header.un.echo.sequence = 1;
icmp_header.checksum = ComputeIcmpChecksum(icmp_header, sizeof(icmp_header));
EXPECT_CALL(*sockets_, SendTo(kSocketFD,
IsIcmpHeader(icmp_header),
sizeof(icmp_header),
0,
IsSocketAddress(ipv4_destination),
sizeof(sockaddr_in)))
.WillOnce(Return(-1))
.WillOnce(Return(0))
.WillOnce(Return(sizeof(icmp_header) - 1))
.WillOnce(Return(sizeof(icmp_header)));
{
InSequence seq;
ScopedMockLog log;
EXPECT_CALL(log,
Log(logging::LOG_ERROR, _,
HasSubstr("Socket sendto failed"))).Times(1);
EXPECT_CALL(log,
Log(logging::LOG_ERROR, _,
HasSubstr("less than the expected result"))).Times(2);
EXPECT_FALSE(icmp_.TransmitEchoRequest(ipv4_destination, 1, 1));
EXPECT_FALSE(icmp_.TransmitEchoRequest(ipv4_destination, 1, 1));
EXPECT_FALSE(icmp_.TransmitEchoRequest(ipv4_destination, 1, 1));
EXPECT_TRUE(icmp_.TransmitEchoRequest(ipv4_destination, 1, 1));
}
}
TEST_F(IcmpTest, ComputeIcmpChecksum) {
EXPECT_EQ(*reinterpret_cast<const uint16_t*>(kIcmpEchoRequestEvenLenChecksum),
ComputeIcmpChecksum(*reinterpret_cast<const struct icmphdr*>(
kIcmpEchoRequestEvenLen),
sizeof(kIcmpEchoRequestEvenLen)));
EXPECT_EQ(*reinterpret_cast<const uint16_t*>(kIcmpEchoRequestOddLenChecksum),
ComputeIcmpChecksum(*reinterpret_cast<const struct icmphdr*>(
kIcmpEchoRequestOddLen),
sizeof(kIcmpEchoRequestOddLen)));
}
} // namespace shill