/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "RNN.h"

#include "NeuralNetworksWrapper.h"
#include "gmock/gmock-matchers.h"
#include "gtest/gtest.h"

namespace android {
namespace nn {
namespace wrapper {

using ::testing::Each;
using ::testing::FloatNear;
using ::testing::Matcher;

namespace {

std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
                                           float max_abs_error = 1.e-5) {
  std::vector<Matcher<float>> matchers;
  matchers.reserve(values.size());
  for (const float& v : values) {
    matchers.emplace_back(FloatNear(v, max_abs_error));
  }
  return matchers;
}

static float rnn_input[] = {
    0.23689353,   0.285385,     0.037029743, -0.19858193,  -0.27569133,
    0.43773448,   0.60379338,   0.35562468,  -0.69424844,  -0.93421471,
    -0.87287879,  0.37144363,   -0.62476718, 0.23791671,   0.40060222,
    0.1356622,    -0.99774903,  -0.98858172, -0.38952237,  -0.47685933,
    0.31073618,   0.71511042,   -0.63767755, -0.31729108,  0.33468103,
    0.75801885,   0.30660987,   -0.37354088, 0.77002847,   -0.62747043,
    -0.68572164,  0.0069220066, 0.65791464,  0.35130811,   0.80834007,
    -0.61777675,  -0.21095741,  0.41213346,  0.73784804,   0.094794154,
    0.47791874,   0.86496925,   -0.53376222, 0.85315156,   0.10288584,
    0.86684,      -0.011186242, 0.10513687,  0.87825835,   0.59929144,
    0.62827742,   0.18899453,   0.31440187,  0.99059987,   0.87170351,
    -0.35091716,  0.74861872,   0.17831337,  0.2755419,    0.51864719,
    0.55084288,   0.58982027,   -0.47443086, 0.20875752,   -0.058871567,
    -0.66609079,  0.59098077,   0.73017097,  0.74604273,   0.32882881,
    -0.17503482,  0.22396147,   0.19379807,  0.29120302,   0.077113032,
    -0.70331609,  0.15804303,   -0.93407321, 0.40182066,   0.036301374,
    0.66521823,   0.0300982,    -0.7747041,  -0.02038002,  0.020698071,
    -0.90300065,  0.62870288,   -0.23068321, 0.27531278,   -0.095755219,
    -0.712036,    -0.17384434,  -0.50593495, -0.18646687,  -0.96508682,
    0.43519354,   0.14744234,   0.62589407,  0.1653645,    -0.10651493,
    -0.045277178, 0.99032974,   -0.88255352, -0.85147917,  0.28153265,
    0.19455957,   -0.55479527,  -0.56042433, 0.26048636,   0.84702539,
    0.47587705,   -0.074295521, -0.12287641, 0.70117295,   0.90532446,
    0.89782166,   0.79817224,   0.53402734,  -0.33286154,  0.073485017,
    -0.56172788,  -0.044897556, 0.89964068,  -0.067662835, 0.76863563,
    0.93455386,   -0.6324693,   -0.083922029};

static float rnn_golden_output[] = {
    0.496726,   0,          0.965996,  0,         0.0584254, 0,
    0,          0.12315,    0,         0,         0.612266,  0.456601,
    0,          0.52286,    1.16099,   0.0291232,

    0,          0,          0.524901,  0,         0,         0,
    0,          1.02116,    0,         1.35762,   0,         0.356909,
    0.436415,   0.0355727,  0,         0,

    0,          0,          0,         0.262335,  0,         0,
    0,          1.33992,    0,         2.9739,    0,         0,
    1.31914,    2.66147,    0,         0,

    0.942568,   0,          0,         0,         0.025507,  0,
    0,          0,          0.321429,  0.569141,  1.25274,   1.57719,
    0.8158,     1.21805,    0.586239,  0.25427,

    1.04436,    0,          0.630725,  0,         0.133801,  0.210693,
    0.363026,   0,          0.533426,  0,         1.25926,   0.722707,
    0,          1.22031,    1.30117,   0.495867,

    0.222187,   0,          0.72725,   0,         0.767003,  0,
    0,          0.147835,   0,         0,         0,         0.608758,
    0.469394,   0.00720298, 0.927537,  0,

    0.856974,   0.424257,   0,         0,         0.937329,  0,
    0,          0,          0.476425,  0,         0.566017,  0.418462,
    0.141911,   0.996214,   1.13063,   0,

    0.967899,   0,          0,         0,         0.0831304, 0,
    0,          1.00378,    0,         0,         0,         1.44818,
    1.01768,    0.943891,   0.502745,  0,

    0.940135,   0,          0,         0,         0,         0,
    0,          2.13243,    0,         0.71208,   0.123918,  1.53907,
    1.30225,    1.59644,    0.70222,   0,

    0.804329,   0,          0.430576,  0,         0.505872,  0.509603,
    0.343448,   0,          0.107756,  0.614544,  1.44549,   1.52311,
    0.0454298,  0.300267,   0.562784,  0.395095,

    0.228154,   0,          0.675323,  0,         1.70536,   0.766217,
    0,          0,          0,         0.735363,  0.0759267, 1.91017,
    0.941888,   0,          0,         0,

    0,          0,          1.5909,    0,         0,         0,
    0,          0.5755,     0,         0.184687,  0,         1.56296,
    0.625285,   0,          0,         0,

    0,          0,          0.0857888, 0,         0,         0,
    0,          0.488383,   0.252786,  0,         0,         0,
    1.02817,    1.85665,    0,         0,

    0.00981836, 0,          1.06371,   0,         0,         0,
    0,          0,          0,         0.290445,  0.316406,  0,
    0.304161,   1.25079,    0.0707152, 0,

    0.986264,   0.309201,   0,         0,         0,         0,
    0,          1.64896,    0.346248,  0,         0.918175,  0.78884,
    0.524981,   1.92076,    2.07013,   0.333244,

    0.415153,   0.210318,   0,         0,         0,         0,
    0,          2.02616,    0,         0.728256,  0.84183,   0.0907453,
    0.628881,   3.58099,    1.49974,   0};

}  // anonymous namespace

#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
  ACTION(Input)                                  \
  ACTION(Weights)                                \
  ACTION(RecurrentWeights)                       \
  ACTION(Bias)                                   \
  ACTION(HiddenStateIn)

// For all output and intermediate states
#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
  ACTION(HiddenStateOut)               \
  ACTION(Output)

class BasicRNNOpModel {
 public:
  BasicRNNOpModel(uint32_t batches, uint32_t units, uint32_t size)
      : batches_(batches),
        units_(units),
        input_size_(size),
        activation_(kActivationRelu) {
    std::vector<uint32_t> inputs;

    OperandType InputTy(Type::TENSOR_FLOAT32, {batches_, input_size_});
    inputs.push_back(model_.addOperand(&InputTy));
    OperandType WeightTy(Type::TENSOR_FLOAT32, {units_, input_size_});
    inputs.push_back(model_.addOperand(&WeightTy));
    OperandType RecurrentWeightTy(Type::TENSOR_FLOAT32, {units_, units_});
    inputs.push_back(model_.addOperand(&RecurrentWeightTy));
    OperandType BiasTy(Type::TENSOR_FLOAT32, {units_});
    inputs.push_back(model_.addOperand(&BiasTy));
    OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
    inputs.push_back(model_.addOperand(&HiddenStateTy));
    OperandType ActionParamTy(Type::INT32, {});
    inputs.push_back(model_.addOperand(&ActionParamTy));

    std::vector<uint32_t> outputs;

    outputs.push_back(model_.addOperand(&HiddenStateTy));
    OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_});
    outputs.push_back(model_.addOperand(&OutputTy));

    Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
    HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f);
    HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f);
    Output_.insert(Output_.end(), batches_ * units_, 0.f);

    model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs);
    model_.identifyInputsAndOutputs(inputs, outputs);

    model_.finish();
  }

#define DefineSetter(X)                          \
  void Set##X(const std::vector<float>& f) {     \
    X##_.insert(X##_.end(), f.begin(), f.end()); \
  }

  FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);

#undef DefineSetter

  void SetInput(int offset, float* begin, float* end) {
    for (; begin != end; begin++, offset++) {
      Input_[offset] = *begin;
    }
  }

  void ResetHiddenState() {
    std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f);
    std::fill(HiddenStateOut_.begin(), HiddenStateOut_.end(), 0.f);
  }

  const std::vector<float>& GetOutput() const { return Output_; }

  uint32_t input_size() const { return input_size_; }
  uint32_t num_units() const { return units_; }
  uint32_t num_batches() const { return batches_; }

  void Invoke() {
    ASSERT_TRUE(model_.isValid());

    HiddenStateIn_.swap(HiddenStateOut_);

    Compilation compilation(&model_);
    compilation.finish();
    Execution execution(&compilation);
#define SetInputOrWeight(X)                                                   \
  ASSERT_EQ(execution.setInput(RNN::k##X##Tensor, X##_.data(),                \
                               sizeof(float) * X##_.size()),                  \
            Result::NO_ERROR);

    FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);

#undef SetInputOrWeight

#define SetOutput(X)                                                           \
  ASSERT_EQ(execution.setOutput(RNN::k##X##Tensor, X##_.data(),                \
                                sizeof(float) * X##_.size()),                  \
            Result::NO_ERROR);

    FOR_ALL_OUTPUT_TENSORS(SetOutput);

#undef SetOutput

    ASSERT_EQ(execution.setInput(RNN::kActivationParam, &activation_,
                                 sizeof(activation_)),
              Result::NO_ERROR);

    ASSERT_EQ(execution.compute(), Result::NO_ERROR);
  }

 private:
  Model model_;

  const uint32_t batches_;
  const uint32_t units_;
  const uint32_t input_size_;

  const int activation_;

#define DefineTensor(X) std::vector<float> X##_;

  FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
  FOR_ALL_OUTPUT_TENSORS(DefineTensor);

#undef DefineTensor
};

TEST(RNNOpTest, BlackBoxTest) {
  BasicRNNOpModel rnn(2, 16, 8);
  rnn.SetWeights(
      {0.461459,    0.153381,   0.529743,    -0.00371218, 0.676267,   -0.211346,
       0.317493,    0.969689,   -0.343251,   0.186423,    0.398151,   0.152399,
       0.448504,    0.317662,   0.523556,    -0.323514,   0.480877,   0.333113,
       -0.757714,   -0.674487,  -0.643585,   0.217766,    -0.0251462, 0.79512,
       -0.595574,   -0.422444,  0.371572,    -0.452178,   -0.556069,  -0.482188,
       -0.685456,   -0.727851,  0.841829,    0.551535,    -0.232336,  0.729158,
       -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,   -0.423241,
       0.548547,    -0.0152023, -0.757482,   -0.85491,    0.251331,   -0.989183,
       0.306261,    -0.340716,  0.886103,    -0.0726757,  -0.723523,  -0.784303,
       0.0354295,   0.566564,   -0.485469,   -0.620498,   0.832546,   0.697884,
       -0.279115,   0.294415,   -0.584313,   0.548772,    0.0648819,  0.968726,
       0.723834,    -0.0080452, -0.350386,   -0.272803,   0.115121,   -0.412644,
       -0.824713,   -0.992843,  -0.592904,   -0.417893,   0.863791,   -0.423461,
       -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,   -0.639158,
       0.816969,    -0.337228,  0.659878,    0.73107,     0.754768,   -0.337042,
       0.0960841,   0.368357,   0.244191,    -0.817703,   -0.211223,  0.442012,
       0.37225,     -0.623598,  -0.405423,   0.455101,    0.673656,   -0.145345,
       -0.511346,   -0.901675,  -0.81252,    -0.127006,   0.809865,   -0.721884,
       0.636255,    0.868989,   -0.347973,   -0.10179,    -0.777449,  0.917274,
       0.819286,    0.206218,   -0.00785118, 0.167141,    0.45872,    0.972934,
       -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057,  -0.469077,
       0.277308,    0.415818});

  rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
               -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
               0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
               -0.37609905});

  rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0.1});

  rnn.ResetHiddenState();
  const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
                                  (rnn.input_size() * rnn.num_batches());

  for (int i = 0; i < input_sequence_size; i++) {
    float* batch_start = rnn_input + i * rnn.input_size();
    float* batch_end = batch_start + rnn.input_size();
    rnn.SetInput(0, batch_start, batch_end);
    rnn.SetInput(rnn.input_size(), batch_start, batch_end);

    rnn.Invoke();

    float* golden_start = rnn_golden_output + i * rnn.num_units();
    float* golden_end = golden_start + rnn.num_units();
    std::vector<float> expected;
    expected.insert(expected.end(), golden_start, golden_end);
    expected.insert(expected.end(), golden_start, golden_end);

    EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
  }
}

}  // namespace wrapper
}  // namespace nn
}  // namespace android