/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "Callbacks.h"
#include "TestHarness.h"
#include "Utils.h"

#include <android-base/logging.h>
#include <android/hardware/neuralnetworks/1.0/IDevice.h>
#include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
#include <android/hardware/neuralnetworks/1.0/IPreparedModel.h>
#include <android/hardware/neuralnetworks/1.0/IPreparedModelCallback.h>
#include <android/hardware/neuralnetworks/1.0/types.h>
#include <android/hidl/allocator/1.0/IAllocator.h>
#include <android/hidl/memory/1.0/IMemory.h>
#include <hidlmemory/mapping.h>
#include <iostream>

namespace android {
namespace hardware {
namespace neuralnetworks {

namespace generated_tests {
using ::android::hardware::neuralnetworks::V1_0::implementation::ExecutionCallback;
using ::android::hardware::neuralnetworks::V1_0::implementation::PreparedModelCallback;
using ::test_helper::filter;
using ::test_helper::for_all;
using ::test_helper::for_each;
using ::test_helper::resize_accordingly;
using ::test_helper::MixedTyped;
using ::test_helper::MixedTypedExampleType;
using ::test_helper::Float32Operands;
using ::test_helper::Int32Operands;
using ::test_helper::Quant8Operands;
using ::test_helper::compare;

template <typename T>
void copy_back_(MixedTyped* dst, const std::vector<RequestArgument>& ra, char* src) {
    MixedTyped& test = *dst;
    for_each<T>(test, [&ra, src](int index, std::vector<T>& m) {
        ASSERT_EQ(m.size(), ra[index].location.length / sizeof(T));
        char* begin = src + ra[index].location.offset;
        memcpy(m.data(), begin, ra[index].location.length);
    });
}

void copy_back(MixedTyped* dst, const std::vector<RequestArgument>& ra, char* src) {
    copy_back_<float>(dst, ra, src);
    copy_back_<int32_t>(dst, ra, src);
    copy_back_<uint8_t>(dst, ra, src);
}

// Top level driver for models and examples generated by test_generator.py
// Test driver for those generated from ml/nn/runtime/test/spec
void EvaluatePreparedModel(sp<IPreparedModel>& preparedModel, std::function<bool(int)> is_ignored,
                           const std::vector<MixedTypedExampleType>& examples,
                           float fpRange = 1e-5f) {
    const uint32_t INPUT = 0;
    const uint32_t OUTPUT = 1;

    int example_no = 1;
    for (auto& example : examples) {
        SCOPED_TRACE(example_no++);

        const MixedTyped& inputs = example.first;
        const MixedTyped& golden = example.second;

        std::vector<RequestArgument> inputs_info, outputs_info;
        uint32_t inputSize = 0, outputSize = 0;

        // This function only partially specifies the metadata (vector of RequestArguments).
        // The contents are copied over below.
        for_all(inputs, [&inputs_info, &inputSize](int index, auto, auto s) {
            if (inputs_info.size() <= static_cast<size_t>(index)) inputs_info.resize(index + 1);
            RequestArgument arg = {
                .location = {.poolIndex = INPUT, .offset = 0, .length = static_cast<uint32_t>(s)},
                .dimensions = {},
            };
            RequestArgument arg_empty = {
                .hasNoValue = true,
            };
            inputs_info[index] = s ? arg : arg_empty;
            inputSize += s;
        });
        // Compute offset for inputs 1 and so on
        {
            size_t offset = 0;
            for (auto& i : inputs_info) {
                if (!i.hasNoValue) i.location.offset = offset;
                offset += i.location.length;
            }
        }

        MixedTyped test;  // holding test results

        // Go through all outputs, initialize RequestArgument descriptors
        resize_accordingly(golden, test);
        for_all(golden, [&outputs_info, &outputSize](int index, auto, auto s) {
            if (outputs_info.size() <= static_cast<size_t>(index)) outputs_info.resize(index + 1);
            RequestArgument arg = {
                .location = {.poolIndex = OUTPUT, .offset = 0, .length = static_cast<uint32_t>(s)},
                .dimensions = {},
            };
            outputs_info[index] = arg;
            outputSize += s;
        });
        // Compute offset for outputs 1 and so on
        {
            size_t offset = 0;
            for (auto& i : outputs_info) {
                i.location.offset = offset;
                offset += i.location.length;
            }
        }
        std::vector<hidl_memory> pools = {nn::allocateSharedMemory(inputSize),
                                          nn::allocateSharedMemory(outputSize)};
        ASSERT_NE(0ull, pools[INPUT].size());
        ASSERT_NE(0ull, pools[OUTPUT].size());

        // load data
        sp<IMemory> inputMemory = mapMemory(pools[INPUT]);
        sp<IMemory> outputMemory = mapMemory(pools[OUTPUT]);
        ASSERT_NE(nullptr, inputMemory.get());
        ASSERT_NE(nullptr, outputMemory.get());
        char* inputPtr = reinterpret_cast<char*>(static_cast<void*>(inputMemory->getPointer()));
        char* outputPtr = reinterpret_cast<char*>(static_cast<void*>(outputMemory->getPointer()));
        ASSERT_NE(nullptr, inputPtr);
        ASSERT_NE(nullptr, outputPtr);
        inputMemory->update();
        outputMemory->update();

        // Go through all inputs, copy the values
        for_all(inputs, [&inputs_info, inputPtr](int index, auto p, auto s) {
            char* begin = (char*)p;
            char* end = begin + s;
            // TODO: handle more than one input
            std::copy(begin, end, inputPtr + inputs_info[index].location.offset);
        });

        inputMemory->commit();
        outputMemory->commit();

        // launch execution
        sp<ExecutionCallback> executionCallback = new ExecutionCallback();
        ASSERT_NE(nullptr, executionCallback.get());
        Return<ErrorStatus> executionLaunchStatus = preparedModel->execute(
            {.inputs = inputs_info, .outputs = outputs_info, .pools = pools}, executionCallback);
        ASSERT_TRUE(executionLaunchStatus.isOk());
        EXPECT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(executionLaunchStatus));

        // retrieve execution status
        executionCallback->wait();
        ErrorStatus executionReturnStatus = executionCallback->getStatus();
        EXPECT_EQ(ErrorStatus::NONE, executionReturnStatus);

        // validate results
        outputMemory->read();
        copy_back(&test, outputs_info, outputPtr);
        outputMemory->commit();
        // Filter out don't cares
        MixedTyped filtered_golden = filter(golden, is_ignored);
        MixedTyped filtered_test = filter(test, is_ignored);

        // We want "close-enough" results for float
        compare(filtered_golden, filtered_test, fpRange);
    }
}

void Execute(const sp<V1_0::IDevice>& device, std::function<V1_0::Model(void)> create_model,
             std::function<bool(int)> is_ignored,
             const std::vector<MixedTypedExampleType>& examples) {
    V1_0::Model model = create_model();

    // see if service can handle model
    bool fullySupportsModel = false;
    Return<void> supportedCall = device->getSupportedOperations(
        model, [&fullySupportsModel](ErrorStatus status, const hidl_vec<bool>& supported) {
            ASSERT_EQ(ErrorStatus::NONE, status);
            ASSERT_NE(0ul, supported.size());
            fullySupportsModel =
                std::all_of(supported.begin(), supported.end(), [](bool valid) { return valid; });
        });
    ASSERT_TRUE(supportedCall.isOk());

    // launch prepare model
    sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
    ASSERT_NE(nullptr, preparedModelCallback.get());
    Return<ErrorStatus> prepareLaunchStatus = device->prepareModel(model, preparedModelCallback);
    ASSERT_TRUE(prepareLaunchStatus.isOk());
    ASSERT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(prepareLaunchStatus));

    // retrieve prepared model
    preparedModelCallback->wait();
    ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
    sp<IPreparedModel> preparedModel = preparedModelCallback->getPreparedModel();

    // early termination if vendor service cannot fully prepare model
    if (!fullySupportsModel && prepareReturnStatus != ErrorStatus::NONE) {
        ASSERT_EQ(nullptr, preparedModel.get());
        LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
                     "prepare model that it does not support.";
        std::cout << "[          ]   Early termination of test because vendor service cannot "
                     "prepare model that it does not support."
                  << std::endl;
        return;
    }
    EXPECT_EQ(ErrorStatus::NONE, prepareReturnStatus);
    ASSERT_NE(nullptr, preparedModel.get());

    EvaluatePreparedModel(preparedModel, is_ignored, examples);
}

void Execute(const sp<V1_1::IDevice>& device, std::function<V1_1::Model(void)> create_model,
             std::function<bool(int)> is_ignored,
             const std::vector<MixedTypedExampleType>& examples) {
    V1_1::Model model = create_model();

    // see if service can handle model
    bool fullySupportsModel = false;
    Return<void> supportedCall = device->getSupportedOperations_1_1(
        model, [&fullySupportsModel](ErrorStatus status, const hidl_vec<bool>& supported) {
            ASSERT_EQ(ErrorStatus::NONE, status);
            ASSERT_NE(0ul, supported.size());
            fullySupportsModel =
                std::all_of(supported.begin(), supported.end(), [](bool valid) { return valid; });
        });
    ASSERT_TRUE(supportedCall.isOk());

    // launch prepare model
    sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
    ASSERT_NE(nullptr, preparedModelCallback.get());
    Return<ErrorStatus> prepareLaunchStatus = device->prepareModel_1_1(
        model, ExecutionPreference::FAST_SINGLE_ANSWER, preparedModelCallback);
    ASSERT_TRUE(prepareLaunchStatus.isOk());
    ASSERT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(prepareLaunchStatus));

    // retrieve prepared model
    preparedModelCallback->wait();
    ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
    sp<IPreparedModel> preparedModel = preparedModelCallback->getPreparedModel();

    // early termination if vendor service cannot fully prepare model
    if (!fullySupportsModel && prepareReturnStatus != ErrorStatus::NONE) {
        ASSERT_EQ(nullptr, preparedModel.get());
        LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
                     "prepare model that it does not support.";
        std::cout << "[          ]   Early termination of test because vendor service cannot "
                     "prepare model that it does not support."
                  << std::endl;
        return;
    }
    EXPECT_EQ(ErrorStatus::NONE, prepareReturnStatus);
    ASSERT_NE(nullptr, preparedModel.get());

    // If in relaxed mode, set the error range to be 5ULP of FP16.
    float fpRange = !model.relaxComputationFloat32toFloat16 ? 1e-5f : 5.0f * 0.0009765625f;
    EvaluatePreparedModel(preparedModel, is_ignored, examples, fpRange);
}

}  // namespace generated_tests

}  // namespace neuralnetworks
}  // namespace hardware
}  // namespace android