/*
 * Copyright (C) 2016 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "dns_responder.h"

#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>

#include <iostream>
#include <vector>

#include <log/log.h>

namespace test {

std::string errno2str() {
    char error_msg[512] = { 0 };
    if (strerror_r(errno, error_msg, sizeof(error_msg)))
        return std::string();
    return std::string(error_msg);
}

#define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())

std::string str2hex(const char* buffer, size_t len) {
    std::string str(len*2, '\0');
    for (size_t i = 0 ; i < len ; ++i) {
        static const char* hex = "0123456789ABCDEF";
        uint8_t c = buffer[i];
        str[i*2] = hex[c >> 4];
        str[i*2 + 1] = hex[c & 0x0F];
    }
    return str;
}

std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
    char host_str[NI_MAXHOST] = { 0 };
    int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
                         NI_NUMERICHOST);
    if (rv == 0) return std::string(host_str);
    return std::string();
}

/* DNS struct helpers */

const char* dnstype2str(unsigned dnstype) {
    static std::unordered_map<unsigned, const char*> kTypeStrs = {
        { ns_type::ns_t_a, "A" },
        { ns_type::ns_t_ns, "NS" },
        { ns_type::ns_t_md, "MD" },
        { ns_type::ns_t_mf, "MF" },
        { ns_type::ns_t_cname, "CNAME" },
        { ns_type::ns_t_soa, "SOA" },
        { ns_type::ns_t_mb, "MB" },
        { ns_type::ns_t_mb, "MG" },
        { ns_type::ns_t_mr, "MR" },
        { ns_type::ns_t_null, "NULL" },
        { ns_type::ns_t_wks, "WKS" },
        { ns_type::ns_t_ptr, "PTR" },
        { ns_type::ns_t_hinfo, "HINFO" },
        { ns_type::ns_t_minfo, "MINFO" },
        { ns_type::ns_t_mx, "MX" },
        { ns_type::ns_t_txt, "TXT" },
        { ns_type::ns_t_rp, "RP" },
        { ns_type::ns_t_afsdb, "AFSDB" },
        { ns_type::ns_t_x25, "X25" },
        { ns_type::ns_t_isdn, "ISDN" },
        { ns_type::ns_t_rt, "RT" },
        { ns_type::ns_t_nsap, "NSAP" },
        { ns_type::ns_t_nsap_ptr, "NSAP-PTR" },
        { ns_type::ns_t_sig, "SIG" },
        { ns_type::ns_t_key, "KEY" },
        { ns_type::ns_t_px, "PX" },
        { ns_type::ns_t_gpos, "GPOS" },
        { ns_type::ns_t_aaaa, "AAAA" },
        { ns_type::ns_t_loc, "LOC" },
        { ns_type::ns_t_nxt, "NXT" },
        { ns_type::ns_t_eid, "EID" },
        { ns_type::ns_t_nimloc, "NIMLOC" },
        { ns_type::ns_t_srv, "SRV" },
        { ns_type::ns_t_naptr, "NAPTR" },
        { ns_type::ns_t_kx, "KX" },
        { ns_type::ns_t_cert, "CERT" },
        { ns_type::ns_t_a6, "A6" },
        { ns_type::ns_t_dname, "DNAME" },
        { ns_type::ns_t_sink, "SINK" },
        { ns_type::ns_t_opt, "OPT" },
        { ns_type::ns_t_apl, "APL" },
        { ns_type::ns_t_tkey, "TKEY" },
        { ns_type::ns_t_tsig, "TSIG" },
        { ns_type::ns_t_ixfr, "IXFR" },
        { ns_type::ns_t_axfr, "AXFR" },
        { ns_type::ns_t_mailb, "MAILB" },
        { ns_type::ns_t_maila, "MAILA" },
        { ns_type::ns_t_any, "ANY" },
        { ns_type::ns_t_zxfr, "ZXFR" },
    };
    auto it = kTypeStrs.find(dnstype);
    static const char* kUnknownStr{ "UNKNOWN" };
    if (it == kTypeStrs.end()) return kUnknownStr;
    return it->second;
}

const char* dnsclass2str(unsigned dnsclass) {
    static std::unordered_map<unsigned, const char*> kClassStrs = {
        { ns_class::ns_c_in , "Internet" },
        { 2, "CSNet" },
        { ns_class::ns_c_chaos, "ChaosNet" },
        { ns_class::ns_c_hs, "Hesiod" },
        { ns_class::ns_c_none, "none" },
        { ns_class::ns_c_any, "any" }
    };
    auto it = kClassStrs.find(dnsclass);
    static const char* kUnknownStr{ "UNKNOWN" };
    if (it == kClassStrs.end()) return kUnknownStr;
    return it->second;
    return "unknown";
}

struct DNSName {
    std::string name;
    const char* read(const char* buffer, const char* buffer_end);
    char* write(char* buffer, const char* buffer_end) const;
    const char* toString() const;
private:
    const char* parseField(const char* buffer, const char* buffer_end,
                           bool* last);
};

const char* DNSName::toString() const {
    return name.c_str();
}

const char* DNSName::read(const char* buffer, const char* buffer_end) {
    const char* cur = buffer;
    bool last = false;
    do {
        cur = parseField(cur, buffer_end, &last);
        if (cur == nullptr) {
            ALOGI("parsing failed at line %d", __LINE__);
            return nullptr;
        }
    } while (!last);
    return cur;
}

char* DNSName::write(char* buffer, const char* buffer_end) const {
    char* buffer_cur = buffer;
    for (size_t pos = 0 ; pos < name.size() ; ) {
        size_t dot_pos = name.find('.', pos);
        if (dot_pos == std::string::npos) {
            // Sanity check, should never happen unless parseField is broken.
            ALOGI("logic error: all names are expected to end with a '.'");
            return nullptr;
        }
        size_t len = dot_pos - pos;
        if (len >= 256) {
            ALOGI("name component '%s' is %zu long, but max is 255",
                    name.substr(pos, dot_pos - pos).c_str(), len);
            return nullptr;
        }
        if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
            ALOGI("buffer overflow at line %d", __LINE__);
            return nullptr;
        }
        *buffer_cur++ = len;
        buffer_cur = std::copy(std::next(name.begin(), pos),
                               std::next(name.begin(), dot_pos),
                               buffer_cur);
        pos = dot_pos + 1;
    }
    // Write final zero.
    *buffer_cur++ = 0;
    return buffer_cur;
}

const char* DNSName::parseField(const char* buffer, const char* buffer_end,
                                bool* last) {
    if (buffer + sizeof(uint8_t) > buffer_end) {
        ALOGI("parsing failed at line %d", __LINE__);
        return nullptr;
    }
    unsigned field_type = *buffer >> 6;
    unsigned ofs = *buffer & 0x3F;
    const char* cur = buffer + sizeof(uint8_t);
    if (field_type == 0) {
        // length + name component
        if (ofs == 0) {
            *last = true;
            return cur;
        }
        if (cur + ofs > buffer_end) {
            ALOGI("parsing failed at line %d", __LINE__);
            return nullptr;
        }
        name.append(cur, ofs);
        name.push_back('.');
        return cur + ofs;
    } else if (field_type == 3) {
        ALOGI("name compression not implemented");
        return nullptr;
    }
    ALOGI("invalid name field type");
    return nullptr;
}

struct DNSQuestion {
    DNSName qname;
    unsigned qtype;
    unsigned qclass;
    const char* read(const char* buffer, const char* buffer_end);
    char* write(char* buffer, const char* buffer_end) const;
    std::string toString() const;
};

const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
    const char* cur = qname.read(buffer, buffer_end);
    if (cur == nullptr) {
        ALOGI("parsing failed at line %d", __LINE__);
        return nullptr;
    }
    if (cur + 2*sizeof(uint16_t) > buffer_end) {
        ALOGI("parsing failed at line %d", __LINE__);
        return nullptr;
    }
    qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
    qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
    return cur + 2*sizeof(uint16_t);
}

char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
    char* buffer_cur = qname.write(buffer, buffer_end);
    if (buffer_cur == nullptr) return nullptr;
    if (buffer_cur + 2*sizeof(uint16_t) > buffer_end) {
        ALOGI("buffer overflow on line %d", __LINE__);
        return nullptr;
    }
    *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
    *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) =
            htons(qclass);
    return buffer_cur + 2*sizeof(uint16_t);
}

std::string DNSQuestion::toString() const {
    char buffer[4096];
    int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.toString(),
                       dnstype2str(qtype), dnsclass2str(qclass));
    return std::string(buffer, len);
}

struct DNSRecord {
    DNSName name;
    unsigned rtype;
    unsigned rclass;
    unsigned ttl;
    std::vector<char> rdata;
    const char* read(const char* buffer, const char* buffer_end);
    char* write(char* buffer, const char* buffer_end) const;
    std::string toString() const;
private:
    struct IntFields {
        uint16_t rtype;
        uint16_t rclass;
        uint32_t ttl;
        uint16_t rdlen;
    } __attribute__((__packed__));

    const char* readIntFields(const char* buffer, const char* buffer_end,
            unsigned* rdlen);
    char* writeIntFields(unsigned rdlen, char* buffer,
                         const char* buffer_end) const;
};

const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
    const char* cur = name.read(buffer, buffer_end);
    if (cur == nullptr) {
        ALOGI("parsing failed at line %d", __LINE__);
        return nullptr;
    }
    unsigned rdlen = 0;
    cur = readIntFields(cur, buffer_end, &rdlen);
    if (cur == nullptr) {
        ALOGI("parsing failed at line %d", __LINE__);
        return nullptr;
    }
    if (cur + rdlen > buffer_end) {
        ALOGI("parsing failed at line %d", __LINE__);
        return nullptr;
    }
    rdata.assign(cur, cur + rdlen);
    return cur + rdlen;
}

char* DNSRecord::write(char* buffer, const char* buffer_end) const {
    char* buffer_cur = name.write(buffer, buffer_end);
    if (buffer_cur == nullptr) return nullptr;
    buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
    if (buffer_cur == nullptr) return nullptr;
    if (buffer_cur + rdata.size() > buffer_end) {
        ALOGI("buffer overflow on line %d", __LINE__);
        return nullptr;
    }
    return std::copy(rdata.begin(), rdata.end(), buffer_cur);
}

std::string DNSRecord::toString() const {
    char buffer[4096];
    int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.toString(),
                       dnstype2str(rtype), dnsclass2str(rclass));
    return std::string(buffer, len);
}

const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end,
                                     unsigned* rdlen) {
    if (buffer + sizeof(IntFields) > buffer_end ) {
        ALOGI("parsing failed at line %d", __LINE__);
        return nullptr;
    }
    const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
    rtype = ntohs(intfields.rtype);
    rclass = ntohs(intfields.rclass);
    ttl = ntohl(intfields.ttl);
    *rdlen = ntohs(intfields.rdlen);
    return buffer + sizeof(IntFields);
}

char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer,
                                const char* buffer_end) const {
    if (buffer + sizeof(IntFields) > buffer_end ) {
        ALOGI("buffer overflow on line %d", __LINE__);
        return nullptr;
    }
    auto& intfields = *reinterpret_cast<IntFields*>(buffer);
    intfields.rtype = htons(rtype);
    intfields.rclass = htons(rclass);
    intfields.ttl = htonl(ttl);
    intfields.rdlen = htons(rdlen);
    return buffer + sizeof(IntFields);
}

struct DNSHeader {
    unsigned id;
    bool ra;
    uint8_t rcode;
    bool qr;
    uint8_t opcode;
    bool aa;
    bool tr;
    bool rd;
    std::vector<DNSQuestion> questions;
    std::vector<DNSRecord> answers;
    std::vector<DNSRecord> authorities;
    std::vector<DNSRecord> additionals;
    const char* read(const char* buffer, const char* buffer_end);
    char* write(char* buffer, const char* buffer_end) const;
    std::string toString() const;

private:
    struct Header {
        uint16_t id;
        uint8_t flags0;
        uint8_t flags1;
        uint16_t qdcount;
        uint16_t ancount;
        uint16_t nscount;
        uint16_t arcount;
    } __attribute__((__packed__));

    const char* readHeader(const char* buffer, const char* buffer_end,
                           unsigned* qdcount, unsigned* ancount,
                           unsigned* nscount, unsigned* arcount);
};

const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
    unsigned qdcount;
    unsigned ancount;
    unsigned nscount;
    unsigned arcount;
    const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount,
                                 &nscount, &arcount);
    if (cur == nullptr) {
        ALOGI("parsing failed at line %d", __LINE__);
        return nullptr;
    }
    if (qdcount) {
        questions.resize(qdcount);
        for (unsigned i = 0 ; i < qdcount ; ++i) {
            cur = questions[i].read(cur, buffer_end);
            if (cur == nullptr) {
                ALOGI("parsing failed at line %d", __LINE__);
                return nullptr;
            }
        }
    }
    if (ancount) {
        answers.resize(ancount);
        for (unsigned i = 0 ; i < ancount ; ++i) {
            cur = answers[i].read(cur, buffer_end);
            if (cur == nullptr) {
                ALOGI("parsing failed at line %d", __LINE__);
                return nullptr;
            }
        }
    }
    if (nscount) {
        authorities.resize(nscount);
        for (unsigned i = 0 ; i < nscount ; ++i) {
            cur = authorities[i].read(cur, buffer_end);
            if (cur == nullptr) {
                ALOGI("parsing failed at line %d", __LINE__);
                return nullptr;
            }
        }
    }
    if (arcount) {
        additionals.resize(arcount);
        for (unsigned i = 0 ; i < arcount ; ++i) {
            cur = additionals[i].read(cur, buffer_end);
            if (cur == nullptr) {
                ALOGI("parsing failed at line %d", __LINE__);
                return nullptr;
            }
        }
    }
    return cur;
}

char* DNSHeader::write(char* buffer, const char* buffer_end) const {
    if (buffer + sizeof(Header) > buffer_end) {
        ALOGI("buffer overflow on line %d", __LINE__);
        return nullptr;
    }
    Header& header = *reinterpret_cast<Header*>(buffer);
    // bytes 0-1
    header.id = htons(id);
    // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
    header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
    // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
    header.flags1 = rcode;
    // rest of header
    header.qdcount = htons(questions.size());
    header.ancount = htons(answers.size());
    header.nscount = htons(authorities.size());
    header.arcount = htons(additionals.size());
    char* buffer_cur = buffer + sizeof(Header);
    for (const DNSQuestion& question : questions) {
        buffer_cur = question.write(buffer_cur, buffer_end);
        if (buffer_cur == nullptr) return nullptr;
    }
    for (const DNSRecord& answer : answers) {
        buffer_cur = answer.write(buffer_cur, buffer_end);
        if (buffer_cur == nullptr) return nullptr;
    }
    for (const DNSRecord& authority : authorities) {
        buffer_cur = authority.write(buffer_cur, buffer_end);
        if (buffer_cur == nullptr) return nullptr;
    }
    for (const DNSRecord& additional : additionals) {
        buffer_cur = additional.write(buffer_cur, buffer_end);
        if (buffer_cur == nullptr) return nullptr;
    }
    return buffer_cur;
}

std::string DNSHeader::toString() const {
    // TODO
    return std::string();
}

const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end,
                                  unsigned* qdcount, unsigned* ancount,
                                  unsigned* nscount, unsigned* arcount) {
    if (buffer + sizeof(Header) > buffer_end)
        return 0;
    const auto& header = *reinterpret_cast<const Header*>(buffer);
    // bytes 0-1
    id = ntohs(header.id);
    // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
    qr = header.flags0 >> 7;
    opcode = (header.flags0 >> 3) & 0x0F;
    aa = (header.flags0 >> 2) & 1;
    tr = (header.flags0 >> 1) & 1;
    rd = header.flags0 & 1;
    // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
    ra = header.flags1 >> 7;
    rcode = header.flags1 & 0xF;
    // rest of header
    *qdcount = ntohs(header.qdcount);
    *ancount = ntohs(header.ancount);
    *nscount = ntohs(header.nscount);
    *arcount = ntohs(header.arcount);
    return buffer + sizeof(Header);
}

/* DNS responder */

DNSResponder::DNSResponder(std::string listen_address,
                           std::string listen_service, int poll_timeout_ms,
                           uint16_t error_rcode, double response_probability) :
    listen_address_(std::move(listen_address)), listen_service_(std::move(listen_service)),
    poll_timeout_ms_(poll_timeout_ms), error_rcode_(error_rcode),
    response_probability_(response_probability),
    socket_(-1), epoll_fd_(-1), terminate_(false) { }

DNSResponder::~DNSResponder() {
    stopServer();
}

void DNSResponder::addMapping(const char* name, ns_type type,
        const char* addr) {
    std::lock_guard<std::mutex> lock(mappings_mutex_);
    auto it = mappings_.find(QueryKey(name, type));
    if (it != mappings_.end()) {
        ALOGI("Overwriting mapping for (%s, %s), previous address %s, new "
            "address %s", name, dnstype2str(type), it->second.c_str(),
            addr);
        it->second = addr;
        return;
    }
    mappings_.emplace(std::piecewise_construct,
                      std::forward_as_tuple(name, type),
                      std::forward_as_tuple(addr));
}

void DNSResponder::removeMapping(const char* name, ns_type type) {
    std::lock_guard<std::mutex> lock(mappings_mutex_);
    auto it = mappings_.find(QueryKey(name, type));
    if (it != mappings_.end()) {
        ALOGI("Cannot remove mapping mapping from (%s, %s), not present", name,
            dnstype2str(type));
        return;
    }
    mappings_.erase(it);
}

void DNSResponder::setResponseProbability(double response_probability) {
    response_probability_ = response_probability;
}

bool DNSResponder::running() const {
    return socket_ != -1;
}

bool DNSResponder::startServer() {
    if (running()) {
        ALOGI("server already running");
        return false;
    }
    addrinfo ai_hints{
        .ai_family = AF_UNSPEC,
        .ai_socktype = SOCK_DGRAM,
        .ai_flags = AI_PASSIVE
    };
    addrinfo* ai_res;
    int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
                         &ai_hints, &ai_res);
    if (rv) {
        ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
            listen_service_.c_str(), gai_strerror(rv));
        return false;
    }
    int s = -1;
    for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) {
        s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
        if (s < 0) continue;
        const int one = 1;
        setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
        if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
            APLOGI("bind failed for socket %d", s);
            close(s);
            s = -1;
            continue;
        }
        std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
        ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str());
        break;
    }
    freeaddrinfo(ai_res);
    if (s < 0) {
        ALOGI("bind() failed");
        return false;
    }

    int flags = fcntl(s, F_GETFL, 0);
    if (flags < 0) flags = 0;
    if (fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) {
        APLOGI("fcntl(F_SETFL) failed for socket %d", s);
        close(s);
        return false;
    }

    int ep_fd = epoll_create(1);
    if (ep_fd < 0) {
        char error_msg[512] = { 0 };
        if (strerror_r(errno, error_msg, sizeof(error_msg)))
            strncpy(error_msg, "UNKNOWN", sizeof(error_msg));
        APLOGI("epoll_create() failed: %s", error_msg);
        close(s);
        return false;
    }
    epoll_event ev;
    ev.events = EPOLLIN;
    ev.data.fd = s;
    if (epoll_ctl(ep_fd, EPOLL_CTL_ADD, s, &ev) < 0) {
        APLOGI("epoll_ctl() failed for socket %d", s);
        close(ep_fd);
        close(s);
        return false;
    }

    epoll_fd_ = ep_fd;
    socket_ = s;
    {
        std::lock_guard<std::mutex> lock(update_mutex_);
        handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
    }
    ALOGI("server started successfully");
    return true;
}

bool DNSResponder::stopServer() {
    std::lock_guard<std::mutex> lock(update_mutex_);
    if (!running()) {
        ALOGI("server not running");
        return false;
    }
    if (terminate_) {
        ALOGI("LOGIC ERROR");
        return false;
    }
    ALOGI("stopping server");
    terminate_ = true;
    handler_thread_.join();
    close(epoll_fd_);
    close(socket_);
    terminate_ = false;
    socket_ = -1;
    ALOGI("server stopped successfully");
    return true;
}

std::vector<std::pair<std::string, ns_type >> DNSResponder::queries() const {
    std::lock_guard<std::mutex> lock(queries_mutex_);
    return queries_;
}

void DNSResponder::clearQueries() {
    std::lock_guard<std::mutex> lock(queries_mutex_);
    queries_.clear();
}

void DNSResponder::requestHandler() {
    epoll_event evs[1];
    while (!terminate_) {
        int n = epoll_wait(epoll_fd_, evs, 1, poll_timeout_ms_);
        if (n == 0) continue;
        if (n < 0) {
            ALOGI("epoll_wait() failed");
            // TODO(imaipi): terminate on error.
            return;
        }
        char buffer[4096];
        sockaddr_storage sa;
        socklen_t sa_len = sizeof(sa);
        ssize_t len;
        do {
            len = recvfrom(socket_, buffer, sizeof(buffer), 0,
                           (sockaddr*) &sa, &sa_len);
        } while (len < 0 && (errno == EAGAIN || errno == EINTR));
        if (len <= 0) {
            ALOGI("recvfrom() failed");
            continue;
        }
        ALOGI("read %zd bytes", len);
        char response[4096];
        size_t response_len = sizeof(response);
        if (handleDNSRequest(buffer, len, response, &response_len) &&
            response_len > 0) {
            len = sendto(socket_, response, response_len, 0,
                         reinterpret_cast<const sockaddr*>(&sa), sa_len);
            std::string host_str =
                addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
            if (len > 0) {
                ALOGI("sent %zu bytes to %s", len, host_str.c_str());
            } else {
                APLOGI("sendto() failed for %s", host_str.c_str());
            }
            // Test that the response is actually a correct DNS message.
            const char* response_end = response + len;
            DNSHeader header;
            const char* cur = header.read(response, response_end);
            if (cur == nullptr) ALOGI("response is flawed");

        } else {
            ALOGI("not responding");
        }
    }
}

bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
                                    char* response, size_t* response_len)
                                    const {
    ALOGI("request: '%s'", str2hex(buffer, len).c_str());
    const char* buffer_end = buffer + len;
    DNSHeader header;
    const char* cur = header.read(buffer, buffer_end);
    // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
    if (cur == nullptr) {
        ALOGI("failed to parse query");
        return false;
    }
    if (header.qr) {
        ALOGI("response received instead of a query");
        return false;
    }
    if (header.opcode != ns_opcode::ns_o_query) {
        ALOGI("unsupported request opcode received");
        return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
                                 response_len);
    }
    if (header.questions.empty()) {
        ALOGI("no questions present");
        return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
                                 response_len);
    }
    if (!header.answers.empty()) {
        ALOGI("already %zu answers present in query", header.answers.size());
        return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
                                 response_len);
    }
    {
        std::lock_guard<std::mutex> lock(queries_mutex_);
        for (const DNSQuestion& question : header.questions) {
            queries_.push_back(make_pair(question.qname.name,
                                         ns_type(question.qtype)));
        }
    }

    // Ignore requests with the preset probability.
    auto constexpr bound = std::numeric_limits<unsigned>::max();
    if (arc4random_uniform(bound) > bound*response_probability_) {
        ALOGI("returning SRVFAIL in accordance with probability distribution");
        return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
                                 response_len);
    }

    for (const DNSQuestion& question : header.questions) {
        if (question.qclass != ns_class::ns_c_in &&
            question.qclass != ns_class::ns_c_any) {
            ALOGI("unsupported question class %u", question.qclass);
            return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
                                     response_len);
        }
        if (!addAnswerRecords(question, &header.answers)) {
            return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
                                     response_len);
        }
    }
    header.qr = true;
    char* response_cur = header.write(response, response + *response_len);
    if (response_cur == nullptr) {
        return false;
    }
    *response_len = response_cur - response;
    return true;
}

bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
                                    std::vector<DNSRecord>* answers) const {
    auto it = mappings_.find(QueryKey(question.qname.name, question.qtype));
    if (it == mappings_.end()) {
        // TODO(imaipi): handle correctly
        ALOGI("no mapping found for %s %s, lazily refusing to add an answer",
            question.qname.name.c_str(), dnstype2str(question.qtype));
        return true;
    }
    ALOGI("mapping found for %s %s: %s", question.qname.name.c_str(),
        dnstype2str(question.qtype), it->second.c_str());
    DNSRecord record;
    record.name = question.qname;
    record.rtype = question.qtype;
    record.rclass = ns_class::ns_c_in;
    record.ttl = 5;  // seconds
    if (question.qtype == ns_type::ns_t_a) {
        record.rdata.resize(4);
        if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) {
            ALOGI("inet_pton(AF_INET, %s) failed", it->second.c_str());
            return false;
        }
    } else if (question.qtype == ns_type::ns_t_aaaa) {
        record.rdata.resize(16);
        if (inet_pton(AF_INET6, it->second.c_str(), record.rdata.data()) != 1) {
            ALOGI("inet_pton(AF_INET6, %s) failed", it->second.c_str());
            return false;
        }
    } else {
        ALOGI("unhandled qtype %s", dnstype2str(question.qtype));
        return false;
    }
    answers->push_back(std::move(record));
    return true;
}

bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode,
                                     char* response, size_t* response_len)
                                     const {
    header->answers.clear();
    header->authorities.clear();
    header->additionals.clear();
    header->rcode = rcode;
    header->qr = true;
    char* response_cur = header->write(response, response + *response_len);
    if (response_cur == nullptr) return false;
    *response_len = response_cur - response;
    return true;
}

}  // namespace test