C++程序  |  148行  |  4.47 KB

/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Top level driver for models and examples generated by test_generator.py

#include "Bridge.h"
#include "NeuralNetworksWrapper.h"
#include "TestHarness.h"

#include <gtest/gtest.h>
#include <cassert>
#include <cmath>
#include <fstream>
#include <iostream>
#include <map>

// Uncomment the following line to generate DOT graphs.
//
// #define GRAPH GRAPH

namespace generated_tests {
using namespace android::nn::wrapper;
using namespace test_helper;

void graphDump([[maybe_unused]] const char* name, [[maybe_unused]] const Model& model) {
#ifdef GRAPH
    ::android::nn::bridge_tests::graphDump(
         name,
         reinterpret_cast<const ::android::nn::ModelBuilder*>(model.getHandle()));
#endif
}

template <typename T>
static void print(std::ostream& os, const MixedTyped& test) {
    // dump T-typed inputs
    for_each<T>(test, [&os](int idx, const std::vector<T>& f) {
        os << "    aliased_output" << idx << ": [";
        for (size_t i = 0; i < f.size(); ++i) {
            os << (i == 0 ? "" : ", ") << +f[i];
        }
        os << "],\n";
    });
}

static void printAll(std::ostream& os, const MixedTyped& test) {
    print<float>(os, test);
    print<int32_t>(os, test);
    print<uint8_t>(os, test);
}

// Test driver for those generated from ml/nn/runtime/test/spec
static void execute(std::function<void(Model*)> createModel,
             std::function<bool(int)> isIgnored,
             std::vector<MixedTypedExampleType>& examples,
             std::string dumpFile = "") {
    Model model;
    createModel(&model);
    model.finish();
    graphDump("", model);
    bool dumpToFile = !dumpFile.empty();

    std::ofstream s;
    if (dumpToFile) {
        s.open(dumpFile, std::ofstream::trunc);
        ASSERT_TRUE(s.is_open());
    }

    int exampleNo = 0;
    Compilation compilation(&model);
    compilation.finish();

    // If in relaxed mode, set the error range to be 5ULP of FP16.
    float fpRange = !model.isRelaxed() ? 1e-5f : 5.0f * 0.0009765625f;
    for (auto& example : examples) {
        SCOPED_TRACE(exampleNo);
        // TODO: We leave it as a copy here.
        // Should verify if the input gets modified by the test later.
        MixedTyped inputs = example.first;
        const MixedTyped& golden = example.second;

        Execution execution(&compilation);

        // Set all inputs
        for_all(inputs, [&execution](int idx, const void* p, size_t s) {
            const void* buffer = s == 0 ? nullptr : p;
            ASSERT_EQ(Result::NO_ERROR, execution.setInput(idx, buffer, s));
        });

        MixedTyped test;
        // Go through all typed outputs
        resize_accordingly(golden, test);
        for_all(test, [&execution](int idx, void* p, size_t s) {
            void* buffer = s == 0 ? nullptr : p;
            ASSERT_EQ(Result::NO_ERROR, execution.setOutput(idx, buffer, s));
        });

        Result r = execution.compute();
        ASSERT_EQ(Result::NO_ERROR, r);

        // Dump all outputs for the slicing tool
        if (dumpToFile) {
            s << "output" << exampleNo << " = {\n";
            printAll(s, test);
            // all outputs are done
            s << "}\n";
        }

        // Filter out don't cares
        MixedTyped filteredGolden = filter(golden, isIgnored);
        MixedTyped filteredTest = filter(test, isIgnored);
        // We want "close-enough" results for float

        compare(filteredGolden, filteredTest, fpRange);
        exampleNo++;
    }
}

};  // namespace generated_tests

using namespace android::nn::wrapper;

// Mixed-typed examples
typedef test_helper::MixedTypedExampleType MixedTypedExample;

class GeneratedTests : public ::testing::Test {
protected:
    virtual void SetUp() {}
};

// Testcases generated from runtime/test/specs/*.mod.py
using namespace test_helper;
using namespace generated_tests;
#include "generated/all_generated_tests.cpp"
// End of testcases generated from runtime/test/specs/*.mod.py