C++程序  |  215行  |  6.97 KB

/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#define LOG_TAG "DnsTlsTransport"
//#define LOG_NDEBUG 0

#include "DnsTlsTransport.h"

#include <arpa/inet.h>
#include <arpa/nameser.h>

#include "DnsTlsSocketFactory.h"
#include "IDnsTlsSocketFactory.h"

#include "log/log.h"

namespace android {
namespace net {

std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) {
    std::lock_guard guard(mLock);

    auto record = mQueries.recordQuery(query);
    if (!record) {
        return std::async(std::launch::deferred, []{
            return (Result) { .code = Response::internal_error };
        });
    }

    if (!mSocket) {
        ALOGV("No socket for query.  Opening socket and sending.");
        doConnect();
    } else {
        sendQuery(record->query);
    }

    return std::move(record->result);
}

bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query q) {
    // Strip off the ID number and send the new ID instead.
    bool sent = mSocket->query(q.newId, netdutils::drop(q.query, 2));
    if (sent) {
        mQueries.markTried(q.newId);
    }
    return sent;
}

void DnsTlsTransport::doConnect() {
    ALOGV("Constructing new socket");
    mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache);

    if (mSocket) {
        auto queries = mQueries.getAll();
        ALOGV("Initialization succeeded.  Reissuing %zu queries.", queries.size());
        for(auto& q : queries) {
            if (!sendQuery(q)) {
                break;
            }
        }
    } else {
        ALOGV("Initialization failed.");
        mSocket.reset();
        ALOGV("Failing all pending queries.");
        mQueries.clear();
    }
}

void DnsTlsTransport::onResponse(std::vector<uint8_t> response) {
    mQueries.onResponse(std::move(response));
}

void DnsTlsTransport::onClosed() {
    std::lock_guard guard(mLock);
    if (mClosing) {
        return;
    }
    // Move remaining operations to a new thread.
    // This is necessary because
    // 1. onClosed is currently running on a thread that blocks mSocket's destructor
    // 2. doReconnect will call that destructor
    if (mReconnectThread) {
        // Complete cleanup of a previous reconnect thread, if present.
        mReconnectThread->join();
        // Joining a thread that is trying to acquire mLock, while holding mLock,
        // looks like it risks a deadlock.  However, a deadlock will not occur because
        // once onClosed is called, it cannot be called again until after doReconnect
        // acquires mLock.
    }
    mReconnectThread.reset(new std::thread(&DnsTlsTransport::doReconnect, this));
}

void DnsTlsTransport::doReconnect() {
    std::lock_guard guard(mLock);
    if (mClosing) {
        return;
    }
    mQueries.cleanup();
    if (!mQueries.empty()) {
        ALOGV("Fast reconnect to retry remaining queries");
        doConnect();
    } else {
        ALOGV("No pending queries.  Going idle.");
        mSocket.reset();
    }
}

DnsTlsTransport::~DnsTlsTransport() {
    ALOGV("Destructor");
    {
        std::lock_guard guard(mLock);
        ALOGV("Locked destruction procedure");
        mQueries.clear();
        mClosing = true;
    }
    // It's possible that a reconnect thread was spawned and waiting for mLock.
    // It's safe for that thread to run now because mClosing is true (and mQueries is empty),
    // but we need to wait for it to finish before allowing destruction to proceed.
    if (mReconnectThread) {
        ALOGV("Waiting for reconnect thread to terminate");
        mReconnectThread->join();
        mReconnectThread.reset();
    }
    // Ensure that the socket is destroyed, and can clean up its callback threads,
    // before any of this object's fields become invalid.
    mSocket.reset();
    ALOGV("Destructor completed");
}

// static
// TODO: Use this function to preheat the session cache.
// That may require moving it to DnsTlsDispatcher.
bool DnsTlsTransport::validate(const DnsTlsServer& server, unsigned netid, uint32_t mark) {
    ALOGV("Beginning validation on %u", netid);
    // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
    // order to prove that it is actually a working DNS over TLS server.
    static const char kDnsSafeChars[] =
            "abcdefhijklmnopqrstuvwxyz"
            "ABCDEFHIJKLMNOPQRSTUVWXYZ"
            "0123456789";
    const auto c = [](uint8_t rnd) -> uint8_t {
        return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
    };
    uint8_t rnd[8];
    arc4random_buf(rnd, std::size(rnd));
    // We could try to use res_mkquery() here, but it's basically the same.
    uint8_t query[] = {
        rnd[6], rnd[7],  // [0-1]   query ID
        1, 0,  // [2-3]   flags; query[2] = 1 for recursion desired (RD).
        0, 1,  // [4-5]   QDCOUNT (number of queries)
        0, 0,  // [6-7]   ANCOUNT (number of answers)
        0, 0,  // [8-9]   NSCOUNT (number of name server records)
        0, 0,  // [10-11] ARCOUNT (number of additional records)
        17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
            '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
        6, 'm', 'e', 't', 'r', 'i', 'c',
        7, 'g', 's', 't', 'a', 't', 'i', 'c',
        3, 'c', 'o', 'm',
        0,  // null terminator of FQDN (root TLD)
        0, ns_t_aaaa,  // QTYPE
        0, ns_c_in     // QCLASS
    };
    const int qlen = std::size(query);

    int replylen = 0;
    DnsTlsSocketFactory factory;
    DnsTlsTransport transport(server, mark, &factory);
    auto r = transport.query(netdutils::Slice(query, qlen)).get();
    if (r.code != Response::success) {
        ALOGV("query failed");
        return false;
    }

    const std::vector<uint8_t>& recvbuf = r.response;
    if (recvbuf.size() < NS_HFIXEDSZ) {
        ALOGW("short response: %d", replylen);
        return false;
    }

    const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
    if (qdcount != 1) {
        ALOGW("reply query count != 1: %d", qdcount);
        return false;
    }

    const int ancount = (recvbuf[6] << 8) | recvbuf[7];
    ALOGV("%u answer count: %d", netid, ancount);

    // TODO: Further validate the response contents (check for valid AAAA record, ...).
    // Note that currently, integration tests rely on this function accepting a
    // response with zero records.
#if 0
    for (int i = 0; i < resplen; i++) {
        ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
    }
#endif
    return true;
}

}  // end of namespace net
}  // end of namespace android