// // Copyright (C) 2012 The Android Open Source Project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #include "shill/net/netlink_message.h" #include <limits.h> #include <algorithm> #include <map> #include <memory> #include <string> #include <base/format_macros.h> #include <base/logging.h> #include <base/stl_util.h> #include <base/strings/stringprintf.h> #include "shill/net/netlink_packet.h" using base::StringAppendF; using base::StringPrintf; using std::map; using std::min; using std::string; namespace shill { const uint32_t NetlinkMessage::kBroadcastSequenceNumber = 0; const uint16_t NetlinkMessage::kIllegalMessageType = UINT16_MAX; // NetlinkMessage ByteString NetlinkMessage::EncodeHeader(uint32_t sequence_number) { ByteString result; if (message_type_ == kIllegalMessageType) { LOG(ERROR) << "Message type not set"; return result; } sequence_number_ = sequence_number; if (sequence_number_ == kBroadcastSequenceNumber) { LOG(ERROR) << "Couldn't get a legal sequence number"; return result; } // Build netlink header. nlmsghdr header; size_t nlmsghdr_with_pad = NLMSG_ALIGN(sizeof(header)); header.nlmsg_len = nlmsghdr_with_pad; header.nlmsg_type = message_type_; header.nlmsg_flags = NLM_F_REQUEST | flags_; header.nlmsg_seq = sequence_number_; header.nlmsg_pid = getpid(); // Netlink header + pad. result.Append(ByteString(reinterpret_cast<unsigned char*>(&header), sizeof(header))); result.Resize(nlmsghdr_with_pad); // Zero-fill pad space (if any). return result; } bool NetlinkMessage::InitAndStripHeader(NetlinkPacket* packet) { const nlmsghdr& header = packet->GetNlMsgHeader(); message_type_ = header.nlmsg_type; flags_ = header.nlmsg_flags; sequence_number_ = header.nlmsg_seq; return true; } bool NetlinkMessage::InitFromPacket(NetlinkPacket* packet, NetlinkMessage::MessageContext context) { if (!packet) { LOG(ERROR) << "Null |packet| parameter"; return false; } if (!InitAndStripHeader(packet)) { return false; } return true; } // static void NetlinkMessage::PrintBytes(int log_level, const unsigned char* buf, size_t num_bytes) { VLOG(log_level) << "Netlink Message -- Examining Bytes"; if (!buf) { VLOG(log_level) << "<NULL Buffer>"; return; } if (num_bytes >= sizeof(nlmsghdr)) { PrintHeader(log_level, reinterpret_cast<const nlmsghdr*>(buf)); buf += sizeof(nlmsghdr); num_bytes -= sizeof(nlmsghdr); } else { VLOG(log_level) << "Not enough bytes (" << num_bytes << ") for a complete nlmsghdr (requires " << sizeof(nlmsghdr) << ")."; } PrintPayload(log_level, buf, num_bytes); } // static void NetlinkMessage::PrintPacket(int log_level, const NetlinkPacket& packet) { VLOG(log_level) << "Netlink Message -- Examining Packet"; if (!packet.IsValid()) { VLOG(log_level) << "<Invalid Buffer>"; return; } PrintHeader(log_level, &packet.GetNlMsgHeader()); const ByteString& payload = packet.GetPayload(); PrintPayload(log_level, payload.GetConstData(), payload.GetLength()); } // static void NetlinkMessage::PrintHeader(int log_level, const nlmsghdr* header) { const unsigned char* buf = reinterpret_cast<const unsigned char*>(header); VLOG(log_level) << StringPrintf( "len: %02x %02x %02x %02x = %u bytes", buf[0], buf[1], buf[2], buf[3], header->nlmsg_len); VLOG(log_level) << StringPrintf( "type | flags: %02x %02x %02x %02x - type:%u flags:%s%s%s%s%s", buf[4], buf[5], buf[6], buf[7], header->nlmsg_type, ((header->nlmsg_flags & NLM_F_REQUEST) ? " REQUEST" : ""), ((header->nlmsg_flags & NLM_F_MULTI) ? " MULTI" : ""), ((header->nlmsg_flags & NLM_F_ACK) ? " ACK" : ""), ((header->nlmsg_flags & NLM_F_ECHO) ? " ECHO" : ""), ((header->nlmsg_flags & NLM_F_DUMP_INTR) ? " BAD-SEQ" : "")); VLOG(log_level) << StringPrintf( "sequence: %02x %02x %02x %02x = %u", buf[8], buf[9], buf[10], buf[11], header->nlmsg_seq); VLOG(log_level) << StringPrintf( "pid: %02x %02x %02x %02x = %u", buf[12], buf[13], buf[14], buf[15], header->nlmsg_pid); } // static void NetlinkMessage::PrintPayload(int log_level, const unsigned char* buf, size_t num_bytes) { while (num_bytes) { string output; size_t bytes_this_row = min(num_bytes, static_cast<size_t>(32)); for (size_t i = 0; i < bytes_this_row; ++i) { StringAppendF(&output, " %02x", *buf++); } VLOG(log_level) << output; num_bytes -= bytes_this_row; } } // ErrorAckMessage. const uint16_t ErrorAckMessage::kMessageType = NLMSG_ERROR; bool ErrorAckMessage::InitFromPacket(NetlinkPacket* packet, NetlinkMessage::MessageContext context) { if (!packet) { LOG(ERROR) << "Null |const_msg| parameter"; return false; } if (!InitAndStripHeader(packet)) { return false; } // Get the error code from the payload. return packet->ConsumeData(sizeof(error_), &error_); } ByteString ErrorAckMessage::Encode(uint32_t sequence_number) { LOG(ERROR) << "We're not supposed to send errors or Acks to the kernel"; return ByteString(); } string ErrorAckMessage::ToString() const { string output; if (error()) { StringAppendF(&output, "NETLINK_ERROR 0x%" PRIx32 ": %s", -error_, strerror(-error_)); } else { StringAppendF(&output, "ACK"); } return output; } void ErrorAckMessage::Print(int header_log_level, int /*detail_log_level*/) const { VLOG(header_log_level) << ToString(); } // NoopMessage. const uint16_t NoopMessage::kMessageType = NLMSG_NOOP; ByteString NoopMessage::Encode(uint32_t sequence_number) { LOG(ERROR) << "We're not supposed to send NOOP to the kernel"; return ByteString(); } void NoopMessage::Print(int header_log_level, int /*detail_log_level*/) const { VLOG(header_log_level) << ToString(); } // DoneMessage. const uint16_t DoneMessage::kMessageType = NLMSG_DONE; ByteString DoneMessage::Encode(uint32_t sequence_number) { return EncodeHeader(sequence_number); } void DoneMessage::Print(int header_log_level, int /*detail_log_level*/) const { VLOG(header_log_level) << ToString(); } // OverrunMessage. const uint16_t OverrunMessage::kMessageType = NLMSG_OVERRUN; ByteString OverrunMessage::Encode(uint32_t sequence_number) { LOG(ERROR) << "We're not supposed to send Overruns to the kernel"; return ByteString(); } void OverrunMessage::Print(int header_log_level, int /*detail_log_level*/) const { VLOG(header_log_level) << ToString(); } // UnknownMessage. ByteString UnknownMessage::Encode(uint32_t sequence_number) { LOG(ERROR) << "We're not supposed to send UNKNOWN messages to the kernel"; return ByteString(); } void UnknownMessage::Print(int header_log_level, int /*detail_log_level*/) const { int total_bytes = message_body_.GetLength(); const uint8_t* const_data = message_body_.GetConstData(); string output = StringPrintf("%d bytes:", total_bytes); for (int i = 0; i < total_bytes; ++i) { StringAppendF(&output, " 0x%02x", const_data[i]); } VLOG(header_log_level) << output; } // // Factory class. // bool NetlinkMessageFactory::AddFactoryMethod(uint16_t message_type, FactoryMethod factory) { if (ContainsKey(factories_, message_type)) { LOG(WARNING) << "Message type " << message_type << " already exists."; return false; } if (message_type == NetlinkMessage::kIllegalMessageType) { LOG(ERROR) << "Not installing factory for illegal message type."; return false; } factories_[message_type] = factory; return true; } NetlinkMessage* NetlinkMessageFactory::CreateMessage( NetlinkPacket* packet, NetlinkMessage::MessageContext context) const { std::unique_ptr<NetlinkMessage> message; auto message_type = packet->GetMessageType(); if (message_type == NoopMessage::kMessageType) { message.reset(new NoopMessage()); } else if (message_type == DoneMessage::kMessageType) { message.reset(new DoneMessage()); } else if (message_type == OverrunMessage::kMessageType) { message.reset(new OverrunMessage()); } else if (message_type == ErrorAckMessage::kMessageType) { message.reset(new ErrorAckMessage()); } else if (ContainsKey(factories_, message_type)) { map<uint16_t, FactoryMethod>::const_iterator factory; factory = factories_.find(message_type); message.reset(factory->second.Run(*packet)); } // If no factory exists for this message _or_ if a factory exists but it // failed, there'll be no message. Handle either of those cases, by // creating an |UnknownMessage|. if (!message) { message.reset(new UnknownMessage(message_type, packet->GetPayload())); } if (!message->InitFromPacket(packet, context)) { LOG(ERROR) << "Message did not initialize properly"; return nullptr; } return message.release(); } } // namespace shill.